lib/torch_util.py (136 lines of code) (raw):
import functools
import itertools
import math
import os
import pickle
import re
import subprocess
import tempfile
from contextlib import contextmanager
from hashlib import md5, sha1
import numpy as np
import torch as th
import torch.distributed as dist
import torch.distributions as dis
import torch.nn.functional as F
from torch import nn
import lib.tree_util as tree_util
from lib import misc
def contextmanager_to_decorator(cm):
def decorator(fn):
@functools.wraps(fn)
def newfn(*args, **kwargs):
with cm():
return fn(*args, **kwargs)
return newfn
return decorator
def have_cuda():
return th.has_cuda
def default_device_type():
return "cuda" if have_cuda() else "cpu"
no_grad = contextmanager_to_decorator(th.no_grad)
DEFAULT_DEVICE = th.device(type=default_device_type())
def set_default_torch_device(device):
global DEFAULT_DEVICE
DEFAULT_DEVICE = th.device(device)
def dev():
return DEFAULT_DEVICE
def zeros(*args, **kwargs):
return th.zeros(*args, **kwargs, device=dev())
def ones(*args, **kwargs):
return th.ones(*args, **kwargs, device=dev())
def arange(*args, **kwargs):
return th.arange(*args, **kwargs, device=dev())
def NormedLinear(*args, scale=1.0, dtype=th.float32, **kwargs):
"""
nn.Linear but with normalized fan-in init
"""
dtype = parse_dtype(dtype)
if dtype == th.float32:
out = nn.Linear(*args, **kwargs)
elif dtype == th.float16:
out = LinearF16(*args, **kwargs)
else:
raise ValueError(dtype)
out.weight.data *= scale / out.weight.norm(dim=1, p=2, keepdim=True)
if kwargs.get("bias", True):
out.bias.data *= 0
return out
class LinearF16(nn.Linear):
def forward(self, x):
return F.linear(x, self.weight.half(), self.bias.half() if self.bias is not None else None)
class LayerNormF16(nn.LayerNorm):
def forward(self, x):
return F.layer_norm(x, self.normalized_shape, self.weight.half(), self.bias.half(), self.eps)
def LayerNorm(*args, dtype=th.float32, **kwargs):
dtype = parse_dtype(dtype)
if dtype == th.float32:
out = nn.LayerNorm(*args, **kwargs)
elif dtype == th.float16:
out = LayerNormF16(*args, **kwargs)
else:
raise ValueError(dtype)
out.weight.no_scale = True
return out
def flatten_image(x):
"""
Flattens last three dims
"""
*batch_shape, h, w, c = x.shape
return x.reshape((*batch_shape, h * w * c))
def sequential(layers, x, *args, diag_name=None, use_checkpoint=False):
for (i, layer) in enumerate(layers):
x = layer(x, *args)
return x
@no_grad
def load_average_with_metadata(paths, overrides):
n_models = len(paths)
model, metadata = load_with_metadata(paths[0], overrides=overrides)
for p in model.parameters():
p.mul_(1 / n_models)
for p in paths[1:]:
new_model, _ = load_with_metadata(p, overrides=overrides)
for (n1, p1), (n2, p2) in misc.safezip(model.named_parameters(), new_model.named_parameters()):
assert n1 == n2, f"names {n1} and {n2} don't match"
p1.add_(p2.mul_(1 / n_models))
return model, metadata
def save_kwargs(fn):
"""
This decorator passes through the user-provided kwargs and adds one more, called
save_kwargs, mapping to {"create_fn" : name_of_decorated_fn, "kwargs" : other_kwargs}
You put on this decorator on a function that creates a pytorch module. This will
save the kwargs and the function that was used to create the module.
This lets us restore the model state later.
"""
@functools.wraps(fn)
def wrapper(**kwargs):
if "save_kwargs" in kwargs:
return fn(**kwargs)
else:
sk = {**kwargs, "create_fn": f"{fn.__module__}:{fn.__name__}"}
return fn(save_kwargs=sk, **kwargs)
return wrapper
def parse_dtype(x):
if isinstance(x, th.dtype):
return x
elif isinstance(x, str):
if x == "float32" or x == "float":
return th.float32
elif x == "float64" or x == "double":
return th.float64
elif x == "float16" or x == "half":
return th.float16
elif x == "uint8":
return th.uint8
elif x == "int8":
return th.int8
elif x == "int16" or x == "short":
return th.int16
elif x == "int32" or x == "int":
return th.int32
elif x == "int64" or x == "long":
return th.int64
elif x == "bool":
return th.bool
else:
raise ValueError(f"cannot parse {x} as a dtype")
else:
raise TypeError(f"cannot parse {type(x)} as dtype")
def index(x, i):
"""
Batched, broadcasting index of x along dimension i.ndim.
For example, if x has shape (1, 2, 3, 4, 5) and i has shape (1, 1, 3)
then the result has shape (1, 2, 3, 5) and each value in i must be between 0 and 3.
"""
assert x.ndim >= i.ndim + 1
gather_dim = i.ndim
while i.ndim < x.ndim:
i = i.unsqueeze(-1)
expand_shape = list(x.shape)
expand_shape[gather_dim] = 1
i = i.expand(*expand_shape)
xi = th.gather(x, gather_dim, i)
assert xi.shape[gather_dim] == 1
return xi.squeeze(gather_dim)