def create_model()

in self_supervision_benchmark/modeling/jigsaw/alexnet_jigsaw_finetune_full.py [0:0]


def create_model(model, data, labels, split):
    num_classes = cfg.MODEL.NUM_CLASSES
    test_mode = False
    if split in ['test', 'val']:
        test_mode = True

    if cfg.MODEL.FC_INIT_TYPE == 'gaussian':
        fc_weight_init = ('GaussianFill', {'std': 0.01})
        fc_bias_init = ('ConstantFill', {'value': 0.})
    elif cfg.MODEL.FC_INIT_TYPE == 'xavier':
        fc_weight_init = ('XavierFill', {})
        fc_bias_init = ('ConstantFill', {'value': 0.})
    elif cfg.MODEL.FC_INIT_TYPE == 'msra':
        fc_weight_init = ('MSRAFill', {})
        fc_bias_init = ('ConstantFill', {'value': 0.})

    ################################ conv1 ####################################
    conv1 = model.Conv(
        'data', 'conv1_s0', 3, 96, 11, stride=4,
        weight_init=('GaussianFill', {'std': 0.01}),
        bias_init=('ConstantFill', {'value': 0.0}),
    )
    relu1 = model.Relu(conv1, conv1)
    pool1 = model.MaxPool(relu1, 'pool1_s0', kernel=3, stride=2)
    lrn1 = model.LRN(pool1, 'norm1_s0', size=5, alpha=0.0001, beta=0.75)

    ################################ conv2 ####################################
    conv2 = model.Conv(
        lrn1, 'conv2_s0', 96, 256, 5,
        weight_init=('GaussianFill', {'std': 0.01}),
        bias_init=('ConstantFill', {'value': 1.0}), pad=2, group=2,
    )
    relu2 = model.Relu(conv2, conv2)
    pool2 = model.MaxPool(relu2, 'pool2_s0', kernel=3, stride=2)
    lrn2 = model.LRN(pool2, 'norm2_s0', size=5, alpha=0.0001, beta=0.75)

    ################################ conv3 ####################################
    conv3 = model.Conv(
        lrn2, 'conv3_s0', 256, 384, 3,
        weight_init=('GaussianFill', {'std': 0.01}),
        bias_init=('ConstantFill', {'value': 0.0}), pad=1,
    )
    bn3 = model.SpatialBN(
        conv3, 'conv3_s0_bn', 384, epsilon=cfg.MODEL.BN_EPSILON,
        momentum=cfg.MODEL.BN_MOMENTUM, is_test=test_mode,
    )
    if cfg.MODEL.BN_NO_SCALE_SHIFT:
        model.param_init_net.ConstantFill([bn3 + '_s'], bn3 + '_s', value=1.0)
        model.param_init_net.ConstantFill([bn3 + '_b'], bn3 + '_b', value=0.0)
    relu3 = model.Relu(bn3, bn3 + '_relu')

    ################################ conv4 ####################################
    conv4 = model.Conv(
        relu3, 'conv4_s0', 384, 384, 3,
        weight_init=('GaussianFill', {'std': 0.01}),
        bias_init=('ConstantFill', {'value': 1.0}), pad=1, group=2,
    )
    bn4 = model.SpatialBN(
        conv4, 'conv4_s0_bn', 384, epsilon=cfg.MODEL.BN_EPSILON,
        momentum=cfg.MODEL.BN_MOMENTUM, is_test=test_mode,
    )
    if cfg.MODEL.BN_NO_SCALE_SHIFT:
        model.param_init_net.ConstantFill([bn4 + '_s'], bn4 + '_s', value=1.0)
        model.param_init_net.ConstantFill([bn4 + '_b'], bn4 + '_b', value=0.0)
    relu4 = model.Relu(bn4, bn4 + '_relu')

    ################################ conv5 ####################################
    conv5 = model.Conv(
        relu4, 'conv5_s0', 384, 256, 3,
        weight_init=('GaussianFill', {'std': 0.01}),
        bias_init=('ConstantFill', {}), pad=1, group=2,
    )
    bn5 = model.SpatialBN(
        conv5, 'conv5_s0_bn', 256, epsilon=cfg.MODEL.BN_EPSILON,
        momentum=cfg.MODEL.BN_MOMENTUM, is_test=test_mode,
    )
    if cfg.MODEL.BN_NO_SCALE_SHIFT:
        model.param_init_net.ConstantFill([bn5 + '_s'], bn5 + '_s', value=1.0)
        model.param_init_net.ConstantFill([bn5 + '_b'], bn5 + '_b', value=0.0)
    relu5 = model.Relu(bn5, bn5 + '_relu')
    pool5 = model.MaxPool(relu5, 'pool5_s0', kernel=3, stride=2)

    ################################## fc6 #####################################
    # in Jigsaw pretext, the fc6-8 are different so can't initialize from those
    fc6 = model.FC(
        pool5, 'fc6', 256 * 6 * 6, 4096,
        weight_init=fc_weight_init, bias_init=fc_bias_init,
    )
    fc6_reshape, _ = model.net.Reshape(
        fc6, [fc6 + '_reshape', fc6 + '_old_shape'], shape=(-1, 4096, 1, 1)
    )
    fc6_bn = model.SpatialBN(
        fc6_reshape, fc6_reshape + '_bn', 4096, epsilon=cfg.MODEL.BN_EPSILON,
        momentum=cfg.MODEL.BN_MOMENTUM, is_test=test_mode,
    )
    if cfg.MODEL.BN_NO_SCALE_SHIFT:
        model.param_init_net.ConstantFill(
            [fc6_bn + '_s'], fc6_bn + '_s', value=1.0
        )
        model.param_init_net.ConstantFill(
            [fc6_bn + '_b'], fc6_bn + '_b', value=0.0
        )
    blob_out = model.Relu(fc6_bn, fc6_bn + '_relu')
    if split not in ['test', 'val']:
        blob_out = model.Dropout(
            blob_out, blob_out + '_dropout', ratio=0.5, is_test=test_mode
        )

    ################################## fc7 #####################################
    # for fc7 and fc8, we still name blobs with suffix 's0' so that we can
    # distinguish it with the jigsaw model otherwise the blob names will clash
    # when using checkpointed model from jigsaw and fc7/8 will have dimension
    # mismatch
    fc7 = model.FC(
        blob_out, 'fc7_s0', 4096, 4096,
        weight_init=fc_weight_init, bias_init=fc_bias_init,
    )

    fc7_reshape, _ = model.net.Reshape(
        fc7, [fc7 + '_reshape', fc7 + '_old_shape'], shape=(-1, 4096, 1, 1)
    )
    fc7_bn = model.SpatialBN(
        fc7_reshape, fc7_reshape + '_bn', 4096, epsilon=cfg.MODEL.BN_EPSILON,
        momentum=cfg.MODEL.BN_MOMENTUM, is_test=test_mode,
    )
    if cfg.MODEL.BN_NO_SCALE_SHIFT:
        model.param_init_net.ConstantFill(
            [fc7_bn + '_s'], fc7_bn + '_s', value=1.0
        )
        model.param_init_net.ConstantFill(
            [fc7_bn + '_b'], fc7_bn + '_b', value=0.0
        )
    blob_out = model.Relu(fc7_bn, fc7_bn + '_relu')
    if split not in ['test', 'val']:
        blob_out = model.Dropout(
            blob_out, blob_out + '_dropout', ratio=0.5, is_test=test_mode
        )

    ################################## fc8 #####################################
    fc8 = model.FC(
        blob_out, 'fc8_s0', 4096, num_classes,
        weight_init=fc_weight_init,
        bias_init=fc_bias_init,
    )

    ################################## Sigmoid #################################
    # Sigmoid loss since VOC07 iss multi-label
    model.net.Alias(fc8, 'pred')
    sigmoid = model.net.Sigmoid('pred', 'sigmoid')
    scale = 1. / cfg.NUM_DEVICES
    if split == 'train':
        loss = model.net.SigmoidCrossEntropyLoss(
            ['pred', labels], 'loss', scale=scale
        )
    elif split in ['test', 'val']:
        loss = None
    return model, sigmoid, loss