in vw/src/main/scala/com/microsoft/azure/synapse/ml/vw/VowpalWabbitContextualBandit.scala [152:209]
override def transformSchema(schema: StructType): StructType = {
val allActionFeatureColumns = Seq(getFeaturesCol) ++ getAdditionalFeatures
val allSharedFeatureColumns = Seq(getSharedCol) ++ getAdditionalSharedFeatures
val actionCol = getChosenActionCol
val labelCol = getLabelCol
val probCol = getProbabilityCol
// Validate args
val allArgs = getArgs
if (allArgs.matches("^.*--(cb_explore|cb|cb_adf)( |$).*$"))
{
throw new NotImplementedError("VowpalWabbitContextualBandit is only compatible with contextual bandit problems" +
" with action dependent features which produce a probability distributions. These are problems which are " +
"used with VowpalWabbit with the '--cb_explore_adf' flag.")
}
// Validate action columns
for (colName <- allActionFeatureColumns) {
val dt = schema(colName).dataType
assert(dt match {
case ArrayType(VectorType, _) => true
case _ => false
}, s"$colName must be a list of sparse vectors of features. Found: $dt. Each item in the list corresponds to a" +
s" specific action and the overall list is the namespace.")
}
// Validate shared columns
for (colName <- allSharedFeatureColumns) {
val dt = schema(colName).dataType
assert(dt match {
case VectorType => true
case _ => false
}, s"$colName must be a sparse vector of features. Found $dt")
}
val actionDt = schema(actionCol).dataType
assert(actionDt match {
case IntegerType => true
case _ => false
}, s" $actionCol must be an integer. Found: $actionDt")
val labelDt = schema(labelCol).dataType
assert(labelDt match {
case IntegerType => true
case DoubleType => true
case FloatType => true
case _ => false
}, s"$labelCol must be an double. Found: $labelDt")
val probDt = schema(probCol).dataType
assert(probDt match {
case IntegerType => true
case DoubleType => true
case _ => false
}, s"$probCol must be an double. Found: $probDt")
schema
}