package io.hops.hopsworks.common.featurestore.query;

import io.hops.hopsworks.common.featurestore.FeaturestoreFacade;
import io.hops.hopsworks.common.featurestore.feature.FeatureGroupFeatureDTO;
import io.hops.hopsworks.common.featurestore.featuregroup.FeaturegroupController;
import io.hops.hopsworks.common.featurestore.featuregroup.FeaturegroupDTO;
import io.hops.hopsworks.common.featurestore.featuregroup.FeaturegroupFacade;
import io.hops.hopsworks.common.featurestore.online.OnlineFeaturestoreController;
import io.hops.hopsworks.exceptions.FeaturestoreException;
import io.hops.hopsworks.persistence.entity.featurestore.featuregroup.Featuregroup;
import io.hops.hopsworks.restutils.RESTCodes;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.logging.Level;
import java.util.stream.Collectors;
import javax.ejb.EJB;
import javax.ejb.Stateless;
import javax.ejb.TransactionAttribute;
import javax.ejb.TransactionAttributeType;
import org.apache.calcite.sql.JoinConditionType;
import org.apache.calcite.sql.JoinType;
import org.apache.calcite.sql.SqlDialect;
import org.apache.calcite.sql.SqlIdentifier;
import org.apache.calcite.sql.SqlJoin;
import org.apache.calcite.sql.SqlLiteral;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlNodeList;
import org.apache.calcite.sql.SqlSelect;
import org.apache.calcite.sql.dialect.SparkSqlDialect;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.parser.SqlParserPos;

@TransactionAttribute(TransactionAttributeType.NEVER)
@Stateless
/* loaded from: input_file:io/hops/hopsworks/common/featurestore/query/ConstructorController.class */
public class ConstructorController {

    @EJB
    private FeaturegroupFacade featuregroupFacade;

    @EJB
    private FeaturestoreFacade featurestoreFacade;

    @EJB
    private FeaturegroupController featuregroupController;

    @EJB
    private OnlineFeaturestoreController onlineFeaturestoreController;
    private static final String ALL_FEATURES = "*";

    public ConstructorController() {
    }

    protected ConstructorController(FeaturegroupController featuregroupController, FeaturestoreFacade featurestoreFacade, FeaturegroupFacade featuregroupFacade, OnlineFeaturestoreController onlineFeaturestoreController) {
        this.featuregroupController = featuregroupController;
        this.featurestoreFacade = featurestoreFacade;
        this.featuregroupFacade = featuregroupFacade;
        this.onlineFeaturestoreController = onlineFeaturestoreController;
    }

    public FsQueryDTO construct(QueryDTO queryDTO) throws FeaturestoreException {
        return construct(convertQueryDTO(queryDTO, 0));
    }

    public FsQueryDTO construct(Query query) {
        FsQueryDTO fsQueryDTO = new FsQueryDTO();
        fsQueryDTO.setQuery(generateSQL(query, false));
        fsQueryDTO.setQueryOnline(generateSQL(query, true));
        return fsQueryDTO;
    }

    public Query convertQueryDTO(QueryDTO queryDTO, int i) throws FeaturestoreException {
        Featuregroup validateFeaturegroupDTO = validateFeaturegroupDTO(queryDTO.getLeftFeatureGroup());
        int i2 = i + 1;
        String generateAs = generateAs(i);
        String hiveDbName = this.featurestoreFacade.getHiveDbName(validateFeaturegroupDTO.getFeaturestore().getHiveDbId());
        String onlineFeaturestoreDbName = this.onlineFeaturestoreController.getOnlineFeaturestoreDbName(validateFeaturegroupDTO.getFeaturestore().getProject());
        List<Feature> list = (List) this.featuregroupController.getFeatures(validateFeaturegroupDTO).stream().map(featureGroupFeatureDTO -> {
            return new Feature(featureGroupFeatureDTO.getName(), validateFeaturegroupDTO.getName(), generateAs, featureGroupFeatureDTO.getType(), featureGroupFeatureDTO.getPrimary().booleanValue());
        }).collect(Collectors.toList());
        Query query = new Query(hiveDbName, onlineFeaturestoreDbName, validateFeaturegroupDTO, generateAs, validateFeatures(validateFeaturegroupDTO, queryDTO.getLeftFeatures(), list), list);
        if (queryDTO.getJoins() != null && !queryDTO.getJoins().isEmpty()) {
            query.setJoins(convertJoins(query, queryDTO.getJoins(), i2));
            removeDuplicateColumns(query);
        }
        return query;
    }

