package org.apache.flink.table.runtime.hashtable;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.core.memory.MemorySegment;
import org.apache.flink.runtime.io.disk.ChannelReaderInputViewIterator;
import org.apache.flink.runtime.io.disk.iomanager.HeaderlessChannelReaderInputView;
import org.apache.flink.runtime.io.disk.iomanager.IOManager;
import org.apache.flink.runtime.memory.MemoryManager;
import org.apache.flink.runtime.operators.util.BitSet;
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.data.binary.BinaryRowData;
import org.apache.flink.table.runtime.generated.JoinCondition;
import org.apache.flink.table.runtime.generated.Projection;
import org.apache.flink.table.runtime.hashtable.BinaryHashPartition;
import org.apache.flink.table.runtime.io.BinaryRowChannelInputViewIterator;
import org.apache.flink.table.runtime.io.ChannelWithMeta;
import org.apache.flink.table.runtime.operators.join.HashJoinType;
import org.apache.flink.table.runtime.operators.join.NullAwareJoinHelper;
import org.apache.flink.table.runtime.typeutils.AbstractRowDataSerializer;
import org.apache.flink.table.runtime.typeutils.BinaryRowDataSerializer;
import org.apache.flink.table.runtime.util.FileChannelUtil;
import org.apache.flink.table.runtime.util.RowIterator;
import org.apache.flink.util.MathUtils;
import org.apache.flink.util.Preconditions;

/* loaded from: input_file:org/apache/flink/table/runtime/hashtable/BinaryHashTable.class */
public class BinaryHashTable extends BaseHybridHashTable {
    final BinaryRowDataSerializer binaryBuildSideSerializer;
    private final AbstractRowDataSerializer originBuildSideSerializer;
    private final BinaryRowDataSerializer binaryProbeSideSerializer;
    private final AbstractRowDataSerializer originProbeSideSerializer;
    private final Projection<RowData, BinaryRowData> buildSideProjection;
    private final Projection<RowData, BinaryRowData> probeSideProjection;
    final int bucketsPerSegment;
    final int bucketsPerSegmentMask;
    final int bucketsPerSegmentBits;
    final boolean useBloomFilters;
    final ArrayList<BinaryHashPartition> partitionsBeingBuilt;
    final BitSet probedSet;
    private final ArrayList<BinaryHashPartition> partitionsPending;
    private final JoinCondition condFunc;
    private final boolean reverseJoin;
    private final int[] nullFilterKeys;
    private final boolean nullSafe;
    private final boolean filterAllNulls;
    LookupBucketIterator bucketIterator;
    private ProbeIterator probeIterator;
    final HashJoinType type;
    private RowIterator<BinaryRowData> buildIterator;
    private boolean probeMatchedPhase;
    private boolean buildIterVisited;
    private BinaryRowData probeKey;
    private RowData probeRow;
    BinaryRowData reuseBuildRow;

    public BinaryHashTable(Configuration configuration, Object obj, AbstractRowDataSerializer abstractRowDataSerializer, AbstractRowDataSerializer abstractRowDataSerializer2, Projection<RowData, BinaryRowData> projection, Projection<RowData, BinaryRowData> projection2, MemoryManager memoryManager, long j, IOManager iOManager, int i, long j2, boolean z, HashJoinType hashJoinType, JoinCondition joinCondition, boolean z2, boolean[] zArr, boolean z3) {
        super(configuration, obj, memoryManager, j, iOManager, i, j2, !hashJoinType.buildLeftSemiOrAnti() && z3);
        this.probedSet = new BitSet(2);
        this.probeMatchedPhase = true;
        this.buildIterVisited = false;
        this.originBuildSideSerializer = abstractRowDataSerializer;
        this.binaryBuildSideSerializer = new BinaryRowDataSerializer(abstractRowDataSerializer.getArity());
        this.reuseBuildRow = this.binaryBuildSideSerializer.m6245createInstance();
        this.originProbeSideSerializer = abstractRowDataSerializer2;
        this.binaryProbeSideSerializer = new BinaryRowDataSerializer(this.originProbeSideSerializer.getArity());
        this.buildSideProjection = projection;
        this.probeSideProjection = projection2;
        this.useBloomFilters = z;
        this.type = hashJoinType;
        this.condFunc = joinCondition;
        this.reverseJoin = z2;
        this.nullFilterKeys = NullAwareJoinHelper.getNullFilterKeys(zArr);
        this.nullSafe = this.nullFilterKeys.length == 0;
        this.filterAllNulls = this.nullFilterKeys.length == zArr.length;
        this.bucketsPerSegment = this.segmentSize >> 7;
        Preconditions.checkArgument(this.bucketsPerSegment != 0, "Hash Table requires buffers of at least 128 bytes.");
        this.bucketsPerSegmentMask = this.bucketsPerSegment - 1;
        this.bucketsPerSegmentBits = MathUtils.log2strict(this.bucketsPerSegment);
        this.partitionsBeingBuilt = new ArrayList<>();
        this.partitionsPending = new ArrayList<>();
        createPartitions(this.initPartitionFanOut, 0);
    }

