package com.logicalclocks.hsfs.engine;

import com.google.common.base.Strings;
import com.logicalclocks.hsfs.EntityEndpointType;
import com.logicalclocks.hsfs.FeatureStoreException;
import com.logicalclocks.hsfs.Storage;
import com.logicalclocks.hsfs.TrainingDataset;
import com.logicalclocks.hsfs.TrainingDatasetFeature;
import com.logicalclocks.hsfs.constructor.ServingPreparedStatement;
import com.logicalclocks.hsfs.metadata.HopsworksClient;
import com.logicalclocks.hsfs.metadata.StorageConnectorApi;
import com.logicalclocks.hsfs.metadata.TagsApi;
import com.logicalclocks.hsfs.metadata.TrainingDatasetApi;
import com.logicalclocks.hsfs.util.Constants;
import java.io.IOException;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.TreeMap;
import org.apache.avro.Schema;
import org.apache.avro.generic.GenericDatumReader;
import org.apache.avro.io.BinaryDecoder;
import org.apache.avro.io.DatumReader;
import org.apache.avro.io.DecoderFactory;
import org.apache.hadoop.fs.Path;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SaveMode;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/logicalclocks/hsfs/engine/TrainingDatasetEngine.class */
public class TrainingDatasetEngine {
    private TrainingDatasetApi trainingDatasetApi = new TrainingDatasetApi();
    private TagsApi tagsApi = new TagsApi(EntityEndpointType.TRAINING_DATASET);
    private StorageConnectorApi storageConnectorApi = new StorageConnectorApi();
    private Utils utils = new Utils();
    private Schema.Parser parser = new Schema.Parser();
    private BinaryDecoder binaryDecoder = DecoderFactory.get().binaryDecoder(new byte[0], (BinaryDecoder) null);
    private static final Logger LOGGER = LoggerFactory.getLogger(TrainingDatasetEngine.class);

    public TrainingDataset save(TrainingDataset trainingDataset, Dataset<Row> dataset, Map<String, String> map, List<String> list) throws FeatureStoreException, IOException {
        trainingDataset.setFeatures(this.utils.parseTrainingDatasetSchema(dataset));
        if (list != null && !list.isEmpty()) {
            for (String str : list) {
                Optional<TrainingDatasetFeature> findFirst = trainingDataset.getFeatures().stream().filter(trainingDatasetFeature -> {
                    return trainingDatasetFeature.getName().equals(str);
                }).findFirst();
                if (!findFirst.isPresent()) {
                    throw new FeatureStoreException("The specified label `" + str + "` could not be found among the features: " + trainingDataset.getFeatures().stream().map((v0) -> {
                        return v0.getName();
                    }) + ".");
                }
                findFirst.get().setLabel(true);
            }
        }
        TrainingDataset createTrainingDataset = this.trainingDatasetApi.createTrainingDataset(trainingDataset);
        if (trainingDataset.getVersion() == null) {
            LOGGER.info("VersionWarning: No version provided for creating training dataset `" + trainingDataset.getName() + "`, incremented version to `" + createTrainingDataset.getVersion() + "`.");
        }
        trainingDataset.setLocation(createTrainingDataset.getLocation());
        trainingDataset.setVersion(createTrainingDataset.getVersion());
        trainingDataset.setId(createTrainingDataset.getId());
        trainingDataset.setStorageConnector(createTrainingDataset.getStorageConnector());
        SparkEngine.getInstance().write(trainingDataset, dataset, SparkEngine.getInstance().getWriteOptions(map, trainingDataset.getDataFormat()), SaveMode.Overwrite);
        return trainingDataset;
    }

