/*
 * Decompiled with CFR 0.152.
 */
package com.aparapi.internal.kernel;

import com.aparapi.Config;
import com.aparapi.Kernel;
import com.aparapi.device.Device;
import com.aparapi.device.JavaDevice;
import com.aparapi.device.OpenCLDevice;
import com.aparapi.internal.kernel.KernelDeviceProfile;
import com.aparapi.internal.kernel.KernelPreferences;
import com.aparapi.internal.kernel.KernelProfile;
import com.aparapi.internal.kernel.PreferencesWrapper;
import com.aparapi.internal.util.Reflection;
import java.lang.reflect.Constructor;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;

public class KernelManager {
    private static KernelManager INSTANCE = new KernelManager();
    private LinkedHashMap<Class<? extends Kernel>, PreferencesWrapper> preferences = new LinkedHashMap();
    private LinkedHashMap<Class<? extends Kernel>, KernelProfile> profiles = new LinkedHashMap();
    private LinkedHashMap<Class<? extends Kernel>, Kernel> sharedInstances = new LinkedHashMap();
    private KernelPreferences defaultPreferences;

    protected KernelManager() {
        this.setup();
    }

    protected void setup() {
        this.defaultPreferences = this.createDefaultPreferences();
    }

    public static KernelManager instance() {
        return INSTANCE;
    }

    public static void setKernelManager(KernelManager manager) {
        INSTANCE = manager;
    }

    public static <T extends Kernel> T sharedKernelInstance(Class<T> kernelClass) {
        return KernelManager.instance().getSharedKernelInstance(kernelClass);
    }

    public void reportDeviceUsage(StringBuilder builder, boolean withProfilingInfo) {
        builder.append("Device Usage by Kernel Subclass");
        if (withProfilingInfo) {
            builder.append(" (showing mean elapsed times in milliseconds)");
        }
        builder.append("\n\n");
        for (PreferencesWrapper wrapper : this.preferences.values()) {
            KernelPreferences preferences = wrapper.getPreferences();
            Class<? extends Kernel> klass = wrapper.getKernelClass();
            KernelProfile profile = withProfilingInfo ? this.profiles.get(klass) : null;
            builder.append(klass.getName()).append(":\n\tusing ").append(preferences.getPreferredDevice(null).getShortDescription());
            List<Device> failedDevices = preferences.getFailedDevices();
            if (failedDevices.size() > 0) {
                builder.append(", failed devices = ");
                for (int i = 0; i < failedDevices.size(); ++i) {
                    builder.append(failedDevices.get(i).getShortDescription());
                    if (i >= failedDevices.size() - 1) continue;
                    builder.append(" | ");
                }
            }
            if (profile != null) {
                builder.append("\n");
                int row = 0;
                for (KernelDeviceProfile deviceProfile : profile.getDeviceProfiles()) {
                    if (row == 0) {
                        builder.append(deviceProfile.getTableHeader()).append("\n");
                    }
                    builder.append(deviceProfile.getAverageAsTableRow()).append("\n");
                    ++row;
                }
            }
            builder.append("\n");
        }
    }

