/*
 * Decompiled with CFR 0.152.
 */
package io.hops.devices;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Strings;
import io.hops.GPUManagementLibrary;
import io.hops.GPUManagementLibraryLoader;
import io.hops.devices.Device;
import io.hops.devices.GPU;
import io.hops.exceptions.GPUManagementLibraryException;
import java.io.IOException;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.TreeSet;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.yarn.server.nodemanager.util.NodeManagerHardwareUtils;

public class GPUAllocator {
    static final Log LOG = LogFactory.getLog(GPUAllocator.class);
    private static final GPUAllocator gpuAllocator = new GPUAllocator();
    private HashSet<GPU> configuredAvailableGPUs = new HashSet();
    private HashSet<GPU> totalGPUs = new HashSet();
    private HashMap<String, HashSet<GPU>> containerGPUAllocationMapping = new HashMap();
    private HashSet<Device> mandatoryDrivers = new HashSet();
    private GPUManagementLibrary gpuManagementLibrary;
    private static String GPU_MANAGEMENT_LIBRARY_CLASSNAME;
    private static int GPU_MAJOR_DEVICE_NUMBER;
    private boolean initialized = false;
    private static final Pattern DEVICES_LIST_FORMAT;

    public static GPUAllocator getInstance() {
        return gpuAllocator;
    }

    private GPUAllocator() {
    }

    @VisibleForTesting
    public GPUAllocator(GPUManagementLibrary gpuManagementLibrary, Configuration conf) {
        this.gpuManagementLibrary = gpuManagementLibrary;
        this.initialize(conf);
    }

    @VisibleForTesting
    public HashMap<String, HashSet<GPU>> getAllocations() {
        return new HashMap<String, HashSet<GPU>>(this.containerGPUAllocationMapping);
    }

    public boolean initialize(Configuration conf) {
        if (!this.initialized) {
            GPU_MANAGEMENT_LIBRARY_CLASSNAME = conf.get("yarn.nodemanager.gpu.management-impl", "io.hops.management.nvidia.NvidiaManagementLibrary");
            LOG.info((Object)("Initializing GPUAllocator for " + GPU_MANAGEMENT_LIBRARY_CLASSNAME));
            try {
                if (this.gpuManagementLibrary == null) {
                    this.gpuManagementLibrary = GPUManagementLibraryLoader.load((String)GPU_MANAGEMENT_LIBRARY_CLASSNAME);
                }
            }
            catch (GPUManagementLibraryException | NoClassDefFoundError | UnsatisfiedLinkError e) {
                LOG.error((Object)("Could not load GPU management library using provider " + GPU_MANAGEMENT_LIBRARY_CLASSNAME), e);
            }
            this.initialized = this.gpuManagementLibrary.initialize();
            if (GPU_MANAGEMENT_LIBRARY_CLASSNAME.equals("io.hops.management.nvidia.NvidiaManagementLibrary")) {
                GPU_MAJOR_DEVICE_NUMBER = 195;
            } else if (GPU_MANAGEMENT_LIBRARY_CLASSNAME.equals("io.hops.management.amd.AMDManagementLibrary")) {
                GPU_MAJOR_DEVICE_NUMBER = 226;
            } else {
                throw new IllegalArgumentException(GPU_MANAGEMENT_LIBRARY_CLASSNAME + " not recognized by GPUAllocator");
            }
            int numGPUs = NodeManagerHardwareUtils.getNodeGPUs(conf);
            try {
                this.initMandatoryDrivers();
                LOG.info((Object)("Found mandatory drivers: " + this.getMandatoryDrivers().toString()));
                this.initConfiguredGPUs(numGPUs);
                this.initTotalGPUs();
            }
            catch (IOException ioe) {
                LOG.error((Object)"Could not initialize GPUAllocator", (Throwable)ioe);
            }
        }
        return this.initialized;
    }

    public boolean isInitialized() {
        return this.initialized;
    }

    public boolean shutDown() {
        if (this.initialized) {
            return this.gpuManagementLibrary.shutDown();
        }
        return false;
    }

