/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.transform.tokenize;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.sysds.common.Types;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.transform.tokenize.Tokenizer;
import org.apache.sysds.runtime.transform.tokenize.TokenizerPost;
import org.apache.sysds.runtime.util.UtilFunctions;
import org.apache.wink.json4j.JSONException;
import org.apache.wink.json4j.JSONObject;

public class TokenizerPostHash
implements TokenizerPost {
    private static final long serialVersionUID = 4763889041868044668L;
    private final Params params;
    private final int numIdCols;
    private final int maxTokens;
    private final boolean wideFormat;

    public TokenizerPostHash(JSONObject params, int numIdCols, int maxTokens, boolean wideFormat) throws JSONException {
        this.params = new Params(params);
        this.numIdCols = numIdCols;
        this.maxTokens = maxTokens;
        this.wideFormat = wideFormat;
    }

    @Override
    public FrameBlock tokenizePost(List<Tokenizer.DocumentToTokens> tl, FrameBlock out) {
        for (Tokenizer.DocumentToTokens docToToken : tl) {
            List<Object> keys = docToToken.keys;
            List<Tokenizer.Token> tokenList = docToToken.tokens;
            List hashList = tokenList.stream().map(token -> token.textToken.hashCode() % this.params.num_features).collect(Collectors.toList());
            Map hashCounts = hashList.stream().collect(Collectors.groupingBy(Function.identity(), Collectors.counting()));
            TreeMap<Integer, Long> sortedHashes = new TreeMap<Integer, Long>(hashCounts);
            if (this.wideFormat) {
                this.appendTokensWide(keys, sortedHashes, out);
                continue;
            }
            this.appendTokensLong(keys, sortedHashes, out);
        }
        return out;
    }

    private void appendTokensLong(List<Object> keys, Map<Integer, Long> sortedHashes, FrameBlock out) {
        int numTokens = 0;
        for (Map.Entry<Integer, Long> hashCount : sortedHashes.entrySet()) {
            if (numTokens >= this.maxTokens) break;
            int hash = hashCount.getKey() + 1;
            long count = hashCount.getValue();
            ArrayList<Object> rowList = new ArrayList<Object>(keys);
            rowList.add(hash);
            rowList.add(count);
            Object[] row = new Object[rowList.size()];
            rowList.toArray(row);
            out.appendRow(row);
            ++numTokens;
        }
    }

    private void appendTokensWide(List<Object> keys, Map<Integer, Long> sortedHashes, FrameBlock out) {
        ArrayList<Object> rowList = new ArrayList<Object>(keys);
        for (int tokenPos = 0; tokenPos < this.maxTokens; ++tokenPos) {
            long positionHash = sortedHashes.getOrDefault(tokenPos, 0L);
            rowList.add(positionHash);
        }
        Object[] row = new Object[rowList.size()];
        rowList.toArray(row);
        out.appendRow(row);
    }

    @Override
    public Types.ValueType[] getOutSchema() {
        if (this.wideFormat) {
            return TokenizerPostHash.getOutSchemaWide(this.numIdCols, this.maxTokens);
        }
        return TokenizerPostHash.getOutSchemaLong(this.numIdCols);
    }

    private static Types.ValueType[] getOutSchemaWide(int numIdCols, int maxTokens) {
        int i;
        Types.ValueType[] schema = new Types.ValueType[numIdCols + maxTokens];
        for (i = 0; i < numIdCols; ++i) {
            schema[i] = Types.ValueType.STRING;
        }
        int j = 0;
        while (j < maxTokens) {
            schema[i] = Types.ValueType.INT64;
            ++j;
            ++i;
        }
        return schema;
    }

    private static Types.ValueType[] getOutSchemaLong(int numIdCols) {
        Types.ValueType[] schema = UtilFunctions.nCopies(numIdCols + 2, Types.ValueType.STRING);
        schema[numIdCols] = Types.ValueType.INT64;
        schema[numIdCols + 1] = Types.ValueType.INT64;
        return schema;
    }

    @Override
    public long getNumRows(long inRows) {
        if (this.wideFormat) {
            return inRows;
        }
        return inRows * (long)this.maxTokens;
    }

    @Override
    public long getNumCols() {
        return this.getOutSchema().length;
    }

    static class Params
    implements Serializable {
        private static final long serialVersionUID = -256069061414241795L;
        public int num_features = 0x100000;

        public Params(JSONObject json) throws JSONException {
            if (json != null && json.has("num_features")) {
                this.num_features = json.getInt("num_features");
            }
        }
    }
}

