package org.apache.flink.test.checkpointing;

import java.io.IOException;
import java.lang.invoke.SerializedLambda;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.StreamSupport;
import org.apache.commons.collections.CollectionUtils;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.state.ValueState;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.HighAvailabilityOptions;
import org.apache.flink.configuration.JobManagerOptions;
import org.apache.flink.runtime.checkpoint.CheckpointsCleaner;
import org.apache.flink.runtime.checkpoint.CompletedCheckpoint;
import org.apache.flink.runtime.checkpoint.PerJobCheckpointRecoveryFactory;
import org.apache.flink.runtime.checkpoint.StandaloneCompletedCheckpointStore;
import org.apache.flink.runtime.highavailability.HighAvailabilityServices;
import org.apache.flink.runtime.highavailability.HighAvailabilityServicesFactory;
import org.apache.flink.runtime.highavailability.nonha.embedded.EmbeddedHaServicesWithLeadershipControl;
import org.apache.flink.runtime.jobgraph.JobGraph;
import org.apache.flink.runtime.state.FunctionInitializationContext;
import org.apache.flink.runtime.state.FunctionSnapshotContext;
import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
import org.apache.flink.runtime.testutils.MiniClusterResourceConfiguration;
import org.apache.flink.streaming.api.CheckpointingMode;
import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
import org.apache.flink.streaming.api.checkpoint.ListCheckpointed;
import org.apache.flink.streaming.api.datastream.DataStreamUtils;
import org.apache.flink.streaming.api.environment.CheckpointConfig;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.sink.RichSinkFunction;
import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
import org.apache.flink.streaming.api.functions.source.SourceFunction;
import org.apache.flink.test.util.MiniClusterWithClientResource;
import org.apache.flink.test.util.TestUtils;
import org.apache.flink.util.TestLogger;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.Before;
import org.junit.ClassRule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;

/* loaded from: input_file:org/apache/flink/test/checkpointing/RegionFailoverITCase.class */
public class RegionFailoverITCase extends TestLogger {
    private static final int FAIL_BASE = 1000;
    private static final int NUM_OF_REGIONS = 3;
    private static final int MAX_PARALLELISM = 6;
    private static final int NUM_OF_RESTARTS = 3;
    private static final int NUM_ELEMENTS = 10000;
    private static final String SINGLE_REGION_SOURCE_NAME = "single-source";
    private static final String MULTI_REGION_SOURCE_NAME = "multi-source";
    private static MiniClusterWithClientResource cluster;
    private static final Set<Integer> EXPECTED_INDICES_MULTI_REGION = (Set) IntStream.range(0, 3).boxed().collect(Collectors.toSet());
    private static final Set<Integer> EXPECTED_INDICES_SINGLE_REGION = Collections.singleton(0);
    private static AtomicLong lastCompletedCheckpointId = new AtomicLong(0);
    private static AtomicInteger numCompletedCheckpoints = new AtomicInteger(0);
    private static AtomicInteger jobFailedCnt = new AtomicInteger(0);
    private static Map<Long, Integer> snapshotIndicesOfSubTask = new HashMap();
    private static boolean restoredState = false;

    @ClassRule
    public static final TemporaryFolder TEMPORARY_FOLDER = new TemporaryFolder();

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/flink/test/checkpointing/RegionFailoverITCase$FailingMapperFunction.class */
    public static class FailingMapperFunction extends RichMapFunction<Tuple2<Integer, Integer>, Tuple2<Integer, Integer>> {
        private final int restartTimes;
        private ValueState<Integer> valueState;

        public void open(Configuration configuration) throws Exception {
            super.open(configuration);
            this.valueState = getRuntimeContext().getState(new ValueStateDescriptor("value", Integer.class));
        }

        FailingMapperFunction(int i) {
            this.restartTimes = i;
        }

