/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sedona.common.raster;

import it.geosolutions.jaiext.jiffle.JiffleBuilder;
import it.geosolutions.jaiext.jiffle.runtime.JiffleDirectRuntime;
import java.awt.image.BufferedImage;
import java.awt.image.ColorModel;
import java.awt.image.Raster;
import java.awt.image.RenderedImage;
import java.awt.image.SampleModel;
import java.awt.image.WritableRaster;
import java.awt.image.WritableRenderedImage;
import java.util.Arrays;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;
import javax.media.jai.PlanarImage;
import javax.media.jai.RasterFactory;
import org.apache.sedona.common.utils.RasterUtils;
import org.geotools.coverage.grid.GridCoverage2D;

public class MapAlgebra {
    private static final ThreadLocal<String> previousScript = new ThreadLocal();
    private static final ThreadLocal<JiffleDirectRuntime> previousRuntime = new ThreadLocal();

    public static double[] bandAsArray(GridCoverage2D rasterGeom, int bandIndex) {
        int numBands = rasterGeom.getNumSampleDimensions();
        if (bandIndex < 1 || bandIndex > numBands) {
            return null;
        }
        Raster raster = RasterUtils.getRaster(rasterGeom.getRenderedImage());
        int width = raster.getWidth();
        int height = raster.getHeight();
        double[] bandValues = new double[width * height];
        return raster.getSamples(0, 0, width, height, bandIndex - 1, bandValues);
    }

    public static GridCoverage2D addBandFromArray(GridCoverage2D rasterGeom, double[] bandValues, int bandIndex, Double noDataValue) {
        int numBands = rasterGeom.getNumSampleDimensions();
        if (bandIndex < 1 || bandIndex > numBands + 1) {
            throw new IllegalArgumentException("Band index is out of bounds. Must be between 1 and " + (numBands + 1) + ")");
        }
        Number[] bandValuesClass = (Double[])Arrays.stream(bandValues).boxed().toArray(Double[]::new);
        if (bandIndex == numBands + 1) {
            return RasterUtils.copyRasterAndAppendBand(rasterGeom, bandValuesClass, noDataValue);
        }
        return RasterUtils.copyRasterAndReplaceBand(rasterGeom, bandIndex, bandValuesClass, noDataValue, true);
    }

    public static GridCoverage2D addBandFromArray(GridCoverage2D rasterGeom, double[] bandValues, int bandIndex) {
        int numBands = rasterGeom.getNumSampleDimensions();
        if (bandIndex < 1 || bandIndex > numBands + 1) {
            throw new IllegalArgumentException("Band index is out of bounds. Must be between 1 and " + (numBands + 1) + ")");
        }
        Number[] bandValuesClass = (Double[])Arrays.stream(bandValues).boxed().toArray(Double[]::new);
        if (bandIndex == numBands + 1) {
            return RasterUtils.copyRasterAndAppendBand(rasterGeom, bandValuesClass);
        }
        return RasterUtils.copyRasterAndReplaceBand(rasterGeom, bandIndex, bandValuesClass);
    }

    public static GridCoverage2D addBandFromArray(GridCoverage2D rasterGeom, double[] bandValues) {
        return MapAlgebra.addBandFromArray(rasterGeom, bandValues, rasterGeom.getNumSampleDimensions() + 1);
    }

    public static GridCoverage2D mapAlgebra(GridCoverage2D gridCoverage2D, String pixelType, String script, Double noDataValue) {
        if (gridCoverage2D == null || script == null) {
            return null;
        }
        RenderedImage renderedImage = gridCoverage2D.getRenderedImage();
        int rasterDataType = pixelType != null ? RasterUtils.getDataTypeCode(pixelType) : renderedImage.getSampleModel().getDataType();
        int width = renderedImage.getWidth();
        int height = renderedImage.getHeight();
        WritableRaster resultRaster = RasterFactory.createBandedRaster((int)5, (int)width, (int)height, (int)1, null);
        ColorModel cm = MapAlgebra.fetchColorModel(renderedImage.getColorModel(), resultRaster);
        BufferedImage resultImage = new BufferedImage(cm, resultRaster, false, null);
        try {
            JiffleDirectRuntime runtime;
            String prevScript = previousScript.get();
            JiffleDirectRuntime prevRuntime = previousRuntime.get();
            if (prevRuntime != null && script.equals(prevScript)) {
                runtime = prevRuntime;
                runtime.setSourceImage("rast", renderedImage);
                runtime.setDestinationImage("out", (WritableRenderedImage)resultImage);
                runtime.setDefaultBounds();
            } else {
                JiffleBuilder builder = new JiffleBuilder();
                runtime = builder.script(script).source("rast", renderedImage).dest("out", (WritableRenderedImage)resultImage).getRuntime();
                previousScript.set(script);
                previousRuntime.set(runtime);
            }
            runtime.evaluateAll(null);
            if (rasterDataType != resultImage.getSampleModel().getDataType()) {
                WritableRaster convertedRaster = RasterFactory.createBandedRaster((int)rasterDataType, (int)width, (int)height, (int)1, null);
                double[] samples = resultRaster.getSamples(0, 0, width, height, 0, (double[])null);
                convertedRaster.setSamples(0, 0, width, height, 0, samples);
                return RasterUtils.clone(convertedRaster, null, gridCoverage2D, noDataValue, false);
            }
            return RasterUtils.clone(resultImage, null, gridCoverage2D, noDataValue, false);
        }
        catch (Exception e) {
            throw new RuntimeException("Failed to run map algebra", e);
        }
    }

