/*
 * Decompiled with CFR 0.152.
 */
package io.hops.hopsworks.common.featurestore.embedding;

import com.google.common.base.Strings;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import io.hops.hopsworks.common.featurestore.embedding.VectorDatabaseClient;
import io.hops.hopsworks.common.featurestore.feature.FeatureGroupFeatureDTO;
import io.hops.hopsworks.common.featurestore.featuregroup.EmbeddingDTO;
import io.hops.hopsworks.common.featurestore.featuregroup.FeaturegroupController;
import io.hops.hopsworks.common.models.ModelFacade;
import io.hops.hopsworks.common.models.version.ModelVersionFacade;
import io.hops.hopsworks.common.util.Settings;
import io.hops.hopsworks.exceptions.FeaturestoreException;
import io.hops.hopsworks.persistence.entity.featurestore.featuregroup.Embedding;
import io.hops.hopsworks.persistence.entity.featurestore.featuregroup.EmbeddingFeature;
import io.hops.hopsworks.persistence.entity.featurestore.featuregroup.Featuregroup;
import io.hops.hopsworks.persistence.entity.models.Model;
import io.hops.hopsworks.persistence.entity.models.version.ModelVersion;
import io.hops.hopsworks.persistence.entity.project.Project;
import io.hops.hopsworks.restutils.RESTCodes;
import io.hops.hopsworks.vectordb.Field;
import io.hops.hopsworks.vectordb.Index;
import io.hops.hopsworks.vectordb.OpensearchVectorDatabase;
import io.hops.hopsworks.vectordb.VectorDatabaseException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Random;
import java.util.Set;
import java.util.logging.Level;
import java.util.stream.Collectors;
import javax.ejb.EJB;
import javax.ejb.Stateless;
import javax.ejb.TransactionAttribute;
import javax.ejb.TransactionAttributeType;

@Stateless
@TransactionAttribute(value=TransactionAttributeType.NEVER)
public class EmbeddingController {
    private static final Random RANDOM = new Random();
    @EJB
    private Settings settings;
    @EJB
    private VectorDatabaseClient vectorDatabaseClient;
    @EJB
    private FeaturegroupController featuregroupController;
    @EJB
    private ModelVersionFacade modelVersionFacade;
    @EJB
    private ModelFacade modelFacade;
    private static final String embeddingIndexIdentifier = "__embedding";

    public EmbeddingController() {
    }

    EmbeddingController(VectorDatabaseClient vectorDatabaseClient, Settings settings) {
        this.vectorDatabaseClient = vectorDatabaseClient;
        this.settings = settings;
    }

