in domainbed_measures/model_spec.py [0:0]
def load_from_state(args: Dict,
hparams: Dict,
algorithm_dict: Dict,
dirty_ood_split: str,
target_test_env: int,
include_index=False):
"""Load from the model checkpoints things like weights and dataloaders.
Args:
args: Args used to train the model
hparams: Hyperparameters used for the model run
algorithm_dict: Weights of the trained model
target_test_env: Test environment we want to load (among a list of
competing alternatives)
include_index: Whether to include the index of a datapoint along
with label
"""
# Set random seeds
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
if target_test_env not in args.test_envs:
raise ValueError("Target test environment must be in "
"list of test envs used for model.")
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
if args.dataset in datasets.DATASETS:
dataset = vars(datasets)[args.dataset](args.data_dir, args.test_envs,
hparams)
else:
raise NotImplementedError
in_splits = []
out_splits = []
for env_i, env in enumerate(dataset):
out, in_ = split_dataset(env,
int(len(env) * args.holdout_fraction),
misc.seed_hash(args.trial_seed, env_i),
include_index=include_index)
if hparams['class_balanced']:
in_weights = misc.make_weights_for_balanced_classes(in_)
out_weights = misc.make_weights_for_balanced_classes(out)
else:
in_weights, out_weights = None, None
in_splits.append((in_, in_weights))
out_splits.append((out, out_weights))
train_loaders = [
DeterministicFastDataLoader(dataset=env,
batch_size=hparams['batch_size'],
num_workers=dataset.N_WORKERS)
for i, (env, _) in enumerate(in_splits) if i not in args.test_envs
]
eval_loaders = [
DeterministicFastDataLoader(dataset=env,
batch_size=9,
num_workers=dataset.N_WORKERS)
for env, _ in (in_splits + out_splits)
]
eval_loader_names = ['env{}_in'.format(i) for i in range(len(in_splits))]
eval_loader_names += [
'env{}_out'.format(i) for i in range(len(out_splits))
]
wd_eval_loader_names = [
'env{}_out'.format(i) for i, _ in enumerate(in_splits)
if i not in args.test_envs
]
# Dirty OOD is the split we touch for computing measures
dirty_ood_eval_loader_names = [
'env{}_{}'.format(i, dirty_ood_split) for i, _ in enumerate(in_splits)
if i == target_test_env
]
clean_ood_split = 'in'
if dirty_ood_split == 'in':
clean_ood_split = 'out'
clean_ood_eval_loader_names = [
'env{}_{}'.format(i, clean_ood_split) for i, _ in enumerate(in_splits)
if i == target_test_env
]
train_loader_names = [
'env{}_in'.format(i) for i, _ in enumerate(in_splits)
if i not in args.test_envs
]
logging.info("WD eval loaders:")
logging.info(wd_eval_loader_names)
logging.info("Dirty OOD eval loaders:")
logging.info(dirty_ood_eval_loader_names)
logging.info("Clean OOD eval loaders:")
logging.info(clean_ood_eval_loader_names)
logging.info("Train loaders:")
logging.info(train_loader_names)
wd_eval_loaders = []
dirty_ood_eval_loaders = []
clean_ood_eval_loaders = []
for this_name, this_loader in zip(eval_loader_names, eval_loaders):
if this_name in wd_eval_loader_names:
wd_eval_loaders.append(this_loader)
elif this_name in dirty_ood_eval_loader_names:
dirty_ood_eval_loaders.append(this_loader)
elif this_name in clean_ood_eval_loader_names:
clean_ood_eval_loaders.append(this_loader)
algorithm_class = algorithms.get_algorithm_class(args.algorithm)
algorithm = algorithm_class(dataset.input_shape, dataset.num_classes,
len(dataset) - len(args.test_envs), hparams)
if algorithm_dict is not None:
algorithm.load_state_dict(algorithm_dict, strict=True)
return (algorithm, train_loaders, wd_eval_loaders, dirty_ood_eval_loaders,
clean_ood_eval_loaders, dataset.num_classes)