in jat/modeling_jat.py [0:0]
def __init__(self, config: JatConfig) -> None:
super().__init__(config)
vocab_size = config.vocab_size
hidden_size = config.hidden_size
max_discrete_value = config.max_discrete_value
max_continuous_size = config.max_continuous_size
self.observation_loss_coef = config.observation_loss_coef
self.action_loss_coef = config.action_loss_coef
# Transformer
self.transformer = GPTNeoModel(config)
# Encoders
self.vit_encoder = ViTPatchEmbeddings(config)
self.single_discrete_encoder = self.transformer.wte
self.continuous_encoder = nn.Linear(max_continuous_size, hidden_size)
self.multi_discrete_encoder = nn.Sequential(
self.single_discrete_encoder, # (B, L, X, H)
nn.Linear(hidden_size, hidden_size // 50), # (B, L, X, H // 50)
nn.ReLU(),
nn.Flatten(start_dim=2), # (B, L, X * (H // 50))
nn.Linear(max_discrete_value * (hidden_size // 50), hidden_size - 1), # (B, L, H)
) # -1 to account for the reward
self.image_encoder = DualBatchReshapeWrapper(ImageEncoder(hidden_size))
# Decoders
self.single_discrete_decoder = nn.Linear(hidden_size, vocab_size, bias=False)
self.continuous_decoder = nn.Linear(hidden_size, max_continuous_size)
self.multi_discrete_decoder = nn.Sequential(
nn.Linear(hidden_size, max_discrete_value * (hidden_size // 50)), # (B, L, X * (H // 50))
nn.Unflatten(dim=2, unflattened_size=(max_discrete_value, hidden_size // 50)), # (B, L, X, H // 50)
nn.ReLU(),
nn.Linear(hidden_size // 50, hidden_size), # (B, L, X, H)
nn.ReLU(),
nn.Linear(hidden_size, 8, bias=False), # (B, L, X, 8) - the max possible value in the dataset is 8
)
self.image_decoder = DualBatchReshapeWrapper(ImageDecoder(hidden_size))
# Initialize weights and apply final processing
self.post_init()