package io.hops.hopsworks.common.featurestore.trainingdatasets;

import io.hops.hopsworks.common.featurestore.FeaturestoreConstants;
import io.hops.hopsworks.common.featurestore.feature.TrainingDatasetFeatureDTO;
import io.hops.hopsworks.common.featurestore.query.Feature;
import io.hops.hopsworks.common.featurestore.query.Query;
import io.hops.hopsworks.common.featurestore.query.join.Join;
import io.hops.hopsworks.common.featurestore.statistics.columns.StatisticColumnController;
import io.hops.hopsworks.common.featurestore.storageconnectors.FeaturestoreConnectorFacade;
import io.hops.hopsworks.common.featurestore.storageconnectors.FeaturestoreStorageConnectorDTO;
import io.hops.hopsworks.common.featurestore.trainingdatasets.split.TrainingDatasetSplitDTO;
import io.hops.hopsworks.common.featurestore.utils.FeaturestoreInputValidation;
import io.hops.hopsworks.exceptions.FeaturestoreException;
import io.hops.hopsworks.persistence.entity.featurestore.storageconnector.FeaturestoreConnector;
import io.hops.hopsworks.persistence.entity.featurestore.storageconnector.FeaturestoreConnectorType;
import io.hops.hopsworks.persistence.entity.featurestore.trainingdataset.TrainingDatasetType;
import io.hops.hopsworks.restutils.RESTCodes;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.logging.Level;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import javax.ejb.EJB;
import javax.ejb.Stateless;
import javax.ejb.TransactionAttribute;
import javax.ejb.TransactionAttributeType;
import joptsimple.internal.Strings;
import org.apache.commons.lang3.StringUtils;

@TransactionAttribute(TransactionAttributeType.NEVER)
@Stateless
/* loaded from: input_file:io/hops/hopsworks/common/featurestore/trainingdatasets/TrainingDatasetInputValidation.class */
public class TrainingDatasetInputValidation {

    @EJB
    private FeaturestoreInputValidation featurestoreInputValidation;

    @EJB
    private StatisticColumnController statisticColumnController;

    @EJB
    private FeaturestoreConnectorFacade connectorFacade;

    public void verifyUserInput(TrainingDatasetDTO trainingDatasetDTO) throws FeaturestoreException {
        this.featurestoreInputValidation.verifyUserInput(trainingDatasetDTO);
        if (trainingDatasetDTO.getQueryDTO() != null || trainingDatasetDTO.getFeatures() == null) {
            return;
        }
        verifyTrainingDatasetFeatureList(trainingDatasetDTO.getFeatures());
    }

    public void verifyTrainingDatasetFeatureList(List<TrainingDatasetFeatureDTO> list) throws FeaturestoreException {
        Iterator<TrainingDatasetFeatureDTO> it = list.iterator();
        while (it.hasNext()) {
            this.featurestoreInputValidation.nameValidation(it.next().getName());
        }
    }

    public void validate(TrainingDatasetDTO trainingDatasetDTO, Query query) throws FeaturestoreException {
        verifyUserInput(trainingDatasetDTO);
        this.statisticColumnController.verifyStatisticColumnsExist(trainingDatasetDTO, query);
        validateType(trainingDatasetDTO.getTrainingDatasetType());
        validateVersion(trainingDatasetDTO.getVersion());
        validateDataFormat(trainingDatasetDTO.getDataFormat());
        validateSplits(trainingDatasetDTO.getSplits());
        validateFeatures(query, trainingDatasetDTO.getFeatures());
        validateStorageConnector(trainingDatasetDTO.getStorageConnector());
        validateTrainSplit(trainingDatasetDTO.getTrainSplit(), trainingDatasetDTO.getSplits());
    }