    public void putBuildRow(RowData rowData) throws IOException {
        insertIntoTable(this.originBuildSideSerializer.toBinaryRow(rowData), hash(this.buildSideProjection.apply(rowData).hashCode(), 0));
    }

    public void endBuild() throws IOException {
        int i = 0;
        Iterator<BinaryHashPartition> it = this.partitionsBeingBuilt.iterator();
        while (it.hasNext()) {
            i += it.next().finalizeBuildPhase(this.ioManager, this.currentEnumerator);
        }
        this.buildSpillRetBufferNumbers += i;
        this.probeIterator = new ProbeIterator(this.binaryProbeSideSerializer.m6245createInstance());
        this.bucketIterator = new LookupBucketIterator(this);
    }

    public boolean tryProbe(RowData rowData) throws IOException {
        if (!this.probeIterator.hasSource()) {
            this.probeIterator.setInstance(rowData);
        }
        BinaryRowData apply = this.probeSideProjection.apply(rowData);
        int hash = hash(apply.hashCode(), this.currentRecursionDepth);
        BinaryHashPartition binaryHashPartition = this.partitionsBeingBuilt.get(hash % this.partitionsBeingBuilt.size());
        if (binaryHashPartition.isInMemory()) {
            this.probeKey = apply;
            this.probeRow = rowData;
            binaryHashPartition.bucketArea.startLookup(hash);
            return true;
        }
        if (!binaryHashPartition.testHashBloomFilter(hash)) {
            return false;
        }
        binaryHashPartition.insertIntoProbeBuffer(this.originProbeSideSerializer.toBinaryRow(rowData));
        return false;
    }

    public boolean nextMatching() throws IOException {
        return this.type.needSetProbed() ? processProbeIter() || processBuildIter() || prepareNextPartition() : processProbeIter() || prepareNextPartition();
    }

    public RowData getCurrentProbeRow() {
        if (this.probeMatchedPhase) {
            return this.probeIterator.current();
        }
        return null;
    }

    public RowIterator<BinaryRowData> getBuildSideIterator() {
        return this.probeMatchedPhase ? this.bucketIterator : this.buildIterator;
    }

    @VisibleForTesting
    static int getNumWriteBehindBuffers(int i) {
        int log = (int) ((Math.log(i) / Math.log(4.0d)) - 1.5d);
        if (log > 6) {
            return 6;
        }
        return log;
    }

    private boolean processProbeIter() throws IOException {
        if (!this.probeIterator.hasSource()) {
            return false;
        }
        ProbeIterator probeIterator = this.probeIterator;
        if (!this.probeMatchedPhase) {
            return false;
        }
        while (true) {
            BinaryRowData next = probeIterator.next();
            if (next == null) {
                return false;
            }
            BinaryRowData apply = this.probeSideProjection.apply(next);
            int hash = hash(apply.hashCode(), this.currentRecursionDepth);
            BinaryHashPartition binaryHashPartition = this.partitionsBeingBuilt.get(hash % this.partitionsBeingBuilt.size());
            if (binaryHashPartition.isInMemory()) {
                this.probeKey = apply;
                this.probeRow = next;
                binaryHashPartition.bucketArea.startLookup(hash);
                return true;
            }
            binaryHashPartition.insertIntoProbeBuffer(next);
        }
    }

