sample_workloads/lit-gpt-demo/utilities/nsight_callbacks.py (44 lines of code) (raw):
import torch
import sys
from typing import Any
import nvtx
class NsightCallback:
def __init__(self):
self.nsys_profile_step_multiple = 5
self.backward_nvtx_range = None
def on_train_batch_start(self, batch_idx: int, gradient_accumulation_steps: int) -> None:
global_batch_idx = batch_idx / gradient_accumulation_steps
if (
global_batch_idx > 0
and global_batch_idx % self.nsys_profile_step_multiple == 0
):
print(f"Starting Nsys profiling")
torch.cuda.cudart().cudaProfilerStart()
def on_train_batch_end(
self, batch_idx: int, gradient_accumulation_steps: int
) -> None:
global_batch_idx = batch_idx // gradient_accumulation_steps
global_batch_offset = batch_idx % gradient_accumulation_steps
is_last_microbatch = global_batch_offset == gradient_accumulation_steps - 1
if (
global_batch_idx > 1
and global_batch_idx % self.nsys_profile_step_multiple == 0
and is_last_microbatch
):
print(f"Stopping Nsys profiling")
torch.cuda.cudart().cudaProfilerStop()
if is_last_microbatch:
print(f"HEARTBEAT: {global_batch_idx=}, {batch_idx=}")
print(
f"Max memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB"
)
sys.stdout.flush()
sys.stderr.flush()
def on_before_backward(self):
self.backward_nvtx_range = nvtx.start_range(message="backward", color="red")
def on_after_backward(self):
if self.backward_nvtx_range:
nvtx.end_range(self.backward_nvtx_range)
def on_train_epoch_start(self) -> None:
print("Resetting max memory allocation")
torch.cuda.reset_peak_memory_stats()