import React from 'react';
import { createStyles, makeStyles, Theme } from '@material-ui/core/styles';
import List from '@material-ui/core/List';
import ListItem from '@material-ui/core/ListItem';
import ListItemIcon from '@material-ui/core/ListItemIcon';
import TrackChangesIcon from '@material-ui/icons/TrackChanges';
import ListItemText from '@material-ui/core/ListItemText';
import TrendingDownIcon from '@material-ui/icons/TrendingDown';
import SettingsApplicationsIcon from '@material-ui/icons/SettingsApplications';
import LoopIcon from '@material-ui/icons/Loop';
import EmojiObjectsIcon from '@material-ui/icons/EmojiObjects';

export type ModelStatsProps = {
  epochs?: number | null;
  batchSize?: number | null;
  learningRate?: number | null;
  testLoss?: number | null;
  testAccuracy?: number | null;
};

const useStyles = makeStyles((theme: Theme) =>
  createStyles({
    lists: {
      display: 'flex',
    },
    list: {
      marginRight: theme.spacing(4),
    },
  })
);

const ModelStats = ({
  epochs,
  batchSize,
  testLoss,
  testAccuracy,
  learningRate,
}: ModelStatsProps): JSX.Element => {
  const classes = useStyles();

  const roundToNDecimals = (value: number, decimalPlaces: number = 1): number => {
    const decimalPower = 10 ** decimalPlaces;
    return Math.round(value * decimalPower) / decimalPower;
  };

  const groups = [
    [
      {
        name: 'Epochs',
        value: epochs ?? '-',
        icon: <LoopIcon />,
      },
      {
        name: 'Batch Size',
        value: batchSize ?? '-',
        icon: <SettingsApplicationsIcon />,
      },
      {
        name: 'Learning Rate',
        value: learningRate ?? '-',
        icon: <EmojiObjectsIcon />,
      },
    ],
    [
      {
        name: 'Accuracy',
        value:
          testAccuracy !== undefined && testAccuracy !== null
            ? roundToNDecimals(testAccuracy, 4)
            : '-',
        icon: <TrackChangesIcon />,
      },
      {
        name: 'Loss',
        value: testLoss !== undefined && testLoss !== null ? roundToNDecimals(testLoss, 4) : '-',
        icon: <TrendingDownIcon />,
      },
    ],
  ];

  return (
    <div className={classes.lists}>
      {groups.map((group) => (
        <List dense className={classes.list}>
          {group.map(({ icon, name, value }) => (
            <ListItem>
              <ListItemIcon>{icon}</ListItemIcon>
              <ListItemText primary={`${name}: ${value}`} />
            </ListItem>
          ))}
        </List>
      ))}
    </div>
  );
};

export default ModelStats;
