package org.apache.hudi.org.apache.hadoop.hive.ql.optimizer.correlation;

import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.EnumSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.Stack;
import jodd.util.StringPool;
import org.apache.hadoop.hive.ql.exec.CommonMergeJoinOperator;
import org.apache.hadoop.hive.ql.exec.MapJoinOperator;
import org.apache.hadoop.hive.ql.exec.Operator;
import org.apache.hadoop.hive.ql.exec.ReduceSinkOperator;
import org.apache.hadoop.hive.ql.lib.DefaultRuleDispatcher;
import org.apache.hadoop.hive.ql.lib.ForwardWalker;
import org.apache.hadoop.hive.ql.lib.Node;
import org.apache.hadoop.hive.ql.lib.NodeProcessor;
import org.apache.hadoop.hive.ql.lib.NodeProcessorCtx;
import org.apache.hadoop.hive.ql.lib.RuleRegExp;
import org.apache.hadoop.hive.ql.parse.ParseContext;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.plan.MapJoinDesc;
import org.apache.hadoop.hive.ql.plan.OperatorDesc;
import org.apache.hadoop.hive.ql.plan.ReduceSinkDesc;
import org.apache.hudi.org.apache.hadoop.hive.ql.optimizer.Transform;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/hudi/org/apache/hadoop/hive/ql/optimizer/correlation/ReduceSinkJoinDeDuplication.class */
public class ReduceSinkJoinDeDuplication extends Transform {
    protected static final Logger LOG = LoggerFactory.getLogger(ReduceSinkJoinDeDuplication.class);
    protected ParseContext pGraphContext;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/apache/hudi/org/apache/hadoop/hive/ql/optimizer/correlation/ReduceSinkJoinDeDuplication$DefaultProc.class */
    public static class DefaultProc implements NodeProcessor {
        DefaultProc() {
        }

