import {
    ActionNoiseType,
    AlgorithmDTO,
    AlgorithmParameter,
    AlgorithmType,
    FrequencyUnit,
    ReplayBufferClass,
    Scope,
} from '../models/algorithm-form';
import { FormControl, FormGroup, Validators } from '@angular/forms';
import { Injectable } from '@angular/core';
import { AlgorithmValidator } from '../validators/algorithm-validator';
import { nonNil } from '../../shared/utility';

@Injectable()
export class AlgorithmFormService {
    getForm(dto: AlgorithmDTO): FormGroup {
        const scopeMapping = new Map<string, Map<string, any>>();
        dto.parameters.forEach((algorithmParameter: AlgorithmParameter) => {
            let currentScope = scopeMapping.get(algorithmParameter.scope);
            if (!currentScope) {
                currentScope = new Map<string, any>();
                scopeMapping.set(algorithmParameter.scope, currentScope);
            }
            currentScope.set(algorithmParameter.name, algorithmParameter.value);
        });

        switch (dto.type) {
            case AlgorithmType.PPO:
                return this.getPPOForm(scopeMapping, false);
            case AlgorithmType.ADAM:
                return this.getADAMForm(scopeMapping);
            case AlgorithmType.SAC:
                return this.getSACForm(scopeMapping, false);
            case AlgorithmType.BC:
                return this.getBCForm(scopeMapping);
            case AlgorithmType.GAIL:
            case AlgorithmType.AIRL:
                return this.getAdversarialTrainerForm(scopeMapping);
        }
    }

    private getEnvironmentFormPrototype(
        environmentParameters: Map<string, any>
    ): FormGroup {
        return new FormGroup({
            maxEpisodeLength: new FormControl<string>(
                this.asString('maxEpisodeLength', environmentParameters),
                [
                    Validators.required,
                    AlgorithmValidator.validatePositiveInteger,
                ]
            ),
        });
    }

    private getPPOTrainingSetupFormPrototype(
        trainingSetupParameters: Map<string, any>,
        isLearner: boolean
    ): FormGroup {
        let groupPrototype = {
            nEpochs: new FormControl<string>(
                this.asString('nEpochs', trainingSetupParameters),
                {
                    validators: [
                        Validators.required,
                        AlgorithmValidator.validatePositiveInteger,
                    ],
                }
            ),
            nSteps: new FormControl<string>(
                this.asString('nSteps', trainingSetupParameters),
                {
                    validators: [
                        Validators.required,
                        AlgorithmValidator.validatePositiveIntegerGreaterOrEqualThanCustom(
                            2
                        ),
                    ],
                }
            ),
        };
        if (!isLearner) {
            groupPrototype['totalTimesteps'] = new FormControl<string>(
                this.asString('totalTimesteps', trainingSetupParameters),
                {
                    validators: [
                        Validators.required,
                        AlgorithmValidator.validatePositiveIntegerGreaterOrEqualThanCustom(
                            2
                        ),
                    ],
                }
            );
        }
        return new FormGroup(groupPrototype);
    }

