package io.hops.hopsworks.common.featurestore.embedding;

import com.google.common.base.Strings;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import io.hops.hopsworks.common.dao.kafka.KafkaConst;
import io.hops.hopsworks.common.featurestore.featuregroup.EmbeddingDTO;
import io.hops.hopsworks.common.featurestore.featuregroup.FeaturegroupController;
import io.hops.hopsworks.common.models.ModelFacade;
import io.hops.hopsworks.common.models.version.ModelVersionFacade;
import io.hops.hopsworks.common.util.Settings;
import io.hops.hopsworks.exceptions.FeaturestoreException;
import io.hops.hopsworks.persistence.entity.featurestore.featuregroup.Embedding;
import io.hops.hopsworks.persistence.entity.featurestore.featuregroup.EmbeddingFeature;
import io.hops.hopsworks.persistence.entity.featurestore.featuregroup.Featuregroup;
import io.hops.hopsworks.persistence.entity.models.version.ModelVersion;
import io.hops.hopsworks.persistence.entity.project.Project;
import io.hops.hopsworks.restutils.RESTCodes;
import io.hops.hopsworks.vectordb.Index;
import io.hops.hopsworks.vectordb.VectorDatabaseException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Random;
import java.util.Set;
import java.util.logging.Level;
import java.util.stream.Collectors;
import javax.ejb.EJB;
import javax.ejb.Stateless;
import javax.ejb.TransactionAttribute;
import javax.ejb.TransactionAttributeType;

@TransactionAttribute(TransactionAttributeType.NEVER)
@Stateless
/* loaded from: input_file:io/hops/hopsworks/common/featurestore/embedding/EmbeddingController.class */
public class EmbeddingController {
    private static final Random RANDOM = new Random();

    @EJB
    private Settings settings;

    @EJB
    private VectorDatabaseClient vectorDatabaseClient;

    @EJB
    private FeaturegroupController featuregroupController;

    @EJB
    private ModelVersionFacade modelVersionFacade;

    @EJB
    private ModelFacade modelFacade;

