/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.ann;

import java.io.File;
import java.io.Serializable;
import org.apache.spark.SparkContext;
import org.apache.spark.SparkFunSuite;
import org.apache.spark.ml.ann.FeedForwardModel$;
import org.apache.spark.ml.ann.FeedForwardTopology;
import org.apache.spark.ml.ann.FeedForwardTopology$;
import org.apache.spark.ml.ann.FeedForwardTrainer;
import org.apache.spark.ml.ann.Topology;
import org.apache.spark.ml.ann.TopologyModel;
import org.apache.spark.ml.feature.Instance;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.linalg.Vectors$;
import org.apache.spark.ml.util.TempDirectory;
import org.apache.spark.ml.util.TestingUtils$;
import org.apache.spark.mllib.util.MLlibTestSparkContext;
import org.apache.spark.mllib.util.MLlibTestSparkContext$testImplicits$;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.SparkSession;
import org.scalactic.Bool;
import org.scalactic.Bool$;
import org.scalactic.Equality$;
import org.scalactic.Prettifier$;
import org.scalactic.TripleEqualsSupport;
import org.scalactic.source.Position;
import org.scalatest.Assertions$;
import org.scalatest.Tag;
import org.scalatest.compatible.Assertion;
import scala.Array$;
import scala.Function0;
import scala.Function1;
import scala.MatchError;
import scala.Predef$;
import scala.Tuple2;
import scala.collection.GenIterable;
import scala.collection.Seq;
import scala.collection.mutable.ArrayOps;
import scala.math.package$;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxesRunTime;
import scala.runtime.java8.JFunction0;

