tinynn/prune/admm_pruner.py (115 lines of code) (raw):
import os
import torch
import torch.distributed as dist
from tinynn.util.util import get_logger
from tinynn.prune.oneshot_pruner import OneShotChannelPruner
log = get_logger(__name__)
class ADMMPruner(OneShotChannelPruner):
required_params = ('sparsity', 'metrics', 'admm_iterations', 'admm_epoch', 'rho', 'admm_lr')
required_context_params = ('val_loader', 'train_loader', 'train_func', 'validate_func', 'optimizer', 'criterion')
default_values = {'admm_save_freq': 1, 'admm_valid_freq': 1, 'admm_dir': 'admm_train/'}
context_from_params_dict = {
'optimizer': ['admm_optimizer', 'optimizer'],
'criterion': ['admm_criterion', 'criterion'],
}
condition_dict = {
'admm_iterations': lambda x: 0 < x,
'admm_epoch': lambda x: 0 < x,
'rho': lambda x: 0 < x < 1,
'admm_lr': lambda x: 0 < x < 1,
}
admm_iterations: int
admm_epoch: int
rho: float
admm_lr: float
admm_dir: str
admm_save_freq: int
admm_valid_freq: int
def __init__(self, model, dummy_input, config, context):
super().__init__(model, dummy_input, config)
self.parse_context(context)
self.Z = {}
self.U = {}
def prune(self):
"""The main function for pruning"""
log.info('Start ADMM training')
old_criterion = self.context.criterion
self.register_mask()
for iteration in range(1, self.admm_iterations + 1):
for epoch in range(1, self.admm_epoch + 1):
self.context.epoch = epoch + (iteration - 1) * self.admm_epoch
self.adjust_learning_rate()
self.context.criterion = self.construct_admm_criterion(old_criterion)
if dist.is_available() and dist.is_initialized():
self.context.train_loader.sampler.set_epoch(self.context.epoch)
self.context.train_func(self.model, self.context)
if not dist.is_available() or not dist.is_initialized() or dist.get_rank() == 0:
if self.context.epoch % self.admm_save_freq == 0:
save_path = os.path.join(self.admm_dir, f'epoch_{self.context.epoch}.pth')
os.makedirs(os.path.dirname(save_path), exist_ok=True)
log.info("Saving model to {}".format(save_path))
torch.save(self.model.state_dict(), save_path)
if self.context.validate_func is not None and self.context.epoch % self.admm_valid_freq == 0:
# According to https://github.com/pytorch/pytorch/issues/54059, when validating via DDP,
# it needs to be done on the original module.
if dist.is_available() and dist.is_initialized():
self.context.validate_func(self.model.module, self.context)
else:
self.context.validate_func(self.model, self.context)
self.context.best_epoch = self.context.epoch
self.admm_params_update()
self.context.criterion = old_criterion
# Need to reset the masker before final pruning
self.graph_modifier.reset_masker()
super().prune()
def adjust_learning_rate(self):
epoch = self.context.epoch
admm_epoch = self.admm_epoch
if (epoch - 1) % admm_epoch == 0:
lr = self.admm_lr
else:
admm_epoch_offset = (epoch - 1) % admm_epoch
# LR is updated roughly every 1/3 admm_epoch.
admm_step = admm_epoch / 3
lr = self.admm_lr * (0.1 ** (admm_epoch_offset // admm_step))
for param_group in self.context.optimizer.param_groups:
param_group['lr'] = lr
def construct_admm_criterion(self, old_criterion):
def criterion_func(output, target):
loss = old_criterion(output, target)
for n in self.center_nodes:
if n.unique_name not in self.Z or n.unique_name not in self.U:
continue
loss += (
0.5
* self.rho
* (torch.norm(n.module.weight - self.Z[n.unique_name] + self.U[n.unique_name], p=2) ** 2)
)
return loss
return criterion_func
def register_mask(self):
super().register_mask()
for m in self.graph_modifier.modifiers.values():
# Disable mask here so that we will use them to update U and Z.
# They won't be applied so that the training process won't be affected.
m.disable_mask()
if m.node in self.center_nodes and m.dim_changes_info.pruned_idx_o:
device = m.module().weight.data.device
self.Z[m.unique_name()] = m.module().weight.detach() * m.masker().get_mask('weight').to(device=device)
self.U[m.unique_name()] = torch.zeros_like(self.Z[m.unique_name()], device=device)
def admm_params_update(self):
self.graph_modifier.reset_masker()
importance = {}
for sub_graph in self.graph_modifier.sub_graphs.values():
for m in sub_graph.modifiers:
if m.node in self.center_nodes and m in sub_graph.dependent_centers:
if m.unique_name() not in self.U:
continue
self.Z[m.unique_name()] = (m.module().weight + self.U[m.unique_name()]).detach()
importance[m.unique_name()] = self.metric_func(self.Z[m.unique_name()], m.module())
sub_graph.calc_prune_idx(importance, self.sparsity)
log.info(f"subgraph [{sub_graph.center}] compute over")
for m in self.graph_modifier.modifiers.values():
m.register_mask(self.graph_modifier.modifiers, importance, self.sparsity)
# Disable mask here so that we will use them to update U and Z.
# They won't be applied so that the training process won't be affected.
m.disable_mask()
if m.node in self.center_nodes and m.dim_changes_info.pruned_idx_o:
weight = m.module().weight
self.Z[m.unique_name()] = self.Z[m.unique_name()] * m.masker().get_mask('weight')
self.U[m.unique_name()] = weight - self.Z[m.unique_name()] + self.U[m.unique_name()]
# Sync ADMM parameters
if dist.is_available() and dist.is_initialized():
for state in (self.U, self.Z):
for param in state.values():
dist.broadcast(param, 0)