package io.hops.devices;

import com.google.common.annotations.VisibleForTesting;
import io.hops.GPUManagementLibrary;
import io.hops.exceptions.GPUManagementLibraryException;
import java.io.IOException;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.TreeSet;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.yarn.server.nodemanager.util.NodeManagerHardwareUtils;

/* loaded from: input_file:io/hops/devices/GPUAllocator.class */
public class GPUAllocator {
    private GPUManagementLibrary gpuManagementLibrary;
    private static final String GPU_MANAGEMENT_LIBRARY_CLASSNAME = "io.hops.management.nvidia.NvidiaManagementLibrary";
    private static final int NVIDIA_GPU_MAJOR_DEVICE_NUMBER = 195;
    static final Log LOG = LogFactory.getLog(GPUAllocator.class);
    private static final GPUAllocator gpuAllocator = new GPUAllocator();
    private static final Pattern DEVICES_LIST_FORMAT = Pattern.compile("([^\\s]+)+\\s([\\d+:\\d]+)+\\s([^\\s]+)");
    private boolean initialized = false;
    private HashSet<Device> configuredAvailableGPUs = new HashSet<>();
    private HashSet<Device> totalGPUs = new HashSet<>();
    private HashMap<String, HashSet<Device>> containerGPUAllocationMapping = new HashMap<>();
    private HashSet<Device> mandatoryDrivers = new HashSet<>();

    public static GPUAllocator getInstance() {
        return gpuAllocator;
    }

    private GPUAllocator() {
        try {
            this.gpuManagementLibrary = io.hops.GPUManagementLibraryLoader.load(GPU_MANAGEMENT_LIBRARY_CLASSNAME);
        } catch (GPUManagementLibraryException | UnsatisfiedLinkError e) {
            LOG.error("Could not load GPU management library. Is this NodeManager supposed to offer its GPUs as a resource? If yes, check installation and make sure hopsnvml-1.0 is present in java.library.path", e);
        }
    }

    @VisibleForTesting
    public GPUAllocator(GPUManagementLibrary gPUManagementLibrary, Configuration configuration) {
        this.gpuManagementLibrary = gPUManagementLibrary;
        initialize(configuration);
    }

    public boolean initialize(Configuration configuration) {
        if (!this.initialized) {
            this.initialized = this.gpuManagementLibrary.initialize();
            int nodeGPUs = NodeManagerHardwareUtils.getNodeGPUs(configuration);
            try {
                initMandatoryDrivers();
                initConfiguredGPUs(nodeGPUs);
                initTotalGPUs();
            } catch (IOException e) {
                LOG.error("Could not initialize GPUAllocator", e);
            }
        }
        return this.initialized;
    }

    public boolean isInitialized() {
        return this.initialized;
    }

    public boolean shutDown() {
        if (this.initialized) {
            return this.gpuManagementLibrary.shutDown();
        }
        return false;
    }

    private void initMandatoryDrivers() throws IOException {
        String queryMandatoryDevices = this.gpuManagementLibrary.queryMandatoryDevices();
        if (queryMandatoryDevices.equals("")) {
            throw new IOException("Could not discover device numbers for GPU drivers, check driver installation and make sure libnvidia-ml.so.1 is present on LD_LIBRARY_PATH, or disable GPUs in the configuration");
        }
        for (String str : queryMandatoryDevices.split(" ")) {
            String[] split = str.split(":");
            try {
                Device device = new Device(Integer.parseInt(split[0]), Integer.parseInt(split[1]));
                this.mandatoryDrivers.add(device);
                LOG.info("Found mandatory GPU driver " + device.toString());
            } catch (NumberFormatException e) {
                LOG.error("Unexpected format for major:minor device numbers: " + split[0] + ":" + split[1]);
            }
        }
    }

    private void initConfiguredGPUs(int i) throws IOException {
        String queryAvailableDevices = this.gpuManagementLibrary.queryAvailableDevices(i);
        if (queryAvailableDevices.equals("")) {
            throw new IOException("Could not discover GPU device numbers, either you enabled GPUs but set NM_GPUS to 0, or there is a problem with the installation, check that libnvidia-ml.so.1 is present on LD_LIBRARY_PATH");
        }
        for (String str : queryAvailableDevices.split(" ")) {
            String[] split = str.split(":");
            try {
                Device device = new Device(Integer.parseInt(split[0]), Integer.parseInt(split[1]));
                this.configuredAvailableGPUs.add(device);
                LOG.info("Found available GPU device " + device.toString() + " for scheduling");
            } catch (NumberFormatException e) {
                LOG.error("Unexpected format for major:minor device numbers: " + split[0] + ":" + split[1]);
            }
        }
    }

