tzrec/modules/capsule.py (117 lines of code) (raw):

# Copyright (c) 2025, Alibaba Group; # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Optional, Tuple import torch from torch import nn from tzrec.protos.module_pb2 import B2ICapsule def sequence_mask(lengths: torch.Tensor, max_len: Optional[int] = None) -> torch.Tensor: """Create a boolean mask from sequence lengths. Args: lengths (Tensor): 1-D tensor containing actual sequence lengths. max_len (int, optional): max length. If None, max_len is the maximum value in lengths. Returns: mask (Tensor): boolean mask with shape [len(lengths), max_len], the first lengths[i] elements of i-th row are True, and the rest are False. """ if max_len is None: max_len = int(lengths.max().item()) mask = torch.arange(0, max_len).to(lengths.device) # symbolic tracing trick zeros_padding = torch.zeros_like(lengths).unsqueeze(1).tile(1, max_len) mask = mask + zeros_padding # broadcasting mask = mask < lengths.unsqueeze(1) return mask @torch.fx.wrap def _init_routing_logits(x: torch.Tensor, k: int) -> torch.Tensor: return torch.randn( x.size()[:-1] + torch.Size([k]), # pyre-ignore [58] device=x.device, dtype=x.dtype, ) class CapsuleLayer(nn.Module): """Capsule layer. Args: capsule_config (B2ICapsule): capsule config. input_dim (int): input dimension. """ def __init__( self, capsule_config: B2ICapsule, input_dim: int, *args: Any, **kwargs: Any ) -> None: """Capsule layer.""" super().__init__(*args, **kwargs) # max_seq_len: max behaviour sequence length(history length) self._max_seq_len = capsule_config.max_seq_len # max_k: max high capsule number self._max_k = capsule_config.max_k # high_dim: high capsule vector dimension self._high_dim = capsule_config.high_dim # low_dim: low capsule vector dimension self._low_dim = input_dim # number of dynamic routing iterations self._num_iters = capsule_config.num_iters # routing_logits_scale self._routing_logits_scale = capsule_config.routing_logits_scale # routing_logits_stddev self._routing_logits_stddev = capsule_config.routing_logits_stddev # squash power self._squash_pow = capsule_config.squash_pow # scale ratio # self._scale_ratio = capsule_config.scale_ratio self._const_caps_num = capsule_config.const_caps_num self.bilinear_matrix = nn.Parameter( torch.randn(self._low_dim, self._high_dim) ) # [ld, hd] def squash(self, inputs: torch.Tensor) -> torch.Tensor: """Squash inputs over the last dimension. Args: inputs: Tensor, shape: [batch_size, max_k, high_dim] Return: Tensor, shape: [batch_size, max_k, high_dim] """ input_norm = torch.linalg.norm(inputs, dim=-1, keepdim=True) input_norm_eps = torch.max(input_norm, torch.tensor(1e-7)) scale_factor = ( torch.pow( torch.square(input_norm_eps) / (1 + torch.square(input_norm_eps)), self._squash_pow, ) / input_norm_eps ) return scale_factor * inputs def dynamic_routing( self, inputs: torch.Tensor, seq_mask: torch.Tensor, capsule_mask: torch.Tensor, num_iters: int, ) -> torch.Tensor: """Dynamic routing algorithm. Args: inputs: Tensor, shape: [batch_size, max_seq_len, low_dim] seq_mask: Tensor, shape: [batch_size, max_seq_len] capsule_mask: Tensor, shape: [batch_size, max_k] num_iters: int, number of iterations Return: [batch_size, max_k, high_dim] """ routing_logits = _init_routing_logits(inputs, self._max_k) routing_logits = routing_logits.detach() routing_logits = routing_logits * self._routing_logits_stddev capsule_mask = capsule_mask.unsqueeze(1) # [bs, 1, max_k] capsule_mask_thresh = (capsule_mask.float() * 2 - 1) * 1e32 low_capsule_vec = torch.einsum("bsl, lh -> bsh", inputs, self.bilinear_matrix) low_capsule_vec_detach = low_capsule_vec.detach() low_capsule_vec_detach_norm = torch.nn.functional.normalize( low_capsule_vec_detach, p=2.0, dim=-1 ) assert num_iters > 0, "num_iters should be greater than 0" high_capsule_vec = torch.Tensor([0]) for iter in range(num_iters): routing_logits = torch.minimum(routing_logits, capsule_mask_thresh) routing_logits = torch.nn.functional.softmax( routing_logits * self._routing_logits_scale, dim=2 ) # [b, s, k] routing_logits = routing_logits * seq_mask.unsqueeze(2).float() if iter + 1 < num_iters: high_capsule_vec = torch.einsum( "bsh,bsk->bkh", low_capsule_vec_detach, routing_logits ) routing_logits = routing_logits + torch.einsum( "bkh, bsh -> bsk", high_capsule_vec, low_capsule_vec_detach_norm ) else: high_capsule_vec = torch.einsum( "bsh,bsk->bkh", low_capsule_vec, routing_logits ) high_capsule_vec = self.squash(high_capsule_vec) return high_capsule_vec def forward( self, inputs: torch.Tensor, seq_len: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """Forward method. Args: inputs: [batch_size, seq_len, low_dim] seq_len: [batch_size] Return: [batch_size, max_k, high_dim] """ _, s, _ = inputs.shape device = inputs.device # truncating or padding to the input sequence, # avoid using if-else statement since the symbolic # traced variables are not allowed in control flow padding_tensor = torch.zeros_like(inputs)[:, 0:1, :].to(device) padding_tensor = padding_tensor.tile(1, self._max_seq_len, 1) inputs = inputs[:, : self._max_seq_len, :] inputs = torch.cat([inputs, padding_tensor[:, s:, :]], dim=1) seq_mask = sequence_mask(seq_len, self._max_seq_len) seq_mask = seq_mask.to(device) inputs = inputs * seq_mask.unsqueeze(-1).float() if self._const_caps_num: n_high_capsules = ( torch.zeros_like(seq_len, dtype=torch.float32) + self._max_k ) # [bs,] n_high_capsules = n_high_capsules.to(device) else: n_high_capsules = torch.maximum( torch.Tensor([1]).to(seq_len.device), torch.minimum( torch.Tensor([self._max_k]).to(seq_len.device), torch.log2(seq_len.float()), ), ).to(device) # [bs,] capsule_mask = sequence_mask(n_high_capsules, self._max_k) capsule_mask = capsule_mask.to(device) user_interests = self.dynamic_routing( inputs, seq_mask, capsule_mask, self._num_iters ) user_interests = user_interests * capsule_mask.unsqueeze(-1).float() return user_interests, capsule_mask