package io.hops.hopsworks.common.featurestore.query.pit;

import com.google.common.base.Strings;
import io.hops.hopsworks.common.featurestore.query.ConstructorController;
import io.hops.hopsworks.common.featurestore.query.Feature;
import io.hops.hopsworks.common.featurestore.query.Query;
import io.hops.hopsworks.common.featurestore.query.QueryDTO;
import io.hops.hopsworks.common.featurestore.query.filter.Filter;
import io.hops.hopsworks.common.featurestore.query.filter.FilterController;
import io.hops.hopsworks.common.featurestore.query.join.Join;
import io.hops.hopsworks.common.featurestore.query.join.JoinController;
import io.hops.hopsworks.persistence.entity.featurestore.featuregroup.Featuregroup;
import io.hops.hopsworks.persistence.entity.featurestore.trainingdataset.SqlCondition;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;
import javax.ejb.EJB;
import javax.ejb.Stateless;
import javax.ejb.TransactionAttribute;
import javax.ejb.TransactionAttributeType;
import org.apache.calcite.sql.JoinConditionType;
import org.apache.calcite.sql.JoinType;
import org.apache.calcite.sql.SqlCall;
import org.apache.calcite.sql.SqlIdentifier;
import org.apache.calcite.sql.SqlJoin;
import org.apache.calcite.sql.SqlLiteral;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlNodeList;
import org.apache.calcite.sql.SqlSelect;
import org.apache.calcite.sql.SqlWindow;
import org.apache.calcite.sql.SqlWith;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.parser.SqlParserPos;

@TransactionAttribute(TransactionAttributeType.NEVER)
@Stateless
/* loaded from: input_file:io/hops/hopsworks/common/featurestore/query/pit/PitJoinController.class */
public class PitJoinController {

    @EJB
    private ConstructorController constructorController;

    @EJB
    private FilterController filterController;

    @EJB
    private JoinController joinController;
    private static final String PIT_JOIN_RANK = "pit_rank_hopsworks";
    private static final String ALL_FEATURES = "*";
    private static final String FG_SUBQUERY = "right_fg";
    private static final String HIVE_ALIAS_PLACEHOLDER = "NA";
    private static final String HIVE_AS = " AS";
    private static final String PK_JOIN_PREFIX = "join_pk_";
    private static final String EVT_JOIN_PREFIX = "join_evt_";

    public PitJoinController() {
    }

    public PitJoinController(ConstructorController constructorController, FilterController filterController, JoinController joinController) {
        this.constructorController = constructorController;
        this.filterController = filterController;
        this.joinController = joinController;
    }

    public boolean isPitEnabled(QueryDTO queryDTO) {
        if (queryDTO.getJoins() == null || queryDTO.getJoins().isEmpty()) {
            return false;
        }
        return queryDTO.getLeftFeatureGroup().getEventTime() != null && queryDTO.getJoins().stream().allMatch(joinDTO -> {
            return joinDTO.getQuery().getLeftFeatureGroup().getEventTime() != null;
        });
    }

    public boolean isPitEnabled(Query query) {
        if (query.getJoins() == null || query.getJoins().isEmpty()) {
            return false;
        }
        return query.getFeaturegroup().getEventTime() != null && query.getJoins().stream().allMatch(join -> {
            return join.getRightQuery().getFeaturegroup().getEventTime() != null;
        });
    }

    public List<SqlCall> generateSubQueries(Query query, Query query2, boolean z) {
        ArrayList arrayList = new ArrayList();
        List list = (List) query2.getAvailableFeatures().stream().filter((v0) -> {
            return v0.isPrimary();
        }).map(feature -> {
            return new Feature(feature.getName(), feature.getFgAlias(), feature.getType(), feature.isPrimary(), feature.getDefaultValue(), PK_JOIN_PREFIX);
        }).collect(Collectors.toList());
        list.add(new Feature(query2.getFeaturegroup().getEventTime(), query2.getAs(), (String) null, (String) null, EVT_JOIN_PREFIX));
        list.forEach(feature2 -> {
            feature2.setFeatureGroup(query2.getFeaturegroup());
        });
        for (Join join : query2.getJoins()) {
            List<Feature> addEventTimeOn = addEventTimeOn(join.getLeftOn(), query.getFeaturegroup(), query.getAs());
            query.setJoins(Collections.singletonList(new Join(query, join.getRightQuery(), addEventTimeOn, addEventTimeOn(join.getRightOn(), join.getRightQuery().getFeaturegroup(), join.getRightQuery().getAs()), join.getJoinType(), join.getPrefix(), addEventTimeCondition(join.getJoinOperator(), SqlCondition.GREATER_THAN_OR_EQUAL))));
            if (z) {
                query.setFeatures(dropIrrelevantSubqueryFeatures(query2, join.getRightQuery()));
            }
            query.getFeatures().addAll(list);
            SqlNode generateSQL = this.constructorController.generateSQL(query, false);
            generateSQL.getSelectList().add(rankOverAs(addEventTimeOn, new Feature(join.getRightQuery().getFeaturegroup().getEventTime(), join.getRightQuery().getAs(), false)));
            arrayList.add(SqlStdOperatorTable.AS.createCall(SqlParserPos.ZERO, new SqlNode[]{generateSQL, new SqlIdentifier(HIVE_ALIAS_PLACEHOLDER, SqlParserPos.ZERO)}));
            query.setFeatures(new ArrayList(query2.getFeatures()));
        }
        return arrayList;
    }

