in utils/torch_utils.py [0:0]
def intersect_dicts(da, db, exclude=(), cfg_path:str=""):
# Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
if 'esod' not in cfg_path.lower():
print(f'Unknown ESOD model type: {cfg_path}. Load pretrained weights directly.')
intersect_dict = {k: v for k, v in da.items() \
if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape}
return intersect_dict
else:
print(f'Try to load pretrained weights for {cfg_path}.')
# load pretrained weights for ESOD models
if "yolo" in cfg_path.lower():
extra_start_layer, extra_num_layer = 5, 3
elif "rtmdet" in cfg_path.lower():
extra_start_layer, extra_num_layer = 7, 3
elif "retina" in cfg_path.lower():
extra_start_layer, extra_num_layer = 4, 3
elif "gpvit" in cfg_path.lower():
extra_start_layer, extra_num_layer = 1, 5
else:
raise NotImplementedError(f"Loading from pretrained weights for {cfg_path} needs to be specified.")
intersect_dict = {}
for k, v in db.items():
if any(x in k for x in exclude):
continue
items = k.split('.')
layer_number = int(items[1])
if extra_start_layer <= layer_number < extra_start_layer + extra_num_layer:
continue
k_ = k if layer_number < extra_start_layer else '.'.join([items[0], str(layer_number - extra_num_layer), *items[2:]])
if k_ in da and v.shape == da[k_].shape:
intersect_dict[k] = da[k_]
if len(intersect_dict) / len(db) < 0.6:
warnings.warn(f'Only {len(intersect_dict)/len(db)} items are loaded from pretrained weights.')
return intersect_dict