in models/attentive_nas_dynamic_model.py [0:0]
def forward(self, x):
# resize input to target resolution first
if x.size(-1) != self.active_resolution:
x = torch.nn.functional.interpolate(x, size=self.active_resolution, mode='bicubic')
# first conv
x = self.first_conv(x)
# blocks
for stage_id, block_idx in enumerate(self.block_group_info):
depth = self.runtime_depth[stage_id]
active_idx = block_idx[:depth]
for idx in active_idx:
x = self.blocks[idx](x)
x = self.last_conv(x)
x = x.mean(3, keepdim=True).mean(2, keepdim=True) # global average pooling
x = torch.squeeze(x)
if self.active_dropout_rate > 0 and self.training:
x = torch.nn.functional.dropout(x, p = self.active_dropout_rate)
x = self.classifier(x)
return x