package org.apache.flink.runtime.operators;

import org.apache.flink.api.common.functions.CoGroupFunction;
import org.apache.flink.api.common.operators.util.JoinHashMap;
import org.apache.flink.api.common.typeutils.TypeComparator;
import org.apache.flink.api.common.typeutils.TypeComparatorFactory;
import org.apache.flink.api.common.typeutils.TypePairComparator;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.runtime.iterative.concurrent.SolutionSetBroker;
import org.apache.flink.runtime.iterative.task.AbstractIterativeTask;
import org.apache.flink.runtime.operators.hash.AbstractHashTableProber;
import org.apache.flink.runtime.operators.hash.CompactingHashTable;
import org.apache.flink.runtime.util.EmptyIterator;
import org.apache.flink.runtime.util.NonReusingKeyGroupedIterator;
import org.apache.flink.runtime.util.ReusingKeyGroupedIterator;
import org.apache.flink.runtime.util.SingleElementIterator;
import org.apache.flink.util.Collector;

/* loaded from: input_file:org/apache/flink/runtime/operators/CoGroupWithSolutionSetSecondDriver.class */
public class CoGroupWithSolutionSetSecondDriver<IT1, IT2, OT> implements ResettableDriver<CoGroupFunction<IT1, IT2, OT>, OT> {
    private TaskContext<CoGroupFunction<IT1, IT2, OT>, OT> taskContext;
    private CompactingHashTable<IT2> hashTable;
    private JoinHashMap<IT2> objectMap;
    private TypeSerializer<IT1> probeSideSerializer;
    private TypeComparator<IT1> probeSideComparator;
    private TypeSerializer<IT2> solutionSetSerializer;
    private TypePairComparator<IT1, IT2> pairComparator;
    private IT2 solutionSideRecord;
    protected volatile boolean running;
    private boolean objectReuseEnabled = false;

    @Override // org.apache.flink.runtime.operators.Driver
    public void setup(TaskContext<CoGroupFunction<IT1, IT2, OT>, OT> taskContext) {
        this.taskContext = taskContext;
        this.running = true;
    }

    @Override // org.apache.flink.runtime.operators.Driver
    public int getNumberOfInputs() {
        return 1;
    }

    @Override // org.apache.flink.runtime.operators.Driver
    public Class<CoGroupFunction<IT1, IT2, OT>> getStubType() {
        return CoGroupFunction.class;
    }

    @Override // org.apache.flink.runtime.operators.Driver
    public int getNumberOfDriverComparators() {
        return 1;
    }

    @Override // org.apache.flink.runtime.operators.ResettableDriver
    public boolean isInputResettable(int i) {
        if (i < 0 || i > 1) {
            throw new IndexOutOfBoundsException();
        }
        return false;
    }

    @Override // org.apache.flink.runtime.operators.ResettableDriver
    public void initialize() throws Exception {
        TypeComparator<IT2> duplicate;
        if (!(this.taskContext instanceof AbstractIterativeTask)) {
            throw new Exception("The task context of this driver is no iterative task context.");
        }
        Object obj = SolutionSetBroker.instance().get(((AbstractIterativeTask) this.taskContext).brokerKey());
        if (obj instanceof CompactingHashTable) {
            this.hashTable = (CompactingHashTable) obj;
            this.solutionSetSerializer = this.hashTable.getBuildSideSerializer();
            duplicate = this.hashTable.getBuildSideComparator().duplicate();
        } else {
            if (!(obj instanceof JoinHashMap)) {
                throw new RuntimeException("Unrecognized solution set index: " + obj);
            }
            this.objectMap = (JoinHashMap) obj;
            this.solutionSetSerializer = this.objectMap.getBuildSerializer();
            duplicate = this.objectMap.getBuildComparator().duplicate();
        }
        TypeComparatorFactory driverComparator = this.taskContext.getTaskConfig().getDriverComparator(0, this.taskContext.getUserCodeClassLoader());
        this.probeSideSerializer = this.taskContext.getInputSerializer(0).getSerializer();
        this.probeSideComparator = driverComparator.createComparator();
        this.objectReuseEnabled = this.taskContext.getExecutionConfig().isObjectReuseEnabled();
        if (this.objectReuseEnabled) {
            this.solutionSideRecord = this.solutionSetSerializer.createInstance2();
        }
        this.pairComparator = this.taskContext.getTaskConfig().getPairComparatorFactory(this.taskContext.getUserCodeClassLoader()).createComparator12(this.probeSideComparator, duplicate);
    }

    @Override // org.apache.flink.runtime.operators.Driver
    public void prepare() {
    }

