in experiments/sgd/detector/train_detector.py [0:0]
def get_backbone(args):
try:
genotype = eval('genotypes.%s' % args.arch)
net_args = {'C': args.init_channels, # 48 if genotype == DARTS else 128
'genotype': genotype,
'n_cells': args.layers, # 14 if genotype == DARTS else 12
'C_mult': int(genotype != ViT) + 1, # assume either ViT or DARTS-style architecture
'preproc': genotype != ViT,
'stem_type': 1} # assume that the ImageNet-style stem is used by default
except:
deepnets = DeepNets1M(split=args.split,
nets_dir=args.data_dir,
large_images=True,
arch=args.arch)
assert len(deepnets) == 1, 'one architecture must be chosen to train'
graph = deepnets[0]
net_args, idx = graph.net_args, graph.net_idx
if 'norm' in net_args and net_args['norm'] == 'bn':
net_args['norm'] = 'bn-track'
if net_args['genotype'] == ViT:
net_args['stem_type'] = 1 # using ImageNet style stem even for ViT
num_classes = 1000
if isinstance(net_args['genotype'], str):
model = eval('torchvision.models.%s(pretrained=%d)' % (net_args['genotype'], args.pretrained))
model.out_channels = model.fc.in_features
else:
model = Network(num_classes=num_classes,
is_imagenet_input=True,
is_vit=False,
**net_args)
model.out_channels = net_args['C'] * len(net_args['genotype'].normal_concat) * (net_args['C_mult'] ** 2)
if args.ckpt is not None or isinstance(model, torchvision.models.ResNet):
model = pretrained_model(model, args.ckpt, num_classes, 1, GHN)
# Allow the detector to use this backbone just as a feature extractor without modifying backbone's code
def fw_hook(module, input, output):
if isinstance(input, tuple):
input = input[0]
if isinstance(output, tuple):
output = output[0]
module.input_sz = input.shape
if hasattr(module, 'prev_mod') and hasattr(module.prev_mod, 'input_sz'):
output = output.view(module.prev_mod.input_sz)
return output
def add_fw_hooks(m):
m.register_forward_hook(fw_hook)
if isinstance(net_args['genotype'], str):
model.fc = nn.Identity()
model.avgpool = nn.Identity()
model.fc.prev_mod = model.avgpool
else:
model.classifier = nn.Identity()
model.global_pooling = nn.Identity()
model.classifier.prev_mod = model.global_pooling
model.apply(add_fw_hooks)
return model