    private getPPOSpecificFormPrototype(specificParameters: Map<string, any>) {
        let groupPrototype = {
            policy: new FormControl<string>(
                this.asString('policy', specificParameters),
                { validators: [Validators.required] }
            ),
            gamma: new FormControl<string>(
                this.asString('gamma', specificParameters),
                {
                    validators: [
                        Validators.required,
                        AlgorithmValidator.validatePositiveFloatUpTo1Closed,
                    ],
                }
            ),
            normalizeAdvantage: new FormControl<boolean>(
                specificParameters.get('normalizeAdvantage') as boolean
            ),
            entCoef: new FormControl<string>(
                this.asString('entCoef', specificParameters),
                {
                    validators: [
                        Validators.required,
                        AlgorithmValidator.validatePositiveFloatUpTo1Closed,
                    ],
                }
            ),
            vfCoef: new FormControl<string>(
                this.asString('vfCoef', specificParameters),
                {
                    validators: [
                        Validators.required,
                        AlgorithmValidator.validatePositiveFloatUpTo1Closed,
                    ],
                }
            ),
            gaeLambda: new FormControl<string>(
                this.asString('gaeLambda', specificParameters),
                {
                    validators: [
                        Validators.required,
                        AlgorithmValidator.validatePositiveFloatUpTo1Closed,
                    ],
                }
            ),
            clipRangeVf: new FormControl<string>(
                this.asString('clipRangeVf', specificParameters),
                {
                    validators: [
                        AlgorithmValidator.validatePositiveFloatUpTo1Closed,
                    ],
                }
            ),
            clipRange: new FormControl<string>(
                this.asString('clipRange', specificParameters),
                {
                    validators: [
                        Validators.required,
                        AlgorithmValidator.validatePositiveFloatUpTo1Closed,
                    ],
                }
            ),
            maxGradNorm: new FormControl<string>(
                this.asString('maxGradNorm', specificParameters),
                {
                    validators: [
                        Validators.required,
                        AlgorithmValidator.validatePositiveFloatUpTo1Closed,
                    ],
                }
            ),
            targetKl: new FormControl<string>(
                this.asString('targetKl', specificParameters),
                { validators: AlgorithmValidator.validateNonNegativeFloat }
            ),
        };
        groupPrototype['useSde'] = new FormControl<boolean>(
            specificParameters.get('useSde') as boolean
        );
        groupPrototype['sdeSampleFreq'] = new FormControl<string>(
            this.asString('sdeSampleFreq', specificParameters),
            {
                validators: [
                    Validators.required,
                    AlgorithmValidator.validatePositiveIntegerOrMinusOne,
                ],
            }
        );
        return new FormGroup(groupPrototype);
    }

    getGeneralFormPrototype(generalParameters: Map<string, any>): FormGroup {
        return new FormGroup({
            seed: new FormControl<string>(
                this.asString('seed', generalParameters),
                {
                    validators: [AlgorithmValidator.validateNonNegativeInteger],
                }
            ),
            batchSize: new FormControl<string>(
                this.asString('batchSize', generalParameters).toString(),
                {
                    validators: [
                        Validators.required,
                        AlgorithmValidator.validatePowerOfTwo,
                        AlgorithmValidator.validateInteger,
                    ],
                }
            ),
        });
    }

    private getBaseOptimizerFormPrototype(
        optimizerParameters: Map<string, any>
    ): FormGroup {
        return new FormGroup({
            learningRate: new FormControl<string>(
                this.asString('learningRate', optimizerParameters).toString(),
                {
                    validators: [
                        Validators.required,
                        AlgorithmValidator.validatePositiveFloatUpTo1Open,
                    ],
                }
            ),
            type: new FormControl<string>(
                this.asString('type', optimizerParameters).toString(),
                {}
            ),
        });
    }

    private getPPOForm(
        parameter: Map<string, Map<string, any>>,
        isLearner: boolean
    ): FormGroup {
        let scopes = this.getScopes(isLearner);

        const generalFormPrototype = this.getGeneralFormPrototype(
            parameter.get(scopes.generalScope)
        );
        const optimizerFormPrototype = this.getBaseOptimizerFormPrototype(
            parameter.get(scopes.optimizerScope)
        );
        const trainingSetupFormPrototype =
            this.getPPOTrainingSetupFormPrototype(
                parameter.get(scopes.trainingSetupScope),
                isLearner
            );
        const trainingSpecificFormPrototype = this.getPPOSpecificFormPrototype(
            parameter.get(scopes.specificScope)
        );
        let groupPrototype = {
            [scopes.generalScope]: generalFormPrototype,
            [scopes.optimizerScope]: optimizerFormPrototype,
            [scopes.trainingSetupScope]: trainingSetupFormPrototype,
            [scopes.specificScope]: trainingSpecificFormPrototype,
        };

        if (!isLearner) {
            const environmentFormPrototype = this.getEnvironmentFormPrototype(
                parameter.get(Scope.ENVIRONMENT)
            );
            groupPrototype[Scope.ENVIRONMENT] = environmentFormPrototype;
        }

        return new FormGroup(groupPrototype);
    }

