/*
 * Decompiled with CFR 0.152.
 */
package org.apache.uniffle.shuffle.manager;

import io.grpc.stub.StreamObserver;
import java.util.Arrays;
import java.util.Map;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import java.util.function.Supplier;
import org.apache.uniffle.common.util.JavaUtils;
import org.apache.uniffle.proto.RssProtos;
import org.apache.uniffle.proto.ShuffleManagerGrpc;
import org.apache.uniffle.shuffle.manager.RssShuffleManagerInterface;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ShuffleManagerGrpcService
extends ShuffleManagerGrpc.ShuffleManagerImplBase {
    private static final Logger LOG = LoggerFactory.getLogger(ShuffleManagerGrpcService.class);
    private final Map<Integer, RssShuffleStatus> shuffleStatus = JavaUtils.newConcurrentMap();
    private final RssShuffleManagerInterface shuffleManager;

    public ShuffleManagerGrpcService(RssShuffleManagerInterface shuffleManager) {
        this.shuffleManager = shuffleManager;
    }

    @Override
    public void reportShuffleFetchFailure(RssProtos.ReportShuffleFetchFailureRequest request, StreamObserver<RssProtos.ReportShuffleFetchFailureResponse> responseObserver) {
        boolean reSubmitWholeStage;
        RssProtos.StatusCode code;
        String msg;
        String appId = request.getAppId();
        int stageAttempt = request.getStageAttemptId();
        int partitionId = request.getPartitionId();
        if (!appId.equals(this.shuffleManager.getAppId())) {
            msg = String.format("got a wrong shuffle fetch failure report from appId: %s, expected appId: %s", appId, this.shuffleManager.getAppId());
            LOG.warn(msg);
            code = RssProtos.StatusCode.INVALID_REQUEST;
            reSubmitWholeStage = false;
        } else {
            RssShuffleStatus status = this.shuffleStatus.computeIfAbsent(request.getShuffleId(), key -> {
                int partitionNum = this.shuffleManager.getPartitionNum((int)key);
                return new RssShuffleStatus(partitionNum, stageAttempt);
            });
            int c = status.resetStageAttemptIfNecessary(stageAttempt);
            if (c < 0) {
                msg = String.format("got an old stage(%d vs %d) shuffle fetch failure report, which should be impossible.", status.getStageAttempt(), stageAttempt);
                LOG.warn(msg);
                code = RssProtos.StatusCode.INVALID_REQUEST;
                reSubmitWholeStage = false;
            } else {
                code = RssProtos.StatusCode.SUCCESS;
                status.incPartitionFetchFailure(stageAttempt, partitionId);
                int fetchFailureNum = status.getPartitionFetchFailureNum(stageAttempt, partitionId);
                if (fetchFailureNum >= this.shuffleManager.getMaxFetchFailures()) {
                    reSubmitWholeStage = true;
                    msg = String.format("report shuffle fetch failure as maximum number(%d) of shuffle fetch is occurred", this.shuffleManager.getMaxFetchFailures());
                } else {
                    reSubmitWholeStage = false;
                    msg = "don't report shuffle fetch failure";
                }
            }
        }
        RssProtos.ReportShuffleFetchFailureResponse reply = RssProtos.ReportShuffleFetchFailureResponse.newBuilder().setStatus(code).setReSubmitWholeStage(reSubmitWholeStage).setMsg(msg).build();
        responseObserver.onNext(reply);
        responseObserver.onCompleted();
    }

    public void unregisterShuffle(int shuffleId) {
        this.shuffleStatus.remove(shuffleId);
    }

    private static class RssShuffleStatus {
        private final ReentrantReadWriteLock lock = new ReentrantReadWriteLock();
        private final ReentrantReadWriteLock.ReadLock readLock = this.lock.readLock();
        private final ReentrantReadWriteLock.WriteLock writeLock = this.lock.writeLock();
        private final int[] partitions;
        private int stageAttempt;

        private RssShuffleStatus(int partitionNum, int stageAttempt) {
            this.stageAttempt = stageAttempt;
            this.partitions = new int[partitionNum];
        }

        private <T> T withReadLock(Supplier<T> fn) {
            this.readLock.lock();
            try {
                T t = fn.get();
                return t;
            }
            finally {
                this.readLock.unlock();
            }
        }

        private <T> T withWriteLock(Supplier<T> fn) {
            this.writeLock.lock();
            try {
                T t = fn.get();
                return t;
            }
            finally {
                this.writeLock.unlock();
            }
        }

        public int getStageAttempt() {
            return this.withReadLock(() -> this.stageAttempt);
        }

        public int resetStageAttemptIfNecessary(int stageAttempt) {
            return this.withWriteLock(() -> {
                if (this.stageAttempt < stageAttempt) {
                    Arrays.fill(this.partitions, 0);
                    this.stageAttempt = stageAttempt;
                    return 1;
                }
                if (this.stageAttempt > stageAttempt) {
                    return -1;
                }
                return 0;
            });
        }

        public void incPartitionFetchFailure(int stageAttempt, int partition) {
            this.withWriteLock(() -> {
                if (this.stageAttempt == stageAttempt) {
                    this.partitions[partition] = this.partitions[partition] + 1;
                }
                return null;
            });
        }

        public int getPartitionFetchFailureNum(int stageAttempt, int partition) {
            return this.withReadLock(() -> {
                if (this.stageAttempt != stageAttempt) {
                    return 0;
                }
                return this.partitions[partition];
            });
        }
    }
}