    public static GridCoverage2D mapAlgebra(GridCoverage2D gridCoverage2D, String pixelType, String script) {
        return MapAlgebra.mapAlgebra(gridCoverage2D, pixelType, script, null);
    }

    private static ColorModel fetchColorModel(ColorModel originalColorModel, WritableRaster resultRaster) {
        if (originalColorModel.isCompatibleRaster(resultRaster)) {
            return originalColorModel;
        }
        return PlanarImage.createColorModel((SampleModel)resultRaster.getSampleModel());
    }

    public static GridCoverage2D mapAlgebra(GridCoverage2D rast0, GridCoverage2D rast1, String pixelType, String script, Double noDataValue) {
        if (rast0 == null || rast1 == null || script == null) {
            return null;
        }
        RasterUtils.isRasterSameShape(rast0, rast1);
        RenderedImage renderedImageRast0 = rast0.getRenderedImage();
        int rasterDataType = pixelType != null ? RasterUtils.getDataTypeCode(pixelType) : renderedImageRast0.getSampleModel().getDataType();
        int width = renderedImageRast0.getWidth();
        int height = renderedImageRast0.getHeight();
        WritableRaster resultRaster = RasterFactory.createBandedRaster((int)5, (int)width, (int)height, (int)1, null);
        ColorModel cmRast0 = MapAlgebra.fetchColorModel(renderedImageRast0.getColorModel(), resultRaster);
        RenderedImage renderedImageRast1 = rast1.getRenderedImage();
        BufferedImage resultImage = new BufferedImage(cmRast0, resultRaster, false, null);
        try {
            JiffleDirectRuntime runtime;
            String prevScript = previousScript.get();
            JiffleDirectRuntime prevRuntime = previousRuntime.get();
            if (prevRuntime != null && script.equals(prevScript)) {
                runtime = prevRuntime;
                runtime.setSourceImage("rast0", renderedImageRast0);
                runtime.setSourceImage("rast1", renderedImageRast1);
                runtime.setDestinationImage("out", (WritableRenderedImage)resultImage);
                runtime.setDefaultBounds();
            } else {
                JiffleBuilder builder = new JiffleBuilder();
                runtime = builder.script(script).source("rast0", renderedImageRast0).source("rast1", renderedImageRast1).dest("out", (WritableRenderedImage)resultImage).getRuntime();
                previousScript.set(script);
                previousRuntime.set(runtime);
            }
            runtime.evaluateAll(null);
            if (rasterDataType != resultImage.getSampleModel().getDataType()) {
                WritableRaster convertedRaster = RasterFactory.createBandedRaster((int)rasterDataType, (int)width, (int)height, (int)1, null);
                double[] samples = resultRaster.getSamples(0, 0, width, height, 0, (double[])null);
                convertedRaster.setSamples(0, 0, width, height, 0, samples);
                return RasterUtils.clone(convertedRaster, null, rast0, noDataValue, false);
            }
            return RasterUtils.clone(resultImage, null, rast0, noDataValue, false);
        }
        catch (Exception e) {
            throw new RuntimeException("Failed to run map algebra", e);
        }
    }

    public static double[] add(double[] band1, double[] band2) {
        MapAlgebra.ensureBandShape(band1.length, band2.length);
        double[] result = new double[band1.length];
        for (int i = 0; i < band1.length; ++i) {
            result[i] = band1[i] + band2[i];
        }
        return result;
    }

    public static double[] subtract(double[] band1, double[] band2) {
        MapAlgebra.ensureBandShape(band1.length, band2.length);
        double[] result = new double[band1.length];
        for (int i = 0; i < band1.length; ++i) {
            result[i] = band2[i] - band1[i];
        }
        return result;
    }

    public static double[] multiply(double[] band1, double[] band2) {
        MapAlgebra.ensureBandShape(band1.length, band2.length);
        double[] result = new double[band1.length];
        for (int i = 0; i < band1.length; ++i) {
            result[i] = band1[i] * band2[i];
        }
        return result;
    }

    public static double[] divide(double[] band1, double[] band2) {
        MapAlgebra.ensureBandShape(band1.length, band2.length);
        double[] result = new double[band1.length];
        for (int i = 0; i < band1.length; ++i) {
            result[i] = (double)Math.round(band1[i] / band2[i] * 100.0) / 100.0;
        }
        return result;
    }

    public static double[] multiplyFactor(double[] band, double factor) {
        double[] result = new double[band.length];
        for (int i = 0; i < band.length; ++i) {
            result[i] = band[i] * factor;
        }
        return result;
    }

    public static double[] modulo(double[] band, double dividend) {
        double[] result = new double[band.length];
        for (int i = 0; i < band.length; ++i) {
            result[i] = band[i] % dividend;
        }
        return result;
    }

