in models/attentive_nas_dynamic_model.py [0:0]
def get_active_subnet(self, preserve_weight=True):
with torch.no_grad():
first_conv = self.first_conv.get_active_subnet(3, preserve_weight)
blocks = []
input_channel = first_conv.out_channels
# blocks
for stage_id, block_idx in enumerate(self.block_group_info):
depth = self.runtime_depth[stage_id]
active_idx = block_idx[:depth]
stage_blocks = []
for idx in active_idx:
stage_blocks.append(MobileInvertedResidualBlock(
self.blocks[idx].mobile_inverted_conv.get_active_subnet(input_channel, preserve_weight),
self.blocks[idx].shortcut.get_active_subnet(input_channel, preserve_weight) if self.blocks[idx].shortcut is not None else None
))
input_channel = stage_blocks[-1].mobile_inverted_conv.out_channels
blocks += stage_blocks
if not self.use_v3_head:
last_conv = self.last_conv.get_active_subnet(input_channel, preserve_weight)
in_features = last_conv.out_channels
else:
final_expand_layer = self.last_conv.final_expand_layer.get_active_subnet(input_channel, preserve_weight)
feature_mix_layer = self.last_conv.feature_mix_layer.get_active_subnet(input_channel*6, preserve_weight)
in_features = feature_mix_layer.out_channels
last_conv = nn.Sequential(
final_expand_layer,
nn.AdaptiveAvgPool2d((1,1)),
feature_mix_layer
)
classifier = self.classifier.get_active_subnet(in_features, preserve_weight)
_subnet = AttentiveNasStaticModel(
first_conv, blocks, last_conv, classifier, self.active_resolution, use_v3_head=self.use_v3_head
)
_subnet.set_bn_param(**self.get_bn_param())
return _subnet