    public void reportProfilingSummary(StringBuilder builder) {
        builder.append("\nProfiles by Kernel Subclass (mean elapsed times in milliseconds)\n\n");
        builder.append(KernelDeviceProfile.getTableHeader()).append("\n");
        for (Class<? extends Kernel> kernelClass : this.profiles.keySet()) {
            String simpleName = Reflection.getSimpleName(kernelClass);
            String kernelName = "----------------- [[ " + simpleName + " ]] ";
            builder.append(kernelName);
            int dashes = 132 - kernelName.length();
            for (int i = 0; i < dashes; ++i) {
                builder.append('-');
            }
            builder.append("\n");
            KernelProfile kernelProfile = this.profiles.get(kernelClass);
            for (KernelDeviceProfile deviceProfile : kernelProfile.getDeviceProfiles()) {
                builder.append(deviceProfile.getAverageAsTableRow()).append("\n");
            }
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public KernelPreferences getPreferences(Kernel kernel) {
        LinkedHashMap<Class<? extends Kernel>, PreferencesWrapper> linkedHashMap = this.preferences;
        synchronized (linkedHashMap) {
            KernelPreferences kernelPreferences;
            PreferencesWrapper wrapper = this.preferences.get(kernel.getClass());
            if (wrapper == null) {
                kernelPreferences = new KernelPreferences(this, kernel.getClass());
                this.preferences.put(kernel.getClass(), new PreferencesWrapper(kernel.getClass(), kernelPreferences));
            } else {
                kernelPreferences = this.preferences.get(kernel.getClass()).getPreferences();
            }
            return kernelPreferences;
        }
    }

    public void setPreferredDevices(Kernel _kernel, LinkedHashSet<Device> _devices) {
        KernelPreferences kernelPreferences = this.getPreferences(_kernel);
        kernelPreferences.setPreferredDevices(_devices);
    }

    public KernelPreferences getDefaultPreferences() {
        return this.defaultPreferences;
    }

    public void setDefaultPreferredDevices(LinkedHashSet<Device> _devices) {
        this.defaultPreferences.setPreferredDevices(_devices);
    }

    protected KernelPreferences createDefaultPreferences() {
        KernelPreferences preferences = new KernelPreferences(this, null);
        preferences.setPreferredDevices(this.createDefaultPreferredDevices());
        return preferences;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private <T extends Kernel> T getSharedKernelInstance(Class<T> kernelClass) {
        LinkedHashMap<Class<? extends Kernel>, Kernel> linkedHashMap = this.sharedInstances;
        synchronized (linkedHashMap) {
            Kernel shared = this.sharedInstances.get(kernelClass);
            if (shared == null) {
                try {
                    Constructor<T> constructor = kernelClass.getConstructor(new Class[0]);
                    constructor.setAccessible(true);
                    shared = (Kernel)constructor.newInstance(new Object[0]);
                    this.sharedInstances.put(kernelClass, shared);
                }
                catch (Exception e) {
                    throw new RuntimeException(e);
                }
            }
            return (T)shared;
        }
    }

    protected LinkedHashSet<Device> createDefaultPreferredDevices() {
        LinkedHashSet<Device> devices = new LinkedHashSet<Device>();
        List<OpenCLDevice> accelerators = OpenCLDevice.listDevices(Device.TYPE.ACC);
        List<OpenCLDevice> gpus = OpenCLDevice.listDevices(Device.TYPE.GPU);
        List<OpenCLDevice> cpus = OpenCLDevice.listDevices(Device.TYPE.CPU);
        Collections.sort(accelerators, this.getDefaultAcceleratorComparator());
        Collections.sort(gpus, this.getDefaultGPUComparator());
        List<Device.TYPE> preferredDeviceTypes = this.getPreferredDeviceTypes();
        for (Device.TYPE type : preferredDeviceTypes) {
            switch (type) {
                case UNKNOWN: {
                    throw new AssertionError((Object)"UNKNOWN device type not supported");
                }
                case GPU: {
                    devices.addAll(gpus);
                    break;
                }
                case CPU: {
                    devices.addAll(cpus);
                    break;
                }
                case JTP: {
                    devices.add(JavaDevice.THREAD_POOL);
                    break;
                }
                case SEQ: {
                    devices.add(JavaDevice.SEQUENTIAL);
                    break;
                }
                case ACC: {
                    devices.addAll(accelerators);
                    break;
                }
                case ALT: {
                    devices.add(JavaDevice.ALTERNATIVE_ALGORITHM);
                }
            }
        }
        return devices;
    }

    protected List<Device.TYPE> getPreferredDeviceTypes() {
        return Arrays.asList(Device.TYPE.ACC, Device.TYPE.GPU, Device.TYPE.CPU, Device.TYPE.ALT, Device.TYPE.JTP);
    }

    protected Comparator<OpenCLDevice> getDefaultAcceleratorComparator() {
        return new Comparator<OpenCLDevice>(){

            @Override
            public int compare(OpenCLDevice left, OpenCLDevice right) {
                return right.getMaxComputeUnits() - left.getMaxComputeUnits();
            }
        };
    }

    protected Comparator<OpenCLDevice> getDefaultGPUComparator() {
        return new Comparator<OpenCLDevice>(){

            @Override
            public int compare(OpenCLDevice left, OpenCLDevice right) {
                return KernelManager.selectLhs(left, right) ? -1 : 1;
            }
        };
    }

    public Device bestDevice() {
        return this.getDefaultPreferences().getPreferredDevice(null);
    }

    protected static boolean selectLhs(OpenCLDevice _deviceLhs, OpenCLDevice _deviceRhs) {
        boolean nvidiaLhs = _deviceLhs.getOpenCLPlatform().getVendor().toLowerCase().contains("nvidia");
        boolean nvidiaRhs = _deviceRhs.getOpenCLPlatform().getVendor().toLowerCase().contains("nvidia");
        if (nvidiaLhs || nvidiaRhs) {
            return KernelManager.selectLhsIfCUDA(_deviceLhs, _deviceRhs);
        }
        return _deviceLhs.getMaxComputeUnits() > _deviceRhs.getMaxComputeUnits();
    }

    protected static boolean selectLhsIfCUDA(OpenCLDevice _deviceLhs, OpenCLDevice _deviceRhs) {
        if (_deviceLhs.getType() != _deviceRhs.getType()) {
            return KernelManager.selectLhsByType(_deviceLhs.getType(), _deviceRhs.getType());
        }
        return _deviceLhs.getMaxWorkGroupSize() == _deviceRhs.getMaxWorkGroupSize() ? _deviceLhs.getGlobalMemSize() > _deviceRhs.getGlobalMemSize() : _deviceLhs.getMaxWorkGroupSize() > _deviceRhs.getMaxWorkGroupSize();
    }

    private static boolean selectLhsByType(Device.TYPE lhs, Device.TYPE rhs) {
        return lhs.rank < rhs.rank;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public KernelProfile getProfile(Class<? extends Kernel> kernelClass) {
        LinkedHashMap<Class<? extends Kernel>, KernelProfile> linkedHashMap = this.profiles;
        synchronized (linkedHashMap) {
            KernelProfile profile = this.profiles.get(kernelClass);
            if (profile == null) {
                profile = new KernelProfile(kernelClass);
                this.profiles.put(kernelClass, profile);
            }
            return profile;
        }
    }

    static {
        if (Config.dumpProfilesOnExit) {
            Runtime.getRuntime().addShutdownHook(new Thread(){

                @Override
                public void run() {
                    StringBuilder builder = new StringBuilder(2048);
                    KernelManager.instance().reportProfilingSummary(builder);
                    System.out.println(builder);
                }
            });
        }
    }

    public static class DeprecatedMethods {
        @Deprecated
        public static Device firstDevice(Device.TYPE _type) {
            List<Device> devices = KernelManager.instance().getDefaultPreferences().getPreferredDevices(null);
            for (Device device : devices) {
                if (device.getType() != _type) continue;
                return device;
            }
            return null;
        }

        @Deprecated
        public static Device bestGPU() {
            return DeprecatedMethods.firstDevice(Device.TYPE.GPU);
        }

        @Deprecated
        public static Device bestACC() {
            return DeprecatedMethods.firstDevice(Device.TYPE.ACC);
        }
    }
}

