def get_param()

in experiments/codes/utils/util.py [0:0]


def get_param(params, param_name_dict, partial_name, ignore_classify=True, eps=0.00001):
    """Get params from the param list
    if two matches are found, check if weight normalized variables
    combine and return the param
    
    Arguments:
        params {list} -- list of params
        param_name_dict {dict} -- dict of name:location_id
        partial_name {str} -- partial name to search
    
    Raises:
        AssertionError: ["arg not found"]
    
    Returns:
        [tensor] -- [parameter tensor]
    """
    match = [k for k, v in param_name_dict.items() if partial_name in k]
    if ignore_classify:
        match = [k for k in match if "classify" not in k]
    if len(match) == 1:
        name = match[0]
        return params[param_name_dict[name]]
    elif len(match) == 2:
        # check if weight norm version
        # names should be "{layername}.{layer}.{param_name}_{g/v}"
        p1_head = "_".join(match[0].split("_")[:-1])
        p2_head = "_".join(match[1].split("_")[:-1])
        assert p1_head == p2_head
        name_d = {m.split("_")[-1]: m for m in match}
        param_g = name_d["g"]
        param_v = name_d["v"]
        return params[param_name_dict[param_v]] * (
            params[param_name_dict[param_g]]
            / (torch.norm(params[param_name_dict[param_v]]) + eps)
        ).expand_as(params[param_name_dict[param_v]])
    else:
        raise AssertionError("arg not found")