maga_transformer/plugins/ret_hidden_states.py (5 lines of code) (raw):

from typing import Any import torch class CustomPlugin(object): def modify_response_plugin(self, response: str, hidden_states: torch.Tensor, **kwargs: Any): return hidden_states