Dassl.pytorch/dassl/utils/torchtools.py (201 lines of code) (raw):

""" Modified from https://github.com/KaiyangZhou/deep-person-reid """ import pickle import shutil import os.path as osp import warnings from functools import partial from collections import OrderedDict import torch import torch.nn as nn from .tools import mkdir_if_missing __all__ = [ "save_checkpoint", "load_checkpoint", "resume_from_checkpoint", "open_all_layers", "open_specified_layers", "count_num_param", "load_pretrained_weights", "init_network_weights", ] def save_checkpoint( state, save_dir, is_best=False, remove_module_from_keys=True, model_name="" ): r"""Save checkpoint. Args: state (dict): dictionary. save_dir (str): directory to save checkpoint. is_best (bool, optional): if True, this checkpoint will be copied and named ``model-best.pth.tar``. Default is False. remove_module_from_keys (bool, optional): whether to remove "module." from layer names. Default is True. model_name (str, optional): model name to save. """ mkdir_if_missing(save_dir) if remove_module_from_keys: # remove 'module.' in state_dict's keys state_dict = state["state_dict"] new_state_dict = OrderedDict() for k, v in state_dict.items(): if k.startswith("module."): k = k[7:] new_state_dict[k] = v state["state_dict"] = new_state_dict # save model epoch = state["epoch"] if not model_name: model_name = "model.pth.tar-" + str(epoch) fpath = osp.join(save_dir, model_name) torch.save(state, fpath) print(f"Checkpoint saved to {fpath}") # save current model name checkpoint_file = osp.join(save_dir, "checkpoint") checkpoint = open(checkpoint_file, "w+") checkpoint.write("{}\n".format(osp.basename(fpath))) checkpoint.close() if is_best: best_fpath = osp.join(osp.dirname(fpath), "model-best.pth.tar") shutil.copy(fpath, best_fpath) print('Best checkpoint saved to "{}"'.format(best_fpath)) def load_checkpoint(fpath): r"""Load checkpoint. ``UnicodeDecodeError`` can be well handled, which means python2-saved files can be read from python3. Args: fpath (str): path to checkpoint. Returns: dict Examples:: >>> fpath = 'log/my_model/model.pth.tar-10' >>> checkpoint = load_checkpoint(fpath) """ if fpath is None: raise ValueError("File path is None") if not osp.exists(fpath): raise FileNotFoundError('File is not found at "{}"'.format(fpath)) map_location = None if torch.cuda.is_available() else "cpu" try: checkpoint = torch.load(fpath, map_location=map_location) except UnicodeDecodeError: pickle.load = partial(pickle.load, encoding="latin1") pickle.Unpickler = partial(pickle.Unpickler, encoding="latin1") checkpoint = torch.load( fpath, pickle_module=pickle, map_location=map_location ) except Exception: print('Unable to load checkpoint from "{}"'.format(fpath)) raise return checkpoint def resume_from_checkpoint(fdir, model, optimizer=None, scheduler=None): r"""Resume training from a checkpoint. This will load (1) model weights and (2) ``state_dict`` of optimizer if ``optimizer`` is not None. Args: fdir (str): directory where the model was saved. model (nn.Module): model. optimizer (Optimizer, optional): an Optimizer. scheduler (Scheduler, optional): an Scheduler. Returns: int: start_epoch. Examples:: >>> fdir = 'log/my_model' >>> start_epoch = resume_from_checkpoint(fdir, model, optimizer, scheduler) """ with open(osp.join(fdir, "checkpoint"), "r") as checkpoint: model_name = checkpoint.readlines()[0].strip("\n") fpath = osp.join(fdir, model_name) print('Loading checkpoint from "{}"'.format(fpath)) checkpoint = load_checkpoint(fpath) model.load_state_dict(checkpoint["state_dict"]) print("Loaded model weights") if optimizer is not None and "optimizer" in checkpoint.keys(): optimizer.load_state_dict(checkpoint["optimizer"]) print("Loaded optimizer") if scheduler is not None and "scheduler" in checkpoint.keys(): scheduler.load_state_dict(checkpoint["scheduler"]) print("Loaded scheduler") start_epoch = checkpoint["epoch"] print("Previous epoch: {}".format(start_epoch)) return start_epoch def adjust_learning_rate( optimizer, base_lr, epoch, stepsize=20, gamma=0.1, linear_decay=False, final_lr=0, max_epoch=100, ): r"""Adjust learning rate. Deprecated. """ if linear_decay: # linearly decay learning rate from base_lr to final_lr frac_done = epoch / max_epoch lr = frac_done*final_lr + (1.0-frac_done) * base_lr else: # decay learning rate by gamma for every stepsize lr = base_lr * (gamma**(epoch // stepsize)) for param_group in optimizer.param_groups: param_group["lr"] = lr def set_bn_to_eval(m): r"""Set BatchNorm layers to eval mode.""" # 1. no update for running mean and var # 2. scale and shift parameters are still trainable classname = m.__class__.__name__ if classname.find("BatchNorm") != -1: m.eval() def open_all_layers(model): r"""Open all layers in model for training. Examples:: >>> open_all_layers(model) """ model.train() for p in model.parameters(): p.requires_grad = True def open_specified_layers(model, open_layers): r"""Open specified layers in model for training while keeping other layers frozen. Args: model (nn.Module): neural net model. open_layers (str or list): layers open for training. Examples:: >>> # Only model.classifier will be updated. >>> open_layers = 'classifier' >>> open_specified_layers(model, open_layers) >>> # Only model.fc and model.classifier will be updated. >>> open_layers = ['fc', 'classifier'] >>> open_specified_layers(model, open_layers) """ if isinstance(model, nn.DataParallel): model = model.module if isinstance(open_layers, str): open_layers = [open_layers] for layer in open_layers: assert hasattr(model, layer), f"{layer} is not an attribute" for name, module in model.named_children(): if name in open_layers: module.train() for p in module.parameters(): p.requires_grad = True else: module.eval() for p in module.parameters(): p.requires_grad = False def count_num_param(model=None, params=None): r"""Count number of parameters in a model. Args: model (nn.Module): network model. params: network model`s params. Examples:: >>> model_size = count_num_param(model) """ if model is not None: return sum(p.numel() for p in model.parameters()) if params is not None: s = 0 for p in params: if isinstance(p, dict): s += p["params"].numel() else: s += p.numel() return s raise ValueError("model and params must provide at least one.") def load_pretrained_weights(model, weight_path): r"""Load pretrianed weights to model. Features:: - Incompatible layers (unmatched in name or size) will be ignored. - Can automatically deal with keys containing "module.". Args: model (nn.Module): network model. weight_path (str): path to pretrained weights. Examples:: >>> weight_path = 'log/my_model/model-best.pth.tar' >>> load_pretrained_weights(model, weight_path) """ checkpoint = load_checkpoint(weight_path) if "state_dict" in checkpoint: state_dict = checkpoint["state_dict"] else: state_dict = checkpoint model_dict = model.state_dict() new_state_dict = OrderedDict() matched_layers, discarded_layers = [], [] for k, v in state_dict.items(): if k.startswith("module."): k = k[7:] # discard module. if k in model_dict and model_dict[k].size() == v.size(): new_state_dict[k] = v matched_layers.append(k) else: discarded_layers.append(k) model_dict.update(new_state_dict) model.load_state_dict(model_dict) if len(matched_layers) == 0: warnings.warn( f"Cannot load {weight_path} (check the key names manually)" ) else: print(f"Successfully loaded pretrained weights from {weight_path}") if len(discarded_layers) > 0: print( f"Layers discarded due to unmatched keys or size: {discarded_layers}" ) def init_network_weights(model, init_type="normal", gain=0.02): def _init_func(m): classname = m.__class__.__name__ if hasattr(m, "weight") and ( classname.find("Conv") != -1 or classname.find("Linear") != -1 ): if init_type == "normal": nn.init.normal_(m.weight.data, 0.0, gain) elif init_type == "xavier": nn.init.xavier_normal_(m.weight.data, gain=gain) elif init_type == "kaiming": nn.init.kaiming_normal_(m.weight.data, a=0, mode="fan_in") elif init_type == "orthogonal": nn.init.orthogonal_(m.weight.data, gain=gain) else: raise NotImplementedError if hasattr(m, "bias") and m.bias is not None: nn.init.constant_(m.bias.data, 0.0) elif classname.find("BatchNorm") != -1: nn.init.constant_(m.weight.data, 1.0) nn.init.constant_(m.bias.data, 0.0) elif classname.find("InstanceNorm") != -1: if m.weight is not None and m.bias is not None: nn.init.constant_(m.weight.data, 1.0) nn.init.constant_(m.bias.data, 0.0) model.apply(_init_func)