jukebox/utils/fp16.py (238 lines of code) (raw):

# Utils for fp16 training. import importlib import math import numpy as np import torch import jukebox.utils.dist_adapter as dist from torch.optim import Optimizer from torch._utils import _flatten_dense_tensors from jukebox.utils.dist_utils import allreduce def adam_step(p: torch.Tensor, out_p: torch.Tensor, exp_avg: torch.Tensor, exp_avg_sq: torch.Tensor, grad: torch.Tensor, lr: float, beta1: float, beta2: float, eps: float, scale: float, step: int, eps_mode: int, bias_correction: int, weight_decay: float): assert bias_correction == 1 assert eps_mode == 1 grad = grad.float() grad.div_(scale) # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) denom = exp_avg_sq.sqrt().add_(eps) bias_correction1 = 1 - beta1 ** step bias_correction2 = 1 - beta2 ** step step_size = lr * math.sqrt(bias_correction2) / bias_correction1 p.add_(exp_avg/denom + weight_decay*p.float(), alpha=-step_size) # Import fused_adam if we have apex, otherwise use regular adam try: fused_adam_cuda = importlib.import_module("fused_adam_cuda") fused_adam_step = fused_adam_cuda.adam print("Using apex fused_adam_cuda") except ModuleNotFoundError: fused_adam_step = adam_step def backward(loss, params, scalar, fp16, logger): # Perform backward if not fp16: scale = 1.0 loss.backward() gn = grad_norm(params, scale) return loss, scale, gn, False, False else: scale = scalar.get_scale() loss = (loss.float())*scale overflow_loss = check_overflow(loss.item()) overflow_loss = allreduce(int(overflow_loss), op=dist.ReduceOp.MAX) > 0 if not overflow_loss: loss.backward() gn = grad_norm(params, scale) overflow_grad = check_overflow(gn) overflow_grad = allreduce(int(overflow_grad), op=dist.ReduceOp.MAX) > 0 scalar.update_scale(overflow_grad) else: gn = 0.0 overflow_grad = True loss = (loss.detach().float()) / scale # Should delete computation graph for overflow if logger.rank == 0: if loss > 12.: print(f"\nWarning. Loss is {loss}") if overflow_loss: print(f"\nOverflow in forward. Loss {loss}, lgscale {np.log2(scale)}. Skipping batch completely (no backward, scale update)") elif overflow_grad: print(f"\nOverflow in backward. Loss {loss}, grad norm {gn}, lgscale {np.log2(scale)}, new lgscale {np.log2(scalar.get_scale())}") return loss, scale, gn, overflow_loss, overflow_grad # Automatic loss scaling class LossScalar(object): def __init__(self, loss_scale, init_scale=2. ** 16, scale_factor=2. ** (1. / 1000), scale_window=1): if loss_scale == None: # Use dynamic loss scaling self.dynamic = True self.loss_scale = init_scale else: self.dynamic = False self.loss_scale = loss_scale self.max_loss_scale = 2.**24 self.scale_factor = scale_factor self.scale_window = scale_window self.unskipped = 0 self.overflow = False def get_scale(self): return self.loss_scale def update_scale(self, overflow): if overflow and self.dynamic: self.loss_scale /= 2. self.unskipped = 0 else: self.unskipped += 1 if self.unskipped == self.scale_window and self.dynamic: self.loss_scale = min(self.max_loss_scale, self.loss_scale * self.scale_factor) self.unskipped = 0 def check_overflow(val): return (val == float('inf')) or (val == -float('inf')) or (val != val) def grad_norm(params, scale, flat=False): params = list(params) if flat: # Faster but more memory fp16_grads = [p.grad for p in params if p.grad is not None and p.data.dtype == torch.float16] fp16_norm = 0.0 if len(fp16_grads) == 0 else float(_flatten_dense_tensors(fp16_grads).norm(p=2, dtype=torch.float32)) fp32_grads = [p.grad for p in params if p.grad is not None and p.data.dtype != torch.float16] fp32_norm = 0.0 if len(fp32_grads) == 0 else float(_flatten_dense_tensors(fp32_grads).norm(p=2)) grad_norm = (fp16_norm**2 + fp32_norm**2)**0.5 else: # Slightly slower but less memory grad_norm = 0.0 for p in params: if p.grad is not None: grad_norm += p.grad.norm(p=2, dtype=torch.float32)**2 grad_norm = float(grad_norm**0.5) return grad_norm / scale def clipped_grad_scale(grad_norm, max_grad_norm, scale): clip = grad_norm / max_grad_norm if clip > 1: scale = clip * scale return scale class FP16FusedAdam(Optimizer): def __init__( self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-8, eps_inside_sqrt=False, weight_decay=0.0, amsgrad=False, ): if amsgrad: raise RuntimeError("FusedAdam does not support the AMSGrad variant.") defaults = dict( lr=lr, bias_correction=bias_correction, betas=betas, eps=eps, weight_decay=weight_decay ) super(FP16FusedAdam, self).__init__(params, defaults) self.eps_mode = 0 if eps_inside_sqrt else 1 self.FLOAT16_MAX = 65504.0 self.init_state() def init_state(self): for group in self.param_groups: for p in group["params"]: assert p.requires_grad == True state = self.state[p] if len(state) == 0: state["step"] = 0 # Exponential moving average of gradient values state["exp_avg"] = torch.zeros_like(p.data) # Exponential moving average of squared gradient values state["exp_avg_sq"] = torch.zeros_like(p.data) if p.data.dtype == torch.float16: state["scale_exp_avg"] = 1.0 state["scale_exp_avg_sq"] = 1.0 def step(self, closure=None, scale=1.0): """Performs a single optimization step. Scales gradients down by scale Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. scale (float, optional): factor to divide gradient tensor values by before applying to weights. (default: 1) """ loss = None if closure is not None: loss = closure() for group in self.param_groups: bias_correction = 1 if group["bias_correction"] else 0 for p in group["params"]: if p.grad is None: continue grad = p.grad.data state = self.state[p] if p.data.dtype == torch.float16: exp_avg, exp_avg_sq = ( state["exp_avg"].float() * state["scale_exp_avg"], state["exp_avg_sq"].float() * state["scale_exp_avg_sq"], ) else: exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] beta1, beta2 = group["betas"] state["step"] += 1 out_p = torch.tensor([], dtype=torch.float) fused_adam_step( p.data, out_p, exp_avg, exp_avg_sq, grad, group["lr"], beta1, beta2, group["eps"], scale, state["step"], self.eps_mode, bias_correction, group["weight_decay"], ) if p.data.dtype == torch.float16: state["scale_exp_avg"] = ( 1e-8 + float(torch.norm(exp_avg, float("inf"))) / self.FLOAT16_MAX ) state["scale_exp_avg_sq"] = ( 1e-8 + float(torch.norm(exp_avg_sq, float("inf"))) / self.FLOAT16_MAX ) state["exp_avg"] = (exp_avg / state["scale_exp_avg"]).half() state["exp_avg_sq"] = (exp_avg_sq / state["scale_exp_avg_sq"]).half() return loss class FusedAdam(Optimizer): def __init__( self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-8, eps_inside_sqrt=False, weight_decay=0.0, amsgrad=False, ): if amsgrad: raise RuntimeError("FusedAdam does not support the AMSGrad variant.") defaults = dict( lr=lr, bias_correction=bias_correction, betas=betas, eps=eps, weight_decay=weight_decay ) super(FusedAdam, self).__init__(params, defaults) self.eps_mode = 0 if eps_inside_sqrt else 1 def step(self, closure=None, scale=1.0): """Performs a single optimization step. Scales gradients down by scale Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. scale (float, optional): factor to divide gradient tensor values by before applying to weights. (default: 1) """ loss = None if closure is not None: loss = closure() for group in self.param_groups: bias_correction = 1 if group["bias_correction"] else 0 for p in group["params"]: if p.grad is None: continue grad = p.grad.data state = self.state[p] # State initialization if len(state) == 0: state["step"] = 0 # Exponential moving average of gradient values state["exp_avg"] = torch.zeros_like(p.data).float() # Exponential moving average of squared gradient values state["exp_avg_sq"] = torch.zeros_like(p.data).float() exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] beta1, beta2 = group["betas"] state["step"] += 1 out_p = torch.tensor([], dtype=torch.float) fused_adam_step( p.data, out_p, exp_avg, exp_avg_sq, grad, group["lr"], beta1, beta2, group["eps"], scale, state["step"], self.eps_mode, bias_correction, group["weight_decay"], ) return loss