in src/main/java/com/amazonaws/services/neptune/profiles/neptune_ml/v2/RdfTrainingDataConfigWriter.java [69:187]
private void writeRdfs() throws IOException {
generator.writeArrayFieldStart("rdfs");
Collection<LabelConfigV2> classificationSpecifications = config.nodeConfig().getAllClassificationSpecifications();
if (classificationSpecifications.isEmpty()) {
for (String filename : filenames) {
generator.writeStartObject();
generator.writeStringField("file_name", filename);
generator.writeObjectFieldStart("label");
generator.writeStringField("task_type", EdgeTaskTypeV2.link_prediction.name());
generator.writeArrayFieldStart("targets");
generator.writeStartObject();
generator.writeArrayFieldStart("split_rate");
for (Double splitRate : config.defaultSplitRates()) {
generator.writeNumber(splitRate);
}
generator.writeEndArray();
generator.writeEndObject();
generator.writeEndArray();
generator.writeEndObject();
generator.writeEndObject();
}
} else {
for (RdfTaskTypeV2 taskType : RdfTaskTypeV2.values()) {
List<LabelConfigV2> taskSpecificConfigs = classificationSpecifications.stream().filter(c -> c.taskType().equals(taskType.name())).collect(Collectors.toList());
if (taskSpecificConfigs.isEmpty()) {
continue;
}
if (taskType == RdfTaskTypeV2.link_prediction) {
for (String filename : filenames) {
generator.writeStartObject();
generator.writeStringField("file_name", filename);
generator.writeObjectFieldStart("label");
generator.writeStringField("task_type", taskType.name());
generator.writeArrayFieldStart("targets");
for (LabelConfigV2 taskSpecificConfig : taskSpecificConfigs) {
generator.writeStartObject();
if (StringUtils.isNotEmpty(taskSpecificConfig.subject())) {
generator.writeStringField("subject", taskSpecificConfig.subject());
} else {
warnings.add("'subject' field is missing for link_prediction task, so all edges will be treated as the training target.");
}
if (StringUtils.isNotEmpty(taskSpecificConfig.property())) {
generator.writeStringField("predicate", taskSpecificConfig.property());
}else {
warnings.add("'predicate' field is missing for link_prediction task, so all edges will be treated as the training target.");
}
if (StringUtils.isNotEmpty(taskSpecificConfig.object())) {
generator.writeStringField("object", taskSpecificConfig.object());
}else {
warnings.add("'object' field is missing for link_prediction task, so all edges will be treated as the training target.");
}
generator.writeArrayFieldStart("split_rate");
for (Double splitRate : taskSpecificConfig.splitRates()) {
generator.writeNumber(splitRate);
}
generator.writeEndArray();
generator.writeEndObject();
}
generator.writeEndArray();
generator.writeEndObject();
generator.writeEndObject();
}
} else {
for (String filename : filenames) {
generator.writeStartObject();
generator.writeStringField("file_name", filename);
generator.writeObjectFieldStart("label");
generator.writeStringField("task_type", taskType.name());
generator.writeArrayFieldStart("targets");
for (LabelConfigV2 taskSpecificConfig : taskSpecificConfigs) {
generator.writeStartObject();
generator.writeStringField("node", taskSpecificConfig.label().labelsAsString());
String property = taskSpecificConfig.property();
if (StringUtils.isNotEmpty(property)){
generator.writeStringField("predicate", property);
} else {
warnings.add(String.format("'predicate' field is missing for %s task. If the target nodes have more than one predicate defining the target node feature, the training task will fail with an error.", taskType));
}
generator.writeArrayFieldStart("split_rate");
for (Double splitRate : taskSpecificConfig.splitRates()) {
generator.writeNumber(splitRate);
}
generator.writeEndArray();
generator.writeEndObject();
}
generator.writeEndArray();
generator.writeEndObject();
generator.writeEndObject();
}
}
}
}
generator.writeEndArray();
}