    private String generateAs(int i) {
        return "fg" + i;
    }

    private Featuregroup validateFeaturegroupDTO(FeaturegroupDTO featuregroupDTO) throws FeaturestoreException {
        if (featuregroupDTO == null) {
            throw new IllegalArgumentException("Feature group not specified");
        }
        return this.featuregroupFacade.findById(featuregroupDTO.getId()).orElseThrow(() -> {
            return new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.FEATUREGROUP_NOT_FOUND, Level.FINE, "Could not find feature group with ID" + featuregroupDTO.getId());
        });
    }

    protected List<Feature> validateFeatures(Featuregroup featuregroup, List<FeatureGroupFeatureDTO> list, List<Feature> list2) throws FeaturestoreException {
        ArrayList arrayList = new ArrayList();
        if (list == null || list.isEmpty()) {
            throw new IllegalArgumentException("Invalid requested features");
        }
        if (list.size() == 1 && list.get(0).getName().equals("*")) {
            arrayList.addAll(list2);
        } else {
            for (FeatureGroupFeatureDTO featureGroupFeatureDTO : list) {
                arrayList.add(list2.stream().filter(feature -> {
                    return feature.getName().equals(featureGroupFeatureDTO.getName());
                }).findFirst().orElseThrow(() -> {
                    return new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.FEATURE_DOES_NOT_EXIST, Level.FINE, "Feature: " + featureGroupFeatureDTO.getName() + " not found in feature group: " + featuregroup.getName());
                }));
            }
        }
        return arrayList;
    }

    private List<Join> convertJoins(Query query, List<JoinDTO> list, int i) throws FeaturestoreException {
        ArrayList arrayList = new ArrayList();
        for (JoinDTO joinDTO : list) {
            if (joinDTO.getQuery() == null) {
                throw new IllegalArgumentException("Subquery not specified");
            }
            int i2 = i;
            i++;
            Query convertQueryDTO = convertQueryDTO(joinDTO.getQuery(), i2);
            if (joinDTO.getOn() != null && !joinDTO.getOn().isEmpty()) {
                arrayList.add(extractOn(query, convertQueryDTO, (List) joinDTO.getOn().stream().map(featureGroupFeatureDTO -> {
                    return new Feature(featureGroupFeatureDTO.getName());
                }).collect(Collectors.toList()), joinDTO.getType()));
            } else if (joinDTO.getLeftOn() == null || joinDTO.getLeftOn().isEmpty()) {
                arrayList.add(extractPrimaryKeysJoin(query, convertQueryDTO, joinDTO.getType()));
            } else {
                arrayList.add(extractLeftRightOn(query, convertQueryDTO, (List) joinDTO.getLeftOn().stream().map(featureGroupFeatureDTO2 -> {
                    return new Feature(featureGroupFeatureDTO2.getName());
                }).collect(Collectors.toList()), (List) joinDTO.getRightOn().stream().map(featureGroupFeatureDTO3 -> {
                    return new Feature(featureGroupFeatureDTO3.getName());
                }).collect(Collectors.toList()), joinDTO.getType()));
            }
        }
        return arrayList;
    }

    protected Join extractOn(Query query, Query query2, List<Feature> list, JoinType joinType) throws FeaturestoreException {
        for (Feature feature : list) {
            checkFeatureExists(query, feature);
            checkFeatureExists(query2, feature);
        }
        return new Join(query, query2, list, joinType);
    }

    protected Join extractLeftRightOn(Query query, Query query2, List<Feature> list, List<Feature> list2, JoinType joinType) throws FeaturestoreException {
        if (list.size() != list2.size()) {
            throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.LEFT_RIGHT_ON_DIFF_SIZES, Level.FINE);
        }
        Iterator<Feature> it = list.iterator();
        while (it.hasNext()) {
            checkFeatureExists(query, it.next());
        }
        Iterator<Feature> it2 = list2.iterator();
        while (it2.hasNext()) {
            checkFeatureExists(query2, it2.next());
        }
        return new Join(query, query2, list, list2, joinType);
    }

    protected Join extractPrimaryKeysJoin(Query query, Query query2, JoinType joinType) throws FeaturestoreException {
        ArrayList arrayList = new ArrayList();
        query.getAvailableFeatures().stream().filter((v0) -> {
            return v0.isPrimary();
        }).forEach(feature -> {
            arrayList.addAll((Collection) query2.getAvailableFeatures().stream().filter(feature -> {
                return feature.getName().equals(feature.getName()) && feature.isPrimary();
            }).collect(Collectors.toList()));
        });
        if (arrayList.isEmpty()) {
            throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.NO_PK_JOINING_KEYS, Level.FINE, query.getFeaturegroup().getName() + " and: " + query2.getFeaturegroup().getName());
        }
        return new Join(query, query2, arrayList, joinType);
    }

    private void checkFeatureExists(Query query, Feature feature) throws FeaturestoreException {
        if (query.getAvailableFeatures().stream().noneMatch(feature2 -> {
            return feature2.getName().equals(feature.getName());
        })) {
            throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.FEATURE_DOES_NOT_EXIST, Level.FINE, "Could not find Join feature " + feature.getName() + " in feature group: " + query.getFeaturegroup().getName());
        }
    }

    private void removeDuplicateColumns(Query query) {
        for (Join join : query.getJoins()) {
            if (join.getRightOn() == null || join.getRightOn().isEmpty()) {
                List<Feature> features = join.getRightQuery().getFeatures();
                List list = (List) join.getOn().stream().map((v0) -> {
                    return v0.getName();
                }).collect(Collectors.toList());
                join.getRightQuery().setFeatures((List) features.stream().filter(feature -> {
                    return !list.contains(feature.getName());
                }).collect(Collectors.toList()));
            }
        }
    }

    public String generateSQL(Query query, boolean z) {
        SqlNodeList sqlNodeList = new SqlNodeList(SqlParserPos.ZERO);
        for (Feature feature : collectFeatures(query)) {
            sqlNodeList.add(new SqlIdentifier(Arrays.asList("`" + feature.getFgAlias() + "`", "`" + feature.getName() + "`"), SqlParserPos.ZERO));
        }
        return new SqlSelect(SqlParserPos.ZERO, (SqlNodeList) null, sqlNodeList, (query.getJoins() == null || query.getJoins().isEmpty()) ? generateTableNode(query, z) : buildJoinNode(query, query.getJoins().size() - 1, z), (SqlNode) null, (SqlNodeList) null, (SqlNode) null, (SqlNodeList) null, (SqlNodeList) null, (SqlNode) null, (SqlNode) null).toSqlString(new SparkSqlDialect(SqlDialect.EMPTY_CONTEXT)).getSql();
    }

    private SqlNode buildJoinNode(Query query, int i, boolean z) {
        return i < 0 ? generateTableNode(query, z) : new SqlJoin(SqlParserPos.ZERO, buildJoinNode(query, i - 1, z), SqlLiteral.createBoolean(false, SqlParserPos.ZERO), SqlLiteral.createSymbol(query.getJoins().get(i).getJoinType(), SqlParserPos.ZERO), generateTableNode(query.getJoins().get(i).getRightQuery(), z), SqlLiteral.createSymbol(JoinConditionType.ON, SqlParserPos.ZERO), query.getJoins().get(i).getCondition());
    }

    protected List<Feature> collectFeatures(Query query) {
        ArrayList arrayList = new ArrayList(query.getFeatures());
        if (query.getJoins() != null) {
            for (Join join : query.getJoins()) {
                if (join.getRightQuery() != null && join.getRightQuery().getFeatures() != null) {
                    arrayList.addAll(collectFeatures(join.getRightQuery()));
                }
            }
        }
        return arrayList;
    }

    private SqlNode generateTableNode(Query query, boolean z) {
        ArrayList arrayList = new ArrayList();
        if (z) {
            arrayList.add("`" + query.getProject() + "`");
        } else {
            arrayList.add("`" + query.getFeatureStore() + "`");
        }
        arrayList.add("`" + query.getFeaturegroup().getName() + "_" + query.getFeaturegroup().getVersion() + "`");
        return SqlStdOperatorTable.AS.createCall(new SqlNodeList(Arrays.asList(new SqlIdentifier(arrayList, SqlParserPos.ZERO), new SqlIdentifier("`" + query.getAs() + "`", SqlParserPos.ZERO)), SqlParserPos.ZERO));
    }
}