    private boolean processBuildIter() throws IOException {
        if (this.buildIterVisited) {
            return false;
        }
        this.probeMatchedPhase = false;
        this.buildIterator = new BuildSideIterator(this.binaryBuildSideSerializer, this.reuseBuildRow, this.partitionsBeingBuilt, this.probedSet, this.type.equals(HashJoinType.BUILD_LEFT_SEMI));
        this.buildIterVisited = true;
        return true;
    }

    private boolean prepareNextPartition() throws IOException {
        Iterator<BinaryHashPartition> it = this.partitionsBeingBuilt.iterator();
        while (it.hasNext()) {
            it.next().finalizeProbePhase(this.internalPool, this.partitionsPending, this.type.needSetProbed());
        }
        this.partitionsBeingBuilt.clear();
        if (this.currentSpilledBuildSide != null) {
            this.currentSpilledBuildSide.getChannel().closeAndDelete();
            this.currentSpilledBuildSide = null;
        }
        if (this.currentSpilledProbeSide != null) {
            this.currentSpilledProbeSide.getChannel().closeAndDelete();
            this.currentSpilledProbeSide = null;
        }
        if (this.partitionsPending.isEmpty()) {
            return false;
        }
        BinaryHashPartition binaryHashPartition = this.partitionsPending.get(0);
        LOG.info(String.format("Begin to process spilled partition [%d]", Integer.valueOf(binaryHashPartition.getPartitionNumber())));
        if (binaryHashPartition.probeSideRecordCounter == 0) {
            this.currentSpilledBuildSide = createInputView(binaryHashPartition.getBuildSideChannel().getChannelID(), binaryHashPartition.getBuildSideBlockCount(), binaryHashPartition.getLastSegmentLimit());
            this.buildIterator = new WrappedRowIterator(new BinaryRowChannelInputViewIterator(this.currentSpilledBuildSide, this.binaryBuildSideSerializer), this.binaryBuildSideSerializer.m6245createInstance());
            this.partitionsPending.remove(0);
            return true;
        }
        this.probeMatchedPhase = true;
        this.buildIterVisited = false;
        buildTableFromSpilledPartition(binaryHashPartition);
        this.currentSpilledProbeSide = FileChannelUtil.createInputView(this.ioManager, new ChannelWithMeta(binaryHashPartition.probeSideBuffer.getChannel().getChannelID(), binaryHashPartition.probeSideBuffer.getBlockCount(), binaryHashPartition.probeNumBytesInLastSeg), new ArrayList(), this.compressionEnable, this.compressionCodecFactory, this.compressionBlockSize, this.segmentSize);
        this.probeIterator.set(new ChannelReaderInputViewIterator<>(this.currentSpilledProbeSide, new ArrayList(), this.binaryProbeSideSerializer));
        this.probeIterator.setReuse(this.binaryProbeSideSerializer.m6245createInstance());
        this.partitionsPending.remove(0);
        this.currentRecursionDepth = binaryHashPartition.getRecursionLevel() + 1;
        return nextMatching();
    }

