def __init__()

in easycv/models/detection/detectors/yolox/yolo_head_template.py [0:0]


    def __init__(self,
                 num_classes=80,
                 model_type='s',
                 strides=[8, 16, 32],
                 in_channels=[256, 512, 1024],
                 act='silu',
                 conv_type='conv',
                 stage='CLOUD',
                 obj_loss_type='BCE',
                 reg_loss_type='giou',
                 decode_in_inference=True,
                 width=None):
        """
        Args:
            num_classes (int): detection class numbers.
            width (float): model width. Default value: 1.0.
            strides (list): expanded strides. Default value: [8, 16, 32].
            in_channels (list): model conv channels set. Default value: [256, 512, 1024].
            act (str): activation type of conv. Defalut value: "silu".
            depthwise (bool): whether apply depthwise conv in conv branch. Default value: False.
            stage (str): model stage, distinguish edge head to cloud head. Default value: CLOUD.
            obj_loss_type (str): the loss function of the obj conf. Default value: BCE.
            reg_loss_type (str): the loss function of the box prediction. Default value: giou.
        """
        super().__init__()
        if width is None and model_type in self.param_map:
            width = self.param_map[model_type][1]
        else:
            assert (width !=
                    None), 'Unknow model type must have a given width!'

        self.width = width
        self.n_anchors = 1
        self.num_classes = num_classes
        self.stage = stage
        self.decode_in_inference = decode_in_inference  # for deploy, set to False

        self.cls_convs = nn.ModuleList()
        self.reg_convs = nn.ModuleList()
        self.cls_preds = nn.ModuleList()
        self.reg_preds = nn.ModuleList()
        self.obj_preds = nn.ModuleList()
        self.stems = nn.ModuleList()

        default_conv_type_list = ['conv', 'dwconv', 'repconv']

        if conv_type not in default_conv_type_list:
            logging.warning(
                'YOLOX-PAI tood head conv_type must in [conv, dwconv, repconv], otherwise we use repconv as default'
            )
            conv_type = 'repconv'
        if conv_type == 'conv':
            Conv = BaseConv
        if conv_type == 'dwconv':
            Conv = DWConv
        if conv_type == 'repconv':
            Conv = RepVGGBlock

        for i in range(len(in_channels)):
            self.stems.append(
                BaseConv(
                    in_channels=int(in_channels[i] * width),
                    out_channels=int(256 * width),
                    ksize=1,
                    stride=1,
                    act=act,
                ))
            self.cls_convs.append(
                nn.Sequential(*[
                    Conv(
                        in_channels=int(256 * width),
                        out_channels=int(256 * width),
                        ksize=3,
                        stride=1,
                        act=act,
                    ),
                    Conv(
                        in_channels=int(256 * width),
                        out_channels=int(256 * width),
                        ksize=3,
                        stride=1,
                        act=act,
                    ),
                ]))
            self.reg_convs.append(
                nn.Sequential(*[
                    Conv(
                        in_channels=int(256 * width),
                        out_channels=int(256 * width),
                        ksize=3,
                        stride=1,
                        act=act,
                    ),
                    Conv(
                        in_channels=int(256 * width),
                        out_channels=int(256 * width),
                        ksize=3,
                        stride=1,
                        act=act,
                    ),
                ]))

            self.cls_preds.append(
                nn.Conv2d(
                    in_channels=int(256 * width),
                    out_channels=self.n_anchors * self.num_classes,
                    kernel_size=1,
                    stride=1,
                    padding=0,
                ))
            self.reg_preds.append(
                nn.Conv2d(
                    in_channels=int(256 * width),
                    out_channels=4,
                    kernel_size=1,
                    stride=1,
                    padding=0,
                ))
            self.obj_preds.append(
                nn.Conv2d(
                    in_channels=int(256 * width),
                    out_channels=self.n_anchors * 1,
                    kernel_size=1,
                    stride=1,
                    padding=0,
                ))
        self.bcewithlog_loss = nn.BCEWithLogitsLoss(reduction='none')

        self.use_l1 = False
        self.l1_loss = nn.L1Loss(reduction='none')

        self.iou_loss = YOLOX_IOULoss(
            reduction='none', loss_type=reg_loss_type)

        self.obj_loss_type = obj_loss_type
        if obj_loss_type == 'BCE':
            self.obj_loss = nn.BCEWithLogitsLoss(reduction='none')
        else:
            raise KeyError('Undefined loss type: {}'.format(obj_loss_type))

        self.strides = strides
        self.grids = [torch.zeros(1)] * len(in_channels)