in benchmarks/pipe.py [0:0]
def train(model_config, model, benchmark_config, model_specs, args):
lm_dataloader, _, _ = utils.get_data_loader(model_config["dataset_info"], args, benchmark_config, model_specs)
criterion = benchmark_config["criterion"]
vocab_size = model_specs["vocab_size"]
optimizer = model_config["optimizer"]
model.train()
utils.log_number_of_parameters(model)
total_loss = 0.0
word_counter = 0
optimizer = optimizer(model.parameters())
pipe_group = model.group if hasattr(model, "group") else None
# TODO(anj-s): Avoid sending fake data to all replicas except the first and last one.
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
if pipe_group and pipe_group.rank() != 0 and pipe_group.rank() != (pipe_group.size() - 1):
lm_dataloader, _, _ = get_synthetic_dataloaders(args, benchmark_config, model_specs)
total_tokens = 0
total_tokens_per_log_interval = 0
bptt = 2
start_time = time.time()
epoch_start_time = 0.0
def get_batch(source):
seq_len = len(source) - 1
data = source[0:seq_len]
target = source[1 : 1 + seq_len]
return data, target
for i, batch in enumerate(lm_dataloader):
if i == 1:
epoch_start_time = time.time()
source, target = get_batch(batch)
if args.max_batch and i > args.max_batch:
break
if i > 0:
total_tokens += source.numel()
optimizer.zero_grad()
try:
if pipe_group is None or pipe_group.rank() == 0:
tmp = source.to(get_device(model, 0))
output = model(tmp)
else:
output = model(source)
except Exception as e:
raise RuntimeError(f"training failed on {torch.distributed.get_rank()}") from e
if pipe_group is None or pipe_group.rank() == pipe_group.size() - 1:
target = target.to(get_device(model, -1))
output = output.to(target.device)
loss = criterion(output.view(-1, vocab_size), target.view(-1))
loss.backward()
del target
else:
model.back_helper(output)
del output
torch.nn.utils.clip_grad_value_(model.parameters(), model_specs["clip_value"])
optimizer.step()
if pipe_group is None or pipe_group.rank() == pipe_group.size() - 1:
total_loss += loss.item()
log_interval = 1
total_tokens_per_log_interval += source.numel()
if i % log_interval == 0 and i > 0:
cur_loss = total_loss / log_interval
elapsed = time.time() - start_time
if dist.get_rank() == dist.get_world_size() - 1:
logging.debug(
"| batch {:5d} | wps {:5.2f} | loss {:5.2f} | ppl {:8.2f}".format(
i, total_tokens_per_log_interval / elapsed, cur_loss, math.exp(cur_loss)
)
)
total_tokens_per_log_interval = 0
total_loss = 0
start_time = time.time()
if epoch_start_time != 0:
wps = total_tokens / (time.time() - epoch_start_time)
else:
raise RuntimeError(
"Unable to benchmark on a single batch. Increase the size " " of the dataset and rerun the benchmark."
)
if dist.get_rank() == dist.get_world_size() - 1:
return wps, loss.item()
else:
return 0.0, 0.0