/*
 * Decompiled with CFR 0.152.
 */
package com.logicalclocks.hsfs.engine;

import com.logicalclocks.hsfs.FeatureStoreException;
import com.logicalclocks.hsfs.StorageConnector;
import com.logicalclocks.hsfs.StorageConnectorType;
import com.logicalclocks.hsfs.TrainingDatasetFeature;
import com.logicalclocks.hsfs.TrainingDatasetType;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Optional;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

public class TrainingDatasetUtils {
    public static void setLabelFeature(List<TrainingDatasetFeature> features, List<String> labels) throws FeatureStoreException {
        if (labels != null && !labels.isEmpty()) {
            for (String label : labels) {
                Optional<TrainingDatasetFeature> feature = features.stream().filter(f -> f.getName().equals(label)).findFirst();
                if (feature.isPresent()) {
                    feature.get().setLabel(true);
                    continue;
                }
                throw new FeatureStoreException("The specified label `" + label + "` could not be found among the features: " + features.stream().map(TrainingDatasetFeature::getName) + ".");
            }
        }
    }

    public List<TrainingDatasetFeature> parseTrainingDatasetSchema(Dataset<Row> dataset) throws FeatureStoreException {
        ArrayList<TrainingDatasetFeature> features = new ArrayList<TrainingDatasetFeature>();
        int index = 0;
        for (StructField structField : dataset.schema().fields()) {
            features.add(new TrainingDatasetFeature(structField.name().toLowerCase(), structField.dataType().catalogString(), index++));
        }
        return features;
    }

    public void trainingDatasetSchemaMatch(Dataset<Row> dataset, List<TrainingDatasetFeature> features) throws FeatureStoreException {
        StructType tdStructType = new StructType((StructField[])features.stream().sorted(Comparator.comparingInt(TrainingDatasetFeature::getIndex)).map(f -> new StructField(f.getName(), new CatalystSqlParser().parseDataType(f.getType()), true, Metadata.empty())).toArray(StructField[]::new));
        if (!dataset.schema().equals((Object)tdStructType)) {
            throw new FeatureStoreException("The Dataframe schema: " + dataset.schema() + " does not match the training dataset schema: " + tdStructType);
        }
    }

    public TrainingDatasetType getTrainingDatasetType(StorageConnector storageConnector) {
        if (storageConnector == null) {
            return TrainingDatasetType.HOPSFS_TRAINING_DATASET;
        }
        if (storageConnector.getStorageConnectorType() == StorageConnectorType.HOPSFS) {
            return TrainingDatasetType.HOPSFS_TRAINING_DATASET;
        }
        return TrainingDatasetType.EXTERNAL_TRAINING_DATASET;
    }
}

