in src/model.py [0:0]
def get_model(args, vocab_size):
# build image encoder
encoder_image = EncoderCNN(args.embed_size, args.dropout_encoder,
args.image_model)
use_empty_set = (True if args.dataset in ['coco', 'nuswide'] else False)
# build set predictor
if args.decoder == 'ff':
print(
'Building feed-forward decoder. Embed size {} / Dropout {} / '
'Cardinality Prediction {} / Max. Num. Labels {} / Num. Layers {}'.format(
args.embed_size, args.dropout_decoder, args.pred_cardinality, args.maxnumlabels,
args.ff_layers),
flush=True)
decoder = FFDecoder(
args.embed_size,
vocab_size,
args.embed_size,
dropout=args.dropout_decoder,
pred_cardinality=args.pred_cardinality,
nobjects=args.maxnumlabels,
n_layers=args.ff_layers,
use_empty_set=use_empty_set)
elif args.decoder == 'lstm':
print(
'Building LSTM decoder. Embed size {} / Dropout {} / Max. Num. Labels {}. '.format(
args.embed_size, args.dropout_decoder, args.maxnumlabels),
flush=True)
decoder = DecoderRNN(
args.embed_size,
args.embed_size,
vocab_size,
dropout=args.dropout_decoder,
seq_length=args.maxnumlabels,
num_instrs=1)
elif args.decoder == 'tf':
print(
'Building Transformer decoder. Embed size {} / Dropout {} / Max. Num. Labels {} / '
'Num. Attention Heads {} / Num. Layers {}.'.format(
args.embed_size, args.dropout_decoder, args.maxnumlabels, args.n_att,
args.tf_layers),
flush=True)
decoder = DecoderTransformer(
args.embed_size,
vocab_size,
dropout=args.dropout_decoder,
seq_length=args.maxnumlabels,
num_instrs=1,
attention_nheads=args.n_att,
pos_embeddings=False,
num_layers=args.tf_layers,
learned=False,
normalize_before=True)
# label and eos loss
label_losses = {
'bce': nn.BCEWithLogitsLoss(reduction='mean') if args.decoder == 'ff' else nn.BCELoss(reduction='mean'),
'iou': softIoULoss(reduction='mean'),
'td': targetDistLoss(reduction='mean'),
}
pad_value = vocab_size - 1
print('Using {} loss.'.format(args.label_loss), flush=True)
if args.decoder == 'ff':
label_loss = label_losses[args.label_loss]
eos_loss = None
elif args.decoder in ['tf', 'lstm'] and args.perminv:
label_loss = label_losses[args.label_loss]
eos_loss = nn.BCELoss(reduction='mean')
else:
label_loss = nn.CrossEntropyLoss(ignore_index=pad_value, reduction='mean')
eos_loss = None
# cardinality loss
if args.pred_cardinality == 'dc':
print('Using Dirichlet-Categorical cardinality loss.', flush=True)
cardinality_loss = DCLoss(U=args.U, dataset=args.dataset, reduction='mean')
elif args.pred_cardinality == 'cat':
print('Using categorical cardinality loss.', flush=True)
cardinality_loss = nn.CrossEntropyLoss(reduction='mean')
else:
print('Using no cardinality loss.', flush=True)
cardinality_loss = None
model = SetPred(
decoder,
encoder_image,
args.maxnumlabels,
crit=label_loss,
crit_eos=eos_loss,
crit_cardinality=cardinality_loss,
pad_value=pad_value,
perminv=args.perminv,
decoder_ff=True if args.decoder == 'ff' else False,
th=args.th,
loss_label=args.label_loss,
replacement=args.replacement,
card_type=args.pred_cardinality,
dataset=args.dataset,
U=args.U,
use_empty_set=use_empty_set)
return model