    private void buildTableFromSpilledPartition(BinaryHashPartition binaryHashPartition) throws IOException {
        int recursionLevel = binaryHashPartition.getRecursionLevel() + 1;
        if (recursionLevel == 2) {
            LOG.info("Recursive hash join: partition number is " + binaryHashPartition.getPartitionNumber());
        } else if (recursionLevel > 3) {
            throw new RuntimeException("Hash join exceeded maximum number of recursions, without reducing partitions enough to be memory resident. Probably cause: Too many duplicate keys.");
        }
        if (binaryHashPartition.getBuildSideBlockCount() > binaryHashPartition.getProbeSideBlockCount()) {
            LOG.info(String.format("Hash join: Partition(%d) build side block [%d] more than probe side block [%d]", Integer.valueOf(binaryHashPartition.getPartitionNumber()), Integer.valueOf(binaryHashPartition.getBuildSideBlockCount()), Integer.valueOf(binaryHashPartition.getProbeSideBlockCount())));
        }
        int freePages = this.internalPool.freePages() + this.buildSpillRetBufferNumbers;
        if (freePages != this.totalNumBuffers) {
            throw new RuntimeException(String.format("Hash Join bug in memory management: Memory buffers leaked. availableMemory(%s), buildSpillRetBufferNumbers(%s), reservedNumBuffers(%s)", Integer.valueOf(this.internalPool.freePages()), Integer.valueOf(this.buildSpillRetBufferNumbers), Integer.valueOf(this.totalNumBuffers)));
        }
        int max = Math.max((int) (2 * (((binaryHashPartition.getBuildSideRecordCount() / 15) + 1) / (this.bucketsPerSegmentMask + 1))), 1);
        long buildSideBlockCount = max + binaryHashPartition.getBuildSideBlockCount() + 2;
        if (buildSideBlockCount < freePages) {
            LOG.info(String.format("Build in memory hash table from spilled partition [%d]", Integer.valueOf(binaryHashPartition.getPartitionNumber())));
            List<MemorySegment> readAllBuffers = readAllBuffers(binaryHashPartition.getBuildSideChannel().getChannelID(), binaryHashPartition.getBuildSideBlockCount());
            BinaryHashBucketArea binaryHashBucketArea = new BinaryHashBucketArea(this, (int) binaryHashPartition.getBuildSideRecordCount(), max, false);
            BinaryHashPartition binaryHashPartition2 = new BinaryHashPartition(binaryHashBucketArea, this.binaryBuildSideSerializer, this.binaryProbeSideSerializer, 0, recursionLevel, readAllBuffers, binaryHashPartition.getBuildSideRecordCount(), this.segmentSize, binaryHashPartition.getLastSegmentLimit());
            binaryHashBucketArea.setPartition(binaryHashPartition2);
            this.partitionsBeingBuilt.add(binaryHashPartition2);
            BinaryHashPartition.PartitionIterator newPartitionIterator = binaryHashPartition2.newPartitionIterator();
            while (newPartitionIterator.advanceNext()) {
                binaryHashBucketArea.insertToBucket(hash(this.buildSideProjection.apply(newPartitionIterator.getRow()).hashCode(), recursionLevel), (int) newPartitionIterator.getPointer(), true);
            }
            return;
        }
        createPartitions(Math.min(Math.min(10 * (((int) (buildSideBlockCount / freePages)) + 1), 127), maxNumPartition()), recursionLevel);
        LOG.info(String.format("Build hybrid hash table from spilled partition [%d] with recursion level [%d]", Integer.valueOf(binaryHashPartition.getPartitionNumber()), Integer.valueOf(recursionLevel)));
        HeaderlessChannelReaderInputView createInputView = createInputView(binaryHashPartition.getBuildSideChannel().getChannelID(), binaryHashPartition.getBuildSideBlockCount(), binaryHashPartition.getLastSegmentLimit());
        BinaryRowChannelInputViewIterator binaryRowChannelInputViewIterator = new BinaryRowChannelInputViewIterator(createInputView, this.binaryBuildSideSerializer);
        BinaryRowData m6245createInstance = this.binaryBuildSideSerializer.m6245createInstance();
        while (true) {
            BinaryRowData next = binaryRowChannelInputViewIterator.next(m6245createInstance);
            m6245createInstance = next;
            if (next == null) {
                break;
            } else {
                insertIntoTable(m6245createInstance, hash(this.buildSideProjection.apply(m6245createInstance).hashCode(), recursionLevel));
            }
        }
        createInputView.getChannel().closeAndDelete();
        int i = 0;
        Iterator<BinaryHashPartition> it = this.partitionsBeingBuilt.iterator();
        while (it.hasNext()) {
            i += it.next().finalizeBuildPhase(this.ioManager, this.currentEnumerator);
        }
        this.buildSpillRetBufferNumbers += i;
    }

