tinynn/graph/masker.py (70 lines of code) (raw):

from copy import deepcopy import torch class Masker(object): """Manage mask of module""" def __init__(self, module, unique_name): self.module = module self.unique_name = unique_name self.masks = {} self.enabled = False # del old hook if hasattr(self.module, "masker"): self.module.masker.hook.remove() # add new hook if isinstance(self.module, torch.nn.Module): self.hook = self.module.register_forward_pre_hook(self) setattr(module, "masker", self) def __call__(self, module, inputs): """Apply mask when module forward""" if not self.enabled: return for tensor_name, mask in self.masks.items(): t = getattr(module, tensor_name) t.data = t.data.mul_(mask) def enable(self): self.enabled = True def disable(self): self.enabled = False def get_mask(self, tensor_name): return self.masks.get(tensor_name, None) def register_mask(self, tensor_name, mask): """Register a mask to the module, a module can have multiple masks""" if self.masks.get(tensor_name, None) is not None: del self.masks[tensor_name] self.masks[tensor_name] = torch.nn.Parameter(mask, requires_grad=False) setattr(self.module, f"{tensor_name}_mask", self.masks[tensor_name]) def unregister_all(self): """Unregister all masks""" if isinstance(self.module, torch.nn.Module): self.hook.remove() for tensor_name, mask in self.masks.items(): delattr(self.module, f"{tensor_name}_mask") self.masks.clear() def serialization(self): return self.unique_name, [self.masks] def deserialization(self, value): self.masks = value[0] class ChannelMasker(Masker): """Channel-wise module masking""" def __init__(self, module, unique_name): super(ChannelMasker, self).__init__(module, unique_name) # Input channel to be deleted self.in_remove_idx = None # Output channel to be deleted self.ot_remove_idx = None # Custom channel to be deleted self.custom_remove_idx = None def set_in_remove_idx(self, in_remove_idx): if self.in_remove_idx is not None: self.in_remove_idx.extend(deepcopy(in_remove_idx)) self.in_remove_idx.sort() else: self.in_remove_idx = deepcopy(in_remove_idx) def set_ot_remove_idx(self, ot_remove_idx): if self.ot_remove_idx is not None: self.ot_remove_idx.extend(deepcopy(ot_remove_idx)) self.ot_remove_idx.sort() else: self.ot_remove_idx = deepcopy(ot_remove_idx) def set_custom_remove_idx(self, custom_remove_idx): if self.custom_remove_idx is not None: self.custom_remove_idx.extend(deepcopy(custom_remove_idx)) self.custom_remove_idx.sort() else: self.custom_remove_idx = deepcopy(custom_remove_idx) def serialization(self): return self.unique_name, [self.masks, self.in_remove_idx, self.ot_remove_idx] def deserialization(self, value): self.masks = value[0] self.in_remove_idx = value[1] self.ot_remove_idx = value[2]