torchbenchmark/util/backends/cudagraph.py (23 lines of code) (raw):
import torch
from typing import Tuple
def enable_cudagraph(model: 'torchbenchmark.util.model.BenchmarkModel', example_inputs: Tuple[torch.tensor]):
optimizer = model.optimizer
loss_fn = model.loss_fn
# warmup
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(3):
optimizer.zero_grad(set_to_none=True)
y_pred = model.model(*example_inputs)
loss = loss_fn(y_pred, model.example_outputs)
loss.backward()
optimizer.step()
torch.cuda.current_stream().wait_stream(s)
# capture
g = torch.cuda.CUDAGraph()
optimizer.zero_grad(set_to_none=True)
with torch.cuda.graph(g):
static_y_pred = model.model(*example_inputs)
static_loss = loss_fn(static_y_pred, model.example_outputs)
static_loss.backward()
optimizer.step()
model.g = g