def run()

in preprocessing/src/main/scala/com/facebook/spark/rl/Timeline.scala [127:344]


  def run(
      sqlContext: SQLContext,
      config: TimelineConfiguration
  ): Unit = {
    var filterTerminal = "WHERE next_state_features IS NOT NULL";
    if (config.addTerminalStateRow) {
      filterTerminal = "";
    }
    var filterTimeLimit = "";
    if (config.timeWindowLimit != None) {
      if (filterTerminal == "") {
        filterTimeLimit = s"WHERE time_since_first <= ${config.timeWindowLimit.get}";
      } else {
        filterTimeLimit = s" AND time_since_first <= ${config.timeWindowLimit.get}";
      }
    }

    val actionDataType =
      Helper.getDataTypes(sqlContext, config.inputTableName, List("action"))("action")
    log.info("action column data type:" + s"${actionDataType}")

    var timelineJoinColumns = config.extraFeatureColumns
    if (config.includePossibleActions) {
      timelineJoinColumns = "possible_actions" :: timelineJoinColumns
    }

    val rewardColumnDataTypes =
      Helper.getDataTypes(sqlContext, config.inputTableName, config.rewardColumns)
    log.info("reward columns:" + s"${config.rewardColumns}")
    log.info("reward column types:" + s"${rewardColumnDataTypes}")

    val timelineJoinColumnDataTypes =
      Helper.getDataTypes(sqlContext, config.inputTableName, timelineJoinColumns)
    log.info("timeline join column columns:" + s"${timelineJoinColumns}")
    log.info("timeline join column types:" + s"${timelineJoinColumnDataTypes}")

    Timeline.createTrainingTable(
      sqlContext,
      config.outputTableName,
      actionDataType,
      rewardColumnDataTypes,
      timelineJoinColumnDataTypes
    )

    config.outlierEpisodeLengthPercentile.foreach { percentile =>
      sqlContext.sql(s"""
          SELECT mdp_id, COUNT(mdp_id) AS mdp_length
          FROM ${config.inputTableName}
          WHERE ds BETWEEN '${config.startDs}' AND '${config.endDs}'
          GROUP BY mdp_id
      """).createOrReplaceTempView("episode_length")
    }

    val lengthThreshold = Timeline.mdpLengthThreshold(sqlContext, config)

    val mdpFilter = lengthThreshold
      .map { threshold =>
        s"mdp_filter AS (SELECT mdp_id FROM episode_length WHERE mdp_length <= ${threshold}),"
      }
      .getOrElse("")

    val joinClause = lengthThreshold
      .map { threshold =>
        s"""
        JOIN mdp_filter
        WHERE a.mdp_id = mdp_filter.mdp_id AND
    """.stripMargin
      }
      .getOrElse("WHERE")

    val rewardSourceColumns = rewardColumnDataTypes.foldLeft("") {
      case (acc, (k, v)) => s"${acc}, a.${k}"
    }
    val timelineSourceColumns = timelineJoinColumnDataTypes.foldLeft("") {
      case (acc, (k, v)) => s"${acc}, a.${k}"
    }

    val sourceTable = s"""
    WITH ${mdpFilter}
        source_table AS (
            SELECT
                a.mdp_id,
                a.state_features,
                a.action_probability,
                a.action
                ${rewardSourceColumns},
                a.sequence_number
                ${timelineSourceColumns}
            FROM ${config.inputTableName} a
            ${joinClause}
            a.ds BETWEEN '${config.startDs}' AND '${config.endDs}'
        )
    """.stripMargin

    val rewardColumnsQuery = rewardColumnDataTypes.foldLeft("") {
      case (acc, (k, v)) => s"${acc}, ${k}"
    }
    val timelineJoinColumnsQuery = timelineJoinColumnDataTypes.foldLeft("") {
      case (acc, (k, v)) =>
        s"""
        ${acc},
        ${k},
        LEAD(${k}) OVER (
            PARTITION BY
                mdp_id
              ORDER BY
                  mdp_id,
                  sequence_number
          ) AS ${Helper.next_step_col_name(k)}
        """
    }

    val sqlCommand = s"""
    ${sourceTable},
    joined_table AS (
      SELECT
          mdp_id,
          state_features,
          action,
          LEAD(action) OVER (
              PARTITION BY
                  mdp_id
              ORDER BY
                  mdp_id,
                  sequence_number
          ) AS next_action,
          action_probability
          ${rewardColumnsQuery},
          LEAD(state_features) OVER (
              PARTITION BY
                  mdp_id
              ORDER BY
                  mdp_id,
                  sequence_number
          ) AS next_state_features,
          sequence_number,
          ROW_NUMBER() OVER (
              PARTITION BY
                  mdp_id
              ORDER BY
                  mdp_id,
                  sequence_number
          ) AS sequence_number_ordinal,
          COALESCE(LEAD(sequence_number) OVER (
              PARTITION BY
                  mdp_id
              ORDER BY
                  mdp_id,
                  sequence_number
          ), sequence_number) - sequence_number AS time_diff,
          sequence_number - FIRST(sequence_number) OVER (
              PARTITION BY
                  mdp_id
              ORDER BY
                  mdp_id,
                  sequence_number
          ) AS time_since_first
          ${timelineJoinColumnsQuery}
      FROM source_table
      CLUSTER BY HASH(mdp_id, sequence_number)
    )
    SELECT
      *
    FROM joined_table
    ${filterTerminal}
    ${filterTimeLimit}
    """.stripMargin
    log.info("Executing query: ")
    log.info(sqlCommand)
    var df = sqlContext.sql(sqlCommand)
    log.info("Done with query")

    // Handle nulls in output present when terminal states are present
    val handle_cols = timelineJoinColumnDataTypes.++(
      Map(
        "action" -> actionDataType,
        "state_features" -> "map<bigint,double>"
      )
    )
    for ((col_name, col_type) <- handle_cols) {
      val next_col_name = Helper.next_step_col_name(col_name)
      val empty_placeholder = col_type match {
        case "string"                                => Udfs.emptyStr()
        case "array<string>"                         => Udfs.emptyArrOfStr()
        case "map<bigint,double>"                    => Udfs.emptyMap()
        case "array<map<bigint,double>>"             => Udfs.emptyArrOfMap()
        case "array<bigint>"                         => Udfs.emptyArrOfLong()
        case "map<bigint,array<bigint>>"             => Udfs.emptyMapOfIds()
        case "map<bigint,map<bigint,double>>"        => Udfs.emptyMapOfMap()
        case "map<bigint,array<map<bigint,double>>>" => Udfs.emptyMapOfArrOfMap()
      }
      df = df
        .withColumn(next_col_name, coalesce(df(next_col_name), empty_placeholder))
    }

    val stagingTable = "stagingTable_" + config.outputTableName
    if (sqlContext.tableNames.contains(stagingTable)) {
      log.warn("RL ValidationSql staging table name collision occurred, name: " + stagingTable)
    }
    df.createOrReplaceTempView(stagingTable)

    val maybeError = config.validationSql.flatMap { query =>
      Helper.validateTimeline(
        sqlContext,
        query
          .replace("{config.outputTableName}", stagingTable)
          .replace("{config.inputTableName}", config.inputTableName)
      )
    }

    assert(maybeError.isEmpty, "validationSql validation failure: " + maybeError)

    val insertCommandOutput = s"""
      INSERT OVERWRITE TABLE ${config.outputTableName} PARTITION(ds='${config.endDs}')
      SELECT * FROM ${stagingTable}
    """.stripMargin
    sqlContext.sql(insertCommandOutput)
  }