in curiosity/predictors.py [0:0]
def predict_json(self, inputs: JsonDict) -> JsonDict:
dialogs = inputs['dialogs']
out = []
for i, d in enumerate(dialogs):
if i == 30:
# Early termination to save time
break
instance = self._dataset_reader.text_to_instance(d)
prediction = self.predict_instance(instance)
# Label predictions for this dialog
label_prediction = {
'dialog_id': d['dialog_id']
}
for k, v in prediction.items():
if k != 'loss':
label_prediction[k] = np.argmax(v, axis=1).tolist()
out.append(label_prediction)
return out