tzrec/modules/extraction_net.py (102 lines of code) (raw):
# Copyright (c) 2024, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, Tuple
import torch
from torch import nn
from tzrec.modules.mlp import MLP
class ExtractionNet(nn.Module):
"""Multiple multi-gate Mixture-of-Experts module.
Args:
in_extraction_networks (list): every task expert input dims.
in_shared_expert (int): shared expert input dims.
network_name: ExtractionNet name, not important.
share_num: number of experts for share.
expert_num_per_task: number of experts per task.
share_expert_net: mlp network config of experts share.
task_expert_net: mlp network config of experts per task.
final_flag: whether to is last extractionNet or not.
"""
def __init__(
self,
in_extraction_networks: List[int],
in_shared_expert: int,
network_name: str,
share_num: int,
expert_num_per_task: int,
share_expert_net: Dict[str, Any],
task_expert_net: Dict[str, Any],
final_flag: bool = False,
) -> None:
super().__init__()
self.name = network_name
self._final_flag = final_flag
self._shared_layers = nn.ModuleList()
self._shared_gate = None
self._output_dims = []
share_net_num = share_num
per_task_num = expert_num_per_task
share_output_dim = share_expert_net["hidden_units"][-1]
for _ in range(share_net_num):
self._shared_layers.append(
MLP(
in_shared_expert,
**share_expert_net,
)
)
share_gate_output = len(in_extraction_networks) * per_task_num + share_net_num
self._shared_gate = None
if not self._final_flag:
self._shared_gate = nn.Linear(in_shared_expert, share_gate_output)
self._task_layers = nn.ModuleList()
self._task_gates = nn.ModuleList()
task_gate_output = per_task_num + share_net_num
task_output_dim = task_expert_net["hidden_units"][-1]
for in_feature in in_extraction_networks:
task_model_list = nn.ModuleList()
for _ in range(per_task_num):
task_model_list.append(
MLP(
in_feature,
**task_expert_net,
)
)
self._task_layers.append(task_model_list)
self._task_gates.append(nn.Linear(in_feature, task_gate_output))
self._output_dims.append(task_output_dim)
self._output_dims.append(share_output_dim)
def output_dim(self) -> List[int]:
"""Output Task expert and shared expert dimension of the module."""
return self._output_dims
def _experts_layer_forward(
self, deep_fea: torch.Tensor, layers: nn.ModuleList
) -> List[torch.Tensor]:
tower_outputs = []
for layer in layers:
output = layer(deep_fea)
tower_outputs.append(output)
return tower_outputs
def _gate_forward(
self,
selector_fea: torch.Tensor,
vec_feas: List[torch.Tensor],
gate_layer: nn.Module,
) -> torch.Tensor:
vec = torch.stack(vec_feas, dim=1)
gate = gate_layer(selector_fea)
gate = torch.softmax(gate, dim=1)
gate = torch.unsqueeze(gate, dim=1)
output = torch.matmul(gate, vec).squeeze(1)
return output
def forward(
self,
extraction_network_fea: List[torch.Tensor],
shared_expert_fea: torch.Tensor,
) -> Tuple[List[torch.Tensor], Optional[torch.Tensor]]:
"""Forward the module."""
shared_expert = self._experts_layer_forward(
shared_expert_fea, self._shared_layers
)
all_task_experts = []
cgc_layer_outs = []
for i, task_layers in enumerate(self._task_layers):
task_experts = self._experts_layer_forward(
extraction_network_fea[i], task_layers
)
cgc_task_out = self._gate_forward(
extraction_network_fea[i],
task_experts + shared_expert,
self._task_gates[i],
)
all_task_experts.extend(task_experts)
cgc_layer_outs.append(cgc_task_out)
shared_layer_out = None
if self._shared_gate:
shared_layer_out = self._gate_forward(
shared_expert_fea, all_task_experts + shared_expert, self._shared_gate
)
return cgc_layer_outs, shared_layer_out