train.py (161 lines of code) (raw):

import argparse import torch from dassl.utils import setup_logger, set_random_seed, collect_env_info from dassl.config import get_cfg_default from dassl.engine import build_trainer # custom import datasets.oxford_pets import datasets.oxford_flowers import datasets.fgvc_aircraft import datasets.dtd import datasets.eurosat import datasets.stanford_cars import datasets.food101 import datasets.sun397 import datasets.caltech101 import datasets.ucf101 import datasets.imagenet import datasets.cifar_ import datasets.imagenet_sketch import datasets.imagenetv2 import datasets.imagenet_a # import datasets.imagenet_r import trainers.catex def print_args(args, cfg): print("***************") print("** Arguments **") print("***************") optkeys = list(args.__dict__.keys()) optkeys.sort() for key in optkeys: print("{}: {}".format(key, args.__dict__[key])) print("************") print("** Config **") print("************") print(cfg) def reset_cfg(cfg, args): if args.root: cfg.DATASET.ROOT = args.root if args.ood_test: cfg.TRAINER.OOD_TEST = args.ood_test if args.ood_train: cfg.TRAINER.OOD_TRAIN = args.ood_train if args.output_dir: cfg.OUTPUT_DIR = args.output_dir if args.resume: cfg.RESUME = args.resume if args.seed: cfg.SEED = args.seed if args.source_domains: cfg.DATASET.SOURCE_DOMAINS = args.source_domains if args.target_domains: cfg.DATASET.TARGET_DOMAINS = args.target_domains if args.transforms: cfg.INPUT.TRANSFORMS = args.transforms if args.trainer: cfg.TRAINER.NAME = args.trainer if args.backbone: cfg.MODEL.BACKBONE.NAME = args.backbone if args.head: cfg.MODEL.HEAD.NAME = args.head def extend_cfg(cfg): """ Add new config variables. E.g. from yacs.config import CfgNode as CN cfg.TRAINER.MY_MODEL = CN() cfg.TRAINER.MY_MODEL.PARAM_A = 1. cfg.TRAINER.MY_MODEL.PARAM_B = 0.5 cfg.TRAINER.MY_MODEL.PARAM_C = False """ from yacs.config import CfgNode as CN cfg.TRAINER.CATEX = CN() cfg.TRAINER.CATEX.N_CTX = 16 # number of context vectors cfg.TRAINER.CATEX.CSC = False # class-specific context cfg.TRAINER.CATEX.CTX_INIT = "" # initialization words a photo of a / ensemble / ensemble with learned cfg.TRAINER.CATEX.PREC = "fp16" # fp16, fp32, amp cfg.TRAINER.CATEX.CLASS_TOKEN_POSITION = "end" # 'middle' or 'end' or 'front' cfg.DATASET.SUBSAMPLE_CLASSES = "all" # all, base or new def setup_cfg(args): cfg = get_cfg_default() extend_cfg(cfg) # 1. From the dataset config file if args.dataset_config_file: cfg.merge_from_file(args.dataset_config_file) # 2. From the method config file if args.config_file: cfg.merge_from_file(args.config_file) # 3. From input arguments reset_cfg(cfg, args) # 4. From optional input arguments cfg.merge_from_list(args.opts) cfg.freeze() return cfg def main(args): cfg = setup_cfg(args) if cfg.SEED >= 0: print("Setting fixed seed: {}".format(cfg.SEED)) set_random_seed(cfg.SEED) setup_logger(cfg.OUTPUT_DIR) if torch.cuda.is_available() and cfg.USE_CUDA: torch.backends.cudnn.benchmark = True # print_args(args, cfg) # print("Collecting env info ...") # print("** System info **\n{}\n".format(collect_env_info())) trainer = build_trainer(cfg) if args.ood_test: if args.model_dir != '': assert cfg.TRAINER.CATEX.CTX_INIT in ['', None, 'ensemble', 'ensemble_learned'] trainer.load_model(args.model_dir, epoch=args.load_epoch) trainer.test_ood(model_directory=args.model_dir) return if args.eval_only: if cfg.TRAINER.CATEX.CTX_INIT == '': trainer.load_model(args.model_dir, epoch=args.load_epoch) trainer.test() return if not args.no_train: if args.model_dir != '': trainer.load_model(args.model_dir, epoch=args.load_epoch) if cfg.TRAINER.OOD_TRAIN: trainer.forward_backward = trainer.forward_backward_ood trainer.train() if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--root", type=str, default="", help="path to dataset") parser.add_argument("--output-dir", type=str, default="", help="output directory") parser.add_argument( "--resume", type=str, default="", help="checkpoint directory (from which the training resumes)", ) parser.add_argument( "--seed", type=int, default=-1, help="only positive value enables a fixed seed" ) parser.add_argument( "--source-domains", type=str, nargs="+", help="source domains for DA/DG" ) parser.add_argument( "--target-domains", type=str, nargs="+", help="target domains for DA/DG" ) parser.add_argument( "--transforms", type=str, nargs="+", help="data augmentation methods" ) parser.add_argument( "--config-file", type=str, default="", help="path to config file" ) parser.add_argument( "--dataset-config-file", type=str, default="", help="path to config file for dataset setup", ) parser.add_argument("--ood-test", action="store_true", help="flag for ood test") parser.add_argument("--ood-train", action="store_true", help="flag for ood train") parser.add_argument("--trainer", type=str, default="", help="name of trainer") parser.add_argument("--backbone", type=str, default="", help="name of CNN backbone") parser.add_argument("--head", type=str, default="", help="name of head") parser.add_argument("--eval-only", action="store_true", help="evaluation only") parser.add_argument( "--model-dir", type=str, default="", help="load model from this directory for eval-only mode", ) parser.add_argument( "--load-epoch", type=int, help="load model weights at this epoch for evaluation" ) parser.add_argument( "--no-train", action="store_true", help="do not call trainer.train()" ) parser.add_argument( "opts", default=None, nargs=argparse.REMAINDER, help="modify config options using the command-line", ) args = parser.parse_args() main(args)