    private void insertIntoTable(BinaryRowData binaryRowData, int i) throws IOException {
        BinaryHashPartition binaryHashPartition = this.partitionsBeingBuilt.get(i % this.partitionsBeingBuilt.size());
        if (!binaryHashPartition.isInMemory()) {
            binaryHashPartition.insertIntoBuildBuffer(binaryRowData);
            binaryHashPartition.addHashBloomFilter(i);
        } else {
            if (binaryHashPartition.bucketArea.appendRecordAndInsert(binaryRowData, i)) {
                return;
            }
            binaryHashPartition.addHashBloomFilter(i);
        }
    }

    private void createPartitions(int i, int i2) {
        ensureNumBuffersReturned(i);
        this.currentEnumerator = this.ioManager.createChannelEnumerator();
        this.partitionsBeingBuilt.clear();
        double d = this.buildRowCount / i;
        int maxInitBufferOfBucketArea = maxInitBufferOfBucketArea(i);
        for (int i3 = 0; i3 < i; i3++) {
            BinaryHashBucketArea binaryHashBucketArea = new BinaryHashBucketArea(this, d, maxInitBufferOfBucketArea);
            BinaryHashPartition binaryHashPartition = new BinaryHashPartition(binaryHashBucketArea, this.binaryBuildSideSerializer, this.binaryProbeSideSerializer, i3, i2, getNotNullNextBuffer(), this, this.segmentSize, this.compressionEnable, this.compressionCodecFactory, this.compressionBlockSize);
            binaryHashBucketArea.setPartition(binaryHashPartition);
            this.partitionsBeingBuilt.add(binaryHashPartition);
        }
    }

