in d2go/data/gans.py [0:0]
def inject_gan_datasets(cfg):
if cfg.D2GO_DATA.DATASETS.GAN_INJECTION.ENABLE:
name = cfg.D2GO_DATA.DATASETS.GAN_INJECTION.NAME
cfg.merge_from_list(
[
"DATASETS.TRAIN",
list(cfg.DATASETS.TRAIN) + [name + "_train"],
"DATASETS.TEST",
list(cfg.DATASETS.TEST) + [name + "_test"],
]
)
json_path = cfg.D2GO_DATA.DATASETS.GAN_INJECTION.JSON_PATH
assert PathManager.isfile(json_path), "{} is not valid!".format(json_path)
if len(cfg.D2GO_DATA.DATASETS.GAN_INJECTION.LOCAL_DIR) > 0:
image_dir = cfg.D2GO_DATA.DATASETS.GAN_INJECTION.LOCAL_DIR
else:
image_dir = Path(tempfile.mkdtemp())
input_src_path = cfg.D2GO_DATA.DATASETS.GAN_INJECTION.INPUT_SRC_DIR
assert PathManager.isfile(input_src_path), "{} is not valid!".format(
input_src_path
)
input_folder = os.path.join(image_dir, name, "input")
gt_src_path = cfg.D2GO_DATA.DATASETS.GAN_INJECTION.GT_SRC_DIR
if PathManager.isfile(gt_src_path):
gt_folder = os.path.join(image_dir, name, "gt")
else:
gt_src_path = None
gt_folder = None
mask_src_path = cfg.D2GO_DATA.DATASETS.GAN_INJECTION.MASK_SRC_DIR
if PathManager.isfile(mask_src_path):
mask_folder = os.path.join(image_dir, name, "mask")
else:
mask_src_path = None
mask_folder = None
input_extras_src_path = (
cfg.D2GO_DATA.DATASETS.GAN_INJECTION.INPUT_EXTRAS_SRC_DIR
)
if PathManager.isfile(input_extras_src_path):
input_extras_folder = os.path.join(image_dir, name, "input_extras")
else:
input_extras_src_path = None
input_extras_folder = None
real_src_path = cfg.D2GO_DATA.DATASETS.GAN_INJECTION.REAL_SRC_DIR
if PathManager.isfile(real_src_path):
real_folder = os.path.join(image_dir, name, "real")
real_json_path = cfg.D2GO_DATA.DATASETS.GAN_INJECTION.REAL_JSON_PATH
assert PathManager.isfile(real_json_path), "{} is not valid!".format(
real_json_path
)
else:
real_src_path = None
real_folder = None
real_json_path = None
register_folder_dataset(
name + "_train",
json_path,
input_folder,
gt_folder,
mask_folder,
input_extras_folder,
input_src_path,
gt_src_path,
mask_src_path,
input_extras_src_path,
real_json_path,
real_folder,
real_src_path,
)
register_folder_dataset(
name + "_test",
json_path,
input_folder,
gt_folder,
mask_folder,
input_extras_folder,
input_src_path,
gt_src_path,
mask_src_path,
input_extras_src_path,
real_json_path,
real_folder,
real_src_path,
max_num=cfg.D2GO_DATA.DATASETS.GAN_INJECTION.MAX_TEST_IMAGES,
)