/*
 * Decompiled with CFR 0.152.
 */
package io.hops.hopsworks.common.featurestore.trainingdatasets;

import com.google.common.base.Joiner;
import io.hops.hopsworks.common.featurestore.FeaturestoreConstants;
import io.hops.hopsworks.common.featurestore.feature.TrainingDatasetFeatureDTO;
import io.hops.hopsworks.common.featurestore.query.Feature;
import io.hops.hopsworks.common.featurestore.query.Query;
import io.hops.hopsworks.common.featurestore.query.QueryController;
import io.hops.hopsworks.common.featurestore.query.filter.FilterLogicDTO;
import io.hops.hopsworks.common.featurestore.query.join.Join;
import io.hops.hopsworks.common.featurestore.statistics.columns.StatisticColumnController;
import io.hops.hopsworks.common.featurestore.storageconnectors.FeaturestoreConnectorFacade;
import io.hops.hopsworks.common.featurestore.storageconnectors.FeaturestoreStorageConnectorDTO;
import io.hops.hopsworks.common.featurestore.trainingdatasets.TrainingDatasetDTO;
import io.hops.hopsworks.common.featurestore.trainingdatasets.split.TrainingDatasetSplitDTO;
import io.hops.hopsworks.common.featurestore.utils.FeaturestoreInputValidation;
import io.hops.hopsworks.exceptions.FeaturestoreException;
import io.hops.hopsworks.persistence.entity.featurestore.featuregroup.Featuregroup;
import io.hops.hopsworks.persistence.entity.featurestore.storageconnector.FeaturestoreConnector;
import io.hops.hopsworks.persistence.entity.featurestore.storageconnector.FeaturestoreConnectorType;
import io.hops.hopsworks.persistence.entity.featurestore.trainingdataset.TrainingDatasetType;
import io.hops.hopsworks.persistence.entity.featurestore.trainingdataset.split.SplitName;
import io.hops.hopsworks.persistence.entity.featurestore.trainingdataset.split.SplitType;
import io.hops.hopsworks.restutils.RESTCodes;
import java.util.ArrayList;
import java.util.Date;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.logging.Level;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import javax.ejb.EJB;
import javax.ejb.Stateless;
import javax.ejb.TransactionAttribute;
import javax.ejb.TransactionAttributeType;
import joptsimple.internal.Strings;
import org.apache.commons.lang3.StringUtils;

@Stateless
@TransactionAttribute(value=TransactionAttributeType.NEVER)
public class TrainingDatasetInputValidation {
    @EJB
    private FeaturestoreInputValidation featurestoreInputValidation;
    @EJB
    private StatisticColumnController statisticColumnController;
    @EJB
    private FeaturestoreConnectorFacade connectorFacade;
    @EJB
    private QueryController queryController;

    public void verifyUserInput(TrainingDatasetDTO trainingDatasetDTO) throws FeaturestoreException {
        this.featurestoreInputValidation.verifyUserInput(trainingDatasetDTO);
        if (trainingDatasetDTO.getQueryDTO() == null && trainingDatasetDTO.getFeatures() != null) {
            this.verifyTrainingDatasetFeatureList(trainingDatasetDTO.getFeatures());
        }
    }

    public void verifyTrainingDatasetFeatureList(List<TrainingDatasetFeatureDTO> trainingDatasetFeatureDTOS) throws FeaturestoreException {
        for (TrainingDatasetFeatureDTO trainingDatasetFeatureDTO : trainingDatasetFeatureDTOS) {
            this.featurestoreInputValidation.nameValidation(trainingDatasetFeatureDTO.getName());
        }
    }

    public void validate(TrainingDatasetDTO trainingDatasetDTO, Query query) throws FeaturestoreException {
        this.verifyUserInput(trainingDatasetDTO);
        this.statisticColumnController.verifyStatisticColumnsExist(trainingDatasetDTO, query);
        this.validateType(trainingDatasetDTO.getTrainingDatasetType());
        this.validateVersion(trainingDatasetDTO.getVersion());
        this.validateDataFormat(trainingDatasetDTO.getDataFormat());
        String eventTimeFieldName = query != null ? query.getFeaturegroup().getEventTime() : null;
        this.validateSplits(trainingDatasetDTO.getSplits(), eventTimeFieldName);
        this.validateFeatures(query, trainingDatasetDTO.getFeatures());
        this.validateStorageConnector(trainingDatasetDTO.getStorageConnector());
        this.validateTrainSplit(trainingDatasetDTO.getTrainSplit(), trainingDatasetDTO.getSplits());
        this.validateExtraFilter(trainingDatasetDTO, query);
    }

