in wayang-ml4all/src/main/java/org/apache/wayang/ml4all/abstraction/plan/ML4allPlan.java [117:188]
public ML4allModel execute(String inputFileUrl, WayangContext wayangContext) {
wayangContext.getConfiguration().setProperty("wayang.core.optimizer.reoptimize", "false");
JavaPlanBuilder javaPlanBuilder = new JavaPlanBuilder(wayangContext)
.withUdfJar(ReflectionUtils.getDeclaringJar(ML4allModel.class))
.withUdfJar(ReflectionUtils.getDeclaringJar(JavaPlatform.class))
.withJobName("ML4all plan");
ML4allModel vars = new ML4allModel();
localStageOp.staging(vars);
ArrayList<ML4allModel> broadcastModel = new ArrayList<>(1);
broadcastModel.add(vars);
final DataQuantaBuilder<?, ML4allModel> modelBuilder = javaPlanBuilder.loadCollection(broadcastModel).withName("init model");
final DataQuantaBuilder transformBuilder = javaPlanBuilder
.readTextFile(inputFileUrl).withName("source")
.mapPartitions(new TransformPerPartitionWrapper(transformOp)).withName("transform");
Collection<ML4allModel> results =
modelBuilder.doWhile((PredicateDescriptor.SerializablePredicate<Collection<Double>>) collection ->
new LoopCheckWrapper<>(loopOp).apply(collection.iterator().next()), model -> {
DataQuantaBuilder convergenceDataset;
DataQuantaBuilder<?, ML4allModel> newModel;
DataQuantaBuilder sampledData;
if (hasSample()) //sample data first
sampledData = transformBuilder
.sample(sampleOp.sampleSize()).withSampleMethod(sampleOp.sampleMethod()).withDatasetSize(datasetsize).withBroadcast(model, "model");
else //sampled data is entire dataset
sampledData = transformBuilder;
if (isUpdateLocal()) { //eg., for GD
DataQuantaBuilder newWeights = sampledData
.map(new ComputeWrapper<>(computeOp)).withBroadcast(model, "model").withName("compute")
.reduce(new AggregateWrapper<>(computeOp)).withName("reduce")
.map(new UpdateLocalWrapper(updateLocalOp)).withBroadcast(model, "model").withName("update");
newModel = newWeights
.map(new AssignWrapperLocal(updateLocalOp)).withName("assign")
.withBroadcast(model, "model");
convergenceDataset = newWeights
.map(new LoopConvergenceWrapper(loopOp)).withName("converge")
.withBroadcast(model, "model");
} else { //eg., for k-means
DataQuantaBuilder listDataset = sampledData
.map(new ComputeWrapper<>(computeOp)).withBroadcast(model, "model").withName("compute")
.reduceByKey(pair -> ((Tuple2) pair).field0, new AggregateWrapper<>(computeOp)).withName("reduce")
.map(new UpdateWrapper(updateOp)).withBroadcast(model, "model").withName("update")
.map(t -> {
ArrayList<Tuple2> list = new ArrayList<>(1);
list.add((Tuple2) t);
return list;
})
.reduce(new ReduceWrapper<>()).withName("global reduce");
newModel = listDataset
.map(new AssignWrapper(updateOp)).withName("assign")
.withBroadcast(model, "model");
convergenceDataset = listDataset
.map(new LoopConvergenceWrapper(loopOp)).withName("converge")
.withBroadcast(model, "model");
}
return new Tuple<>(newModel, convergenceDataset);
}).collect();
return WayangCollections.getSingle(results);
// }
}