in weak_to_strong/train.py [0:0]
def maybe_load_model(model):
if os.path.exists(os.path.join(save_path, "results.pkl")) and not force_retrain:
print("loading from", save_path)
checkpoint_path = os.path.join(save_path, "pytorch_model.bin")
if not os.path.exists(checkpoint_path):
# Assume this means we have a sharded checkpoint, and load it appropriately
load_sharded_checkpoint(model, checkpoint_path)
else:
state_dict = torch.load(os.path.join(save_path, "pytorch_model.bin"))
state_dict = {
k.replace("transformer.module", "transformer"): v
for (k, v) in state_dict.items()
}
custom_kwargs["state_dict"] = state_dict
return True
return False