    public void createVectorDbIndex(Project project, Featuregroup featureGroup, List<FeatureGroupFeatureDTO> features) throws FeaturestoreException {
        Index index = new Index(featureGroup.getEmbedding().getVectorDbIndexName());
        try {
            if (this.isDefaultVectorDbIndex(project, index.getName())) {
                this.vectorDatabaseClient.getClient().createIndex(index, this.createIndex(featureGroup.getEmbedding().getColPrefix(), featureGroup.getEmbedding().getEmbeddingFeatures(), features), Boolean.valueOf(true));
                this.vectorDatabaseClient.getClient().addFields(index, this.createMapping(featureGroup.getEmbedding().getColPrefix(), featureGroup.getEmbedding().getEmbeddingFeatures(), features));
            } else {
                this.vectorDatabaseClient.getClient().createIndex(index, this.createIndex(featureGroup.getEmbedding().getColPrefix(), featureGroup.getEmbedding().getEmbeddingFeatures(), features), Boolean.valueOf(false));
            }
        }
        catch (VectorDatabaseException e) {
            throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.COULD_NOT_GET_VECTOR_DB_INDEX, Level.FINE, String.format("Cannot create opensearch vectordb index: %s. Reason: %s", index.getName(), e.getMessage()));
        }
    }

    public void validateWithinMappingLimit(Project project, Index index, Integer numFeatures) throws FeaturestoreException {
        String indexName = this.getProjectIndexName(project, index.getName());
        try {
            int remainingMappingSize = this.indexExist(indexName) ? this.settings.getOpensearchDefaultIndexMappingLimit() - this.vectorDatabaseClient.getClient().getSchema(new Index(indexName)).stream().mapToInt(field -> this.countMappingSizeIncludingSubFields(field.getType())).sum() : this.settings.getOpensearchDefaultIndexMappingLimit();
            if (numFeatures > remainingMappingSize) {
                throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.VECTOR_DATABASE_INDEX_MAPPING_LIMIT_EXCEEDED, Level.FINE, String.format("Number of features exceeds the limit of the index '%s'. Maximum number of features can be added/created is %d. Reduce the number of features or use a different embedding index.", index.getName(), remainingMappingSize));
            }
        }
        catch (VectorDatabaseException e) {
            throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.COULD_NOT_GET_VECTOR_DB_INDEX, Level.FINE, "Cannot create opensearch vectordb index: " + indexName, e.getMessage());
        }
    }

    private int countMappingSizeIncludingSubFields(Object value) {
        int count = 1;
        if (value instanceof Map && ((Map)value).containsKey("fields")) {
            ++count;
        }
        return count;
    }

    public boolean indexExist(String name) throws FeaturestoreException {
        try {
            return this.vectorDatabaseClient.getClient().getIndex(name).isPresent();
        }
        catch (VectorDatabaseException e) {
            throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.COULD_NOT_GET_VECTOR_DB_INDEX, Level.FINE, "Cannot get opensearch vectordb index: " + name);
        }
    }

    public void verifyIndexName(Project project, String name) throws FeaturestoreException {
        String projectIndexName;
        if (name != null && !Strings.isNullOrEmpty((String)name) && this.indexExist(projectIndexName = this.getProjectIndexName(project, name)) && !this.isDefaultVectorDbIndex(project, projectIndexName)) {
            throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.EMBEDDING_INDEX_EXISTED, Level.FINE, String.format("Provided embedding index `%s` already exists in the vector database.", projectIndexName));
        }
    }

    String getProjectIndexName(Project project, String name) throws FeaturestoreException {
        if (Strings.isNullOrEmpty((String)name)) {
            return this.getDefaultVectorDbIndex(project);
        }
        String vectorDbIndexPrefix = this.getVectorDbIndexPrefix(project);
        if (!name.startsWith(vectorDbIndexPrefix)) {
            return vectorDbIndexPrefix + "_" + name;
        }
        return name;
    }

    private ModelVersion getModel(Integer projectId, String modelName, Integer modelVersion) {
        Model model = this.modelFacade.findByProjectIdAndName(projectId, modelName);
        return this.modelVersionFacade.findByProjectAndMlId(model.getId(), modelVersion);
    }

    public String getFieldName(Embedding embedding, String featureName) {
        return embedding.getColPrefix() == null ? featureName : embedding.getColPrefix() + featureName;
    }

    public Embedding getEmbedding(Project project, EmbeddingDTO embeddingDTO, Featuregroup featuregroup) throws FeaturestoreException {
        Embedding embedding = new Embedding();
        embedding.setFeaturegroup(featuregroup);
        String projectIndexName = this.getProjectIndexName(project, embeddingDTO.getIndexName());
        embedding.setVectorDbIndexName(projectIndexName);
        if (Strings.isNullOrEmpty((String)embeddingDTO.getIndexName())) {
            embedding.setColPrefix(this.getVectorDbColPrefix(featuregroup));
        } else {
            String vectorDbIndexPrefix = this.getVectorDbIndexPrefix(project);
            if (!embeddingDTO.getIndexName().startsWith(vectorDbIndexPrefix)) {
                embedding.setColPrefix("");
            }
            if (this.isDefaultVectorDbIndex(project, embeddingDTO.getIndexName())) {
                embedding.setColPrefix(this.getVectorDbColPrefix(featuregroup));
            }
        }
        embedding.setEmbeddingFeatures((Collection)embeddingDTO.getFeatures().stream().map(mapping -> {
            if (mapping.getModel() != null) {
                return new EmbeddingFeature(embedding, mapping.getName(), mapping.getDimension(), mapping.getSimilarityFunctionType(), this.getModel(mapping.getModel().getModelRegistryId(), mapping.getModel().getModelName(), mapping.getModel().getModelVersion()));
            }
            return new EmbeddingFeature(embedding, mapping.getName(), mapping.getDimension(), mapping.getSimilarityFunctionType());
        }).collect(Collectors.toList()));
        return embedding;
    }

    public void dropEmbeddingForProject(Project project) throws FeaturestoreException {
        try {
            for (Index index2 : this.vectorDatabaseClient.getClient().getAllIndices().stream().filter(index -> index.getName().startsWith(this.getVectorDbIndexPrefix(project))).collect(Collectors.toSet())) {
                this.vectorDatabaseClient.getClient().deleteIndex(index2);
            }
        }
        catch (VectorDatabaseException e) {
            throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.COULD_NOT_DELETE_VECTOR_DB_INDEX, Level.FINE, "Cannot delete index from vectordb for project: " + project.getName());
        }
    }

    public Boolean isEmbeddingIndex(String indexName) {
        return indexName.matches("^\\d+__embedding.*");
    }

    public Integer getProjectId(String indexName) {
        return Integer.valueOf(indexName.split(embeddingIndexIdentifier)[0]);
    }

    public void dropEmbedding(Project project, Featuregroup featureGroup) throws FeaturestoreException {
        Index index = new Index(featureGroup.getEmbedding().getVectorDbIndexName());
        try {
            if (this.isDefaultVectorDbIndex(project, featureGroup.getEmbedding().getVectorDbIndexName())) {
                this.removeDocuments(featureGroup);
            } else if (this.isPreviousDefaultVectorDbIndex(featureGroup.getEmbedding())) {
                this.removeDocuments(featureGroup);
            } else {
                this.vectorDatabaseClient.getClient().deleteIndex(index);
            }
        }
        catch (VectorDatabaseException e) {
            throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.COULD_NOT_DELETE_FEATUREGROUP, Level.FINE, "Cannot delete index from vectordb for feature group: " + featureGroup.getName(), e.getMessage(), (Throwable)e);
        }
    }

    private boolean isPreviousDefaultVectorDbIndex(Embedding embedding) throws FeaturestoreException, VectorDatabaseException {
        return !Strings.isNullOrEmpty((String)embedding.getColPrefix()) && !this.vectorDatabaseClient.getClient().getSchema(new Index(embedding.getVectorDbIndexName())).stream().allMatch(field -> field.getName().startsWith(embedding.getColPrefix()));
    }

    private void removeDocuments(Featuregroup featureGroup) throws FeaturestoreException, VectorDatabaseException {
        Set fields = this.vectorDatabaseClient.getClient().getSchema(new Index(featureGroup.getEmbedding().getVectorDbIndexName())).stream().map(Field::getName).collect(Collectors.toSet());
        Optional<String> embeddingFeatureName = featureGroup.getEmbedding().getEmbeddingFeatures().stream().map(feature -> feature.getEmbedding().getColPrefix() == null ? feature.getName() : feature.getEmbedding().getColPrefix() + feature.getName()).filter(fields::contains).findFirst();
        String matchQuery = "%s:*";
        if (embeddingFeatureName.isPresent()) {
            this.vectorDatabaseClient.getClient().deleteByQuery(new Index(featureGroup.getEmbedding().getVectorDbIndexName()), String.format(matchQuery, embeddingFeatureName.get()));
        }
    }

    protected String createMapping(String prefix, Collection<EmbeddingFeature> embeddingFeatures, List<FeatureGroupFeatureDTO> features) {
        Set embeddingFeatureNames = embeddingFeatures.stream().map(EmbeddingFeature::getName).collect(Collectors.toSet());
        String mappingString = "{\n    \"properties\": {\n%s\n    }\n  }";
        String embeddingFieldString = "        \"%s\": {\n          \"type\": \"knn_vector\",\n          \"dimension\": %d,\n          \"method\": {\n            \"name\": \"hnsw\",\n            \"space_type\": \"%s\",\n            \"engine\": \"nmslib\"\n            }\n        }";
        String fieldString = "        \"%s\": {\n          \"type\": \"%s\"\n        }";
        ArrayList fieldMapping = Lists.newArrayList();
        for (EmbeddingFeature embeddingFeature : embeddingFeatures) {
            fieldMapping.add(String.format(embeddingFieldString, prefix + embeddingFeature.getName(), embeddingFeature.getDimension(), embeddingFeature.getSimilarityFunctionType().getOpensearchFunction()));
        }
        for (FeatureGroupFeatureDTO featureGroupFeatureDTO : features) {
            String type;
            if (embeddingFeatureNames.contains(featureGroupFeatureDTO.getName()) || (type = OpensearchVectorDatabase.getDataType((String)featureGroupFeatureDTO.getType())) == null) continue;
            fieldMapping.add(String.format(fieldString, prefix + featureGroupFeatureDTO.getName(), type));
        }
        return String.format(mappingString, String.join((CharSequence)",\n", fieldMapping));
    }

    protected String createIndex(String prefix, Collection<EmbeddingFeature> embeddingFeatures, List<FeatureGroupFeatureDTO> features) {
        String jsonString = "{\n  \"settings\": {\n    \"index\": {\n      \"knn\": \"true\",\n      \"knn.algo_param.ef_search\": 512\n    }\n  },\n  \"mappings\": %s\n}";
        return String.format(jsonString, this.createMapping(prefix, embeddingFeatures, features));
    }

    String getDefaultVectorDbIndex(Project project) throws FeaturestoreException {
        Set<String> indexName = this.getAllDefaultVectorDbIndex(project);
        return indexName.stream().sorted(Comparator.comparingInt(i -> RANDOM.nextInt())).findFirst().get();
    }

    boolean isDefaultVectorDbIndex(Project project, String index) throws FeaturestoreException {
        return this.getAllDefaultVectorDbIndex(project).contains(index);
    }

    private Set<String> getAllDefaultVectorDbIndex(Project project) throws FeaturestoreException {
        Set<Object> indices;
        if (!Strings.isNullOrEmpty((String)this.settings.getOpensearchDefaultEmbeddingIndexName())) {
            indices = Arrays.stream(this.settings.getOpensearchDefaultEmbeddingIndexName().split(",")).collect(Collectors.toSet());
        } else {
            indices = Sets.newHashSet();
            for (int i = 0; i < this.settings.getOpensearchNumDefaultEmbeddingIndex(); ++i) {
                indices.add(this.getVectorDbIndexPrefix(project) + "_default_project_embedding_" + i);
            }
        }
        if (indices.isEmpty()) {
            throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.OPENSEARCH_DEFAULT_EMBEDDING_INDEX_SUFFIX_NOT_DEFINED, Level.FINE, "Default vector db index is not defined.");
        }
        return indices;
    }

    String getVectorDbIndexPrefix(Project project) {
        return project.getId() + embeddingIndexIdentifier;
    }

    private String getVectorDbColPrefix(Featuregroup featuregroup) {
        return featuregroup.getId() + "_";
    }
}