    @Override // org.apache.flink.runtime.operators.Driver
    public void run() throws Exception {
        CoGroupFunction stub = this.taskContext.getStub();
        Collector<OT> outputCollector = this.taskContext.getOutputCollector();
        SingleElementIterator singleElementIterator = new SingleElementIterator();
        EmptyIterator emptyIterator = EmptyIterator.get();
        if (!this.objectReuseEnabled) {
            NonReusingKeyGroupedIterator nonReusingKeyGroupedIterator = new NonReusingKeyGroupedIterator(this.taskContext.getInput(0), this.probeSideComparator);
            if (this.hashTable != null) {
                AbstractHashTableProber prober = this.hashTable.getProber((TypeComparator) this.probeSideComparator, (TypePairComparator<PT, IT2>) this.pairComparator);
                while (this.running && nonReusingKeyGroupedIterator.nextKey()) {
                    Object matchFor = prober.getMatchFor(nonReusingKeyGroupedIterator.getCurrent());
                    if (matchFor != null) {
                        singleElementIterator.set(this.solutionSetSerializer.copy(matchFor));
                        stub.coGroup(nonReusingKeyGroupedIterator.getValues(), singleElementIterator, outputCollector);
                    } else {
                        stub.coGroup(nonReusingKeyGroupedIterator.getValues(), emptyIterator, outputCollector);
                    }
                }
                return;
            }
            JoinHashMap<IT2> joinHashMap = this.objectMap;
            JoinHashMap<IT2>.Prober<PT> createProber = joinHashMap.createProber(this.probeSideComparator, this.pairComparator);
            TypeSerializer<IT2> buildSerializer = joinHashMap.getBuildSerializer();
            while (this.running && nonReusingKeyGroupedIterator.nextKey()) {
                Object lookupMatch = createProber.lookupMatch(nonReusingKeyGroupedIterator.getCurrent());
                if (lookupMatch != null) {
                    singleElementIterator.set(buildSerializer.copy(lookupMatch));
                    stub.coGroup(nonReusingKeyGroupedIterator.getValues(), singleElementIterator, outputCollector);
                } else {
                    stub.coGroup(nonReusingKeyGroupedIterator.getValues(), emptyIterator, outputCollector);
                }
            }
            return;
        }
        ReusingKeyGroupedIterator reusingKeyGroupedIterator = new ReusingKeyGroupedIterator(this.taskContext.getInput(0), this.probeSideSerializer, this.probeSideComparator);
        if (this.hashTable != null) {
            AbstractHashTableProber prober2 = this.hashTable.getProber((TypeComparator) this.probeSideComparator, (TypePairComparator<PT, IT2>) this.pairComparator);
            IT2 it2 = this.solutionSideRecord;
            while (this.running && reusingKeyGroupedIterator.nextKey()) {
                Object matchFor2 = prober2.getMatchFor(reusingKeyGroupedIterator.getCurrent(), it2);
                if (matchFor2 != null) {
                    singleElementIterator.set(matchFor2);
                    stub.coGroup(reusingKeyGroupedIterator.getValues(), singleElementIterator, outputCollector);
                } else {
                    stub.coGroup(reusingKeyGroupedIterator.getValues(), emptyIterator, outputCollector);
                }
            }
            return;
        }
        JoinHashMap<IT2> joinHashMap2 = this.objectMap;
        JoinHashMap<IT2>.Prober<PT> createProber2 = joinHashMap2.createProber(this.probeSideComparator, this.pairComparator);
        TypeSerializer<IT2> buildSerializer2 = joinHashMap2.getBuildSerializer();
        while (this.running && reusingKeyGroupedIterator.nextKey()) {
            Object lookupMatch2 = createProber2.lookupMatch(reusingKeyGroupedIterator.getCurrent());
            if (lookupMatch2 != null) {
                singleElementIterator.set(buildSerializer2.copy(lookupMatch2));
                stub.coGroup(reusingKeyGroupedIterator.getValues(), singleElementIterator, outputCollector);
            } else {
                stub.coGroup(reusingKeyGroupedIterator.getValues(), emptyIterator, outputCollector);
            }
        }
    }

    @Override // org.apache.flink.runtime.operators.Driver
    public void cleanup() {
    }

    @Override // org.apache.flink.runtime.operators.ResettableDriver
    public void reset() {
    }

    @Override // org.apache.flink.runtime.operators.ResettableDriver
    public void teardown() {
    }

    @Override // org.apache.flink.runtime.operators.Driver
    public void cancel() {
        this.running = false;
    }
}
