def on_train_batch_start()

in sample_workloads/lit-gpt-demo/utilities/nsight_callbacks.py [0:0]


    def on_train_batch_start(self, batch_idx: int, gradient_accumulation_steps: int) -> None:
        global_batch_idx = batch_idx / gradient_accumulation_steps
        if (
            global_batch_idx > 0
            and global_batch_idx % self.nsys_profile_step_multiple == 0
        ):
            print(f"Starting Nsys profiling")
            torch.cuda.cudart().cudaProfilerStart()