tinynn/util/train_util.py (130 lines of code) (raw):

import os import typing import torch from .util import get_logger log = get_logger(__name__, 'INFO') class DLContext(object): def __init__( self, train_loader=None, train_sampler=None, val_loader=None, criterion=None, optimizer=None, scheduler=None, iter_scheduler=None, warmup_scheduler=None, epoch=0, max_epoch=None, iteration=0, warmup_iteration=0, max_iteration=None, gpu=None, device=None, grad_scaler=None, print_freq=50, train_func=None, validate_fun=None, custom_args: dict = None, ): self.train_loader = train_loader self.train_sampler = train_sampler self.val_loader = val_loader self.criterion = criterion self.optimizer = optimizer self.scheduler = scheduler self.iter_scheduler = iter_scheduler self.warmup_scheduler = warmup_scheduler self.epoch = epoch self.max_epoch = max_epoch self.iteration = iteration self.warmup_iteration = warmup_iteration self.max_iteration = max_iteration self.gpu = gpu self.device = device self.grad_scaler = grad_scaler self.print_freq = print_freq self.best_acc = 0.0 self.best_epoch = -1 self.train_func = train_func self.validate_func = validate_fun self.custom_args = custom_args class AverageMeter(object): """ Computes and stores the average and current value """ def __init__(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count def train( model, context: DLContext, train_func: typing.Callable[[torch.nn.Module, DLContext], None], validate_func: typing.Callable[[torch.nn.Module, DLContext], float], distributed: bool = False, main_process: bool = True, qat: bool = False, work_dir: str = None, ): """The main function for the whole train process Args: model: The model to be trained context (DLContext): the DLContext object train_func (typing.Callable[[torch.nn.Module, DLContext], None]): The function to train the model by one step validate_func (typing.Callable[[torch.nn.Module, DLContext], float]): The function to get the \ accuracy of the model distributed (bool, optional): Whether DDP training is used. Defaults to False. main_process (bool, optional): Whether the code runs in the main process. Defaults to True. qat (bool, optional): Whether to perform quantization-aware training. Defaults to False. work_dir (str, optional): Working directory. Defaults to None, which means "out". """ if work_dir is None: work_dir = 'out' os.makedirs(work_dir, exist_ok=True) if isinstance(context.criterion, torch.nn.Module): context.criterion = context.criterion.to(device=context.device) best_acc = 0 for i in range(context.max_epoch): context.epoch = i # calling `set_epoch` is required in distributd training if distributed: context.train_loader.sampler.set_epoch(i) train_func(model, context) # qat specific if qat: if context.epoch == context.max_epoch // 3: log.info("freeze quantizer parameters") model.apply(torch.quantization.disable_observer) elif context.epoch == context.max_epoch // 3 * 2: log.info("freeze batch norm mean and variance estimates") model.apply(torch.nn.intrinsic.qat.freeze_bn_stats) # only validate and save model in the main process if main_process: if distributed: # According to https://github.com/pytorch/pytorch/issues/54059, when validating via DDP, # it needs to be done on the original module. acc = validate_func(model.module, context) else: acc = validate_func(model, context) if qat: if context.epoch == context.max_epoch - 1: save_path = os.path.join(work_dir, 'qat_last_model.pth') torch.save(model.state_dict(), save_path) else: if acc > best_acc: best_acc = acc log.info(f"Best acc: {best_acc}") save_path = os.path.join(work_dir, 'best_model.pth') torch.save(model, save_path) # only validate the final model in the main process if main_process: if not qat: load_path = os.path.join(work_dir, 'best_model.pth') model.load_state_dict(torch.load(load_path).state_dict()) validate_func(model, context) def get_device() -> torch.device: """Gets the default device Returns: [torch.device]: The default device """ if torch.cuda.is_available(): device = torch.device("cuda", 0) else: device = torch.device("cpu") return device def get_module_device(module: torch.nn.Module) -> typing.Optional[torch.device]: """Gets the device of the module Args: module (torch.nn.Module): The given module Returns: typing.Optional[torch.device]: The device of the module """ assert isinstance(module, torch.nn.Module) device = None try: first_param = next(module.parameters()) device = first_param.device except StopIteration: pass return device