in preprocessing/src/main/scala/com/facebook/spark/rl/Timeline.scala [127:344]
def run(
sqlContext: SQLContext,
config: TimelineConfiguration
): Unit = {
var filterTerminal = "WHERE next_state_features IS NOT NULL";
if (config.addTerminalStateRow) {
filterTerminal = "";
}
var filterTimeLimit = "";
if (config.timeWindowLimit != None) {
if (filterTerminal == "") {
filterTimeLimit = s"WHERE time_since_first <= ${config.timeWindowLimit.get}";
} else {
filterTimeLimit = s" AND time_since_first <= ${config.timeWindowLimit.get}";
}
}
val actionDataType =
Helper.getDataTypes(sqlContext, config.inputTableName, List("action"))("action")
log.info("action column data type:" + s"${actionDataType}")
var timelineJoinColumns = config.extraFeatureColumns
if (config.includePossibleActions) {
timelineJoinColumns = "possible_actions" :: timelineJoinColumns
}
val rewardColumnDataTypes =
Helper.getDataTypes(sqlContext, config.inputTableName, config.rewardColumns)
log.info("reward columns:" + s"${config.rewardColumns}")
log.info("reward column types:" + s"${rewardColumnDataTypes}")
val timelineJoinColumnDataTypes =
Helper.getDataTypes(sqlContext, config.inputTableName, timelineJoinColumns)
log.info("timeline join column columns:" + s"${timelineJoinColumns}")
log.info("timeline join column types:" + s"${timelineJoinColumnDataTypes}")
Timeline.createTrainingTable(
sqlContext,
config.outputTableName,
actionDataType,
rewardColumnDataTypes,
timelineJoinColumnDataTypes
)
config.outlierEpisodeLengthPercentile.foreach { percentile =>
sqlContext.sql(s"""
SELECT mdp_id, COUNT(mdp_id) AS mdp_length
FROM ${config.inputTableName}
WHERE ds BETWEEN '${config.startDs}' AND '${config.endDs}'
GROUP BY mdp_id
""").createOrReplaceTempView("episode_length")
}
val lengthThreshold = Timeline.mdpLengthThreshold(sqlContext, config)
val mdpFilter = lengthThreshold
.map { threshold =>
s"mdp_filter AS (SELECT mdp_id FROM episode_length WHERE mdp_length <= ${threshold}),"
}
.getOrElse("")
val joinClause = lengthThreshold
.map { threshold =>
s"""
JOIN mdp_filter
WHERE a.mdp_id = mdp_filter.mdp_id AND
""".stripMargin
}
.getOrElse("WHERE")
val rewardSourceColumns = rewardColumnDataTypes.foldLeft("") {
case (acc, (k, v)) => s"${acc}, a.${k}"
}
val timelineSourceColumns = timelineJoinColumnDataTypes.foldLeft("") {
case (acc, (k, v)) => s"${acc}, a.${k}"
}
val sourceTable = s"""
WITH ${mdpFilter}
source_table AS (
SELECT
a.mdp_id,
a.state_features,
a.action_probability,
a.action
${rewardSourceColumns},
a.sequence_number
${timelineSourceColumns}
FROM ${config.inputTableName} a
${joinClause}
a.ds BETWEEN '${config.startDs}' AND '${config.endDs}'
)
""".stripMargin
val rewardColumnsQuery = rewardColumnDataTypes.foldLeft("") {
case (acc, (k, v)) => s"${acc}, ${k}"
}
val timelineJoinColumnsQuery = timelineJoinColumnDataTypes.foldLeft("") {
case (acc, (k, v)) =>
s"""
${acc},
${k},
LEAD(${k}) OVER (
PARTITION BY
mdp_id
ORDER BY
mdp_id,
sequence_number
) AS ${Helper.next_step_col_name(k)}
"""
}
val sqlCommand = s"""
${sourceTable},
joined_table AS (
SELECT
mdp_id,
state_features,
action,
LEAD(action) OVER (
PARTITION BY
mdp_id
ORDER BY
mdp_id,
sequence_number
) AS next_action,
action_probability
${rewardColumnsQuery},
LEAD(state_features) OVER (
PARTITION BY
mdp_id
ORDER BY
mdp_id,
sequence_number
) AS next_state_features,
sequence_number,
ROW_NUMBER() OVER (
PARTITION BY
mdp_id
ORDER BY
mdp_id,
sequence_number
) AS sequence_number_ordinal,
COALESCE(LEAD(sequence_number) OVER (
PARTITION BY
mdp_id
ORDER BY
mdp_id,
sequence_number
), sequence_number) - sequence_number AS time_diff,
sequence_number - FIRST(sequence_number) OVER (
PARTITION BY
mdp_id
ORDER BY
mdp_id,
sequence_number
) AS time_since_first
${timelineJoinColumnsQuery}
FROM source_table
CLUSTER BY HASH(mdp_id, sequence_number)
)
SELECT
*
FROM joined_table
${filterTerminal}
${filterTimeLimit}
""".stripMargin
log.info("Executing query: ")
log.info(sqlCommand)
var df = sqlContext.sql(sqlCommand)
log.info("Done with query")
// Handle nulls in output present when terminal states are present
val handle_cols = timelineJoinColumnDataTypes.++(
Map(
"action" -> actionDataType,
"state_features" -> "map<bigint,double>"
)
)
for ((col_name, col_type) <- handle_cols) {
val next_col_name = Helper.next_step_col_name(col_name)
val empty_placeholder = col_type match {
case "string" => Udfs.emptyStr()
case "array<string>" => Udfs.emptyArrOfStr()
case "map<bigint,double>" => Udfs.emptyMap()
case "array<map<bigint,double>>" => Udfs.emptyArrOfMap()
case "array<bigint>" => Udfs.emptyArrOfLong()
case "map<bigint,array<bigint>>" => Udfs.emptyMapOfIds()
case "map<bigint,map<bigint,double>>" => Udfs.emptyMapOfMap()
case "map<bigint,array<map<bigint,double>>>" => Udfs.emptyMapOfArrOfMap()
}
df = df
.withColumn(next_col_name, coalesce(df(next_col_name), empty_placeholder))
}
val stagingTable = "stagingTable_" + config.outputTableName
if (sqlContext.tableNames.contains(stagingTable)) {
log.warn("RL ValidationSql staging table name collision occurred, name: " + stagingTable)
}
df.createOrReplaceTempView(stagingTable)
val maybeError = config.validationSql.flatMap { query =>
Helper.validateTimeline(
sqlContext,
query
.replace("{config.outputTableName}", stagingTable)
.replace("{config.inputTableName}", config.inputTableName)
)
}
assert(maybeError.isEmpty, "validationSql validation failure: " + maybeError)
val insertCommandOutput = s"""
INSERT OVERWRITE TABLE ${config.outputTableName} PARTITION(ds='${config.endDs}')
SELECT * FROM ${stagingTable}
""".stripMargin
sqlContext.sql(insertCommandOutput)
}