/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.optim.loss;

import breeze.linalg.DenseVector;
import breeze.linalg.Vector;
import java.io.File;
import java.io.Serializable;
import org.apache.spark.SparkContext;
import org.apache.spark.SparkFunSuite;
import org.apache.spark.ml.feature.Instance;
import org.apache.spark.ml.linalg.BLAS$;
import org.apache.spark.ml.linalg.Vectors$;
import org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregatorSuite;
import org.apache.spark.ml.optim.loss.L2Regularization;
import org.apache.spark.ml.optim.loss.RDDLossFunction;
import org.apache.spark.ml.optim.loss.RDDLossFunction$;
import org.apache.spark.ml.util.TempDirectory;
import org.apache.spark.ml.util.TestingUtils$;
import org.apache.spark.mllib.util.MLlibTestSparkContext;
import org.apache.spark.mllib.util.MLlibTestSparkContext$testImplicits$;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.SparkSession;
import org.scalactic.Bool;
import org.scalactic.Bool$;
import org.scalactic.Equality$;
import org.scalactic.Prettifier$;
import org.scalactic.TripleEqualsSupport;
import org.scalactic.source.Position;
import org.scalatest.Assertions$;
import org.scalatest.Tag;
import scala.Function0;
import scala.Function1;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.Predef$;
import scala.Some;
import scala.Tuple2;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.mutable.ArrayOps;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxesRunTime;
import scala.runtime.java8.JFunction1;

