in experiments/codes/utils/util.py [0:0]
def get_param(params, param_name_dict, partial_name, ignore_classify=True, eps=0.00001):
"""Get params from the param list
if two matches are found, check if weight normalized variables
combine and return the param
Arguments:
params {list} -- list of params
param_name_dict {dict} -- dict of name:location_id
partial_name {str} -- partial name to search
Raises:
AssertionError: ["arg not found"]
Returns:
[tensor] -- [parameter tensor]
"""
match = [k for k, v in param_name_dict.items() if partial_name in k]
if ignore_classify:
match = [k for k in match if "classify" not in k]
if len(match) == 1:
name = match[0]
return params[param_name_dict[name]]
elif len(match) == 2:
# check if weight norm version
# names should be "{layername}.{layer}.{param_name}_{g/v}"
p1_head = "_".join(match[0].split("_")[:-1])
p2_head = "_".join(match[1].split("_")[:-1])
assert p1_head == p2_head
name_d = {m.split("_")[-1]: m for m in match}
param_g = name_d["g"]
param_v = name_d["v"]
return params[param_name_dict[param_v]] * (
params[param_name_dict[param_g]]
/ (torch.norm(params[param_name_dict[param_v]]) + eps)
).expand_as(params[param_name_dict[param_v]])
else:
raise AssertionError("arg not found")