/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.iteration.operator.perround;

import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import org.apache.commons.collections.IteratorUtils;
import org.apache.flink.annotation.Internal;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.common.typeutils.base.IntSerializer;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.MetricOptions;
import org.apache.flink.contrib.streaming.state.RocksDBKeyedStateBackend;
import org.apache.flink.core.memory.ManagedMemoryUseCase;
import org.apache.flink.iteration.IterationListener;
import org.apache.flink.iteration.IterationRecord;
import org.apache.flink.iteration.operator.AbstractWrapperOperator;
import org.apache.flink.iteration.operator.OperatorStateUtils;
import org.apache.flink.iteration.operator.OperatorUtils;
import org.apache.flink.iteration.proxy.state.ProxyStateSnapshotContext;
import org.apache.flink.iteration.proxy.state.ProxyStreamOperatorStateContext;
import org.apache.flink.iteration.utils.ReflectionUtils;
import org.apache.flink.metrics.MetricGroup;
import org.apache.flink.metrics.groups.OperatorMetricGroup;
import org.apache.flink.runtime.checkpoint.CheckpointOptions;
import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.runtime.metrics.groups.UnregisteredMetricGroups;
import org.apache.flink.runtime.operators.coordination.OperatorEventDispatcher;
import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
import org.apache.flink.runtime.state.CheckpointStreamFactory;
import org.apache.flink.runtime.state.DefaultOperatorStateBackend;
import org.apache.flink.runtime.state.KeyedStateBackend;
import org.apache.flink.runtime.state.OperatorStateBackend;
import org.apache.flink.runtime.state.OperatorStateCheckpointOutputStream;
import org.apache.flink.runtime.state.StateInitializationContext;
import org.apache.flink.runtime.state.StatePartitionStreamProvider;
import org.apache.flink.runtime.state.StateSnapshotContext;
import org.apache.flink.runtime.state.heap.HeapKeyedStateBackend;
import org.apache.flink.streaming.api.graph.StreamConfig;
import org.apache.flink.streaming.api.operators.InternalTimeServiceManager;
import org.apache.flink.streaming.api.operators.KeyContext;
import org.apache.flink.streaming.api.operators.OperatorSnapshotFutures;
import org.apache.flink.streaming.api.operators.Output;
import org.apache.flink.streaming.api.operators.StreamOperator;
import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
import org.apache.flink.streaming.api.operators.StreamOperatorFactoryUtil;
import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
import org.apache.flink.streaming.api.operators.StreamOperatorStateContext;
import org.apache.flink.streaming.api.operators.StreamOperatorStateHandler;
import org.apache.flink.streaming.api.operators.StreamTaskStateInitializer;
import org.apache.flink.streaming.runtime.streamrecord.LatencyMarker;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.streaming.runtime.tasks.StreamTask;
import org.apache.flink.streaming.util.LatencyStats;
import org.apache.flink.util.CloseableIterable;
import org.apache.flink.util.ExceptionUtils;
import org.apache.flink.util.InstantiationUtil;
import org.apache.flink.util.Preconditions;
import org.apache.flink.util.function.BiConsumerWithException;
import org.rocksdb.RocksDB;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class AbstractPerRoundWrapperOperator<T, S extends StreamOperator<T>>
extends AbstractWrapperOperator<T>
implements StreamOperatorStateHandler.CheckpointedStreamOperator {
    private static final Logger LOG = LoggerFactory.getLogger(AbstractPerRoundWrapperOperator.class);
    private static final String HEAP_KEYED_STATE_NAME = "org.apache.flink.runtime.state.heap.HeapKeyedStateBackend";
    private static final String ROCKSDB_KEYED_STATE_NAME = "org.apache.flink.contrib.streaming.state.RocksDBKeyedStateBackend";
    private final Map<Integer, S> wrappedOperators = new HashMap<Integer, S>();
    protected final LatencyStats latencyStats = this.initializeLatencyStats();
    private transient StreamOperatorStateContext streamOperatorStateContext;
    private transient StreamOperatorStateHandler stateHandler;
    private transient InternalTimeServiceManager<?> timeServiceManager;
    private transient KeySelector<?, ?> stateKeySelector1;
    private transient KeySelector<?, ?> stateKeySelector2;
    private int latestEpochWatermark = -1;
    private ListState<Integer> parallelismState;
    private ListState<Integer> latestEpochWatermarkState;
    private ListState<Integer> pendingEpochState;
    private ListState<Integer> rawStateEpochState;

    public AbstractPerRoundWrapperOperator(StreamOperatorParameters<IterationRecord<T>> parameters, StreamOperatorFactory<T> operatorFactory) {
        super(parameters, operatorFactory);
    }

    protected S getWrappedOperator(int round) {
        return this.getWrappedOperator(round, CloseableIterable.empty().iterator(), 0);
    }

    private S getWrappedOperator(int round, Iterator<StatePartitionStreamProvider> rawOperatorStates, int count) {
        StreamOperator wrappedOperator = (StreamOperator)this.wrappedOperators.get(round);
        if (wrappedOperator != null) {
            return (S)wrappedOperator;
        }
        try {
            StreamOperatorFactory clonedOperatorFactory = (StreamOperatorFactory)InstantiationUtil.clone((Serializable)this.operatorFactory, (ClassLoader)this.containingTask.getUserCodeClassLoader());
            wrappedOperator = (StreamOperator)StreamOperatorFactoryUtil.createOperator((StreamOperatorFactory)clonedOperatorFactory, (StreamTask)this.parameters.getContainingTask(), (StreamConfig)OperatorUtils.createWrappedOperatorConfig((StreamConfig)this.parameters.getStreamConfig(), (ClassLoader)this.containingTask.getUserCodeClassLoader()), (Output)this.proxyOutput, (OperatorEventDispatcher)this.parameters.getOperatorEventDispatcher()).f0;
            this.initializeStreamOperator(wrappedOperator, round, rawOperatorStates, count);
            this.wrappedOperators.put(round, wrappedOperator);
            return (S)wrappedOperator;
        }
        catch (Exception e) {
            ExceptionUtils.rethrow((Throwable)e);
            return (S)wrappedOperator;
        }
    }

    protected abstract void endInputAndEmitMaxWatermark(S var1, int var2, int var3) throws Exception;

    protected void closeStreamOperator(S operator, int epoch, int epochWatermark) throws Exception {
        this.setIterationContextRound(epoch);
        OperatorUtils.processOperatorOrUdfIfSatisfy(operator, IterationListener.class, listener -> this.notifyEpochWatermarkIncrement((IterationListener<?>)listener, epochWatermark));
        this.endInputAndEmitMaxWatermark(operator, epoch, epochWatermark);
        operator.finish();
        operator.close();
        this.setIterationContextRound(null);
        this.cleanupOperatorStates(epoch);
        if (this.stateHandler.getKeyedStateBackend() != null) {
            this.cleanupKeyedStates(epoch);
        }
    }

    @Override
    public void onEpochWatermarkIncrement(int epochWatermark) throws IOException {
        Preconditions.checkState((epochWatermark >= 0 ? 1 : 0) != 0, (Object)"The epoch watermark should be non-negative.");
        if (epochWatermark > this.latestEpochWatermark) {
            this.latestEpochWatermark = epochWatermark;
            try {
                if (epochWatermark < Integer.MAX_VALUE) {
                    StreamOperator wrappedOperator = (StreamOperator)this.wrappedOperators.remove(epochWatermark);
                    if (wrappedOperator != null) {
                        this.closeStreamOperator(wrappedOperator, epochWatermark, epochWatermark);
                    }
                } else {
                    ArrayList<Integer> sortedEpochs = new ArrayList<Integer>(this.wrappedOperators.keySet());
                    Collections.sort(sortedEpochs);
                    for (Integer epoch : sortedEpochs) {
                        this.closeStreamOperator((StreamOperator)this.wrappedOperators.remove(epoch), epoch, epochWatermark);
                    }
                }
            }
            catch (Exception exception) {
                ExceptionUtils.rethrow((Throwable)exception);
            }
        }
        super.onEpochWatermarkIncrement(epochWatermark);
    }

    protected void processForEachWrappedOperator(BiConsumerWithException<Integer, S, Exception> consumer) throws Exception {
        for (Map.Entry<Integer, S> entry : this.wrappedOperators.entrySet()) {
            consumer.accept((Object)entry.getKey(), (Object)((StreamOperator)entry.getValue()));
        }
    }

    public void open() throws Exception {
    }

    public void initializeState(StreamTaskStateInitializer streamTaskStateManager) throws Exception {
        TypeSerializer keySerializer = this.streamConfig.getStateKeySerializer(this.containingTask.getUserCodeClassLoader());
        this.streamOperatorStateContext = streamTaskStateManager.streamOperatorStateContext(this.getOperatorID(), this.getClass().getSimpleName(), this.parameters.getProcessingTimeService(), (KeyContext)this, keySerializer, this.containingTask.getCancelables(), (MetricGroup)this.metrics, this.streamConfig.getManagedMemoryFractionOperatorUseCaseOfSlot(ManagedMemoryUseCase.STATE_BACKEND, this.containingTask.getEnvironment().getTaskManagerInfo().getConfiguration(), this.containingTask.getUserCodeClassLoader()), this.isUsingCustomRawKeyedState());
        this.stateHandler = new StreamOperatorStateHandler(this.streamOperatorStateContext, this.containingTask.getExecutionConfig(), this.containingTask.getCancelables());
        this.stateHandler.initializeOperatorState((StreamOperatorStateHandler.CheckpointedStreamOperator)this);
        this.timeServiceManager = this.streamOperatorStateContext.internalTimerServiceManager();
        this.stateKeySelector1 = this.streamConfig.getStatePartitioner(0, this.containingTask.getUserCodeClassLoader());
        this.stateKeySelector2 = this.streamConfig.getStatePartitioner(1, this.containingTask.getUserCodeClassLoader());
    }

    public void initializeState(StateInitializationContext context) throws Exception {
        this.parallelismState = context.getOperatorStateStore().getUnionListState(new ListStateDescriptor("parallelism", (TypeSerializer)IntSerializer.INSTANCE));
        OperatorStateUtils.getUniqueElement(this.parallelismState, "parallelism").ifPresent(oldParallelism -> Preconditions.checkState((oldParallelism.intValue() == this.containingTask.getEnvironment().getTaskInfo().getNumberOfParallelSubtasks() ? 1 : 0) != 0, (Object)("The all-round wrapper operator is recovered with parallelism changed from " + oldParallelism + " to " + this.containingTask.getEnvironment().getTaskInfo().getNumberOfParallelSubtasks())));
        this.latestEpochWatermarkState = context.getOperatorStateStore().getListState(new ListStateDescriptor("latestEpoch", (TypeSerializer)IntSerializer.INSTANCE));
        OperatorStateUtils.getUniqueElement(this.latestEpochWatermarkState, "latestEpoch").ifPresent(oldLatestEpochWatermark -> {
            this.latestEpochWatermark = oldLatestEpochWatermark;
        });
        this.rawStateEpochState = context.getOperatorStateStore().getListState(new ListStateDescriptor("rawStateEpoch", Integer.class));
        List rawStateEpochs = IteratorUtils.toList(((Iterable)this.rawStateEpochState.get()).iterator());
        this.pendingEpochState = context.getOperatorStateStore().getListState(new ListStateDescriptor("pendingEpochs", (TypeSerializer)IntSerializer.INSTANCE));
        List pendingEpochs = IteratorUtils.toList(((Iterable)this.pendingEpochState.get()).iterator());
        Iterator<StatePartitionStreamProvider> rawStates = context.getRawOperatorStateInputs().iterator();
        int nextRawStateEntryIndex = 0;
        Iterator iterator = pendingEpochs.iterator();
        while (iterator.hasNext()) {
            int epoch = (Integer)iterator.next();
            Preconditions.checkState((nextRawStateEntryIndex == rawStateEpochs.size() || (Integer)rawStateEpochs.get(nextRawStateEntryIndex) >= epoch ? 1 : 0) != 0, (Object)String.format("Unexpected raw state indices %s and epochs %s", rawStateEpochs.toString(), pendingEpochs.toString()));
            int numberOfStateEntries = 0;
            while (nextRawStateEntryIndex < rawStateEpochs.size() && (Integer)rawStateEpochs.get(nextRawStateEntryIndex) == epoch) {
                ++numberOfStateEntries;
                ++nextRawStateEntryIndex;
            }
            this.getWrappedOperator(epoch, rawStates, numberOfStateEntries);
        }
    }

    @Internal
    protected boolean isUsingCustomRawKeyedState() {
        return false;
    }

    public void finish() throws Exception {
        Preconditions.checkState((this.wrappedOperators.size() == 0 ? 1 : 0) != 0, (Object)("Some wrapped operators are still not closed yet: " + this.wrappedOperators.keySet()));
    }

    public void close() throws Exception {
        if (this.stateHandler != null) {
            this.stateHandler.dispose();
        }
    }

    public void prepareSnapshotPreBarrier(long checkpointId) throws Exception {
        for (Map.Entry<Integer, S> entry : this.wrappedOperators.entrySet()) {
            ((StreamOperator)entry.getValue()).prepareSnapshotPreBarrier(checkpointId);
        }
    }

    public OperatorSnapshotFutures snapshotState(long checkpointId, long timestamp, CheckpointOptions checkpointOptions, CheckpointStreamFactory factory) throws Exception {
        return this.stateHandler.snapshotState((StreamOperatorStateHandler.CheckpointedStreamOperator)this, Optional.ofNullable(this.timeServiceManager), this.streamConfig.getOperatorName(), checkpointId, timestamp, checkpointOptions, factory, this.isUsingCustomRawKeyedState());
    }

    public void snapshotState(StateSnapshotContext context) throws Exception {
        OperatorStateCheckpointOutputStream rawOperatorStateOutputStream = context.getRawOperatorStateOutput();
        ArrayList<Integer> operatorStateEpoch = new ArrayList<Integer>();
        ArrayList<Integer> sortedEpochs = new ArrayList<Integer>(this.wrappedOperators.keySet());
        Collections.sort(sortedEpochs);
        Iterator iterator = sortedEpochs.iterator();
        while (iterator.hasNext()) {
            int epoch = (Integer)iterator.next();
            StreamOperator wrappedOperator = (StreamOperator)this.wrappedOperators.get(epoch);
            if (!StreamOperatorStateHandler.CheckpointedStreamOperator.class.isAssignableFrom(wrappedOperator.getClass())) continue;
            ((StreamOperatorStateHandler.CheckpointedStreamOperator)wrappedOperator).snapshotState((StateSnapshotContext)new ProxyStateSnapshotContext(context));
            int numberOfPartitions = rawOperatorStateOutputStream.getNumberOfPartitions();
            while (operatorStateEpoch.size() < numberOfPartitions) {
                operatorStateEpoch.add(epoch);
            }
        }
        this.parallelismState.clear();
        if (this.containingTask.getEnvironment().getTaskInfo().getIndexOfThisSubtask() == 0) {
            this.parallelismState.update(Collections.singletonList(this.containingTask.getEnvironment().getTaskInfo().getNumberOfParallelSubtasks()));
        }
        this.latestEpochWatermarkState.update(Collections.singletonList(this.latestEpochWatermark));
        this.rawStateEpochState.update(operatorStateEpoch);
        this.pendingEpochState.update(sortedEpochs);
    }

    public void setKeyContextElement1(StreamRecord record) throws Exception {
        this.setKeyContextElement(record, this.stateKeySelector1);
    }

    public void setKeyContextElement2(StreamRecord record) throws Exception {
        this.setKeyContextElement(record, this.stateKeySelector2);
    }

    private <T> void setKeyContextElement(StreamRecord<T> record, KeySelector<T, ?> selector) throws Exception {
        if (selector != null && ((IterationRecord)record.getValue()).getType() == IterationRecord.Type.RECORD) {
            Object key = selector.getKey(record.getValue());
            this.setCurrentKey(key);
        }
    }

    public OperatorMetricGroup getMetricGroup() {
        return this.metrics;
    }

    public OperatorID getOperatorID() {
        return this.streamConfig.getOperatorID();
    }

    public void notifyCheckpointComplete(long l) throws Exception {
        for (Map.Entry<Integer, S> entry : this.wrappedOperators.entrySet()) {
            ((StreamOperator)entry.getValue()).notifyCheckpointComplete(l);
        }
    }

    public void notifyCheckpointAborted(long checkpointId) throws Exception {
        for (Map.Entry<Integer, S> entry : this.wrappedOperators.entrySet()) {
            ((StreamOperator)entry.getValue()).notifyCheckpointAborted(checkpointId);
        }
    }

    public void setCurrentKey(Object key) {
        this.stateHandler.setCurrentKey(key);
    }

    public Object getCurrentKey() {
        if (this.stateHandler == null) {
            return null;
        }
        return this.stateHandler.getCurrentKey();
    }

    protected void reportOrForwardLatencyMarker(LatencyMarker marker) {
        this.latencyStats.reportLatency(marker);
        this.output.emitLatencyMarker(marker);
    }

    private LatencyStats initializeLatencyStats() {
        try {
            LatencyStats.Granularity granularity;
            Configuration taskManagerConfig = this.containingTask.getEnvironment().getTaskManagerInfo().getConfiguration();
            int historySize = taskManagerConfig.getInteger(MetricOptions.LATENCY_HISTORY_SIZE);
            if (historySize <= 0) {
                LOG.warn("{} has been set to a value equal or below 0: {}. Using default.", (Object)MetricOptions.LATENCY_HISTORY_SIZE, (Object)historySize);
                historySize = (Integer)MetricOptions.LATENCY_HISTORY_SIZE.defaultValue();
            }
            String configuredGranularity = taskManagerConfig.getString(MetricOptions.LATENCY_SOURCE_GRANULARITY);
            try {
                granularity = LatencyStats.Granularity.valueOf((String)configuredGranularity.toUpperCase(Locale.ROOT));
            }
            catch (IllegalArgumentException iae) {
                granularity = LatencyStats.Granularity.OPERATOR;
                LOG.warn("Configured value {} option for {} is invalid. Defaulting to {}.", new Object[]{configuredGranularity, MetricOptions.LATENCY_SOURCE_GRANULARITY.key(), granularity});
            }
            MetricGroup jobMetricGroup = this.metrics.getJobMetricGroup();
            return new LatencyStats(jobMetricGroup.addGroup("latency"), historySize, this.containingTask.getIndexInSubtaskGroup(), this.getOperatorID(), granularity);
        }
        catch (Exception e) {
            LOG.warn("An error occurred while instantiating latency metrics.", (Throwable)e);
            return new LatencyStats(UnregisteredMetricGroups.createUnregisteredTaskManagerJobMetricGroup().addGroup("latency"), 1, 0, new OperatorID(), LatencyStats.Granularity.SINGLE);
        }
    }

    private void initializeStreamOperator(S operator, int round, Iterator<StatePartitionStreamProvider> rawOperatorStates, int count) throws Exception {
        operator.initializeState((operatorID, operatorClassName, processingTimeService, keyContext, keySerializer, streamTaskCloseableRegistry, metricGroup, managedMemoryFraction, isUsingCustomRawKeyedState) -> new ProxyStreamOperatorStateContext(this.streamOperatorStateContext, this.getRoundStatePrefix(round), rawOperatorStates, count));
        operator.open();
    }

    private void cleanupOperatorStates(int round) {
        String roundPrefix = this.getRoundStatePrefix(round);
        OperatorStateBackend operatorStateBackend = this.stateHandler.getOperatorStateBackend();
        if (operatorStateBackend instanceof DefaultOperatorStateBackend) {
            for (String fieldNames : new String[]{"registeredOperatorStates", "registeredBroadcastStates", "accessedStatesByName", "accessedBroadcastStatesByName"}) {
                Map field = (Map)ReflectionUtils.getFieldValue(operatorStateBackend, DefaultOperatorStateBackend.class, fieldNames);
                field.entrySet().removeIf(entry -> ((String)entry.getKey()).startsWith(roundPrefix));
            }
        } else {
            LOG.warn("Unable to cleanup the operator state {}", (Object)operatorStateBackend);
        }
    }

    private void cleanupKeyedStates(int round) {
        String roundPrefix = this.getRoundStatePrefix(round);
        KeyedStateBackend keyedStateBackend = this.stateHandler.getKeyedStateBackend();
        if (keyedStateBackend.getClass().getName().equals(HEAP_KEYED_STATE_NAME)) {
            ((Map)ReflectionUtils.getFieldValue(keyedStateBackend, HeapKeyedStateBackend.class, "registeredKVStates")).entrySet().removeIf(entry -> ((String)entry.getKey()).startsWith(roundPrefix));
            ((Map)ReflectionUtils.getFieldValue(keyedStateBackend, HeapKeyedStateBackend.class, "createdKVStates")).entrySet().removeIf(entry -> ((String)entry.getKey()).startsWith(roundPrefix));
            ((Map)ReflectionUtils.getFieldValue(keyedStateBackend, AbstractKeyedStateBackend.class, "keyValueStatesByName")).entrySet().removeIf(entry -> ((String)entry.getKey()).startsWith(roundPrefix));
        } else if (keyedStateBackend.getClass().getName().equals(ROCKSDB_KEYED_STATE_NAME)) {
            RocksDB db = (RocksDB)ReflectionUtils.getFieldValue(keyedStateBackend, RocksDBKeyedStateBackend.class, "db");
            HashMap kvStateInformation = (HashMap)ReflectionUtils.getFieldValue(keyedStateBackend, RocksDBKeyedStateBackend.class, "kvStateInformation");
            kvStateInformation.entrySet().stream().filter(entry -> ((String)entry.getKey()).startsWith(roundPrefix)).forEach(entry -> {
                try {
                    db.dropColumnFamily(((RocksDBKeyedStateBackend.RocksDbKvStateInfo)entry.getValue()).columnFamilyHandle);
                }
                catch (Exception e) {
                    LOG.error("Failed to drop state {} for round {}", entry.getKey(), (Object)round);
                }
            });
            kvStateInformation.entrySet().removeIf(entry -> ((String)entry.getKey()).startsWith(roundPrefix));
            ((Map)ReflectionUtils.getFieldValue(keyedStateBackend, RocksDBKeyedStateBackend.class, "createdKVStates")).entrySet().removeIf(entry -> ((String)entry.getKey()).startsWith(roundPrefix));
            ((Map)ReflectionUtils.getFieldValue(keyedStateBackend, AbstractKeyedStateBackend.class, "keyValueStatesByName")).entrySet().removeIf(entry -> ((String)entry.getKey()).startsWith(roundPrefix));
        } else {
            LOG.warn("Unable to cleanup the keyed state {}", (Object)keyedStateBackend);
        }
    }

    private String getRoundStatePrefix(int round) {
        return "r" + round + "-";
    }

    int getLatestEpochWatermark() {
        return this.latestEpochWatermark;
    }

    public Map<Integer, S> getWrappedOperators() {
        return this.wrappedOperators;
    }
}

