tzrec/modules/variational_dropout.py (83 lines of code) (raw):
# Copyright (c) 2024, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional, Tuple
import torch
from torch import Tensor, nn
from torch.functional import F
from tzrec.utils.logging_util import logger
@torch.fx.wrap
def _feature_tile(
feature_p: torch.Tensor,
feature: torch.Tensor,
) -> Tensor:
return feature_p.tile([feature.size(0), 1])
@torch.fx.wrap
def _update_dict_tensor(
group_name: str, features: Dict[str, torch.Tensor], new_feature: torch.Tensor
) -> Dict[str, torch.Tensor]:
features[group_name] = new_feature
return features
class VariationalDropout(nn.Module):
"""Rank features by variational dropout.
Args:
features_dimension: features dimension.
name: group name.
regularization_lambda: regularization lambda
"""
def __init__(
self,
features_dimension: Dict[str, int],
name: str,
regularization_lambda: Optional[float] = 0.01,
**kwargs: Optional[Dict[str, Any]],
) -> None:
super().__init__()
self.group_name = name
self.features_dimension = features_dimension
self._regularization_lambda = regularization_lambda
self.feature_p = nn.parameter.Parameter(
torch.randn(len(features_dimension), requires_grad=True)
)
self.feature_dim_repeat = nn.parameter.Parameter(
torch.tensor(list(features_dimension.values()), dtype=torch.int),
requires_grad=False,
)
logger.info(
f"group name: {name} has VariationalDropout ! "
f"feature number:{len(features_dimension)}, "
f"features:{features_dimension.keys()}"
)
def concrete_dropout_neuron(
self, dropout_p: torch.Tensor, temp: float = 0.1
) -> Tensor:
"""Add disturbance to dropout probability."""
EPSILON = torch.finfo(torch.float32).eps
unif_noise = torch.rand_like(dropout_p)
approx = (
torch.log(dropout_p + EPSILON)
- torch.log(1.0 - dropout_p + EPSILON)
+ torch.log(unif_noise + EPSILON)
- torch.log(1.0 - unif_noise + EPSILON)
)
approx_output = F.sigmoid(approx / temp)
return approx_output
def sample_noisy_input(self, feature: Tensor) -> Tensor:
"""Add noisy for feature."""
if self.training:
dropout_p = self.feature_p.sigmoid()
dropout_p = torch.unsqueeze(dropout_p, dim=0)
dropout_p = _feature_tile(dropout_p, feature)
bern_val = self.concrete_dropout_neuron(dropout_p)
bern_val = torch.repeat_interleave(
bern_val, self.feature_dim_repeat, dim=-1
)
noisy_input = feature * (1 - bern_val)
else:
dropout_p = self.feature_p.sigmoid()
dropout_p = torch.unsqueeze(dropout_p, dim=0)
dropout_p = _feature_tile(dropout_p, feature)
dropout_p = torch.repeat_interleave(
dropout_p, self.feature_dim_repeat, dim=-1
)
noisy_input = feature * (1 - dropout_p)
return noisy_input
def forward(self, feature: Tensor) -> Tuple[Tensor, Tensor]:
"""Add dropout to feature."""
noisy_input = self.sample_noisy_input(feature)
dropout_p = self.feature_p.sigmoid()
variational_dropout_penalty = 1.0 - dropout_p
sample_num = feature.size(0)
# pyre-ignore [58]
variational_dropout_penalty_lambda = self._regularization_lambda / sample_num
variational_dropout_loss_sum = variational_dropout_penalty_lambda * torch.sum(
variational_dropout_penalty
)
return noisy_input, variational_dropout_loss_sum