import React, { useEffect } from 'react';
import * as yup from 'yup';
import { yupResolver } from '@hookform/resolvers/yup';
import { SubmitHandler, useForm } from 'react-hook-form';
import semverValid from 'semver/functions/valid';
import semverInc from 'semver/functions/inc';
import { createStyles, makeStyles, Theme } from '@material-ui/core/styles';
import Typography from '@material-ui/core/Typography';
import Button from '@material-ui/core/Button';
import Alert from '@material-ui/lab/Alert';
import TextField, { TextFieldProps } from '@material-ui/core/TextField';
import TrainIcon from '@material-ui/icons/Train';

import { Label } from '../../hooks/useLabels';
import defaultLabelData from '../../config/defaultLabelData';
import { Model as ModelType } from '../../API';

export type DataForm = {
  epochs: number;
  batchSize: number;
  learningRate: number;
  version: string;
  description: string;
};

export type TrainData = {
  epochs: number;
  batchSize: number;
  learningRate: number;
  version: string;
  description: string;
};

export type TrainModelProps = {
  onSubmit: SubmitHandler<TrainData>;
  isLoading: boolean;
  models?: ModelType[];
  labels: Label[];
};

const useStyles = makeStyles((theme: Theme) =>
  createStyles({
    alert: {
      marginTop: theme.spacing(1),
      marginBottom: theme.spacing(1),
    },
  })
);

const TrainModel = ({ labels, onSubmit, isLoading, models }: TrainModelProps) => {
  const classes = useStyles();

  const schema = yup.object().shape({
    epochs: yup.number().required(),
    batchSize: yup.number().required(),
    learningRate: yup.number().required(),
    version: yup
      .string()
      .required()
      .trim()
      .test(
        'is-semver',
        // eslint-disable-next-line
        '${path} not matched to https://semver.org format.',
        (value) => !!semverValid(value)
      ),
    description: yup.string(),
  });

  const {
    register,
    handleSubmit: handleUseFormSubmit,
    formState: { errors, isValid },
    setValue,
  } = useForm({
    resolver: yupResolver(schema),
    mode: 'onBlur',
  });

  const getRegisterProps = (name: string) => {
    const { ref: inputRef, ...registerProps } = register(name);
    return {
      inputRef,
      ...registerProps,
      id: name,
      error: !!errors[name],
      helperText: errors[name]?.message,
      variant: 'standard',
      fullWidth: true,
      margin: 'normal',
      inputProps: {
        min: 0,
      },
    } as TextFieldProps;
  };

  const handleSubmit: SubmitHandler<DataForm> = (dataForm, evt) => {
    onSubmit({ ...dataForm }, evt);
  };

  const isAnyFrameRecorded =
    labels.map((c) => c.frames.length).reduce((sum: number, val: number) => sum + val, 0) > 0;

  const isDefaultCategoryAvailable = !!labels.find(
    (label) => label.name === defaultLabelData?.label
  );

  const version = models && models[0]?.version;

  useEffect(() => {
    if (version) {
      setValue('version', semverInc(version, 'patch'));
      setValue('description', '');
    }
    // eslint-disable-next-line
  }, [version]);

  return (
    <section>
      <Typography component="h3" variant="h5" gutterBottom>
        Train Model
      </Typography>
      <form id="add-project-form" onSubmit={handleUseFormSubmit(handleSubmit as any)} noValidate>
        <TextField
          {...getRegisterProps('epochs')}
          type="number"
          defaultValue={50}
          label="Epochs"
          required
        />
        <TextField
          {...getRegisterProps('batchSize')}
          type="number"
          defaultValue={32}
          label="Batch Size"
          required
        />
        <TextField
          {...getRegisterProps('learningRate')}
          type="number"
          inputProps={{
            step: '0.01',
          }}
          defaultValue={0.01}
          label="Learning Rate"
          required
        />
        <TextField {...getRegisterProps('version')} defaultValue="1.0.0" label="Version" required />
        <TextField
          {...getRegisterProps('description')}
          label="Description"
          minRows={1}
          maxRows={4}
          multiline
        />
        {!isDefaultCategoryAvailable && (
          <Alert className={classes.alert} severity="warning">
            {defaultLabelData?.warning}
          </Alert>
        )}
        <Button
          disabled={!isValid || !labels.length || isLoading || !isAnyFrameRecorded}
          type="submit"
          startIcon={<TrainIcon />}
          variant="contained"
          color="primary"
        >
          Train
        </Button>
      </form>
    </section>
  );
};

export default TrainModel;
