def get_active_subnet()

in models/attentive_nas_dynamic_model.py [0:0]


    def get_active_subnet(self, preserve_weight=True):
        with torch.no_grad():
            first_conv = self.first_conv.get_active_subnet(3, preserve_weight)

            blocks = []
            input_channel = first_conv.out_channels
            # blocks
            for stage_id, block_idx in enumerate(self.block_group_info):
                depth = self.runtime_depth[stage_id]
                active_idx = block_idx[:depth]
                stage_blocks = []
                for idx in active_idx:
                    stage_blocks.append(MobileInvertedResidualBlock(
                        self.blocks[idx].mobile_inverted_conv.get_active_subnet(input_channel, preserve_weight),
                        self.blocks[idx].shortcut.get_active_subnet(input_channel, preserve_weight) if self.blocks[idx].shortcut is not None else None
                    ))
                    input_channel = stage_blocks[-1].mobile_inverted_conv.out_channels
                blocks += stage_blocks

            if not self.use_v3_head:
                last_conv = self.last_conv.get_active_subnet(input_channel, preserve_weight)
                in_features = last_conv.out_channels
            else:
                final_expand_layer = self.last_conv.final_expand_layer.get_active_subnet(input_channel, preserve_weight)
                feature_mix_layer = self.last_conv.feature_mix_layer.get_active_subnet(input_channel*6, preserve_weight)
                in_features = feature_mix_layer.out_channels
                last_conv = nn.Sequential(
                    final_expand_layer,
                    nn.AdaptiveAvgPool2d((1,1)),
                    feature_mix_layer
                )

            classifier = self.classifier.get_active_subnet(in_features, preserve_weight)

            _subnet = AttentiveNasStaticModel(
                first_conv, blocks, last_conv, classifier, self.active_resolution, use_v3_head=self.use_v3_head
            )
            _subnet.set_bn_param(**self.get_bn_param())
            return _subnet