package org.tensorflow.spark.datasources.tfrecords.serde;

import org.apache.spark.ml.linalg.DenseVector;
import org.apache.spark.ml.linalg.SQLDataTypes$;
import org.apache.spark.ml.linalg.SparseVector;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.catalyst.util.ArrayData$;
import org.apache.spark.sql.types.ArrayType;
import org.apache.spark.sql.types.BinaryType$;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Decimal;
import org.apache.spark.sql.types.DecimalType$;
import org.apache.spark.sql.types.DoubleType$;
import org.apache.spark.sql.types.FloatType$;
import org.apache.spark.sql.types.IntegerType$;
import org.apache.spark.sql.types.LongType$;
import org.apache.spark.sql.types.StringType$;
import org.apache.spark.sql.types.StructField;
import org.tensorflow.example.Example;
import org.tensorflow.example.Feature;
import org.tensorflow.example.FeatureList;
import org.tensorflow.example.FeatureLists;
import org.tensorflow.example.Features;
import org.tensorflow.example.SequenceExample;
import scala.Array$;
import scala.Predef$;
import scala.Predef$DummyImplicit$;
import scala.StringContext;
import scala.collection.IterableLike;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.reflect.ClassTag$;
import scala.runtime.ScalaRunTime$;

/* compiled from: DefaultTfRecordRowEncoder.scala */
/* loaded from: input_file:org/tensorflow/spark/datasources/tfrecords/serde/DefaultTfRecordRowEncoder$.class */
public final class DefaultTfRecordRowEncoder$ implements TfRecordRowEncoder {
    public static final DefaultTfRecordRowEncoder$ MODULE$ = null;

    static {
        new DefaultTfRecordRowEncoder$();
    }

    @Override // org.tensorflow.spark.datasources.tfrecords.serde.TfRecordRowEncoder
    public Example encodeExample(Row row) {
        Features.Builder newBuilder = Features.newBuilder();
        Example.Builder newBuilder2 = Example.newBuilder();
        ((IterableLike) row.schema().zipWithIndex(Seq$.MODULE$.canBuildFrom())).foreach(new DefaultTfRecordRowEncoder$$anonfun$encodeExample$1(row, newBuilder));
        newBuilder2.setFeatures(newBuilder.build());
        return newBuilder2.build();
    }

    @Override // org.tensorflow.spark.datasources.tfrecords.serde.TfRecordRowEncoder
    public SequenceExample encodeSequenceExample(Row row) {
        Features.Builder newBuilder = Features.newBuilder();
        FeatureLists.Builder newBuilder2 = FeatureLists.newBuilder();
        SequenceExample.Builder newBuilder3 = SequenceExample.newBuilder();
        ((IterableLike) row.schema().zipWithIndex(Seq$.MODULE$.canBuildFrom())).foreach(new DefaultTfRecordRowEncoder$$anonfun$encodeSequenceExample$1(row, newBuilder, newBuilder2));
        newBuilder3.setContext(newBuilder.build());
        newBuilder3.setFeatureLists(newBuilder2.build());
        return newBuilder3.build();
    }

