in training/distributed_training/pytorch/model_parallel/bert/bert_example/sagemaker_smp_pretrain.py [0:0]
def main():
global timeout_sent
args = parse_arguments()
random.seed(args.seed + args.local_rank)
np.random.seed(args.seed + args.local_rank)
torch.manual_seed(args.seed + args.local_rank)
torch.cuda.manual_seed(args.seed + args.local_rank)
worker_init = WorkerInitObj(args.seed + args.local_rank)
device, args = setup_training(args)
# Prepare optimizer
(
model,
optimizer,
lr_scheduler,
checkpoint,
global_step,
criterion,
) = prepare_model_and_optimizer(args, device)
raw_train_start = None
most_recent_ckpts_paths = []
average_loss = 0.0 # averaged loss every args.log_freq steps
epoch = 0
training_steps = 0
test_losses = []
pool = ProcessPoolExecutor(1)
# Note: We loop infinitely over epochs, termination is handled via iteration count
while True:
thread = None
restored_data_loader = None
if (
not args.resume_from_checkpoint
or epoch > 0
or (args.phase2 and global_step < 1)
or args.init_checkpoint
):
files = [
os.path.join(args.input_dir, f)
for f in os.listdir(args.input_dir)
if os.path.isfile(os.path.join(args.input_dir, f)) and "training" in f
]
files.sort()
num_files = len(files)
random.Random(args.seed + epoch).shuffle(files)
f_start_id = 0
else:
f_start_id = checkpoint["files"][0]
files = checkpoint["files"][1:]
args.resume_from_checkpoint = False
num_files = len(files)
# may not exist in all checkpoints
epoch = checkpoint.get("epoch", 0)
restored_dataloader = checkpoint.get("data_loader", None)
shared_file_list = {}
if smp.is_initialized():
dpsize = smp.dp_size()
dprank = smp.dp_rank()
elif torch.distributed.is_initialized():
dpsize = get_world_size()
dprank = get_rank()
else:
dpsize = 1
dprank = 0
dparallel = dpsize > 1
if dparallel and dpsize > num_files:
remainder = dpsize % num_files
data_file = files[(f_start_id * dpsize + dprank + remainder * f_start_id) % num_files]
else:
data_file = files[(f_start_id * dpsize + dprank) % num_files]
previous_file = data_file
if restored_data_loader is None:
train_data = pretraining_dataset(data_file, args.max_predictions_per_seq)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(
train_data,
sampler=train_sampler,
batch_size=args.train_batch_size * args.n_gpu,
num_workers=4,
worker_init_fn=worker_init,
pin_memory=True,
drop_last=True,
)
# shared_file_list["0"] = (train_dataloader, data_file)
else:
train_dataloader = restored_data_loader
restored_data_loader = None
overflow_buf = None
if args.allreduce_post_accumulation:
overflow_buf = torch.cuda.IntTensor([0])
for f_id in range(f_start_id + 1, len(files)):
if get_world_size() > num_files:
data_file = files[
(f_id * get_world_size() + get_rank() + remainder * f_id) % num_files
]
else:
data_file = files[(f_id * get_world_size() + get_rank()) % num_files]
previous_file = data_file
dataset_future = pool.submit(
create_pretraining_dataset,
data_file,
args.max_predictions_per_seq,
shared_file_list,
args,
worker_init,
)
train_iter = (
tqdm(train_dataloader, desc="Iteration", disable=args.disable_progress_bar)
if is_main_process()
else train_dataloader
)
if raw_train_start is None:
raw_train_start = time.time()
for step, batch in enumerate(train_iter):
training_steps += 1
batch = [t.to(device) for t in batch]
input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels = batch
if args.do_train:
from smdistributed.modelparallel.test.torch.utils import dump_model, verify
model.train()
if args.smp > 0:
loss_mbs = smp_step(
args,
device,
input_ids,
segment_ids,
input_mask,
masked_lm_labels,
next_sentence_labels,
model,
optimizer,
criterion,
step,
)
loss = loss_mbs.reduce_mean()
if smp.rank() == 0:
print("Loss:", loss.item())
else:
loss = train_step(
args,
device,
input_ids,
segment_ids,
input_mask,
masked_lm_labels,
next_sentence_labels,
model,
optimizer,
criterion,
step,
)
divisor = 1
average_loss += loss.item()
if training_steps % args.gradient_accumulation_steps == 0:
lr_scheduler.step() # learning rate warmup
global_step = take_optimizer_step(
args, optimizer, model, overflow_buf, global_step
)
if global_step >= args.steps_this_run or timeout_sent:
train_time_raw = time.time() - raw_train_start
last_num_steps = (
int(training_steps / args.gradient_accumulation_steps) % args.log_freq
)
last_num_steps = args.log_freq if last_num_steps == 0 else last_num_steps
average_loss = torch.tensor(average_loss, dtype=torch.float32).cuda()
average_loss = average_loss / (last_num_steps * divisor)
if torch.distributed.is_initialized():
average_loss /= get_world_size()
torch.distributed.all_reduce(average_loss)
final_loss = loss.item()
elif training_steps % (args.log_freq * args.gradient_accumulation_steps) == 0:
average_loss = 0
if (
global_step >= args.steps_this_run
or training_steps
% (args.num_steps_per_checkpoint * args.gradient_accumulation_steps)
== 0
or timeout_sent
):
if smp.dp_rank() == 0 and not args.skip_checkpoint:
if args.resume_step < 0 or not args.phase2:
output_save_file = os.path.join(
args.output_dir, "ckpt_{}.pt".format(global_step)
)
else:
output_save_file = os.path.join(
args.output_dir,
"ckpt_{}.pt".format(global_step + args.phase1_end_step),
)
if args.do_train:
save_dict = {
"model": model.local_state_dict(),
"optimizer": optimizer.local_state_dict(),
"files": [f_id] + files,
"epoch": epoch,
"data_loader": None
if global_step >= args.steps_this_run
else train_dataloader,
}
if args.fp16:
save_dict["master params"] = list(amp.master_params(optimizer))
# SMP: Checkpoint mp_rank specific state
smp.save(save_dict, output_save_file, partial=True)
most_recent_ckpts_paths.append(output_save_file)
if len(most_recent_ckpts_paths) > 3 and (
args.smp == 0 or smp.dp_rank() == 0
):
ckpt_to_be_removed = most_recent_ckpts_paths.pop(0)
os.remove(ckpt_to_be_removed + f"_{smp.mp_rank()}")
# Exiting the training due to hitting max steps, or being sent a
# timeout from the cluster scheduler
if global_step >= args.steps_this_run or timeout_sent:
del train_dataloader
# thread.join()
if smp.dp_rank() == 0 and args.save_full:
output_save_file = os.path.join(
args.output_dir, "ckpt_{}.pt".format(global_step)
)
save_dict = {
"model": model.local_state_dict(),
"optimizer": optimizer.local_state_dict(),
"files": [f_id] + files,
"epoch": epoch,
"data_loader": None
if global_step >= args.steps_this_run
else train_dataloader,
}
if args.fp16:
save_dict["master params"] = list(amp.master_params(optimizer))
# SMP: Save a single checkpoint containing entire model parameters
smp.save(save_dict, output_save_file, partial=False)
smp.barrier()
if smp.local_rank() == 0:
print(f"Start syncing model checkpoints to s3")
base_s3_path = os.path.dirname(
os.path.dirname(os.getenv("SM_MODULE_DIR", ""))
)
curr_host = os.getenv("SM_CURRENT_HOST")
full_s3_path = f"{base_s3_path}/checkpoints/{curr_host}/"
sync_local_checkpoints_to_s3(
local_path=args.output_dir, s3_path=full_s3_path
)
print(f"Finished syncing model checkpoints to s3")
return args, final_loss, train_time_raw, global_step
else:
model.eval()
with torch.no_grad():
loss = test_step(
args,
device,
input_ids,
segment_ids,
input_mask,
masked_lm_labels,
next_sentence_labels,
model,
criterion,
step,
)
print(f"global_step {global_step} Test Loss:", loss)
test_losses.append(loss)
global_step += 1
if global_step >= args.steps_this_run:
return sum(test_losses) / len(test_losses)
del train_dataloader
# thread.join()
# Make sure pool has finished and switch train_dataloader
# NOTE: Will block until complete
train_dataloader, data_file = dataset_future.result(timeout=None)
epoch += 1