in jat/modeling_jat.py [0:0]
def forward(
self,
input_ids: Optional[LongTensor] = None,
pixel_values: Optional[FloatTensor] = None,
continuous_observations: Optional[FloatTensor] = None,
discrete_observations: Optional[LongTensor] = None,
image_observations: Optional[FloatTensor] = None,
continuous_actions: Optional[FloatTensor] = None,
discrete_actions: Optional[LongTensor] = None,
rewards: Optional[FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[FloatTensor]]] = None,
attention_mask: Optional[BoolTensor] = None,
token_type_ids: Optional[LongTensor] = None,
position_ids: Optional[LongTensor] = None,
return_loss: bool = True,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
loss_weight: Optional[FloatTensor] = None,