    private SqlNodeList partitionBy(List<Feature> list) {
        SqlNodeList sqlNodeList = new SqlNodeList(SqlParserPos.ZERO);
        list.forEach(feature -> {
            sqlNodeList.add(new SqlIdentifier(Arrays.asList("`" + feature.getFgAlias() + "`", "`" + feature.getName() + "`"), SqlParserPos.ZERO));
        });
        return sqlNodeList;
    }

    private SqlNodeList orderByDesc(Feature feature) {
        return SqlNodeList.of(SqlStdOperatorTable.DESC.createCall(SqlParserPos.ZERO, new SqlNode[]{new SqlIdentifier(Arrays.asList("`" + feature.getFgAlias(false) + "`", "`" + feature.getName() + "`"), SqlParserPos.ZERO)}));
    }

    public SqlNode rankOverAs(List<Feature> list, Feature feature) {
        return SqlStdOperatorTable.AS.createCall(SqlParserPos.ZERO, new SqlNode[]{SqlStdOperatorTable.OVER.createCall(SqlParserPos.ZERO, new SqlNode[]{SqlStdOperatorTable.RANK.createCall(SqlParserPos.ZERO, new SqlNode[0]), SqlWindow.create((SqlIdentifier) null, (SqlIdentifier) null, partitionBy(list), orderByDesc(feature), (SqlLiteral) null, (SqlNode) null, (SqlNode) null, (SqlLiteral) null, SqlParserPos.ZERO)}), new SqlIdentifier(PIT_JOIN_RANK, SqlParserPos.ZERO)});
    }

