def __init__()

in projects_oss/detr/detr/d2/detr.py [0:0]


    def __init__(self, cfg):
        super().__init__()

        self.device = torch.device(cfg.MODEL.DEVICE)

        self.num_classes = cfg.MODEL.DETR.NUM_CLASSES
        self.mask_on = cfg.MODEL.MASK_ON
        hidden_dim = cfg.MODEL.DETR.HIDDEN_DIM
        num_queries = cfg.MODEL.DETR.NUM_OBJECT_QUERIES
        # Transformer parameters:
        nheads = cfg.MODEL.DETR.NHEADS
        dropout = cfg.MODEL.DETR.DROPOUT
        dim_feedforward = cfg.MODEL.DETR.DIM_FEEDFORWARD
        enc_layers = cfg.MODEL.DETR.ENC_LAYERS
        dec_layers = cfg.MODEL.DETR.DEC_LAYERS
        pre_norm = cfg.MODEL.DETR.PRE_NORM

        # Loss parameters:
        giou_weight = cfg.MODEL.DETR.GIOU_WEIGHT
        l1_weight = cfg.MODEL.DETR.L1_WEIGHT
        cls_weight = cfg.MODEL.DETR.CLS_WEIGHT
        deep_supervision = cfg.MODEL.DETR.DEEP_SUPERVISION
        no_object_weight = cfg.MODEL.DETR.NO_OBJECT_WEIGHT
        centered_position_encoding = cfg.MODEL.DETR.CENTERED_POSITION_ENCODIND
        num_feature_levels = cfg.MODEL.DETR.NUM_FEATURE_LEVELS

        N_steps = hidden_dim // 2
        if "resnet" in cfg.MODEL.BACKBONE.NAME.lower():
            d2_backbone = ResNetMaskedBackbone(cfg)
        elif "fbnet" in cfg.MODEL.BACKBONE.NAME.lower():
            d2_backbone = FBNetMaskedBackbone(cfg)
        elif cfg.MODEL.BACKBONE.SIMPLE:
            d2_backbone = SimpleSingleStageBackbone(cfg)
        else:
            raise NotImplementedError

        backbone = Joiner(
            d2_backbone,
            PositionEmbeddingSine(
                N_steps, normalize=True, centered=centered_position_encoding
            ),
        )
        backbone.num_channels = d2_backbone.num_channels
        self.use_focal_loss = cfg.MODEL.DETR.USE_FOCAL_LOSS

        if cfg.MODEL.DETR.DEFORMABLE:
            transformer = DeformableTransformer(
                d_model=hidden_dim,
                nhead=nheads,
                num_encoder_layers=enc_layers,
                num_decoder_layers=dec_layers,
                dim_feedforward=dim_feedforward,
                dropout=dropout,
                activation="relu",
                return_intermediate_dec=True,
                num_feature_levels=num_feature_levels,
                dec_n_points=4,
                enc_n_points=4,
                two_stage=cfg.MODEL.DETR.TWO_STAGE,
                two_stage_num_proposals=num_queries,
            )

            self.detr = DeformableDETR(
                backbone,
                transformer,
                num_classes=self.num_classes,
                num_queries=num_queries,
                num_feature_levels=num_feature_levels,
                aux_loss=deep_supervision,
                with_box_refine=cfg.MODEL.DETR.WITH_BOX_REFINE,
                two_stage=cfg.MODEL.DETR.TWO_STAGE,
            )
        else:
            transformer = Transformer(
                d_model=hidden_dim,
                dropout=dropout,
                nhead=nheads,
                dim_feedforward=dim_feedforward,
                num_encoder_layers=enc_layers,
                num_decoder_layers=dec_layers,
                normalize_before=pre_norm,
                return_intermediate_dec=deep_supervision,
            )

            self.detr = DETR(
                backbone,
                transformer,
                num_classes=self.num_classes,
                num_queries=num_queries,
                aux_loss=deep_supervision,
                use_focal_loss=self.use_focal_loss,
            )
        if self.mask_on:
            frozen_weights = cfg.MODEL.DETR.FROZEN_WEIGHTS
            if frozen_weights != "":
                print("LOAD pre-trained weights")
                weight = torch.load(
                    frozen_weights, map_location=lambda storage, loc: storage
                )["model"]
                new_weight = {}
                for k, v in weight.items():
                    if "detr." in k:
                        new_weight[k.replace("detr.", "")] = v
                    else:
                        print(f"Skipping loading weight {k} from frozen model")
                del weight
                self.detr.load_state_dict(new_weight)
                del new_weight
            self.detr = DETRsegm(self.detr, freeze_detr=(frozen_weights != ""))
            self.seg_postprocess = PostProcessSegm

        self.detr.to(self.device)

        # building criterion
        matcher = HungarianMatcher(
            cost_class=cls_weight,
            cost_bbox=l1_weight,
            cost_giou=giou_weight,
            use_focal_loss=self.use_focal_loss,
        )
        weight_dict = {"loss_ce": cls_weight, "loss_bbox": l1_weight}
        weight_dict["loss_giou"] = giou_weight
        if deep_supervision:
            aux_weight_dict = {}
            for i in range(dec_layers - 1):
                aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
            weight_dict.update(aux_weight_dict)
        losses = ["labels", "boxes", "cardinality"]
        if self.mask_on:
            losses += ["masks"]
        if self.use_focal_loss:
            self.criterion = FocalLossSetCriterion(
                self.num_classes,
                matcher=matcher,
                weight_dict=weight_dict,
                losses=losses,
            )
        else:
            self.criterion = SetCriterion(
                self.num_classes,
                matcher=matcher,
                weight_dict=weight_dict,
                eos_coef=no_object_weight,
                losses=losses,
            )
        self.criterion.to(self.device)

        pixel_mean = torch.Tensor(cfg.MODEL.PIXEL_MEAN).to(self.device).view(3, 1, 1)
        pixel_std = torch.Tensor(cfg.MODEL.PIXEL_STD).to(self.device).view(3, 1, 1)
        self.normalizer = lambda x: (x - pixel_mean) / pixel_std
        self.to(self.device)