in parlai/scripts/train_model.py [0:0]
def __init__(self, opt):
# if python is called from a non-interactive shell, like a bash script,
# it will by-default ignore SIGINTs, and KeyboardInterrupt exceptions are
# not produced. This line brings them back
signal.signal(signal.SIGINT, signal.default_int_handler)
# Possibly load from checkpoint
trainstats_suffix = '.trainstats' # we might load training statistics from here
if (
opt['load_from_checkpoint']
and opt.get('model_file')
and PathManager.exists(opt['model_file'] + '.checkpoint')
):
opt['init_model'] = opt['model_file'] + '.checkpoint'
trainstats_suffix = '.checkpoint.trainstats'
# Possibly build a dictionary (not all models do this).
if not (opt.get('dict_file') or opt.get('model_file')):
raise RuntimeError(
'WARNING: For train_model, please specify either a '
'model_file or dict_file.'
)
if 'dict_file' in opt:
if opt['dict_file'] is None and opt.get('model_file'):
opt['dict_file'] = opt['model_file'] + '.dict'
logging.info("building dictionary first...")
build_dict(opt, skip_if_built=True)
# Create model and assign it to the specified task
self.agent = create_agent(opt)
self.agent.opt.log()
self.world = create_task(opt, self.agent)
# set up timers
self.train_time = Timer()
self.validate_time = Timer()
self.log_time = Timer()
self.save_time = Timer()
self.parleys = 0
self._train_steps = 0
self._last_log_steps = 0
self.update_freq = opt.get('update_freq', 1)
self.max_num_epochs = _num_else_inf(opt, 'num_epochs', distributed_warn=True)
self.max_train_time = _num_else_inf(
opt, 'max_train_time', distributed_warn=True
)
self.max_train_steps = _num_else_inf(opt, 'max_train_steps')
self.log_every_n_secs = _num_else_inf(
opt, 'log_every_n_secs', distributed_warn=True
)
self.log_every_n_steps = _num_else_inf(opt, 'log_every_n_steps')
self.val_every_n_secs = _num_else_inf(
opt, 'validation_every_n_secs', distributed_warn=True
)
self.val_every_n_epochs = _num_else_inf(
opt, 'validation_every_n_epochs', distributed_warn=True
)
self.val_every_n_steps = _num_else_inf(opt, 'validation_every_n_steps')
self.save_every_n_secs = _num_else_inf(
opt, 'save_every_n_secs', distributed_warn=True
)
# smart defaults for --validation-metric-mode
if opt['validation_metric'] in {'loss', 'ppl', 'mean_rank'}:
opt['validation_metric_mode'] = 'min'
elif opt['validation_metric'] in {'accuracy', 'hits@1', 'hits@5', 'f1', 'bleu'}:
opt['validation_metric_mode'] = 'max'
if opt.get('validation_metric_mode') is None:
opt['validation_metric_mode'] = 'max'
self.last_valid_epoch = 0
self._last_valid_steps = 0
self.valid_optim = 1 if opt['validation_metric_mode'] == 'max' else -1
self.train_reports = []
self.valid_reports = []
self.final_valid_report = {}
self.final_test_report = {}
self.final_extra_valid_report = {}
self.best_valid = None
self.impatience = 0
self.saved = False
self.valid_worlds = None
self.opt = opt
# we may have been preempted, make sure we note that amount
self._preempted_epochs = 0.0
if opt.get('model_file') and PathManager.exists(
opt['model_file'] + trainstats_suffix
):
# looks like we were preempted. make sure we load up our total
# training stats, etc
with PathManager.open(opt['model_file'] + trainstats_suffix) as ts:
obj = json.load(ts)
self.parleys = obj.get('parleys', 0)
self._preempted_epochs = obj.get('total_epochs', 0)
self.train_time.total = obj.get('train_time', 0)
self._train_steps = obj.get('train_steps', 0)
self.impatience = obj.get('impatience', 0)
self.valid_reports = obj.get('valid_reports', [])
if self.valid_reports:
self.last_valid_epoch = self.valid_reports[-1].get(
'total_epochs', 0.0
)
self.train_reports = obj.get('train_reports', [])
if 'best_valid' in obj:
self.best_valid = obj['best_valid']
else:
# old method
if opt.get('model_file') and PathManager.exists(
opt['model_file'] + '.best_valid'
):
with PathManager.open(
opt['model_file'] + ".best_valid", 'r'
) as f:
x = f.readline()
self.best_valid = float(x)
f.close()
if opt['tensorboard_log'] and is_primary_worker():
self.tb_logger = TensorboardLogger(opt)
if opt['wandb_log'] and is_primary_worker():
model = self.agent.model if hasattr(self.agent, 'model') else None
self.wb_logger = WandbLogger(opt, model)