static Dataset marshalData()

in c3r-cli-spark/src/main/java/com/amazonaws/c3r/spark/action/SparkMarshaller.java [262:335]


    static Dataset<Row> marshalData(final Dataset<Row> rawInputData, final SparkEncryptConfig encryptConfig,
                                    final List<ColumnInsight> columnInsights) {
        // Copy out values that need to be serialized
        final ClientSettings settings = encryptConfig.getSettings();
        final String salt = encryptConfig.getSalt();
        final String base64EncodedKey = Base64.getEncoder().encodeToString(encryptConfig.getSecretKey().getEncoded());

        final ExpressionEncoder<Row> rowEncoder = ExpressionEncoder.apply(rawInputData.schema());
        final StructField[] fields = rawInputData.schema().fields();
        try {
            return rawInputData.map((MapFunction<Row, Row>) row -> {
                // Grab a nonce for the row
                final Nonce nonce = Nonce.nextNonce();
                // Build a list of transformers for each row, limiting state to keys/salts/settings POJOs
                final Map<ColumnType, Transformer> transformers = Transformer.initTransformers(
                        KeyUtil.sharedSecretKeyFromString(base64EncodedKey),
                        salt,
                        settings,
                        false); // Not relevant to encryption
                // For each column in the row, transform the data
                return Row.fromSeq(
                        CollectionConverters.IteratorHasAsScala(columnInsights.stream().map(column -> {
                            if (column.getType() == ColumnType.CLEARTEXT) {
                                return row.get(column.getSourceColumnPosition());
                            }
                            if (fields[column.getSourceColumnPosition()].dataType() != DataTypes.StringType) {
                                throw new C3rRuntimeException("Encrypting non-String values is not supported. Non-String column marked" +
                                        " for encryption: `" + column.getTargetHeader() + "`");
                            }
                            final Transformer transformer = transformers.get(column.getType());
                            final String data = row.getString(column.getSourceColumnPosition());
                            // NOTE: This is essentially a custom in-place version of ValueConverter.getBytesForColumn
                            // since the Spark client can't use the Value infrastructure
                            final byte[] dataBytes;
                            if (column.getType() == ColumnType.SEALED) {
                                if (data == null && settings.isPreserveNulls()) {
                                    dataBytes = null;
                                } else {
                                    dataBytes = ValueConverter.String.encode(data);
                                }
                            } else if (column.getType() == ColumnType.FINGERPRINT) {
                                if (data == null) {
                                    if (settings.isPreserveNulls()) {
                                        dataBytes = null;
                                    } else {
                                        dataBytes = new byte[NULL_RANDOM_BYTE_SIZE + ClientDataInfo.BYTE_LENGTH];
                                        RANDOM.nextBytes(dataBytes);
                                        dataBytes[NULL_RANDOM_BYTE_SIZE] = ClientDataInfo.builder()
                                                .type(ClientDataType.STRING)
                                                .isNull(true)
                                                .build()
                                                .encode();
                                    }
                                } else {
                                    final byte[] utf8Bytes = data.getBytes(StandardCharsets.UTF_8);
                                    dataBytes = ByteBuffer.allocate(utf8Bytes.length + ClientDataInfo.BYTE_LENGTH)
                                            .put(utf8Bytes)
                                            .put(ClientDataInfo.builder().type(ClientDataType.STRING).isNull(false).build().encode())
                                            .array();
                                }
                            } else {
                                dataBytes = data == null ? null : data.getBytes(StandardCharsets.UTF_8);
                            }
                            final EncryptionContext encryptionContext = new EncryptionContext(column, nonce, ClientDataType.STRING);
                            final byte[] marshalledBytes = transformer.marshal(dataBytes, encryptionContext);
                            return (marshalledBytes == null ? null : new String(marshalledBytes, StandardCharsets.UTF_8));
                        }).iterator()).asScala().toSeq());
            }, rowEncoder);
        } catch (C3rRuntimeException e) {
            throw e;
        } catch (Exception e) {
            throw new C3rRuntimeException("Unknown exception when encrypting data.", e);
        }
    }