in scala-spark-sdk/src/main/scala/software/amazon/sagemaker/featurestore/sparksdk/validators/InputDataSchemaValidator.scala [45:96]
def validateInputDataFrame(
dataFrame: DataFrame,
describeResponse: DescribeFeatureGroupResponse
): DataFrame = {
val recordIdentifierName = describeResponse.recordIdentifierFeatureName()
val eventTimeFeatureName = describeResponse.eventTimeFeatureName()
validateSchemaNames(dataFrame.schema.names, describeResponse, recordIdentifierName, eventTimeFeatureName)
// Numeric data types validation - For example, verify that only numeric values that are within bounds
// of the Integer and Double data types are present in the corresponding fields.
// This should be caught by the type conversion check as well
val schemaDataTypeValidatorMap = getSchemaDataTypeValidatorMap(
dataFrame = dataFrame,
featureDefinitions = describeResponse.featureDefinitions().asScala.toList,
describeResponse.eventTimeFeatureName()
)
val schemaDataTypeValidatorColumn = getSchemaDataTypeValidatorColumn(
schemaDataTypeValidatorMap,
recordIdentifierName,
eventTimeFeatureName
)
val invalidRows = dataFrame
.withColumn(
"dataTypeValidationErrors",
concat_ws(",", schemaDataTypeValidatorColumn: _*)
)
.filter(col("dataTypeValidationErrors").like("%not valid"))
if (!invalidRows.isEmpty) {
invalidRows
.select(col(recordIdentifierName), col("dataTypeValidationErrors"))
.show(numRows = 20, truncate = false)
throw ValidationError(
"Cannot proceed. Some records contain columns with data types that are not registered in the FeatureGroup " +
"or records values equal to NaN."
)
}
val dataTypeTransformationMap = getSchemaDataTypeTransformationMap(
schemaDataTypeValidatorMap,
describeResponse.featureDefinitions().asScala.toList,
eventTimeFeatureName
)
dataFrame.select(dataFrame.columns.map { col =>
dataTypeTransformationMap(col).apply(col)
}: _*)
}