    public List<SqlSelect> wrapSubQueries(List<SqlCall> list) {
        ArrayList arrayList = new ArrayList();
        Iterator<SqlCall> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add(new SqlSelect(SqlParserPos.ZERO, (SqlNodeList) null, SqlNodeList.of(new SqlIdentifier("*", SqlParserPos.ZERO)), it.next(), this.filterController.generateFilterNode(new Filter((List<Feature>) Arrays.asList(new Feature(PIT_JOIN_RANK, (String) null, "int", (String) null, (String) null)), SqlCondition.EQUALS, "1"), false), (SqlNodeList) null, (SqlNode) null, (SqlNodeList) null, (SqlNodeList) null, (SqlNode) null, (SqlNode) null, (SqlNodeList) null));
        }
        return arrayList;
    }

    public SqlNode generateSQL(Query query, boolean z) {
        Query query2 = new Query(query.getFeatureStore(), query.getProject(), query.getFeaturegroup(), query.getAs(), new ArrayList(query.getFeatures()), query.getAvailableFeatures(), query.getHiveEngine(), query.getFilter());
        List<Feature> collectFeatures = this.constructorController.collectFeatures(query2);
        List<SqlSelect> wrapSubQueries = wrapSubQueries(generateSubQueries(query2, query, z));
        collectFeatures.forEach(feature -> {
            feature.setPitFgAlias("right_fg0");
        });
        SqlNodeList sqlNodeList = new SqlNodeList(SqlParserPos.ZERO);
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < wrapSubQueries.size(); i++) {
            sqlNodeList.add(SqlStdOperatorTable.AS.createCall(SqlNodeList.of(new SqlIdentifier(FG_SUBQUERY + i + HIVE_AS, SqlParserPos.ZERO), wrapSubQueries.get(i))));
            String str = FG_SUBQUERY + i;
            if (z) {
                int i2 = i;
                collectFeatures.stream().filter(feature2 -> {
                    return feature2.getFeatureGroup() == query.getJoins().get(i2).getRightQuery().getFeaturegroup();
                }).forEach(feature3 -> {
                    feature3.setPitFgAlias(str);
                });
            } else {
                List<Feature> collectFeatures2 = this.constructorController.collectFeatures(query.getJoins().get(i).getRightQuery());
                collectFeatures2.forEach(feature4 -> {
                    feature4.setPitFgAlias(str);
                });
                collectFeatures.addAll(collectFeatures2);
            }
            List<Feature> list = (List) query2.getAvailableFeatures().stream().filter((v0) -> {
                return v0.isPrimary();
            }).collect(Collectors.toList());
            List<Feature> addEventTimeOn = addEventTimeOn(list, query2.getFeaturegroup(), query2.getAs());
            renameJoinFeatures(addEventTimeOn);
            List<Feature> addEventTimeOn2 = addEventTimeOn(list, query2.getFeaturegroup(), query2.getAs());
            renameJoinFeatures(addEventTimeOn2);
            List list2 = (List) addEventTimeOn.stream().map(feature5 -> {
                return SqlCondition.EQUALS;
            }).collect(Collectors.toList());
            addEventTimeOn.forEach(feature6 -> {
                feature6.setPitFgAlias("right_fg0");
            });
            addEventTimeOn2.forEach(feature7 -> {
                feature7.setPitFgAlias(str);
            });
            arrayList.add(new Join(null, null, addEventTimeOn, addEventTimeOn2, JoinType.INNER, null, list2));
        }
        if (z) {
            collectFeatures = (List) collectFeatures.stream().sorted(Comparator.comparing((v0) -> {
                return v0.getIdx();
            })).collect(Collectors.toList());
        }
        SqlNodeList sqlNodeList2 = new SqlNodeList(SqlParserPos.ZERO);
        for (Feature feature8 : collectFeatures) {
            sqlNodeList2.add(new SqlIdentifier(Arrays.asList("`" + feature8.getFgAlias(true) + "`", "`" + (!Strings.isNullOrEmpty(feature8.getPrefix()) ? feature8.getPrefix() + feature8.getName() : feature8.getName()) + "`"), SqlParserPos.ZERO));
        }
        return new SqlWith(SqlParserPos.ZERO, sqlNodeList, new SqlSelect(SqlParserPos.ZERO, (SqlNodeList) null, sqlNodeList2, buildWithJoin(arrayList, arrayList.size() - 1), (SqlNode) null, (SqlNodeList) null, (SqlNode) null, (SqlNodeList) null, (SqlNodeList) null, (SqlNode) null, (SqlNode) null, (SqlNodeList) null));
    }

    public boolean isTrainingDataset(List<Feature> list) {
        return list.stream().allMatch(feature -> {
            return (feature.getIdx() == null || feature.getFeatureGroup() == null) ? false : true;
        });
    }

    public void renameJoinFeatures(List<Feature> list) {
        list.forEach(feature -> {
            feature.setName(feature.isPrimary() ? PK_JOIN_PREFIX + feature.getName() : EVT_JOIN_PREFIX + feature.getName());
        });
    }

    public List<Feature> addEventTimeOn(List<Feature> list, Featuregroup featuregroup, String str) {
        List<Feature> list2 = (List) list.stream().map(feature -> {
            return new Feature(feature.getName(), feature.getFgAlias(), feature.isPrimary());
        }).collect(Collectors.toList());
        list2.add(new Feature(featuregroup.getEventTime(), str));
        return list2;
    }

    public List<SqlCondition> addEventTimeCondition(List<SqlCondition> list, SqlCondition sqlCondition) {
        ArrayList arrayList = new ArrayList(list);
        arrayList.add(sqlCondition);
        return arrayList;
    }

    public List<Feature> dropIrrelevantSubqueryFeatures(Query query, Query query2) {
        return (List) query.getFeatures().stream().filter(feature -> {
            return feature.getFeatureGroup() == query.getFeaturegroup() || feature.getFeatureGroup() == query2.getFeaturegroup();
        }).collect(Collectors.toList());
    }

    private SqlNode buildWithJoin(List<Join> list, int i) {
        return i > 0 ? new SqlJoin(SqlParserPos.ZERO, buildWithJoin(list, i - 1), SqlLiteral.createBoolean(false, SqlParserPos.ZERO), SqlLiteral.createSymbol(JoinType.INNER, SqlParserPos.ZERO), new SqlIdentifier(FG_SUBQUERY + i, SqlParserPos.ZERO), SqlLiteral.createSymbol(JoinConditionType.ON, SqlParserPos.ZERO), this.joinController.getLeftRightCondition("right_fg0", FG_SUBQUERY + i, list.get(i).getLeftOn(), list.get(i).getRightOn(), list.get(i).getJoinOperator(), true)) : new SqlIdentifier("right_fg0", SqlParserPos.ZERO);
    }
}
