torchbenchmark/models/timm_efficientdet/loader.py (67 lines of code) (raw):
from effdet.data import resolve_input_config, SkipSubset
from effdet import create_loader, create_dataset, create_evaluator
from effdet.anchors import Anchors, AnchorLabeler
from effdet.data.dataset_config import CocoCfg
from dataclasses import dataclass, field
from typing import Dict
@dataclass
class Coco2017MinimalCfg(CocoCfg):
variant: str = '2017-minimal'
splits: Dict[str, dict] = field(default_factory=lambda: dict(
train=dict(ann_filename='annotations/instances_val2017_100.json', img_dir='val2017', has_labels=True),
val=dict(ann_filename='annotations/instances_val2017_100.json', img_dir='val2017', has_labels=True),
))
def create_datasets_and_loaders(
args,
model_config,
transform_train_fn=None,
transform_eval_fn=None,
collate_fn=None,
):
""" Setup datasets, transforms, loaders, evaluator.
Args:
args: Command line args / config for training
model_config: Model specific configuration dict / struct
transform_train_fn: Override default image + annotation transforms (see note in loaders.py)
transform_eval_fn: Override default image + annotation transforms (see note in loaders.py)
collate_fn: Override default fast collate function
Returns:
Train loader, validation loader, evaluator
"""
input_config = resolve_input_config(args, model_config=model_config)
dataset_train, dataset_eval = create_dataset(args.dataset, args.root, custom_dataset_cfg=Coco2017MinimalCfg())
# setup labeler in loader/collate_fn if not enabled in the model bench
labeler = None
if not args.bench_labeler:
labeler = AnchorLabeler(
Anchors.from_config(model_config), model_config.num_classes, match_threshold=0.5)
loader_train = create_loader(
dataset_train,
input_size=input_config['input_size'],
batch_size=args.batch_size,
is_training=True,
use_prefetcher=args.prefetcher,
re_prob=args.reprob,
re_mode=args.remode,
re_count=args.recount,
# color_jitter=args.color_jitter,
# auto_augment=args.aa,
interpolation=args.train_interpolation or input_config['interpolation'],
fill_color=input_config['fill_color'],
mean=input_config['mean'],
std=input_config['std'],
num_workers=args.workers,
distributed=args.distributed,
pin_mem=args.pin_mem,
anchor_labeler=labeler,
transform_fn=transform_train_fn,
collate_fn=collate_fn,
)
if args.val_skip > 1:
dataset_eval = SkipSubset(dataset_eval, args.val_skip)
loader_eval = create_loader(
dataset_eval,
input_size=input_config['input_size'],
batch_size=args.batch_size,
is_training=False,
use_prefetcher=args.prefetcher,
interpolation=input_config['interpolation'],
fill_color=input_config['fill_color'],
mean=input_config['mean'],
std=input_config['std'],
num_workers=args.workers,
distributed=args.distributed,
pin_mem=args.pin_mem,
anchor_labeler=labeler,
transform_fn=transform_eval_fn,
collate_fn=collate_fn,
)
evaluator = create_evaluator(args.dataset, loader_eval.dataset, distributed=args.distributed, pred_yxyx=False)
return loader_train, loader_eval, evaluator, dataset_train, dataset_eval