tzrec/modules/intervention.py (41 lines of code) (raw):
# Copyright (c) 2024, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import nn
class RotateLayer(nn.Module):
"""Applies a orthogonal low-rank transformation.
Args:
base_dim (int): The dimension of the original space.
low_rank_dim (int): The dimension of the low-rank space.
"""
def __init__(self, base_dim: int, low_rank_dim: int) -> None:
super().__init__()
# n > m
assert base_dim > low_rank_dim, "Low-rank dimension should lower than the base"
self.weight = torch.nn.Parameter(
torch.empty(base_dim, low_rank_dim), requires_grad=True
)
torch.nn.init.orthogonal_(self.weight)
def forward(self, base: torch.Tensor) -> torch.Tensor:
"""Forward the module."""
return torch.matmul(base.to(self.weight.dtype), self.weight)
class Intervention(nn.Module):
"""Deducing the influence of the source on the base representations.
Args:
base_dim (int): The dimension of the base space.
source_dim (int): The dimension of the source space.
low_rank_dim (int): The dimension of the low-rank space
(Shared space for the base and source).
drpout_ratio: dropout rate for the intervented output.
"""
def __init__(
self,
base_dim: int,
source_dim: int,
low_rank_dim: int,
dropout_ratio: float = 0.0,
) -> None:
super().__init__()
self.base_dim = base_dim
base_rotate_layer = RotateLayer(base_dim, low_rank_dim)
self.base_rotate_layer = torch.nn.utils.parametrizations.orthogonal(
base_rotate_layer
)
source_rotate_layer = RotateLayer(source_dim, low_rank_dim)
self.source_rotate_layer = torch.nn.utils.parametrizations.orthogonal(
source_rotate_layer
)
self.dropout = torch.nn.Dropout(dropout_ratio)
def forward(self, base: torch.Tensor, source: torch.Tensor) -> torch.Tensor:
"""Forward the module."""
rotated_base = self.base_rotate_layer(base)
rotated_source = self.source_rotate_layer(source.detach())
output = (
torch.matmul(rotated_base - rotated_source, self.base_rotate_layer.weight.T)
+ base
)
return self.dropout(output.to(base.dtype))
def output_dim(self) -> int:
"""Output dimension of the module."""
return self.base_dim