in classy_vision/tasks/classification_task.py [0:0]
def from_config(cls, config: Dict[str, Any]) -> "ClassificationTask":
"""Instantiates a ClassificationTask from a configuration.
Args:
config: A configuration for a ClassificationTask.
See :func:`__init__` for parameters expected in the config.
Returns:
A ClassificationTask instance.
"""
test_only = config.get("test_only", False)
if not test_only:
# TODO Make distinction between epochs and phases in optimizer clear
train_phases_per_epoch = config["dataset"]["train"].get(
"phases_per_epoch", 1
)
optimizer_config = config["optimizer"]
optimizer_config["num_epochs"] = (
config["num_epochs"] * train_phases_per_epoch
)
optimizer = build_optimizer(optimizer_config)
param_schedulers = build_optimizer_schedulers(optimizer_config)
datasets = {}
phase_types = ["train", "test"]
for phase_type in phase_types:
if phase_type in config["dataset"]:
datasets[phase_type] = build_dataset(config["dataset"][phase_type])
loss = build_loss(config["loss"])
amp_args = config.get("amp_args")
meters = build_meters(config.get("meters", {}))
model = build_model(config["model"])
mixup_transform = None
if config.get("mixup") is not None:
assert "alpha" in config["mixup"], "key alpha is missing in mixup dict"
mixup_transform = MixupTransform(
config["mixup"]["alpha"],
num_classes=config["mixup"].get("num_classes"),
cutmix_alpha=config["mixup"].get("cutmix_alpha", 0),
cutmix_minmax=config["mixup"].get("cutmix_minmax"),
mix_prob=config["mixup"].get("mix_prob", 1.0),
switch_prob=config["mixup"].get("switch_prob", 0.5),
mode=config["mixup"].get("mode", "batch"),
label_smoothing=config["mixup"].get("label_smoothing", 0.0),
)
# hooks config is optional
hooks_config = config.get("hooks")
hooks = []
if hooks_config is not None:
hooks = build_hooks(hooks_config)
distributed_config = config.get("distributed", {})
distributed_options = {
"broadcast_buffers_mode": BroadcastBuffersMode[
distributed_config.get("broadcast_buffers", "before_eval").upper()
],
"batch_norm_sync_mode": BatchNormSyncMode[
distributed_config.get("batch_norm_sync_mode", "disabled").upper()
],
"batch_norm_sync_group_size": distributed_config.get(
"batch_norm_sync_group_size", 0
),
"find_unused_parameters": distributed_config.get(
"find_unused_parameters", False
),
"bucket_cap_mb": distributed_config.get("bucket_cap_mb", 25),
"fp16_grad_compress": distributed_config.get("fp16_grad_compress", False),
}
task = (
cls()
.set_num_epochs(config["num_epochs"])
.set_test_phase_period(config.get("test_phase_period", 1))
.set_loss(loss)
.set_test_only(test_only)
.set_model(model)
.set_meters(meters)
.set_amp_args(amp_args)
.set_mixup_transform(mixup_transform)
.set_distributed_options(**distributed_options)
.set_hooks(hooks)
.set_bn_weight_decay(config.get("bn_weight_decay", False))
.set_clip_grad_norm(config.get("clip_grad_norm"))
.set_simulated_global_batchsize(config.get("simulated_global_batchsize"))
.set_use_sharded_ddp(config.get("use_sharded_ddp", False))
)
if not test_only:
task.set_optimizer(optimizer)
task.set_optimizer_schedulers(param_schedulers)
use_gpu = config.get("use_gpu")
if use_gpu is not None:
task.set_use_gpu(use_gpu)
for phase_type in datasets:
task.set_dataset(datasets[phase_type], phase_type)
# NOTE: this is a private member and only meant to be used for
# logging/debugging purposes. See __repr__ implementation
task._config = config
return task