/*
 * Decompiled with CFR 0.152.
 */
package io.hops.devices;

import com.google.common.annotations.VisibleForTesting;
import io.hops.GPUManagementLibrary;
import io.hops.GPUManagementLibraryLoader;
import io.hops.devices.Device;
import io.hops.exceptions.GPUManagementLibraryException;
import java.io.IOException;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.TreeSet;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.yarn.server.nodemanager.util.NodeManagerHardwareUtils;

public class GPUAllocator {
    static final Log LOG = LogFactory.getLog(GPUAllocator.class);
    private static final GPUAllocator gpuAllocator = new GPUAllocator();
    private HashSet<Device> configuredAvailableGPUs = new HashSet();
    private HashSet<Device> totalGPUs = new HashSet();
    private HashMap<String, HashSet<Device>> containerGPUAllocationMapping = new HashMap();
    private HashSet<Device> mandatoryDrivers = new HashSet();
    private GPUManagementLibrary gpuManagementLibrary;
    private static final String GPU_MANAGEMENT_LIBRARY_CLASSNAME = "io.hops.management.nvidia.NvidiaManagementLibrary";
    private static final int NVIDIA_GPU_MAJOR_DEVICE_NUMBER = 195;
    private boolean initialized = false;
    private static final Pattern DEVICES_LIST_FORMAT = Pattern.compile("([^\\s]+)+\\s([\\d+:\\d]+)+\\s([^\\s]+)");

    public static GPUAllocator getInstance() {
        return gpuAllocator;
    }

    private GPUAllocator() {
        try {
            this.gpuManagementLibrary = GPUManagementLibraryLoader.load((String)GPU_MANAGEMENT_LIBRARY_CLASSNAME);
        }
        catch (GPUManagementLibraryException | UnsatisfiedLinkError e) {
            LOG.error((Object)"Could not load GPU management library. Is this NodeManager supposed to offer its GPUs as a resource? If yes, check installation and make sure hopsnvml-1.0 is present in java.library.path", e);
        }
    }

    @VisibleForTesting
    public GPUAllocator(GPUManagementLibrary gpuManagementLibrary, Configuration conf) {
        this.gpuManagementLibrary = gpuManagementLibrary;
        this.initialize(conf);
    }

    public boolean initialize(Configuration conf) {
        if (!this.initialized) {
            this.initialized = this.gpuManagementLibrary.initialize();
            int numGPUs = NodeManagerHardwareUtils.getNodeGPUs(conf);
            try {
                this.initMandatoryDrivers();
                this.initConfiguredGPUs(numGPUs);
                this.initTotalGPUs();
            }
            catch (IOException ioe) {
                LOG.error((Object)"Could not initialize GPUAllocator", (Throwable)ioe);
            }
        }
        return this.initialized;
    }

    public boolean isInitialized() {
        return this.initialized;
    }

    public boolean shutDown() {
        if (this.initialized) {
            return this.gpuManagementLibrary.shutDown();
        }
        return false;
    }

    private void initMandatoryDrivers() throws IOException {
        String mandatoryDeviceIds = this.gpuManagementLibrary.queryMandatoryDevices();
        if (!mandatoryDeviceIds.equals("")) {
            String[] mandatoryDeviceIdsArr = mandatoryDeviceIds.split(" ");
            for (int i = 0; i < mandatoryDeviceIdsArr.length; ++i) {
                String[] majorMinorPair = mandatoryDeviceIdsArr[i].split(":");
                try {
                    Device mandatoryDevice = new Device(Integer.parseInt(majorMinorPair[0]), Integer.parseInt(majorMinorPair[1]));
                    this.mandatoryDrivers.add(mandatoryDevice);
                    LOG.info((Object)("Found mandatory GPU driver " + mandatoryDevice.toString()));
                    continue;
                }
                catch (NumberFormatException e) {
                    LOG.error((Object)("Unexpected format for major:minor device numbers: " + majorMinorPair[0] + ":" + majorMinorPair[1]));
                }
            }
        } else {
            throw new IOException("Could not discover device numbers for GPU drivers, check driver installation and make sure libnvidia-ml.so.1 is present on LD_LIBRARY_PATH, or disable GPUs in the configuration");
        }
    }

    private void initConfiguredGPUs(int configuredGPUs) throws IOException {
        String configuredGPUDeviceIds = this.gpuManagementLibrary.queryAvailableDevices(configuredGPUs);
        if (!configuredGPUDeviceIds.equals("")) {
            String[] configuredGPUDeviceIdsArr = configuredGPUDeviceIds.split(" ");
            for (int i = 0; i < configuredGPUDeviceIdsArr.length; ++i) {
                String[] majorMinorPair = configuredGPUDeviceIdsArr[i].split(":");
                try {
                    Device gpu = new Device(Integer.parseInt(majorMinorPair[0]), Integer.parseInt(majorMinorPair[1]));
                    this.configuredAvailableGPUs.add(gpu);
                    LOG.info((Object)("Found available GPU device " + gpu.toString() + " for scheduling"));
                    continue;
                }
                catch (NumberFormatException e) {
                    LOG.error((Object)("Unexpected format for major:minor device numbers: " + majorMinorPair[0] + ":" + majorMinorPair[1]));
                }
            }
        } else {
            throw new IOException("Could not discover GPU device numbers, either you enabled GPUs but set NM_GPUS to 0, or there is a problem with the installation, check that libnvidia-ml.so.1 is present on LD_LIBRARY_PATH");
        }
    }

