in opacus/layers/dp_multihead_attention.py [0:0]
def load_state_dict(self, state_dict):
r"""
Loads module from previously saved state.
Supports loading from both :class:`torch.nn.MultiheadAttention` and
:class:`opacus.layers.dp_multihead_attention.DPMultiheadAttention`.
Args:
state_dict: Please refer to
https://pytorch.org/tutorials/recipes/recipes/what_is_state_dict.html.
"""
if "in_proj_weight" in state_dict:
qweight, kweight, vweight = state_dict["in_proj_weight"].chunk(3, dim=0)
state_dict["qlinear.weight"] = qweight
state_dict["klinear.weight"] = kweight
state_dict["vlinear.weight"] = vweight
del state_dict["in_proj_weight"]
if "in_proj_bias" in state_dict:
qbias, kbias, vbias = state_dict["in_proj_bias"].chunk(3, dim=0)
state_dict["qlinear.bias"] = qbias
state_dict["klinear.bias"] = kbias
state_dict["vlinear.bias"] = vbias
del state_dict["in_proj_bias"]
if "bias_k" in state_dict:
state_dict["seq_bias_k.bias"] = state_dict["bias_k"].squeeze()
del state_dict["bias_k"]
if "bias_v" in state_dict:
state_dict["seq_bias_v.bias"] = state_dict["bias_v"].squeeze()
del state_dict["bias_v"]
if "q_proj_weight" in state_dict:
state_dict["qlinear.weight"] = state_dict["q_proj_weight"]
del state_dict["q_proj_weight"]
if "k_proj_weight" in state_dict:
state_dict["klinear.weight"] = state_dict["k_proj_weight"]
del state_dict["k_proj_weight"]
if "v_proj_weight" in state_dict:
state_dict["vlinear.weight"] = state_dict["v_proj_weight"]
del state_dict["v_proj_weight"]
super(DPMultiheadAttention, self).load_state_dict(state_dict)