in easycv/models/detection3d/detectors/mvx_two_stage.py [0:0]
def __init__(self,
pts_voxel_layer=None,
pts_voxel_encoder=None,
pts_middle_encoder=None,
pts_fusion_layer=None,
img_backbone=None,
pts_backbone=None,
img_neck=None,
pts_neck=None,
pts_bbox_head=None,
img_roi_head=None,
img_rpn_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None,
init_cfg=None):
super(MVXTwoStageDetector, self).__init__(init_cfg=init_cfg)
if pts_voxel_layer:
self.pts_voxel_layer = Voxelization(**pts_voxel_layer)
if pts_voxel_encoder:
self.pts_voxel_encoder = builder.build_voxel_encoder(
pts_voxel_encoder)
if pts_middle_encoder:
self.pts_middle_encoder = builder.build_middle_encoder(
pts_middle_encoder)
if pts_backbone:
self.pts_backbone = builder.build_backbone(pts_backbone)
if pts_fusion_layer:
self.pts_fusion_layer = builder.build_fusion_layer(
pts_fusion_layer)
if pts_neck is not None:
self.pts_neck = builder.build_neck(pts_neck)
if pts_bbox_head:
pts_train_cfg = train_cfg.pts if train_cfg else None
pts_bbox_head.update(train_cfg=pts_train_cfg)
pts_test_cfg = test_cfg.pts if test_cfg else None
pts_bbox_head.update(test_cfg=pts_test_cfg)
self.pts_bbox_head = builder.build_head(pts_bbox_head)
if img_backbone:
self.img_backbone = builder.build_backbone(img_backbone)
if img_neck is not None:
self.img_neck = builder.build_neck(img_neck)
if img_rpn_head is not None:
self.img_rpn_head = builder.build_head(img_rpn_head)
if img_roi_head is not None:
self.img_roi_head = builder.build_head(img_roi_head)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
if pretrained is None:
img_pretrained = None
pts_pretrained = None
elif isinstance(pretrained, dict):
img_pretrained = pretrained.get('img', None)
pts_pretrained = pretrained.get('pts', None)
else:
raise ValueError(
f'pretrained should be a dict, got {type(pretrained)}')
self.init_weights()
logger = get_root_logger()
if self.with_img_backbone:
if img_pretrained is not None:
load_checkpoint(
self.img_backbone,
img_pretrained,
strict=False,
logger=logger)
if self.with_img_roi_head:
if img_pretrained is not None:
load_checkpoint(
self.img_roi_head,
img_pretrained,
strict=False,
logger=logger)
if self.with_pts_backbone:
if pts_pretrained is not None:
load_checkpoint(
self.pts_backbone,
pts_pretrained,
strict=False,
logger=logger)