def predict()

in tzrec/models/dc2vr.py [0:0]


    def predict(self, batch: Batch) -> Dict[str, torch.Tensor]:
        """Forward the model.

        Args:
            batch (Batch): input batch data.

        Return:
            predictions (dict): a dict of predicted result.
        """
        grouped_features = self.build_input(batch)

        net = grouped_features[self.group_name]
        if self.bottom_mlp is not None:
            net = self.bottom_mlp(net)

        if self.mmoe is not None:
            task_input_list = self.mmoe(net)
        else:
            task_input_list = [net] * len(self._task_tower_cfgs)

        task_net = {}
        for i, task_tower_cfg in enumerate(self._task_tower_cfgs):
            tower_name = task_tower_cfg.tower_name
            if tower_name in self.task_mlps.keys():
                task_net[tower_name] = self.task_mlps[tower_name](task_input_list[i])
            else:
                task_net[tower_name] = task_input_list[i]

        intervention = {}
        for task_tower_cfg in self._task_tower_cfgs:
            tower_name = task_tower_cfg.tower_name
            if task_tower_cfg.HasField("low_rank_dim"):
                intervention_base = task_net[tower_name]
                intervention_source = []
                for intervention_tower_name in task_tower_cfg.intervention_tower_names:
                    intervention_source.append(intervention[intervention_tower_name])
                intervention_source = torch.cat(intervention_source, dim=-1)  # .mean(0)
                intervention[tower_name] = self.intervention[tower_name](
                    intervention_base, intervention_source
                )
            else:
                intervention[tower_name] = task_net[tower_name]

        tower_outputs = {}
        for i, task_tower_cfg in enumerate(self._task_tower_cfgs):
            tower_name = task_tower_cfg.tower_name
            tower_output = self.task_outputs[i](intervention[tower_name])
            tower_outputs[tower_name] = tower_output

        return self._multi_task_output_to_prediction(tower_outputs)