/*
 * Decompiled with CFR 0.152.
 */
package io.hops.hopsworks.common.featurestore.query.pit;

import com.google.common.base.Strings;
import io.hops.hopsworks.common.featurestore.query.ConstructorController;
import io.hops.hopsworks.common.featurestore.query.Feature;
import io.hops.hopsworks.common.featurestore.query.Query;
import io.hops.hopsworks.common.featurestore.query.QueryDTO;
import io.hops.hopsworks.common.featurestore.query.SqlCondition;
import io.hops.hopsworks.common.featurestore.query.filter.Filter;
import io.hops.hopsworks.common.featurestore.query.filter.FilterController;
import io.hops.hopsworks.common.featurestore.query.join.Join;
import io.hops.hopsworks.common.featurestore.query.join.JoinController;
import io.hops.hopsworks.persistence.entity.featurestore.featuregroup.Featuregroup;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.stream.Collectors;
import javax.ejb.EJB;
import javax.ejb.Stateless;
import javax.ejb.TransactionAttribute;
import javax.ejb.TransactionAttributeType;
import org.apache.calcite.sql.JoinConditionType;
import org.apache.calcite.sql.JoinType;
import org.apache.calcite.sql.SqlCall;
import org.apache.calcite.sql.SqlIdentifier;
import org.apache.calcite.sql.SqlJoin;
import org.apache.calcite.sql.SqlLiteral;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlNodeList;
import org.apache.calcite.sql.SqlSelect;
import org.apache.calcite.sql.SqlWindow;
import org.apache.calcite.sql.SqlWith;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.parser.SqlParserPos;

@Stateless
@TransactionAttribute(value=TransactionAttributeType.NEVER)
public class PitJoinController {
    @EJB
    private ConstructorController constructorController;
    @EJB
    private FilterController filterController;
    @EJB
    private JoinController joinController;
    private static final String PIT_JOIN_RANK = "pit_rank_hopsworks";
    private static final String ALL_FEATURES = "*";
    private static final String FG_SUBQUERY = "right_fg";
    private static final String HIVE_ALIAS_PLACEHOLDER = "NA";
    private static final String HIVE_AS = " AS";
    private static final String PK_JOIN_PREFIX = "join_pk_";
    private static final String EVT_JOIN_PREFIX = "join_evt_";

    public PitJoinController() {
    }

    public PitJoinController(ConstructorController constructorController, FilterController filterController, JoinController joinController) {
        this.constructorController = constructorController;
        this.filterController = filterController;
        this.joinController = joinController;
    }

    public boolean isPitEnabled(QueryDTO queryDTO) {
        if (queryDTO.getJoins() == null || queryDTO.getJoins().isEmpty()) {
            return false;
        }
        boolean eventTimeEnabled = queryDTO.getJoins().stream().allMatch(j -> j.getQuery().getLeftFeatureGroup().getEventTime() != null);
        return queryDTO.getLeftFeatureGroup().getEventTime() != null && eventTimeEnabled;
    }

    public boolean isPitEnabled(Query query) {
        if (query.getJoins() == null || query.getJoins().isEmpty()) {
            return false;
        }
        boolean eventTimeEnabled = query.getJoins().stream().allMatch(j -> j.getRightQuery().getFeaturegroup().getEventTime() != null);
        return query.getFeaturegroup().getEventTime() != null && eventTimeEnabled;
    }

    public List<SqlCall> generateSubQueries(Query baseQuery, Query query, boolean isTrainingDataset) {
        ArrayList<SqlCall> subQueries = new ArrayList<SqlCall>();
        List<Feature> additionalPkFeatures = query.getAvailableFeatures().stream().filter(Feature::isPrimary).map(f -> new Feature(f.getName(), f.getFgAlias(), f.getType(), f.isPrimary(), f.getDefaultValue(), PK_JOIN_PREFIX)).collect(Collectors.toList());
        additionalPkFeatures.add(new Feature(query.getFeaturegroup().getEventTime(), query.getAs(), (String)null, null, EVT_JOIN_PREFIX));
        additionalPkFeatures.forEach(f -> f.setFeatureGroup(query.getFeaturegroup()));
        for (Join join : query.getJoins()) {
            List<Feature> newLeftOn = this.addEventTimeOn(join.getLeftOn(), baseQuery.getFeaturegroup(), baseQuery.getAs());
            List<Feature> newRightOn = this.addEventTimeOn(join.getRightOn(), join.getRightQuery().getFeaturegroup(), join.getRightQuery().getAs());
            List<SqlCondition> newJoinOperator = this.addEventTimeCondition(join.getJoinOperator(), SqlCondition.GREATER_THAN_OR_EQUAL);
            List<Join> newJoins = Collections.singletonList(new Join(baseQuery, join.getRightQuery(), newLeftOn, newRightOn, join.getJoinType(), join.getPrefix(), newJoinOperator));
            baseQuery.setJoins(newJoins);
            if (isTrainingDataset) {
                baseQuery.setFeatures(this.dropIrrelevantSubqueryFeatures(query, join.getRightQuery()));
            }
            baseQuery.getFeatures().addAll(additionalPkFeatures);
            SqlSelect subQuery = this.constructorController.generateSQL(baseQuery, false);
            subQuery.getSelectList().add(this.rankOverAs(newLeftOn, new Feature(join.getRightQuery().getFeaturegroup().getEventTime(), join.getRightQuery().getAs(), false)));
            subQueries.add(SqlStdOperatorTable.AS.createCall(SqlParserPos.ZERO, new SqlNode[]{subQuery, new SqlIdentifier(HIVE_ALIAS_PLACEHOLDER, SqlParserPos.ZERO)}));
            baseQuery.setFeatures(new ArrayList<Feature>(query.getFeatures()));
        }
        return subQueries;
    }