@ScalaSignature(bytes="\u0006\u0001}1AAA\u0002\u0001\u001d!)1\u0004\u0001C\u00019\tA\u0011I\u0014(Tk&$XM\u0003\u0002\u0005\u000b\u0005\u0019\u0011M\u001c8\u000b\u0005\u00199\u0011AA7m\u0015\tA\u0011\"A\u0003ta\u0006\u00148N\u0003\u0002\u000b\u0017\u00051\u0011\r]1dQ\u0016T\u0011\u0001D\u0001\u0004_J<7\u0001A\n\u0004\u0001=\u0019\u0002C\u0001\t\u0012\u001b\u00059\u0011B\u0001\n\b\u00055\u0019\u0006/\u0019:l\rVt7+^5uKB\u0011A#G\u0007\u0002+)\u0011acF\u0001\u0005kRLGN\u0003\u0002\u0019\u000f\u0005)Q\u000e\u001c7jE&\u0011!$\u0006\u0002\u0016\u001b2c\u0017N\u0019+fgR\u001c\u0006/\u0019:l\u0007>tG/\u001a=u\u0003\u0019a\u0014N\\5u}Q\tQ\u0004\u0005\u0002\u001f\u00015\t1\u0001")
public class ANNSuite
extends SparkFunSuite
implements MLlibTestSparkContext {
    private transient SparkSession spark;
    private transient SparkContext sc;
    private transient String checkpointDir;
    private volatile MLlibTestSparkContext$testImplicits$ testImplicits$module;
    private File org$apache$spark$ml$util$TempDirectory$$_tempDir;

    @Override
    public /* synthetic */ void org$apache$spark$mllib$util$MLlibTestSparkContext$$super$beforeAll() {
        TempDirectory.beforeAll$(this);
    }

    @Override
    public /* synthetic */ void org$apache$spark$mllib$util$MLlibTestSparkContext$$super$afterAll() {
        TempDirectory.afterAll$(this);
    }

    @Override
    public void beforeAll() {
        MLlibTestSparkContext.beforeAll$(this);
    }

    @Override
    public void afterAll() {
        MLlibTestSparkContext.afterAll$(this);
    }

    @Override
    public Instance[] standardize(Instance[] instances) {
        return MLlibTestSparkContext.standardize$(this, instances);
    }

    @Override
    public /* synthetic */ void org$apache$spark$ml$util$TempDirectory$$super$beforeAll() {
        super.beforeAll();
    }

    @Override
    public /* synthetic */ void org$apache$spark$ml$util$TempDirectory$$super$afterAll() {
        super.afterAll();
    }

    @Override
    public File tempDir() {
        return TempDirectory.tempDir$(this);
    }

    @Override
    public SparkSession spark() {
        return this.spark;
    }

    @Override
    public void spark_$eq(SparkSession x$1) {
        this.spark = x$1;
    }

    @Override
    public SparkContext sc() {
        return this.sc;
    }

    @Override
    public void sc_$eq(SparkContext x$1) {
        this.sc = x$1;
    }

    @Override
    public String checkpointDir() {
        return this.checkpointDir;
    }

    @Override
    public void checkpointDir_$eq(String x$1) {
        this.checkpointDir = x$1;
    }

    @Override
    public MLlibTestSparkContext$testImplicits$ testImplicits() {
        if (this.testImplicits$module == null) {
            this.testImplicits$lzycompute$1();
        }
        return this.testImplicits$module;
    }

    @Override
    public File org$apache$spark$ml$util$TempDirectory$$_tempDir() {
        return this.org$apache$spark$ml$util$TempDirectory$$_tempDir;
    }

    @Override
    public void org$apache$spark$ml$util$TempDirectory$$_tempDir_$eq(File x$1) {
        this.org$apache$spark$ml$util$TempDirectory$$_tempDir = x$1;
    }

    private final void testImplicits$lzycompute$1() {
        ANNSuite aNNSuite = this;
        synchronized (aNNSuite) {
            if (this.testImplicits$module == null) {
                this.testImplicits$module = new MLlibTestSparkContext$testImplicits$(this);
            }
        }
    }

    public ANNSuite() {
        TempDirectory.$init$(this);
        MLlibTestSparkContext.$init$(this);
        this.test("ANN with Sigmoid learns XOR function with LBFGS optimizer", (Seq)Predef$.MODULE$.wrapRefArray((Object[])new Tag[0]), (Function0)(JFunction0.mcV.sp & Serializable & scala.Serializable)() -> {
            TopologyModel model;
            double[][] inputs = (double[][])((Object[])new double[][]{{0.0, 0.0}, {0.0, 1.0}, {1.0, 0.0}, {1.0, 1.0}});
            double[] outputs = new double[]{0.0, 1.0, 1.0, 0.0};
            Tuple2[] data = (Tuple2[])new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])inputs)).zip((GenIterable)Predef$.MODULE$.wrapDoubleArray(outputs), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class))))).map((Function1 & Serializable & scala.Serializable)x0$1 -> {
                Tuple2 tuple2 = x0$1;
                if (tuple2 == null) {
                    throw new MatchError((Object)tuple2);
                }
                double[] features = (double[])tuple2._1();
                double label = tuple2._2$mcD$sp();
                Tuple2 tuple22 = new Tuple2((Object)Vectors$.MODULE$.dense(features), (Object)Vectors$.MODULE$.dense(label, (Seq)Predef$.MODULE$.wrapDoubleArray(new double[0])));
                return tuple22;
            }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class)));
            RDD rddData = this.sc().parallelize((Seq)Predef$.MODULE$.wrapRefArray((Object[])data), 1, ClassTag$.MODULE$.apply(Tuple2.class));
            int[] hiddenLayersTopology = new int[]{5};
            Tuple2 dataSample = (Tuple2)rddData.first();
            int n = ((Vector)dataSample._1()).size();
            int[] layerSizes = (int[])new ArrayOps.ofInt(Predef$.MODULE$.intArrayOps((int[])new ArrayOps.ofInt(Predef$.MODULE$.intArrayOps(hiddenLayersTopology)).$plus$colon((Object)BoxesRunTime.boxToInteger((int)n), ClassTag$.MODULE$.Int()))).$colon$plus((Object)BoxesRunTime.boxToInteger((int)((Vector)dataSample._2()).size()), ClassTag$.MODULE$.Int());
            FeedForwardTopology topology = FeedForwardTopology$.MODULE$.multiLayerPerceptron(layerSizes, false);
            Vector initialWeights = FeedForwardModel$.MODULE$.apply(topology, 23124L).weights();
            FeedForwardTrainer trainer = new FeedForwardTrainer((Topology)topology, 2, 1);
            trainer.setWeights(initialWeights);
            trainer.LBFGSOptimizer().setNumIterations(20);
            Tuple2 tuple2 = trainer.train(rddData);
            if (tuple2 == null) {
                throw new MatchError((Object)tuple2);
            }
            TopologyModel topologyModel = model = (TopologyModel)tuple2._1();
            TopologyModel model2 = topologyModel;
            Tuple2[] predictionAndLabels = (Tuple2[])rddData.map((Function1 & Serializable & scala.Serializable)x0$2 -> {
                Tuple2 tuple2 = x0$2;
                if (tuple2 == null) {
                    throw new MatchError((Object)tuple2);
                }
                Vector input = (Vector)tuple2._1();
                Vector label = (Vector)tuple2._2();
                Tuple2.mcDD.sp sp2 = new Tuple2.mcDD.sp(model2.predict(input).apply(0), label.apply(0));
                return sp2;
            }, ClassTag$.MODULE$.apply(Tuple2.class)).collect();
            new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])predictionAndLabels)).foreach((Function1 & Serializable & scala.Serializable)x0$3 -> {
                Tuple2 tuple2 = x0$3;
                if (tuple2 == null) {
                    throw new MatchError((Object)tuple2);
                }
                double p = tuple2._1$mcD$sp();
                double l = tuple2._2$mcD$sp();
                TripleEqualsSupport.Equalizer $org_scalatest_assert_macro_left = this.convertToEqualizer(BoxesRunTime.boxToLong((long)package$.MODULE$.round(p)));
                double $org_scalatest_assert_macro_right = l;
                Bool $org_scalatest_assert_macro_expr = Bool$.MODULE$.binaryMacroBool((Object)$org_scalatest_assert_macro_left, "===", (Object)BoxesRunTime.boxToDouble((double)$org_scalatest_assert_macro_right), $org_scalatest_assert_macro_left.$eq$eq$eq((Object)BoxesRunTime.boxToDouble((double)$org_scalatest_assert_macro_right), Equality$.MODULE$.default()), Prettifier$.MODULE$.default());
                Assertion assertion = Assertions$.MODULE$.assertionsHelper().macroAssert($org_scalatest_assert_macro_expr, (Object)"", Prettifier$.MODULE$.default(), new Position("ANNSuite.scala", "Please set the environment variable SCALACTIC_FILL_FILE_PATHNAMES to yes at compile time to enable this feature.", 53));
                return assertion;
            });
        }, new Position("ANNSuite.scala", "Please set the environment variable SCALACTIC_FILL_FILE_PATHNAMES to yes at compile time to enable this feature.", 28));
        this.test("ANN with SoftMax learns XOR function with 2-bit output and batch GD optimizer", (Seq)Predef$.MODULE$.wrapRefArray((Object[])new Tag[0]), (Function0)(JFunction0.mcV.sp & Serializable & scala.Serializable)() -> {
            TopologyModel model;
            double[][] inputs = (double[][])((Object[])new double[][]{{0.0, 0.0}, {0.0, 1.0}, {1.0, 0.0}, {1.0, 1.0}});
            double[][] outputs = (double[][])((Object[])new double[][]{{1.0, 0.0}, {0.0, 1.0}, {0.0, 1.0}, {1.0, 0.0}});
            Tuple2[] data = (Tuple2[])new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])inputs)).zip((GenIterable)Predef$.MODULE$.wrapRefArray((Object[])outputs), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class))))).map((Function1 & Serializable & scala.Serializable)x0$4 -> {
                Tuple2 tuple2 = x0$4;
                if (tuple2 == null) {
                    throw new MatchError((Object)tuple2);
                }
                double[] features = (double[])tuple2._1();
                double[] label = (double[])tuple2._2();
                Tuple2 tuple22 = new Tuple2((Object)Vectors$.MODULE$.dense(features), (Object)Vectors$.MODULE$.dense(label));
                return tuple22;
            }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class)));
            RDD rddData = this.sc().parallelize((Seq)Predef$.MODULE$.wrapRefArray((Object[])data), 1, ClassTag$.MODULE$.apply(Tuple2.class));
            int[] hiddenLayersTopology = new int[]{5};
            Tuple2 dataSample = (Tuple2)rddData.first();
            int n = ((Vector)dataSample._1()).size();
            int[] layerSizes = (int[])new ArrayOps.ofInt(Predef$.MODULE$.intArrayOps((int[])new ArrayOps.ofInt(Predef$.MODULE$.intArrayOps(hiddenLayersTopology)).$plus$colon((Object)BoxesRunTime.boxToInteger((int)n), ClassTag$.MODULE$.Int()))).$colon$plus((Object)BoxesRunTime.boxToInteger((int)((Vector)dataSample._2()).size()), ClassTag$.MODULE$.Int());
            FeedForwardTopology topology = FeedForwardTopology$.MODULE$.multiLayerPerceptron(layerSizes, false);
            Vector initialWeights = FeedForwardModel$.MODULE$.apply(topology, 23124L).weights();
            FeedForwardTrainer trainer = new FeedForwardTrainer((Topology)topology, 2, 2);
            trainer.LBFGSOptimizer().setConvergenceTol(1.0E-4).setNumIterations(20);
            trainer.setWeights(initialWeights).setStackSize(1);
            Tuple2 tuple2 = trainer.train(rddData);
            if (tuple2 == null) {
                throw new MatchError((Object)tuple2);
            }
            TopologyModel topologyModel = model = (TopologyModel)tuple2._1();
            TopologyModel model2 = topologyModel;
            Tuple2[] predictionAndLabels = (Tuple2[])rddData.map((Function1 & Serializable & scala.Serializable)x0$5 -> {
                Tuple2 tuple2 = x0$5;
                if (tuple2 == null) {
                    throw new MatchError((Object)tuple2);
                }
                Vector input = (Vector)tuple2._1();
                Vector label = (Vector)tuple2._2();
                Tuple2 tuple22 = new Tuple2((Object)model2.predict(input), (Object)label);
                return tuple22;
            }, ClassTag$.MODULE$.apply(Tuple2.class)).collect();
            new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])predictionAndLabels)).foreach((Function1 & Serializable & scala.Serializable)x0$6 -> {
                Tuple2 tuple2 = x0$6;
                if (tuple2 == null) {
                    throw new MatchError((Object)tuple2);
                }
                Vector p = (Vector)tuple2._1();
                Vector l = (Vector)tuple2._2();
                Bool $org_scalatest_assert_macro_expr = Bool$.MODULE$.simpleMacroBool(TestingUtils$.MODULE$.VectorWithAlmostEquals(p).$tilde$eq$eq(TestingUtils$.MODULE$.VectorWithAlmostEquals(l).absTol(0.5)), "org.apache.spark.ml.util.TestingUtils.VectorWithAlmostEquals(p).~==(org.apache.spark.ml.util.TestingUtils.VectorWithAlmostEquals(l).absTol(0.5))", Prettifier$.MODULE$.default());
                Assertion assertion = Assertions$.MODULE$.assertionsHelper().macroAssert($org_scalatest_assert_macro_expr, (Object)"", Prettifier$.MODULE$.default(), new Position("ANNSuite.scala", "Please set the environment variable SCALACTIC_FILL_FILE_PATHNAMES to yes at compile time to enable this feature.", 88));
                return assertion;
            });
        }, new Position("ANNSuite.scala", "Please set the environment variable SCALACTIC_FILL_FILE_PATHNAMES to yes at compile time to enable this feature.", 57));
    }
}

