in ssl/real-dataset/main.py [0:0]
def main(args):
log.info("Command line: \n\n" + common_utils.pretty_print_cmd(sys.argv))
log.info(f"Working dir: {os.getcwd()}")
log.info("\n" + common_utils.get_git_hash())
log.info("\n" + common_utils.get_git_diffs())
os.environ["CUDA_VISIBLE_DEVICES"] = f"{args.gpu}"
torch.manual_seed(args.seed)
log.info(args.pretty())
device = 'cuda' if torch.cuda.is_available() else 'cpu'
log.info(f"Training with: {device}")
data_transform = get_simclr_data_transforms_train(args['dataset'])
data_transform_identity = get_simclr_data_transforms_test(args['dataset'])
if args["dataset"] == "stl10":
train_dataset = datasets.STL10(args.dataset_path, split='train+unlabeled', download=True,
transform=MultiViewDataInjector([data_transform, data_transform, data_transform_identity]))
elif args["dataset"] == "cifar10":
train_dataset = datasets.CIFAR10(args.dataset_path, train=True, download=True,
transform=MultiViewDataInjector([data_transform, data_transform, data_transform_identity]))
else:
raise RuntimeError(f"Unknown dataset! {args['dataset']}")
args = hydra2dict(args)
train_params = args["trainer"]
if train_params["projector_same_as_predictor"]:
train_params["projector_params"] = train_params["predictor_params"]
# online network
online_network = ResNet18(dataset=args["dataset"], options=train_params["projector_params"], **args['network']).to(device)
if torch.cuda.device_count() > 1:
online_network = torch.nn.parallel.DataParallel(online_network)
pretrained_path = args['network']['pretrained_path']
if pretrained_path:
try:
load_params = torch.load(pretrained_path, map_location=torch.device(device))
online_network.load_state_dict(load_params['online_network_state_dict'])
online_network.load_state_dict(load_params)
log.info("Load from {}.".format(pretrained_path))
except FileNotFoundError:
log.info("Pre-trained weights not found. Training from scratch.")
# predictor network
if train_params["has_predictor"] and args["method"] == "byol":
predictor = MLPHead(in_channels=args['network']['projection_head']['projection_size'],
**args['network']['predictor_head'], options=train_params["predictor_params"]).to(device)
if torch.cuda.device_count() > 1:
predictor = torch.nn.parallel.DataParallel(predictor)
else:
predictor = None
# target encoder
target_network = ResNet18(dataset=args["dataset"], options=train_params["projector_params"], **args['network']).to(device)
if torch.cuda.device_count() > 1:
target_network = torch.nn.parallel.DataParallel(target_network)
params = online_network.parameters()
# Save network and parameters.
torch.save(args, "args.pt")
if args["eval_after_each_epoch"]:
evaluator = Evaluator(args["dataset"], args["dataset_path"], args["test"]["batch_size"])
else:
evaluator = None
if args["use_optimizer"] == "adam":
optimizer = torch.optim.Adam(params, lr=args['optimizer']['params']["lr"], weight_decay=args["optimizer"]["params"]['weight_decay'])
elif args["use_optimizer"] == "sgd":
optimizer = torch.optim.SGD(params, **args['optimizer']['params'])
else:
raise RuntimeError(f"Unknown optimizer! {args['use_optimizer']}")
if args["predictor_optimizer_same"]:
args["predictor_optimizer"] = args["optimizer"]
if predictor and train_params["train_predictor"]:
predictor_optimizer = torch.optim.SGD(predictor.parameters(), **args['predictor_optimizer']['params'])
## SimCLR scheduler
if args["method"] == "simclr":
trainer = SimCLRTrainer(log_dir="./", model=online_network, optimizer=optimizer, evaluator=evaluator, device=device, params=args["trainer"])
elif args["method"] == "byol":
trainer = BYOLTrainer(log_dir="./",
online_network=online_network,
target_network=target_network,
optimizer=optimizer,
predictor_optimizer=predictor_optimizer,
predictor=predictor,
device=device,
evaluator=evaluator,
**args['trainer'])
else:
raise RuntimeError(f'Unknown method {args["method"]}')
trainer.train(train_dataset)
if not args["eval_after_each_epoch"]:
result_eval = linear_eval(args["dataset"], args["dataset_path"], args["test"]["batch_size"], ["./"], [])
torch.save(result_eval, "eval.pth")