def run()

in preprocessing/src/main/scala/com/facebook/spark/rl/MultiStepTimeline.scala [116:278]


  def run(sqlContext: SQLContext, config: MultiStepTimelineConfiguration): Unit = {
    var terminalJoin = "";
    if (config.addTerminalStateRow) {
      terminalJoin = "LEFT OUTER";
    }

    val actionDataType =
      Helper.getDataTypes(sqlContext, config.inputTableName, List("action"))("action")
    log.info("action column data type:" + s"${actionDataType}")
    assert(Set("string", "map<bigint,double>").contains(actionDataType))

    var sortActionMethod = "UDF_SORT_ID";
    var sortPossibleActionMethod = "UDF_SORT_ARRAY_ID";
    if (actionDataType != "string") {
      sortActionMethod = "UDF_SORT_MAP";
      sortPossibleActionMethod = "UDF_SORT_ARRAY_MAP";
    }

    MultiStepTimeline.createTrainingTable(sqlContext, config.outputTableName, actionDataType)
    MultiStepTimeline.registerUDFs(sqlContext)

    val sqlCommand = s"""
      WITH deduped as (
          SELECT
              mdp_id as mdp_id,
              FIRST(state_features) as state_features,
              FIRST(action) as action,
              FIRST(action_probability) as action_probability,
              FIRST(reward) as reward,
              FIRST(possible_actions) as possible_actions,
              FIRST(metrics) as metrics,
              FIRST(ds) as ds,
              sequence_number as sequence_number
              FROM (
                  SELECT * FROM ${config.inputTableName}
                  WHERE ds BETWEEN '${config.startDs}' AND '${config.endDs}'
              ) dummy
              GROUP BY mdp_id, sequence_number
      ),
      ordinal as (
          SELECT
              mdp_id as mdp_id,
              state_features as state_features,
              action as action,
              action_probability as action_probability,
              reward as reward,
              possible_actions as possible_actions,
              metrics as metrics,
              ds as ds,
              sequence_number as sequence_number,
              row_number() over (partition by mdp_id order by mdp_id, sequence_number) as sequence_number_ordinal,
              sequence_number - FIRST(sequence_number) OVER (
                  PARTITION BY mdp_id ORDER BY mdp_id, sequence_number
              ) AS time_since_first
              FROM deduped
      ),
      ordinal_join AS (
          SELECT
              first_sa.mdp_id AS mdp_id,
              first_sa.state_features AS state_features,
              first_sa.action AS action,
              first_sa.action_probability as action_probability,
              first_sa.reward AS reward,
              second_sa.reward AS next_reward,
              second_sa.state_features AS next_state_features,
              second_sa.action AS next_action,
              first_sa.sequence_number AS sequence_number,
              first_sa.sequence_number_ordinal AS sequence_number_ordinal,
              COALESCE(
                CAST(second_sa.sequence_number - first_sa.sequence_number AS BIGINT),
                first_sa.sequence_number
              ) AS time_diff,
              first_sa.time_since_first AS time_since_first,
              first_sa.possible_actions AS possible_actions,
              second_sa.possible_actions AS possible_next_actions,
              first_sa.metrics AS metrics,
              second_sa.metrics AS next_metrics
              FROM
                  ordinal first_sa
                  ${terminalJoin} JOIN ordinal second_sa
                  ON first_sa.mdp_id = second_sa.mdp_id
                  AND (first_sa.sequence_number_ordinal + 1) <= second_sa.sequence_number_ordinal
                  AND (first_sa.sequence_number_ordinal + ${config.steps}) >= second_sa.sequence_number_ordinal
      ),
      ordinal_join_time_diff AS (
          SELECT
              mdp_id AS mdp_id,
              state_features AS state_features,
              action AS action,
              action_probability as action_probability,
              reward AS reward,
              MAP(time_diff, next_reward) AS next_reward,
              MAP(time_diff, next_state_features) AS next_state_features,
              MAP(time_diff, next_action) AS next_action,
              sequence_number AS sequence_number,
              sequence_number_ordinal AS sequence_number_ordinal,
              time_diff AS time_diff,
              time_since_first AS time_since_first,
              possible_actions AS possible_actions,
              MAP(time_diff, possible_next_actions) AS possible_next_actions,
              metrics AS metrics,
              MAP(time_diff, next_metrics) AS next_metrics
              FROM
                  ordinal_join
      ),
      sarsa_unshuffled_multi_step AS (
          SELECT
              mdp_id AS mdp_id,
              FIRST(state_features) AS state_features,
              FIRST(action) AS action,
              FIRST(action_probability) as action_probability,
              UDF_PREPEND_DOUBLE(
                FIRST(reward),
                UDF_DROP_LAST_DOUBLE(UDF_SORT_DOUBLE(COLLECT_LIST(next_reward)))
              ) AS reward,
              UDF_SORT_MAP(
                COLLECT_LIST(next_state_features)
              ) AS next_state_features,
              ${sortActionMethod}(
                COLLECT_LIST(next_action)
              ) AS next_action,
              sequence_number AS sequence_number,
              sequence_number_ordinal AS sequence_number_ordinal,
              SORT_ARRAY(
                COLLECT_LIST(time_diff)
              ) AS time_diff,
              FIRST(time_since_first) AS time_since_first,
              FIRST(possible_actions) AS possible_actions,
              ${sortPossibleActionMethod}(
                COLLECT_LIST(possible_next_actions)
              ) AS possible_next_actions,
              UDF_PREPEND_MAP_STRING(
                FIRST(metrics),
                UDF_DROP_LAST_MAP_STRING(UDF_SORT_MAP_STRING(COLLECT_LIST(next_metrics)))
              ) AS metrics
          FROM
              ordinal_join_time_diff
          GROUP BY mdp_id, sequence_number, sequence_number_ordinal
      )
      INSERT OVERWRITE TABLE ${config.outputTableName} PARTITION(ds='${config.endDs}')
      SELECT
          mdp_id,
          state_features,
          action,
          action_probability,
          reward,
          next_state_features,
          next_action,
          sequence_number,
          sequence_number_ordinal,
          time_diff,
          time_since_first,
          possible_actions,
          possible_next_actions,
          metrics
      FROM
          sarsa_unshuffled_multi_step
          CLUSTER BY HASH(mdp_id, sequence_number)
    """.stripMargin
    log.info("Executing query: ")
    log.info(sqlCommand)
    sqlContext.sql(sqlCommand)
  }