override def transformSchema()

in vw/src/main/scala/com/microsoft/azure/synapse/ml/vw/VowpalWabbitContextualBandit.scala [152:209]


  override def transformSchema(schema: StructType): StructType = {
    val allActionFeatureColumns = Seq(getFeaturesCol) ++ getAdditionalFeatures
    val allSharedFeatureColumns = Seq(getSharedCol) ++ getAdditionalSharedFeatures
    val actionCol = getChosenActionCol
    val labelCol = getLabelCol
    val probCol = getProbabilityCol

    // Validate args
    val allArgs = getArgs
    if (allArgs.matches("^.*--(cb_explore|cb|cb_adf)( |$).*$"))
    {
      throw new NotImplementedError("VowpalWabbitContextualBandit is only compatible with contextual bandit problems" +
        " with action dependent features which produce a probability distributions. These are problems which are " +
        "used with VowpalWabbit with the '--cb_explore_adf' flag.")
    }

    // Validate action columns
    for (colName <- allActionFeatureColumns) {
      val dt = schema(colName).dataType
      assert(dt match {
        case ArrayType(VectorType, _) => true
        case _ => false
      }, s"$colName must be a list of sparse vectors of features. Found: $dt. Each item in the list corresponds to a" +
        s" specific action and the overall list is the namespace.")
    }

    // Validate shared columns
    for (colName <- allSharedFeatureColumns) {
      val dt = schema(colName).dataType
      assert(dt match {
        case VectorType => true
        case _ => false
      }, s"$colName must be a sparse vector of features. Found $dt")
    }

    val actionDt = schema(actionCol).dataType
    assert(actionDt match {
      case IntegerType => true
      case _ => false
    }, s" $actionCol must be an integer. Found: $actionDt")

    val labelDt = schema(labelCol).dataType
    assert(labelDt match {
      case IntegerType => true
      case DoubleType => true
      case FloatType => true
      case _ => false
    }, s"$labelCol must be an double. Found: $labelDt")

    val probDt = schema(probCol).dataType
    assert(probDt match {
      case IntegerType => true
      case DoubleType => true
      case _ => false
    }, s"$probCol must be an double. Found: $probDt")

    schema
  }