package io.hops.hadoop.shaded.org.apache.zookeeper.server.quorum;

import io.hops.hadoop.shaded.org.apache.zookeeper.PortAssignment;
import io.hops.hadoop.shaded.org.apache.zookeeper.common.BaseX509ParameterizedTestCase;
import io.hops.hadoop.shaded.org.apache.zookeeper.common.ClientX509Util;
import io.hops.hadoop.shaded.org.apache.zookeeper.common.KeyStoreFileType;
import io.hops.hadoop.shaded.org.apache.zookeeper.common.X509Exception;
import io.hops.hadoop.shaded.org.apache.zookeeper.common.X509KeyType;
import io.hops.hadoop.shaded.org.apache.zookeeper.common.X509TestContext;
import io.hops.hadoop.shaded.org.apache.zookeeper.common.X509Util;
import java.io.BufferedInputStream;
import java.io.IOException;
import java.net.ConnectException;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.SocketException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import javax.net.ssl.HandshakeCompletedEvent;
import javax.net.ssl.HandshakeCompletedListener;
import javax.net.ssl.SSLSocket;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;

@RunWith(Parameterized.class)
/* loaded from: input_file:io/hops/hadoop/shaded/org/apache/zookeeper/server/quorum/UnifiedServerSocketTest.class */
public class UnifiedServerSocketTest extends BaseX509ParameterizedTestCase {
    private static final int MAX_RETRIES = 5;
    private static final int TIMEOUT = 1000;
    private static final byte[] DATA_TO_CLIENT = "hello client".getBytes();
    private static final byte[] DATA_FROM_CLIENT = "hello server".getBytes();
    private X509Util x509Util;
    private InetSocketAddress localServerAddress;
    private final Object handshakeCompletedLock;
    private boolean handshakeCompleted;

    /* loaded from: input_file:io/hops/hadoop/shaded/org/apache/zookeeper/server/quorum/UnifiedServerSocketTest$UnifiedServerThread.class */
    private static final class UnifiedServerThread extends Thread {
        private final byte[] dataToClient;
        private List<byte[]> dataFromClients = new ArrayList();
        private ExecutorService workerPool = Executors.newCachedThreadPool();
        private UnifiedServerSocket serverSocket;

        UnifiedServerThread(X509Util x509Util, InetSocketAddress inetSocketAddress, boolean z, byte[] bArr) throws IOException {
            this.dataToClient = bArr;
            this.serverSocket = new UnifiedServerSocket(x509Util, z);
            this.serverSocket.bind(inetSocketAddress);
        }

        @Override // java.lang.Thread, java.lang.Runnable
        public void run() {
            try {
                try {
                    Random random = new Random();
                    while (true) {
                        final Socket accept = this.serverSocket.accept();
                        final boolean nextBoolean = random.nextBoolean();
                        accept.setTcpNoDelay(nextBoolean);
                        accept.setSoTimeout(1000);
                        final boolean nextBoolean2 = random.nextBoolean();
                        accept.setKeepAlive(nextBoolean2);
                        new BufferedInputStream(accept.getInputStream());
                        this.workerPool.submit(new Runnable() { // from class: io.hops.hadoop.shaded.org.apache.zookeeper.server.quorum.UnifiedServerSocketTest.UnifiedServerThread.1
                            @Override // java.lang.Runnable
                            public void run() {
                                try {
                                    try {
                                        byte[] bArr = new byte[1024];
                                        int read = accept.getInputStream().read(bArr, 0, 1024);
                                        Assert.assertEquals(Boolean.valueOf(nextBoolean), Boolean.valueOf(accept.getTcpNoDelay()));
                                        Assert.assertEquals(1000L, accept.getSoTimeout());
                                        Assert.assertEquals(Boolean.valueOf(nextBoolean2), Boolean.valueOf(accept.getKeepAlive()));
                                        if (read > 0) {
                                            byte[] bArr2 = new byte[read];
                                            System.arraycopy(bArr, 0, bArr2, 0, read);
                                            synchronized (UnifiedServerThread.this.dataFromClients) {
                                                UnifiedServerThread.this.dataFromClients.add(bArr2);
                                            }
                                        }
                                        accept.getOutputStream().write(UnifiedServerThread.this.dataToClient);
                                        accept.getOutputStream().flush();
                                        UnifiedServerSocketTest.forceClose(accept);
                                    } catch (IOException e) {
                                        throw new RuntimeException(e);
                                    }
                                } catch (Throwable th) {
                                    UnifiedServerSocketTest.forceClose(accept);
                                    throw th;
                                }
                            }
                        });
                    }
                } catch (IOException e) {
                    throw new RuntimeException(e);
                }
            } catch (Throwable th) {
                UnifiedServerSocketTest.forceClose(this.serverSocket);
                this.workerPool.shutdown();
                throw th;
            }
        }

