in timesformer/utils/checkpoint.py [0:0]
def sub_to_normal_bn(sd):
"""
Convert the Sub-BN paprameters to normal BN parameters in a state dict.
There are two copies of BN layers in a Sub-BN implementation: `bn.bn` and
`bn.split_bn`. `bn.split_bn` is used during training and
"compute_precise_bn". Before saving or evaluation, its stats are copied to
`bn.bn`. We rename `bn.bn` to `bn` and store it to be consistent with normal
BN layers.
Args:
sd (OrderedDict): a dict of parameters whitch might contain Sub-BN
parameters.
Returns:
new_sd (OrderedDict): a dict with Sub-BN parameters reshaped to
normal parameters.
"""
new_sd = copy.deepcopy(sd)
modifications = [
("bn.bn.running_mean", "bn.running_mean"),
("bn.bn.running_var", "bn.running_var"),
("bn.split_bn.num_batches_tracked", "bn.num_batches_tracked"),
]
to_remove = ["bn.bn.", ".split_bn."]
for key in sd:
for before, after in modifications:
if key.endswith(before):
new_key = key.split(before)[0] + after
new_sd[new_key] = new_sd.pop(key)
for rm in to_remove:
if rm in key and key in new_sd:
del new_sd[key]
for key in new_sd:
if key.endswith("bn.weight") or key.endswith("bn.bias"):
if len(new_sd[key].size()) == 4:
assert all(d == 1 for d in new_sd[key].size()[1:])
new_sd[key] = new_sd[key][:, 0, 0, 0]
return new_sd