in src/deep_baselines/run.py [0:0]
def args2config(args, config):
if args.dropout:
config.dropout = args.dropout
if args.bias:
config.bias = args.bias
if args.pos_weight:
config.pos_weight = args.pos_weight
if args.weight:
config.weight = args.weight
if args.num_labels:
config.num_labels = args.num_labels
if args.task_type in ["binary_class", "binary-class"]:
config.num_labels = 2
if args.seq_max_length:
config.max_position_embeddings = args.seq_max_length
config.embedding_trainable = args.embedding_trainable
if args.embedding_dim:
config.embedding_dim = args.embedding_dim
if args.model_type in ["CHEER-CatWCNN", "CHEER-WDCNN", "CHEER-WCNN"]:
if args.embedding_dim:
config.embedding_dim = args.embedding_dim
if args.channel_in:
config.channel_in = args.channel_in
if args.kernel_nums:
config.kernel_nums = list(args.kernel_nums.split(","))
if args.kernel_sizes:
config.kernel_sizes = list(args.kernel_sizes.split(","))
if args.fc_sizes:
fc_sizes = args.fc_sizes.split(",")
if args.model_type == "CHEER-CatWCNN":
config.fc_size1 = fc_sizes[0]
config.fc_size2 = fc_sizes[1]
else:
config.fc_size = fc_sizes[0]
elif args.model_type == "VirHunter":
if args.kernel_nums:
config.kernel_num = list(args.kernel_nums.split(","))[0]
if args.kernel_sizes:
config.kernel_size = list(args.kernel_sizes.split(","))[0]
if args.fc_sizes:
fc_sizes = args.fc_sizes.split(",")
config.fc_size = fc_sizes[0]
config.one_hot_encode = args.one_hot_encode
elif args.model_type == "Virtifier":
if args.embedding_init:
config.embedding_init = args.embedding_init
if args.embedding_init_path:
config.embedding_init_path = args.embedding_init_path
if args.bidirectional:
config.bidirectional = args.bidirectional
if args.num_layers:
config.num_layers = args.num_layers
if args.hidden_dim:
config.hidden_dim = args.hidden_dim
if args.padding_idx:
config.padding_idx = args.padding_idx
if args.fc_sizes:
fc_sizes = args.fc_sizes.split(",")
config.fc_size = fc_sizes[0]
elif args.model_type == "VirSeeker":
config.embedding = args.embedding
config.bidirectional = args.bidirectional
if args.num_layers:
config.num_layers = args.num_layers
if args.hidden_dim:
config.hidden_dim = args.hidden_dim
if args.padding_idx:
config.padding_idx = args.padding_idx