in ModelConf.py [0:0]
def configurate_training_params(self):
# optimizer
if self.phase == 'train':
if self.optimizer is None:
self.raise_configuration_error('training_params.optimizer')
if 'name' not in self.optimizer.keys():
self.raise_configuration_error('training_params.optimizer.name')
self.optimizer_name = self.optimizer['name']
if 'params' not in self.optimizer.keys():
self.raise_configuration_error('training_params.optimizer.params')
self.optimizer_params = self.optimizer['params']
if hasattr(self.params, 'learning_rate') and self.params.learning_rate:
self.optimizer_params['lr'] = self.params.learning_rate
# batch size
self.batch_size_each_gpu = self.batch_size # the batch_size in conf file is the batch_size on each GPU
if hasattr(self.params, 'batch_size') and self.params.batch_size:
self.batch_size_each_gpu = self.params.batch_size
if self.batch_size_each_gpu is None:
self.raise_configuration_error('training_params.batch_size')
self.batch_size_total = self.batch_size_each_gpu
if torch.cuda.device_count() > 1:
self.batch_size_total = torch.cuda.device_count() * self.batch_size_each_gpu
self.batch_num_to_show_results = self.batch_num_to_show_results // torch.cuda.device_count()
if hasattr(self.params, 'max_epoch') and self.params.max_epoch:
self.max_epoch = self.params.max_epoch
if self.valid_times_per_epoch is not None:
logging.info("configuration[training_params][valid_times_per_epoch] is deprecated, please use configuration[training_params][steps_per_validation] instead")
# sequence length
if self.fixed_lengths:
self.max_lengths = None
if ProblemTypes[self.problem_type] == ProblemTypes.sequence_tagging:
self.fixed_lengths = None
self.max_lengths = None
# text preprocessing
self.__text_preprocessing = self.text_preprocessing
self.DBC2SBC = True if 'DBC2SBC' in self.__text_preprocessing else False
self.unicode_fix = True if 'unicode_fix' in self.__text_preprocessing else False
self.remove_stopwords = True if 'remove_stopwords' in self.__text_preprocessing else False
# tokenzier
if self.tokenizer is None:
self.tokenizer = 'jieba' if self.language == 'chinese' else 'nltk'
# GPU/CPU
if self.phase != 'cache':
if torch.cuda.is_available() and torch.cuda.device_count() > 0 and self.use_gpu:
logging.info("Activating GPU mode, there are %d GPUs available" % torch.cuda.device_count())
else:
self.use_gpu = False
logging.info("Activating CPU mode")