def _get_proj_keys_from_state_dict()

in paq/retrievers/retriever_utils.py [0:0]


def _get_proj_keys_from_state_dict(state_dict):
    weight_key = [k for k in state_dict.keys() if 'encode_proj' in k and 'weight' in k]
    bias_key = [k for k in state_dict.keys() if 'encode_proj' in k and 'bias' in k]
    assert len(weight_key) == 1 == len(bias_key)
    weight_key, bias_key = weight_key[0], bias_key[0]
    return weight_key, bias_key