    private void initTotalGPUs() {
        this.totalGPUs = new HashSet<>(this.configuredAvailableGPUs);
    }

    public HashSet<Device> getMandatoryDrivers() {
        return this.mandatoryDrivers;
    }

    public HashSet<Device> getConfiguredAvailableGPUs() {
        return this.configuredAvailableGPUs;
    }

    public HashSet<Device> getTotalGPUs() {
        return this.totalGPUs;
    }

    private HashSet<Device> getAllocatedGPUs() {
        HashSet<Device> hashSet = new HashSet<>();
        Iterator<HashSet<Device>> it = this.containerGPUAllocationMapping.values().iterator();
        while (it.hasNext()) {
            hashSet.addAll(it.next());
        }
        return hashSet;
    }

    public synchronized HashSet<Device> allocate(String str, int i) throws IOException {
        HashSet<Device> hashSet = new HashSet<>(getTotalGPUs());
        if (this.configuredAvailableGPUs.size() < i) {
            throw new IOException("Container " + str + " requested " + i + " GPUs when only " + this.configuredAvailableGPUs.size() + " available");
        }
        LOG.info("Trying to allocate " + i + " GPUs");
        if (i > 0) {
            LOG.info("Currently unallocated GPUs: " + this.configuredAvailableGPUs.toString());
            LOG.info("Currently allocated GPUs: " + getAllocatedGPUs());
            HashSet<Device> selectGPUsToAllocate = selectGPUsToAllocate(i);
            this.configuredAvailableGPUs.removeAll(selectGPUsToAllocate);
            LOG.info("GPUs to allocate for " + str + " = " + selectGPUsToAllocate);
            this.containerGPUAllocationMapping.put(str, selectGPUsToAllocate);
            hashSet.removeAll(selectGPUsToAllocate);
        }
        LOG.info("GPUs to deny for " + str + " = " + hashSet);
        return hashSet;
    }

    private synchronized HashSet<Device> selectGPUsToAllocate(int i) {
        HashSet<Device> hashSet = new HashSet<>();
        Iterator<Device> it = this.configuredAvailableGPUs.iterator();
        TreeSet treeSet = new TreeSet();
        while (it.hasNext()) {
            treeSet.add(Integer.valueOf(it.next().getMinorDeviceNumber()));
        }
        Iterator it2 = treeSet.iterator();
        while (it2.hasNext() && i != 0) {
            hashSet.add(new Device(NVIDIA_GPU_MAJOR_DEVICE_NUMBER, ((Integer) it2.next()).intValue()));
            i--;
        }
        return hashSet;
    }

    public synchronized void release(String str) {
        if (this.containerGPUAllocationMapping == null || !this.containerGPUAllocationMapping.containsKey(str)) {
            return;
        }
        HashSet<Device> hashSet = this.containerGPUAllocationMapping.get(str);
        this.containerGPUAllocationMapping.remove(str);
        this.configuredAvailableGPUs.addAll(hashSet);
        LOG.info("Releasing GPUs " + hashSet + " for container " + str);
    }

    public synchronized void recoverAllocation(String str, String str2) {
        HashSet<Device> findGPUsInWhitelist = findGPUsInWhitelist(str2);
        if (findGPUsInWhitelist.isEmpty()) {
            return;
        }
        this.configuredAvailableGPUs.removeAll(findGPUsInWhitelist);
        this.containerGPUAllocationMapping.put(str, findGPUsInWhitelist);
        LOG.info("Recovering GPUs " + findGPUsInWhitelist + " for container " + str);
        LOG.info("Available GPUs after container " + str + " recovery " + this.configuredAvailableGPUs);
        LOG.info("So far " + this.containerGPUAllocationMapping.size() + " recovered containers");
    }

    private HashSet<Device> findGPUsInWhitelist(String str) {
        HashSet<Device> hashSet = new HashSet<>();
        Matcher matcher = DEVICES_LIST_FORMAT.matcher(str);
        while (matcher.find()) {
            String[] split = matcher.group(2).split(":");
            int parseInt = Integer.parseInt(split[0]);
            if (parseInt == NVIDIA_GPU_MAJOR_DEVICE_NUMBER) {
                int parseInt2 = Integer.parseInt(split[1]);
                if (!getMandatoryDrivers().contains(new Device(parseInt, parseInt2))) {
                    hashSet.add(new Device(parseInt, parseInt2));
                }
            }
        }
        return hashSet;
    }
}
