/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.math.decompositions;

import org.apache.log4j.Logger;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.common.RandomWrapper;
import org.apache.mahout.math.Matrices;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.decompositions.ALS;
import org.apache.mahout.math.drm.CheckpointedDrm;
import org.apache.mahout.math.drm.DrmLike;
import org.apache.mahout.math.drm.RLikeDrmOps;
import org.apache.mahout.math.drm.RLikeDrmOps$;
import org.apache.mahout.math.scalabindings.RLikeOps$;
import scala.Enumeration;
import scala.Function0;
import scala.Function1;
import scala.MatchError;
import scala.Predef;
import scala.Predef$;
import scala.Serializable;
import scala.Tuple2;
import scala.collection.Iterable;
import scala.collection.Seq;
import scala.collection.immutable.List;
import scala.collection.immutable.List$;
import scala.collection.immutable.Nil$;
import scala.collection.immutable.StringOps;
import scala.math.package$;
import scala.reflect.ClassTag;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

public final class ALS$ {
    public static final ALS$ MODULE$;
    private final Logger log;

    static {
        new ALS$();
    }

    private Logger log() {
        return this.log;
    }

    public <K> ALS.Result<K> dals(DrmLike<K> drmA, int k, double lambda, int maxIterations, double convergenceThreshold) {
        Predef$.MODULE$.assert(convergenceThreshold < 1.0, (Function0)new Serializable(){
            public static final long serialVersionUID = 0L;

            public final String apply() {
                return "convergenceThreshold";
            }
        });
        Predef$.MODULE$.assert(maxIterations >= 1, (Function0)new Serializable(){
            public static final long serialVersionUID = 0L;

            public final String apply() {
                return "maxIterations";
            }
        });
        ClassTag<K> ktag = drmA.keyClassTag();
        DrmLike<Object> drmAt = RLikeDrmOps$.MODULE$.drm2RLikeOps(drmA).t();
        RLikeDrmOps<K> qual$1 = RLikeDrmOps$.MODULE$.drm2RLikeOps(drmA);
        int x$3 = k;
        boolean x$4 = qual$1.mapBlock$default$2();
        Serializable x$5 = new Serializable(k){
            public static final long serialVersionUID = 0L;
            private final int k$1;

            public final Tuple2<Object, Matrix> apply(Tuple2<Object, Matrix> x0$1) {
                Tuple2<Object, Matrix> tuple2 = x0$1;
                if (tuple2 != null) {
                    Object keys = tuple2._1();
                    Matrix block = (Matrix)tuple2._2();
                    RandomWrapper rnd = RandomUtils.getRandom();
                    Matrix uBlock = RLikeOps$.MODULE$.m2mOps(Matrices.symmetricUniformView(RLikeOps$.MODULE$.m2mOps(block).nrow(), this.k$1, rnd.nextInt())).$times(0.01);
                    Tuple2 tuple22 = Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(keys), (Object)uBlock);
                    return tuple22;
                }
                throw new MatchError(tuple2);
            }
            {
                this.k$1 = k$1;
            }
        };
        DrmLike<K> drmU = qual$1.mapBlock(x$3, x$4, (Function1<Tuple2<Object, Matrix>, Tuple2<Object, Matrix>>)x$5, ktag);
        CheckpointedDrm<Object> drmV = null;
        Nil$ rmseIterations = Nil$.MODULE$;
        boolean stop = false;
        for (int i = 0; !stop && i < maxIterations; ++i) {
            Object object = drmV == null ? BoxedUnit.UNIT : org.apache.mahout.math.drm.package$.MODULE$.drm2Checkpointed(drmV).uncache();
            DrmLike drmLike = RLikeDrmOps$.MODULE$.drmInt2RLikeOps(RLikeDrmOps$.MODULE$.drm2RLikeOps(drmU).t()).$percent$times$percent(drmU);
            DrmLike qual$2 = RLikeDrmOps$.MODULE$.drmInt2RLikeOps(RLikeDrmOps$.MODULE$.drmInt2RLikeOps(drmAt).$percent$times$percent(drmU)).$percent$times$percent(org.apache.mahout.math.scalabindings.package$.MODULE$.solve(RLikeOps$.MODULE$.m2mOps(org.apache.mahout.math.scalabindings.package$.MODULE$.diag(lambda, k)).$minus$colon(org.apache.mahout.math.drm.package$.MODULE$.drm2InCore(drmLike))));
            Enumeration.Value x$6 = qual$2.checkpoint$default$1();
            drmV = qual$2.checkpoint(x$6);
            org.apache.mahout.math.drm.package$.MODULE$.drm2Checkpointed(drmU).uncache();
            DrmLike drmLike2 = RLikeDrmOps$.MODULE$.drmInt2RLikeOps(RLikeDrmOps$.MODULE$.drmInt2RLikeOps(drmV).t()).$percent$times$percent(drmV);
            DrmLike<K> qual$3 = RLikeDrmOps$.MODULE$.drm2RLikeOps(RLikeDrmOps$.MODULE$.drm2RLikeOps(drmA).$percent$times$percent(drmV)).$percent$times$percent(org.apache.mahout.math.scalabindings.package$.MODULE$.solve(RLikeOps$.MODULE$.m2mOps(org.apache.mahout.math.scalabindings.package$.MODULE$.diag(lambda, k)).$minus$colon(org.apache.mahout.math.drm.package$.MODULE$.drm2InCore(drmLike2))));
            Enumeration.Value x$7 = qual$3.checkpoint$default$1();
            drmU = qual$3.checkpoint(x$7);
            if (!(convergenceThreshold > 0.0)) continue;
            double rmse = RLikeDrmOps$.MODULE$.drm2cpops(RLikeDrmOps$.MODULE$.drm2RLikeOps(drmA).$minus(RLikeDrmOps$.MODULE$.drm2RLikeOps(drmU).$percent$times$percent(RLikeDrmOps$.MODULE$.drmInt2RLikeOps(drmV).t()))).norm() / package$.MODULE$.sqrt((double)((long)drmA.ncol() * drmA.nrow()));
            if (i > 0) {
                double rmsePrev = BoxesRunTime.unboxToDouble((Object)rmseIterations.last());
                double convergence = (rmsePrev - rmse) / rmsePrev;
                if (convergence < 0.0) {
                    this.log().warn((Object)new StringOps(Predef$.MODULE$.augmentString("Rmse increase of %f. Should not happen.")).format((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToDouble((double)convergence)})));
                    stop = true;
                } else if (convergence < convergenceThreshold) {
                    stop = true;
                }
            }
            rmseIterations = (List)rmseIterations.$colon$plus((Object)BoxesRunTime.boxToDouble((double)rmse), List$.MODULE$.canBuildFrom());
        }
        return new ALS.Result<K>(drmU, drmV, (Iterable<Object>)rmseIterations);
    }

    public <K> int dals$default$2() {
        return 50;
    }

    public <K> double dals$default$3() {
        return 0.0;
    }

    public <K> int dals$default$4() {
        return 10;
    }

    public <K> double dals$default$5() {
        return 0.1;
    }

    private ALS$() {
        MODULE$ = this;
        this.log = Logger.getLogger(this.getClass());
    }
}

