in archs/models.py [0:0]
def __init__(self,
dset,
args,
num_layers=2,
num_modules_per_layer=3,
stoch_sample=False,
use_full_model=False,
num_classes=[2],
gater_type='general'):
"""TODO: to be defined1.
:dset: TODO
:args: TODO
"""
CompositionalModel.__init__(self, dset, args)
self.train_forward = self.train_forward_softmax
self.compose_type = args.compose_type
gating_in_dim = 128
if args.glove_init:
gating_in_dim = 300
elif args.clf_init:
gating_in_dim = 512
if self.compose_type == 'nn':
tdim = gating_in_dim * 2
inter_tdim = self.args.embed_rank
# Change this to allow only obj, only attr gatings
self.attr_embedder = nn.Embedding(
len(dset.attrs) + 1,
gating_in_dim,
padding_idx=len(dset.attrs),
)
self.obj_embedder = nn.Embedding(
len(dset.objs) + 1,
gating_in_dim,
padding_idx=len(dset.objs),
)
# initialize the weights of the embedders with the svm weights
if args.glove_init:
pretrained_weight = utils.load_word_embeddings(
'data/glove/glove.6B.300d.txt', dset.attrs)
self.attr_embedder.weight[:-1, :].data.copy_(pretrained_weight)
pretrained_weight = utils.load_word_embeddings(
'data/glove/glove.6B.300d.txt', dset.objs)
self.obj_embedder.weight.data[:-1, :].copy_(pretrained_weight)
elif args.clf_init:
for idx, attr in enumerate(dset.attrs):
at_id = self.dset.attr2idx[attr]
weight = torch.load(
'%s/svm/attr_%d' % (args.data_dir,
at_id)).coef_.squeeze()
self.attr_embedder.weight[idx].data.copy_(
torch.from_numpy(weight))
for idx, obj in enumerate(dset.objs):
obj_id = self.dset.obj2idx[obj]
weight = torch.load(
'%s/svm/obj_%d' % (args.data_dir,
obj_id)).coef_.squeeze()
self.obj_embedder.weight[idx].data.copy_(
torch.from_numpy(weight))
else:
n_attr = len(dset.predicates)
gating_in_dim = 300
tdim = gating_in_dim * 2 + n_attr
self.attr_embedder = nn.Embedding(
n_attr,
n_attr,
)
self.attr_embedder.weight.data.copy_(
torch.from_numpy(np.eye(n_attr)))
self.obj_embedder = nn.Embedding(
len(dset.objs) + 1,
gating_in_dim,
padding_idx=len(dset.objs),
)
pretrained_weight = utils.load_word_embeddings(
'data/glove/glove.6B.300d.txt', dset.objs)
self.obj_embedder.weight.data[:-1, :].copy_(pretrained_weight)
else:
raise (NotImplementedError)
self.comp_network, self.gating_network, self.nummods, _ = modular_general(
num_layers=num_layers,
num_modules_per_layer=num_modules_per_layer,
feat_dim=dset.feat_dim,
inter_dim=args.emb_dim,
stoch_sample=stoch_sample,
use_full_model=use_full_model,
tdim=tdim,
inter_tdim=inter_tdim,
gater_type=gater_type,
)
if args.static_inp:
for param in self.attr_embedder.parameters():
param.requires_grad = False
for param in self.obj_embedder.parameters():
param.requires_grad = False