in recommenders/models/newsrec/newsrec_utils.py [0:0]
def check_nn_config(f_config):
"""Check neural networks configuration.
Args:
f_config (dict): Neural network configuration.
Raises:
ValueError: If the parameters are not correct.
"""
if f_config["model_type"] in ["nrms", "NRMS"]:
required_parameters = [
"title_size",
"his_size",
"wordEmb_file",
"wordDict_file",
"userDict_file",
"npratio",
"data_format",
"word_emb_dim",
# nrms
"head_num",
"head_dim",
# attention
"attention_hidden_dim",
"loss",
"data_format",
"dropout",
]
elif f_config["model_type"] in ["naml", "NAML"]:
required_parameters = [
"title_size",
"body_size",
"his_size",
"wordEmb_file",
"subvertDict_file",
"vertDict_file",
"wordDict_file",
"userDict_file",
"npratio",
"data_format",
"word_emb_dim",
"vert_emb_dim",
"subvert_emb_dim",
# naml
"filter_num",
"cnn_activation",
"window_size",
"dense_activation",
# attention
"attention_hidden_dim",
"loss",
"data_format",
"dropout",
]
elif f_config["model_type"] in ["lstur", "LSTUR"]:
required_parameters = [
"title_size",
"his_size",
"wordEmb_file",
"wordDict_file",
"userDict_file",
"npratio",
"data_format",
"word_emb_dim",
# lstur
"gru_unit",
"type",
"filter_num",
"cnn_activation",
"window_size",
# attention
"attention_hidden_dim",
"loss",
"data_format",
"dropout",
]
elif f_config["model_type"] in ["npa", "NPA"]:
required_parameters = [
"title_size",
"his_size",
"wordEmb_file",
"wordDict_file",
"userDict_file",
"npratio",
"data_format",
"word_emb_dim",
# npa
"user_emb_dim",
"filter_num",
"cnn_activation",
"window_size",
# attention
"attention_hidden_dim",
"loss",
"data_format",
"dropout",
]
else:
required_parameters = []
# check required parameters
for param in required_parameters:
if param not in f_config:
raise ValueError("Parameters {0} must be set".format(param))
if f_config["model_type"] in ["nrms", "NRMS", "lstur", "LSTUR"]:
if f_config["data_format"] != "news":
raise ValueError(
"For nrms and naml model, data format must be 'news', but your set is {0}".format(
f_config["data_format"]
)
)
elif f_config["model_type"] in ["naml", "NAML"]:
if f_config["data_format"] != "naml":
raise ValueError(
"For nrms and naml model, data format must be 'naml', but your set is {0}".format(
f_config["data_format"]
)
)
check_type(f_config)