def get_backbone()

in experiments/sgd/detector/train_detector.py [0:0]


def get_backbone(args):

    try:
        genotype = eval('genotypes.%s' % args.arch)
        net_args = {'C': args.init_channels,  # 48 if genotype == DARTS else 128
                    'genotype': genotype,
                    'n_cells': args.layers,   # 14 if genotype == DARTS else 12
                    'C_mult': int(genotype != ViT) + 1,  # assume either ViT or DARTS-style architecture
                    'preproc': genotype != ViT,
                    'stem_type': 1}  # assume that the ImageNet-style stem is used by default
    except:
        deepnets = DeepNets1M(split=args.split,
                              nets_dir=args.data_dir,
                              large_images=True,
                              arch=args.arch)
        assert len(deepnets) == 1, 'one architecture must be chosen to train'
        graph = deepnets[0]
        net_args, idx = graph.net_args, graph.net_idx
        if 'norm' in net_args and net_args['norm'] == 'bn':
            net_args['norm'] = 'bn-track'
        if net_args['genotype'] == ViT:
            net_args['stem_type'] = 1  # using ImageNet style stem even for ViT

    num_classes = 1000
    if isinstance(net_args['genotype'], str):
        model = eval('torchvision.models.%s(pretrained=%d)' % (net_args['genotype'], args.pretrained))
        model.out_channels = model.fc.in_features
    else:
        model = Network(num_classes=num_classes,
                        is_imagenet_input=True,
                        is_vit=False,
                        **net_args)
        model.out_channels = net_args['C'] * len(net_args['genotype'].normal_concat) * (net_args['C_mult'] ** 2)

    if args.ckpt is not None or isinstance(model, torchvision.models.ResNet):
        model = pretrained_model(model, args.ckpt, num_classes, 1, GHN)

    # Allow the detector to use this backbone just as a feature extractor without modifying backbone's code
    def fw_hook(module, input, output):
        if isinstance(input, tuple):
            input = input[0]
        if isinstance(output, tuple):
            output = output[0]
        module.input_sz = input.shape
        if hasattr(module, 'prev_mod') and hasattr(module.prev_mod, 'input_sz'):
            output = output.view(module.prev_mod.input_sz)
        return output

    def add_fw_hooks(m):
        m.register_forward_hook(fw_hook)

    if isinstance(net_args['genotype'], str):
        model.fc = nn.Identity()
        model.avgpool = nn.Identity()
        model.fc.prev_mod = model.avgpool
    else:
        model.classifier = nn.Identity()
        model.global_pooling = nn.Identity()
        model.classifier.prev_mod = model.global_pooling

    model.apply(add_fw_hooks)

    return model