        public Tuple2<Integer, Integer> map(Tuple2<Integer, Integer> tuple2) throws Exception {
            int indexOfThisSubtask = getRuntimeContext().getIndexOfThisSubtask();
            if (((Integer) tuple2.f1).intValue() > RegionFailoverITCase.FAIL_BASE * (RegionFailoverITCase.jobFailedCnt.get() + 1)) {
                if (RegionFailoverITCase.jobFailedCnt.get() < 1 && indexOfThisSubtask == 0) {
                    RegionFailoverITCase.jobFailedCnt.incrementAndGet();
                    throw new TestException();
                }
                if (RegionFailoverITCase.jobFailedCnt.get() < this.restartTimes && indexOfThisSubtask == 2) {
                    RegionFailoverITCase.jobFailedCnt.incrementAndGet();
                    throw new TestException();
                }
            }
            Integer num = (Integer) this.valueState.value();
            if (num != null) {
                return Tuple2.of(tuple2.f0, Integer.valueOf(num.intValue() + ((Integer) tuple2.f1).intValue()));
            }
            this.valueState.update(tuple2.f1);
            return tuple2;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/flink/test/checkpointing/RegionFailoverITCase$StringGeneratingSourceFunction.class */
    public static class StringGeneratingSourceFunction extends RichParallelSourceFunction<Tuple2<Integer, Integer>> implements CheckpointedFunction {
        private static final long serialVersionUID = 1;
        private final long numElements;
        private final long checkpointLatestAt;
        private int index = -1;
        private int lastRegionIndex = -1;
        private volatile boolean isRunning = true;
        private ListState<Integer> listState;
        private ListState<Integer> unionListState;
        private static final ListStateDescriptor<Integer> stateDescriptor = new ListStateDescriptor<>("list-1", Integer.class);
        private static final ListStateDescriptor<Integer> unionStateDescriptor = new ListStateDescriptor<>("list-2", Integer.class);

        StringGeneratingSourceFunction(long j, long j2) {
            this.numElements = j;
            this.checkpointLatestAt = j2;
        }

        public void run(SourceFunction.SourceContext<Tuple2<Integer, Integer>> sourceContext) throws Exception {
            if (this.index < 0) {
                this.index = 0;
            }
            int indexOfThisSubtask = getRuntimeContext().getIndexOfThisSubtask();
            while (this.isRunning && this.index < this.numElements) {
                synchronized (sourceContext.getCheckpointLock()) {
                    int i = this.index / 2;
                    if (KeyGroupRangeAssignment.assignKeyToParallelOperator(Integer.valueOf(i), RegionFailoverITCase.MAX_PARALLELISM, 3) == indexOfThisSubtask) {
                        sourceContext.collect(Tuple2.of(Integer.valueOf(i), Integer.valueOf(this.index)));
                    }
                    this.index++;
                }
                if (RegionFailoverITCase.numCompletedCheckpoints.get() < 3) {
                    if (this.index < this.checkpointLatestAt) {
                        Thread.sleep(1L);
                    } else {
                        while (this.isRunning && RegionFailoverITCase.numCompletedCheckpoints.get() < 3) {
                            Thread.sleep(300L);
                        }
                    }
                }
                if (RegionFailoverITCase.jobFailedCnt.get() < 3) {
                    Thread.sleep(1L);
                }
            }
        }

        public void cancel() {
            this.isRunning = false;
        }

        public void snapshotState(FunctionSnapshotContext functionSnapshotContext) throws Exception {
            int indexOfThisSubtask = getRuntimeContext().getIndexOfThisSubtask();
            if (indexOfThisSubtask != 0) {
                this.listState.clear();
                this.listState.add(Integer.valueOf(this.index));
                if (indexOfThisSubtask == 2) {
                    this.lastRegionIndex = this.index;
                    RegionFailoverITCase.snapshotIndicesOfSubTask.put(Long.valueOf(functionSnapshotContext.getCheckpointId()), Integer.valueOf(this.lastRegionIndex));
                }
            }
            this.unionListState.clear();
            this.unionListState.add(Integer.valueOf(indexOfThisSubtask));
        }

        public void initializeState(FunctionInitializationContext functionInitializationContext) throws Exception {
            int indexOfThisSubtask = getRuntimeContext().getIndexOfThisSubtask();
            if (!functionInitializationContext.isRestored()) {
                this.unionListState = functionInitializationContext.getOperatorStateStore().getUnionListState(unionStateDescriptor);
                if (indexOfThisSubtask != 0) {
                    this.listState = functionInitializationContext.getOperatorStateStore().getListState(stateDescriptor);
                    return;
                }
                return;
            }
            boolean unused = RegionFailoverITCase.restoredState = true;
            this.unionListState = functionInitializationContext.getOperatorStateStore().getUnionListState(unionStateDescriptor);
            Set set = (Set) StreamSupport.stream(((Iterable) this.unionListState.get()).spliterator(), false).collect(Collectors.toSet());
            if (getRuntimeContext().getTaskName().contains(RegionFailoverITCase.SINGLE_REGION_SOURCE_NAME)) {
                Assert.assertTrue(CollectionUtils.isEqualCollection(RegionFailoverITCase.EXPECTED_INDICES_SINGLE_REGION, set));
            } else {
                Assert.assertTrue(CollectionUtils.isEqualCollection(RegionFailoverITCase.EXPECTED_INDICES_MULTI_REGION, set));
            }
            if (indexOfThisSubtask == 0) {
                this.listState = functionInitializationContext.getOperatorStateStore().getListState(stateDescriptor);
                Assert.assertTrue("list state should be empty for subtask-0", ((List) this.listState.get()).isEmpty());
                return;
            }
            this.listState = functionInitializationContext.getOperatorStateStore().getListState(stateDescriptor);
            Assert.assertTrue("list state should not be empty for subtask-" + indexOfThisSubtask, ((List) this.listState.get()).size() > 0);
            if (indexOfThisSubtask == 2) {
                this.index = ((Integer) ((Iterable) this.listState.get()).iterator().next()).intValue();
                if (this.index != ((Integer) RegionFailoverITCase.snapshotIndicesOfSubTask.get(Long.valueOf(RegionFailoverITCase.lastCompletedCheckpointId.get()))).intValue()) {
                    throw new RuntimeException("Test failed due to unexpected recovered index: " + this.index + ", while last completed checkpoint record index: " + RegionFailoverITCase.snapshotIndicesOfSubTask.get(Long.valueOf(RegionFailoverITCase.lastCompletedCheckpointId.get())));
                }
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/flink/test/checkpointing/RegionFailoverITCase$TestException.class */
    public static class TestException extends IOException {
        private static final long serialVersionUID = 1;

        private TestException() {
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/flink/test/checkpointing/RegionFailoverITCase$TestingCompletedCheckpointStore.class */
    public static class TestingCompletedCheckpointStore extends StandaloneCompletedCheckpointStore {
        TestingCompletedCheckpointStore() {
            super(1);
        }

        public void addCheckpoint(CompletedCheckpoint completedCheckpoint, CheckpointsCleaner checkpointsCleaner, Runnable runnable) throws Exception {
            super.addCheckpoint(completedCheckpoint, checkpointsCleaner, runnable);
            RegionFailoverITCase.lastCompletedCheckpointId.set(completedCheckpoint.getCheckpointID());
            RegionFailoverITCase.numCompletedCheckpoints.incrementAndGet();
        }
    }

    /* loaded from: input_file:org/apache/flink/test/checkpointing/RegionFailoverITCase$TestingHAFactory.class */
    public static class TestingHAFactory implements HighAvailabilityServicesFactory {
        public HighAvailabilityServices createHAServices(Configuration configuration, Executor executor) {
            return new EmbeddedHaServicesWithLeadershipControl(executor, PerJobCheckpointRecoveryFactory.withoutCheckpointStoreRecovery(i -> {
                return new TestingCompletedCheckpointStore();
            }));
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/flink/test/checkpointing/RegionFailoverITCase$ValidatingSink.class */
    public static class ValidatingSink extends RichSinkFunction<Tuple2<Integer, Integer>> implements ListCheckpointed<HashMap<Integer, Integer>> {
        private static Map<Integer, Integer>[] maps = new Map[3];
        private HashMap<Integer, Integer> counts;

        private ValidatingSink() {
            this.counts = new HashMap<>();
        }

        /* JADX WARN: Multi-variable type inference failed */
        public void invoke(Tuple2<Integer, Integer> tuple2) {
            this.counts.merge(tuple2.f0, tuple2.f1, (v0, v1) -> {
                return Math.max(v0, v1);
            });
        }

        public void close() throws Exception {
            maps[getRuntimeContext().getIndexOfThisSubtask()] = this.counts;
        }

        public List<HashMap<Integer, Integer>> snapshotState(long j, long j2) throws Exception {
            return Collections.singletonList(this.counts);
        }

        public void restoreState(List<HashMap<Integer, Integer>> list) throws Exception {
            if (list.size() != 1) {
                throw new RuntimeException("Test failed due to unexpected recovered state size " + list.size());
            }
            this.counts.putAll(list.get(0));
        }
    }

    @Before
    public void setup() throws Exception {
        Configuration configuration = new Configuration();
        configuration.setString(JobManagerOptions.EXECUTION_FAILOVER_STRATEGY, "region");
        configuration.setString(HighAvailabilityOptions.HA_MODE, TestingHAFactory.class.getName());
        cluster = new MiniClusterWithClientResource(new MiniClusterResourceConfiguration.Builder().setConfiguration(configuration).setNumberTaskManagers(2).setNumberSlotsPerTaskManager(2).build());
        cluster.before();
        jobFailedCnt.set(0);
        numCompletedCheckpoints.set(0);
    }

    @AfterClass
    public static void shutDownExistingCluster() {
        if (cluster != null) {
            cluster.after();
            cluster = null;
        }
    }

    @Test(timeout = 60000)
    public void testMultiRegionFailover() {
        try {
            TestUtils.submitJobAndWaitForResult(cluster.getClusterClient(), createJobGraph(), getClass().getClassLoader());
            verifyAfterJobExecuted();
        } catch (Exception e) {
            e.printStackTrace();
            Assert.fail(e.getMessage());
        }
    }

    private void verifyAfterJobExecuted() {
        Assert.assertTrue("The test multi-region job has never ever restored state.", restoredState);
        int i = 0;
        for (Map map : ValidatingSink.maps) {
            for (Map.Entry entry : map.entrySet()) {
                Assert.assertEquals((4 * ((Integer) entry.getKey()).intValue()) + 1, ((Integer) entry.getValue()).intValue());
                i++;
            }
        }
        Assert.assertEquals(5000L, i);
    }

    private JobGraph createJobGraph() {
        StreamExecutionEnvironment executionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment();
        executionEnvironment.setParallelism(3);
        executionEnvironment.setMaxParallelism(MAX_PARALLELISM);
        executionEnvironment.enableCheckpointing(200L, CheckpointingMode.EXACTLY_ONCE);
        executionEnvironment.getCheckpointConfig().setExternalizedCheckpointCleanup(CheckpointConfig.ExternalizedCheckpointCleanup.RETAIN_ON_CANCELLATION);
        executionEnvironment.disableOperatorChaining();
        DataStreamUtils.reinterpretAsKeyedStream(executionEnvironment.addSource(new StringGeneratingSourceFunction(10000L, 3333L)).name(MULTI_REGION_SOURCE_NAME).setParallelism(3), tuple2 -> {
            return (Integer) tuple2.f0;
        }, TypeInformation.of(Integer.class)).map(new FailingMapperFunction(3)).setParallelism(3).addSink(new ValidatingSink()).setParallelism(3);
        executionEnvironment.addSource(new StringGeneratingSourceFunction(10000L, 3333L)).name(SINGLE_REGION_SOURCE_NAME).setParallelism(1).map(tuple22 -> {
            return tuple22;
        }).setParallelism(1);
        return executionEnvironment.getStreamGraph().getJobGraph();
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case 1759234963:
                if (implMethodName.equals("lambda$createJobGraph$31691507$1")) {
                    z = true;
                    break;
                }
                break;
            case 1873978407:
                if (implMethodName.equals("lambda$createJobGraph$dadbaf66$1")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == MAX_PARALLELISM && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/flink/api/java/functions/KeySelector") && serializedLambda.getFunctionalInterfaceMethodName().equals("getKey") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/flink/test/checkpointing/RegionFailoverITCase") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/flink/api/java/tuple/Tuple2;)Ljava/lang/Integer;")) {
                    return tuple2 -> {
                        return (Integer) tuple2.f0;
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == MAX_PARALLELISM && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/flink/api/common/functions/MapFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("map") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/flink/test/checkpointing/RegionFailoverITCase") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/flink/api/java/tuple/Tuple2;)Ljava/lang/Object;")) {
                    return tuple22 -> {
                        return tuple22;
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
