in preprocessing/src/main/scala/com/facebook/spark/rl/MultiStepTimeline.scala [116:278]
def run(sqlContext: SQLContext, config: MultiStepTimelineConfiguration): Unit = {
var terminalJoin = "";
if (config.addTerminalStateRow) {
terminalJoin = "LEFT OUTER";
}
val actionDataType =
Helper.getDataTypes(sqlContext, config.inputTableName, List("action"))("action")
log.info("action column data type:" + s"${actionDataType}")
assert(Set("string", "map<bigint,double>").contains(actionDataType))
var sortActionMethod = "UDF_SORT_ID";
var sortPossibleActionMethod = "UDF_SORT_ARRAY_ID";
if (actionDataType != "string") {
sortActionMethod = "UDF_SORT_MAP";
sortPossibleActionMethod = "UDF_SORT_ARRAY_MAP";
}
MultiStepTimeline.createTrainingTable(sqlContext, config.outputTableName, actionDataType)
MultiStepTimeline.registerUDFs(sqlContext)
val sqlCommand = s"""
WITH deduped as (
SELECT
mdp_id as mdp_id,
FIRST(state_features) as state_features,
FIRST(action) as action,
FIRST(action_probability) as action_probability,
FIRST(reward) as reward,
FIRST(possible_actions) as possible_actions,
FIRST(metrics) as metrics,
FIRST(ds) as ds,
sequence_number as sequence_number
FROM (
SELECT * FROM ${config.inputTableName}
WHERE ds BETWEEN '${config.startDs}' AND '${config.endDs}'
) dummy
GROUP BY mdp_id, sequence_number
),
ordinal as (
SELECT
mdp_id as mdp_id,
state_features as state_features,
action as action,
action_probability as action_probability,
reward as reward,
possible_actions as possible_actions,
metrics as metrics,
ds as ds,
sequence_number as sequence_number,
row_number() over (partition by mdp_id order by mdp_id, sequence_number) as sequence_number_ordinal,
sequence_number - FIRST(sequence_number) OVER (
PARTITION BY mdp_id ORDER BY mdp_id, sequence_number
) AS time_since_first
FROM deduped
),
ordinal_join AS (
SELECT
first_sa.mdp_id AS mdp_id,
first_sa.state_features AS state_features,
first_sa.action AS action,
first_sa.action_probability as action_probability,
first_sa.reward AS reward,
second_sa.reward AS next_reward,
second_sa.state_features AS next_state_features,
second_sa.action AS next_action,
first_sa.sequence_number AS sequence_number,
first_sa.sequence_number_ordinal AS sequence_number_ordinal,
COALESCE(
CAST(second_sa.sequence_number - first_sa.sequence_number AS BIGINT),
first_sa.sequence_number
) AS time_diff,
first_sa.time_since_first AS time_since_first,
first_sa.possible_actions AS possible_actions,
second_sa.possible_actions AS possible_next_actions,
first_sa.metrics AS metrics,
second_sa.metrics AS next_metrics
FROM
ordinal first_sa
${terminalJoin} JOIN ordinal second_sa
ON first_sa.mdp_id = second_sa.mdp_id
AND (first_sa.sequence_number_ordinal + 1) <= second_sa.sequence_number_ordinal
AND (first_sa.sequence_number_ordinal + ${config.steps}) >= second_sa.sequence_number_ordinal
),
ordinal_join_time_diff AS (
SELECT
mdp_id AS mdp_id,
state_features AS state_features,
action AS action,
action_probability as action_probability,
reward AS reward,
MAP(time_diff, next_reward) AS next_reward,
MAP(time_diff, next_state_features) AS next_state_features,
MAP(time_diff, next_action) AS next_action,
sequence_number AS sequence_number,
sequence_number_ordinal AS sequence_number_ordinal,
time_diff AS time_diff,
time_since_first AS time_since_first,
possible_actions AS possible_actions,
MAP(time_diff, possible_next_actions) AS possible_next_actions,
metrics AS metrics,
MAP(time_diff, next_metrics) AS next_metrics
FROM
ordinal_join
),
sarsa_unshuffled_multi_step AS (
SELECT
mdp_id AS mdp_id,
FIRST(state_features) AS state_features,
FIRST(action) AS action,
FIRST(action_probability) as action_probability,
UDF_PREPEND_DOUBLE(
FIRST(reward),
UDF_DROP_LAST_DOUBLE(UDF_SORT_DOUBLE(COLLECT_LIST(next_reward)))
) AS reward,
UDF_SORT_MAP(
COLLECT_LIST(next_state_features)
) AS next_state_features,
${sortActionMethod}(
COLLECT_LIST(next_action)
) AS next_action,
sequence_number AS sequence_number,
sequence_number_ordinal AS sequence_number_ordinal,
SORT_ARRAY(
COLLECT_LIST(time_diff)
) AS time_diff,
FIRST(time_since_first) AS time_since_first,
FIRST(possible_actions) AS possible_actions,
${sortPossibleActionMethod}(
COLLECT_LIST(possible_next_actions)
) AS possible_next_actions,
UDF_PREPEND_MAP_STRING(
FIRST(metrics),
UDF_DROP_LAST_MAP_STRING(UDF_SORT_MAP_STRING(COLLECT_LIST(next_metrics)))
) AS metrics
FROM
ordinal_join_time_diff
GROUP BY mdp_id, sequence_number, sequence_number_ordinal
)
INSERT OVERWRITE TABLE ${config.outputTableName} PARTITION(ds='${config.endDs}')
SELECT
mdp_id,
state_features,
action,
action_probability,
reward,
next_state_features,
next_action,
sequence_number,
sequence_number_ordinal,
time_diff,
time_since_first,
possible_actions,
possible_next_actions,
metrics
FROM
sarsa_unshuffled_multi_step
CLUSTER BY HASH(mdp_id, sequence_number)
""".stripMargin
log.info("Executing query: ")
log.info(sqlCommand)
sqlContext.sql(sqlCommand)
}