    private getADAMOptimizerFormPrototype(
        optimizerParameter: Map<string, any>
    ): FormGroup {
        const baseOptimizerFormPrototype =
            this.getBaseOptimizerFormPrototype(optimizerParameter);

        baseOptimizerFormPrototype.addControl(
            'weightDecay',
            new FormControl<string>(
                this.asString('weightDecay', optimizerParameter)
            )
        );

        baseOptimizerFormPrototype.addControl(
            'beta_1',
            new FormControl<string>(
                this.asString('beta_1', optimizerParameter),
                {
                    validators: [
                        Validators.required,
                        AlgorithmValidator.validatePositiveFloatUpTo1Open,
                    ],
                }
            )
        );
        baseOptimizerFormPrototype.addControl(
            'beta_2',
            new FormControl<string>(
                this.asString('beta_2', optimizerParameter),
                {
                    validators: [
                        Validators.required,
                        AlgorithmValidator.validatePositiveFloatUpTo1Open,
                    ],
                }
            )
        );
        baseOptimizerFormPrototype.addControl(
            'epsilon',
            new FormControl<string>(
                this.asString('epsilon', optimizerParameter),
                {
                    validators: [
                        Validators.required,
                        AlgorithmValidator.validateInteger,
                        AlgorithmValidator.validateEpsilonInRange,
                    ],
                }
            )
        );
        baseOptimizerFormPrototype.addControl(
            'amsgrad',
            new FormControl<boolean>(
                optimizerParameter.get('amsgrad') as boolean,
                { validators: [Validators.required] }
            )
        );
        return baseOptimizerFormPrototype;
    }

    private getADAMTrainingSetupFormPrototype(
        specificParameters: Map<string, any>
    ): FormGroup {
        return new FormGroup({
            nEpochs: new FormControl<string>(
                this.asString('nEpochs', specificParameters),
                {
                    validators: [
                        Validators.required,
                        AlgorithmValidator.validatePositiveInteger,
                    ],
                }
            ),
        });
    }

    private getADAMForm(parameter: Map<string, Map<string, any>>): FormGroup {
        const generalFormPrototype = this.getGeneralFormPrototype(
            parameter.get(Scope.GENERAL)
        );
        const optimizerFormPrototype = this.getADAMOptimizerFormPrototype(
            parameter.get(Scope.OPTIMIZER)
        );
        const trainingSetupFormPrototype =
            this.getADAMTrainingSetupFormPrototype(
                parameter.get(Scope.TRAINING_SETUP)
            );
        return new FormGroup({
            [Scope.GENERAL]: generalFormPrototype,
            [Scope.OPTIMIZER]: optimizerFormPrototype,
            [Scope.TRAINING_SETUP]: trainingSetupFormPrototype,
        });
    }

