private static TrainingDataWriterConfigV2 getTrainingJobWriterConfig()

in src/main/java/com/amazonaws/services/neptune/profiles/neptune_ml/v2/config/TrainingDataWriterConfigV2.java [66:179]


    private static TrainingDataWriterConfigV2 getTrainingJobWriterConfig(JsonNode json, int index, NeptuneMLSourceDataModel dataModel) {

        Collection<Double> defaultSplitRates = new ParseSplitRate(json, DEFAULT_SPLIT_RATES_V2, new ParsingContext("config")).parseSplitRates();
        Collection<LabelConfigV2> nodeClassLabels = new ArrayList<>();
        Collection<LabelConfigV2> edgeClassLabels = new ArrayList<>();
        Collection<NoneFeatureConfig> noneNodeFeatures = new ArrayList<>();
        Collection<TfIdfConfigV2> tfIdfNodeFeatures = new ArrayList<>();
        Collection<DatetimeConfigV2> datetimeNodeFeatures = new ArrayList<>();
        Collection<Word2VecConfig> word2VecNodeFeatures = new ArrayList<>();
        Collection<FastTextConfig> fastTextNodeFeatures = new ArrayList<>();
        Collection<SbertConfig> sbertNodeFeatures = new ArrayList<>();
        Collection<NumericalBucketFeatureConfigV2> numericalBucketNodeFeatures = new ArrayList<>();
        Collection<NoneFeatureConfig> noneEdgeFeatures = new ArrayList<>();
        Collection<TfIdfConfigV2> tfIdfEdgeFeatures = new ArrayList<>();
        Collection<DatetimeConfigV2> datetimeEdgeFeatures = new ArrayList<>();
        Collection<Word2VecConfig> word2VecEdgeFeatures = new ArrayList<>();
        Collection<FastTextConfig> fastTextEdgeFeatures = new ArrayList<>();
        Collection<SbertConfig> sbertEdgeFeatures = new ArrayList<>();
        Collection<NumericalBucketFeatureConfigV2> numericalBucketEdgeFeatures = new ArrayList<>();
        Collection<FeatureOverrideConfigV2> nodeFeatureOverrides = new ArrayList<>();
        Collection<FeatureOverrideConfigV2> edgeFeatureOverrides = new ArrayList<>();

        String name = json.has("name") ?
                json.get("name").textValue() :
                index > 1 ? String.format("%s-%s", DEFAULT_NAME_V2, index) : DEFAULT_NAME_V2;

        FeatureEncodingFlag featureEncodingFlag = FeatureEncodingFlag.auto;

        if (json.has("feature_encoding")) {
            try {
                featureEncodingFlag = FeatureEncodingFlag.valueOf(json.path("feature_encoding").textValue());
            } catch (IllegalArgumentException e) {
                // Use default value of auto
            }
        }

        if (json.has("targets")) {
            JsonNode labels = json.path("targets");
            Collection<JsonNode> labelNodes = new ArrayList<>();
            if (labels.isArray()) {
                labels.forEach(labelNodes::add);
            } else {
                labelNodes.add(labels);
            }
            ParseLabelsV2 parseLabels = new ParseLabelsV2(labelNodes, defaultSplitRates, dataModel);
            parseLabels.validate();
            nodeClassLabels.addAll(parseLabels.parseNodeClassLabels());
            edgeClassLabels.addAll(parseLabels.parseEdgeClassLabels());
        }

        if (json.has("features")) {

            JsonNode features = json.path("features");

            Collection<JsonNode> featureNodes = new ArrayList<>();
            if (features.isArray()) {
                features.forEach(featureNodes::add);
            } else {
                featureNodes.add(features);
            }

            ParseFeaturesV2 parseFeatures = new ParseFeaturesV2(featureNodes);

            parseFeatures.validate();

            noneNodeFeatures.addAll(parseFeatures.parseNoneFeatures(ParseFeaturesV2.NodeFeatureFilter, ParseFeaturesV2.NodeLabelSupplier));
            tfIdfNodeFeatures.addAll(parseFeatures.parseTfIdfFeatures(ParseFeaturesV2.NodeFeatureFilter, ParseFeaturesV2.NodeLabelSupplier));
            datetimeNodeFeatures.addAll(parseFeatures.parseDatetimeFeatures(ParseFeaturesV2.NodeFeatureFilter, ParseFeaturesV2.NodeLabelSupplier));
            word2VecNodeFeatures.addAll(parseFeatures.parseWord2VecFeatures(ParseFeaturesV2.NodeFeatureFilter, ParseFeaturesV2.NodeLabelSupplier));
            fastTextNodeFeatures.addAll(parseFeatures.parseFastTextFeatures(ParseFeaturesV2.NodeFeatureFilter, ParseFeaturesV2.NodeLabelSupplier));
            sbertNodeFeatures.addAll(parseFeatures.parseSbertFeatures(ParseFeaturesV2.NodeFeatureFilter, ParseFeaturesV2.NodeLabelSupplier));
            numericalBucketNodeFeatures.addAll(parseFeatures.parseNumericalBucketFeatures(ParseFeaturesV2.NodeFeatureFilter, ParseFeaturesV2.NodeLabelSupplier));

            noneEdgeFeatures.addAll(parseFeatures.parseNoneFeatures(ParseFeaturesV2.EdgeFeatureFilter, ParseFeaturesV2.EdgeLabelSupplier));
            tfIdfEdgeFeatures.addAll(parseFeatures.parseTfIdfFeatures(ParseFeaturesV2.EdgeFeatureFilter, ParseFeaturesV2.EdgeLabelSupplier));
            datetimeEdgeFeatures.addAll(parseFeatures.parseDatetimeFeatures(ParseFeaturesV2.EdgeFeatureFilter, ParseFeaturesV2.EdgeLabelSupplier));
            word2VecEdgeFeatures.addAll(parseFeatures.parseWord2VecFeatures(ParseFeaturesV2.EdgeFeatureFilter, ParseFeaturesV2.EdgeLabelSupplier));
            fastTextEdgeFeatures.addAll(parseFeatures.parseFastTextFeatures(ParseFeaturesV2.EdgeFeatureFilter, ParseFeaturesV2.EdgeLabelSupplier));
            sbertEdgeFeatures.addAll(parseFeatures.parseSbertFeatures(ParseFeaturesV2.EdgeFeatureFilter, ParseFeaturesV2.EdgeLabelSupplier));
            numericalBucketEdgeFeatures.addAll(parseFeatures.parseNumericalBucketFeatures(ParseFeaturesV2.EdgeFeatureFilter, ParseFeaturesV2.EdgeLabelSupplier));

            nodeFeatureOverrides.addAll(parseFeatures.parseNodeFeatureOverrides());
            edgeFeatureOverrides.addAll(parseFeatures.parseEdgeFeatureOverrides());

        }

        ElementConfig nodeConfig = new ElementConfig(
                nodeClassLabels,
                noneNodeFeatures,
                tfIdfNodeFeatures,
                datetimeNodeFeatures,
                word2VecNodeFeatures,
                fastTextNodeFeatures,
                sbertNodeFeatures,
                numericalBucketNodeFeatures,
                nodeFeatureOverrides);

        ElementConfig edgeConfig = new ElementConfig(
                edgeClassLabels,
                noneEdgeFeatures,
                tfIdfEdgeFeatures,
                datetimeEdgeFeatures,
                word2VecEdgeFeatures,
                fastTextEdgeFeatures,
                sbertEdgeFeatures,
                numericalBucketEdgeFeatures,
                edgeFeatureOverrides);

        return new TrainingDataWriterConfigV2(name,
                featureEncodingFlag,
                defaultSplitRates,
                nodeConfig,
                edgeConfig);
    }