    private void validateExtraFilter(TrainingDatasetDTO trainingDatasetDTO, Query query) throws FeaturestoreException {
        FilterLogicDTO filterLogicDTO = trainingDatasetDTO.getExtraFilter();
        if (filterLogicDTO != null) {
            this.validateExtraFilter(trainingDatasetDTO.getExtraFilter(), this.queryController.getFeatureGroups(query).stream().map(Featuregroup::getId).collect(Collectors.toSet()));
        }
    }

    private void validateExtraFilter(FilterLogicDTO filterLogicDTO, Set<Integer> allFgs) throws FeaturestoreException {
        if (filterLogicDTO.getLeftFilter() != null && !allFgs.contains(filterLogicDTO.getLeftFilter().getFeature().getFeatureGroupId()) || filterLogicDTO.getRightFilter() != null && !allFgs.contains(filterLogicDTO.getRightFilter().getFeature().getFeatureGroupId())) {
            throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.FEATUREGROUP_NOT_FOUND, Level.FINE, String.format("Feature '%s' is from feature group with id '%d' which is not available in the feature view's `Query`. All available feature group id in this query are [%s].", filterLogicDTO.getLeftFilter().getFeature().getName(), filterLogicDTO.getLeftFilter().getFeature().getFeatureGroupId(), Joiner.on((String)", ").join(allFgs)));
        }
        if (filterLogicDTO.getLeftLogic() != null) {
            this.validateExtraFilter(filterLogicDTO.getLeftLogic(), allFgs);
        }
        if (filterLogicDTO.getRightLogic() != null) {
            this.validateExtraFilter(filterLogicDTO.getRightLogic(), allFgs);
        }
    }

    private void validateType(TrainingDatasetType trainingDatasetType) throws FeaturestoreException {
        if (trainingDatasetType != TrainingDatasetType.HOPSFS_TRAINING_DATASET && trainingDatasetType != TrainingDatasetType.EXTERNAL_TRAINING_DATASET && trainingDatasetType != TrainingDatasetType.IN_MEMORY_TRAINING_DATASET) {
            throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.ILLEGAL_TRAINING_DATASET_TYPE, Level.FINE, ", Recognized Training Dataset types are: " + TrainingDatasetType.HOPSFS_TRAINING_DATASET + ", and: " + TrainingDatasetType.EXTERNAL_TRAINING_DATASET + ", and: " + TrainingDatasetType.IN_MEMORY_TRAINING_DATASET + ". The provided training dataset type was not recognized: " + trainingDatasetType);
        }
    }

    private void validateVersion(Integer version) throws FeaturestoreException {
        if (version == null) {
            throw new IllegalArgumentException(RESTCodes.FeaturestoreErrorCode.TRAINING_DATASET_VERSION_NOT_PROVIDED.getMessage());
        }
        if (version <= 0) {
            throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.ILLEGAL_TRAINING_DATASET_VERSION, Level.FINE, " version cannot be negative or zero");
        }
    }

    private void validateDataFormat(String dataFormat) throws FeaturestoreException {
        if (!FeaturestoreConstants.TRAINING_DATASET_DATA_FORMATS.contains(dataFormat)) {
            throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.ILLEGAL_TRAINING_DATASET_DATA_FORMAT, Level.FINE, ", the recognized training dataset formats are: " + StringUtils.join((Object[])new List[]{FeaturestoreConstants.TRAINING_DATASET_DATA_FORMATS}) + ". The provided data format:" + dataFormat + " was not recognized.");
        }
    }

    void validateSplits(List<TrainingDatasetSplitDTO> trainingDatasetSplitDTOs, String eventTimeFieldName) throws FeaturestoreException {
        if (trainingDatasetSplitDTOs != null && !trainingDatasetSplitDTOs.isEmpty()) {
            Pattern namePattern = FeaturestoreConstants.FEATURESTORE_REGEX;
            HashSet<String> splitNames = new HashSet<String>();
            Boolean isTimeSplit = false;
            Date trainStart = null;
            Date trainEnd = null;
            Date validationStart = null;
            Date validationEnd = null;
            Date testStart = null;
            Date testEnd = null;
            for (TrainingDatasetSplitDTO trainingDatasetSplitDTO : trainingDatasetSplitDTOs) {
                if (!namePattern.matcher(trainingDatasetSplitDTO.getName()).matches()) {
                    throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.ILLEGAL_TRAINING_DATASET_SPLIT_NAME, Level.FINE, "The provided training dataset split name " + trainingDatasetSplitDTO.getName() + " is invalid. Split names can only contain lower case characters, numbers and underscores and cannot be longer than " + 63 + " characters or empty.");
                }
                if (SplitType.RANDOM_SPLIT.equals((Object)trainingDatasetSplitDTO.getSplitType()) && trainingDatasetSplitDTO.getPercentage() == null) {
                    throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.ILLEGAL_TRAINING_DATASET_SPLIT_PERCENTAGE, Level.FINE, "The provided training dataset split percentage is invalid. Percentages can only be numeric. Weights will be normalized if they don\u2019t sum up to 1.0.");
                }
                if (SplitType.RANDOM_SPLIT.equals((Object)trainingDatasetSplitDTO.getSplitType()) && trainingDatasetSplitDTO.getPercentage().floatValue() <= 0.0f) {
                    throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.ILLEGAL_TRAINING_DATASET_SPLIT_PERCENTAGE, Level.FINE, "The provided training dataset split percentage is invalid. Weights must be greater than 0.");
                }
                if (!splitNames.add(trainingDatasetSplitDTO.getName())) {
                    throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.TRAINING_DATASET_DUPLICATE_SPLIT_NAMES, Level.FINE, " The split names must be unique.");
                }
                if (SplitType.TIME_SERIES_SPLIT.equals((Object)trainingDatasetSplitDTO.getSplitType())) {
                    isTimeSplit = true;
                    if (Strings.isNullOrEmpty((String)eventTimeFieldName)) {
                        throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.EVENT_TIME_FEATURE_NOT_FOUND, Level.FINE, "Failed to define time series split because event time column is not available in one or more feature groups.");
                    }
                }
                if (!SplitType.TIME_SERIES_SPLIT.equals((Object)trainingDatasetSplitDTO.getSplitType())) continue;
                if (SplitName.TRAIN.getName().equals(trainingDatasetSplitDTO.getName())) {
                    trainStart = trainingDatasetSplitDTO.getStartTime();
                    trainEnd = trainingDatasetSplitDTO.getEndTime();
                    if (trainStart == null || trainEnd == null) {
                        throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.ILLEGAL_TRAINING_DATASET_TIME_SERIES_SPLIT, Level.FINE, "Start/end time of train split is/are not provided.");
                    }
                }
                if (SplitName.VALIDATION.getName().equals(trainingDatasetSplitDTO.getName())) {
                    validationStart = trainingDatasetSplitDTO.getStartTime();
                    validationEnd = trainingDatasetSplitDTO.getEndTime();
                    if (validationStart == null || validationEnd == null) {
                        throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.ILLEGAL_TRAINING_DATASET_TIME_SERIES_SPLIT, Level.FINE, "Start/end time of validation split is/are not provided.");
                    }
                }
                if (!SplitName.TEST.getName().equals(trainingDatasetSplitDTO.getName())) continue;
                testStart = trainingDatasetSplitDTO.getStartTime();
                testEnd = trainingDatasetSplitDTO.getEndTime();
                if (testStart != null && testEnd != null) continue;
                throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.ILLEGAL_TRAINING_DATASET_TIME_SERIES_SPLIT, Level.FINE, "Start/end time of test split is/are not provided.");
            }
            if (isTimeSplit.booleanValue()) {
                if (trainStart != null && trainStart.getTime() > trainEnd.getTime()) {
                    throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.ILLEGAL_TRAINING_DATASET_TIME_SERIES_SPLIT, Level.FINE, "End time of the train split should be greater than or equal to the start time.");
                }
                if (validationStart != null && validationStart.getTime() > validationEnd.getTime()) {
                    throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.ILLEGAL_TRAINING_DATASET_TIME_SERIES_SPLIT, Level.FINE, "End time of the validation split should be greater than or equal to the start time.");
                }
                if (testStart != null && testStart.getTime() > testEnd.getTime()) {
                    throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.ILLEGAL_TRAINING_DATASET_TIME_SERIES_SPLIT, Level.FINE, "End time of the test split should be greater than or equal to the start time.");
                }
                if (validationStart != null) {
                    if (validationStart.getTime() == trainStart.getTime() || validationStart.getTime() < trainEnd.getTime()) {
                        throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.ILLEGAL_TRAINING_DATASET_TIME_SERIES_SPLIT, Level.FINE, "Start time of the validation split should be greater than the start/end time of train split.");
                    }
                    if (testStart.getTime() == validationStart.getTime() || testStart.getTime() < validationEnd.getTime()) {
                        throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.ILLEGAL_TRAINING_DATASET_TIME_SERIES_SPLIT, Level.FINE, "Start time of the test split should be greater than the start/end time of validation split.");
                    }
                }
                if (testStart.getTime() == trainStart.getTime() || testStart.getTime() < trainEnd.getTime()) {
                    throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.ILLEGAL_TRAINING_DATASET_TIME_SERIES_SPLIT, Level.FINE, "Start time of the test split should be greater than the start/end time of train split.");
                }
            }
        }
    }

    public void validateFeatures(Query query, List<TrainingDatasetFeatureDTO> featuresDTOs) throws FeaturestoreException {
        if (query == null || featuresDTOs == null) {
            return;
        }
        List labels = featuresDTOs.stream().filter(TrainingDatasetFeatureDTO::getLabel).collect(Collectors.toList());
        List featuresWithTransformation = featuresDTOs.stream().filter(f -> f.getTransformationFunction() != null).collect(Collectors.toList());
        List inferenceHelperColumns = featuresDTOs.stream().filter(TrainingDatasetFeatureDTO::getInferenceHelperColumn).collect(Collectors.toList());
        List trainingHelperColumns = featuresDTOs.stream().filter(TrainingDatasetFeatureDTO::getTrainingHelperColumn).collect(Collectors.toList());
        List<Feature> features = this.collectFeatures(query);
        for (TrainingDatasetFeatureDTO label : labels) {
            if (!features.stream().noneMatch(f -> f.getName().equals(label.getName()))) continue;
            throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.LABEL_NOT_FOUND, Level.FINE, "Label: " + label.getName() + " is missing");
        }
        for (TrainingDatasetFeatureDTO inferenceHelperColumn : inferenceHelperColumns) {
            if (!features.stream().noneMatch(f -> f.getName().equals(inferenceHelperColumn.getName()))) continue;
            throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.LABEL_NOT_FOUND, Level.FINE, "Inference helper column: " + inferenceHelperColumn.getName() + " is missing");
        }
        for (TrainingDatasetFeatureDTO trainingHelperColumn : trainingHelperColumns) {
            if (!features.stream().noneMatch(f -> f.getName().equals(trainingHelperColumn.getName()))) continue;
            throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.HELPER_COL_NOT_FOUND, Level.FINE, "Training helper column: " + trainingHelperColumn.getName() + " is missing");
        }
        for (TrainingDatasetFeatureDTO featureWithTransformation : featuresWithTransformation) {
            if (!features.stream().noneMatch(f -> f.getName().equals(featureWithTransformation.getFeatureGroupFeatureName()))) continue;
            throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.FEATURE_WITH_TRANSFORMATION_NOT_FOUND, Level.FINE, "feature: " + featureWithTransformation.getName() + " is missing and transformation function can't be attached");
        }
        if (query != null && query.getJoins() != null) {
            for (Join join : query.getJoins()) {
                Pattern namePattern;
                if (join.getPrefix() == null || (namePattern = FeaturestoreConstants.FEATURESTORE_REGEX).matcher(join.getPrefix()).matches()) continue;
                throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.ILLEGAL_PREFIX_NAME, Level.FINE, ", the provided prefix name " + join.getPrefix() + " is invalid. Prefix names can only contain lower case characters, numbers and underscores and cannot be longer than " + 63 + " characters or empty.");
            }
        }
    }

    private List<Feature> collectFeatures(Query query) {
        ArrayList<Feature> features = new ArrayList<Feature>(query.getFeatures());
        if (query.getJoins() != null) {
            for (Join join : query.getJoins()) {
                features.addAll(this.collectFeatures(join.getRightQuery()));
            }
        }
        return features;
    }

    private void validateStorageConnector(FeaturestoreStorageConnectorDTO connectorDTO) throws FeaturestoreException {
        if (connectorDTO == null || connectorDTO.getId() == null) {
            return;
        }
        FeaturestoreConnector connector = this.connectorFacade.findById(connectorDTO.getId()).orElseThrow(() -> new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.CONNECTOR_NOT_FOUND, Level.FINE, "Connector ID: " + connectorDTO.getId()));
        if (connector.getConnectorType() != FeaturestoreConnectorType.HOPSFS && connector.getConnectorType() != FeaturestoreConnectorType.S3 && connector.getConnectorType() != FeaturestoreConnectorType.ADLS && connector.getConnectorType() != FeaturestoreConnectorType.GCS) {
            throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.ILLEGAL_STORAGE_CONNECTOR_TYPE, Level.FINE, "Only HopsFS, S3, ADLS and GCS storage connectors can be used to create training datasets");
        }
    }

    void validateTrainSplit(String trainSplit, List<TrainingDatasetSplitDTO> splits) throws FeaturestoreException {
        if ((splits == null || splits.isEmpty()) && !Strings.isNullOrEmpty((String)trainSplit)) {
            throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.ILLEGAL_TRAINING_DATASET_SPLIT_NAME, Level.FINE, "Training data split name provided without splitting the dataset.");
        }
        if (splits != null && !splits.isEmpty() && !splits.stream().map(TrainingDatasetSplitDTO::getName).collect(Collectors.toList()).contains(trainSplit)) {
            throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.ILLEGAL_TRAINING_DATASET_SPLIT_NAME, Level.FINE, "The provided training data split name `" + trainSplit + "` could not be found among the provided splits.");
        }
    }
}