    private void initMandatoryDrivers() throws IOException {
        String mandatoryDeviceIds = this.gpuManagementLibrary.queryMandatoryDevices();
        LOG.info((Object)("GPU Management Library drivers: " + mandatoryDeviceIds));
        if (!Strings.isNullOrEmpty((String)mandatoryDeviceIds)) {
            String[] mandatoryDeviceIdsArr = mandatoryDeviceIds.split(" ");
            for (int i = 0; i < mandatoryDeviceIdsArr.length; ++i) {
                String[] majorMinorPair = mandatoryDeviceIdsArr[i].split(":");
                try {
                    Device mandatoryDevice = new Device(Integer.parseInt(majorMinorPair[0]), Integer.parseInt(majorMinorPair[1]));
                    this.mandatoryDrivers.add(mandatoryDevice);
                    LOG.info((Object)("Found mandatory GPU driver " + mandatoryDevice.toString()));
                    continue;
                }
                catch (NumberFormatException e) {
                    LOG.error((Object)("Unexpected format for major:minor device numbers: " + majorMinorPair[0] + ":" + majorMinorPair[1]));
                }
            }
        } else {
            throw new IOException("Failed to discover device numbers for GPU drivers using provider: " + GPU_MANAGEMENT_LIBRARY_CLASSNAME);
        }
    }

    private void initConfiguredGPUs(int configuredGPUs) throws IOException {
        LOG.info((Object)("Querying GPU Management Library for " + configuredGPUs + " GPUs"));
        String configuredGPUDeviceIds = this.gpuManagementLibrary.queryAvailableDevices(configuredGPUs);
        LOG.info((Object)("GPU Management Library response: " + configuredGPUDeviceIds));
        if (!configuredGPUDeviceIds.equals("")) {
            String[] gpuEntries = configuredGPUDeviceIds.split(" ");
            for (int i = 0; i < gpuEntries.length; ++i) {
                String gpuEntry = gpuEntries[i];
                if (GPU_MANAGEMENT_LIBRARY_CLASSNAME.equals("io.hops.management.nvidia.NvidiaManagementLibrary")) {
                    String[] majorMinorPair = gpuEntry.split(":");
                    GPU nvidiaGPU = new GPU(new Device(Integer.parseInt(majorMinorPair[0]), Integer.parseInt(majorMinorPair[1])), null);
                    this.configuredAvailableGPUs.add(nvidiaGPU);
                    LOG.info((Object)("Found available GPU device " + nvidiaGPU.toString() + " for scheduling"));
                    continue;
                }
                String[] gpuRenderNodePair = gpuEntry.split("&");
                String gpuMajorMinor = gpuRenderNodePair[0];
                String[] gpuMajorMinorPair = gpuMajorMinor.split(":");
                String renderNode = gpuRenderNodePair[1];
                String[] renderNodeMajorMinorPair = renderNode.split(":");
                GPU amdGPU = new GPU(new Device(Integer.parseInt(gpuMajorMinorPair[0]), Integer.parseInt(gpuMajorMinorPair[1])), new Device(Integer.parseInt(renderNodeMajorMinorPair[0]), Integer.parseInt(renderNodeMajorMinorPair[1])));
                this.configuredAvailableGPUs.add(amdGPU);
                LOG.info((Object)("Found available GPU device " + amdGPU.toString() + " for scheduling"));
            }
        } else {
            throw new IOException("Could not discover GPU device numbers using provider " + GPU_MANAGEMENT_LIBRARY_CLASSNAME);
        }
    }

    private void initTotalGPUs() {
        this.totalGPUs = new HashSet<GPU>(this.configuredAvailableGPUs);
    }

    public HashSet<Device> getMandatoryDrivers() {
        return this.mandatoryDrivers;
    }

    public HashSet<GPU> getConfiguredAvailableGPUs() {
        return this.configuredAvailableGPUs;
    }

    public HashSet<GPU> getTotalGPUs() {
        return this.totalGPUs;
    }

    private HashSet<GPU> getAllocatedGPUs() {
        HashSet<GPU> allocatedGPUs = new HashSet<GPU>();
        Collection<HashSet<GPU>> gpuSets = this.containerGPUAllocationMapping.values();
        for (HashSet<GPU> allocatedGpuSet : gpuSets) {
            allocatedGPUs.addAll(allocatedGpuSet);
        }
        return allocatedGPUs;
    }

