maga_transformer/plugins/ret_hidden_states.py (
5
lines of code) (
raw
):
from typing import Any import torch class CustomPlugin(object): def modify_response_plugin(self, response: str, hidden_states: torch.Tensor, **kwargs: Any): return hidden_states