in c3r-cli-spark/src/main/java/com/amazonaws/c3r/spark/action/SparkUnmarshaller.java [90:137]
static Dataset<Row> unmarshalData(final Dataset<Row> rawInputData, final SparkDecryptConfig decryptConfig) {
// Copy out values that need to be serialized
final String salt = decryptConfig.getSalt();
final String base64EncodedKey = Base64.getEncoder().encodeToString(decryptConfig.getSecretKey().getEncoded());
final boolean failOnFingerprintColumns = decryptConfig.isFailOnFingerprintColumns();
final ExpressionEncoder<Row> rowEncoder = ExpressionEncoder.apply(rawInputData.schema());
final StructField[] fields = rawInputData.schema().fields();
try {
return rawInputData.map((MapFunction<Row, Row>) row -> {
// Build a list of transformers for each row, limiting state to keys/salts/settings POJOs
final Map<ColumnType, Transformer> transformers = Transformer.initTransformers(
KeyUtil.sharedSecretKeyFromString(base64EncodedKey),
salt,
null, // Not relevant to decryption.
failOnFingerprintColumns);
// For each column in the row, transform the data
final List<Object> unmarshalledValues = new ArrayList<>();
for (int i = 0; i < row.size(); i++) {
// Pass through non-String data types
if (fields[i].dataType() != DataTypes.StringType) {
unmarshalledValues.add(row.get(i));
continue;
}
final String data = row.getString(i);
final byte[] dataBytes = data == null ? null : data.getBytes(StandardCharsets.UTF_8);
Transformer transformer = transformers.get(ColumnType.CLEARTEXT); // Default to pass through
Function<byte[], byte[]> decode = x -> x;
if (Transformer.hasDescriptor(transformers.get(ColumnType.SEALED), dataBytes)) {
transformer = transformers.get(ColumnType.SEALED);
decode = x -> {
final String str = ValueConverter.String.decode(x);
return str == null ? null : str.getBytes(StandardCharsets.UTF_8);
};
} else if (Transformer.hasDescriptor(transformers.get(ColumnType.FINGERPRINT), dataBytes)) {
transformer = transformers.get(ColumnType.FINGERPRINT);
}
final byte[] unmarshalledBytes = decode.apply(transformer.unmarshal(dataBytes));
unmarshalledValues.add(unmarshalledBytes == null ? null : new String(unmarshalledBytes, StandardCharsets.UTF_8));
}
return Row.fromSeq(
CollectionConverters.IteratorHasAsScala(unmarshalledValues.iterator()).asScala().toSeq());
}, rowEncoder);
} catch (Exception e) {
throw new C3rRuntimeException("Unknown exception when decrypting data.", e);
}
}