tinynn/prune/netadapt_pruner.py (278 lines of code) (raw):
import copy
import math
import multiprocessing as mp
import os
import shutil
import sys
import typing
import torch
from tinynn.graph import modifier
from tinynn.util.util import get_logger
from tinynn.prune import OneShotChannelPruner
if sys.version_info.major == 3 and sys.version_info.minor < 7:
from futures3.process import ProcessPoolExecutor
else:
from concurrent.futures import ProcessPoolExecutor
log = get_logger(__name__)
def device_init(device_ids):
if torch.cuda.is_available():
device_id = device_ids.get()
log.info(f'Init pool process with cuda id {device_id}')
torch.cuda.set_device(device_id)
class NetAdaptPruner(OneShotChannelPruner):
required_params = (
'budget_type',
'metrics',
'netadapt_max_iter',
'budget_reduce_rate_init',
'budget_reduce_rate_decay',
'netadapt_lr',
)
required_context_params = ('val_loader', 'train_loader', 'train_func', 'validate_func', 'optimizer', 'criterion')
default_values = {'netadapt_dir': 'netadapt_train/', 'netadapt_max_rounds': -1, 'netadapt_min_feature_size': 8}
context_from_params_dict = {
'optimizer': ['netadapt_optimizer', 'optimizer'],
'criterion': ['netadapt_criterion', 'criterion'],
}
condition_dict = {
'netadapt_max_iter': lambda x: 0 < x,
'budget_reduce_rate_init': lambda x: 0 <= x <= 1,
'budget_reduce_rate_decay': lambda x: 0 <= x <= 1,
'netadapt_lr': lambda x: 0 < x < 1,
'netadapt_max_rounds': lambda x: x >= -1,
'budget_type': lambda x: x in ('flops', 'weights', 'latency'),
}
budget: int
budget_ratio: float
budget_type: str
budget_reduce_rate_init: float
budget_reduce_rate_decay: float
netadapt_max_iter: int
netadapt_lr: float
netadapt_max_rounds: int
netadapt_min_feature_size: int
netadapt_dir: str
init_flops: int
def __init__(self, model, dummy_input, config, context):
self.default_sparsity = 0.0
self.original_channels = {}
super().__init__(model, dummy_input, config)
self.parse_context(context)
self.iteration = 1
for node in self.center_nodes:
self.original_channels[node.unique_name] = node.module.weight.shape[0]
def parse_config(self):
"""Parses the context and copy the needed items to the pruner"""
super(OneShotChannelPruner, self).parse_config()
all_param_keys = list(self.required_params) + list(self.default_values.keys())
for param_key in all_param_keys:
if param_key not in ['sparsity', 'metrics']:
setattr(self, param_key, self.config[param_key])
metrics = self.config['metrics']
if hasattr(modifier, metrics):
self.metric_func = getattr(modifier, metrics)
else:
raise Exception(f'{metrics} is not a known metrics for {type(self).__name__}')
budget = None
budget_ratio = None
if self.budget_type != 'flops':
# TODO: Implement budget type: weights, latency
raise NotImplementedError('Only `budget_type == "flops"` is supported')
if 'budget' in self.config:
budget = self.config['budget']
if not isinstance(budget, int):
raise Exception('The type of `budget` should be int')
if budget < 0:
raise Exception('The value of `budget` doesn\'t meet the requirement: x >= 0')
if 'budget_ratio' in self.config:
budget_ratio = self.config['budget_ratio']
if not isinstance(budget_ratio, float):
raise Exception('The type of `budget_ratio` should be float')
if budget_ratio < 0 or budget_ratio > 1:
raise Exception('The value of `budget_ratio` doesn\'t meet the requirement: 0 <= x <= 1')
if (budget is not None) != (budget_ratio is not None):
self.init_flops = self.calc_flops()
if budget_ratio is not None:
self.budget = int(budget_ratio * self.init_flops)
else:
self.budget = budget
log.info(f'Global Target/Initial FLOPS: {self.budget}/{self.init_flops}')
else:
raise Exception('You should define either `budget` or `budget_ratio` for NetAdaptPruner, not both of them')
def generate_config(self, path: str) -> None:
"""Generates a new copy the updated configuration with the given path"""
super(OneShotChannelPruner, self).generate_config(path)
def get_sparsity_state(self) -> typing.Dict[str, float]:
"""Calculate the sparsity of the subgraphs in the original model"""
sparsity = {}
for node in self.center_nodes:
sparsity[node.unique_name] = (
self.original_channels[node.unique_name] - node.module.weight.shape[0]
) / self.original_channels[node.unique_name]
return sparsity
def get_pruned_subgraph_info(
self, iteration: int, subgraph_id: int, current_flops: float, target_flops: float
) -> typing.Tuple[float, typing.Dict[str, float], int]:
"""Prunes the given subgraph so that the flops of the model <
target flops and finetunes model with the pruned subgraph"""
if torch.cuda.is_available():
self.context.device = torch.device("cuda", torch.cuda.current_device())
else:
self.context.device = torch.device("cpu")
# Copies the model, otherwise the one in the main process will be updated as well
self.model = copy.deepcopy(self.model)
self.model.eval()
self.reset()
log.info(f'Processing subgraph index {subgraph_id} at iteration: {iteration}, device: {self.context.device}')
sparsity, new_flops = self.find_prune_plan(subgraph_id, current_flops, target_flops)
if len(sparsity) == 0:
# Cannot find a plan, early stop
log.warning(f'Subgraph: {subgraph_id}, Iteration: {iteration}, cannot find a plan')
return -1, {}, -1
# Apply the mask to get a pruned model
super().apply_mask()
# Regenerate the optimizer, since the model has changed
self.context.optimizer = type(self.context.optimizer)(
self.model.parameters(), **self.context.optimizer.defaults
)
# Use fork to speed up data loading in child processes
mp_context = mp.get_context('fork')
self.context.train_loader.multiprocessing_context = mp_context
self.context.val_loader.multiprocessing_context = mp_context
num_iter_per_epoch = len(self.context.train_loader)
max_epoch = math.ceil(self.netadapt_max_iter / num_iter_per_epoch)
max_iter_last_epoch = (self.netadapt_max_iter - 1) % num_iter_per_epoch + 1
for epoch in range(1, max_epoch + 1):
self.context.epoch = epoch
if epoch == max_epoch:
old_max_iter = self.context.max_iteration
self.context.max_iteration = max_iter_last_epoch
self.context.train_func(self.model, self.context)
if epoch == max_epoch:
self.context.max_iteration = old_max_iter
save_path = os.path.join(self.netadapt_dir, f'iter_{iteration}_subgraph_{subgraph_id}.pth')
os.makedirs(os.path.dirname(save_path), exist_ok=True)
log.info("Saving model to {}".format(save_path))
torch.save(self.model.state_dict(), save_path)
acc = self.context.validate_func(self.model, self.context)
self.context.best_epoch = self.context.epoch
log.info(f'Subgraph: {subgraph_id}, Iteration: {iteration}, FLOPS: {new_flops}, Accuracy: {acc}')
del self.model
return acc, sparsity, new_flops
def find_prune_plan(
self, subgraph_id: int, pre_flops: float, target_flops: float
) -> typing.Tuple[typing.Dict[str, float], int]:
"""Figures out the best plan to prune the given subgraph so that the flops of the model < target flops"""
nodes = []
for m in self.graph_modifier.sub_graphs[subgraph_id]:
if m.node in self.center_nodes and m.output_modify_:
nodes.append(m.node)
num_out_channels = nodes[0].next_tensors[0].shape[1]
# All possible number of channels according to `netadapt_min_feature_size`
candidate_channels = list(
range(
num_out_channels // self.netadapt_min_feature_size * self.netadapt_min_feature_size,
self.netadapt_min_feature_size - 1,
-self.netadapt_min_feature_size,
)
)
# Possible sparsity for pruning the model
sparsity_per_step = list(
map(
lambda x: (num_out_channels - x[0]) / num_out_channels,
zip(candidate_channels[1:], candidate_channels[:-1]),
)
)
# Init sparsity dictionary
self.sparsity = self.sparsity.fromkeys(self.sparsity, 0.0)
post_flops = pre_flops
diff_flops = None
possible_diff_flops = None
for idx, s in enumerate(sparsity_per_step):
for node in nodes:
self.sparsity[node.unique_name] = s
# Fast-forward when possible
if diff_flops is not None:
post_flops -= diff_flops
if not post_flops < target_flops:
continue
target_out_channels = candidate_channels[idx + 1]
for node in nodes:
self.sparsity[node.unique_name] = (num_out_channels - target_out_channels) / num_out_channels
# Reset masks
if idx != 0:
self.graph_modifier.unregister_masker()
self.graph_modifier.reset_masker()
# Get the updated masks
super().register_mask()
# Update flops for the current graph
if diff_flops is None:
next_flops = self.calc_flops()
if idx == 1:
# Skip further calculations if FLOPs varies proportional to output channel size
num_cur_flops = post_flops
num_next_flops = next_flops
if (
num_cur_flops % candidate_channels[idx] == 0
and num_next_flops % candidate_channels[idx + 1] == 0
and num_cur_flops // candidate_channels[idx] == num_next_flops // candidate_channels[idx + 1]
):
diff_flops = num_cur_flops - num_next_flops
else:
possible_diff_flops = num_cur_flops - num_next_flops
elif idx == 2:
# Skip further calculations if FLOPS(n) - FLOP(n-1) is constant
if post_flops - next_flops == possible_diff_flops:
diff_flops = possible_diff_flops
possible_diff_flops = None
post_flops = next_flops
log.info(
f'Subgraph: {subgraph_id}, '
f'Channels: {candidate_channels[idx + 1]}/{num_out_channels}, '
f'FLOPS(pre/post/target): {pre_flops}/{post_flops}/{target_flops:.2f}'
)
# Early stop if we get the desired sparsity
if post_flops < target_flops:
sparsity = copy.deepcopy(self.sparsity)
return sparsity, post_flops
return {}, -1
def prune_subgraph(self, sparsity: typing.Dict[str, float]) -> None:
"""Prunes the model with the sparsity dictionary given"""
self.sparsity = copy.deepcopy(sparsity)
super().prune()
def prune(self):
"""The main function for pruning"""
# PyTorch forbids initialization CUDA in the default fork settings
# So `spawn` is used here instead
mp_context = mp.get_context('spawn')
if torch.cuda.is_available():
max_workers = torch.cuda.device_count()
else:
max_workers = 1
available_cores = mp_context.Queue()
for i in range(max_workers):
available_cores.put(i)
with ProcessPoolExecutor(
max_workers=max_workers, mp_context=mp_context, initializer=device_init, initargs=(available_cores,)
) as pool:
cur_flops = self.calc_flops()
while (
self.netadapt_max_rounds == -1 or self.iteration <= self.netadapt_max_rounds
) and cur_flops > self.budget:
log.info(f"Start iteration {self.iteration}")
# Acquire target FLOPs for the current ratio
target_flops = cur_flops - self.budget_reduce_rate_init * cur_flops * (
self.budget_reduce_rate_decay ** (self.iteration - 1)
)
# Regenerate modifier, graph and center nodes before sending to child processes
if self.iteration != 1:
self.reset()
# Prepare jobs for child processes
num_subgraphs = len(self.graph_modifier.sub_graphs)
results = pool.map(
self.get_pruned_subgraph_info,
[self.iteration] * num_subgraphs,
range(num_subgraphs),
[cur_flops] * num_subgraphs,
[target_flops] * num_subgraphs,
)
# Get the step we should take by comparing the best accuracy among all sub jobs
max_idx, (best_acc, best_sparsity, best_flops) = max(enumerate(results), key=lambda x: x[1][0])
if len(best_sparsity) == 0:
log.error('All subgraphs yield invalid result, stopping')
break
# Sync model back to the main process
self.prune_subgraph(best_sparsity)
load_path = os.path.join(self.netadapt_dir, f'iter_{self.iteration}_subgraph_{max_idx}.pth')
self.model.load_state_dict(torch.load(load_path, map_location='cpu'))
global_sparsity = self.get_sparsity_state()
global_sparsity['best_flops'] = best_flops
global_sparsity['iteration'] = self.iteration
sparsity_path = os.path.join(self.netadapt_dir, f'iter_{self.iteration}_sparsity.yml')
super(OneShotChannelPruner, self).generate_config(sparsity_path, global_sparsity)
save_path = os.path.join(self.netadapt_dir, f'iter_{self.iteration}.pth')
shutil.copy(load_path, save_path)
# Print info and update current flops
log.info(f'Iter: {self.iteration}, Acc: {best_acc}, FLOPS: {best_flops}, Subgraph: {max_idx}')
cur_flops = best_flops
self.iteration += 1
def restore(self, iteration: int) -> torch.nn.Module:
"""Restores a model at specific iteration"""
sparsity_path = os.path.join(self.netadapt_dir, f'iter_{iteration}_sparsity.yml')
weights_path = os.path.join(self.netadapt_dir, f'iter_{iteration}.pth')
for path in (sparsity_path, weights_path):
if not os.path.exists(path):
raise FileNotFoundError(f'{path} is required for restoring the model')
config = self.load_config(sparsity_path)
self.prune_subgraph(config)
self.model.load_state_dict(torch.load(weights_path, map_location='cpu'))
self.iteration = config['iteration'] + 1
log.info(f"Restored from iteration {iteration}")
return self.model