in notebooks/utils.py [0:0]
def epic100_unseen_tail_eval(probs, dataset):
"""
probs: contains 3 elements: predictions for verb, noun and action
"""
# based on https://github.com/fpv-iplab/rulstm/blob/d44612e4c351ff668f149e2f9bc870f1e000f113/RULSTM/main.py#L379
unseen_participants_ids = pd.read_csv(osp.join(
dataset.rulstm_annotation_dir,
'validation_unseen_participants_ids.csv'),
names=['id'],
squeeze=True)
tail_verbs_ids = pd.read_csv(osp.join(dataset.rulstm_annotation_dir,
'validation_tail_verbs_ids.csv'),
names=['id'],
squeeze=True)
tail_nouns_ids = pd.read_csv(osp.join(dataset.rulstm_annotation_dir,
'validation_tail_nouns_ids.csv'),
names=['id'],
squeeze=True)
tail_actions_ids = pd.read_csv(osp.join(dataset.rulstm_annotation_dir,
'validation_tail_actions_ids.csv'),
names=['id'],
squeeze=True)
# Now based on https://github.com/fpv-iplab/rulstm/blob/d44612e4c351ff668f149e2f9bc870f1e000f113/RULSTM/main.py#L495
unseen_bool_idx = dataset.df.narration_id.isin(
unseen_participants_ids).values
tail_verbs_bool_idx = dataset.df.narration_id.isin(tail_verbs_ids).values
tail_nouns_bool_idx = dataset.df.narration_id.isin(tail_nouns_ids).values
tail_actions_bool_idx = dataset.df.narration_id.isin(
tail_actions_ids).values
# For tail
_, _, vrec5_tail, _, _ = compute_accuracy(
probs[0][tail_verbs_bool_idx],
dataset.df.verb_class.values[tail_verbs_bool_idx])
_, _, nrec5_tail, _, _ = compute_accuracy(
probs[1][tail_nouns_bool_idx],
dataset.df.noun_class.values[tail_nouns_bool_idx])
_, _, arec5_tail, _, _ = compute_accuracy(
probs[2][tail_actions_bool_idx],
dataset.df.action_class.values[tail_actions_bool_idx])
# for unseen
_, _, vrec5_unseen, _, _ = compute_accuracy(
probs[0][unseen_bool_idx],
dataset.df.verb_class.values[unseen_bool_idx])
_, _, nrec5_unseen, _, _ = compute_accuracy(
probs[1][unseen_bool_idx],
dataset.df.noun_class.values[unseen_bool_idx])
_, _, arec5_unseen, _, _ = compute_accuracy(
probs[2][unseen_bool_idx],
dataset.df.action_class.values[unseen_bool_idx])
return dict(
vrec5_tail=vrec5_tail,
nrec5_tail=nrec5_tail,
arec5_tail=arec5_tail,
vrec5_unseen=vrec5_unseen,
nrec5_unseen=nrec5_unseen,
arec5_unseen=arec5_unseen,
)