    private SqlNodeList partitionBy(List<Feature> partitionFeatures) {
        SqlNodeList partitionBy = new SqlNodeList(SqlParserPos.ZERO);
        partitionFeatures.forEach(joinFeature -> partitionBy.add((SqlNode)new SqlIdentifier(Arrays.asList("`" + joinFeature.getFgAlias() + "`", "`" + joinFeature.getName() + "`"), SqlParserPos.ZERO)));
        return partitionBy;
    }

    private SqlNodeList orderByDesc(Feature feature) {
        return SqlNodeList.of((SqlNode)SqlStdOperatorTable.DESC.createCall(SqlParserPos.ZERO, new SqlNode[]{new SqlIdentifier(Arrays.asList("`" + feature.getFgAlias(false) + "`", "`" + feature.getName() + "`"), SqlParserPos.ZERO)}));
    }

    public SqlNode rankOverAs(List<Feature> partitionByFeatures, Feature orderByFeature) {
        SqlCall rank = SqlStdOperatorTable.RANK.createCall(SqlParserPos.ZERO, new SqlNode[0]);
        SqlNodeList partitionBy = this.partitionBy(partitionByFeatures);
        SqlNodeList orderList = this.orderByDesc(orderByFeature);
        SqlWindow win = SqlWindow.create(null, null, (SqlNodeList)partitionBy, (SqlNodeList)orderList, null, null, null, null, (SqlParserPos)SqlParserPos.ZERO);
        SqlCall over = SqlStdOperatorTable.OVER.createCall(SqlParserPos.ZERO, new SqlNode[]{rank, win});
        return SqlStdOperatorTable.AS.createCall(SqlParserPos.ZERO, new SqlNode[]{over, new SqlIdentifier(PIT_JOIN_RANK, SqlParserPos.ZERO)});
    }

    public List<SqlSelect> wrapSubQueries(List<SqlCall> sqlSelects) {
        ArrayList<SqlSelect> newSubQueries = new ArrayList<SqlSelect>();
        for (SqlCall select : sqlSelects) {
            SqlNode whereRank = this.filterController.generateFilterNode(new Filter(new Feature(PIT_JOIN_RANK, null, "int", null, null), SqlCondition.EQUALS, "1"), false);
            SqlNodeList selectList = SqlNodeList.of((SqlNode)new SqlIdentifier(ALL_FEATURES, SqlParserPos.ZERO));
            newSubQueries.add(new SqlSelect(SqlParserPos.ZERO, null, selectList, (SqlNode)select, whereRank, null, null, null, null, null, null));
        }
        return newSubQueries;
    }

