in code/scripts/lstm_alone.py [0:0]
def evaluate(model=None, model_name='', eval_input=''):
"""Evaluate the model on validation dataset.
"""
## Load model
bert, vocabulary = nlp.model.get_model('bert_12_768_12',
dataset_name='wiki_multilingual_uncased',
pretrained=True,
ctx=ctx,
use_pooler=False,
use_decoder=False,
use_classifier=False)
if model is None:
assert model_name != ''
model = ICSL(len(vocabulary), num_slot_labels=len(label2idx), num_intents=len(intent2idx))
model.initialize(ctx=ctx)
model.hybridize(static_alloc=True)
model.load_parameters(os.path.join(model_dir, model_name + '.params'))
idx2label = {}
for label, idx in label2idx.items():
idx2label[idx] = label
## Load dev dataset
field_separator = nlp.data.Splitter('\t')
field_indices = [1, 3, 4, 0]
eval_data = nlp.data.TSVDataset(filename=eval_input,
field_separator=field_separator,
num_discard_samples=1,
field_indices=field_indices)
bert_tokenizer = nlp.data.BERTTokenizer(vocabulary, lower=True)
dev_alignment = {}
eval_data_transform = []
for sample in eval_data:
sample, alignment = icsl_transform(sample, vocabulary, label2idx, intent2idx, bert_tokenizer)
eval_data_transform += [sample]
dev_alignment[sample[0]] = alignment
log.info('The number of examples after preprocessing: {}'
.format(len(eval_data_transform)))
test_batch_size = 16
pad_token_id = vocabulary[PAD]
pad_label_id = label2idx[PAD]
batchify_fn = nlp.data.batchify.Tuple(
nlp.data.batchify.Stack(),
nlp.data.batchify.Pad(axis=0, pad_val=pad_token_id),
nlp.data.batchify.Pad(axis=0, pad_val=pad_label_id),
nlp.data.batchify.Stack('float32'),
nlp.data.batchify.Stack('float32'))
eval_dataloader = mx.gluon.data.DataLoader(
eval_data_transform,
batchify_fn=batchify_fn,
num_workers=4, batch_size=test_batch_size, shuffle=False, last_batch='keep')
_Result = collections.namedtuple(
'_Result', ['intent', 'slot_labels'])
all_results = {}
total_num = 0
for data in eval_dataloader:
example_ids, token_ids, _, _, valid_length = data
total_num += len(token_ids)
# load data to GPU
token_ids = token_ids.astype('float32').as_in_context(ctx[0])
valid_length = valid_length.astype('float32').as_in_context(ctx[0])
# forward computation
intent_pred, slot_pred = model(token_ids)
intent_pred = intent_pred.asnumpy()
slot_pred = slot_pred.asnumpy()
valid_length = valid_length.asnumpy()
for eid, y_intent, y_slot, length in zip(example_ids, intent_pred, slot_pred, valid_length):
eid = eid.asscalar()
length = int(length) - 2
intent_id = y_intent.argmax(axis=-1)
slot_ids = y_slot.argmax(axis=-1).tolist()[:length]
slot_names = [idx2label[idx] for idx in slot_ids]
merged_slot_names = merge_slots(slot_names, dev_alignment[eid] + [length])
if eid not in all_results:
all_results[eid] = _Result(intent_id, merged_slot_names)
example_ids, utterances, labels, intents = load_tsv(eval_input)
pred_intents = []
label_intents = []
for eid, intent in zip(example_ids, intents):
label_intents.append(label2index(intent2idx, intent))
pred_intents.append(all_results[eid].intent)
intent_acc = sklearn.metrics.accuracy_score(label_intents, pred_intents)
log.info("Intent Accuracy: %.4f" % intent_acc)
pred_icsl = []
label_icsl = []
for eid, intent, slot_labels in zip(example_ids, intents, labels):
label_icsl.append(str(label2index(intent2idx, intent)) + ' ' + ' '.join(slot_labels))
pred_icsl.append(str(all_results[eid].intent) + ' ' + ' '.join(all_results[eid].slot_labels))
exact_match = sklearn.metrics.accuracy_score(label_icsl, pred_icsl)
log.info("Exact Match: %.4f" % exact_match)
with open(conll_prediction_file, "w") as fw:
for eid, utterance, labels in zip(example_ids, utterances, labels):
preds = all_results[eid].slot_labels
for w, l, p in zip(utterance, labels, preds):
fw.write(' '.join([w, l, p]) + '\n')
fw.write('\n')
proc = subprocess.Popen(["perl", "conlleval.pl"], stdin=subprocess.PIPE, stdout=subprocess.PIPE)
with open(conll_prediction_file) as f:
stdout = proc.communicate(f.read().encode())[0]
result = stdout.decode('utf-8').split('\n')[1]
slot_f1 = float(result.split()[-1].strip())
log.info("Slot Labeling: %s" % result)
return intent_acc, slot_f1