@ScalaSignature(bytes="\u0006\u0001\u00113AAB\u0004\u0001)!)\u0011\u0005\u0001C\u0001E!IQ\u0005\u0001a\u0001\u0002\u0004%\tA\n\u0005\ng\u0001\u0001\r\u00111A\u0005\u0002QB\u0011\"\u0010\u0001A\u0002\u0003\u0005\u000b\u0015B\u0014\t\u000b\t\u0003A\u0011I\"\u0003)I#E\tT8tg\u001a+hn\u0019;j_:\u001cV/\u001b;f\u0015\tA\u0011\"\u0001\u0003m_N\u001c(B\u0001\u0006\f\u0003\u0015y\u0007\u000f^5n\u0015\taQ\"\u0001\u0002nY*\u0011abD\u0001\u0006gB\f'o\u001b\u0006\u0003!E\ta!\u00199bG\",'\"\u0001\n\u0002\u0007=\u0014xm\u0001\u0001\u0014\u0007\u0001)\u0012\u0004\u0005\u0002\u0017/5\tQ\"\u0003\u0002\u0019\u001b\ti1\u000b]1sW\u001a+hnU;ji\u0016\u0004\"AG\u0010\u000e\u0003mQ!\u0001H\u000f\u0002\tU$\u0018\u000e\u001c\u0006\u0003=5\tQ!\u001c7mS\nL!\u0001I\u000e\u0003+5cE.\u001b2UKN$8\u000b]1sW\u000e{g\u000e^3yi\u00061A(\u001b8jiz\"\u0012a\t\t\u0003I\u0001i\u0011aB\u0001\nS:\u001cH/\u00198dKN,\u0012a\n\t\u0004Q-jS\"A\u0015\u000b\u0005)j\u0011a\u0001:eI&\u0011A&\u000b\u0002\u0004%\u0012#\u0005C\u0001\u00182\u001b\u0005y#B\u0001\u0019\f\u0003\u001d1W-\u0019;ve\u0016L!AM\u0018\u0003\u0011%s7\u000f^1oG\u0016\fQ\"\u001b8ti\u0006t7-Z:`I\u0015\fHCA\u001b<!\t1\u0014(D\u00018\u0015\u0005A\u0014!B:dC2\f\u0017B\u0001\u001e8\u0005\u0011)f.\u001b;\t\u000fq\u001a\u0011\u0011!a\u0001O\u0005\u0019\u0001\u0010J\u0019\u0002\u0015%t7\u000f^1oG\u0016\u001c\b\u0005\u000b\u0002\u0005\u007fA\u0011a\u0007Q\u0005\u0003\u0003^\u0012\u0011\u0002\u001e:b]NLWM\u001c;\u0002\u0013\t,gm\u001c:f\u00032dG#A\u001b")
public class RDDLossFunctionSuite
extends SparkFunSuite
implements MLlibTestSparkContext {
    private transient RDD<Instance> instances;
    private transient SparkSession spark;
    private transient SparkContext sc;
    private transient String checkpointDir;
    private volatile MLlibTestSparkContext$testImplicits$ testImplicits$module;
    private File org$apache$spark$ml$util$TempDirectory$$_tempDir;

    @Override
    public /* synthetic */ void org$apache$spark$mllib$util$MLlibTestSparkContext$$super$beforeAll() {
        TempDirectory.beforeAll$(this);
    }

    @Override
    public /* synthetic */ void org$apache$spark$mllib$util$MLlibTestSparkContext$$super$afterAll() {
        TempDirectory.afterAll$(this);
    }

    @Override
    public void afterAll() {
        MLlibTestSparkContext.afterAll$(this);
    }

    @Override
    public Instance[] standardize(Instance[] instances) {
        return MLlibTestSparkContext.standardize$(this, instances);
    }

    @Override
    public /* synthetic */ void org$apache$spark$ml$util$TempDirectory$$super$beforeAll() {
        super.beforeAll();
    }

    @Override
    public /* synthetic */ void org$apache$spark$ml$util$TempDirectory$$super$afterAll() {
        super.afterAll();
    }

    @Override
    public File tempDir() {
        return TempDirectory.tempDir$(this);
    }

    @Override
    public SparkSession spark() {
        return this.spark;
    }

    @Override
    public void spark_$eq(SparkSession x$1) {
        this.spark = x$1;
    }

    @Override
    public SparkContext sc() {
        return this.sc;
    }

    @Override
    public void sc_$eq(SparkContext x$1) {
        this.sc = x$1;
    }

    @Override
    public String checkpointDir() {
        return this.checkpointDir;
    }

    @Override
    public void checkpointDir_$eq(String x$1) {
        this.checkpointDir = x$1;
    }

    @Override
    public MLlibTestSparkContext$testImplicits$ testImplicits() {
        if (this.testImplicits$module == null) {
            this.testImplicits$lzycompute$1();
        }
        return this.testImplicits$module;
    }

    @Override
    public File org$apache$spark$ml$util$TempDirectory$$_tempDir() {
        return this.org$apache$spark$ml$util$TempDirectory$$_tempDir;
    }

    @Override
    public void org$apache$spark$ml$util$TempDirectory$$_tempDir_$eq(File x$1) {
        this.org$apache$spark$ml$util$TempDirectory$$_tempDir = x$1;
    }

    public RDD<Instance> instances() {
        return this.instances;
    }

    public void instances_$eq(RDD<Instance> x$1) {
        this.instances = x$1;
    }

    @Override
    public void beforeAll() {
        MLlibTestSparkContext.beforeAll$(this);
        SparkContext qual$1 = this.sc();
        Seq x$1 = (Seq)Seq$.MODULE$.apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Instance[]{new Instance(0.0, 0.1, Vectors$.MODULE$.dense(1.0, (Seq)Predef$.MODULE$.wrapDoubleArray(new double[]{2.0}))), new Instance(1.0, 0.5, Vectors$.MODULE$.dense(1.5, (Seq)Predef$.MODULE$.wrapDoubleArray(new double[]{1.0}))), new Instance(2.0, 0.3, Vectors$.MODULE$.dense(4.0, (Seq)Predef$.MODULE$.wrapDoubleArray(new double[]{0.5})))}));
        int x$2 = qual$1.parallelize$default$2();
        this.instances_$eq((RDD<Instance>)qual$1.parallelize(x$1, x$2, ClassTag$.MODULE$.apply(Instance.class)));
    }

    private final void testImplicits$lzycompute$1() {
        RDDLossFunctionSuite rDDLossFunctionSuite = this;
        synchronized (rDDLossFunctionSuite) {
            if (this.testImplicits$module == null) {
                this.testImplicits$module = new MLlibTestSparkContext$testImplicits$(this);
            }
        }
    }

    public RDDLossFunctionSuite() {
        TempDirectory.$init$(this);
        MLlibTestSparkContext.$init$(this);
        this.test("regularization", (Seq)Predef$.MODULE$.wrapRefArray((Object[])new Tag[0]), (Function0 & Serializable & scala.Serializable)() -> {
            org.apache.spark.ml.linalg.Vector coefficients = Vectors$.MODULE$.dense(0.5, (Seq)Predef$.MODULE$.wrapDoubleArray(new double[]{-0.1}));
            L2Regularization regLossFun = new L2Regularization(0.1, (Function1)(JFunction1.mcZI.sp & Serializable & scala.Serializable)x$1 -> true, (Option)None$.MODULE$);
            Function1 & Serializable & scala.Serializable getAgg = (Function1 & Serializable & scala.Serializable)bvec -> new DifferentiableLossAggregatorSuite.TestAggregator(2, (org.apache.spark.ml.linalg.Vector)bvec.value());
            RDDLossFunction lossNoReg = new RDDLossFunction(this.instances(), (Function1)getAgg, (Option)None$.MODULE$, RDDLossFunction$.MODULE$.$lessinit$greater$default$4(), ClassTag$.MODULE$.apply(Instance.class), ClassTag$.MODULE$.apply(DifferentiableLossAggregatorSuite.TestAggregator.class));
            RDDLossFunction lossWithReg = new RDDLossFunction(this.instances(), (Function1)getAgg, (Option)new Some((Object)regLossFun), RDDLossFunction$.MODULE$.$lessinit$greater$default$4(), ClassTag$.MODULE$.apply(Instance.class), ClassTag$.MODULE$.apply(DifferentiableLossAggregatorSuite.TestAggregator.class));
            Tuple2 tuple2 = lossNoReg.calculate(coefficients.asBreeze().toDenseVector$mcD$sp(ClassTag$.MODULE$.Double()));
            if (tuple2 == null) {
                throw new MatchError((Object)tuple2);
            }
            double loss1 = tuple2._1$mcD$sp();
            DenseVector grad1 = (DenseVector)tuple2._2();
            Tuple2 tuple22 = new Tuple2((Object)BoxesRunTime.boxToDouble((double)loss1), (Object)grad1);
            Tuple2 tuple23 = tuple22;
            double loss12 = tuple23._1$mcD$sp();
            DenseVector grad12 = (DenseVector)tuple23._2();
            Tuple2 tuple24 = regLossFun.calculate(coefficients);
            if (tuple24 == null) {
                throw new MatchError((Object)tuple24);
            }
            double regLoss = tuple24._1$mcD$sp();
            org.apache.spark.ml.linalg.Vector regGrad = (org.apache.spark.ml.linalg.Vector)tuple24._2();
            Tuple2 tuple25 = new Tuple2((Object)BoxesRunTime.boxToDouble((double)regLoss), (Object)regGrad);
            Tuple2 tuple26 = tuple25;
            double regLoss2 = tuple26._1$mcD$sp();
            org.apache.spark.ml.linalg.Vector regGrad2 = (org.apache.spark.ml.linalg.Vector)tuple26._2();
            Tuple2 tuple27 = lossWithReg.calculate(coefficients.asBreeze().toDenseVector$mcD$sp(ClassTag$.MODULE$.Double()));
            if (tuple27 == null) {
                throw new MatchError((Object)tuple27);
            }
            double loss2 = tuple27._1$mcD$sp();
            DenseVector grad2 = (DenseVector)tuple27._2();
            Tuple2 tuple28 = new Tuple2((Object)BoxesRunTime.boxToDouble((double)loss2), (Object)grad2);
            Tuple2 tuple29 = tuple28;
            double loss22 = tuple29._1$mcD$sp();
            DenseVector grad22 = (DenseVector)tuple29._2();
            BLAS$.MODULE$.axpy(1.0, Vectors$.MODULE$.fromBreeze((Vector)grad12), regGrad2);
            Bool $org_scalatest_assert_macro_expr = Bool$.MODULE$.simpleMacroBool(TestingUtils$.MODULE$.VectorWithAlmostEquals(regGrad2).$tilde$eq$eq(TestingUtils$.MODULE$.VectorWithAlmostEquals(Vectors$.MODULE$.fromBreeze((Vector)grad22)).relTol(1.0E-5)), "org.apache.spark.ml.util.TestingUtils.VectorWithAlmostEquals(regGrad).~==(org.apache.spark.ml.util.TestingUtils.VectorWithAlmostEquals(org.apache.spark.ml.linalg.Vectors.fromBreeze(grad2)).relTol(1.0E-5))", Prettifier$.MODULE$.default());
            Assertions$.MODULE$.assertionsHelper().macroAssert($org_scalatest_assert_macro_expr, (Object)"", Prettifier$.MODULE$.default(), new Position("RDDLossFunctionSuite.scala", "Please set the environment variable SCALACTIC_FILL_FILE_PATHNAMES to yes at compile time to enable this feature.", 53));
            TripleEqualsSupport.Equalizer $org_scalatest_assert_macro_left = this.convertToEqualizer(BoxesRunTime.boxToDouble((double)(loss12 + regLoss2)));
            double $org_scalatest_assert_macro_right = loss22;
            Bool $org_scalatest_assert_macro_expr2 = Bool$.MODULE$.binaryMacroBool((Object)$org_scalatest_assert_macro_left, "===", (Object)BoxesRunTime.boxToDouble((double)$org_scalatest_assert_macro_right), $org_scalatest_assert_macro_left.$eq$eq$eq((Object)BoxesRunTime.boxToDouble((double)$org_scalatest_assert_macro_right), Equality$.MODULE$.default()), Prettifier$.MODULE$.default());
            return Assertions$.MODULE$.assertionsHelper().macroAssert($org_scalatest_assert_macro_expr2, (Object)"", Prettifier$.MODULE$.default(), new Position("RDDLossFunctionSuite.scala", "Please set the environment variable SCALACTIC_FILL_FILE_PATHNAMES to yes at compile time to enable this feature.", 54));
        }, new Position("RDDLossFunctionSuite.scala", "Please set the environment variable SCALACTIC_FILL_FILE_PATHNAMES to yes at compile time to enable this feature.", 41));
        this.test("empty RDD", (Seq)Predef$.MODULE$.wrapRefArray((Object[])new Tag[0]), (Function0 & Serializable & scala.Serializable)() -> {
            SparkContext qual$1 = this.sc();
            Seq x$1 = (Seq)Seq$.MODULE$.empty();
            int x$2 = qual$1.parallelize$default$2();
            RDD rdd = qual$1.parallelize(x$1, x$2, ClassTag$.MODULE$.apply(Instance.class));
            org.apache.spark.ml.linalg.Vector coefficients = Vectors$.MODULE$.dense(0.5, (Seq)Predef$.MODULE$.wrapDoubleArray(new double[]{-0.1}));
            Function1 & Serializable & scala.Serializable getAgg = (Function1 & Serializable & scala.Serializable)bv -> new DifferentiableLossAggregatorSuite.TestAggregator(2, (org.apache.spark.ml.linalg.Vector)bv.value());
            RDDLossFunction lossFun = new RDDLossFunction(rdd, (Function1)getAgg, (Option)None$.MODULE$, RDDLossFunction$.MODULE$.$lessinit$greater$default$4(), ClassTag$.MODULE$.apply(Instance.class), ClassTag$.MODULE$.apply(DifferentiableLossAggregatorSuite.TestAggregator.class));
            return (IllegalArgumentException)this.withClue("cannot calculate cost for empty dataset", (Function0 & Serializable & scala.Serializable)() -> (IllegalArgumentException)this.intercept((Function0 & Serializable & scala.Serializable)() -> lossFun.calculate(coefficients.asBreeze().toDenseVector$mcD$sp(ClassTag$.MODULE$.Double())), ClassTag$.MODULE$.apply(IllegalArgumentException.class), new Position("RDDLossFunctionSuite.scala", "Please set the environment variable SCALACTIC_FILL_FILE_PATHNAMES to yes at compile time to enable this feature.", 63)));
        }, new Position("RDDLossFunctionSuite.scala", "Please set the environment variable SCALACTIC_FILL_FILE_PATHNAMES to yes at compile time to enable this feature.", 57));
        this.test("versus aggregating on an iterable", (Seq)Predef$.MODULE$.wrapRefArray((Object[])new Tag[0]), (Function0 & Serializable & scala.Serializable)() -> {
            org.apache.spark.ml.linalg.Vector coefficients = Vectors$.MODULE$.dense(0.5, (Seq)Predef$.MODULE$.wrapDoubleArray(new double[]{-0.1}));
            Function1 & Serializable & scala.Serializable getAgg = (Function1 & Serializable & scala.Serializable)bv -> new DifferentiableLossAggregatorSuite.TestAggregator(2, (org.apache.spark.ml.linalg.Vector)bv.value());
            RDDLossFunction lossFun = new RDDLossFunction(this.instances(), (Function1)getAgg, (Option)None$.MODULE$, RDDLossFunction$.MODULE$.$lessinit$greater$default$4(), ClassTag$.MODULE$.apply(Instance.class), ClassTag$.MODULE$.apply(DifferentiableLossAggregatorSuite.TestAggregator.class));
            Tuple2 tuple2 = lossFun.calculate(coefficients.asBreeze().toDenseVector$mcD$sp(ClassTag$.MODULE$.Double()));
            if (tuple2 == null) {
                throw new MatchError((Object)tuple2);
            }
            double loss = tuple2._1$mcD$sp();
            DenseVector grad = (DenseVector)tuple2._2();
            Tuple2 tuple22 = new Tuple2((Object)BoxesRunTime.boxToDouble((double)loss), (Object)grad);
            Tuple2 tuple23 = tuple22;
            double loss2 = tuple23._1$mcD$sp();
            DenseVector grad2 = (DenseVector)tuple23._2();
            DifferentiableLossAggregatorSuite.TestAggregator agg = new DifferentiableLossAggregatorSuite.TestAggregator(2, coefficients);
            new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])this.instances().collect())).foreach((Function1 & Serializable & scala.Serializable)instance -> agg.add((Instance)instance));
            TripleEqualsSupport.Equalizer $org_scalatest_assert_macro_left = this.convertToEqualizer(BoxesRunTime.boxToDouble((double)loss2));
            double $org_scalatest_assert_macro_right = agg.loss();
            Bool $org_scalatest_assert_macro_expr = Bool$.MODULE$.binaryMacroBool((Object)$org_scalatest_assert_macro_left, "===", (Object)BoxesRunTime.boxToDouble((double)$org_scalatest_assert_macro_right), $org_scalatest_assert_macro_left.$eq$eq$eq((Object)BoxesRunTime.boxToDouble((double)$org_scalatest_assert_macro_right), Equality$.MODULE$.default()), Prettifier$.MODULE$.default());
            Assertions$.MODULE$.assertionsHelper().macroAssert($org_scalatest_assert_macro_expr, (Object)"", Prettifier$.MODULE$.default(), new Position("RDDLossFunctionSuite.scala", "Please set the environment variable SCALACTIC_FILL_FILE_PATHNAMES to yes at compile time to enable this feature.", 79));
            TripleEqualsSupport.Equalizer $org_scalatest_assert_macro_left2 = this.convertToEqualizer(Vectors$.MODULE$.fromBreeze((Vector)grad2));
            org.apache.spark.ml.linalg.Vector $org_scalatest_assert_macro_right2 = agg.gradient();
            Bool $org_scalatest_assert_macro_expr2 = Bool$.MODULE$.binaryMacroBool((Object)$org_scalatest_assert_macro_left2, "===", (Object)$org_scalatest_assert_macro_right2, $org_scalatest_assert_macro_left2.$eq$eq$eq((Object)$org_scalatest_assert_macro_right2, Equality$.MODULE$.default()), Prettifier$.MODULE$.default());
            return Assertions$.MODULE$.assertionsHelper().macroAssert($org_scalatest_assert_macro_expr2, (Object)"", Prettifier$.MODULE$.default(), new Position("RDDLossFunctionSuite.scala", "Please set the environment variable SCALACTIC_FILL_FILE_PATHNAMES to yes at compile time to enable this feature.", 80));
        }, new Position("RDDLossFunctionSuite.scala", "Please set the environment variable SCALACTIC_FILL_FILE_PATHNAMES to yes at compile time to enable this feature.", 69));
    }
}

