/*
 * Decompiled with CFR 0.152.
 */
package io.hops.hopsworks.common.tensorflow;

import com.google.common.base.Strings;
import io.hops.hopsworks.common.dao.project.Project;
import io.hops.hopsworks.common.dao.python.CondaCommandFacade;
import io.hops.hopsworks.common.dao.python.CondaCommands;
import io.hops.hopsworks.common.dao.tensorflow.TfLibMapping;
import io.hops.hopsworks.common.dao.tensorflow.TfLibMappingFacade;
import io.hops.hopsworks.common.python.environment.EnvironmentController;
import io.hops.hopsworks.common.util.Settings;
import java.io.File;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import javax.ejb.EJB;
import javax.ejb.Stateless;
import javax.ejb.TransactionAttribute;
import javax.ejb.TransactionAttributeType;

@Stateless
@TransactionAttribute(value=TransactionAttributeType.NEVER)
public class TfLibMappingUtil {
    private static final String LIB_PATH = "/usr/local";
    private static final String CUDA_BASE_PATH = "/usr/local/cuda-";
    private static final String CUDNN_BASE_PATH = "/usr/local/cudnn-";
    private static final String NCCL_BASE_PATH = "/usr/local/nccl";
    private static final String ROCM_RCCL_PATH = "/site-packages/tensorflow/include/external/local_config_rocm/rocm/rocm/lib";
    @EJB
    private TfLibMappingFacade tfLibMappingFacade;
    @EJB
    private Settings settings;
    @EJB
    private EnvironmentController environmentController;

    private String buildTfLdLibraryPath(TfLibMapping tfLibMapping, Project project) {
        StringBuilder ldPathBuilder = new StringBuilder();
        if (!Strings.isNullOrEmpty((String)tfLibMapping.getCudnnVersion())) {
            ldPathBuilder.append(CUDNN_BASE_PATH + tfLibMapping.getCudnnVersion() + "/lib64" + File.pathSeparator);
        }
        if (!Strings.isNullOrEmpty((String)tfLibMapping.getCudaVersion())) {
            ldPathBuilder.append(CUDA_BASE_PATH + tfLibMapping.getCudaVersion() + "/lib64" + File.pathSeparator);
            ldPathBuilder.append(CUDA_BASE_PATH + tfLibMapping.getCudaVersion() + "/extras/CUPTI/lib64" + File.pathSeparator);
        }
        if (!Strings.isNullOrEmpty((String)tfLibMapping.getNcclVersion())) {
            ldPathBuilder.append(NCCL_BASE_PATH + tfLibMapping.getNcclVersion() + "/lib" + File.pathSeparator);
        }
        ldPathBuilder.append(this.settings.getAnacondaProjectDir(project) + "/lib/python" + project.getPythonVersion() + ROCM_RCCL_PATH + File.pathSeparator);
        return ldPathBuilder.toString();
    }

    public String getTfLdLibraryPath(Project project) {
        TfLibMapping tfLibMapping = this.findTfMappingForProject(project);
        if (tfLibMapping == null) {
            return "";
        }
        return this.buildTfLdLibraryPath(tfLibMapping, project);
    }

    public TfLibMapping findTfMappingForProject(Project project) {
        if (!project.getCondaEnv().booleanValue()) {
            return this.tfLibMappingFacade.findByTfVersion(this.settings.getTensorflowVersion());
        }
        CondaCommands command = this.environmentController.getOngoingEnvCreation(project);
        if (command == null) {
            return project.getPythonDepCollection().stream().filter(dep -> dep.getDependency().equals("tensorflow") || dep.getDependency().equals("tensorflow-gpu") || dep.getDependency().equals("tensorflow-rocm")).findAny().map(tfDep -> this.tfLibMappingFacade.findByTfVersion(tfDep.getVersion())).orElse(null);
        }
        if (command.getOp().compareTo(CondaCommandFacade.CondaOp.CREATE) == 0) {
            return this.tfLibMappingFacade.findByTfVersion(this.settings.getTensorflowVersion());
        }
        if (command.getOp().compareTo(CondaCommandFacade.CondaOp.YML) == 0) {
            String envYml = command.getEnvironmentYml();
            Pattern tfCPUPattern = Pattern.compile("(tensorflow==\\d*.\\d*.\\d*)");
            Matcher tfCPUMatcher = tfCPUPattern.matcher(envYml);
            if (tfCPUMatcher.find()) {
                String[] libVersionPair = tfCPUMatcher.group(0).split("==");
                return this.tfLibMappingFacade.findByTfVersion(libVersionPair[1]);
            }
            Pattern tfGPUPattern = Pattern.compile("(tensorflow-gpu==\\d*.\\d*.\\d*)");
            Matcher tfGPUMatcher = tfGPUPattern.matcher(envYml);
            if (tfGPUMatcher.find()) {
                String[] libVersionPair = tfGPUMatcher.group(0).split("==");
                return this.tfLibMappingFacade.findByTfVersion(libVersionPair[1]);
            }
            Pattern tfRocmPattern = Pattern.compile("(tensorflow-rocm==\\d*.\\d*.\\d*)");
            Matcher tfRocmMatcher = tfRocmPattern.matcher(envYml);
            if (tfRocmMatcher.find()) {
                String[] libVersionPair = tfGPUMatcher.group(0).split("==");
                return this.tfLibMappingFacade.findByTfVersion(libVersionPair[1]);
            }
        }
        return null;
    }
}

