in scripts/train_instance_seg.py [0:0]
def make_dataloader(args, config, rank, world_size):
config = config["dataloader"]
log_debug("Creating dataloaders for dataset in %s", args.data)
# Training dataloader
train_tf = ISSTransform(config.getint("shortest_size"),
config.getint("longest_max_size"),
config.getstruct("rgb_mean"),
config.getstruct("rgb_std"),
config.getboolean("random_flip"),
config.getstruct("random_scale"))
train_db = ISSDataset(args.data, config["train_set"], train_tf)
train_sampler = DistributedARBatchSampler(
train_db, config.getint("train_batch_size"), world_size, rank, True)
train_dl = data.DataLoader(train_db,
batch_sampler=train_sampler,
collate_fn=iss_collate_fn,
pin_memory=True,
num_workers=config.getint("num_workers"))
# Validation dataloader
val_tf = ISSTransform(config.getint("shortest_size"),
config.getint("longest_max_size"),
config.getstruct("rgb_mean"),
config.getstruct("rgb_std"))
val_db = ISSDataset(args.data, config["val_set"], val_tf)
val_sampler = DistributedARBatchSampler(
val_db, config.getint("val_batch_size"), world_size, rank, False)
val_dl = data.DataLoader(val_db,
batch_sampler=val_sampler,
collate_fn=iss_collate_fn,
pin_memory=True,
num_workers=config.getint("num_workers"))
return train_dl, val_dl