in vw/src/main/scala/com/microsoft/azure/synapse/ml/vw/VowpalWabbitBase.scala [337:422]
private def trainInternal(df: DataFrame, vwArgs: String, contextArgs: => String = "") = {
val schema = df.schema
def trainIteration(inputRows: Iterator[Row],
localInitialModel: Option[Array[Byte]]): Iterator[TrainingResult] = {
// construct command line arguments
val args = buildCommandLineArguments(vwArgs, contextArgs)
FaultToleranceUtils.retryWithTimeout() {
try {
val totalTime = new StopWatch
val nativeIngestTime = new StopWatch
val learnTime = new StopWatch
val multipassTime = new StopWatch
val (model, stats) = StreamUtilities.using(if (localInitialModel.isEmpty) new VowpalWabbitNative(args.result)
else new VowpalWabbitNative(args.result, localInitialModel.get)) { vw =>
val trainContext = new TrainContext(vw)
val result = StreamUtilities.using(vw.createExample()) { ex =>
// pass data to VW native part
totalTime.measure {
trainRow(schema, inputRows, trainContext)
multipassTime.measure {
vw.endPass()
if (getNumPasses > 1)
vw.performRemainingPasses()
}
}
}
// If the using statement failed rethrow here.
result match {
case Failure(exception) => throw exception
case Success(_) => Unit
}
// only export the model on the first partition
val perfStats = vw.getPerformanceStatistics
val args = vw.getArguments
(if (TaskContext.get.partitionId == 0) Some(vw.getModel) else None,
TrainingStats(
TaskContext.get.partitionId,
args.getArgs,
args.getLearningRate,
args.getPowerT,
args.getHashSeed,
args.getNumBits,
perfStats.getNumberOfExamplesPerPass,
perfStats.getWeightedExampleSum,
perfStats.getWeightedLabelSum,
perfStats.getAverageLoss,
perfStats.getBestConstant,
perfStats.getBestConstantLoss,
perfStats.getTotalNumberOfFeatures,
totalTime.elapsed(),
nativeIngestTime.elapsed(),
learnTime.elapsed(),
multipassTime.elapsed(),
trainContext.contextualBanditMetrics.getIpsEstimate,
trainContext.contextualBanditMetrics.getSnipsEstimate))
}.get // this will throw if there was an exception
Seq(TrainingResult(model, stats)).iterator
} catch {
case e: java.lang.Exception =>
throw new Exception(s"VW failed with args: ${args.result}", e)
}
}
}
val encoder = Encoders.kryo[TrainingResult]
// schedule multiple mapPartitions in
val localInitialModel = if (isDefined(initialModel)) Some(getInitialModel) else None
// dispatch to exectuors and collect the model of the first partition (everybody has the same at the end anyway)
// important to trigger collect() here so that the spanning tree is still up
if (getUseBarrierExecutionMode)
df.rdd.barrier().mapPartitions(inputRows => trainIteration(inputRows, localInitialModel)).collect()
else
df.mapPartitions(inputRows => trainIteration(inputRows, localInitialModel))(encoder).collect()
}