public ML4allModel execute()

in wayang-ml4all/src/main/java/org/apache/wayang/ml4all/abstraction/plan/ML4allPlan.java [117:188]


    public ML4allModel execute(String inputFileUrl, WayangContext wayangContext) {

        wayangContext.getConfiguration().setProperty("wayang.core.optimizer.reoptimize", "false");

        JavaPlanBuilder javaPlanBuilder = new JavaPlanBuilder(wayangContext)
                .withUdfJar(ReflectionUtils.getDeclaringJar(ML4allModel.class))
                .withUdfJar(ReflectionUtils.getDeclaringJar(JavaPlatform.class))
                .withJobName("ML4all plan");

        ML4allModel vars = new ML4allModel();
        localStageOp.staging(vars);
        ArrayList<ML4allModel> broadcastModel = new ArrayList<>(1);
        broadcastModel.add(vars);
        final DataQuantaBuilder<?, ML4allModel> modelBuilder = javaPlanBuilder.loadCollection(broadcastModel).withName("init model");

        final DataQuantaBuilder transformBuilder = javaPlanBuilder
                .readTextFile(inputFileUrl).withName("source")
                .mapPartitions(new TransformPerPartitionWrapper(transformOp)).withName("transform");

        Collection<ML4allModel> results =
                modelBuilder.doWhile((PredicateDescriptor.SerializablePredicate<Collection<Double>>) collection ->
                        new LoopCheckWrapper<>(loopOp).apply(collection.iterator().next()), model -> {

                    DataQuantaBuilder convergenceDataset;
                    DataQuantaBuilder<?, ML4allModel> newModel;

                    DataQuantaBuilder sampledData;
                    if (hasSample()) //sample data first
                        sampledData = transformBuilder
                                .sample(sampleOp.sampleSize()).withSampleMethod(sampleOp.sampleMethod()).withDatasetSize(datasetsize).withBroadcast(model, "model");
                    else //sampled data is entire dataset
                        sampledData = transformBuilder;

                    if (isUpdateLocal()) { //eg., for GD
                        DataQuantaBuilder newWeights = sampledData
                                .map(new ComputeWrapper<>(computeOp)).withBroadcast(model, "model").withName("compute")
                                .reduce(new AggregateWrapper<>(computeOp)).withName("reduce")
                                .map(new UpdateLocalWrapper(updateLocalOp)).withBroadcast(model, "model").withName("update");

                        newModel = newWeights
                                .map(new AssignWrapperLocal(updateLocalOp)).withName("assign")
                                .withBroadcast(model, "model");

                        convergenceDataset = newWeights
                                .map(new LoopConvergenceWrapper(loopOp)).withName("converge")
                                .withBroadcast(model, "model");

                    } else { //eg., for k-means
                        DataQuantaBuilder listDataset = sampledData
                                .map(new ComputeWrapper<>(computeOp)).withBroadcast(model, "model").withName("compute")
                                .reduceByKey(pair -> ((Tuple2) pair).field0, new AggregateWrapper<>(computeOp)).withName("reduce")
                                .map(new UpdateWrapper(updateOp)).withBroadcast(model, "model").withName("update")
                                .map(t -> {
                                    ArrayList<Tuple2> list = new ArrayList<>(1);
                                    list.add((Tuple2) t);
                                    return list;
                                })
                                .reduce(new ReduceWrapper<>()).withName("global reduce");
                        newModel = listDataset
                                .map(new AssignWrapper(updateOp)).withName("assign")
                                .withBroadcast(model, "model");
                        convergenceDataset = listDataset
                                .map(new LoopConvergenceWrapper(loopOp)).withName("converge")
                                .withBroadcast(model, "model");
                    }

                    return new Tuple<>(newModel, convergenceDataset);
                }).collect();

            return WayangCollections.getSingle(results);
//        }
    }