in XLM/src/model/__init__.py [0:0]
def build_model(params, dico):
"""
Build model.
"""
if params.encoder_only:
# build
model = TransformerModel(params, dico, is_encoder=True, with_output=True)
# reload pretrained word embeddings
if params.reload_emb != '':
word2id, embeddings = load_embeddings(params.reload_emb, params)
set_pretrain_emb(model, dico, word2id, embeddings)
# reload a pretrained model
if params.reload_model != '':
logger.info("Reloading model from %s ..." % params.reload_model)
reloaded = torch.load(params.reload_model, map_location=lambda storage, loc: storage.cuda(params.local_rank))['model']
if all([k.startswith('module.') for k in reloaded.keys()]):
reloaded = {k[len('module.'):]: v for k, v in reloaded.items()}
# # HACK to reload models with less layers
# for i in range(12, 24):
# for k in TRANSFORMER_LAYER_PARAMS:
# k = k % i
# if k in model.state_dict() and k not in reloaded:
# logger.warning("Parameter %s not found. Ignoring ..." % k)
# reloaded[k] = model.state_dict()[k]
model.load_state_dict(reloaded, strict=False)
logger.info("Model: {}".format(model))
logger.info("Number of parameters (model): %i" % sum([p.numel() for p in model.parameters() if p.requires_grad]))
return model.cuda()
else:
# build
encoder = TransformerModel(params, dico, is_encoder=True, with_output=False) # TODO: only output when necessary - len(params.clm_steps + params.mlm_steps) > 0
decoder = TransformerModel(params, dico, is_encoder=False, with_output=True)
# reload pretrained word embeddings
if params.reload_emb != '':
word2id, embeddings = load_embeddings(params.reload_emb, params)
set_pretrain_emb(encoder, dico, word2id, embeddings)
set_pretrain_emb(decoder, dico, word2id, embeddings)
# reload a pretrained model
if params.reload_model != '':
enc_path, dec_path = params.reload_model.split(',')
assert not (enc_path == '' and dec_path == '')
# reload encoder
if enc_path != '':
logger.info("Reloading encoder from %s ..." % enc_path)
enc_reload = torch.load(enc_path, map_location=lambda storage, loc: storage.cuda(params.local_rank))
enc_reload = enc_reload['model' if 'model' in enc_reload else 'encoder']
if all([k.startswith('module.') for k in enc_reload.keys()]):
enc_reload = {k[len('module.'):]: v for k, v in enc_reload.items()}
encoder.load_state_dict(enc_reload, strict=False)
# reload decoder
if dec_path != '':
logger.info("Reloading decoder from %s ..." % dec_path)
dec_reload = torch.load(dec_path, map_location=lambda storage, loc: storage.cuda(params.local_rank))
dec_reload = dec_reload['model' if 'model' in dec_reload else 'decoder']
if all([k.startswith('module.') for k in dec_reload.keys()]):
dec_reload = {k[len('module.'):]: v for k, v in dec_reload.items()}
# If pre-trained model has more unused weights, init the decoder with these weights
# NB: dec_reload 'pred_layer.proj.weight' not in decoder
num_keys_fixed = 0
for i in range(params.n_layers, 2 * params.n_layers):
keys_to_fix = [k for k in dec_reload.keys() if f'.{i}.' in k]
for k in keys_to_fix:
new_k = k.replace(f'.{i}.', f'.{i % params.n_layers}.')
dec_reload.pop(new_k) # Check that you're replacing an existing key
dec_reload[new_k] = dec_reload.pop(k)
num_keys_fixed += 1
logger.info("Keys fixed while reloading decoder: %i ..." % num_keys_fixed)
for i in range(params.n_layers):
for name in DECODER_ONLY_PARAMS:
if name % i not in dec_reload:
logger.warning("Parameter %s not found." % (name % i))
dec_reload[name % i] = decoder.state_dict()[name % i]
decoder.load_state_dict(dec_reload, strict=False)
logger.debug("Encoder: {}".format(encoder))
logger.debug("Decoder: {}".format(decoder))
logger.info("Number of parameters (encoder): %i" % sum([p.numel() for p in encoder.parameters() if p.requires_grad]))
logger.info("Number of parameters (decoder): %i" % sum([p.numel() for p in decoder.parameters() if p.requires_grad]))
return encoder.cuda(), decoder.cuda()