def validateInputDataFrame()

in scala-spark-sdk/src/main/scala/software/amazon/sagemaker/featurestore/sparksdk/validators/InputDataSchemaValidator.scala [45:96]


  def validateInputDataFrame(
      dataFrame: DataFrame,
      describeResponse: DescribeFeatureGroupResponse
  ): DataFrame = {
    val recordIdentifierName = describeResponse.recordIdentifierFeatureName()
    val eventTimeFeatureName = describeResponse.eventTimeFeatureName()

    validateSchemaNames(dataFrame.schema.names, describeResponse, recordIdentifierName, eventTimeFeatureName)

    // Numeric data types validation - For example, verify that only numeric values that are within bounds
    // of the Integer and Double data types are present in the corresponding fields.
    // This should be caught by the type conversion check as well
    val schemaDataTypeValidatorMap = getSchemaDataTypeValidatorMap(
      dataFrame = dataFrame,
      featureDefinitions = describeResponse.featureDefinitions().asScala.toList,
      describeResponse.eventTimeFeatureName()
    )
    val schemaDataTypeValidatorColumn = getSchemaDataTypeValidatorColumn(
      schemaDataTypeValidatorMap,
      recordIdentifierName,
      eventTimeFeatureName
    )

    val invalidRows = dataFrame
      .withColumn(
        "dataTypeValidationErrors",
        concat_ws(",", schemaDataTypeValidatorColumn: _*)
      )
      .filter(col("dataTypeValidationErrors").like("%not valid"))

    if (!invalidRows.isEmpty) {

      invalidRows
        .select(col(recordIdentifierName), col("dataTypeValidationErrors"))
        .show(numRows = 20, truncate = false)

      throw ValidationError(
        "Cannot proceed. Some records contain columns with data types that are not registered in the FeatureGroup " +
          "or records values equal to NaN."
      )
    }

    val dataTypeTransformationMap = getSchemaDataTypeTransformationMap(
      schemaDataTypeValidatorMap,
      describeResponse.featureDefinitions().asScala.toList,
      eventTimeFeatureName
    )

    dataFrame.select(dataFrame.columns.map { col =>
      dataTypeTransformationMap(col).apply(col)
    }: _*)
  }