    private getSACSpecificFormPrototype(
        specificParameters: Map<string, any>,
        isLearner: boolean
    ) {
        let groupPrototype = {
            gamma: new FormControl<string>(
                this.asString('gamma', specificParameters),
                [
                    Validators.required,
                    AlgorithmValidator.validatePositiveFloatUpTo1Closed,
                ]
            ),
            bufferSize: new FormControl<string>(
                this.asString('bufferSize', specificParameters),
                [
                    Validators.required,
                    AlgorithmValidator.validatePositiveInteger,
                ]
            ),
            learningStarts: new FormControl<string>(
                this.asString('learningStarts', specificParameters),
                [
                    Validators.required,
                    AlgorithmValidator.validatePositiveInteger,
                ]
            ),
            tau: new FormControl<string>(
                this.asString('tau', specificParameters),
                [
                    Validators.required,
                    AlgorithmValidator.validatePositiveFloatUpTo1Closed,
                ]
            ),
            gradientSteps: new FormControl<string>(
                this.asString('gradientSteps', specificParameters),
                [
                    Validators.required,
                    AlgorithmValidator.validatePositiveInteger,
                ]
            ),
            actionNoiseType: new FormControl<ActionNoiseType>(
                specificParameters.get('actionNoiseType') as ActionNoiseType,
                [Validators.required]
            ),
            replayBufferClass: new FormControl<ReplayBufferClass>(
                specificParameters.get(
                    'replayBufferClass'
                ) as ReplayBufferClass,
                [Validators.required]
            ),
            optimizeMemoryUsage: new FormControl<boolean>(
                specificParameters.get('optimizeMemoryUsage') as boolean,
                [Validators.required]
            ),
            entCoef: new FormControl<string>(
                this.asString('entCoef', specificParameters),
                [AlgorithmValidator.validatePositiveFloatUpTo1Closed]
            ),
            targetUpdateInterval: new FormControl<string>(
                this.asString('targetUpdateInterval', specificParameters),
                [
                    Validators.required,
                    AlgorithmValidator.validatePositiveInteger,
                ]
            ),
            targetEntropy: new FormControl<string>(
                this.asString('targetEntropy', specificParameters),
                [AlgorithmValidator.validatePositiveFloatUpTo1Closed]
            ),
        };

        if (!isLearner) {
            groupPrototype['useSde'] = new FormControl<boolean>(
                specificParameters.get('useSde') as boolean,
                [Validators.required]
            );
            groupPrototype['sdeSampleFreq'] = new FormControl<string>(
                this.asString('sdeSampleFreq', specificParameters),
                [
                    Validators.required,
                    AlgorithmValidator.validatePositiveIntegerOrMinusOne,
                ]
            );
            groupPrototype['useSdeAtWarmup'] = new FormControl<boolean>(
                specificParameters.get('useSdeAtWarmup') as boolean,
                [Validators.required]
            );
        }
        return new FormGroup(groupPrototype);
    }

    private getSACTrainingSetupFormPrototype(
        trainingSetupParameters: Map<string, any>,
        isLearner: boolean
    ): FormGroup {
        let groupPrototype = {
            trainFreqValue: new FormControl<string>(
                this.asString('trainFreqValue', trainingSetupParameters),
                [
                    Validators.required,
                    AlgorithmValidator.validatePositiveInteger,
                ]
            ),
            trainFreqUnit: new FormControl<FrequencyUnit>(
                trainingSetupParameters.get('trainFreqUnit') as FrequencyUnit,
                [Validators.required]
            ),
        };
        if (!isLearner) {
            groupPrototype['totalTimesteps'] = new FormControl<string>(
                this.asString('totalTimesteps', trainingSetupParameters),
                [
                    Validators.required,
                    AlgorithmValidator.validatePositiveIntegerGreaterOrEqualThanCustom(
                        2
                    ),
                ]
            );
        }
        return new FormGroup(groupPrototype);
    }

    private getSACForm(
        parameter: Map<string, Map<string, any>>,
        isLearner: boolean
    ): FormGroup {
        let scopes = this.getScopes(isLearner);
        const generalFormPrototype = this.getGeneralFormPrototype(
            parameter.get(scopes.generalScope)
        );
        const optimizerFormPrototype = this.getBaseOptimizerFormPrototype(
            parameter.get(scopes.optimizerScope)
        );
        const trainingSetupFormPrototype =
            this.getSACTrainingSetupFormPrototype(
                parameter.get(scopes.trainingSetupScope),
                isLearner
            );
        const trainingSpecificFormPrototype = this.getSACSpecificFormPrototype(
            parameter.get(scopes.specificScope),
            isLearner
        );

        let groupPrototype = {
            [scopes.generalScope]: generalFormPrototype,
            [scopes.optimizerScope]: optimizerFormPrototype,
            [scopes.trainingSetupScope]: trainingSetupFormPrototype,
            [scopes.specificScope]: trainingSpecificFormPrototype,
        };
        if (!isLearner) {
            const environmentFormPrototype = this.getEnvironmentFormPrototype(
                parameter.get(Scope.ENVIRONMENT)
            );
            groupPrototype[Scope.ENVIRONMENT] = environmentFormPrototype;
        }

        return new FormGroup(groupPrototype);
    }

