in tzrec/models/dc2vr.py [0:0]
def predict(self, batch: Batch) -> Dict[str, torch.Tensor]:
"""Forward the model.
Args:
batch (Batch): input batch data.
Return:
predictions (dict): a dict of predicted result.
"""
grouped_features = self.build_input(batch)
net = grouped_features[self.group_name]
if self.bottom_mlp is not None:
net = self.bottom_mlp(net)
if self.mmoe is not None:
task_input_list = self.mmoe(net)
else:
task_input_list = [net] * len(self._task_tower_cfgs)
task_net = {}
for i, task_tower_cfg in enumerate(self._task_tower_cfgs):
tower_name = task_tower_cfg.tower_name
if tower_name in self.task_mlps.keys():
task_net[tower_name] = self.task_mlps[tower_name](task_input_list[i])
else:
task_net[tower_name] = task_input_list[i]
intervention = {}
for task_tower_cfg in self._task_tower_cfgs:
tower_name = task_tower_cfg.tower_name
if task_tower_cfg.HasField("low_rank_dim"):
intervention_base = task_net[tower_name]
intervention_source = []
for intervention_tower_name in task_tower_cfg.intervention_tower_names:
intervention_source.append(intervention[intervention_tower_name])
intervention_source = torch.cat(intervention_source, dim=-1) # .mean(0)
intervention[tower_name] = self.intervention[tower_name](
intervention_base, intervention_source
)
else:
intervention[tower_name] = task_net[tower_name]
tower_outputs = {}
for i, task_tower_cfg in enumerate(self._task_tower_cfgs):
tower_name = task_tower_cfg.tower_name
tower_output = self.task_outputs[i](intervention[tower_name])
tower_outputs[tower_name] = tower_output
return self._multi_task_output_to_prediction(tower_outputs)