in src/main/java/com/amazonaws/services/neptune/profiles/neptune_ml/v2/config/TrainingDataWriterConfigV2.java [66:179]
private static TrainingDataWriterConfigV2 getTrainingJobWriterConfig(JsonNode json, int index, NeptuneMLSourceDataModel dataModel) {
Collection<Double> defaultSplitRates = new ParseSplitRate(json, DEFAULT_SPLIT_RATES_V2, new ParsingContext("config")).parseSplitRates();
Collection<LabelConfigV2> nodeClassLabels = new ArrayList<>();
Collection<LabelConfigV2> edgeClassLabels = new ArrayList<>();
Collection<NoneFeatureConfig> noneNodeFeatures = new ArrayList<>();
Collection<TfIdfConfigV2> tfIdfNodeFeatures = new ArrayList<>();
Collection<DatetimeConfigV2> datetimeNodeFeatures = new ArrayList<>();
Collection<Word2VecConfig> word2VecNodeFeatures = new ArrayList<>();
Collection<FastTextConfig> fastTextNodeFeatures = new ArrayList<>();
Collection<SbertConfig> sbertNodeFeatures = new ArrayList<>();
Collection<NumericalBucketFeatureConfigV2> numericalBucketNodeFeatures = new ArrayList<>();
Collection<NoneFeatureConfig> noneEdgeFeatures = new ArrayList<>();
Collection<TfIdfConfigV2> tfIdfEdgeFeatures = new ArrayList<>();
Collection<DatetimeConfigV2> datetimeEdgeFeatures = new ArrayList<>();
Collection<Word2VecConfig> word2VecEdgeFeatures = new ArrayList<>();
Collection<FastTextConfig> fastTextEdgeFeatures = new ArrayList<>();
Collection<SbertConfig> sbertEdgeFeatures = new ArrayList<>();
Collection<NumericalBucketFeatureConfigV2> numericalBucketEdgeFeatures = new ArrayList<>();
Collection<FeatureOverrideConfigV2> nodeFeatureOverrides = new ArrayList<>();
Collection<FeatureOverrideConfigV2> edgeFeatureOverrides = new ArrayList<>();
String name = json.has("name") ?
json.get("name").textValue() :
index > 1 ? String.format("%s-%s", DEFAULT_NAME_V2, index) : DEFAULT_NAME_V2;
FeatureEncodingFlag featureEncodingFlag = FeatureEncodingFlag.auto;
if (json.has("feature_encoding")) {
try {
featureEncodingFlag = FeatureEncodingFlag.valueOf(json.path("feature_encoding").textValue());
} catch (IllegalArgumentException e) {
// Use default value of auto
}
}
if (json.has("targets")) {
JsonNode labels = json.path("targets");
Collection<JsonNode> labelNodes = new ArrayList<>();
if (labels.isArray()) {
labels.forEach(labelNodes::add);
} else {
labelNodes.add(labels);
}
ParseLabelsV2 parseLabels = new ParseLabelsV2(labelNodes, defaultSplitRates, dataModel);
parseLabels.validate();
nodeClassLabels.addAll(parseLabels.parseNodeClassLabels());
edgeClassLabels.addAll(parseLabels.parseEdgeClassLabels());
}
if (json.has("features")) {
JsonNode features = json.path("features");
Collection<JsonNode> featureNodes = new ArrayList<>();
if (features.isArray()) {
features.forEach(featureNodes::add);
} else {
featureNodes.add(features);
}
ParseFeaturesV2 parseFeatures = new ParseFeaturesV2(featureNodes);
parseFeatures.validate();
noneNodeFeatures.addAll(parseFeatures.parseNoneFeatures(ParseFeaturesV2.NodeFeatureFilter, ParseFeaturesV2.NodeLabelSupplier));
tfIdfNodeFeatures.addAll(parseFeatures.parseTfIdfFeatures(ParseFeaturesV2.NodeFeatureFilter, ParseFeaturesV2.NodeLabelSupplier));
datetimeNodeFeatures.addAll(parseFeatures.parseDatetimeFeatures(ParseFeaturesV2.NodeFeatureFilter, ParseFeaturesV2.NodeLabelSupplier));
word2VecNodeFeatures.addAll(parseFeatures.parseWord2VecFeatures(ParseFeaturesV2.NodeFeatureFilter, ParseFeaturesV2.NodeLabelSupplier));
fastTextNodeFeatures.addAll(parseFeatures.parseFastTextFeatures(ParseFeaturesV2.NodeFeatureFilter, ParseFeaturesV2.NodeLabelSupplier));
sbertNodeFeatures.addAll(parseFeatures.parseSbertFeatures(ParseFeaturesV2.NodeFeatureFilter, ParseFeaturesV2.NodeLabelSupplier));
numericalBucketNodeFeatures.addAll(parseFeatures.parseNumericalBucketFeatures(ParseFeaturesV2.NodeFeatureFilter, ParseFeaturesV2.NodeLabelSupplier));
noneEdgeFeatures.addAll(parseFeatures.parseNoneFeatures(ParseFeaturesV2.EdgeFeatureFilter, ParseFeaturesV2.EdgeLabelSupplier));
tfIdfEdgeFeatures.addAll(parseFeatures.parseTfIdfFeatures(ParseFeaturesV2.EdgeFeatureFilter, ParseFeaturesV2.EdgeLabelSupplier));
datetimeEdgeFeatures.addAll(parseFeatures.parseDatetimeFeatures(ParseFeaturesV2.EdgeFeatureFilter, ParseFeaturesV2.EdgeLabelSupplier));
word2VecEdgeFeatures.addAll(parseFeatures.parseWord2VecFeatures(ParseFeaturesV2.EdgeFeatureFilter, ParseFeaturesV2.EdgeLabelSupplier));
fastTextEdgeFeatures.addAll(parseFeatures.parseFastTextFeatures(ParseFeaturesV2.EdgeFeatureFilter, ParseFeaturesV2.EdgeLabelSupplier));
sbertEdgeFeatures.addAll(parseFeatures.parseSbertFeatures(ParseFeaturesV2.EdgeFeatureFilter, ParseFeaturesV2.EdgeLabelSupplier));
numericalBucketEdgeFeatures.addAll(parseFeatures.parseNumericalBucketFeatures(ParseFeaturesV2.EdgeFeatureFilter, ParseFeaturesV2.EdgeLabelSupplier));
nodeFeatureOverrides.addAll(parseFeatures.parseNodeFeatureOverrides());
edgeFeatureOverrides.addAll(parseFeatures.parseEdgeFeatureOverrides());
}
ElementConfig nodeConfig = new ElementConfig(
nodeClassLabels,
noneNodeFeatures,
tfIdfNodeFeatures,
datetimeNodeFeatures,
word2VecNodeFeatures,
fastTextNodeFeatures,
sbertNodeFeatures,
numericalBucketNodeFeatures,
nodeFeatureOverrides);
ElementConfig edgeConfig = new ElementConfig(
edgeClassLabels,
noneEdgeFeatures,
tfIdfEdgeFeatures,
datetimeEdgeFeatures,
word2VecEdgeFeatures,
fastTextEdgeFeatures,
sbertEdgeFeatures,
numericalBucketEdgeFeatures,
edgeFeatureOverrides);
return new TrainingDataWriterConfigV2(name,
featureEncodingFlag,
defaultSplitRates,
nodeConfig,
edgeConfig);
}