def __init__()

in easycv/models/detection/detectors/dino/dino_head.py [0:0]


    def __init__(
            self,
            num_classes,
            embed_dims,
            in_channels=[512, 1024, 2048],
            query_dim=4,
            num_queries=300,
            num_select=300,
            random_refpoints_xy=False,
            num_patterns=0,
            dn_components=None,
            transformer=None,
            fix_refpoints_hw=-1,
            num_feature_levels=1,
            # two stage
            two_stage_type='standard',  # ['no', 'standard']
            two_stage_add_query_num=0,
            dec_pred_class_embed_share=True,
            dec_pred_bbox_embed_share=True,
            two_stage_class_embed_share=True,
            two_stage_bbox_embed_share=True,
            use_centerness=False,
            use_iouaware=False,
            losses_list=['labels', 'boxes'],
            decoder_sa_type='sa',
            temperatureH=20,
            temperatureW=20,
            cost_dict={
                'cost_class': 1,
                'cost_bbox': 5,
                'cost_giou': 2,
            },
            weight_dict={
                'loss_ce': 1,
                'loss_bbox': 5,
                'loss_giou': 2
            },
            **kwargs):

        super(DINOHead, self).__init__()

        self.matcher = HungarianMatcher(
            cost_dict=cost_dict, cost_class_type='focal_loss_cost')
        self.criterion = SetCriterion(
            num_classes,
            matcher=self.matcher,
            weight_dict=weight_dict,
            losses=losses_list,
            loss_class_type='focal_loss')
        if dn_components is not None:
            self.dn_criterion = CDNCriterion(
                num_classes,
                matcher=self.matcher,
                weight_dict=weight_dict,
                losses=losses_list,
                loss_class_type='focal_loss')
        self.postprocess = DetrPostProcess(
            num_select=num_select,
            use_centerness=use_centerness,
            use_iouaware=use_iouaware)
        self.transformer = build_neck(transformer)

        self.positional_encoding = PositionEmbeddingSineHW(
            embed_dims // 2,
            temperatureH=temperatureH,
            temperatureW=temperatureW,
            normalize=True)

        self.num_classes = num_classes
        self.num_queries = num_queries
        self.embed_dims = embed_dims
        self.query_dim = query_dim
        self.dn_components = dn_components

        self.random_refpoints_xy = random_refpoints_xy
        self.fix_refpoints_hw = fix_refpoints_hw

        # for dn training
        self.dn_number = self.dn_components['dn_number']
        self.dn_box_noise_scale = self.dn_components['dn_box_noise_scale']
        self.dn_label_noise_ratio = self.dn_components['dn_label_noise_ratio']
        self.dn_labelbook_size = self.dn_components['dn_labelbook_size']
        self.label_enc = nn.Embedding(self.dn_labelbook_size + 1, embed_dims)

        # prepare input projection layers
        self.num_feature_levels = num_feature_levels
        if num_feature_levels > 1:
            num_backbone_outs = len(in_channels)
            input_proj_list = []
            for i in range(num_backbone_outs):
                in_channels_i = in_channels[i]
                input_proj_list.append(
                    nn.Sequential(
                        nn.Conv2d(in_channels_i, embed_dims, kernel_size=1),
                        nn.GroupNorm(32, embed_dims),
                    ))
            for _ in range(num_feature_levels - num_backbone_outs):
                input_proj_list.append(
                    nn.Sequential(
                        nn.Conv2d(
                            in_channels_i,
                            embed_dims,
                            kernel_size=3,
                            stride=2,
                            padding=1),
                        nn.GroupNorm(32, embed_dims),
                    ))
                in_channels_i = embed_dims
            self.input_proj = nn.ModuleList(input_proj_list)
        else:
            assert two_stage_type == 'no', 'two_stage_type should be no if num_feature_levels=1 !!!'
            self.input_proj = nn.ModuleList([
                nn.Sequential(
                    nn.Conv2d(in_channels[-1], embed_dims, kernel_size=1),
                    nn.GroupNorm(32, embed_dims),
                )
            ])

        # prepare pred layers
        self.dec_pred_class_embed_share = dec_pred_class_embed_share
        self.dec_pred_bbox_embed_share = dec_pred_bbox_embed_share
        # prepare class & box embed
        _class_embed = nn.Linear(embed_dims, num_classes)
        _bbox_embed = MLP(embed_dims, embed_dims, 4, 3)
        # init the two embed layers
        prior_prob = 0.01
        bias_value = -math.log((1 - prior_prob) / prior_prob)
        _class_embed.bias.data = torch.ones(self.num_classes) * bias_value
        nn.init.constant_(_bbox_embed.layers[-1].weight.data, 0)
        nn.init.constant_(_bbox_embed.layers[-1].bias.data, 0)

        # fcos centerness & iou-aware & tokenlabel
        self.use_centerness = use_centerness
        self.use_iouaware = use_iouaware
        if self.use_centerness:
            _center_embed = MLP(embed_dims, embed_dims, 1, 3)
        if self.use_iouaware:
            _iou_embed = MLP(embed_dims, embed_dims, 1, 3)

        if dec_pred_bbox_embed_share:
            box_embed_layerlist = [
                _bbox_embed for i in range(transformer.num_decoder_layers)
            ]
            if self.use_centerness:
                center_embed_layerlist = [
                    _center_embed
                    for i in range(transformer.num_decoder_layers)
                ]
            if self.use_iouaware:
                iou_embed_layerlist = [
                    _iou_embed for i in range(transformer.num_decoder_layers)
                ]
        else:
            box_embed_layerlist = [
                copy.deepcopy(_bbox_embed)
                for i in range(transformer.num_decoder_layers)
            ]
            if self.use_centerness:
                center_embed_layerlist = [
                    copy.deepcopy(_center_embed)
                    for i in range(transformer.num_decoder_layers)
                ]
            if self.use_iouaware:
                iou_embed_layerlist = [
                    copy.deepcopy(_iou_embed)
                    for i in range(transformer.num_decoder_layers)
                ]

        if dec_pred_class_embed_share:
            class_embed_layerlist = [
                _class_embed for i in range(transformer.num_decoder_layers)
            ]
        else:
            class_embed_layerlist = [
                copy.deepcopy(_class_embed)
                for i in range(transformer.num_decoder_layers)
            ]
        self.bbox_embed = nn.ModuleList(box_embed_layerlist)
        self.class_embed = nn.ModuleList(class_embed_layerlist)
        self.transformer.decoder.bbox_embed = self.bbox_embed
        self.transformer.decoder.class_embed = self.class_embed

        if self.use_centerness:
            self.center_embed = nn.ModuleList(center_embed_layerlist)
            self.transformer.decoder.center_embed = self.center_embed
        if self.use_iouaware:
            self.iou_embed = nn.ModuleList(iou_embed_layerlist)
            self.transformer.decoder.iou_embed = self.iou_embed

        # two stage
        self.two_stage_type = two_stage_type
        self.two_stage_add_query_num = two_stage_add_query_num
        assert two_stage_type in [
            'no', 'standard'
        ], 'unknown param {} of two_stage_type'.format(two_stage_type)
        if two_stage_type != 'no':
            if two_stage_bbox_embed_share:
                assert dec_pred_class_embed_share and dec_pred_bbox_embed_share
                self.transformer.enc_out_bbox_embed = _bbox_embed
                if self.use_centerness:
                    self.transformer.enc_out_center_embed = _center_embed
                if self.use_iouaware:
                    self.transformer.enc_out_iou_embed = _iou_embed
            else:
                self.transformer.enc_out_bbox_embed = copy.deepcopy(
                    _bbox_embed)
                if self.use_centerness:
                    self.transformer.enc_out_center_embed = copy.deepcopy(
                        _center_embed)
                if self.use_iouaware:
                    self.transformer.enc_out_iou_embed = copy.deepcopy(
                        _iou_embed)

            if two_stage_class_embed_share:
                assert dec_pred_class_embed_share and dec_pred_bbox_embed_share
                self.transformer.enc_out_class_embed = _class_embed
            else:
                self.transformer.enc_out_class_embed = copy.deepcopy(
                    _class_embed)

            self.refpoint_embed = None
            if self.two_stage_add_query_num > 0:
                self.init_ref_points(two_stage_add_query_num)

        self.decoder_sa_type = decoder_sa_type
        assert decoder_sa_type in ['sa', 'ca_label', 'ca_content']
        # self.replace_sa_with_double_ca = replace_sa_with_double_ca
        if decoder_sa_type == 'ca_label':
            self.label_embedding = nn.Embedding(num_classes, embed_dims)
            for layer in self.transformer.decoder.layers:
                layer.label_embedding = self.label_embedding
        else:
            for layer in self.transformer.decoder.layers:
                layer.label_embedding = None
            self.label_embedding = None