in easycv/models/detection/detectors/dino/dino_head.py [0:0]
def __init__(
self,
num_classes,
embed_dims,
in_channels=[512, 1024, 2048],
query_dim=4,
num_queries=300,
num_select=300,
random_refpoints_xy=False,
num_patterns=0,
dn_components=None,
transformer=None,
fix_refpoints_hw=-1,
num_feature_levels=1,
# two stage
two_stage_type='standard', # ['no', 'standard']
two_stage_add_query_num=0,
dec_pred_class_embed_share=True,
dec_pred_bbox_embed_share=True,
two_stage_class_embed_share=True,
two_stage_bbox_embed_share=True,
use_centerness=False,
use_iouaware=False,
losses_list=['labels', 'boxes'],
decoder_sa_type='sa',
temperatureH=20,
temperatureW=20,
cost_dict={
'cost_class': 1,
'cost_bbox': 5,
'cost_giou': 2,
},
weight_dict={
'loss_ce': 1,
'loss_bbox': 5,
'loss_giou': 2
},
**kwargs):
super(DINOHead, self).__init__()
self.matcher = HungarianMatcher(
cost_dict=cost_dict, cost_class_type='focal_loss_cost')
self.criterion = SetCriterion(
num_classes,
matcher=self.matcher,
weight_dict=weight_dict,
losses=losses_list,
loss_class_type='focal_loss')
if dn_components is not None:
self.dn_criterion = CDNCriterion(
num_classes,
matcher=self.matcher,
weight_dict=weight_dict,
losses=losses_list,
loss_class_type='focal_loss')
self.postprocess = DetrPostProcess(
num_select=num_select,
use_centerness=use_centerness,
use_iouaware=use_iouaware)
self.transformer = build_neck(transformer)
self.positional_encoding = PositionEmbeddingSineHW(
embed_dims // 2,
temperatureH=temperatureH,
temperatureW=temperatureW,
normalize=True)
self.num_classes = num_classes
self.num_queries = num_queries
self.embed_dims = embed_dims
self.query_dim = query_dim
self.dn_components = dn_components
self.random_refpoints_xy = random_refpoints_xy
self.fix_refpoints_hw = fix_refpoints_hw
# for dn training
self.dn_number = self.dn_components['dn_number']
self.dn_box_noise_scale = self.dn_components['dn_box_noise_scale']
self.dn_label_noise_ratio = self.dn_components['dn_label_noise_ratio']
self.dn_labelbook_size = self.dn_components['dn_labelbook_size']
self.label_enc = nn.Embedding(self.dn_labelbook_size + 1, embed_dims)
# prepare input projection layers
self.num_feature_levels = num_feature_levels
if num_feature_levels > 1:
num_backbone_outs = len(in_channels)
input_proj_list = []
for i in range(num_backbone_outs):
in_channels_i = in_channels[i]
input_proj_list.append(
nn.Sequential(
nn.Conv2d(in_channels_i, embed_dims, kernel_size=1),
nn.GroupNorm(32, embed_dims),
))
for _ in range(num_feature_levels - num_backbone_outs):
input_proj_list.append(
nn.Sequential(
nn.Conv2d(
in_channels_i,
embed_dims,
kernel_size=3,
stride=2,
padding=1),
nn.GroupNorm(32, embed_dims),
))
in_channels_i = embed_dims
self.input_proj = nn.ModuleList(input_proj_list)
else:
assert two_stage_type == 'no', 'two_stage_type should be no if num_feature_levels=1 !!!'
self.input_proj = nn.ModuleList([
nn.Sequential(
nn.Conv2d(in_channels[-1], embed_dims, kernel_size=1),
nn.GroupNorm(32, embed_dims),
)
])
# prepare pred layers
self.dec_pred_class_embed_share = dec_pred_class_embed_share
self.dec_pred_bbox_embed_share = dec_pred_bbox_embed_share
# prepare class & box embed
_class_embed = nn.Linear(embed_dims, num_classes)
_bbox_embed = MLP(embed_dims, embed_dims, 4, 3)
# init the two embed layers
prior_prob = 0.01
bias_value = -math.log((1 - prior_prob) / prior_prob)
_class_embed.bias.data = torch.ones(self.num_classes) * bias_value
nn.init.constant_(_bbox_embed.layers[-1].weight.data, 0)
nn.init.constant_(_bbox_embed.layers[-1].bias.data, 0)
# fcos centerness & iou-aware & tokenlabel
self.use_centerness = use_centerness
self.use_iouaware = use_iouaware
if self.use_centerness:
_center_embed = MLP(embed_dims, embed_dims, 1, 3)
if self.use_iouaware:
_iou_embed = MLP(embed_dims, embed_dims, 1, 3)
if dec_pred_bbox_embed_share:
box_embed_layerlist = [
_bbox_embed for i in range(transformer.num_decoder_layers)
]
if self.use_centerness:
center_embed_layerlist = [
_center_embed
for i in range(transformer.num_decoder_layers)
]
if self.use_iouaware:
iou_embed_layerlist = [
_iou_embed for i in range(transformer.num_decoder_layers)
]
else:
box_embed_layerlist = [
copy.deepcopy(_bbox_embed)
for i in range(transformer.num_decoder_layers)
]
if self.use_centerness:
center_embed_layerlist = [
copy.deepcopy(_center_embed)
for i in range(transformer.num_decoder_layers)
]
if self.use_iouaware:
iou_embed_layerlist = [
copy.deepcopy(_iou_embed)
for i in range(transformer.num_decoder_layers)
]
if dec_pred_class_embed_share:
class_embed_layerlist = [
_class_embed for i in range(transformer.num_decoder_layers)
]
else:
class_embed_layerlist = [
copy.deepcopy(_class_embed)
for i in range(transformer.num_decoder_layers)
]
self.bbox_embed = nn.ModuleList(box_embed_layerlist)
self.class_embed = nn.ModuleList(class_embed_layerlist)
self.transformer.decoder.bbox_embed = self.bbox_embed
self.transformer.decoder.class_embed = self.class_embed
if self.use_centerness:
self.center_embed = nn.ModuleList(center_embed_layerlist)
self.transformer.decoder.center_embed = self.center_embed
if self.use_iouaware:
self.iou_embed = nn.ModuleList(iou_embed_layerlist)
self.transformer.decoder.iou_embed = self.iou_embed
# two stage
self.two_stage_type = two_stage_type
self.two_stage_add_query_num = two_stage_add_query_num
assert two_stage_type in [
'no', 'standard'
], 'unknown param {} of two_stage_type'.format(two_stage_type)
if two_stage_type != 'no':
if two_stage_bbox_embed_share:
assert dec_pred_class_embed_share and dec_pred_bbox_embed_share
self.transformer.enc_out_bbox_embed = _bbox_embed
if self.use_centerness:
self.transformer.enc_out_center_embed = _center_embed
if self.use_iouaware:
self.transformer.enc_out_iou_embed = _iou_embed
else:
self.transformer.enc_out_bbox_embed = copy.deepcopy(
_bbox_embed)
if self.use_centerness:
self.transformer.enc_out_center_embed = copy.deepcopy(
_center_embed)
if self.use_iouaware:
self.transformer.enc_out_iou_embed = copy.deepcopy(
_iou_embed)
if two_stage_class_embed_share:
assert dec_pred_class_embed_share and dec_pred_bbox_embed_share
self.transformer.enc_out_class_embed = _class_embed
else:
self.transformer.enc_out_class_embed = copy.deepcopy(
_class_embed)
self.refpoint_embed = None
if self.two_stage_add_query_num > 0:
self.init_ref_points(two_stage_add_query_num)
self.decoder_sa_type = decoder_sa_type
assert decoder_sa_type in ['sa', 'ca_label', 'ca_content']
# self.replace_sa_with_double_ca = replace_sa_with_double_ca
if decoder_sa_type == 'ca_label':
self.label_embedding = nn.Embedding(num_classes, embed_dims)
for layer in self.transformer.decoder.layers:
layer.label_embedding = self.label_embedding
else:
for layer in self.transformer.decoder.layers:
layer.label_embedding = None
self.label_embedding = None