in train_extractive_reader.py [0:0]
def __init__(self, cfg: DictConfig):
self.cfg = cfg
self.shard_id = cfg.local_rank if cfg.local_rank != -1 else 0
self.distributed_factor = cfg.distributed_world_size or 1
logger.info("***** Initializing components for training *****")
model_file = get_model_file(self.cfg, self.cfg.checkpoint_file_name)
saved_state = None
if model_file:
saved_state = load_states_from_checkpoint(model_file)
set_cfg_params_from_state(saved_state.encoder_params, cfg)
tensorizer, reader, optimizer = init_reader_components(cfg.encoder.encoder_model_type, cfg)
reader, optimizer = setup_for_distributed_mode(
reader,
optimizer,
cfg.device,
cfg.n_gpu,
cfg.local_rank,
cfg.fp16,
cfg.fp16_opt_level,
)
self.reader = reader
self.optimizer = optimizer
self.tensorizer = tensorizer
self.start_epoch = 0
self.start_batch = 0
self.scheduler_state = None
self.best_validation_result = None
self.best_cp_name = None
if saved_state:
self._load_saved_state(saved_state)