in src/sase.py [0:0]
def main():
parser.add_argument("--input", type=str, default="", help="input file")
parser.add_argument("--model", type=str, default="", help="model path")
parser.add_argument("--spm_model", type=str, default="", help="spm model path")
parser.add_argument("--batch_size", type=int, default=64, help="batch size")
parser.add_argument("--max_words", type=int, default=100, help="max words")
parser.add_argument("--cuda", type=str, default="True", help="use cuda")
parser.add_argument("--output", type=str, default="", help="output file")
args = parser.parse_args()
# Reload a pretrained model
reloaded = torch.load(args.model)
params = AttrDict(reloaded['params'])
# Reload the SPM model
spm_model = spm.SentencePieceProcessor()
spm_model.Load(args.spm_model)
# cuda
assert args.cuda in ["True", "False"]
args.cuda = eval(args.cuda)
# build dictionary / update parameters
dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'], reloaded['dico_counts'])
params.n_words = len(dico)
params.bos_index = dico.index(BOS_WORD)
params.eos_index = dico.index(EOS_WORD)
params.pad_index = dico.index(PAD_WORD)
params.unk_index = dico.index(UNK_WORD)
params.mask_index = dico.index(MASK_WORD)
# build model / reload weights
model = TransformerModel(params, dico, True, True)
reloaded['model'] = OrderedDict({key.replace('module.', ''):reloaded['model'][key] for key in reloaded['model']})
model.load_state_dict(reloaded['model'])
model.eval()
if args.cuda:
model.cuda()
# load sentences
sentences = []
with open(args.input) as f:
for line in f:
line = spm_model.EncodeAsPieces(line.rstrip())
line = line[:args.max_words - 1]
sentences.append(line)
# encode sentences
embs = []
for i in range(0, len(sentences), args.batch_size):
batch = sentences[i:i+args.batch_size]
lengths = torch.LongTensor([len(s) + 1 for s in batch])
bs, slen = len(batch), lengths.max().item()
assert slen <= args.max_words
x = torch.LongTensor(slen, bs).fill_(params.pad_index)
for k in range(bs):
sent = torch.LongTensor([params.eos_index] + [dico.index(w) for w in batch[k]])
x[:len(sent), k] = sent
if args.cuda:
x = x.cuda()
lengths = lengths.cuda()
with torch.no_grad():
embedding = model('fwd', x=x, lengths=lengths, langs=None, causal=False).contiguous()[0].cpu()
embs.append(embedding)
# save embeddings
torch.save(torch.cat(embs, dim=0).squeeze(0), args.output)