    private void validateType(TrainingDatasetType trainingDatasetType) throws FeaturestoreException {
        if (trainingDatasetType != TrainingDatasetType.HOPSFS_TRAINING_DATASET && trainingDatasetType != TrainingDatasetType.EXTERNAL_TRAINING_DATASET && trainingDatasetType != TrainingDatasetType.IN_MEMORY_TRAINING_DATASET) {
            throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.ILLEGAL_TRAINING_DATASET_TYPE, Level.FINE, ", Recognized Training Dataset types are: " + TrainingDatasetType.HOPSFS_TRAINING_DATASET + ", and: " + TrainingDatasetType.EXTERNAL_TRAINING_DATASET + ", and: " + TrainingDatasetType.IN_MEMORY_TRAINING_DATASET + ". The provided training dataset type was not recognized: " + trainingDatasetType);
        }
    }

    private void validateVersion(Integer num) throws FeaturestoreException {
        if (num == null) {
            throw new IllegalArgumentException(RESTCodes.FeaturestoreErrorCode.TRAINING_DATASET_VERSION_NOT_PROVIDED.getMessage());
        }
        if (num.intValue() <= 0) {
            throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.ILLEGAL_TRAINING_DATASET_VERSION, Level.FINE, " version cannot be negative or zero");
        }
    }

    private void validateDataFormat(String str) throws FeaturestoreException {
        if (!FeaturestoreConstants.TRAINING_DATASET_DATA_FORMATS.contains(str)) {
            throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.ILLEGAL_TRAINING_DATASET_DATA_FORMAT, Level.FINE, ", the recognized training dataset formats are: " + StringUtils.join(new List[]{FeaturestoreConstants.TRAINING_DATASET_DATA_FORMATS}) + ". The provided data format:" + str + " was not recognized.");
        }
    }

    private void validateSplits(List<TrainingDatasetSplitDTO> list) throws FeaturestoreException {
        if (list == null || list.isEmpty()) {
            return;
        }
        Pattern pattern = FeaturestoreConstants.FEATURESTORE_REGEX;
        HashSet hashSet = new HashSet();
        for (TrainingDatasetSplitDTO trainingDatasetSplitDTO : list) {
            if (!pattern.matcher(trainingDatasetSplitDTO.getName()).matches()) {
                throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.ILLEGAL_TRAINING_DATASET_SPLIT_NAME, Level.FINE, ", the provided training dataset split name " + trainingDatasetSplitDTO.getName() + " is invalid. Split names can only contain lower case characters, numbers and underscores and cannot be longer than 63 characters or empty.");
            }
            if (trainingDatasetSplitDTO.getPercentage() == null) {
                throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.ILLEGAL_TRAINING_DATASET_SPLIT_PERCENTAGE, Level.FINE, ", the provided training dataset split percentage is invalid. Percentages can only be numeric. Weights will be normalized if they don’t sum up to 1.0.");
            }
            if (trainingDatasetSplitDTO.getPercentage().floatValue() <= 0.0f) {
                throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.ILLEGAL_TRAINING_DATASET_SPLIT_PERCENTAGE, Level.FINE, ", the provided training dataset split percentage is invalid. Weights must be greater than 0.");
            }
            if (!hashSet.add(trainingDatasetSplitDTO.getName())) {
                throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.TRAINING_DATASET_DUPLICATE_SPLIT_NAMES, Level.FINE, " The split names must be unique");
            }
        }
    }

    public void validateFeatures(Query query, List<TrainingDatasetFeatureDTO> list) throws FeaturestoreException {
        if (query == null || list == null) {
            return;
        }
        List<TrainingDatasetFeatureDTO> list2 = (List) list.stream().filter((v0) -> {
            return v0.getLabel();
        }).collect(Collectors.toList());
        List<TrainingDatasetFeatureDTO> list3 = (List) list.stream().filter(trainingDatasetFeatureDTO -> {
            return trainingDatasetFeatureDTO.getTransformationFunction() != null;
        }).collect(Collectors.toList());
        List<Feature> collectFeatures = collectFeatures(query);
        for (TrainingDatasetFeatureDTO trainingDatasetFeatureDTO2 : list2) {
            if (collectFeatures.stream().noneMatch(feature -> {
                return feature.getName().equals(trainingDatasetFeatureDTO2.getName());
            })) {
                throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.LABEL_NOT_FOUND, Level.FINE, "Label: " + trainingDatasetFeatureDTO2.getName() + " is missing");
            }
        }
        for (TrainingDatasetFeatureDTO trainingDatasetFeatureDTO3 : list3) {
            if (collectFeatures.stream().noneMatch(feature2 -> {
                return feature2.getName().equals(trainingDatasetFeatureDTO3.getFeatureGroupFeatureName());
            })) {
                throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.FEATURE_WITH_TRANSFORMATION_NOT_FOUND, Level.FINE, "feature: " + trainingDatasetFeatureDTO3.getName() + " is missing and transformation function can't be attached");
            }
        }
        if (query == null || query.getJoins() == null) {
            return;
        }
        for (Join join : query.getJoins()) {
            if (join.getPrefix() != null && !FeaturestoreConstants.FEATURESTORE_REGEX.matcher(join.getPrefix()).matches()) {
                throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.ILLEGAL_PREFIX_NAME, Level.FINE, ", the provided prefix name " + join.getPrefix() + " is invalid. Prefix names can only contain lower case characters, numbers and underscores and cannot be longer than 63 characters or empty.");
            }
        }
    }

    private List<Feature> collectFeatures(Query query) {
        ArrayList arrayList = new ArrayList(query.getFeatures());
        if (query.getJoins() != null) {
            Iterator<Join> it = query.getJoins().iterator();
            while (it.hasNext()) {
                arrayList.addAll(collectFeatures(it.next().getRightQuery()));
            }
        }
        return arrayList;
    }

    private void validateStorageConnector(FeaturestoreStorageConnectorDTO featurestoreStorageConnectorDTO) throws FeaturestoreException {
        if (featurestoreStorageConnectorDTO == null || featurestoreStorageConnectorDTO.getId() == null) {
            return;
        }
        FeaturestoreConnector orElseThrow = this.connectorFacade.findById(featurestoreStorageConnectorDTO.getId()).orElseThrow(() -> {
            return new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.CONNECTOR_NOT_FOUND, Level.FINE, "Connector ID: " + featurestoreStorageConnectorDTO.getId());
        });
        if (orElseThrow.getConnectorType() != FeaturestoreConnectorType.HOPSFS && orElseThrow.getConnectorType() != FeaturestoreConnectorType.S3 && orElseThrow.getConnectorType() != FeaturestoreConnectorType.ADLS && orElseThrow.getConnectorType() != FeaturestoreConnectorType.GCS) {
            throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.ILLEGAL_STORAGE_CONNECTOR_TYPE, Level.FINE, "Only HopsFS, S3, ADLS and GCS storage connectors can be used to create training datasets");
        }
    }

    void validateTrainSplit(String str, List<TrainingDatasetSplitDTO> list) throws FeaturestoreException {
        if ((list == null || list.isEmpty()) && !Strings.isNullOrEmpty(str)) {
            throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.ILLEGAL_TRAINING_DATASET_SPLIT_NAME, Level.FINE, "Training data split name provided without splitting the dataset.");
        }
        if (list != null && !list.isEmpty() && !((List) list.stream().map((v0) -> {
            return v0.getName();
        }).collect(Collectors.toList())).contains(str)) {
            throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.ILLEGAL_TRAINING_DATASET_SPLIT_NAME, Level.FINE, "The provided training data split name `" + str + "` could not be found among the provided splits.");
        }
    }
}