    private void initTotalGPUs() {
        this.totalGPUs = new HashSet<Device>(this.configuredAvailableGPUs);
    }

    public HashSet<Device> getMandatoryDrivers() {
        return this.mandatoryDrivers;
    }

    public HashSet<Device> getConfiguredAvailableGPUs() {
        return this.configuredAvailableGPUs;
    }

    public HashSet<Device> getTotalGPUs() {
        return this.totalGPUs;
    }

    private HashSet<Device> getAllocatedGPUs() {
        HashSet<Device> allocatedGPUs = new HashSet<Device>();
        Collection<HashSet<Device>> gpuSets = this.containerGPUAllocationMapping.values();
        for (HashSet<Device> allocatedGpuSet : gpuSets) {
            allocatedGPUs.addAll(allocatedGpuSet);
        }
        return allocatedGPUs;
    }

    public synchronized HashSet<Device> allocate(String containerName, int gpus) throws IOException {
        HashSet<Device> gpusToDeny = new HashSet<Device>(this.getTotalGPUs());
        if (this.configuredAvailableGPUs.size() >= gpus) {
            LOG.info((Object)("Trying to allocate " + gpus + " GPUs"));
            if (gpus > 0) {
                LOG.info((Object)("Currently unallocated GPUs: " + this.configuredAvailableGPUs.toString()));
                HashSet<Device> currentlyAllocatedGPUs = this.getAllocatedGPUs();
                LOG.info((Object)("Currently allocated GPUs: " + currentlyAllocatedGPUs));
                HashSet<Device> gpuAllocation = this.selectGPUsToAllocate(gpus);
                this.configuredAvailableGPUs.removeAll(gpuAllocation);
                LOG.info((Object)("GPUs to allocate for " + containerName + " = " + gpuAllocation));
                this.containerGPUAllocationMapping.put(containerName, gpuAllocation);
                gpusToDeny.removeAll(gpuAllocation);
            }
            LOG.info((Object)("GPUs to deny for " + containerName + " = " + gpusToDeny));
            return gpusToDeny;
        }
        throw new IOException("Container " + containerName + " requested " + gpus + " GPUs when only " + this.configuredAvailableGPUs.size() + " available");
    }

    private synchronized HashSet<Device> selectGPUsToAllocate(int gpus) {
        HashSet<Device> gpuAllocation = new HashSet<Device>();
        Iterator<Device> availableGPUItr = this.configuredAvailableGPUs.iterator();
        TreeSet<Integer> minDeviceNums = new TreeSet<Integer>();
        while (availableGPUItr.hasNext()) {
            minDeviceNums.add(availableGPUItr.next().getMinorDeviceNumber());
        }
        Iterator minGPUDeviceNumItr = minDeviceNums.iterator();
        while (minGPUDeviceNumItr.hasNext() && gpus != 0) {
            int gpuMinorNum = (Integer)minGPUDeviceNumItr.next();
            Device allocatedGPU = new Device(195, gpuMinorNum);
            gpuAllocation.add(allocatedGPU);
            --gpus;
        }
        return gpuAllocation;
    }

    public synchronized void release(String containerName) {
        if (this.containerGPUAllocationMapping != null && this.containerGPUAllocationMapping.containsKey(containerName)) {
            HashSet<Device> gpuAllocation = this.containerGPUAllocationMapping.get(containerName);
            this.containerGPUAllocationMapping.remove(containerName);
            this.configuredAvailableGPUs.addAll(gpuAllocation);
            LOG.info((Object)("Releasing GPUs " + gpuAllocation + " for container " + containerName));
        }
    }

    public synchronized void recoverAllocation(String containerId, String devicesWhitelistStr) {
        HashSet<Device> allocatedGPUsForContainer = this.findGPUsInWhitelist(devicesWhitelistStr);
        if (allocatedGPUsForContainer.isEmpty()) {
            return;
        }
        this.configuredAvailableGPUs.removeAll(allocatedGPUsForContainer);
        this.containerGPUAllocationMapping.put(containerId, allocatedGPUsForContainer);
        LOG.info((Object)("Recovering GPUs " + allocatedGPUsForContainer + " for container " + containerId));
        LOG.info((Object)("Available GPUs after container " + containerId + " recovery " + this.configuredAvailableGPUs));
        LOG.info((Object)("So far " + this.containerGPUAllocationMapping.size() + " recovered containers"));
    }

    private HashSet<Device> findGPUsInWhitelist(String devicesWhitelistStr) {
        HashSet<Device> recoveredGPUs = new HashSet<Device>();
        Matcher m = DEVICES_LIST_FORMAT.matcher(devicesWhitelistStr);
        while (m.find()) {
            String majorMinorDeviceNumber = m.group(2);
            String[] majorMinorPair = majorMinorDeviceNumber.split(":");
            int majorDeviceNumber = Integer.parseInt(majorMinorPair[0]);
            if (majorDeviceNumber != 195) continue;
            int minorDeviceNumber = Integer.parseInt(majorMinorPair[1]);
            Device device = new Device(majorDeviceNumber, minorDeviceNumber);
            if (this.getMandatoryDrivers().contains(device)) continue;
            recoveredGPUs.add(new Device(majorDeviceNumber, minorDeviceNumber));
        }
        return recoveredGPUs;
    }
}

