package io.hops.hopsworks.common.serving.tf;

import com.google.common.io.Files;
import io.hops.hopsworks.common.dao.project.Project;
import io.hops.hopsworks.common.dao.serving.TfServing;
import io.hops.hopsworks.common.dao.serving.TfServingFacade;
import io.hops.hopsworks.common.dao.user.Users;
import io.hops.hopsworks.common.exception.KafkaException;
import io.hops.hopsworks.common.exception.ProjectException;
import io.hops.hopsworks.common.exception.RESTCodes;
import io.hops.hopsworks.common.exception.ServiceException;
import io.hops.hopsworks.common.exception.UserException;
import io.hops.hopsworks.common.security.CertificateMaterializer;
import io.hops.hopsworks.common.serving.KafkaServingHelper;
import io.hops.hopsworks.common.util.Settings;
import java.io.IOException;
import java.nio.charset.Charset;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Date;
import java.util.Iterator;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.ThreadLocalRandom;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.ejb.EJB;
import javax.ejb.Stateless;
import javax.ejb.TransactionAttribute;
import javax.ejb.TransactionAttributeType;
import javax.enterprise.inject.Alternative;

@Alternative
@TransactionAttribute(TransactionAttributeType.NOT_SUPPORTED)
@Stateless
/* loaded from: input_file:io/hops/hopsworks/common/serving/tf/LocalhostTfServingController.class */
public class LocalhostTfServingController implements TfServingController {
    public static final String SERVING_DIRS = "/serving/";

    @EJB
    private TfServingFacade tfServingFacade;

    @EJB
    private Settings settings;

    @EJB
    private CertificateMaterializer certificateMaterializer;

    @EJB
    private KafkaServingHelper kafkaServingHelper;
    private static final Logger logger = Logger.getLogger(LocalhostTfServingController.class.getName());
    public static final Integer PID_STOPPED = -2;

    @Override // io.hops.hopsworks.common.serving.tf.TfServingController
    public List<TfServingWrapper> getTfServings(Project project) throws TfServingException {
        List<TfServing> findForProject = this.tfServingFacade.findForProject(project);
        ArrayList arrayList = new ArrayList();
        Iterator<TfServing> it = findForProject.iterator();
        while (it.hasNext()) {
            arrayList.add(getTfServingInternal(it.next()));
        }
        return arrayList;
    }

    @Override // io.hops.hopsworks.common.serving.tf.TfServingController
    public TfServingWrapper getTfServing(Project project, Integer num) throws TfServingException {
        TfServing findByProjectAndId = this.tfServingFacade.findByProjectAndId(project, num);
        if (findByProjectAndId == null) {
            return null;
        }
        return getTfServingInternal(findByProjectAndId);
    }

    @Override // io.hops.hopsworks.common.serving.tf.TfServingController
    public void deleteTfServings(Project project) throws TfServingException {
        for (TfServing tfServing : this.tfServingFacade.findForProject(project)) {
            this.tfServingFacade.acquireLock(project, tfServing.getId());
            if (!getTfServingStatus(tfServing).equals(TfServingStatusEnum.STARTING)) {
                killTfServingInstance(project, tfServing, false);
            }
            this.tfServingFacade.delete(tfServing);
        }
    }

    @Override // io.hops.hopsworks.common.serving.tf.TfServingController
    public void deleteTfServing(Project project, Integer num) throws TfServingException {
        TfServing acquireLock = this.tfServingFacade.acquireLock(project, num);
        if (!getTfServingStatus(acquireLock).equals(TfServingStatusEnum.STARTING)) {
            killTfServingInstance(project, acquireLock, false);
        }
        this.tfServingFacade.delete(acquireLock);
    }

