in ModelConf.py [0:0]
def check_conf(self):
""" verify if the configuration is legal or not
Returns:
"""
# In philly mode, ensure the data and model etc. are not the local paths defined in configuration file.
if self.mode == 'philly':
assert not (hasattr(self.params, 'train_data_path') and self.params.train_data_path is None and hasattr(self, 'train_data_path') and self.train_data_path), 'In philly mode, but you define a local train_data_path:%s in your configuration file' % self.train_data_path
assert not (hasattr(self.params, 'valid_data_path') and self.params.valid_data_path is None and hasattr(self, 'valid_data_path') and self.valid_data_path), 'In philly mode, but you define a local valid_data_path:%s in your configuration file' % self.valid_data_path
assert not (hasattr(self.params, 'test_data_path') and self.params.test_data_path is None and hasattr(self, 'test_data_path') and self.test_data_path), 'In philly mode, but you define a local test_data_path:%s in your configuration file' % self.test_data_path
if self.phase == 'train':
assert hasattr(self.params, 'model_save_dir') and self.params.model_save_dir, 'In philly mode, you must define a model save dir through the training params'
assert not (self.params.pretrained_model_path is None and self.pretrained_model_path), 'In philly mode, but you define a local pretrained model path:%s in your configuration file' % self.pretrained_model_path
assert not (self.pretrained_model_path is None and self.params.pretrained_emb_path is None and self.pretrained_emb_path), 'In philly mode, but you define a local pretrained embedding:%s in your configuration file' % self.pretrained_emb_path
elif self.phase == 'test' or self.phase == 'predict':
assert not (self.params.previous_model_path is None and self.previous_model_path), 'In philly mode, but you define a local model trained previously %s in your configuration file' % self.previous_model_path
# check inputs
# it seems that os.path.isfile cannot detect hdfs files
if self.phase == 'train':
assert self.train_data_path is not None, "Please define train_data_path"
assert os.path.isfile(self.train_data_path), "Training data %s does not exist!" % self.train_data_path
assert self.valid_data_path is not None, "Please define valid_data_path"
assert os.path.isfile(self.valid_data_path), "Training data %s does not exist!" % self.valid_data_path
if hasattr(self, 'pretrained_emb_type') and self.pretrained_emb_type:
assert self.pretrained_emb_type in set(['glove', 'word2vec', 'fasttext']), 'Embedding type %s is not supported! We support glove, word2vec, fasttext now.'
if hasattr(self, 'pretrained_emb_binary_or_text') and self.pretrained_emb_binary_or_text:
assert self.pretrained_emb_binary_or_text in set(['text', 'binary']), 'Embedding file type %s is not supported! We support text and binary.'
elif self.phase == 'test':
assert self.test_data_path is not None, "Please define test_data_path"
assert os.path.isfile(self.test_data_path), "Training data %s does not exist!" % self.test_data_path
elif self.phase == 'predict':
assert self.predict_data_path is not None, "Please define predict_data_path"
assert os.path.isfile(self.predict_data_path), "Training data %s does not exist!" % self.predict_data_path
# check language types
SUPPORTED_LANGUAGES = set(LanguageTypes._member_names_)
assert self.language in SUPPORTED_LANGUAGES, "Language type %s is not supported now. Supported types: %s" % (self.language, ",".join(SUPPORTED_LANGUAGES))
# check problem types
SUPPORTED_PROBLEMS = set(ProblemTypes._member_names_)
assert self.problem_type in SUPPORTED_PROBLEMS, "Data type %s is not supported now. Supported types: %s" % (self.problem_type, ",".join(SUPPORTED_PROBLEMS))
if ProblemTypes[self.problem_type] == ProblemTypes.sequence_tagging:
SUPPORTED_TAGGING_SCHEMES = set(TaggingSchemes._member_names_)
assert self.tagging_scheme is not None, "For sequence tagging proble, tagging scheme must be defined at configuration[\'inputs\'][\'tagging_scheme\']!"
assert self.tagging_scheme in SUPPORTED_TAGGING_SCHEMES, "Tagging scheme %s is not supported now. Supported schemes: %s" % (self.tagging_scheme, ",".join(SUPPORTED_TAGGING_SCHEMES))
# the max_lengths of all the inputs and targets should be consistent
if self.max_lengths:
max_lengths = list(self.max_lengths.values())
for i in range(len(max_lengths) - 1):
assert max_lengths[i] == max_lengths[i + 1], "For sequence tagging tasks, the max_lengths of all the inputs and targets should be consistent!"
# check appliable metrics
if self.phase == 'train' or self.phase == 'test':
self.metrics_post_check = set() # saved to check later
diff = set(self.metrics) - SupportedMetrics[ProblemTypes[self.problem_type]]
illegal_metrics = []
for diff_metric in diff:
if diff_metric.find('@') != -1:
field, target = diff_metric.split('@')
#if not field in PredictionTypes[ProblemTypes[self.problem_type]]:
if field != 'auc':
illegal_metrics.append(diff_metric)
else:
if target != 'average':
self.metrics_post_check.add(diff_metric)
if len(illegal_metrics) > 0:
raise Exception("Metrics %s are not supported for %s tasks!" % (",".join(list(illegal_metrics)), self.problem_type))
# check predict fields
if self.phase == 'predict':
self.predict_fields_post_check = set() # saved to check later
diff = set(self.predict_fields) - PredictionTypes[ProblemTypes[self.problem_type]]
illegal_fields = []
for diff_field in diff:
if diff_field.find('@') != -1 and diff_field.startswith('confidence'):
field, target = diff_field.split('@')
#if not field in PredictionTypes[ProblemTypes[self.problem_type]]:
if field != 'confidence':
illegal_fields.append(diff_field)
else:
# don't know if the target exists in the output dictionary, check after problem loaded
self.predict_fields_post_check.add(diff_field)
else:
illegal_fields.append(diff_field)
if len(illegal_fields) > 0:
raise Exception("The prediction fields %s is/are not supported!" % ",".join(illegal_fields))