    private getBCSpecificFormPrototype(specificParameters: Map<string, any>) {
        return new FormGroup({
            minibatchSize: new FormControl<string>(
                this.asString('minibatchSize', specificParameters),
                [
                    AlgorithmValidator.validateInteger,
                    AlgorithmValidator.validatePowerOfTwo,
                ]
            ),
            entWeight: new FormControl<string>(
                this.asString('entWeight', specificParameters),
                [Validators.required, AlgorithmValidator.validatePositiveFloat]
            ),
            l2Weight: new FormControl<string>(
                this.asString('l2Weight', specificParameters),
                [
                    Validators.required,
                    AlgorithmValidator.validateNonNegativeFloat,
                ]
            ),

            policy: new FormControl<string>(
                this.asString('policy', specificParameters),
                []
            ),
        });
    }

    private getBCTrainingSetupFormPrototype(trainingSetup: Map<string, any>) {
        return new FormGroup({
            nEpochs: new FormControl<string>(
                this.asString('nEpochs', trainingSetup),
                {
                    validators: [
                        Validators.required,
                        AlgorithmValidator.validatePositiveInteger,
                    ],
                }
            ),
        });
    }

    private getBCForm(parameter: Map<string, Map<string, any>>): FormGroup {
        const optimizerFormPrototype = this.getBaseOptimizerFormPrototype(
            parameter.get(Scope.OPTIMIZER)
        );
        const trainingSpecificFormPrototype = this.getBCSpecificFormPrototype(
            parameter.get(Scope.SPECIFIC)
        );
        const generalFormPrototype = this.getGeneralFormPrototype(
            parameter.get(Scope.GENERAL)
        );
        const trainingSetupFormPrototype = this.getBCTrainingSetupFormPrototype(
            parameter.get(Scope.TRAINING_SETUP)
        );

        return new FormGroup(
            {
                [Scope.GENERAL]: generalFormPrototype,
                [Scope.OPTIMIZER]: optimizerFormPrototype,
                [Scope.SPECIFIC]: trainingSpecificFormPrototype,
                [Scope.TRAINING_SETUP]: trainingSetupFormPrototype,
            },
            {
                validators:
                    AlgorithmValidator.validateFirstControlNotGreaterThanSecondControl(
                        {
                            firstControl: {
                                parameter: 'minibatchSize',
                                scope: Scope.SPECIFIC,
                            },
                            secondControl: {
                                parameter: 'batchSize',
                                scope: Scope.GENERAL,
                            },
                        }
                    ),
            }
        );
    }

    private getAdversarialTrainerForm(
        parameter: Map<string, Map<string, any>>
    ): FormGroup {
        const generalFormPrototype = this.getGeneralFormPrototype(
            parameter.get(Scope.GENERAL)
        );
        const optimizerFormPrototype = this.getBaseOptimizerFormPrototype(
            parameter.get(Scope.OPTIMIZER)
        );
        const trainingSetupFormPrototype =
            this.getAdversarialTrainerTrainingSetupFormPrototype(
                parameter.get(Scope.TRAINING_SETUP)
            );
        const trainingSpecificFormPrototype =
            this.getAdversarialTrainerSpecificFormPrototype(
                parameter.get(Scope.SPECIFIC)
            );
        const environmentFormPrototype = this.getEnvironmentFormPrototype(
            parameter.get(Scope.ENVIRONMENT)
        );
        let learnerGroup: FormGroup;

        switch (parameter.get(Scope.SPECIFIC).get('inputAlgorithm')) {
            case AlgorithmType.PPO:
                learnerGroup = this.getPPOForm(parameter, true);
                break;
            case AlgorithmType.SAC:
                learnerGroup = this.getSACForm(parameter, true);
                break;
        }

        return new FormGroup(
            {
                [Scope.GENERAL]: generalFormPrototype,
                [Scope.OPTIMIZER]: optimizerFormPrototype,
                [Scope.TRAINING_SETUP]: trainingSetupFormPrototype,
                [Scope.SPECIFIC]: trainingSpecificFormPrototype,
                [Scope.ENVIRONMENT]: environmentFormPrototype,
                ...learnerGroup.controls,
            },
            {
                validators: [
                    AlgorithmValidator.validateFirstControlNotGreaterThanSecondControl(
                        {
                            firstControl: {
                                parameter: 'genTrainTimesteps',
                                scope: Scope.TRAINING_SETUP,
                            },
                            secondControl: {
                                parameter: 'totalTimesteps',
                                scope: Scope.TRAINING_SETUP,
                            },
                        }
                    ),
                    AlgorithmValidator.validateFirstControlNotGreaterThanSecondControl(
                        {
                            firstControl: {
                                parameter: 'demoMinibatchSize',
                                scope: Scope.SPECIFIC,
                            },
                            secondControl: {
                                parameter: 'batchSize',
                                scope: Scope.GENERAL,
                            },
                        }
                    ),
                ],
            }
        );
    }