    public void insert(TrainingDataset trainingDataset, Dataset<Row> dataset, Map<String, String> map, SaveMode saveMode) throws FeatureStoreException, IOException {
        this.utils.trainingDatasetSchemaMatch(dataset, trainingDataset.getFeatures());
        if (this.trainingDatasetApi.getTransformationFunctions(trainingDataset).size() > 0) {
            throw new FeatureStoreException("This training dataset has transformation functions attached and insert operation must be performed from a PySpark application");
        }
        SparkEngine.getInstance().write(trainingDataset, dataset, SparkEngine.getInstance().getWriteOptions(map, trainingDataset.getDataFormat()), saveMode);
    }

    public Dataset<Row> read(TrainingDataset trainingDataset, String str, Map<String, String> map) throws FeatureStoreException, IOException {
        Map<String, String> readOptions = SparkEngine.getInstance().getReadOptions(map, trainingDataset.getDataFormat());
        String location = trainingDataset.getLocation();
        if (!Strings.isNullOrEmpty(str)) {
            location = new Path(trainingDataset.getLocation(), str).toString();
        }
        return trainingDataset.getStorageConnector().read(null, trainingDataset.getDataFormat().toString(), readOptions, location);
    }

    public void addTag(TrainingDataset trainingDataset, String str, Object obj) throws FeatureStoreException, IOException {
        this.tagsApi.add(trainingDataset, str, obj);
    }

    public Map<String, Object> getTags(TrainingDataset trainingDataset) throws FeatureStoreException, IOException {
        return this.tagsApi.get(trainingDataset);
    }

    public Object getTag(TrainingDataset trainingDataset, String str) throws FeatureStoreException, IOException {
        return this.tagsApi.get(trainingDataset, str);
    }

    public void deleteTag(TrainingDataset trainingDataset, String str) throws FeatureStoreException, IOException {
        this.tagsApi.deleteTag(trainingDataset, str);
    }

    public String getQuery(TrainingDataset trainingDataset, Storage storage, boolean z) throws FeatureStoreException, IOException {
        return this.trainingDatasetApi.getQuery(trainingDataset, z).getStorageQuery(storage);
    }

    public void updateStatisticsConfig(TrainingDataset trainingDataset) throws FeatureStoreException, IOException {
        TrainingDataset updateMetadata = this.trainingDatasetApi.updateMetadata(trainingDataset, "updateStatsConfig");
        trainingDataset.getStatisticsConfig().setCorrelations(updateMetadata.getStatisticsConfig().getCorrelations());
        trainingDataset.getStatisticsConfig().setHistograms(updateMetadata.getStatisticsConfig().getHistograms());
    }

    public void initPreparedStatement(TrainingDataset trainingDataset, boolean z) throws FeatureStoreException, IOException, SQLException {
        if (this.trainingDatasetApi.getTransformationFunctions(trainingDataset).size() > 0) {
            throw new FeatureStoreException("This training dataset has transformation functions attached and serving must performed from a Python application");
        }
        Map<String, String> sparkOptions = this.storageConnectorApi.getOnlineStorageConnector(trainingDataset.getFeatureStore()).sparkOptions();
        String str = sparkOptions.get(Constants.JDBC_URL);
        if (z) {
            str = str.replaceAll("/[0-9.]+:", "/" + HopsworksClient.getInstance().getHost() + ":");
        }
        Connection connection = DriverManager.getConnection(str, sparkOptions.get(Constants.JDBC_USER), sparkOptions.get(Constants.JDBC_PWD));
        connection.setAutoCommit(false);
        trainingDataset.setPreparedStatementConnection(connection);
        List<ServingPreparedStatement> servingPreparedStatement = this.trainingDatasetApi.getServingPreparedStatement(trainingDataset);
        Map<Integer, Map<String, Integer>> hashMap = new HashMap<>();
        TreeMap<Integer, PreparedStatement> treeMap = new TreeMap<>();
        HashSet<String> hashSet = new HashSet<>();
        for (ServingPreparedStatement servingPreparedStatement2 : servingPreparedStatement) {
            treeMap.put(servingPreparedStatement2.getPreparedStatementIndex(), connection.prepareStatement(servingPreparedStatement2.getQueryOnline()));
            HashMap hashMap2 = new HashMap();
            servingPreparedStatement2.getPreparedStatementParameters().forEach(preparedStatementParameter -> {
                hashSet.add(preparedStatementParameter.getName());
                hashMap2.put(preparedStatementParameter.getName(), preparedStatementParameter.getIndex());
            });
            hashMap.put(servingPreparedStatement2.getPreparedStatementIndex(), hashMap2);
        }
        trainingDataset.setServingKeys(hashSet);
        trainingDataset.setPreparedStatementParameters(hashMap);
        trainingDataset.setPreparedStatements(treeMap);
    }

