/*
 * Decompiled with CFR 0.152.
 */
package com.logicalclocks.hsfs.engine;

import com.google.common.base.Strings;
import com.logicalclocks.hsfs.EntityEndpointType;
import com.logicalclocks.hsfs.FeatureStoreException;
import com.logicalclocks.hsfs.Storage;
import com.logicalclocks.hsfs.StorageConnector;
import com.logicalclocks.hsfs.TrainingDataset;
import com.logicalclocks.hsfs.TrainingDatasetFeature;
import com.logicalclocks.hsfs.constructor.ServingPreparedStatement;
import com.logicalclocks.hsfs.engine.SparkEngine;
import com.logicalclocks.hsfs.engine.Utils;
import com.logicalclocks.hsfs.metadata.HopsworksClient;
import com.logicalclocks.hsfs.metadata.StorageConnectorApi;
import com.logicalclocks.hsfs.metadata.TagsApi;
import com.logicalclocks.hsfs.metadata.TrainingDatasetApi;
import java.io.IOException;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.TreeMap;
import org.apache.avro.Schema;
import org.apache.avro.generic.GenericDatumReader;
import org.apache.avro.io.BinaryDecoder;
import org.apache.avro.io.DatumReader;
import org.apache.avro.io.Decoder;
import org.apache.avro.io.DecoderFactory;
import org.apache.hadoop.fs.Path;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SaveMode;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class TrainingDatasetEngine {
    private TrainingDatasetApi trainingDatasetApi = new TrainingDatasetApi();
    private TagsApi tagsApi = new TagsApi(EntityEndpointType.TRAINING_DATASET);
    private StorageConnectorApi storageConnectorApi = new StorageConnectorApi();
    private Utils utils = new Utils();
    private Schema.Parser parser = new Schema.Parser();
    private BinaryDecoder binaryDecoder = DecoderFactory.get().binaryDecoder(new byte[0], null);
    private static final Logger LOGGER = LoggerFactory.getLogger(TrainingDatasetEngine.class);

    public TrainingDataset save(TrainingDataset trainingDataset, Dataset<Row> dataset, Map<String, String> userWriteOptions, List<String> label) throws FeatureStoreException, IOException {
        trainingDataset.setFeatures(this.utils.parseTrainingDatasetSchema(dataset));
        if (label != null && !label.isEmpty()) {
            for (String l : label) {
                Optional<TrainingDatasetFeature> feature = trainingDataset.getFeatures().stream().filter(f -> f.getName().equals(l)).findFirst();
                if (feature.isPresent()) {
                    feature.get().setLabel(true);
                    continue;
                }
                throw new FeatureStoreException("The specified label `" + l + "` could not be found among the features: " + trainingDataset.getFeatures().stream().map(TrainingDatasetFeature::getName) + ".");
            }
        }
        TrainingDataset apiTD = this.trainingDatasetApi.createTrainingDataset(trainingDataset);
        if (trainingDataset.getVersion() == null) {
            LOGGER.info("VersionWarning: No version provided for creating training dataset `" + trainingDataset.getName() + "`, incremented version to `" + apiTD.getVersion() + "`.");
        }
        trainingDataset.setLocation(apiTD.getLocation());
        trainingDataset.setVersion(apiTD.getVersion());
        trainingDataset.setId(apiTD.getId());
        trainingDataset.setStorageConnector(apiTD.getStorageConnector());
        Map<String, String> writeOptions = SparkEngine.getInstance().getWriteOptions(userWriteOptions, trainingDataset.getDataFormat());
        SparkEngine.getInstance().write(trainingDataset, dataset, writeOptions, SaveMode.Overwrite);
        return trainingDataset;
    }

    public void insert(TrainingDataset trainingDataset, Dataset<Row> dataset, Map<String, String> providedOptions, SaveMode saveMode) throws FeatureStoreException, IOException {
        this.utils.trainingDatasetSchemaMatch(dataset, trainingDataset.getFeatures());
        if (this.trainingDatasetApi.getTransformationFunctions(trainingDataset).size() > 0) {
            throw new FeatureStoreException("This training dataset has transformation functions attached and insert operation must be performed from a PySpark application");
        }
        Map<String, String> writeOptions = SparkEngine.getInstance().getWriteOptions(providedOptions, trainingDataset.getDataFormat());
        SparkEngine.getInstance().write(trainingDataset, dataset, writeOptions, saveMode);
    }

    public Dataset<Row> read(TrainingDataset trainingDataset, String split, Map<String, String> providedOptions) throws FeatureStoreException, IOException {
        Map<String, String> readOptions = SparkEngine.getInstance().getReadOptions(providedOptions, trainingDataset.getDataFormat());
        String path = trainingDataset.getLocation();
        if (!Strings.isNullOrEmpty((String)split)) {
            path = new Path(trainingDataset.getLocation(), split).toString();
        }
        return trainingDataset.getStorageConnector().read(null, trainingDataset.getDataFormat().toString(), readOptions, path);
    }

    public void addTag(TrainingDataset trainingDataset, String name, Object value) throws FeatureStoreException, IOException {
        this.tagsApi.add(trainingDataset, name, value);
    }

    public Map<String, Object> getTags(TrainingDataset trainingDataset) throws FeatureStoreException, IOException {
        return this.tagsApi.get(trainingDataset);
    }

    public Object getTag(TrainingDataset trainingDataset, String name) throws FeatureStoreException, IOException {
        return this.tagsApi.get(trainingDataset, name);
    }

    public void deleteTag(TrainingDataset trainingDataset, String name) throws FeatureStoreException, IOException {
        this.tagsApi.deleteTag(trainingDataset, name);
    }

    public String getQuery(TrainingDataset trainingDataset, Storage storage, boolean withLabel) throws FeatureStoreException, IOException {
        return this.trainingDatasetApi.getQuery(trainingDataset, withLabel).getStorageQuery(storage);
    }

    public void updateStatisticsConfig(TrainingDataset trainingDataset) throws FeatureStoreException, IOException {
        TrainingDataset apiTD = this.trainingDatasetApi.updateMetadata(trainingDataset, "updateStatsConfig");
        trainingDataset.getStatisticsConfig().setCorrelations(apiTD.getStatisticsConfig().getCorrelations());
        trainingDataset.getStatisticsConfig().setHistograms(apiTD.getStatisticsConfig().getHistograms());
        trainingDataset.getStatisticsConfig().setExactUniqueness(apiTD.getStatisticsConfig().getExactUniqueness());
    }

    public void initPreparedStatement(TrainingDataset trainingDataset, boolean external) throws FeatureStoreException, IOException, SQLException, ClassNotFoundException {
        Class.forName("com.mysql.jdbc.Driver");
        if (this.trainingDatasetApi.getTransformationFunctions(trainingDataset).size() > 0) {
            throw new FeatureStoreException("This training dataset has transformation functions attached and serving must performed from a Python application");
        }
        StorageConnector.JdbcConnector storageConnector = this.storageConnectorApi.getOnlineStorageConnector(trainingDataset.getFeatureStore());
        Map<String, String> jdbcOptions = ((StorageConnector)storageConnector).sparkOptions();
        String url = jdbcOptions.get("url");
        if (external) {
            url = url.replaceAll("/[0-9.]+:", "/" + HopsworksClient.getInstance().getHost() + ":");
        }
        Connection jdbcConnection = DriverManager.getConnection(url, jdbcOptions.get("user"), jdbcOptions.get("password"));
        jdbcConnection.setAutoCommit(false);
        trainingDataset.setPreparedStatementConnection(jdbcConnection);
        List<ServingPreparedStatement> servingPreparedStatements = this.trainingDatasetApi.getServingPreparedStatement(trainingDataset);
        HashMap<Integer, Map<String, Integer>> preparedStatementParameters = new HashMap<Integer, Map<String, Integer>>();
        TreeMap<Integer, PreparedStatement> preparedStatements = new TreeMap<Integer, PreparedStatement>();
        HashSet<String> servingVectorKeys = new HashSet<String>();
        for (ServingPreparedStatement servingPreparedStatement : servingPreparedStatements) {
            preparedStatements.put(servingPreparedStatement.getPreparedStatementIndex(), jdbcConnection.prepareStatement(servingPreparedStatement.getQueryOnline()));
            HashMap parameterIndices = new HashMap();
            servingPreparedStatement.getPreparedStatementParameters().forEach(preparedStatementParameter -> {
                servingVectorKeys.add(preparedStatementParameter.getName());
                parameterIndices.put(preparedStatementParameter.getName(), preparedStatementParameter.getIndex());
            });
            preparedStatementParameters.put(servingPreparedStatement.getPreparedStatementIndex(), parameterIndices);
        }
        trainingDataset.setServingKeys(servingVectorKeys);
        trainingDataset.setPreparedStatementParameters(preparedStatementParameters);
        trainingDataset.setPreparedStatements(preparedStatements);
    }

    public List<Object> getServingVector(TrainingDataset trainingDataset, Map<String, Object> entry, boolean external) throws SQLException, FeatureStoreException, IOException, ClassNotFoundException {
        if (trainingDataset.getPreparedStatements() == null) {
            this.initPreparedStatement(trainingDataset, external);
        }
        if (!trainingDataset.getServingKeys().equals(entry.keySet())) {
            throw new IllegalArgumentException("Provided primary key map doesn't correspond to serving_keys");
        }
        Map<Integer, Map<String, Integer>> preparedStatementParameters = trainingDataset.getPreparedStatementParameters();
        TreeMap<Integer, PreparedStatement> preparedStatements = trainingDataset.getPreparedStatements();
        Map<String, DatumReader<Object>> complexFeatureSchemas = this.getComplexFeatureSchemas(trainingDataset);
        for (Integer fgId : preparedStatements.keySet()) {
            Map<String, Integer> parameterIndexInStatement = preparedStatementParameters.get(fgId);
            for (String name : entry.keySet()) {
                if (!parameterIndexInStatement.containsKey(name)) continue;
                preparedStatements.get(fgId).setObject(parameterIndexInStatement.get(name), entry.get(name));
            }
        }
        ArrayList<Object> servingVector = new ArrayList<Object>();
        for (Integer preparedStatementIndex : preparedStatements.keySet()) {
            ResultSet results = preparedStatements.get(preparedStatementIndex).executeQuery();
            if (!results.isBeforeFirst()) {
                throw new FeatureStoreException("No data was retrieved from online feature store using input " + entry);
            }
            int columnCount = results.getMetaData().getColumnCount();
            while (results.next()) {
                for (int index = 1; index <= columnCount; ++index) {
                    if (complexFeatureSchemas.containsKey(results.getMetaData().getColumnName(index))) {
                        servingVector.add(this.deserializeComplexFeature(complexFeatureSchemas, results, index));
                        continue;
                    }
                    servingVector.add(results.getObject(index));
                }
            }
            results.close();
        }
        trainingDataset.getPreparedStatementConnection().commit();
        return servingVector;
    }

    private Object deserializeComplexFeature(Map<String, DatumReader<Object>> complexFeatureSchemas, ResultSet results, int index) throws SQLException, IOException {
        BinaryDecoder decoder = DecoderFactory.get().binaryDecoder(results.getBytes(index), this.binaryDecoder);
        return complexFeatureSchemas.get(results.getMetaData().getColumnName(index)).read(null, (Decoder)decoder);
    }

    private Map<String, DatumReader<Object>> getComplexFeatureSchemas(TrainingDataset trainingDataset) throws FeatureStoreException, IOException {
        HashMap<String, DatumReader<Object>> featureSchemaMap = new HashMap<String, DatumReader<Object>>();
        for (TrainingDatasetFeature f : trainingDataset.getFeatures()) {
            if (!f.isComplex()) continue;
            GenericDatumReader datumReader = new GenericDatumReader(this.parser.parse(f.getFeaturegroup().getFeatureAvroSchema(f.getName())));
            featureSchemaMap.put(f.getName(), (DatumReader<Object>)datumReader);
        }
        return featureSchemaMap;
    }

    public void delete(TrainingDataset trainingDataset) throws FeatureStoreException, IOException {
        this.trainingDatasetApi.delete(trainingDataset);
    }
}