    public void createVectorDbIndex(Project project, Featuregroup featuregroup) throws FeaturestoreException {
        Index index = new Index(featuregroup.getEmbedding().getVectorDbIndexName());
        try {
            this.vectorDatabaseClient.getClient().createIndex(index, createIndex(featuregroup.getEmbedding().getColPrefix(), featuregroup.getEmbedding().getEmbeddingFeatures()), true);
            if (isDefaultVectorDbIndex(project, index.getName())) {
                this.vectorDatabaseClient.getClient().addFields(index, createMapping(featuregroup.getEmbedding().getColPrefix(), featuregroup.getEmbedding().getEmbeddingFeatures()));
            }
        } catch (VectorDatabaseException e) {
            throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.COULD_NOT_CREATE_FEATUREGROUP, Level.FINE, "Cannot create opensearch vectordb index: " + index.getName());
        }
    }

    private ModelVersion getModel(Integer num, String str, Integer num2) {
        return this.modelVersionFacade.findByProjectAndMlId(this.modelFacade.findByProjectIdAndName(num, str).getId(), num2);
    }

    public String getFieldName(Embedding embedding, String str) {
        return embedding.getColPrefix() == null ? str : embedding.getColPrefix() + str;
    }

    public Embedding getEmbedding(Project project, EmbeddingDTO embeddingDTO, Featuregroup featuregroup) throws FeaturestoreException {
        Embedding embedding = new Embedding();
        embedding.setFeaturegroup(featuregroup);
        if (Strings.isNullOrEmpty(embeddingDTO.getIndexName())) {
            embedding.setVectorDbIndexName(getDefaultVectorDbIndex(project));
            embedding.setColPrefix(getVectorDbColPrefix(featuregroup));
        } else {
            String vectorDbIndexPrefix = getVectorDbIndexPrefix(project);
            if (!embeddingDTO.getIndexName().startsWith(vectorDbIndexPrefix)) {
                embedding.setVectorDbIndexName(vectorDbIndexPrefix + "_" + embeddingDTO.getIndexName());
                embedding.setColPrefix(KafkaConst.KAFKA_ENDPOINT_IDENTIFICATION_ALGORITHM);
            }
            if (isDefaultVectorDbIndex(project, embeddingDTO.getIndexName())) {
                embedding.setColPrefix(getVectorDbColPrefix(featuregroup));
            }
        }
        embedding.setEmbeddingFeatures((Collection) embeddingDTO.getFeatures().stream().map(embeddingFeatureDTO -> {
            return embeddingFeatureDTO.getModel() != null ? new EmbeddingFeature(embedding, embeddingFeatureDTO.getName(), embeddingFeatureDTO.getDimension(), embeddingFeatureDTO.getSimilarityFunctionType(), getModel(embeddingFeatureDTO.getModel().getModelRegistryId(), embeddingFeatureDTO.getModel().getModelName(), embeddingFeatureDTO.getModel().getModelVersion())) : new EmbeddingFeature(embedding, embeddingFeatureDTO.getName(), embeddingFeatureDTO.getDimension(), embeddingFeatureDTO.getSimilarityFunctionType());
        }).collect(Collectors.toList()));
        return embedding;
    }

    public void dropEmbeddingForProject(Project project) throws FeaturestoreException {
        try {
            Iterator it = ((Set) this.vectorDatabaseClient.getClient().getAllIndices().stream().filter(index -> {
                return index.getName().startsWith(getVectorDbIndexPrefix(project));
            }).collect(Collectors.toSet())).iterator();
            while (it.hasNext()) {
                this.vectorDatabaseClient.getClient().deleteIndex((Index) it.next());
            }
        } catch (VectorDatabaseException e) {
            throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.COULD_NOT_DELETE_VECTOR_DB_INDEX, Level.FINE, "Cannot delete index from vectordb for project: " + project.getName());
        }
    }

    public void dropEmbedding(Project project, Featuregroup featuregroup) throws FeaturestoreException {
        Index index = new Index(featuregroup.getEmbedding().getVectorDbIndexName());
        try {
            if (isDefaultVectorDbIndex(project, featuregroup.getEmbedding().getVectorDbIndexName())) {
                removeDocuments(featuregroup);
            } else if (isPreviousDefaultVectorDbIndex(featuregroup.getEmbedding())) {
                removeDocuments(featuregroup);
            } else {
                this.vectorDatabaseClient.getClient().deleteIndex(index);
            }
        } catch (VectorDatabaseException e) {
            throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.COULD_NOT_DELETE_FEATUREGROUP, Level.FINE, "Cannot delete documents from vectordb for feature group: " + featuregroup.getName(), e.getMessage(), e);
        }
    }

    private boolean isPreviousDefaultVectorDbIndex(Embedding embedding) throws FeaturestoreException, VectorDatabaseException {
        return (Strings.isNullOrEmpty(embedding.getColPrefix()) || this.vectorDatabaseClient.getClient().getSchema(new Index(embedding.getVectorDbIndexName())).stream().allMatch(field -> {
            return field.getName().startsWith(embedding.getColPrefix());
        })) ? false : true;
    }

    private void removeDocuments(Featuregroup featuregroup) throws FeaturestoreException, VectorDatabaseException {
        EmbeddingFeature embeddingFeature = (EmbeddingFeature) featuregroup.getEmbedding().getEmbeddingFeatures().stream().findFirst().get();
        this.vectorDatabaseClient.getClient().deleteByQuery(new Index(featuregroup.getEmbedding().getVectorDbIndexName()), String.format("%s:*", embeddingFeature.getEmbedding().getColPrefix() == null ? embeddingFeature.getName() : embeddingFeature.getEmbedding().getColPrefix() + embeddingFeature.getName()));
    }

    protected String createMapping(String str, Collection<EmbeddingFeature> collection) {
        ArrayList newArrayList = Lists.newArrayList();
        for (EmbeddingFeature embeddingFeature : collection) {
            newArrayList.add(String.format("        \"%s\": {\n          \"type\": \"knn_vector\",\n          \"dimension\": %d\n        }", str + embeddingFeature.getName(), embeddingFeature.getDimension()));
        }
        return String.format("{\n    \"properties\": {\n%s\n    }\n  }", String.join(",\n", newArrayList));
    }

    protected String createIndex(String str, Collection<EmbeddingFeature> collection) {
        return String.format("{\n  \"settings\": {\n    \"index\": {\n      \"knn\": \"true\",\n      \"knn.algo_param.ef_search\": 512\n    }\n  },\n  \"mappings\": %s\n}", createMapping(str, collection));
    }

    private String getDefaultVectorDbIndex(Project project) throws FeaturestoreException {
        return getAllDefaultVectorDbIndex(project).stream().sorted(Comparator.comparingInt(str -> {
            return RANDOM.nextInt();
        })).findFirst().get();
    }

    private boolean isDefaultVectorDbIndex(Project project, String str) throws FeaturestoreException {
        return getAllDefaultVectorDbIndex(project).contains(str);
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v19, types: [java.util.Set] */
    private Set<String> getAllDefaultVectorDbIndex(Project project) throws FeaturestoreException {
        HashSet newHashSet;
        if (Strings.isNullOrEmpty(this.settings.getOpensearchDefaultEmbeddingIndexName())) {
            newHashSet = Sets.newHashSet();
            for (int i = 0; i < this.settings.getOpensearchNumDefaultEmbeddingIndex().intValue(); i++) {
                newHashSet.add(getVectorDbIndexPrefix(project) + "_default_project_embedding_" + i);
            }
        } else {
            newHashSet = (Set) Arrays.stream(this.settings.getOpensearchDefaultEmbeddingIndexName().split(",")).collect(Collectors.toSet());
        }
        if (newHashSet.isEmpty()) {
            throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.OPENSEARCH_DEFAULT_EMBEDDING_INDEX_SUFFIX_NOT_DEFINED, Level.FINE, "Default vector db index is not defined.");
        }
        return newHashSet;
    }

    private String getVectorDbIndexPrefix(Project project) {
        return project.getId() + "__embedding";
    }

    private String getVectorDbColPrefix(Featuregroup featuregroup) {
        return featuregroup.getId() + "_";
    }
}