    @Override // io.hops.hopsworks.common.serving.tf.TfServingController
    public void checkDuplicates(Project project, TfServingWrapper tfServingWrapper) throws TfServingException {
        TfServing findByProjectModelName = this.tfServingFacade.findByProjectModelName(project, tfServingWrapper.getTfServing().getModelName());
        if (findByProjectModelName != null && !findByProjectModelName.getId().equals(tfServingWrapper.getTfServing().getId())) {
            throw new TfServingException(RESTCodes.TfServingErrorCode.DUPLICATEDENTRY, Level.FINE);
        }
    }

    @Override // io.hops.hopsworks.common.serving.tf.TfServingController
    public void createOrUpdate(Project project, Users users, TfServingWrapper tfServingWrapper) throws KafkaException, UserException, ProjectException, ServiceException, TfServingException {
        TfServing tfServing = tfServingWrapper.getTfServing();
        if (tfServing.getId() == null) {
            tfServing.setCreated(new Date());
            tfServing.setCreator(users);
            tfServing.setProject(project);
            tfServing.setLocalDir(UUID.randomUUID().toString());
            tfServing.setLocalPid(PID_STOPPED);
            tfServing.setInstances(1);
            this.kafkaServingHelper.setupKafkaServingTopic(project, tfServingWrapper, tfServing, null);
            this.tfServingFacade.merge(tfServing);
            return;
        }
        TfServing acquireLock = this.tfServingFacade.acquireLock(project, tfServing.getId());
        TfServingStatusEnum tfServingStatus = getTfServingStatus(acquireLock);
        this.kafkaServingHelper.setupKafkaServingTopic(project, tfServingWrapper, tfServing, acquireLock);
        TfServing updateDbObject = this.tfServingFacade.updateDbObject(tfServing, project);
        if (tfServingStatus != TfServingStatusEnum.RUNNING && tfServingStatus != TfServingStatusEnum.UPDATING) {
            this.tfServingFacade.releaseLock(project, tfServing.getId());
            return;
        }
        if (acquireLock.getModelName().equals(updateDbObject.getModelName()) && acquireLock.getModelPath().equals(updateDbObject.getModelPath()) && acquireLock.isBatchingEnabled() == updateDbObject.isBatchingEnabled() && acquireLock.getVersion().intValue() <= updateDbObject.getVersion().intValue()) {
            updateModelVersion(project, users, updateDbObject);
        } else {
            restartTfServingInstance(project, users, acquireLock, updateDbObject);
        }
    }

    @Override // io.hops.hopsworks.common.serving.tf.TfServingController
    public void startOrStop(Project project, Users users, Integer num, TfServingCommands tfServingCommands) throws TfServingException {
        TfServing acquireLock = this.tfServingFacade.acquireLock(project, num);
        TfServingStatusEnum tfServingStatus = getTfServingStatus(acquireLock);
        if (tfServingStatus == TfServingStatusEnum.STARTING && tfServingCommands == TfServingCommands.START) {
            startTfServingInstance(project, users, acquireLock);
        } else if (tfServingStatus == TfServingStatusEnum.UPDATING && tfServingCommands == TfServingCommands.STOP) {
            killTfServingInstance(project, acquireLock, true);
        } else {
            this.tfServingFacade.releaseLock(project, num);
            throw new TfServingException(RESTCodes.TfServingErrorCode.LIFECYCLEERROR, Level.FINE, "Instance is already " + (tfServingCommands == TfServingCommands.START ? "started" : "stopped"));
        }
    }

    @Override // io.hops.hopsworks.common.serving.tf.TfServingController
    public int getMaxNumInstances() {
        return 1;
    }

    @Override // io.hops.hopsworks.common.serving.tf.TfServingController
    public String getClassName() {
        return LocalhostTfServingController.class.getName();
    }