        @Override // org.apache.hadoop.hive.ql.lib.NodeProcessor
        public Object process(Node node, Stack<Node> stack, NodeProcessorCtx nodeProcessorCtx, Object... objArr) throws SemanticException {
            return null;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:org/apache/hudi/org/apache/hadoop/hive/ql/optimizer/correlation/ReduceSinkJoinDeDuplication$ReduceSinkJoinDeDuplicateProcCtx.class */
    public class ReduceSinkJoinDeDuplicateProcCtx extends AbstractCorrelationProcCtx {
        public ReduceSinkJoinDeDuplicateProcCtx(ParseContext parseContext) {
            super(parseContext);
        }
    }

    /* loaded from: input_file:org/apache/hudi/org/apache/hadoop/hive/ql/optimizer/correlation/ReduceSinkJoinDeDuplication$ReduceSinkJoinDeDuplicateProcFactory.class */
    static class ReduceSinkJoinDeDuplicateProcFactory {
        ReduceSinkJoinDeDuplicateProcFactory() {
        }

        public static NodeProcessor getReducerMapJoinProc() {
            return new ReducerProc();
        }

        public static NodeProcessor getDefaultProc() {
            return new DefaultProc();
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/apache/hudi/org/apache/hadoop/hive/ql/optimizer/correlation/ReduceSinkJoinDeDuplication$ReducerProc.class */
    public static class ReducerProc implements NodeProcessor {
        ReducerProc() {
        }

        /* JADX WARN: Multi-variable type inference failed */
        @Override // org.apache.hadoop.hive.ql.lib.NodeProcessor
        public Object process(Node node, Stack<Node> stack, NodeProcessorCtx nodeProcessorCtx, Object... objArr) throws SemanticException {
            ReduceSinkJoinDeDuplicateProcCtx reduceSinkJoinDeDuplicateProcCtx = (ReduceSinkJoinDeDuplicateProcCtx) nodeProcessorCtx;
            ReduceSinkOperator reduceSinkOperator = (ReduceSinkOperator) node;
            if (!((ReduceSinkDesc) reduceSinkOperator.getConf()).isForwarding() && !((ReduceSinkDesc) reduceSinkOperator.getConf()).getKeyCols().isEmpty()) {
                boolean z = false;
                Operator<? extends OperatorDesc> operator = reduceSinkOperator.getChildOperators().get(0);
                if ((operator instanceof MapJoinOperator) || (operator instanceof CommonMergeJoinOperator)) {
                    Iterator<Operator<? extends OperatorDesc>> it = operator.getParentOperators().iterator();
                    while (it.hasNext()) {
                        if (!(it.next() instanceof ReduceSinkOperator)) {
                            return false;
                        }
                    }
                    if (operator instanceof MapJoinOperator) {
                        z = true;
                    }
                }
                int numReducers = ((ReduceSinkDesc) reduceSinkOperator.getConf()).getNumReducers();
                ReduceSinkOperator reduceSinkOperator2 = z ? (ReduceSinkOperator) CorrelationUtilities.findFirstPossibleParent(reduceSinkOperator, ReduceSinkOperator.class, reduceSinkJoinDeDuplicateProcCtx.trustScript()) : (ReduceSinkOperator) CorrelationUtilities.findFirstPossibleParentPreserveSortOrder(reduceSinkOperator, ReduceSinkOperator.class, reduceSinkJoinDeDuplicateProcCtx.trustScript());
                if (reduceSinkOperator2 != null) {
                    Operator<? extends OperatorDesc> operator2 = reduceSinkOperator2.getChildOperators().get(0);
                    if (operator2 instanceof MapJoinOperator) {
                        if (!((MapJoinDesc) ((MapJoinOperator) operator2).getConf()).isDynamicPartitionHashJoin() || !(operator instanceof MapJoinOperator) || !((MapJoinDesc) ((MapJoinOperator) operator).getConf()).isDynamicPartitionHashJoin()) {
                            return false;
                        }
                        ImmutableList.Builder builder = ImmutableList.builder();
                        Iterator<Operator<? extends OperatorDesc>> it2 = operator2.getParentOperators().iterator();
                        while (it2.hasNext()) {
                            ReduceSinkOperator reduceSinkOperator3 = (ReduceSinkOperator) it2.next();
                            builder.add((ImmutableList.Builder) reduceSinkOperator3);
                            if (((ReduceSinkDesc) reduceSinkOperator3.getConf()).getNumReducers() > numReducers) {
                                numReducers = ((ReduceSinkDesc) reduceSinkOperator3.getConf()).getNumReducers();
                            }
                        }
                        if (ReduceSinkDeDuplicationUtils.strictMerge(reduceSinkOperator, builder.build())) {
                            ReduceSinkJoinDeDuplication.LOG.debug("Set {} to forward data", reduceSinkOperator);
                            ((ReduceSinkDesc) reduceSinkOperator.getConf()).setForwarding(true);
                            propagateMaxNumReducers(reduceSinkJoinDeDuplicateProcCtx, reduceSinkOperator, numReducers);
                            return true;
                        }
                    } else if (reduceSinkOperator2.getChildOperators().get(0) instanceof CommonMergeJoinOperator) {
                        ImmutableList.Builder builder2 = ImmutableList.builder();
                        for (Operator<? extends OperatorDesc> operator3 : operator2.getParentOperators()) {
                            if (!(operator3 instanceof ReduceSinkOperator)) {
                                return false;
                            }
                            ReduceSinkOperator reduceSinkOperator4 = (ReduceSinkOperator) operator3;
                            builder2.add((ImmutableList.Builder) reduceSinkOperator4);
                            if (((ReduceSinkDesc) reduceSinkOperator4.getConf()).getNumReducers() > numReducers) {
                                numReducers = ((ReduceSinkDesc) reduceSinkOperator4.getConf()).getNumReducers();
                            }
                        }
                        if (ReduceSinkDeDuplicationUtils.strictMerge(reduceSinkOperator, builder2.build())) {
                            ReduceSinkJoinDeDuplication.LOG.debug("Set {} to forward data", reduceSinkOperator);
                            ((ReduceSinkDesc) reduceSinkOperator.getConf()).setForwarding(true);
                            propagateMaxNumReducers(reduceSinkJoinDeDuplicateProcCtx, reduceSinkOperator, numReducers);
                            return true;
                        }
                    } else {
                        if (((ReduceSinkDesc) reduceSinkOperator2.getConf()).getNumReducers() > numReducers) {
                            numReducers = ((ReduceSinkDesc) reduceSinkOperator2.getConf()).getNumReducers();
                        }
                        if (ReduceSinkDeDuplicationUtils.strictMerge(reduceSinkOperator, reduceSinkOperator2)) {
                            ReduceSinkJoinDeDuplication.LOG.debug("Set {} to forward data", reduceSinkOperator);
                            ((ReduceSinkDesc) reduceSinkOperator.getConf()).setForwarding(true);
                            propagateMaxNumReducers(reduceSinkJoinDeDuplicateProcCtx, reduceSinkOperator, numReducers);
                            return true;
                        }
                    }
                }
                return false;
            }
            return false;
        }

        /* JADX WARN: Multi-variable type inference failed */
        private static void propagateMaxNumReducers(ReduceSinkJoinDeDuplicateProcCtx reduceSinkJoinDeDuplicateProcCtx, ReduceSinkOperator reduceSinkOperator, int i) throws SemanticException {
            if (reduceSinkOperator == null) {
                return;
            }
            if (!(reduceSinkOperator.getChildOperators().get(0) instanceof MapJoinOperator) && !(reduceSinkOperator.getChildOperators().get(0) instanceof CommonMergeJoinOperator)) {
                ((ReduceSinkDesc) reduceSinkOperator.getConf()).setReducerTraits(EnumSet.of(ReduceSinkDesc.ReducerTraits.FIXED));
                ((ReduceSinkDesc) reduceSinkOperator.getConf()).setNumReducers(i);
                ReduceSinkJoinDeDuplication.LOG.debug("Set {} to FIXED parallelism: {}", reduceSinkOperator, Integer.valueOf(i));
                if (((ReduceSinkDesc) reduceSinkOperator.getConf()).isForwarding()) {
                    propagateMaxNumReducers(reduceSinkJoinDeDuplicateProcCtx, (ReduceSinkOperator) CorrelationUtilities.findFirstPossibleParent(reduceSinkOperator, ReduceSinkOperator.class, reduceSinkJoinDeDuplicateProcCtx.trustScript()), i);
                    return;
                }
                return;
            }
            Iterator<Operator<? extends OperatorDesc>> it = reduceSinkOperator.getChildOperators().get(0).getParentOperators().iterator();
            while (it.hasNext()) {
                ReduceSinkOperator reduceSinkOperator2 = (ReduceSinkOperator) it.next();
                ((ReduceSinkDesc) reduceSinkOperator2.getConf()).setReducerTraits(EnumSet.of(ReduceSinkDesc.ReducerTraits.FIXED));
                ((ReduceSinkDesc) reduceSinkOperator2.getConf()).setNumReducers(i);
                ReduceSinkJoinDeDuplication.LOG.debug("Set {} to FIXED parallelism: {}", reduceSinkOperator2, Integer.valueOf(i));
                if (((ReduceSinkDesc) reduceSinkOperator2.getConf()).isForwarding()) {
                    propagateMaxNumReducers(reduceSinkJoinDeDuplicateProcCtx, (ReduceSinkOperator) CorrelationUtilities.findFirstPossibleParent(reduceSinkOperator2, ReduceSinkOperator.class, reduceSinkJoinDeDuplicateProcCtx.trustScript()), i);
                }
            }
        }
    }

    @Override // org.apache.hudi.org.apache.hadoop.hive.ql.optimizer.Transform
    public ParseContext transform(ParseContext parseContext) throws SemanticException {
        this.pGraphContext = parseContext;
        ReduceSinkJoinDeDuplicateProcCtx reduceSinkJoinDeDuplicateProcCtx = new ReduceSinkJoinDeDuplicateProcCtx(this.pGraphContext);
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        linkedHashMap.put(new RuleRegExp("R1", ReduceSinkOperator.getOperatorName() + StringPool.PERCENT), ReduceSinkJoinDeDuplicateProcFactory.getReducerMapJoinProc());
        ForwardWalker forwardWalker = new ForwardWalker(new DefaultRuleDispatcher(ReduceSinkJoinDeDuplicateProcFactory.getDefaultProc(), linkedHashMap, reduceSinkJoinDeDuplicateProcCtx));
        ArrayList arrayList = new ArrayList();
        arrayList.addAll(this.pGraphContext.getTopOps().values());
        forwardWalker.startWalking(arrayList, null);
        return this.pGraphContext;
    }
}