    private getAdversarialTrainerTrainingSetupFormPrototype(
        trainingSetupParameters: Map<string, any>
    ): FormGroup {
        return new FormGroup({
            totalTimesteps: new FormControl<string>(
                this.asString('totalTimesteps', trainingSetupParameters),
                [
                    Validators.required,
                    AlgorithmValidator.validatePositiveIntegerGreaterOrEqualThanCustom(
                        2
                    ),
                ]
            ),
            genTrainTimesteps: new FormControl<string>(
                this.asString('genTrainTimesteps', trainingSetupParameters),
                [AlgorithmValidator.validatePositiveInteger]
            ),
        });
    }

    private getAdversarialTrainerSpecificFormPrototype(
        specificParameters: Map<string, any>
    ) {
        return new FormGroup({
            nDiscUpdatesPerRound: new FormControl<string>(
                this.asString('nDiscUpdatesPerRound', specificParameters),
                [
                    Validators.required,
                    AlgorithmValidator.validatePositiveInteger,
                ]
            ),
            demoMinibatchSize: new FormControl<string>(
                this.asString('demoMinibatchSize', specificParameters),
                [
                    Validators.required,
                    AlgorithmValidator.validatePowerOfTwo,
                    AlgorithmValidator.validateInteger,
                ]
            ),
            genReplayBufferCapacity: new FormControl<string>(
                this.asString('genReplayBufferCapacity', specificParameters),
                [AlgorithmValidator.validatePositiveInteger]
            ),
            allowVariableHorizon: new FormControl<boolean>(
                specificParameters.get('allowVariableHorizon'),
                [Validators.required]
            ),
            debugUseGroundTruth: new FormControl<boolean>(
                specificParameters.get('debugUseGroundTruth'),
                [Validators.required]
            ),
            inputAlgorithm: new FormControl<AlgorithmType>(
                specificParameters.get('inputAlgorithm') as AlgorithmType,
                [Validators.required]
            ),
        });
    }

    private asString(controlName: string, params: Map<string, Object>): string {
        let value = params.get(controlName);
        if (controlName === 'epsilon') {
            value = Math.log10(+value);
        }
        return nonNil(value) ? value.toString() : '';
    }

    private getScopes(isLearner: boolean): {
        generalScope: Scope;
        trainingSetupScope: Scope;
        optimizerScope: Scope;
        specificScope: Scope;
    } {
        if (isLearner) {
            return {
                generalScope: Scope.LEARNER_GENERAL,
                trainingSetupScope: Scope.LEARNER_TRAINING_SETUP,
                optimizerScope: Scope.LEARNER_OPTIMIZER,
                specificScope: Scope.LEARNER_SPECIFIC,
            };
        } else {
            return {
                generalScope: Scope.GENERAL,
                trainingSetupScope: Scope.TRAINING_SETUP,
                optimizerScope: Scope.OPTIMIZER,
                specificScope: Scope.SPECIFIC,
            };
        }
    }
}
