package io.hops.hopsworks.common.tensorflow;

import com.google.common.base.Strings;
import io.hops.hopsworks.common.dao.kafka.KafkaConst;
import io.hops.hopsworks.common.dao.project.Project;
import io.hops.hopsworks.common.dao.python.CondaCommandFacade;
import io.hops.hopsworks.common.dao.python.CondaCommands;
import io.hops.hopsworks.common.dao.tensorflow.TfLibMapping;
import io.hops.hopsworks.common.dao.tensorflow.TfLibMappingFacade;
import io.hops.hopsworks.common.python.environment.EnvironmentController;
import io.hops.hopsworks.common.util.Settings;
import java.io.File;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import javax.ejb.EJB;
import javax.ejb.Stateless;
import javax.ejb.TransactionAttribute;
import javax.ejb.TransactionAttributeType;

@TransactionAttribute(TransactionAttributeType.NEVER)
@Stateless
/* loaded from: input_file:io/hops/hopsworks/common/tensorflow/TfLibMappingUtil.class */
public class TfLibMappingUtil {
    private static final String LIB_PATH = "/usr/local";
    private static final String CUDA_BASE_PATH = "/usr/local/cuda-";
    private static final String CUDNN_BASE_PATH = "/usr/local/cudnn-";
    private static final String NCCL_BASE_PATH = "/usr/local/nccl";
    private static final String ROCM_RCCL_PATH = "/site-packages/tensorflow/include/external/local_config_rocm/rocm/rocm/lib";

    @EJB
    private TfLibMappingFacade tfLibMappingFacade;

    @EJB
    private Settings settings;

    @EJB
    private EnvironmentController environmentController;

    private String buildTfLdLibraryPath(TfLibMapping tfLibMapping, Project project) {
        StringBuilder sb = new StringBuilder();
        if (!Strings.isNullOrEmpty(tfLibMapping.getCudnnVersion())) {
            sb.append(CUDNN_BASE_PATH + tfLibMapping.getCudnnVersion() + "/lib64" + File.pathSeparator);
        }
        if (!Strings.isNullOrEmpty(tfLibMapping.getCudaVersion())) {
            sb.append(CUDA_BASE_PATH + tfLibMapping.getCudaVersion() + "/lib64" + File.pathSeparator);
            sb.append(CUDA_BASE_PATH + tfLibMapping.getCudaVersion() + "/extras/CUPTI/lib64" + File.pathSeparator);
        }
        if (!Strings.isNullOrEmpty(tfLibMapping.getNcclVersion())) {
            sb.append(NCCL_BASE_PATH + tfLibMapping.getNcclVersion() + "/lib" + File.pathSeparator);
        }
        sb.append(this.settings.getAnacondaProjectDir(project) + "/lib/python" + project.getPythonVersion() + ROCM_RCCL_PATH + File.pathSeparator);
        return sb.toString();
    }

    public String getTfLdLibraryPath(Project project) {
        TfLibMapping findTfMappingForProject = findTfMappingForProject(project);
        return findTfMappingForProject == null ? KafkaConst.KAFKA_ENDPOINT_IDENTIFICATION_ALGORITHM : buildTfLdLibraryPath(findTfMappingForProject, project);
    }

    public TfLibMapping findTfMappingForProject(Project project) {
        if (!project.getCondaEnv().booleanValue()) {
            return this.tfLibMappingFacade.findByTfVersion(this.settings.getTensorflowVersion());
        }
        CondaCommands ongoingEnvCreation = this.environmentController.getOngoingEnvCreation(project);
        if (ongoingEnvCreation == null) {
            return (TfLibMapping) project.getPythonDepCollection().stream().filter(pythonDep -> {
                return pythonDep.getDependency().equals("tensorflow") || pythonDep.getDependency().equals("tensorflow-gpu") || pythonDep.getDependency().equals("tensorflow-rocm");
            }).findAny().map(pythonDep2 -> {
                return this.tfLibMappingFacade.findByTfVersion(pythonDep2.getVersion());
            }).orElse(null);
        }
        if (ongoingEnvCreation.getOp().compareTo(CondaCommandFacade.CondaOp.CREATE) == 0) {
            return this.tfLibMappingFacade.findByTfVersion(this.settings.getTensorflowVersion());
        }
        if (ongoingEnvCreation.getOp().compareTo(CondaCommandFacade.CondaOp.YML) != 0) {
            return null;
        }
        String environmentYml = ongoingEnvCreation.getEnvironmentYml();
        Matcher matcher = Pattern.compile("(tensorflow==\\d*.\\d*.\\d*)").matcher(environmentYml);
        if (matcher.find()) {
            return this.tfLibMappingFacade.findByTfVersion(matcher.group(0).split("==")[1]);
        }
        Matcher matcher2 = Pattern.compile("(tensorflow-gpu==\\d*.\\d*.\\d*)").matcher(environmentYml);
        if (matcher2.find()) {
            return this.tfLibMappingFacade.findByTfVersion(matcher2.group(0).split("==")[1]);
        }
        if (!Pattern.compile("(tensorflow-rocm==\\d*.\\d*.\\d*)").matcher(environmentYml).find()) {
            return null;
        }
        return this.tfLibMappingFacade.findByTfVersion(matcher2.group(0).split("==")[1]);
    }
}
