in depth_upsampling/models/__init__.py [0:0]
def get_network(network, upsampling_factor):
# Create model
if network == 'MSG':
model = MSGNet(upsampling_factor)
elif network == 'MSPF':
model = MSPF(upsampling_factor)
model.decoder.apply(weights_init_xavier)
else:
raise ValueError(f'No such network ({network})')
num_params = sum([np.prod(p.size()) for p in model.parameters()])
print("Total number of parameters: {}".format(num_params))
num_params_update = sum([np.prod(p.shape) for p in model.parameters() if p.requires_grad])
print("Total number of learning parameters: {}".format(num_params_update))
return model