in utils.py [0:0]
def make_fresh_copy_of_pruned_network(model: nn.Module, model_kwargs: Dict):
norm_type_string = get_norm_type_string(model)
builder = Builder(conv_type="StandardConv", bn_type=norm_type_string)
copy = type(model.module)(
builder=builder, block_builder=builder, **model_kwargs
) # type: nn.Module
# Need to move @copy to GPU before moving to DataParallel.
if next(model.parameters()).is_cuda:
copy = copy.cuda()
copy = nn.DataParallel(copy)
state_dict = model.state_dict()
del_me = []
for k, v in state_dict.items():
if k.endswith(f"1"):
del_me.append(k)
for elem in del_me:
del state_dict[elem]
copy.load_state_dict(state_dict)
# The only part we should need to fix are modules with a get_weight()
# function.
name_to_copy = {name: module for name, module in copy.named_modules()}
for name, module in model.named_modules():
if hasattr(module, "get_weight"):
print(f"Adjusting weight at module {name}")
pieces = module.get_weight()
if len(pieces) == 1:
name_to_copy[name].weight.data = pieces
else:
assert len(pieces) == 2, f"Invalid len(pieces)={len(pieces)}"
name_to_copy[name].weight.data = pieces[0]
name_to_copy[name].bias.data = pieces[1]
return copy