in flsim/utils/example_utils.py [0:0]
def __init__(self, in_channels, num_classes, dropout_rate=0):
super(SimpleConvNet, self).__init__()
self.out_channels = 32
self.stride = 1
self.padding = 2
self.layers = []
in_dim = in_channels
for _ in range(4):
self.layers.append(
nn.Conv2d(in_dim, self.out_channels, 3, self.stride, self.padding)
)
in_dim = self.out_channels
self.layers = nn.ModuleList(self.layers)
self.gn_relu = nn.Sequential(
nn.GroupNorm(self.out_channels, self.out_channels, affine=True),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
)
num_features = (
self.out_channels
* (self.stride + self.padding)
* (self.stride + self.padding)
)
self.dropout = nn.Dropout(dropout_rate)
self.fc = nn.Linear(num_features, num_classes)