in projects_oss/detr/detr/d2/detr.py [0:0]
def __init__(self, cfg):
super().__init__()
self.device = torch.device(cfg.MODEL.DEVICE)
self.num_classes = cfg.MODEL.DETR.NUM_CLASSES
self.mask_on = cfg.MODEL.MASK_ON
hidden_dim = cfg.MODEL.DETR.HIDDEN_DIM
num_queries = cfg.MODEL.DETR.NUM_OBJECT_QUERIES
# Transformer parameters:
nheads = cfg.MODEL.DETR.NHEADS
dropout = cfg.MODEL.DETR.DROPOUT
dim_feedforward = cfg.MODEL.DETR.DIM_FEEDFORWARD
enc_layers = cfg.MODEL.DETR.ENC_LAYERS
dec_layers = cfg.MODEL.DETR.DEC_LAYERS
pre_norm = cfg.MODEL.DETR.PRE_NORM
# Loss parameters:
giou_weight = cfg.MODEL.DETR.GIOU_WEIGHT
l1_weight = cfg.MODEL.DETR.L1_WEIGHT
cls_weight = cfg.MODEL.DETR.CLS_WEIGHT
deep_supervision = cfg.MODEL.DETR.DEEP_SUPERVISION
no_object_weight = cfg.MODEL.DETR.NO_OBJECT_WEIGHT
centered_position_encoding = cfg.MODEL.DETR.CENTERED_POSITION_ENCODIND
num_feature_levels = cfg.MODEL.DETR.NUM_FEATURE_LEVELS
N_steps = hidden_dim // 2
if "resnet" in cfg.MODEL.BACKBONE.NAME.lower():
d2_backbone = ResNetMaskedBackbone(cfg)
elif "fbnet" in cfg.MODEL.BACKBONE.NAME.lower():
d2_backbone = FBNetMaskedBackbone(cfg)
elif cfg.MODEL.BACKBONE.SIMPLE:
d2_backbone = SimpleSingleStageBackbone(cfg)
else:
raise NotImplementedError
backbone = Joiner(
d2_backbone,
PositionEmbeddingSine(
N_steps, normalize=True, centered=centered_position_encoding
),
)
backbone.num_channels = d2_backbone.num_channels
self.use_focal_loss = cfg.MODEL.DETR.USE_FOCAL_LOSS
if cfg.MODEL.DETR.DEFORMABLE:
transformer = DeformableTransformer(
d_model=hidden_dim,
nhead=nheads,
num_encoder_layers=enc_layers,
num_decoder_layers=dec_layers,
dim_feedforward=dim_feedforward,
dropout=dropout,
activation="relu",
return_intermediate_dec=True,
num_feature_levels=num_feature_levels,
dec_n_points=4,
enc_n_points=4,
two_stage=cfg.MODEL.DETR.TWO_STAGE,
two_stage_num_proposals=num_queries,
)
self.detr = DeformableDETR(
backbone,
transformer,
num_classes=self.num_classes,
num_queries=num_queries,
num_feature_levels=num_feature_levels,
aux_loss=deep_supervision,
with_box_refine=cfg.MODEL.DETR.WITH_BOX_REFINE,
two_stage=cfg.MODEL.DETR.TWO_STAGE,
)
else:
transformer = Transformer(
d_model=hidden_dim,
dropout=dropout,
nhead=nheads,
dim_feedforward=dim_feedforward,
num_encoder_layers=enc_layers,
num_decoder_layers=dec_layers,
normalize_before=pre_norm,
return_intermediate_dec=deep_supervision,
)
self.detr = DETR(
backbone,
transformer,
num_classes=self.num_classes,
num_queries=num_queries,
aux_loss=deep_supervision,
use_focal_loss=self.use_focal_loss,
)
if self.mask_on:
frozen_weights = cfg.MODEL.DETR.FROZEN_WEIGHTS
if frozen_weights != "":
print("LOAD pre-trained weights")
weight = torch.load(
frozen_weights, map_location=lambda storage, loc: storage
)["model"]
new_weight = {}
for k, v in weight.items():
if "detr." in k:
new_weight[k.replace("detr.", "")] = v
else:
print(f"Skipping loading weight {k} from frozen model")
del weight
self.detr.load_state_dict(new_weight)
del new_weight
self.detr = DETRsegm(self.detr, freeze_detr=(frozen_weights != ""))
self.seg_postprocess = PostProcessSegm
self.detr.to(self.device)
# building criterion
matcher = HungarianMatcher(
cost_class=cls_weight,
cost_bbox=l1_weight,
cost_giou=giou_weight,
use_focal_loss=self.use_focal_loss,
)
weight_dict = {"loss_ce": cls_weight, "loss_bbox": l1_weight}
weight_dict["loss_giou"] = giou_weight
if deep_supervision:
aux_weight_dict = {}
for i in range(dec_layers - 1):
aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
weight_dict.update(aux_weight_dict)
losses = ["labels", "boxes", "cardinality"]
if self.mask_on:
losses += ["masks"]
if self.use_focal_loss:
self.criterion = FocalLossSetCriterion(
self.num_classes,
matcher=matcher,
weight_dict=weight_dict,
losses=losses,
)
else:
self.criterion = SetCriterion(
self.num_classes,
matcher=matcher,
weight_dict=weight_dict,
eos_coef=no_object_weight,
losses=losses,
)
self.criterion.to(self.device)
pixel_mean = torch.Tensor(cfg.MODEL.PIXEL_MEAN).to(self.device).view(3, 1, 1)
pixel_std = torch.Tensor(cfg.MODEL.PIXEL_STD).to(self.device).view(3, 1, 1)
self.normalizer = lambda x: (x - pixel_mean) / pixel_std
self.to(self.device)