def __init__()

in jat/modeling_jat.py [0:0]


    def __init__(self, hidden_size: int) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(4, 32, kernel_size=3, stride=2, padding=1)  # 42x42
        self.norm1 = nn.InstanceNorm2d(32)
        self.att1 = AttentionLayer(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)  # 21x21
        self.norm2 = nn.InstanceNorm2d(64)
        self.att2 = AttentionLayer(64)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)  # 11x11
        self.norm3 = nn.InstanceNorm2d(128)
        self.att3 = AttentionLayer(128)
        self.fc = nn.Linear(128 * 11 * 11, hidden_size)  # Adjusted to the new spatial dimension