in utils/pipeline_utils.py [0:0]
def cudagraph(f):
from torch.utils._pytree import tree_map_only
_graphs = {}
def f_(*args, **kwargs):
key = hash(tuple(tuple(kwargs[a].shape) for a in sorted(kwargs.keys())
if isinstance(kwargs[a], torch.Tensor)))
if key in _graphs:
# use the cached wrapper if one exists. this will perform CUDAGraph replay
wrapped, *_ = _graphs[key]
return wrapped(*args, **kwargs)
# record a new CUDAGraph and cache it for future use
g = torch.cuda.CUDAGraph()
in_args, in_kwargs = tree_map_only(torch.Tensor, lambda t: t.clone(), (args, kwargs))
f(*in_args, **in_kwargs) # stream warmup
with torch.cuda.graph(g):
out_tensors = f(*in_args, **in_kwargs)
def wrapped(*args, **kwargs):
# note that CUDAGraphs require inputs / outputs to be in fixed memory locations.
# inputs must be copied into the fixed input memory locations.
[a.copy_(b) for a, b in zip(in_args, args) if isinstance(a, torch.Tensor)]
for key in kwargs:
if isinstance(kwargs[key], torch.Tensor):
in_kwargs[key].copy_(kwargs[key])
g.replay()
# clone() outputs on the way out to disconnect them from the fixed output memory
# locations. this allows for CUDAGraph reuse without accidentally overwriting memory
return [o.clone() for o in out_tensors]
# cache function that does CUDAGraph replay
_graphs[key] = (wrapped, g, in_args, in_kwargs, out_tensors)
return wrapped(*args, **kwargs)
return f_