    public List<Object> getServingVector(TrainingDataset trainingDataset, Map<String, Object> map, boolean z) throws SQLException, FeatureStoreException, IOException {
        if (trainingDataset.getPreparedStatements() == null) {
            initPreparedStatement(trainingDataset, z);
        }
        if (!trainingDataset.getServingKeys().equals(map.keySet())) {
            throw new IllegalArgumentException("Provided primary key map doesn't correspond to serving_keys");
        }
        Map<Integer, Map<String, Integer>> preparedStatementParameters = trainingDataset.getPreparedStatementParameters();
        TreeMap<Integer, PreparedStatement> preparedStatements = trainingDataset.getPreparedStatements();
        Map<String, DatumReader<Object>> complexFeatureSchemas = getComplexFeatureSchemas(trainingDataset);
        for (Integer num : preparedStatements.keySet()) {
            Map<String, Integer> map2 = preparedStatementParameters.get(num);
            for (String str : map.keySet()) {
                if (map2.containsKey(str)) {
                    preparedStatements.get(num).setObject(map2.get(str).intValue(), map.get(str));
                }
            }
        }
        ArrayList arrayList = new ArrayList();
        Iterator<Integer> it = preparedStatements.keySet().iterator();
        while (it.hasNext()) {
            ResultSet executeQuery = preparedStatements.get(it.next()).executeQuery();
            if (!executeQuery.isBeforeFirst()) {
                throw new FeatureStoreException("No data was retrieved from online feature store using input " + map);
            }
            int columnCount = executeQuery.getMetaData().getColumnCount();
            while (executeQuery.next()) {
                for (int i = 1; i <= columnCount; i++) {
                    if (complexFeatureSchemas.containsKey(executeQuery.getMetaData().getColumnName(i))) {
                        arrayList.add(deserializeComplexFeature(complexFeatureSchemas, executeQuery, i));
                    } else {
                        arrayList.add(executeQuery.getObject(i));
                    }
                }
            }
            executeQuery.close();
        }
        trainingDataset.getPreparedStatementConnection().commit();
        return arrayList;
    }

    private Object deserializeComplexFeature(Map<String, DatumReader<Object>> map, ResultSet resultSet, int i) throws SQLException, IOException {
        return map.get(resultSet.getMetaData().getColumnName(i)).read((Object) null, DecoderFactory.get().binaryDecoder(resultSet.getBytes(i), this.binaryDecoder));
    }

    private Map<String, DatumReader<Object>> getComplexFeatureSchemas(TrainingDataset trainingDataset) throws FeatureStoreException, IOException {
        HashMap hashMap = new HashMap();
        for (TrainingDatasetFeature trainingDatasetFeature : trainingDataset.getFeatures()) {
            if (trainingDatasetFeature.isComplex()) {
                hashMap.put(trainingDatasetFeature.getName(), new GenericDatumReader(this.parser.parse(trainingDatasetFeature.getFeaturegroup().getFeatureAvroSchema(trainingDatasetFeature.getName()))));
            }
        }
        return hashMap;
    }

    public void delete(TrainingDataset trainingDataset) throws FeatureStoreException, IOException {
        this.trainingDatasetApi.delete(trainingDataset);
    }
}