        public void shutdown(long j) throws InterruptedException {
            UnifiedServerSocketTest.forceClose(this.serverSocket);
            this.workerPool.awaitTermination(j, TimeUnit.MILLISECONDS);
            join(j);
        }

        synchronized byte[] getDataFromClient(int i) {
            return this.dataFromClients.get(i);
        }

        synchronized boolean receivedAnyDataFromClient() {
            return !this.dataFromClients.isEmpty();
        }
    }

    @Parameterized.Parameters
    public static Collection<Object[]> params() {
        ArrayList arrayList = new ArrayList();
        int i = 0;
        for (X509KeyType x509KeyType : X509KeyType.values()) {
            for (X509KeyType x509KeyType2 : X509KeyType.values()) {
                for (Boolean bool : new Boolean[]{true, false}) {
                    int i2 = i;
                    i++;
                    arrayList.add(new Object[]{x509KeyType, x509KeyType2, bool, Integer.valueOf(i2)});
                }
            }
        }
        return arrayList;
    }

    public UnifiedServerSocketTest(X509KeyType x509KeyType, X509KeyType x509KeyType2, Boolean bool, Integer num) {
        super(num, () -> {
            try {
                return X509TestContext.newBuilder().setTempDir(tempDir).setKeyStoreKeyType(x509KeyType2).setTrustStoreKeyType(x509KeyType).setHostnameVerification(bool).build();
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        });
        this.handshakeCompletedLock = new Object();
        this.handshakeCompleted = false;
    }

    @Before
    public void setUp() throws Exception {
        this.localServerAddress = new InetSocketAddress(InetAddress.getLoopbackAddress(), PortAssignment.unique());
        this.x509Util = new ClientX509Util();
        this.x509TestContext.setSystemProperties(this.x509Util, KeyStoreFileType.JKS, KeyStoreFileType.JKS);
    }

    @After
    public void tearDown() throws Exception {
        this.x509TestContext.clearSystemProperties(this.x509Util);
        this.x509Util.close();
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static void forceClose(Socket socket) {
        if (socket == null || socket.isClosed()) {
            return;
        }
        try {
            socket.close();
        } catch (IOException e) {
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static void forceClose(ServerSocket serverSocket) {
        if (serverSocket == null || serverSocket.isClosed()) {
            return;
        }
        try {
            serverSocket.close();
        } catch (IOException e) {
        }
    }

    private SSLSocket connectWithSSL() throws IOException, X509Exception, InterruptedException {
        SSLSocket sSLSocket = null;
        for (int i = 0; i < 5; i++) {
            try {
                sSLSocket = this.x509Util.createSSLSocket();
                sSLSocket.addHandshakeCompletedListener(new HandshakeCompletedListener() { // from class: io.hops.hadoop.shaded.org.apache.zookeeper.server.quorum.UnifiedServerSocketTest.1
                    @Override // javax.net.ssl.HandshakeCompletedListener
                    public void handshakeCompleted(HandshakeCompletedEvent handshakeCompletedEvent) {
                        synchronized (UnifiedServerSocketTest.this.handshakeCompletedLock) {
                            UnifiedServerSocketTest.this.handshakeCompleted = true;
                            UnifiedServerSocketTest.this.handshakeCompletedLock.notifyAll();
                        }
                    }
                });
                sSLSocket.setSoTimeout(1000);
                sSLSocket.connect(this.localServerAddress, 1000);
                break;
            } catch (ConnectException e) {
                e.printStackTrace();
                forceClose(sSLSocket);
                sSLSocket = null;
                Thread.sleep(1000L);
            }
        }
        Assert.assertNotNull("Failed to connect to server with SSL", sSLSocket);
        return sSLSocket;
    }

    private Socket connectWithoutSSL() throws IOException, InterruptedException {
        Socket socket = null;
        for (int i = 0; i < 5; i++) {
            try {
                socket = new Socket();
                socket.setSoTimeout(1000);
                socket.connect(this.localServerAddress, 1000);
                break;
            } catch (ConnectException e) {
                e.printStackTrace();
                forceClose(socket);
                socket = null;
                Thread.sleep(1000L);
            }
        }
        Assert.assertNotNull("Failed to connect to server without SSL", socket);
        return socket;
    }

    @Test
    public void testConnectWithSSLToNonStrictServer() throws Exception {
        UnifiedServerThread unifiedServerThread = new UnifiedServerThread(this.x509Util, this.localServerAddress, true, DATA_TO_CLIENT);
        unifiedServerThread.start();
        SSLSocket connectWithSSL = connectWithSSL();
        try {
            connectWithSSL.getOutputStream().write(DATA_FROM_CLIENT);
            connectWithSSL.getOutputStream().flush();
            byte[] bArr = new byte[DATA_TO_CLIENT.length];
            Assert.assertEquals(bArr.length, connectWithSSL.getInputStream().read(bArr, 0, bArr.length));
            Assert.assertArrayEquals(DATA_TO_CLIENT, bArr);
            synchronized (this.handshakeCompletedLock) {
                if (!this.handshakeCompleted) {
                    this.handshakeCompletedLock.wait(1000L);
                }
                Assert.assertTrue(this.handshakeCompleted);
            }
            Assert.assertArrayEquals(DATA_FROM_CLIENT, unifiedServerThread.getDataFromClient(0));
            forceClose(connectWithSSL);
            unifiedServerThread.shutdown(1000L);
        } catch (Throwable th) {
            forceClose(connectWithSSL);
            unifiedServerThread.shutdown(1000L);
            throw th;
        }
    }

    @Test
    public void testConnectWithSSLToStrictServer() throws Exception {
        UnifiedServerThread unifiedServerThread = new UnifiedServerThread(this.x509Util, this.localServerAddress, false, DATA_TO_CLIENT);
        unifiedServerThread.start();
        SSLSocket connectWithSSL = connectWithSSL();
        try {
            connectWithSSL.getOutputStream().write(DATA_FROM_CLIENT);
            connectWithSSL.getOutputStream().flush();
            byte[] bArr = new byte[DATA_TO_CLIENT.length];
            Assert.assertEquals(bArr.length, connectWithSSL.getInputStream().read(bArr, 0, bArr.length));
            Assert.assertArrayEquals(DATA_TO_CLIENT, bArr);
            synchronized (this.handshakeCompletedLock) {
                if (!this.handshakeCompleted) {
                    this.handshakeCompletedLock.wait(1000L);
                }
                Assert.assertTrue(this.handshakeCompleted);
            }
            Assert.assertArrayEquals(DATA_FROM_CLIENT, unifiedServerThread.getDataFromClient(0));
            forceClose(connectWithSSL);
            unifiedServerThread.shutdown(1000L);
        } catch (Throwable th) {
            forceClose(connectWithSSL);
            unifiedServerThread.shutdown(1000L);
            throw th;
        }
    }

    @Test
    public void testConnectWithoutSSLToNonStrictServer() throws Exception {
        UnifiedServerThread unifiedServerThread = new UnifiedServerThread(this.x509Util, this.localServerAddress, true, DATA_TO_CLIENT);
        unifiedServerThread.start();
        Socket connectWithoutSSL = connectWithoutSSL();
        try {
            connectWithoutSSL.getOutputStream().write(DATA_FROM_CLIENT);
            connectWithoutSSL.getOutputStream().flush();
            byte[] bArr = new byte[DATA_TO_CLIENT.length];
            Assert.assertEquals(bArr.length, connectWithoutSSL.getInputStream().read(bArr, 0, bArr.length));
            Assert.assertArrayEquals(DATA_TO_CLIENT, bArr);
            Assert.assertArrayEquals(DATA_FROM_CLIENT, unifiedServerThread.getDataFromClient(0));
            forceClose(connectWithoutSSL);
            unifiedServerThread.shutdown(1000L);
        } catch (Throwable th) {
            forceClose(connectWithoutSSL);
            unifiedServerThread.shutdown(1000L);
            throw th;
        }
    }

    @Test
    public void testConnectWithoutSSLToNonStrictServerPartialWrite() throws Exception {
        UnifiedServerThread unifiedServerThread = new UnifiedServerThread(this.x509Util, this.localServerAddress, true, DATA_TO_CLIENT);
        unifiedServerThread.start();
        Socket connectWithoutSSL = connectWithoutSSL();
        try {
            connectWithoutSSL.getOutputStream().write(DATA_FROM_CLIENT, 0, 2);
            connectWithoutSSL.getOutputStream().flush();
            Thread.sleep(500L);
            connectWithoutSSL.getOutputStream().write(DATA_FROM_CLIENT, 2, DATA_FROM_CLIENT.length - 2);
            connectWithoutSSL.getOutputStream().flush();
            byte[] bArr = new byte[DATA_TO_CLIENT.length];
            Assert.assertEquals(bArr.length, connectWithoutSSL.getInputStream().read(bArr, 0, bArr.length));
            Assert.assertArrayEquals(DATA_TO_CLIENT, bArr);
            Assert.assertArrayEquals(DATA_FROM_CLIENT, unifiedServerThread.getDataFromClient(0));
            forceClose(connectWithoutSSL);
            unifiedServerThread.shutdown(1000L);
        } catch (Throwable th) {
            forceClose(connectWithoutSSL);
            unifiedServerThread.shutdown(1000L);
            throw th;
        }
    }

    @Test
    public void testConnectWithoutSSLToStrictServer() throws Exception {
        UnifiedServerThread unifiedServerThread = new UnifiedServerThread(this.x509Util, this.localServerAddress, false, DATA_TO_CLIENT);
        unifiedServerThread.start();
        Socket connectWithoutSSL = connectWithoutSSL();
        connectWithoutSSL.getOutputStream().write(DATA_FROM_CLIENT);
        connectWithoutSSL.getOutputStream().flush();
        byte[] bArr = new byte[DATA_TO_CLIENT.length];
        try {
            if (connectWithoutSSL.getInputStream().read(bArr, 0, bArr.length) == -1) {
                forceClose(connectWithoutSSL);
                unifiedServerThread.shutdown(1000L);
                Assert.assertFalse("The strict server accepted connection without SSL.", unifiedServerThread.receivedAnyDataFromClient());
            } else {
                forceClose(connectWithoutSSL);
                unifiedServerThread.shutdown(1000L);
                Assert.assertFalse("The strict server accepted connection without SSL.", unifiedServerThread.receivedAnyDataFromClient());
                Assert.fail("Expected server to hang up the connection. Read from server succeeded unexpectedly.");
            }
        } catch (SocketException e) {
            forceClose(connectWithoutSSL);
            unifiedServerThread.shutdown(1000L);
            Assert.assertFalse("The strict server accepted connection without SSL.", unifiedServerThread.receivedAnyDataFromClient());
        } catch (Throwable th) {
            forceClose(connectWithoutSSL);
            unifiedServerThread.shutdown(1000L);
            Assert.assertFalse("The strict server accepted connection without SSL.", unifiedServerThread.receivedAnyDataFromClient());
            throw th;
        }
    }

    @Test
    public void testTLSDetectionNonBlockingNonStrictServerIdleClient() throws Exception {
        Socket socket = null;
        Socket socket2 = null;
        SSLSocket sSLSocket = null;
        UnifiedServerThread unifiedServerThread = new UnifiedServerThread(this.x509Util, this.localServerAddress, true, DATA_TO_CLIENT);
        unifiedServerThread.start();
        try {
            socket = connectWithoutSSL();
            socket2 = connectWithoutSSL();
            socket2.getOutputStream().write(DATA_FROM_CLIENT);
            socket2.getOutputStream().flush();
            byte[] bArr = new byte[DATA_TO_CLIENT.length];
            Assert.assertEquals(bArr.length, socket2.getInputStream().read(bArr, 0, bArr.length));
            Assert.assertArrayEquals(DATA_TO_CLIENT, bArr);
            Assert.assertArrayEquals(DATA_FROM_CLIENT, unifiedServerThread.getDataFromClient(0));
            synchronized (this.handshakeCompletedLock) {
                Assert.assertFalse(this.handshakeCompleted);
            }
            sSLSocket = connectWithSSL();
            sSLSocket.getOutputStream().write(DATA_FROM_CLIENT);
            sSLSocket.getOutputStream().flush();
            byte[] bArr2 = new byte[DATA_TO_CLIENT.length];
            Assert.assertEquals(bArr2.length, sSLSocket.getInputStream().read(bArr2, 0, bArr2.length));
            Assert.assertArrayEquals(DATA_TO_CLIENT, bArr2);
            Assert.assertArrayEquals(DATA_FROM_CLIENT, unifiedServerThread.getDataFromClient(1));
            synchronized (this.handshakeCompletedLock) {
                if (!this.handshakeCompleted) {
                    this.handshakeCompletedLock.wait(1000L);
                }
                Assert.assertTrue(this.handshakeCompleted);
            }
            forceClose(socket);
            forceClose(socket2);
            forceClose(sSLSocket);
            unifiedServerThread.shutdown(1000L);
        } catch (Throwable th) {
            forceClose(socket);
            forceClose(socket2);
            forceClose(sSLSocket);
            unifiedServerThread.shutdown(1000L);
            throw th;
        }
    }

    @Test
    public void testTLSDetectionNonBlockingStrictServerIdleClient() throws Exception {
        Socket socket = null;
        SSLSocket sSLSocket = null;
        UnifiedServerThread unifiedServerThread = new UnifiedServerThread(this.x509Util, this.localServerAddress, false, DATA_TO_CLIENT);
        unifiedServerThread.start();
        try {
            socket = connectWithoutSSL();
            sSLSocket = connectWithSSL();
            sSLSocket.getOutputStream().write(DATA_FROM_CLIENT);
            sSLSocket.getOutputStream().flush();
            byte[] bArr = new byte[DATA_TO_CLIENT.length];
            Assert.assertEquals(bArr.length, sSLSocket.getInputStream().read(bArr, 0, bArr.length));
            Assert.assertArrayEquals(DATA_TO_CLIENT, bArr);
            synchronized (this.handshakeCompletedLock) {
                if (!this.handshakeCompleted) {
                    this.handshakeCompletedLock.wait(1000L);
                }
                Assert.assertTrue(this.handshakeCompleted);
            }
            Assert.assertArrayEquals(DATA_FROM_CLIENT, unifiedServerThread.getDataFromClient(0));
            forceClose(socket);
            forceClose(sSLSocket);
            unifiedServerThread.shutdown(1000L);
        } catch (Throwable th) {
            forceClose(socket);
            forceClose(sSLSocket);
            unifiedServerThread.shutdown(1000L);
            throw th;
        }
    }

    @Test
    public void testTLSDetectionNonBlockingNonStrictServerDisconnectedClient() throws Exception {
        Socket socket = null;
        SSLSocket sSLSocket = null;
        UnifiedServerThread unifiedServerThread = new UnifiedServerThread(this.x509Util, this.localServerAddress, true, DATA_TO_CLIENT);
        unifiedServerThread.start();
        try {
            forceClose(connectWithoutSSL());
            socket = connectWithoutSSL();
            socket.getOutputStream().write(DATA_FROM_CLIENT);
            socket.getOutputStream().flush();
            byte[] bArr = new byte[DATA_TO_CLIENT.length];
            Assert.assertEquals(bArr.length, socket.getInputStream().read(bArr, 0, bArr.length));
            Assert.assertArrayEquals(DATA_TO_CLIENT, bArr);
            Assert.assertArrayEquals(DATA_FROM_CLIENT, unifiedServerThread.getDataFromClient(0));
            synchronized (this.handshakeCompletedLock) {
                Assert.assertFalse(this.handshakeCompleted);
            }
            sSLSocket = connectWithSSL();
            sSLSocket.getOutputStream().write(DATA_FROM_CLIENT);
            sSLSocket.getOutputStream().flush();
            byte[] bArr2 = new byte[DATA_TO_CLIENT.length];
            Assert.assertEquals(bArr2.length, sSLSocket.getInputStream().read(bArr2, 0, bArr2.length));
            Assert.assertArrayEquals(DATA_TO_CLIENT, bArr2);
            Assert.assertArrayEquals(DATA_FROM_CLIENT, unifiedServerThread.getDataFromClient(1));
            synchronized (this.handshakeCompletedLock) {
                if (!this.handshakeCompleted) {
                    this.handshakeCompletedLock.wait(1000L);
                }
                Assert.assertTrue(this.handshakeCompleted);
            }
            forceClose(socket);
            forceClose(sSLSocket);
            unifiedServerThread.shutdown(1000L);
        } catch (Throwable th) {
            forceClose(socket);
            forceClose(sSLSocket);
            unifiedServerThread.shutdown(1000L);
            throw th;
        }
    }

    @Test
    public void testTLSDetectionNonBlockingStrictServerDisconnectedClient() throws Exception {
        SSLSocket sSLSocket = null;
        UnifiedServerThread unifiedServerThread = new UnifiedServerThread(this.x509Util, this.localServerAddress, false, DATA_TO_CLIENT);
        unifiedServerThread.start();
        try {
            forceClose(connectWithoutSSL());
            sSLSocket = connectWithSSL();
            sSLSocket.getOutputStream().write(DATA_FROM_CLIENT);
            sSLSocket.getOutputStream().flush();
            byte[] bArr = new byte[DATA_TO_CLIENT.length];
            Assert.assertEquals(bArr.length, sSLSocket.getInputStream().read(bArr, 0, bArr.length));
            Assert.assertArrayEquals(DATA_TO_CLIENT, bArr);
            synchronized (this.handshakeCompletedLock) {
                if (!this.handshakeCompleted) {
                    this.handshakeCompletedLock.wait(1000L);
                }
                Assert.assertTrue(this.handshakeCompleted);
            }
            Assert.assertArrayEquals(DATA_FROM_CLIENT, unifiedServerThread.getDataFromClient(0));
            forceClose(sSLSocket);
            unifiedServerThread.shutdown(1000L);
        } catch (Throwable th) {
            forceClose(sSLSocket);
            unifiedServerThread.shutdown(1000L);
            throw th;
        }
    }
}
