in benchmarks/experimental/offload.py [0:0]
def train(model_config, model, benchmark_config, model_specs, args):
device = torch.device("cuda")
torch.cuda.set_device(0)
lm_dataloader, _, _ = model_config["data"]
criterion = benchmark_config["criterion"]
vocab_size = model_specs["vocab_size"]
optimizer = model_config["optimizer"]
model.train()
log_number_of_parameters(model)
total_loss = 0.0
word_counter = 0
optimizer = optimizer(model.parameters())
total_tokens = 0
total_tokens_per_log_interval = 0
bptt = 2
start_time = time.time()
epoch_start_time = 0.0
def get_batch(source):
seq_len = len(source) - 1
data = source[0:seq_len]
target = source[1 : 1 + seq_len]
return data, target
for i, batch in enumerate(lm_dataloader):
# TODO(anj): Make this a flag for both "lm" and "seq" models.
if i == 5:
break
if i == 1:
epoch_start_time = time.time()
source, target = get_batch(batch)
source, target = source.cuda(), target.cuda()
if i > 0:
total_tokens += source.numel()
with _get_profiler_context(args.use_profiler) as prof:
optimizer.zero_grad()
with _get_profiler_record_context("FW pass", args.use_profiler):
output = model(source)
with _get_profiler_record_context("Loss", args.use_profiler):
loss = criterion(output.view(-1, vocab_size), target.view(-1))
with _get_profiler_record_context("BW pass", args.use_profiler):
loss.backward()
torch.nn.utils.clip_grad_value_(model.parameters(), model_specs["clip_value"])
with _get_profiler_record_context("Opt step", args.use_profiler):
optimizer.step()
total_loss += loss.item()
log_interval = 1
total_tokens_per_log_interval += source.numel()
if i % log_interval == 0 and i > 0:
cur_loss = total_loss / log_interval
elapsed = time.time() - start_time
print(
"| batch {:5d} | wps {:5.2f} | loss {:5.2f} | ppl {:8.2f}".format(
i, total_tokens_per_log_interval / elapsed, cur_loss, math.exp(cur_loss)
)
)
total_tokens_per_log_interval = 0
total_loss = 0
start_time = time.time()
if args.use_profiler:
prof.export_chrome_trace("/tmp/offload_prof")
if epoch_start_time != 0:
wps = total_tokens / (time.time() - epoch_start_time)
else:
raise RuntimeError(
"Unable to benchmark on a single batch. Increase the size " " of the dataset and rerun the benchmark."
)
return wps, loss.item()