/*
 * Decompiled with CFR 0.152.
 */
package org.apache.uniffle.shuffle.manager;

import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicBoolean;
import org.apache.spark.MapOutputTracker;
import org.apache.spark.MapOutputTrackerMaster;
import org.apache.spark.SparkEnv;
import org.apache.spark.SparkException;
import org.apache.spark.shuffle.RssSparkShuffleUtils;
import org.apache.spark.shuffle.ShuffleManager;
import org.apache.spark.shuffle.SparkVersionUtils;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.shuffle.manager.RssShuffleManagerInterface;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class RssShuffleManagerBase
implements RssShuffleManagerInterface,
ShuffleManager {
    private static final Logger LOG = LoggerFactory.getLogger(RssShuffleManagerBase.class);
    private AtomicBoolean isInitialized = new AtomicBoolean(false);
    private Method unregisterAllMapOutputMethod;
    private Method registerShuffleMethod;

    @Override
    public void unregisterAllMapOutput(int shuffleId) throws SparkException {
        if (!RssSparkShuffleUtils.isStageResubmitSupported()) {
            return;
        }
        MapOutputTrackerMaster tracker = RssShuffleManagerBase.getMapOutputTrackerMaster();
        if (this.isInitialized.compareAndSet(false, true)) {
            this.unregisterAllMapOutputMethod = RssShuffleManagerBase.getUnregisterAllMapOutputMethod(tracker);
            this.registerShuffleMethod = RssShuffleManagerBase.getRegisterShuffleMethod(tracker);
        }
        if (this.unregisterAllMapOutputMethod != null) {
            try {
                this.unregisterAllMapOutputMethod.invoke((Object)tracker, shuffleId);
            }
            catch (IllegalAccessException | InvocationTargetException e) {
                throw new RssException("Invoke unregisterAllMapOutput method failed", e);
            }
        } else {
            int numMaps = this.getNumMaps(shuffleId);
            int numReduces = this.getPartitionNum(shuffleId);
            RssShuffleManagerBase.defaultUnregisterAllMapOutput(tracker, this.registerShuffleMethod, shuffleId, numMaps, numReduces);
        }
    }

    private static void defaultUnregisterAllMapOutput(MapOutputTrackerMaster tracker, Method registerShuffle, int shuffleId, int numMaps, int numReduces) throws SparkException {
        if (tracker != null && registerShuffle != null) {
            tracker.unregisterShuffle(shuffleId);
            try {
                if (SparkVersionUtils.MAJOR_VERSION > 3 || SparkVersionUtils.isSpark3() && SparkVersionUtils.MINOR_VERSION >= 2) {
                    registerShuffle.invoke((Object)tracker, shuffleId, numMaps, numReduces);
                }
                registerShuffle.invoke((Object)tracker, shuffleId, numMaps);
            }
            catch (IllegalAccessException | InvocationTargetException e) {
                throw new RssException("Invoke registerShuffle method failed", e);
            }
        } else {
            throw new SparkException("default unregisterAllMapOutput should only be called on the driver side");
        }
        tracker.incrementEpoch();
    }

    private static Method getUnregisterAllMapOutputMethod(MapOutputTrackerMaster tracker) {
        if (tracker != null) {
            Class<?> klass = tracker.getClass();
            Method m = null;
            try {
                if (SparkVersionUtils.isSpark2() && SparkVersionUtils.MINOR_VERSION <= 3) {
                    LOG.warn("Spark version <= 2.3, fallback to default method");
                } else if (SparkVersionUtils.isSpark2()) {
                    m = klass.getDeclaredMethod("unregisterAllMapOutput", Integer.TYPE);
                } else if (SparkVersionUtils.isSpark3() && SparkVersionUtils.MINOR_VERSION <= 1) {
                    m = klass.getDeclaredMethod("unregisterAllMapOutput", Integer.TYPE);
                } else if (SparkVersionUtils.isSpark3()) {
                    m = klass.getDeclaredMethod("unregisterAllMapAndMergeOutput", Integer.TYPE);
                } else {
                    LOG.warn("Unknown spark version({}), fallback to default method", (Object)SparkVersionUtils.SPARK_VERSION);
                }
            }
            catch (NoSuchMethodException e) {
                LOG.warn("Got no such method error when get unregisterAllMapOutput method for spark version({})", (Object)SparkVersionUtils.SPARK_VERSION);
            }
            return m;
        }
        return null;
    }

    private static Method getRegisterShuffleMethod(MapOutputTrackerMaster tracker) {
        if (tracker != null) {
            Class<?> klass = tracker.getClass();
            Method m = null;
            try {
                m = SparkVersionUtils.MAJOR_VERSION > 3 || SparkVersionUtils.isSpark3() && SparkVersionUtils.MINOR_VERSION >= 2 ? klass.getDeclaredMethod("registerShuffle", Integer.TYPE, Integer.TYPE, Integer.TYPE) : klass.getDeclaredMethod("registerShuffle", Integer.TYPE, Integer.TYPE);
            }
            catch (NoSuchMethodException e) {
                LOG.warn("Got no such method error when get registerShuffle method for spark version({})", (Object)SparkVersionUtils.SPARK_VERSION);
            }
            return m;
        }
        return null;
    }

    private static MapOutputTrackerMaster getMapOutputTrackerMaster() {
        MapOutputTracker tracker = Optional.ofNullable(SparkEnv.get()).map(SparkEnv::mapOutputTracker).orElse(null);
        return tracker instanceof MapOutputTrackerMaster ? (MapOutputTrackerMaster)tracker : null;
    }
}

