private def trainInternal()

in vw/src/main/scala/com/microsoft/azure/synapse/ml/vw/VowpalWabbitBase.scala [337:422]


  private def trainInternal(df: DataFrame, vwArgs: String, contextArgs: => String = "") = {

    val schema = df.schema

    def trainIteration(inputRows: Iterator[Row],
                       localInitialModel: Option[Array[Byte]]): Iterator[TrainingResult] = {
      // construct command line arguments
      val args = buildCommandLineArguments(vwArgs, contextArgs)
      FaultToleranceUtils.retryWithTimeout() {
        try {
          val totalTime = new StopWatch
          val nativeIngestTime = new StopWatch
          val learnTime = new StopWatch
          val multipassTime = new StopWatch

          val (model, stats) = StreamUtilities.using(if (localInitialModel.isEmpty) new VowpalWabbitNative(args.result)
          else new VowpalWabbitNative(args.result, localInitialModel.get)) { vw =>
            val trainContext = new TrainContext(vw)

            val result = StreamUtilities.using(vw.createExample()) { ex =>
              // pass data to VW native part
              totalTime.measure {
                trainRow(schema, inputRows, trainContext)

                multipassTime.measure {
                  vw.endPass()

                  if (getNumPasses > 1)
                    vw.performRemainingPasses()
                }
              }
            }

            // If the using statement failed rethrow here.
            result match {
              case Failure(exception) => throw exception
              case Success(_) => Unit
            }

            // only export the model on the first partition
            val perfStats = vw.getPerformanceStatistics
            val args = vw.getArguments

            (if (TaskContext.get.partitionId == 0) Some(vw.getModel) else None,
              TrainingStats(
                TaskContext.get.partitionId,
                args.getArgs,
                args.getLearningRate,
                args.getPowerT,
                args.getHashSeed,
                args.getNumBits,
                perfStats.getNumberOfExamplesPerPass,
                perfStats.getWeightedExampleSum,
                perfStats.getWeightedLabelSum,
                perfStats.getAverageLoss,
                perfStats.getBestConstant,
                perfStats.getBestConstantLoss,
                perfStats.getTotalNumberOfFeatures,
                totalTime.elapsed(),
                nativeIngestTime.elapsed(),
                learnTime.elapsed(),
                multipassTime.elapsed(),
                trainContext.contextualBanditMetrics.getIpsEstimate,
                trainContext.contextualBanditMetrics.getSnipsEstimate))
          }.get // this will throw if there was an exception

          Seq(TrainingResult(model, stats)).iterator
        } catch {
          case e: java.lang.Exception =>
            throw new Exception(s"VW failed with args: ${args.result}", e)
        }
      }
    }

    val encoder = Encoders.kryo[TrainingResult]

    // schedule multiple mapPartitions in
    val localInitialModel = if (isDefined(initialModel)) Some(getInitialModel) else None

    // dispatch to exectuors and collect the model of the first partition (everybody has the same at the end anyway)
    // important to trigger collect() here so that the spanning tree is still up
    if (getUseBarrierExecutionMode)
      df.rdd.barrier().mapPartitions(inputRows => trainIteration(inputRows, localInitialModel)).collect()
    else
      df.mapPartitions(inputRows => trainIteration(inputRows, localInitialModel))(encoder).collect()
  }