in codes/rnn_models.py [0:0]
def __init__(self, args):
super().__init__(args)
## Model specific init
params = get_model_args()
params.encoder_type = args.encoder_type
nlipath = "mnli" if "nlipath" not in args else args.nlipath
params.nlipath = "ocnli" if "ocnli" in args.model_name else nlipath
if "outputdir" in args:
print(f"Changing otuputdir to {args.outputdir}")
params.outputdir = args.outputdir
print(f"NLIPATH : {params.nlipath}")
self.params = params
exp_folder = os.path.join(
params.outputdir,
params.nlipath,
params.encoder_type,
"exp_seed_{}".format(params.seed),
)
if not os.path.exists(exp_folder):
os.makedirs(exp_folder)
# set proper name
save_folder_name = os.path.join(exp_folder, "model")
params.outputmodelname = os.path.join(
save_folder_name, "{}_model.pkl".format(params.encoder_type)
)
if params.nlipath != "ocnli":
word_vec = getEmbeddingWeights(
[
_.split("\n")[0]
for _ in open(
"rnn_models/vocab/glove_" + params.nlipath + "_vocab.txt"
).readlines()
],
params.nlipath,
)
else:
word_vec = None
model_config = get_model_config(params)
print(
params.nlipath,
params.outputmodelname,
word_vec is None,
"-----------!!------------",
)
self.model = NLINet(model_config, weights=word_vec)
self.model.load_state_dict(
torch.load(params.outputmodelname, map_location=torch.device("cpu"))
)
self.model.eval()
# nothing to do here, as its handled internally
self.label_fn = {"c": "c", "n": "n", "e": "e"}