def get_network()

in depth_upsampling/models/__init__.py [0:0]


def get_network(network, upsampling_factor):
    # Create model
    if network == 'MSG':
        model = MSGNet(upsampling_factor)
    elif network == 'MSPF':
        model = MSPF(upsampling_factor)
        model.decoder.apply(weights_init_xavier)
    else:
        raise ValueError(f'No such network ({network})')

    num_params = sum([np.prod(p.size()) for p in model.parameters()])
    print("Total number of parameters: {}".format(num_params))

    num_params_update = sum([np.prod(p.shape) for p in model.parameters() if p.requires_grad])
    print("Total number of learning parameters: {}".format(num_params_update))
    return model