in get_training_params.py [0:0]
def get_method_config(args, setting):
method = args.method
base_config_dir = f"configs/{setting}/{method}.yaml"
try:
base_config = utils.get_yaml_config(base_config_dir)
except:
raise ValueError(f"{setting}/{method} not valid training configuration")
params = base_config["parameters"]
# Update base config with specific parameters
# Set normalization layers
config_norm = get_config_norm(params)
norm_to_use = config_norm if args.norm is None else args.norm
norm_types = ["IN", "BN", "GN"]
if norm_to_use not in norm_types:
raise ValueError(
f"Norm {norm_to_use} not valid. Supported normalization types: IN, BN, GN."
)
base_bn_type = base_config["parameters"]["bn_type"]
base_block_bn_type = base_config["parameters"]["block_bn_type"]
for n in norm_types:
if n in base_bn_type:
base_bn = n
if n in base_block_bn_type:
base_block_bn = n
if setting == "quantized" and norm_to_use == "BN":
base_config["parameters"]["bn_type"] = "QuantStandardBN"
base_config["parameters"]["block_bn_type"] = "QuantStandardBN"
else:
base_config["parameters"]["bn_type"] = base_bn_type.replace(
base_bn, norm_to_use
)
base_config["parameters"]["block_bn_type"] = base_block_bn_type.replace(
base_block_bn, norm_to_use
)
# Set track_running_stats to True if using BN
if norm_to_use == "BN":
base_norm_kwargs = base_config["parameters"].get(
"norm_kwargs", None
)
if base_norm_kwargs is not None:
base_norm_kwargs["track_running_stats"] = True
else:
base_config["parameters"]["norm_kwargs"] = {
"track_running_stats": True
}
# Update dataset
dataset = args.dataset
base_config["parameters"]["dataset"] = dataset
# Update model
model_name = args.model.lower()
base_config["parameters"]["model_config"]["model_class"] = model_name
model_kwargs = params.get("model_kwargs", None)
if model_kwargs is not None:
base_model_kwargs = base_config["parameters"]["model_config"].get(
"model_kwargs", None
)
if base_model_kwargs is not None:
base_model_kwargs.update(model_kwargs)
else:
base_config["parameters"]["model_config"]["model_kwargs"] = {}
# Remove channel_selection_active for models without this parameter
if model_name not in ("cpreresnet20"):
base_model_kwargs = base_config["parameters"]["model_config"][
"model_kwargs"
]
base_model_kwargs.pop("channel_selection_active", None)
# For unstructured sparsity, if using GN, set num_groups to 32
# We cannot use GN with structured sparsity (number of channels isn't always
# divisible by 32), we use IN instead.
if norm_to_use == "GN":
if setting in ("unstructured_sparsity", "quantized"):
num_groups = 32
else:
raise NotImplementedError(
f"GroupNorm disabled for setting={setting}."
)
base_norm_kwargs = base_config["parameters"].get("norm_kwargs", None)
if base_norm_kwargs is not None:
base_norm_kwargs["num_groups"] = num_groups
else:
base_config["parameters"]["norm_kwargs"] = {
"num_groups": num_groups
}
# Set default model training parameters
model_training_params = training_params.model_data_params(args)
base_config["parameters"].update(model_training_params)
# Update training parameters with user-specified ones
if setting == "unstructured_sparsity":
args_dict = unstructured_args_dict(args)
elif setting == "structured_sparsity":
args_dict = structured_args_dict(args, base_config)
elif setting == "quantized":
args_dict = quantized_args_dict(args)
else:
args_dict = {}
base_config["parameters"].update(args_dict)
# Get/make save directory
args_save_dir = args.save_dir
if args_save_dir is None:
config_save_dir = params["save_dir"]
save_dir = utils.create_save_dir(
config_save_dir, method, model_name, dataset, norm_to_use
)
else:
save_dir = utils.create_save_dir(
args_save_dir, method, model_name, dataset, norm_to_use, False
)
base_config["parameters"]["save_dir"] = save_dir
# Print training details
utils.print_train_params(
base_config, setting, method, norm_to_use, save_dir
)
return base_config