in jat/modeling_jat.py [0:0]
def output_rl(
self,
transformer_outputs,
continuous_observations: Optional[FloatTensor] = None,
discrete_observations: Optional[LongTensor] = None,
image_observations: Optional[FloatTensor] = None,
continuous_actions: Optional[FloatTensor] = None,
discrete_actions: Optional[LongTensor] = None,
rewards: Optional[FloatTensor] = None,
attention_mask: Optional[BoolTensor] = None,
return_loss: bool = True,
return_dict: Optional[bool] = None,
loss_weight: Optional[FloatTensor] = None,