    public SqlNode generateSQL(Query query, boolean isTrainingDataset) {
        Query baseQuery = new Query(query.getFeatureStore(), query.getProject(), query.getFeaturegroup(), query.getAs(), new ArrayList<Feature>(query.getFeatures()), query.getAvailableFeatures(), query.getHiveEngine(), query.getFilter());
        List<Feature> finalSelectList = this.constructorController.collectFeatures(baseQuery);
        List<SqlSelect> withSelects = this.wrapSubQueries(this.generateSubQueries(baseQuery, query, isTrainingDataset));
        finalSelectList.forEach(f -> f.setPitFgAlias("right_fg0"));
        SqlNodeList selectAsses = new SqlNodeList(SqlParserPos.ZERO);
        ArrayList<Join> newJoins = new ArrayList<Join>();
        for (int i = 0; i < withSelects.size(); ++i) {
            selectAsses.add((SqlNode)SqlStdOperatorTable.AS.createCall(SqlNodeList.of((SqlNode)new SqlIdentifier(FG_SUBQUERY + i + HIVE_AS, SqlParserPos.ZERO), (SqlNode)((SqlNode)withSelects.get(i)))));
            String pitAlias = FG_SUBQUERY + i;
            if (isTrainingDataset) {
                int finalI = i;
                finalSelectList.stream().filter(f -> f.getFeatureGroup() == query.getJoins().get(finalI).getRightQuery().getFeaturegroup()).forEach(f -> f.setPitFgAlias(pitAlias));
            } else {
                List<Feature> features = this.constructorController.collectFeatures(query.getJoins().get(i).getRightQuery());
                features.forEach(f -> f.setPitFgAlias(pitAlias));
                finalSelectList.addAll(features);
            }
            List<Feature> primaryKey = baseQuery.getAvailableFeatures().stream().filter(Feature::isPrimary).collect(Collectors.toList());
            List<Feature> newLeftOn = this.addEventTimeOn(primaryKey, baseQuery.getFeaturegroup(), baseQuery.getAs());
            this.renameJoinFeatures(newLeftOn);
            List<Feature> newRightOn = this.addEventTimeOn(primaryKey, baseQuery.getFeaturegroup(), baseQuery.getAs());
            this.renameJoinFeatures(newRightOn);
            List<SqlCondition> newJoinOperator = newLeftOn.stream().map(f -> SqlCondition.EQUALS).collect(Collectors.toList());
            newLeftOn.forEach(f -> f.setPitFgAlias("right_fg0"));
            newRightOn.forEach(f -> f.setPitFgAlias(pitAlias));
            newJoins.add(new Join(null, null, newLeftOn, newRightOn, JoinType.INNER, null, newJoinOperator));
        }
        if (isTrainingDataset) {
            finalSelectList = finalSelectList.stream().sorted(Comparator.comparing(Feature::getIdx)).collect(Collectors.toList());
        }
        SqlNodeList selectList = new SqlNodeList(SqlParserPos.ZERO);
        for (Feature f2 : finalSelectList) {
            String featurePrefixed = !Strings.isNullOrEmpty((String)f2.getPrefix()) ? f2.getPrefix() + f2.getName() : f2.getName();
            selectList.add((SqlNode)new SqlIdentifier(Arrays.asList("`" + f2.getFgAlias(true) + "`", "`" + featurePrefixed + "`"), SqlParserPos.ZERO));
        }
        SqlSelect body = new SqlSelect(SqlParserPos.ZERO, null, selectList, this.buildWithJoin(newJoins, newJoins.size() - 1), null, null, null, null, null, null, null);
        return new SqlWith(SqlParserPos.ZERO, selectAsses, (SqlNode)body);
    }

    public boolean isTrainingDataset(List<Feature> selectList) {
        return selectList.stream().allMatch(f -> f.getIdx() != null && f.getFeatureGroup() != null);
    }

    public void renameJoinFeatures(List<Feature> joinFeatures) {
        joinFeatures.forEach(f -> {
            String prefixName = f.isPrimary() ? PK_JOIN_PREFIX + f.getName() : EVT_JOIN_PREFIX + f.getName();
            f.setName(prefixName);
        });
    }

    public List<Feature> addEventTimeOn(List<Feature> on, Featuregroup featureGroup, String fgAlias) {
        List<Feature> newOn = on.stream().map(f -> new Feature(f.getName(), f.getFgAlias(), f.isPrimary())).collect(Collectors.toList());
        newOn.add(new Feature(featureGroup.getEventTime(), fgAlias));
        return newOn;
    }

    public List<SqlCondition> addEventTimeCondition(List<SqlCondition> joinCondition, SqlCondition operator) {
        ArrayList<SqlCondition> newJoinCondition = new ArrayList<SqlCondition>(joinCondition);
        newJoinCondition.add(operator);
        return newJoinCondition;
    }

    public List<Feature> dropIrrelevantSubqueryFeatures(Query query, Query rightQuery) {
        return query.getFeatures().stream().filter(f -> f.getFeatureGroup() == query.getFeaturegroup() || f.getFeatureGroup() == rightQuery.getFeaturegroup()).collect(Collectors.toList());
    }

    private SqlNode buildWithJoin(List<Join> joins, int i) {
        if (i > 0) {
            return new SqlJoin(SqlParserPos.ZERO, this.buildWithJoin(joins, i - 1), SqlLiteral.createBoolean((boolean)false, (SqlParserPos)SqlParserPos.ZERO), SqlLiteral.createSymbol((Enum)JoinType.INNER, (SqlParserPos)SqlParserPos.ZERO), (SqlNode)new SqlIdentifier(FG_SUBQUERY + i, SqlParserPos.ZERO), SqlLiteral.createSymbol((Enum)JoinConditionType.ON, (SqlParserPos)SqlParserPos.ZERO), this.joinController.getLeftRightCondition("right_fg0", FG_SUBQUERY + i, joins.get(i).getLeftOn(), joins.get(i).getRightOn(), joins.get(i).getJoinOperator(), true));
        }
        return new SqlIdentifier("right_fg0", SqlParserPos.ZERO);
    }
}

