in src/trainer.py [0:0]
def __init__(self, modules, env, params):
"""
Initialize trainer.
"""
# modules / params
self.modules = modules
self.params = params
self.env = env
# epoch / iteration size
self.epoch_size = params.epoch_size
if self.epoch_size == -1:
self.epoch_size = self.data
assert self.epoch_size > 0
# data iterators
self.iterators = {}
# set parameters
self.set_parameters()
# float16 / distributed (no AMP)
assert params.amp >= 1 or not params.fp16
assert params.amp >= 0 or params.accumulate_gradients == 1
if params.multi_gpu and params.amp == -1:
logger.info("Using nn.parallel.DistributedDataParallel ...")
for k in self.modules.keys():
self.modules[k] = nn.parallel.DistributedDataParallel(
self.modules[k],
device_ids=[params.local_rank],
output_device=params.local_rank,
broadcast_buffers=True,
)
# set optimizers
self.set_optimizers()
# float16 / distributed (AMP)
if params.amp >= 0:
self.init_amp()
if params.multi_gpu:
logger.info("Using apex.parallel.DistributedDataParallel ...")
for k in self.modules.keys():
self.modules[k] = apex.parallel.DistributedDataParallel(
self.modules[k], delay_allreduce=True
)
# stopping criterion used for early stopping
if params.stopping_criterion != "":
split = params.stopping_criterion.split(",")
assert len(split) == 2 and split[1].isdigit()
self.decrease_counts_max = int(split[1])
self.decrease_counts = 0
if split[0][0] == "_":
self.stopping_criterion = (split[0][1:], False)
else:
self.stopping_criterion = (split[0], True)
self.best_stopping_criterion = -1e12 if self.stopping_criterion[1] else 1e12
else:
self.stopping_criterion = None
self.best_stopping_criterion = None
# validation metrics
self.metrics = []
metrics = [m for m in params.validation_metrics.split(",") if m != ""]
for m in metrics:
m = (m[1:], False) if m[0] == "_" else (m, True)
self.metrics.append(m)
self.best_metrics = {
metric: (-1e12 if biggest else 1e12) for (metric, biggest) in self.metrics
}
# training statistics
self.epoch = 0
self.n_iter = 0
self.n_total_iter = 0
self.stats = OrderedDict(
[("processed_e", 0)]
+ [("processed_w", 0)]
+ sum(
[[(x, []), (f"{x}-AVG-STOP-PROBS", [])] for x in env.TRAINING_TASKS], []
)
)
self.last_time = time.time()
# reload potential checkpoints
self.reload_checkpoint()
# file handler to export data
if params.export_data:
assert params.reload_data == ""
params.export_path_prefix = os.path.join(params.dump_path, "data.prefix")
self.file_handler_prefix = io.open(
params.export_path_prefix, mode="a", encoding="utf-8"
)
logger.info(
f"Data will be stored in prefix in: {params.export_path_prefix} ..."
)
# reload exported data
if params.reload_data != "":
assert params.num_workers in [0, 1]
assert params.export_data is False
s = [x.split(",") for x in params.reload_data.split(";") if len(x) > 0]
assert (
len(s) >= 1
and all(len(x) == 4 for x in s)
and len(s) == len(set([x[0] for x in s]))
)
self.data_path = {
task: (train_path, valid_path, test_path)
for task, train_path, valid_path, test_path in s
}
assert all(
all(os.path.isfile(path) for path in paths)
for paths in self.data_path.values()
)
for task in self.env.TRAINING_TASKS:
assert (task in self.data_path) == (task in params.tasks)
else:
self.data_path = None
# create data loaders
if not params.eval_only:
if params.env_base_seed < 0:
params.env_base_seed = np.random.randint(1_000_000_000)
self.dataloader = {
task: iter(self.env.create_train_iterator(task, self.data_path, params))
for task in params.tasks
}