def __init__()

in jat/modeling_jat.py [0:0]


    def __init__(self, config: JatConfig) -> None:
        super().__init__(config)

        vocab_size = config.vocab_size
        hidden_size = config.hidden_size
        max_discrete_value = config.max_discrete_value
        max_continuous_size = config.max_continuous_size
        self.observation_loss_coef = config.observation_loss_coef
        self.action_loss_coef = config.action_loss_coef

        # Transformer
        self.transformer = GPTNeoModel(config)

        # Encoders
        self.vit_encoder = ViTPatchEmbeddings(config)
        self.single_discrete_encoder = self.transformer.wte
        self.continuous_encoder = nn.Linear(max_continuous_size, hidden_size)
        self.multi_discrete_encoder = nn.Sequential(
            self.single_discrete_encoder,  # (B, L, X, H)
            nn.Linear(hidden_size, hidden_size // 50),  # (B, L, X, H // 50)
            nn.ReLU(),
            nn.Flatten(start_dim=2),  # (B, L, X * (H // 50))
            nn.Linear(max_discrete_value * (hidden_size // 50), hidden_size - 1),  # (B, L, H)
        )  # -1 to account for the reward
        self.image_encoder = DualBatchReshapeWrapper(ImageEncoder(hidden_size))

        # Decoders
        self.single_discrete_decoder = nn.Linear(hidden_size, vocab_size, bias=False)
        self.continuous_decoder = nn.Linear(hidden_size, max_continuous_size)
        self.multi_discrete_decoder = nn.Sequential(
            nn.Linear(hidden_size, max_discrete_value * (hidden_size // 50)),  # (B, L, X * (H // 50))
            nn.Unflatten(dim=2, unflattened_size=(max_discrete_value, hidden_size // 50)),  # (B, L, X, H // 50)
            nn.ReLU(),
            nn.Linear(hidden_size // 50, hidden_size),  # (B, L, X, H)
            nn.ReLU(),
            nn.Linear(hidden_size, 8, bias=False),  # (B, L, X, 8) - the max possible value in the dataset is 8
        )
        self.image_decoder = DualBatchReshapeWrapper(ImageDecoder(hidden_size))

        # Initialize weights and apply final processing
        self.post_init()