    private TfServingWrapper getTfServingInternal(TfServing tfServing) {
        TfServingWrapper tfServingWrapper = new TfServingWrapper(tfServing);
        TfServingStatusEnum tfServingStatus = getTfServingStatus(tfServing);
        tfServingWrapper.setStatus(tfServingStatus);
        switch (tfServingStatus) {
            case STOPPED:
            case STARTING:
            case UPDATING:
                tfServingWrapper.setAvailableReplicas(0);
                break;
            case RUNNING:
                tfServingWrapper.setAvailableReplicas(1);
                tfServingWrapper.setNodePort(tfServing.getLocalPort());
                break;
        }
        tfServingWrapper.setKafkaTopicDTO(this.kafkaServingHelper.buildTopicDTO(tfServing));
        return tfServingWrapper;
    }

    private TfServingStatusEnum getTfServingStatus(TfServing tfServing) {
        return (tfServing.getLocalPid().equals(PID_STOPPED) && tfServing.getLockIP() == null) ? TfServingStatusEnum.STOPPED : tfServing.getLocalPid().equals(PID_STOPPED) ? TfServingStatusEnum.STARTING : (tfServing.getLocalPid().equals(PID_STOPPED) || tfServing.getLockIP() != null) ? TfServingStatusEnum.UPDATING : TfServingStatusEnum.RUNNING;
    }

    private void updateModelVersion(Project project, Users users, TfServing tfServing) throws TfServingException {
        String[] strArr = {"/usr/bin/sudo", this.settings.getHopsworksDomainDir() + "/bin/tfserving.sh", "update", tfServing.getModelName(), Paths.get(tfServing.getModelPath(), tfServing.getVersion().toString()).toString(), Paths.get(this.settings.getStagingDir(), "/serving/", tfServing.getLocalDir()).toString(), project.getName() + "__" + users.getUsername()};
        logger.log(Level.INFO, Arrays.toString(strArr));
        ProcessBuilder processBuilder = new ProcessBuilder(strArr);
        try {
            if (this.settings.getHopsRpcTls()) {
                try {
                    this.certificateMaterializer.materializeCertificatesLocal(users.getUsername(), project.getName());
                    this.tfServingFacade.releaseLock(project, tfServing.getId());
                } catch (IOException e) {
                    throw new TfServingException(RESTCodes.TfServingErrorCode.LIFECYCLEERRORINT, Level.SEVERE, null, e.getMessage(), e);
                }
            }
            try {
                try {
                    processBuilder.start().waitFor();
                    if (this.settings.getHopsRpcTls()) {
                        this.certificateMaterializer.removeCertificatesLocal(users.getUsername(), project.getName());
                    }
                    this.tfServingFacade.releaseLock(project, tfServing.getId());
                } catch (IOException | InterruptedException e2) {
                    throw new TfServingException(RESTCodes.TfServingErrorCode.UPDATEERROR, Level.SEVERE, "tfServing id: " + tfServing.getId(), e2.getMessage(), e2);
                }
            } catch (Throwable th) {
                if (this.settings.getHopsRpcTls()) {
                    this.certificateMaterializer.removeCertificatesLocal(users.getUsername(), project.getName());
                }
                throw th;
            }
        } finally {
            this.tfServingFacade.releaseLock(project, tfServing.getId());
        }
    }

    private void killTfServingInstance(Project project, TfServing tfServing, boolean z) throws TfServingException {
        String[] strArr = {"/usr/bin/sudo", this.settings.getHopsworksDomainDir() + "/bin/tfserving.sh", "kill", String.valueOf(tfServing.getLocalPid()), String.valueOf(tfServing.getLocalPort()), Paths.get(this.settings.getStagingDir(), "/serving/" + tfServing.getLocalDir()).toString()};
        logger.log(Level.INFO, Arrays.toString(strArr));
        try {
            new ProcessBuilder(strArr).start().waitFor();
            tfServing.setLocalPid(PID_STOPPED);
            tfServing.setLocalPort(-1);
            this.tfServingFacade.updateDbObject(tfServing, project);
            if (z) {
                this.tfServingFacade.releaseLock(project, tfServing.getId());
            }
        } catch (IOException | InterruptedException e) {
            throw new TfServingException(RESTCodes.TfServingErrorCode.LIFECYCLEERROR, Level.SEVERE, "tfServing id: " + tfServing.getId(), e.getMessage(), e);
        }
    }

