package org.apache.spark.sql.catalyst.optimizer;

import org.apache.spark.sql.catalyst.expressions.Alias;
import org.apache.spark.sql.catalyst.expressions.ArrayTransform;
import org.apache.spark.sql.catalyst.expressions.CaseWhen;
import org.apache.spark.sql.catalyst.expressions.Coalesce;
import org.apache.spark.sql.catalyst.expressions.CreateArray;
import org.apache.spark.sql.catalyst.expressions.CreateMap;
import org.apache.spark.sql.catalyst.expressions.CreateNamedStruct;
import org.apache.spark.sql.catalyst.expressions.Expression;
import org.apache.spark.sql.catalyst.expressions.GetStructField;
import org.apache.spark.sql.catalyst.expressions.GetStructField$;
import org.apache.spark.sql.catalyst.expressions.If;
import org.apache.spark.sql.catalyst.expressions.IsNull;
import org.apache.spark.sql.catalyst.expressions.KnownFloatingPointNormalized;
import org.apache.spark.sql.catalyst.expressions.LambdaFunction;
import org.apache.spark.sql.catalyst.expressions.LambdaFunction$;
import org.apache.spark.sql.catalyst.expressions.Literal;
import org.apache.spark.sql.catalyst.expressions.Literal$;
import org.apache.spark.sql.catalyst.expressions.NamedLambdaVariable;
import org.apache.spark.sql.catalyst.expressions.NamedLambdaVariable$;
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
import org.apache.spark.sql.catalyst.rules.Rule;
import org.apache.spark.sql.types.ArrayType;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DoubleType$;
import org.apache.spark.sql.types.FloatType$;
import org.apache.spark.sql.types.MapType;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import scala.Array$;
import scala.Function1;
import scala.MatchError;
import scala.Predef$;
import scala.Tuple2;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.mutable.ArrayOps;
import scala.reflect.ClassTag$;
import scala.runtime.BoxesRunTime;

/* compiled from: NormalizeFloatingNumbers.scala */
/* loaded from: input_file:org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers$.class */
public final class NormalizeFloatingNumbers$ extends Rule<LogicalPlan> {
    public static NormalizeFloatingNumbers$ MODULE$;
    private final Function1<Object, Object> FLOAT_NORMALIZER;
    private final Function1<Object, Object> DOUBLE_NORMALIZER;

    static {
        new NormalizeFloatingNumbers$();
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.apache.spark.sql.catalyst.rules.Rule
    public LogicalPlan apply(LogicalPlan logicalPlan) {
        return (LogicalPlan) logicalPlan.transform(new NormalizeFloatingNumbers$$anonfun$apply$1());
    }

    public boolean org$apache$spark$sql$catalyst$optimizer$NormalizeFloatingNumbers$$needNormalize(Expression expression) {
        return expression instanceof KnownFloatingPointNormalized ? false : needNormalize(expression.dataType());
    }

    private boolean needNormalize(DataType dataType) {
        boolean z;
        while (true) {
            DataType dataType2 = dataType;
            if (FloatType$.MODULE$.equals(dataType2) ? true : DoubleType$.MODULE$.equals(dataType2)) {
                z = true;
                break;
            }
            if (dataType2 instanceof StructType) {
                z = new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(((StructType) dataType2).fields())).exists(structField -> {
                    return BoxesRunTime.boxToBoolean($anonfun$needNormalize$1(structField));
                });
                break;
            }
            if (dataType2 instanceof ArrayType) {
                dataType = ((ArrayType) dataType2).elementType();
            } else {
                if (dataType2 instanceof MapType) {
                    throw new IllegalStateException("grouping/join/window partition keys cannot be map type.");
                }
                z = false;
            }
        }
        return z;
    }

