in training/distributed_training/pytorch/model_parallel/bert/bert_example/sagemaker_smp_pretrain.py [0:0]
def prepare_model_and_optimizer(args, device):
# Prepare model
config = modeling.BertConfig.from_json_file(args.config_file)
# Padding for divisibility by 8
if config.vocab_size % 8 != 0:
config.vocab_size += 8 - (config.vocab_size % 8)
if args.use_sequential > 0:
config.use_sequential = True
else:
config.use_sequential = False
modeling.ACT2FN["bias_gelu"] = modeling.bias_gelu_training
model = modeling.BertForPreTraining(config)
model.checkpoint_activations(args.checkpoint_activations)
if args.smp > 0:
# SMP: Use the DistributedModel container to provide the model
# to be partitioned across different ranks. For the rest of the script,
# the returned DistributedModel object should be used in place of
# the model provided for DistributedModel class instantiation.
model = smp.DistributedModel(model)
checkpoint = None
if not args.resume_from_checkpoint:
global_step = 0
else:
if not args.init_checkpoint:
if not args.s3_checkpoint_uri:
raise ValueError("Need to set s3_checkpoint_uri, if init_checkpoint not set")
if smp.local_rank() == 0:
sync_s3_checkpoints_to_local(args.output_dir, args.s3_checkpoint_uri)
smp.barrier()
if args.resume_step == -1 and not args.init_checkpoint:
model_names = [f for f in os.listdir(args.output_dir) if ".pt" in f]
args.resume_step = max(
[int(x.split(".pt")[0].split("_")[1].strip()) for x in model_names]
)
global_step = args.resume_step if not args.init_checkpoint else 0
# SMP: Load a model that was saved with smp.save
if not args.init_checkpoint:
checkpoint = smp.load(
os.path.join(args.output_dir, "ckpt_{}.pt".format(global_step)),
partial=args.partial_checkpoint,
)
else:
checkpoint = smp.load(args.init_checkpoint)
model.load_state_dict(checkpoint["model"], strict=False)
if args.phase2 and not args.init_checkpoint:
global_step -= args.phase1_end_step
if is_main_process():
print("resume step from ", args.resume_step)
model.to(device)
param_optimizer = list(model.named_parameters())
no_decay = ["bias", "gamma", "beta", "LayerNorm"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
"weight_decay": 0.01,
},
{
"params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
optimizer = FusedLAMB(optimizer_grouped_parameters, lr=args.learning_rate)
if args.smp > 0:
# SMP: Use Distributed Optimizer which allows the loading of optimizer state for a distributed model
# Also provides APIs to obtain local optimizer state for the current mp_rank.
optimizer = smp.DistributedOptimizer(optimizer)
lr_scheduler = PolyWarmUpScheduler(
optimizer, warmup=args.warmup_proportion, total_steps=args.max_steps
)
if args.fp16:
if args.loss_scale == 0:
model, optimizer = amp.initialize(
model,
optimizer,
opt_level="O2",
loss_scale="dynamic",
cast_model_outputs=torch.float16,
)
else:
model, optimizer = amp.initialize(
model,
optimizer,
opt_level="O2",
loss_scale=args.loss_scale,
cast_model_outputs=torch.float16,
)
amp._amp_state.loss_scalers[0]._loss_scale = args.init_loss_scale
if args.resume_from_checkpoint:
if args.phase2 or args.init_checkpoint:
keys = list(checkpoint["optimizer"]["state"].keys())
# Override hyperparameters from previous checkpoint
for key in keys:
checkpoint["optimizer"]["state"][key]["step"] = global_step
for iter, item in enumerate(checkpoint["optimizer"]["param_groups"]):
checkpoint["optimizer"]["param_groups"][iter]["step"] = global_step
checkpoint["optimizer"]["param_groups"][iter]["t_total"] = args.max_steps
checkpoint["optimizer"]["param_groups"][iter]["warmup"] = args.warmup_proportion
checkpoint["optimizer"]["param_groups"][iter]["lr"] = args.learning_rate
optimizer.load_state_dict(checkpoint["optimizer"]) # , strict=False)
# Restore AMP master parameters
if args.fp16:
optimizer._lazy_init_maybe_master_weights()
optimizer._amp_stash.lazy_init_called = True
optimizer.load_state_dict(checkpoint["optimizer"])
for param, saved_param in zip(
amp.master_params(optimizer), checkpoint["master params"]
):
param.data.copy_(saved_param.data)
# if args.local_rank != -1:
# if not args.allreduce_post_accumulation:
# model = DDP(model, message_size=250000000, gradient_predivide_factor=get_world_size())
# else:
# flat_dist_call([param.data for param in model.parameters()], torch.distributed.broadcast, (0,) )
# elif args.n_gpu > 1:
# model = torch.nn.DataParallel(model)
criterion = BertPretrainingCriterion(config.vocab_size)
return model, optimizer, lr_scheduler, checkpoint, global_step, criterion