def get_model()

in src/model.py [0:0]


def get_model(args, vocab_size):

    # build image encoder
    encoder_image = EncoderCNN(args.embed_size, args.dropout_encoder,
                               args.image_model)

    use_empty_set = (True if args.dataset in ['coco', 'nuswide'] else False)

    # build set predictor
    if args.decoder == 'ff':
        print(
            'Building feed-forward decoder. Embed size {} / Dropout {} / '
            'Cardinality Prediction {} / Max. Num. Labels {} / Num. Layers {}'.format(
                args.embed_size, args.dropout_decoder, args.pred_cardinality, args.maxnumlabels,
                args.ff_layers),
            flush=True)

        decoder = FFDecoder(
            args.embed_size,
            vocab_size,
            args.embed_size,
            dropout=args.dropout_decoder,
            pred_cardinality=args.pred_cardinality,
            nobjects=args.maxnumlabels,
            n_layers=args.ff_layers,
            use_empty_set=use_empty_set)

    elif args.decoder == 'lstm':
        print(
            'Building LSTM decoder. Embed size {} / Dropout {} / Max. Num. Labels {}. '.format(
                args.embed_size, args.dropout_decoder, args.maxnumlabels),
            flush=True)

        decoder = DecoderRNN(
            args.embed_size,
            args.embed_size,
            vocab_size,
            dropout=args.dropout_decoder,
            seq_length=args.maxnumlabels,
            num_instrs=1)

    elif args.decoder == 'tf':
        print(
            'Building Transformer decoder. Embed size {} / Dropout {} / Max. Num. Labels {} / '
            'Num. Attention Heads {} / Num. Layers {}.'.format(
                args.embed_size, args.dropout_decoder, args.maxnumlabels, args.n_att,
                args.tf_layers),
            flush=True)

        decoder = DecoderTransformer(
            args.embed_size,
            vocab_size,
            dropout=args.dropout_decoder,
            seq_length=args.maxnumlabels,
            num_instrs=1,
            attention_nheads=args.n_att,
            pos_embeddings=False,
            num_layers=args.tf_layers,
            learned=False,
            normalize_before=True)

    # label and eos loss
    label_losses = {
        'bce': nn.BCEWithLogitsLoss(reduction='mean') if args.decoder == 'ff' else nn.BCELoss(reduction='mean'), 
        'iou': softIoULoss(reduction='mean'),
        'td': targetDistLoss(reduction='mean'), 
    }
    pad_value = vocab_size - 1
    print('Using {} loss.'.format(args.label_loss), flush=True)
    if args.decoder == 'ff':
        label_loss = label_losses[args.label_loss]
        eos_loss = None
    elif args.decoder in ['tf', 'lstm'] and args.perminv:
        label_loss = label_losses[args.label_loss]
        eos_loss = nn.BCELoss(reduction='mean')
    else:
        label_loss = nn.CrossEntropyLoss(ignore_index=pad_value, reduction='mean')
        eos_loss = None

    # cardinality loss
    if args.pred_cardinality == 'dc':
        print('Using Dirichlet-Categorical cardinality loss.', flush=True)
        cardinality_loss = DCLoss(U=args.U, dataset=args.dataset, reduction='mean')
    elif args.pred_cardinality == 'cat':
        print('Using categorical cardinality loss.', flush=True)
        cardinality_loss = nn.CrossEntropyLoss(reduction='mean')
    else:
        print('Using no cardinality loss.', flush=True)
        cardinality_loss = None

    model = SetPred(
        decoder,
        encoder_image,
        args.maxnumlabels,
        crit=label_loss,
        crit_eos=eos_loss,
        crit_cardinality=cardinality_loss,
        pad_value=pad_value,
        perminv=args.perminv,
        decoder_ff=True if args.decoder == 'ff' else False,
        th=args.th,
        loss_label=args.label_loss,
        replacement=args.replacement,
        card_type=args.pred_cardinality,
        dataset=args.dataset,
        U=args.U,
        use_empty_set=use_empty_set)

    return model