    public synchronized HashSet<GPU> allocate(String containerName, int gpus) throws IOException {
        HashSet<GPU> gpusToDeny = new HashSet<GPU>(this.getTotalGPUs());
        if (this.configuredAvailableGPUs.size() >= gpus) {
            LOG.info((Object)("Trying to allocate " + gpus + " GPUs"));
            if (gpus > 0) {
                LOG.info((Object)("Currently unallocated GPUs: " + this.configuredAvailableGPUs.toString()));
                HashSet<GPU> currentlyAllocatedGPUs = this.getAllocatedGPUs();
                LOG.info((Object)("Currently allocated GPUs: " + currentlyAllocatedGPUs));
                HashSet<GPU> gpuAllocation = this.selectGPUsToAllocate(gpus);
                this.configuredAvailableGPUs.removeAll(gpuAllocation);
                LOG.info((Object)("GPUs to allocate for " + containerName + " = " + gpuAllocation));
                this.containerGPUAllocationMapping.put(containerName, gpuAllocation);
                gpusToDeny.removeAll(gpuAllocation);
            }
            LOG.info((Object)("GPUs to deny for " + containerName + " = " + gpusToDeny));
            return gpusToDeny;
        }
        throw new IOException("Container " + containerName + " requested " + gpus + " GPUs when only " + this.configuredAvailableGPUs.size() + " available");
    }

    private synchronized HashSet<GPU> selectGPUsToAllocate(int gpus) {
        HashSet<GPU> gpuAllocation = new HashSet<GPU>();
        Iterator<GPU> availableGPUItr = this.configuredAvailableGPUs.iterator();
        TreeSet<Integer> minDeviceNums = new TreeSet<Integer>();
        while (availableGPUItr.hasNext()) {
            minDeviceNums.add(availableGPUItr.next().getGpuDevice().getMinorDeviceNumber());
        }
        Iterator minGPUDeviceNumItr = minDeviceNums.iterator();
        while (minGPUDeviceNumItr.hasNext() && gpus != 0) {
            int gpuMinorNum = (Integer)minGPUDeviceNumItr.next();
            for (GPU gpu : this.configuredAvailableGPUs) {
                if (gpu.getGpuDevice().getMinorDeviceNumber() != gpuMinorNum) continue;
                gpuAllocation.add(gpu);
                --gpus;
            }
        }
        return gpuAllocation;
    }

    public synchronized void release(String containerName) {
        if (this.containerGPUAllocationMapping != null && this.containerGPUAllocationMapping.containsKey(containerName)) {
            HashSet<GPU> gpuAllocation = this.containerGPUAllocationMapping.get(containerName);
            this.containerGPUAllocationMapping.remove(containerName);
            this.configuredAvailableGPUs.addAll(gpuAllocation);
            LOG.info((Object)("Releasing GPUs " + gpuAllocation + " for container " + containerName));
        }
    }

    public synchronized void recoverAllocation(String containerId, String devicesWhitelistStr) {
        HashSet<GPU> allocatedGPUsForContainer = this.findGPUsInWhitelist(devicesWhitelistStr);
        if (allocatedGPUsForContainer.isEmpty()) {
            return;
        }
        this.configuredAvailableGPUs.removeAll(allocatedGPUsForContainer);
        this.containerGPUAllocationMapping.put(containerId, allocatedGPUsForContainer);
        LOG.info((Object)("Recovering GPUs " + allocatedGPUsForContainer + " for container " + containerId));
        LOG.info((Object)("Available GPUs after container " + containerId + " recovery " + this.configuredAvailableGPUs));
        LOG.info((Object)("So far " + this.containerGPUAllocationMapping.size() + " recovered containers"));
    }

    private HashSet<GPU> findGPUsInWhitelist(String devicesWhitelistStr) {
        HashSet<GPU> recoveredGPUs = new HashSet<GPU>();
        Matcher m = DEVICES_LIST_FORMAT.matcher(devicesWhitelistStr);
        while (m.find()) {
            int minorDeviceNumber;
            GPU gpu;
            String majorMinorDeviceNumber = m.group(2);
            String[] majorMinorPair = majorMinorDeviceNumber.split(":");
            int majorDeviceNumber = Integer.parseInt(majorMinorPair[0]);
            if (majorDeviceNumber != GPU_MAJOR_DEVICE_NUMBER || (gpu = this.recoverGPU(minorDeviceNumber = Integer.parseInt(majorMinorPair[1]))) == null || this.getMandatoryDrivers().contains(gpu.getGpuDevice())) continue;
            recoveredGPUs.add(gpu);
        }
        return recoveredGPUs;
    }

    private GPU recoverGPU(int minDeviceNumber) {
        for (GPU gpu : this.configuredAvailableGPUs) {
            if (gpu.getGpuDevice().getMinorDeviceNumber() != minDeviceNumber) continue;
            return gpu;
        }
        return null;
    }

    static {
        DEVICES_LIST_FORMAT = Pattern.compile("([^\\s]+)+\\s([\\d+:\\d]+)+\\s([^\\s]+)");
    }
}