    public static double[] squareRoot(double[] band) {
        double[] result = new double[band.length];
        for (int i = 0; i < band.length; ++i) {
            result[i] = (double)Math.round(Math.sqrt(band[i]) * 100.0) / 100.0;
        }
        return result;
    }

    public static double[] bitwiseAnd(double[] band1, double[] band2) {
        MapAlgebra.ensureBandShape(band1.length, band2.length);
        double[] result = new double[band1.length];
        for (int i = 0; i < band1.length; ++i) {
            result[i] = (int)band1[i] & (int)band2[i];
        }
        return result;
    }

    public static double[] bitwiseOr(double[] band1, double[] band2) {
        MapAlgebra.ensureBandShape(band1.length, band2.length);
        double[] result = new double[band1.length];
        for (int i = 0; i < band1.length; ++i) {
            result[i] = (int)band1[i] | (int)band2[i];
        }
        return result;
    }

    public static double[] logicalDifference(double[] band1, double[] band2) {
        MapAlgebra.ensureBandShape(band1.length, band2.length);
        double[] result = new double[band1.length];
        for (int i = 0; i < band1.length; ++i) {
            result[i] = band1[i] != band2[i] ? band1[i] : 0.0;
        }
        return result;
    }

    public static double[] logicalOver(double[] band1, double[] band2) {
        MapAlgebra.ensureBandShape(band1.length, band2.length);
        double[] result = new double[band1.length];
        for (int i = 0; i < band1.length; ++i) {
            result[i] = band1[i] != 0.0 ? band1[i] : band2[i];
        }
        return result;
    }

    public static double[] normalize(double[] bandValues) {
        Double minValue = Arrays.stream(bandValues).min().orElse(Double.NaN);
        Double maxValue = Arrays.stream(bandValues).max().orElse(Double.NaN);
        if (Double.compare(maxValue, minValue) == 0) {
            Arrays.fill(bandValues, 0.0);
        } else {
            for (int i = 0; i < bandValues.length; ++i) {
                bandValues[i] = (bandValues[i] - minValue) * 255.0 / (maxValue - minValue);
            }
        }
        return bandValues;
    }

    public static double[] normalizedDifference(double[] band1, double[] band2) {
        MapAlgebra.ensureBandShape(band1.length, band2.length);
        double[] result = new double[band1.length];
        for (int i = 0; i < band1.length; ++i) {
            if (band1[i] == 0.0) {
                band1[i] = -1.0;
            }
            if (band2[i] == 0.0) {
                band2[i] = -1.0;
            }
            result[i] = (double)Math.round((band2[i] - band1[i]) / (band2[i] + band1[i]) * 100.0) / 100.0;
        }
        return result;
    }

    public static double mean(double[] band) {
        return Arrays.stream(band).sum() / (double)band.length * 100.0 / 100.0;
    }

    public static double[] fetchRegion(double[] band, int[] coordinates, int[] dimension) {
        double[] result = new double[(coordinates[2] - coordinates[0] + 1) * (coordinates[3] - coordinates[1] + 1)];
        int k = 0;
        for (int i = coordinates[0]; i < coordinates[2] + 1; ++i) {
            for (int j = coordinates[1]; j < coordinates[3] + 1; ++j) {
                result[k] = band[i * dimension[0] + j];
                ++k;
            }
        }
        return result;
    }

    public static double[] mode(double[] band) {
        Map frequency = Arrays.stream(band).boxed().collect(Collectors.groupingBy(Function.identity(), Collectors.counting()));
        if (frequency.values().stream().max(Long::compare).orElse(0L) == 1L) {
            return band;
        }
        return new double[]{frequency.entrySet().stream().max(Map.Entry.comparingByValue()).map(Map.Entry::getKey).orElse(null)};
    }

    public static double[] greaterThan(double[] band, double target) {
        double[] result = new double[band.length];
        for (int i = 0; i < band.length; ++i) {
            result[i] = band[i] > target ? 1.0 : 0.0;
        }
        return result;
    }

    public static double[] greaterThanEqual(double[] band, double target) {
        double[] result = new double[band.length];
        for (int i = 0; i < band.length; ++i) {
            result[i] = band[i] >= target ? 1.0 : 0.0;
        }
        return result;
    }

    public static double[] lessThan(double[] band, double target) {
        double[] result = new double[band.length];
        for (int i = 0; i < band.length; ++i) {
            result[i] = band[i] < target ? 1.0 : 0.0;
        }
        return result;
    }

    public static double[] lessThanEqual(double[] band, double target) {
        double[] result = new double[band.length];
        for (int i = 0; i < band.length; ++i) {
            result[i] = band[i] <= target ? 1.0 : 0.0;
        }
        return result;
    }

    public static int countValue(double[] band, double target) {
        return (int)Arrays.stream(band).filter(x -> x == target).count();
    }

    private static void ensureBandShape(int band1, int band2) {
        if (band1 != band2) {
            throw new IllegalArgumentException("The shape of the provided bands is not same. Please check your inputs, it should be same.");
        }
    }
}

