def load()

in metaicl/model.py [0:0]


    def load(self, checkpoint=None, gpt2="gpt2-large"):
        '''
        checkpoint can be either keyword of the model or path to the checkpoint file
        '''
        if checkpoint is not None and checkpoint.startswith("gpt"):
            gpt2 = checkpoint
            checkpoint = None
        if checkpoint is None:
            if gpt2.startswith("gpt2"):
                model = AutoModelForCausalLM.from_pretrained(gpt2)
            elif gpt2=="gpt-j-6B":
                model = AutoModelForCausalLM.from_pretrained("/checkpoint/sewonmin/gpt-j")
            else:
                raise NotImplementedError(checkpoint)
            self.model_name = gpt2
        else:
            self.model_name = checkpoint
            _id = get_checkpoint_id(checkpoint)
            if _id is not None:
                method, setting, _id = _id
                keyword = checkpoint
                checkpoint = os.path.join("checkpoints", method, setting)
                if self.local_rank <= 0:
                    if os.path.exists(checkpoint):
                        self.logger.info("Reusing checkpoint at %s" % checkpoint)
                    else:
                        self.logger.info("Downloading %s in %s", keyword, checkpoint)
                    download_file(_id, checkpoint)

            assert os.path.exists(checkpoint), checkpoint
            if self.local_rank <= 0:
                self.logger.info("Loading the model from %s" % checkpoint)
            state_dict = torch.load(checkpoint)
            model = AutoModelForCausalLM.from_pretrained(gpt2, state_dict=state_dict)
        self.model = model