in projects/light_whoami/agents/expanded_attention.py [0:0]
def load_state_dict(self, state_dict: Dict[str, torch.Tensor]):
"""
Load the state dict into model.
Override TA.load_state_dict to build the expanded attention
"""
try:
super().load_state_dict(state_dict)
except RuntimeError:
if not [k for k in state_dict if 'extra_input_attention' in k]:
if self.opt['expanded_attention_init_weights'] == 'random':
logging.info('Loading Random Init for Expanded Attention')
state_dict.update(
{
**{
k: v
for k, v in self.model.state_dict().items()
if 'extra_input_attention' in k
},
**{
k: v
for k, v in self.model.state_dict().items()
if 'extra_input_norm' in k
},
}
)
elif self.opt['expanded_attention_init_weights'] == 'encoder_attention':
logging.info('Loading Encoder Attention for Expanded Attention')
state_dict.update(
{
**{
k.replace(
'encoder_attention', 'extra_input_attention'
): v
for k, v in state_dict.items()
if 'decoder' in k and 'encoder_attention' in k
},
**{
k.replace('norm2', 'extra_input_norm'): v
for k, v in state_dict.items()
if 'decoder' in k and 'norm2' in k
},
}
)
if not [k for k in state_dict if 'classifier_model' in k]:
logging.info('Adding Classifier Model Weights')
state_dict.update(
{
k: v
for k, v in self.model.state_dict().items()
if 'classifier_model' in k
}
)
if not [k for k in state_dict if 'mask_linear' in k]:
logging.info('Adding trainable mask Weights')
state_dict.update(
{
k: v
for k, v in self.model.state_dict().items()
if 'mask_linear' in k
}
)
super().load_state_dict(state_dict)