def load_state_dict()

in projects/light_whoami/agents/expanded_attention.py [0:0]


    def load_state_dict(self, state_dict: Dict[str, torch.Tensor]):
        """
        Load the state dict into model.

        Override TA.load_state_dict to build the expanded attention
        """
        try:
            super().load_state_dict(state_dict)
        except RuntimeError:
            if not [k for k in state_dict if 'extra_input_attention' in k]:
                if self.opt['expanded_attention_init_weights'] == 'random':
                    logging.info('Loading Random Init for Expanded Attention')
                    state_dict.update(
                        {
                            **{
                                k: v
                                for k, v in self.model.state_dict().items()
                                if 'extra_input_attention' in k
                            },
                            **{
                                k: v
                                for k, v in self.model.state_dict().items()
                                if 'extra_input_norm' in k
                            },
                        }
                    )
                elif self.opt['expanded_attention_init_weights'] == 'encoder_attention':
                    logging.info('Loading Encoder Attention for Expanded Attention')
                    state_dict.update(
                        {
                            **{
                                k.replace(
                                    'encoder_attention', 'extra_input_attention'
                                ): v
                                for k, v in state_dict.items()
                                if 'decoder' in k and 'encoder_attention' in k
                            },
                            **{
                                k.replace('norm2', 'extra_input_norm'): v
                                for k, v in state_dict.items()
                                if 'decoder' in k and 'norm2' in k
                            },
                        }
                    )
            if not [k for k in state_dict if 'classifier_model' in k]:
                logging.info('Adding Classifier Model Weights')
                state_dict.update(
                    {
                        k: v
                        for k, v in self.model.state_dict().items()
                        if 'classifier_model' in k
                    }
                )
            if not [k for k in state_dict if 'mask_linear' in k]:
                logging.info('Adding trainable mask Weights')
                state_dict.update(
                    {
                        k: v
                        for k, v in self.model.state_dict().items()
                        if 'mask_linear' in k
                    }
                )
            super().load_state_dict(state_dict)