    /* JADX WARN: Multi-variable type inference failed */
    public Feature org$tensorflow$spark$datasources$tfrecords$serde$DefaultTfRecordRowEncoder$$encodeFeature(Row row, StructField structField, int i) {
        Feature encode;
        Feature feature;
        boolean z = false;
        ArrayType arrayType = null;
        DataType dataType = structField.dataType();
        if (IntegerType$.MODULE$.equals(dataType)) {
            feature = Int64ListFeatureEncoder$.MODULE$.encode((Seq<Object>) Seq$.MODULE$.apply(Predef$.MODULE$.wrapLongArray(new long[]{row.getInt(i)})));
        } else if (LongType$.MODULE$.equals(dataType)) {
            feature = Int64ListFeatureEncoder$.MODULE$.encode((Seq<Object>) Seq$.MODULE$.apply(Predef$.MODULE$.wrapLongArray(new long[]{row.getLong(i)})));
        } else if (FloatType$.MODULE$.equals(dataType)) {
            feature = FloatListFeatureEncoder$.MODULE$.encode((Seq<Object>) Seq$.MODULE$.apply(Predef$.MODULE$.wrapFloatArray(new float[]{row.getFloat(i)})));
        } else if (DoubleType$.MODULE$.equals(dataType)) {
            feature = FloatListFeatureEncoder$.MODULE$.encode((Seq<Object>) Seq$.MODULE$.apply(Predef$.MODULE$.wrapFloatArray(new float[]{(float) row.getDouble(i)})));
        } else if (DecimalType$.MODULE$.unapply(dataType)) {
            feature = FloatListFeatureEncoder$.MODULE$.encode((Seq<Object>) Seq$.MODULE$.apply(Predef$.MODULE$.wrapFloatArray(new float[]{((Decimal) row.getAs(i)).toFloat()})));
        } else if (StringType$.MODULE$.equals(dataType)) {
            feature = BytesListFeatureEncoder$.MODULE$.encode((Seq<byte[]>) Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray((Object[]) new byte[]{row.getString(i).getBytes()})));
        } else if (BinaryType$.MODULE$.equals(dataType)) {
            feature = BytesListFeatureEncoder$.MODULE$.encode((Seq<byte[]>) Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray((Object[]) new byte[]{(byte[]) row.getAs(i)})));
        } else {
            if (dataType instanceof ArrayType) {
                z = true;
                arrayType = (ArrayType) dataType;
                if (IntegerType$.MODULE$.equals(arrayType.elementType())) {
                    feature = Int64ListFeatureEncoder$.MODULE$.encode((Seq<Object>) Predef$.MODULE$.intArrayOps(ArrayData$.MODULE$.toArrayData(row.get(i)).toIntArray()).map(new DefaultTfRecordRowEncoder$$anonfun$1(), Array$.MODULE$.fallbackCanBuildFrom(Predef$DummyImplicit$.MODULE$.dummyImplicit())));
                }
            }
            if (z && LongType$.MODULE$.equals(arrayType.elementType())) {
                feature = Int64ListFeatureEncoder$.MODULE$.encode((Seq<Object>) Predef$.MODULE$.wrapLongArray(ArrayData$.MODULE$.toArrayData(row.get(i)).toLongArray()));
            } else if (z && FloatType$.MODULE$.equals(arrayType.elementType())) {
                feature = FloatListFeatureEncoder$.MODULE$.encode((Seq<Object>) Predef$.MODULE$.wrapFloatArray(ArrayData$.MODULE$.toArrayData(row.get(i)).toFloatArray()));
            } else if (z && DoubleType$.MODULE$.equals(arrayType.elementType())) {
                feature = FloatListFeatureEncoder$.MODULE$.encode((Seq<Object>) Predef$.MODULE$.doubleArrayOps(ArrayData$.MODULE$.toArrayData(row.get(i)).toDoubleArray()).map(new DefaultTfRecordRowEncoder$$anonfun$2(), Array$.MODULE$.fallbackCanBuildFrom(Predef$DummyImplicit$.MODULE$.dummyImplicit())));
            } else if (z && DecimalType$.MODULE$.unapply(arrayType.elementType())) {
                feature = FloatListFeatureEncoder$.MODULE$.encode((Seq<Object>) Predef$.MODULE$.refArrayOps((Decimal[]) ArrayData$.MODULE$.toArrayData(row.get(i)).toArray(DataTypes.createDecimalType(), ClassTag$.MODULE$.apply(Decimal.class))).map(new DefaultTfRecordRowEncoder$$anonfun$5(), Array$.MODULE$.fallbackCanBuildFrom(Predef$DummyImplicit$.MODULE$.dummyImplicit())));
            } else if (z && StringType$.MODULE$.equals(arrayType.elementType())) {
                feature = BytesListFeatureEncoder$.MODULE$.encode((Seq<byte[]>) Predef$.MODULE$.refArrayOps((Object[]) ArrayData$.MODULE$.toArrayData(row.get(i)).toArray(StringType$.MODULE$, ClassTag$.MODULE$.apply(String.class))).map(new DefaultTfRecordRowEncoder$$anonfun$6(), Array$.MODULE$.fallbackCanBuildFrom(Predef$DummyImplicit$.MODULE$.dummyImplicit())));
            } else if (z && BinaryType$.MODULE$.equals(arrayType.elementType())) {
                feature = BytesListFeatureEncoder$.MODULE$.encode((Seq<byte[]>) Predef$.MODULE$.wrapRefArray((Object[]) ArrayData$.MODULE$.toArrayData(row.get(i)).toArray(BinaryType$.MODULE$, ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(Byte.TYPE)))));
            } else {
                DataType VectorType = SQLDataTypes$.MODULE$.VectorType();
                if (VectorType != null ? !VectorType.equals(dataType) : dataType != null) {
                    throw new RuntimeException(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"Cannot convert field to unsupported data type ", ""})).s(Predef$.MODULE$.genericWrapArray(new Object[]{structField.dataType()})));
                }
                Object obj = row.get(i);
                if (obj instanceof SparseVector) {
                    encode = FloatListFeatureEncoder$.MODULE$.encode((Seq<Object>) Predef$.MODULE$.doubleArrayOps(((SparseVector) obj).toDense().toArray()).map(new DefaultTfRecordRowEncoder$$anonfun$3(), Array$.MODULE$.fallbackCanBuildFrom(Predef$DummyImplicit$.MODULE$.dummyImplicit())));
                } else {
                    if (!(obj instanceof DenseVector)) {
                        throw new RuntimeException(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"Cannot convert ", " to vector"})).s(Predef$.MODULE$.genericWrapArray(new Object[]{obj})));
                    }
                    encode = FloatListFeatureEncoder$.MODULE$.encode((Seq<Object>) Predef$.MODULE$.doubleArrayOps(((DenseVector) obj).toArray()).map(new DefaultTfRecordRowEncoder$$anonfun$4(), Array$.MODULE$.fallbackCanBuildFrom(Predef$DummyImplicit$.MODULE$.dummyImplicit())));
                }
                feature = encode;
            }
        }
        return feature;
    }

    public FeatureList encodeFeatureList(Row row, StructField structField, int i) {
        FeatureList encode;
        boolean z = false;
        ArrayType arrayType = null;
        DataType dataType = structField.dataType();
        if (dataType instanceof ArrayType) {
            z = true;
            arrayType = (ArrayType) dataType;
            ArrayType elementType = arrayType.elementType();
            if ((elementType instanceof ArrayType) && IntegerType$.MODULE$.equals(elementType.elementType())) {
                encode = Int64FeatureListEncoder$.MODULE$.encode((Seq<Seq<Object>>) Predef$.MODULE$.wrapRefArray((Seq[]) Predef$.MODULE$.genericArrayOps(ArrayData$.MODULE$.toArrayData(row.get(i)).array()).map(new DefaultTfRecordRowEncoder$$anonfun$7(), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Seq.class)))));
                return encode;
            }
        }
        if (z) {
            ArrayType elementType2 = arrayType.elementType();
            if ((elementType2 instanceof ArrayType) && LongType$.MODULE$.equals(elementType2.elementType())) {
                encode = Int64FeatureListEncoder$.MODULE$.encode((Seq<Seq<Object>>) Predef$.MODULE$.wrapRefArray((Seq[]) Predef$.MODULE$.genericArrayOps(ArrayData$.MODULE$.toArrayData(row.get(i)).array()).map(new DefaultTfRecordRowEncoder$$anonfun$8(), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Seq.class)))));
                return encode;
            }
        }
        if (z) {
            ArrayType elementType3 = arrayType.elementType();
            if ((elementType3 instanceof ArrayType) && FloatType$.MODULE$.equals(elementType3.elementType())) {
                encode = FloatFeatureListEncoder$.MODULE$.encode((Seq<Seq<Object>>) Predef$.MODULE$.wrapRefArray((Seq[]) Predef$.MODULE$.genericArrayOps(ArrayData$.MODULE$.toArrayData(row.get(i)).array()).map(new DefaultTfRecordRowEncoder$$anonfun$9(), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Seq.class)))));
                return encode;
            }
        }
        if (z) {
            ArrayType elementType4 = arrayType.elementType();
            if ((elementType4 instanceof ArrayType) && DoubleType$.MODULE$.equals(elementType4.elementType())) {
                encode = FloatFeatureListEncoder$.MODULE$.encode((Seq<Seq<Object>>) Predef$.MODULE$.wrapRefArray((Seq[]) Predef$.MODULE$.genericArrayOps(ArrayData$.MODULE$.toArrayData(row.get(i)).array()).map(new DefaultTfRecordRowEncoder$$anonfun$10(), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Seq.class)))));
                return encode;
            }
        }
        if (z) {
            ArrayType elementType5 = arrayType.elementType();
            if ((elementType5 instanceof ArrayType) && DecimalType$.MODULE$.unapply(elementType5.elementType())) {
                encode = FloatFeatureListEncoder$.MODULE$.encode((Seq<Seq<Object>>) Predef$.MODULE$.wrapRefArray((Seq[]) Predef$.MODULE$.genericArrayOps(ArrayData$.MODULE$.toArrayData(row.get(i)).array()).map(new DefaultTfRecordRowEncoder$$anonfun$11(), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Seq.class)))));
                return encode;
            }
        }
        if (z) {
            ArrayType elementType6 = arrayType.elementType();
            if ((elementType6 instanceof ArrayType) && StringType$.MODULE$.equals(elementType6.elementType())) {
                encode = BytesFeatureListEncoder$.MODULE$.encode((Seq<Seq<byte[]>>) Predef$.MODULE$.refArrayOps((Object[]) Predef$.MODULE$.genericArrayOps(ArrayData$.MODULE$.toArrayData(row.get(i)).array()).map(new DefaultTfRecordRowEncoder$$anonfun$12(), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Seq.class)))).toSeq());
                return encode;
            }
        }
        if (z) {
            ArrayType elementType7 = arrayType.elementType();
            if ((elementType7 instanceof ArrayType) && BinaryType$.MODULE$.equals(elementType7.elementType())) {
                encode = BytesFeatureListEncoder$.MODULE$.encode((Seq<Seq<byte[]>>) Predef$.MODULE$.refArrayOps((Object[]) Predef$.MODULE$.genericArrayOps(ArrayData$.MODULE$.toArrayData(row.get(i)).array()).map(new DefaultTfRecordRowEncoder$$anonfun$13(), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Seq.class)))).toSeq());
                return encode;
            }
        }
        throw new RuntimeException(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"Cannot convert row element ", " to FeatureList."})).s(Predef$.MODULE$.genericWrapArray(new Object[]{row.get(i)})));
    }

    private DefaultTfRecordRowEncoder$() {
        MODULE$ = this;
    }
}
