in easycv/core/evaluation/nuscenes_eval.py [0:0]
def load_gt(nusc: NuScenes, eval_split: str, box_cls, verbose: bool = False):
"""
Loads ground truth boxes from DB.
:param nusc: A NuScenes instance.
:param eval_split: The evaluation split for which we load GT boxes.
:param box_cls: Type of box to load, e.g. DetectionBox or TrackingBox.
:param verbose: Whether to print messages to stdout.
:return: The GT boxes.
"""
# Init.
if box_cls == CustomDetectionBox:
attribute_map = {a['token']: a['name'] for a in nusc.attribute}
if verbose:
print('Loading annotations for {} split from nuScenes version: {}'.
format(eval_split, nusc.version))
# Read out all sample_tokens in DB.
sample_tokens_all = [s['token'] for s in nusc.sample]
assert len(sample_tokens_all) > 0, 'Error: Database has no samples!'
# Only keep samples from this split.
splits = create_splits_scenes()
# Check compatibility of split with nusc_version.
version = nusc.version
if eval_split in {'train', 'val', 'train_detect', 'train_track'}:
assert version.endswith('trainval'), \
'Error: Requested split {} which is not compatible with NuScenes version {}'.format(eval_split, version)
elif eval_split in {'mini_train', 'mini_val'}:
assert version.endswith('mini'), \
'Error: Requested split {} which is not compatible with NuScenes version {}'.format(eval_split, version)
elif eval_split == 'test':
assert version.endswith('test'), \
'Error: Requested split {} which is not compatible with NuScenes version {}'.format(eval_split, version)
else:
raise ValueError(
'Error: Requested split {} which this function cannot map to the correct NuScenes version.'
.format(eval_split))
if eval_split == 'test':
# Check that you aren't trying to cheat :).
assert len(nusc.sample_annotation) > 0, \
'Error: You are trying to evaluate on the test set but you do not have the annotations!'
index_map = {}
for scene in nusc.scene:
first_sample_token = scene['first_sample_token']
sample = nusc.get('sample', first_sample_token)
index_map[first_sample_token] = 1
index = 2
while sample['next'] != '':
sample = nusc.get('sample', sample['next'])
index_map[sample['token']] = index
index += 1
sample_tokens = []
for sample_token in sample_tokens_all:
scene_token = nusc.get('sample', sample_token)['scene_token']
scene_record = nusc.get('scene', scene_token)
if scene_record['name'] in splits[eval_split]:
sample_tokens.append(sample_token)
all_annotations = EvalBoxes()
# Load annotations and filter predictions and annotations.
tracking_id_set = set()
for sample_token in tqdm.tqdm(sample_tokens, leave=verbose):
sample = nusc.get('sample', sample_token)
sample_annotation_tokens = sample['anns']
sample_boxes = []
for sample_annotation_token in sample_annotation_tokens:
sample_annotation = nusc.get('sample_annotation',
sample_annotation_token)
if box_cls == CustomDetectionBox:
# Get label name in detection task and filter unused labels.
detection_name = category_to_detection_name(
sample_annotation['category_name'])
if detection_name is None:
continue
# Get attribute_name.
attr_tokens = sample_annotation['attribute_tokens']
attr_count = len(attr_tokens)
if attr_count == 0:
attribute_name = ''
elif attr_count == 1:
attribute_name = attribute_map[attr_tokens[0]]
else:
raise Exception(
'Error: GT annotations must not have more than one attribute!'
)
sample_boxes.append(
box_cls(
token=sample_annotation_token,
sample_token=sample_token,
translation=sample_annotation['translation'],
size=sample_annotation['size'],
rotation=sample_annotation['rotation'],
velocity=nusc.box_velocity(
sample_annotation['token'])[:2],
num_pts=sample_annotation['num_lidar_pts'] +
sample_annotation['num_radar_pts'],
detection_name=detection_name,
detection_score=-1.0, # GT samples do not have a score.
attribute_name=attribute_name,
visibility=sample_annotation['visibility_token'],
index=index_map[sample_token]))
elif box_cls == TrackingBox:
assert False
else:
raise NotImplementedError('Error: Invalid box_cls %s!' %
box_cls)
all_annotations.add_boxes(sample_token, sample_boxes)
if verbose:
print('Loaded ground truth annotations for {} samples.'.format(
len(all_annotations.sample_tokens)))
return all_annotations