in jat/modeling_jat.py [0:0]
def __init__(self, hidden_size: int) -> None:
super().__init__()
self.conv1 = nn.Conv2d(4, 32, kernel_size=3, stride=2, padding=1) # 42x42
self.norm1 = nn.InstanceNorm2d(32)
self.att1 = AttentionLayer(32)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1) # 21x21
self.norm2 = nn.InstanceNorm2d(64)
self.att2 = AttentionLayer(64)
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1) # 11x11
self.norm3 = nn.InstanceNorm2d(128)
self.att3 = AttentionLayer(128)
self.fc = nn.Linear(128 * 11 * 11, hidden_size) # Adjusted to the new spatial dimension