    public Expression normalize(Expression expression) {
        Expression knownFloatingPointNormalized;
        if (!org$apache$spark$sql$catalyst$optimizer$NormalizeFloatingNumbers$$needNormalize(expression)) {
            knownFloatingPointNormalized = expression;
        } else if (expression instanceof Alias) {
            Alias alias = (Alias) expression;
            knownFloatingPointNormalized = alias.withNewChildren(Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new Expression[]{normalize(alias.mo436child())})));
        } else if (expression instanceof CreateNamedStruct) {
            knownFloatingPointNormalized = new CreateNamedStruct((Seq) ((CreateNamedStruct) expression).children().map(expression2 -> {
                return MODULE$.normalize(expression2);
            }, Seq$.MODULE$.canBuildFrom()));
        } else if (expression instanceof CreateArray) {
            CreateArray createArray = (CreateArray) expression;
            knownFloatingPointNormalized = new CreateArray((Seq) createArray.children().map(expression3 -> {
                return MODULE$.normalize(expression3);
            }, Seq$.MODULE$.canBuildFrom()), createArray.useStringTypeWhenEmpty());
        } else if (expression instanceof CreateMap) {
            CreateMap createMap = (CreateMap) expression;
            knownFloatingPointNormalized = new CreateMap((Seq) createMap.children().map(expression4 -> {
                return MODULE$.normalize(expression4);
            }, Seq$.MODULE$.canBuildFrom()), createMap.useStringTypeWhenEmpty());
        } else {
            DataType dataType = expression.dataType();
            FloatType$ floatType$ = FloatType$.MODULE$;
            if (dataType != null ? !dataType.equals(floatType$) : floatType$ != null) {
                DataType dataType2 = expression.dataType();
                DoubleType$ doubleType$ = DoubleType$.MODULE$;
                if (dataType2 != null ? !dataType2.equals(doubleType$) : doubleType$ != null) {
                    if (expression instanceof If) {
                        If r0 = (If) expression;
                        knownFloatingPointNormalized = new If(r0.predicate(), normalize(r0.trueValue()), normalize(r0.falseValue()));
                    } else if (expression instanceof CaseWhen) {
                        CaseWhen caseWhen = (CaseWhen) expression;
                        knownFloatingPointNormalized = new CaseWhen((Seq) caseWhen.branches().map(tuple2 -> {
                            return new Tuple2(tuple2._1(), MODULE$.normalize((Expression) tuple2._2()));
                        }, Seq$.MODULE$.canBuildFrom()), caseWhen.elseValue().map(expression5 -> {
                            return MODULE$.normalize(expression5);
                        }));
                    } else if (expression instanceof Coalesce) {
                        knownFloatingPointNormalized = new Coalesce((Seq) ((Coalesce) expression).children().map(expression6 -> {
                            return MODULE$.normalize(expression6);
                        }, Seq$.MODULE$.canBuildFrom()));
                    } else if (expression.dataType() instanceof StructType) {
                        CreateNamedStruct createNamedStruct = new CreateNamedStruct(new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Seq[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(((StructType) expression.dataType()).fieldNames())).zipWithIndex(Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class))))).map(tuple22 -> {
                            if (tuple22 == null) {
                                throw new MatchError(tuple22);
                            }
                            return Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new Expression[]{Literal$.MODULE$.apply((String) tuple22._1()), MODULE$.normalize(new GetStructField(expression, tuple22._2$mcI$sp(), GetStructField$.MODULE$.apply$default$3()))}));
                        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Seq.class))))).flatten(Predef$.MODULE$.$conforms(), ClassTag$.MODULE$.apply(Expression.class)))).toSeq());
                        knownFloatingPointNormalized = new KnownFloatingPointNormalized(new If(new IsNull(expression), new Literal(null, createNamedStruct.dataType()), createNamedStruct));
                    } else {
                        if (!(expression.dataType() instanceof ArrayType)) {
                            throw new IllegalStateException(new StringBuilder(18).append("fail to normalize ").append(expression).toString());
                        }
                        DataType dataType3 = expression.dataType();
                        if (!(dataType3 instanceof ArrayType)) {
                            throw new MatchError(dataType3);
                        }
                        ArrayType arrayType = (ArrayType) dataType3;
                        Tuple2 tuple23 = new Tuple2(arrayType.elementType(), BoxesRunTime.boxToBoolean(arrayType.containsNull()));
                        NamedLambdaVariable namedLambdaVariable = new NamedLambdaVariable("arg", (DataType) tuple23._1(), tuple23._2$mcZ$sp(), NamedLambdaVariable$.MODULE$.apply$default$4(), NamedLambdaVariable$.MODULE$.apply$default$5());
                        knownFloatingPointNormalized = new KnownFloatingPointNormalized(new ArrayTransform(expression, new LambdaFunction(normalize(namedLambdaVariable), Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new NamedLambdaVariable[]{namedLambdaVariable})), LambdaFunction$.MODULE$.apply$default$3())));
                    }
                }
            }
            knownFloatingPointNormalized = new KnownFloatingPointNormalized(new NormalizeNaNAndZero(expression));
        }
        return knownFloatingPointNormalized;
    }

    public Function1<Object, Object> FLOAT_NORMALIZER() {
        return this.FLOAT_NORMALIZER;
    }

    public Function1<Object, Object> DOUBLE_NORMALIZER() {
        return this.DOUBLE_NORMALIZER;
    }

    public static final /* synthetic */ boolean $anonfun$needNormalize$1(StructField structField) {
        return MODULE$.needNormalize(structField.dataType());
    }

    private NormalizeFloatingNumbers$() {
        MODULE$ = this;
        this.FLOAT_NORMALIZER = obj -> {
            float unboxToFloat = BoxesRunTime.unboxToFloat(obj);
            return Predef$.MODULE$.float2Float(unboxToFloat).isNaN() ? BoxesRunTime.boxToFloat(Float.NaN) : unboxToFloat == -0.0f ? BoxesRunTime.boxToFloat(0.0f) : BoxesRunTime.boxToFloat(unboxToFloat);
        };
        this.DOUBLE_NORMALIZER = obj2 -> {
            double unboxToDouble = BoxesRunTime.unboxToDouble(obj2);
            return Predef$.MODULE$.double2Double(unboxToDouble).isNaN() ? BoxesRunTime.boxToDouble(Double.NaN) : unboxToDouble == -0.0d ? BoxesRunTime.boxToDouble(0.0d) : BoxesRunTime.boxToDouble(unboxToDouble);
        };
    }
}
