in reagent/training/ranking/seq2slate_sim_trainer.py [0:0]
def _simulated_training_input(self, training_input: rlt.PreprocessedRankingInput):
device = training_input.state.float_features.device
# precision error may cause invalid actions
valid_output = False
while not valid_output:
rank_output = self.seq2slate_net(
training_input,
mode=Seq2SlateMode.RANK_MODE,
tgt_seq_len=self.seq2slate_net.max_tgt_seq_len,
greedy=False,
)
model_propensities = rank_output.ranked_per_seq_probs
model_actions_with_offset = rank_output.ranked_tgt_out_idx
model_actions = model_actions_with_offset - 2
if torch.all(model_actions >= 0):
valid_output = True
batch_size = model_actions_with_offset.shape[0]
simulated_slate_features = gather(
training_input.src_seq.float_features,
# pyre-fixme[61]: `model_actions` may not be initialized here.
model_actions,
)
if not self.reward_name_and_net:
use_gpu = True if device == torch.device("cuda") else False
self.reward_name_and_net = nn.ModuleDict(
_load_reward_net(self.sim_param.reward_name_path, use_gpu)
)
sim_slate_reward = torch.zeros(batch_size, 1, device=device)
for name, reward_net in self.reward_name_and_net.items():
weight = self.sim_param.reward_name_weight[name]
power = self.sim_param.reward_name_power[name]
sr = reward_net(
training_input.state.float_features,
training_input.src_seq.float_features,
simulated_slate_features,
training_input.src_src_mask,
model_actions_with_offset,
).detach()
assert sr.ndim == 2, f"Slate reward {name} output should be 2-D tensor"
sim_slate_reward += weight * (sr ** power)
# guard-rail reward prediction range
reward_clamp = self.sim_param.reward_clamp
if reward_clamp is not None:
sim_slate_reward = torch.clamp(
sim_slate_reward, min=reward_clamp.clamp_min, max=reward_clamp.clamp_max
)
# guard-rail sequence similarity
distance_penalty = self.sim_param.distance_penalty
if distance_penalty is not None:
sim_distance = (
torch.tensor(
# pyre-fixme[16]: `int` has no attribute `__iter__`.
[swap_dist(x.tolist()) for x in model_actions],
device=device,
)
.unsqueeze(1)
.float()
)
sim_slate_reward += distance_penalty * (self.MAX_DISTANCE - sim_distance)
assert (
len(sim_slate_reward.shape) == 2 and sim_slate_reward.shape[1] == 1
), f"{sim_slate_reward.shape}"
on_policy_input = rlt.PreprocessedRankingInput.from_input(
state=training_input.state.float_features,
candidates=training_input.src_seq.float_features,
device=device,
# pyre-fixme[6]: Expected `Optional[torch.Tensor]` for 4th param but got
# `int`.
# pyre-fixme[61]: `model_actions` may not be initialized here.
action=model_actions,
slate_reward=sim_slate_reward,
# pyre-fixme[61]: `model_propensities` may not be initialized here.
logged_propensities=model_propensities,
)
return on_policy_input