    @Override // org.apache.flink.table.runtime.hashtable.BaseHybridHashTable
    public void clearPartitions() {
        this.bucketIterator = null;
        this.probeIterator = null;
        for (int size = this.partitionsBeingBuilt.size() - 1; size >= 0; size--) {
            try {
                this.partitionsBeingBuilt.get(size).clearAllMemory(this.internalPool);
            } catch (Exception e) {
                LOG.error("Error during partition cleanup.", e);
            }
        }
        this.partitionsBeingBuilt.clear();
        Iterator<BinaryHashPartition> it = this.partitionsPending.iterator();
        while (it.hasNext()) {
            it.next().clearAllMemory(this.internalPool);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.apache.flink.table.runtime.hashtable.BaseHybridHashTable
    public int spillPartition() throws IOException {
        MemorySegment poll;
        int i = 0;
        int i2 = -1;
        for (int i3 = 0; i3 < this.partitionsBeingBuilt.size(); i3++) {
            BinaryHashPartition binaryHashPartition = this.partitionsBeingBuilt.get(i3);
            if (binaryHashPartition.isInMemory() && binaryHashPartition.getNumOccupiedMemorySegments() > i) {
                i = binaryHashPartition.getNumOccupiedMemorySegments();
                i2 = i3;
            }
        }
        BinaryHashPartition binaryHashPartition2 = this.partitionsBeingBuilt.get(i2);
        int spillPartition = binaryHashPartition2.spillPartition(this.ioManager, this.currentEnumerator.next(), this.buildSpillReturnBuffers);
        this.buildSpillRetBufferNumbers += spillPartition;
        LOG.info(String.format("Grace hash join: Ran out memory, choosing partition [%d] to spill, %d memory segments being freed", Integer.valueOf(i2), Integer.valueOf(spillPartition)));
        while (this.buildSpillRetBufferNumbers > 0 && (poll = this.buildSpillReturnBuffers.poll()) != null) {
            returnPage(poll);
            this.buildSpillRetBufferNumbers--;
        }
        this.numSpillFiles++;
        this.spillInBytes += spillPartition * this.segmentSize;
        binaryHashPartition2.buildBloomFilterAndFreeBucket();
        return i2;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* JADX WARN: Code restructure failed: missing block: B:14:0x0068, code lost:
    
        r0 = true;
     */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    public boolean applyCondition(org.apache.flink.table.data.binary.BinaryRowData r5) {
        /*
            r4 = this;
            r0 = r4
            org.apache.flink.table.runtime.generated.Projection<org.apache.flink.table.data.RowData, org.apache.flink.table.data.binary.BinaryRowData> r0 = r0.buildSideProjection
            r1 = r5
            org.apache.flink.table.data.RowData r0 = r0.apply(r1)
            org.apache.flink.table.data.binary.BinaryRowData r0 = (org.apache.flink.table.data.binary.BinaryRowData) r0
            r6 = r0
            r0 = r6
            int r0 = r0.getSizeInBytes()
            r1 = r4
            org.apache.flink.table.data.binary.BinaryRowData r1 = r1.probeKey
            int r1 = r1.getSizeInBytes()
            if (r0 != r1) goto L3f
            r0 = r6
            org.apache.flink.core.memory.MemorySegment[] r0 = r0.getSegments()
            r1 = 0
            r0 = r0[r1]
            byte[] r0 = r0.getHeapMemory()
            r1 = r4
            org.apache.flink.table.data.binary.BinaryRowData r1 = r1.probeKey
            org.apache.flink.core.memory.MemorySegment[] r1 = r1.getSegments()
            r2 = 0
            r1 = r1[r2]
            byte[] r1 = r1.getHeapMemory()
            r2 = r6
            int r2 = r2.getSizeInBytes()
            boolean r0 = org.apache.flink.table.data.binary.BinaryRowDataUtil.byteArrayEquals(r0, r1, r2)
            if (r0 == 0) goto L3f
            r0 = 1
            goto L40
        L3f:
            r0 = 0
        L40:
            r7 = r0
            r0 = r4
            boolean r0 = r0.nullSafe
            if (r0 != 0) goto L6e
            r0 = r7
            if (r0 == 0) goto L6c
            r0 = r4
            boolean r0 = r0.filterAllNulls
            if (r0 == 0) goto L5d
            r0 = r6
            boolean r0 = r0.anyNull()
            if (r0 == 0) goto L68
            goto L6c
        L5d:
            r0 = r6
            r1 = r4
            int[] r1 = r1.nullFilterKeys
            boolean r0 = r0.anyNull(r1)
            if (r0 != 0) goto L6c
        L68:
            r0 = 1
            goto L6d
        L6c:
            r0 = 0
        L6d:
            r7 = r0
        L6e:
            r0 = r4
            org.apache.flink.table.runtime.generated.JoinCondition r0 = r0.condFunc
            if (r0 != 0) goto L79
            r0 = r7
            goto Lae
        L79:
            r0 = r7
            if (r0 == 0) goto Lad
            r0 = r4
            boolean r0 = r0.reverseJoin
            if (r0 == 0) goto L98
            r0 = r4
            org.apache.flink.table.runtime.generated.JoinCondition r0 = r0.condFunc
            r1 = r4
            org.apache.flink.table.data.RowData r1 = r1.probeRow
            r2 = r5
            boolean r0 = r0.apply(r1, r2)
            if (r0 == 0) goto Lad
            goto La9
        L98:
            r0 = r4
            org.apache.flink.table.runtime.generated.JoinCondition r0 = r0.condFunc
            r1 = r5
            r2 = r4
            org.apache.flink.table.data.RowData r2 = r2.probeRow
            boolean r0 = r0.apply(r1, r2)
            if (r0 == 0) goto Lad
        La9:
            r0 = 1
            goto Lae
        Lad:
            r0 = 0
        Lae:
            return r0
        */
        throw new UnsupportedOperationException("Method not decompiled: org.apache.flink.table.runtime.hashtable.BinaryHashTable.applyCondition(org.apache.flink.table.data.binary.BinaryRowData):boolean");
    }
}
