def parse_network_from_config()

in sample_info/modules/nn_utils.py [0:0]


def parse_network_from_config(args, input_shape, detailed_output=False):
    """Parses a sequential feed-forward neural network from json config."""

    # parse project-specific networks
    if isinstance(args, dict) and args['net'] == 'emnist-letters-pretrained-cnn':
        checkpoint_path = 'sample_info/modules/resources/emnist_letters_cnn_pretrained.mdl'
        net = utils.load(path=checkpoint_path, methods=methods, device='cpu').classifier.sequential_model

        num_classes = args.get('num_classes', 26)
        if num_classes != 26:
            layers = list(net.children())
            replace = torch.nn.Linear(in_features=layers[-1].in_features, out_features=num_classes)
            layers[-1] = replace
            net = torch.nn.Sequential(*layers)

        output_shape = (None, num_classes)
        print("output.shape:", output_shape)

    elif isinstance(args, dict) and args['net'] in ['leaky-resnet18', 'leaky-resnet34', 'leaky-resnet50']:
        from sample_info.modules.resnet_leaky_relu import resnet18, resnet34, resnet50

        resnet_fn = None
        if args['net'] == 'leaky-resnet18':
            resnet_fn = resnet18
        if args['net'] == 'leaky-resnet34':
            resnet_fn = resnet34
        if args['net'] == 'leaky-resnet50':
            resnet_fn = resnet50

        norm_layer = torch.nn.BatchNorm2d
        if args.get('norm_layer', '') == 'GroupNorm':
            norm_layer = nn_utils_base.group_norm_partial_apply_fn(num_groups=32)
        if args.get('norm_layer', '') == 'none':
            norm_layer = (lambda num_channels: nn_utils_base.Identity())

        num_classes = args.get('num_classes', 1000)
        pretrained = args.get('pretrained', False)

        # if pretraining is enabled but number of classes is not 1000 replace the last layer
        if pretrained and num_classes != 1000:
            net = resnet_fn(norm_layer=norm_layer, num_classes=1000, pretrained=pretrained)
            net.fc = torch.nn.Linear(net.fc.in_features, num_classes)
        else:
            net = resnet_fn(norm_layer=norm_layer, num_classes=num_classes, pretrained=pretrained)
        output_shape = nn_utils_base.infer_shape([net], input_shape)
        print("output.shape:", output_shape)

    else:
        # parse general-case networks
        net, output_shape = nn_utils_base.parse_network_from_config(args, input_shape)

    if detailed_output:
        net = SimpleDetailedOutputWrapper(net)

    return net, output_shape