in metaicl/model.py [0:0]
def load(self, checkpoint=None, gpt2="gpt2-large"):
'''
checkpoint can be either keyword of the model or path to the checkpoint file
'''
if checkpoint is not None and checkpoint.startswith("gpt"):
gpt2 = checkpoint
checkpoint = None
if checkpoint is None:
if gpt2.startswith("gpt2"):
model = AutoModelForCausalLM.from_pretrained(gpt2)
elif gpt2=="gpt-j-6B":
model = AutoModelForCausalLM.from_pretrained("/checkpoint/sewonmin/gpt-j")
else:
raise NotImplementedError(checkpoint)
self.model_name = gpt2
else:
self.model_name = checkpoint
_id = get_checkpoint_id(checkpoint)
if _id is not None:
method, setting, _id = _id
keyword = checkpoint
checkpoint = os.path.join("checkpoints", method, setting)
if self.local_rank <= 0:
if os.path.exists(checkpoint):
self.logger.info("Reusing checkpoint at %s" % checkpoint)
else:
self.logger.info("Downloading %s in %s", keyword, checkpoint)
download_file(_id, checkpoint)
assert os.path.exists(checkpoint), checkpoint
if self.local_rank <= 0:
self.logger.info("Loading the model from %s" % checkpoint)
state_dict = torch.load(checkpoint)
model = AutoModelForCausalLM.from_pretrained(gpt2, state_dict=state_dict)
self.model = model