in d2/detr/detr.py [0:0]
def __init__(self, cfg):
super().__init__()
self.device = torch.device(cfg.MODEL.DEVICE)
self.num_classes = cfg.MODEL.DETR.NUM_CLASSES
self.mask_on = cfg.MODEL.MASK_ON
hidden_dim = cfg.MODEL.DETR.HIDDEN_DIM
num_queries = cfg.MODEL.DETR.NUM_OBJECT_QUERIES
# Transformer parameters:
nheads = cfg.MODEL.DETR.NHEADS
dropout = cfg.MODEL.DETR.DROPOUT
dim_feedforward = cfg.MODEL.DETR.DIM_FEEDFORWARD
enc_layers = cfg.MODEL.DETR.ENC_LAYERS
dec_layers = cfg.MODEL.DETR.DEC_LAYERS
pre_norm = cfg.MODEL.DETR.PRE_NORM
# Loss parameters:
giou_weight = cfg.MODEL.DETR.GIOU_WEIGHT
l1_weight = cfg.MODEL.DETR.L1_WEIGHT
deep_supervision = cfg.MODEL.DETR.DEEP_SUPERVISION
no_object_weight = cfg.MODEL.DETR.NO_OBJECT_WEIGHT
N_steps = hidden_dim // 2
d2_backbone = MaskedBackbone(cfg)
backbone = Joiner(d2_backbone, PositionEmbeddingSine(N_steps, normalize=True))
backbone.num_channels = d2_backbone.num_channels
transformer = Transformer(
d_model=hidden_dim,
dropout=dropout,
nhead=nheads,
dim_feedforward=dim_feedforward,
num_encoder_layers=enc_layers,
num_decoder_layers=dec_layers,
normalize_before=pre_norm,
return_intermediate_dec=deep_supervision,
)
self.detr = DETR(
backbone, transformer, num_classes=self.num_classes, num_queries=num_queries, aux_loss=deep_supervision
)
if self.mask_on:
frozen_weights = cfg.MODEL.DETR.FROZEN_WEIGHTS
if frozen_weights != '':
print("LOAD pre-trained weights")
weight = torch.load(frozen_weights, map_location=lambda storage, loc: storage)['model']
new_weight = {}
for k, v in weight.items():
if 'detr.' in k:
new_weight[k.replace('detr.', '')] = v
else:
print(f"Skipping loading weight {k} from frozen model")
del weight
self.detr.load_state_dict(new_weight)
del new_weight
self.detr = DETRsegm(self.detr, freeze_detr=(frozen_weights != ''))
self.seg_postprocess = PostProcessSegm
self.detr.to(self.device)
# building criterion
matcher = HungarianMatcher(cost_class=1, cost_bbox=l1_weight, cost_giou=giou_weight)
weight_dict = {"loss_ce": 1, "loss_bbox": l1_weight}
weight_dict["loss_giou"] = giou_weight
if deep_supervision:
aux_weight_dict = {}
for i in range(dec_layers - 1):
aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
weight_dict.update(aux_weight_dict)
losses = ["labels", "boxes", "cardinality"]
if self.mask_on:
losses += ["masks"]
self.criterion = SetCriterion(
self.num_classes, matcher=matcher, weight_dict=weight_dict, eos_coef=no_object_weight, losses=losses,
)
self.criterion.to(self.device)
pixel_mean = torch.Tensor(cfg.MODEL.PIXEL_MEAN).to(self.device).view(3, 1, 1)
pixel_std = torch.Tensor(cfg.MODEL.PIXEL_STD).to(self.device).view(3, 1, 1)
self.normalizer = lambda x: (x - pixel_mean) / pixel_std
self.to(self.device)