in archs/models.py [0:0]
def __init__(self, dset, model):
self.dset = dset
# convert text pairs to idx tensors: [('sliced', 'apple'), ('ripe', 'apple'), ...] --> torch.LongTensor([[0,1],[1,1], ...])
pairs = [(dset.attr2idx[attr], dset.obj2idx[obj])
for attr, obj in dset.pairs]
self.train_pairs = [(dset.attr2idx[attr], dset.obj2idx[obj])
for attr, obj in dset.train_pairs]
self.pairs = torch.LongTensor(pairs)
# mask over pairs that occur in closed world
if dset.phase == 'train':
print('Evaluating with train pairs')
test_pair_set = set(dset.train_pairs)
elif dset.phase == 'val':
print('Evaluating with val pairs')
test_pair_set = set(dset.val_pairs + dset.train_pairs)
else:
print('Evaluating with test pairs')
test_pair_set = set(dset.test_pairs + dset.train_pairs)
self.test_pairs = [(dset.attr2idx[attr], dset.obj2idx[obj])
for attr, obj in list(test_pair_set)]
mask = [1 if pair in test_pair_set else 0 for pair in dset.pairs]
self.closed_mask = torch.ByteTensor(mask)
seen_pair_set = set(dset.train_pairs)
mask = [1 if pair in seen_pair_set else 0 for pair in dset.pairs]
self.seen_mask = torch.ByteTensor(mask)
# object specific mask over which pairs occur in the object oracle setting
oracle_obj_mask = []
for _obj in dset.objs:
mask = [1 if _obj == obj else 0 for attr, obj in dset.pairs]
oracle_obj_mask.append(torch.ByteTensor(mask))
self.oracle_obj_mask = torch.stack(oracle_obj_mask, 0)
# decide if the model being evaluated is a manifold model or not
mname = model.__class__.__name__
if 'VisualProduct' in mname:
self.score_model = self.score_clf_model
else:
self.score_model = self.score_manifold_model