    private void startTfServingInstance(Project project, Users users, TfServing tfServing) throws TfServingException {
        String str = this.settings.getHopsworksDomainDir() + "/bin/tfserving.sh";
        Integer valueOf = Integer.valueOf(ThreadLocalRandom.current().nextInt(40000, 59999));
        Integer valueOf2 = Integer.valueOf(ThreadLocalRandom.current().nextInt(40000, 59999));
        Path path = Paths.get(this.settings.getStagingDir(), "/serving/" + tfServing.getLocalDir());
        String[] strArr = new String[11];
        strArr[0] = "/usr/bin/sudo";
        strArr[1] = str;
        strArr[2] = "start";
        strArr[3] = tfServing.getModelName();
        strArr[4] = Paths.get(tfServing.getModelPath(), tfServing.getVersion().toString()).toString();
        strArr[5] = String.valueOf(valueOf);
        strArr[6] = String.valueOf(valueOf2);
        strArr[7] = path.toString();
        strArr[8] = project.getName() + "__" + users.getUsername();
        strArr[9] = tfServing.isBatchingEnabled().booleanValue() ? "1" : "0";
        strArr[10] = project.getName();
        logger.log(Level.INFO, Arrays.toString(strArr));
        try {
            if (this.settings.getHopsRpcTls()) {
                try {
                    this.certificateMaterializer.materializeCertificatesLocal(users.getUsername(), project.getName());
                    this.tfServingFacade.releaseLock(project, tfServing.getId());
                } catch (IOException e) {
                    throw new TfServingException(RESTCodes.TfServingErrorCode.LIFECYCLEERRORINT, Level.SEVERE, null, e.getMessage(), e);
                }
            }
            ProcessBuilder processBuilder = new ProcessBuilder(strArr);
            Process process = null;
            try {
                try {
                    processBuilder.redirectErrorStream(true);
                    Process start = processBuilder.start();
                    start.waitFor();
                    if (start.exitValue() != 0) {
                        tfServing.setLocalPid(PID_STOPPED);
                        this.tfServingFacade.updateDbObject(tfServing, project);
                        throw new TfServingException(RESTCodes.TfServingErrorCode.LIFECYCLEERRORINT, Level.INFO);
                    }
                    tfServing.setLocalPid(Integer.valueOf(Files.readFirstLine(Paths.get(path.toString(), "tfserving.pid").toFile(), Charset.defaultCharset())));
                    tfServing.setLocalPort(valueOf2);
                    this.tfServingFacade.updateDbObject(tfServing, project);
                    if (this.settings.getHopsRpcTls()) {
                        this.certificateMaterializer.removeCertificatesLocal(users.getUsername(), project.getName());
                    }
                    this.tfServingFacade.releaseLock(project, tfServing.getId());
                } catch (Exception e2) {
                    if (0 != 0) {
                        process.destroyForcibly();
                    }
                    tfServing.setLocalPid(PID_STOPPED);
                    this.tfServingFacade.updateDbObject(tfServing, project);
                    throw new TfServingException(RESTCodes.TfServingErrorCode.LIFECYCLEERRORINT, Level.SEVERE, null, e2.getMessage(), e2);
                }
            } catch (Throwable th) {
                if (this.settings.getHopsRpcTls()) {
                    this.certificateMaterializer.removeCertificatesLocal(users.getUsername(), project.getName());
                }
                throw th;
            }
        } finally {
            this.tfServingFacade.releaseLock(project, tfServing.getId());
        }
    }

    private void restartTfServingInstance(Project project, Users users, TfServing tfServing, TfServing tfServing2) throws TfServingException {
        killTfServingInstance(project, tfServing, false);
        startTfServingInstance(project, users, tfServing2);
    }
}
