in egg/zoo/objects_game/train.py [0:0]
def get_params(params):
parser = argparse.ArgumentParser()
input_data = parser.add_mutually_exclusive_group()
input_data.add_argument(
"--perceptual_dimensions",
type=str,
default="[4, 4, 4, 4, 4]",
help="Number of features for every perceptual dimension",
)
input_data.add_argument(
"--load_data_path",
type=str,
default=None,
help="Path to .npz data file to load",
)
parser.add_argument(
"--n_distractors",
type=int,
default=3,
help="Number of distractor objects for the receiver (default: 3)",
)
parser.add_argument(
"--train_samples",
type=float,
default=1e5,
help="Number of tuples in training data (default: 1e6)",
)
parser.add_argument(
"--validation_samples",
type=float,
default=1e3,
help="Number of tuples in validation data (default: 1e4)",
)
parser.add_argument(
"--test_samples",
type=float,
default=1e3,
help="Number of tuples in test data (default: 1e3)",
)
parser.add_argument(
"--data_seed",
type=int,
default=111,
help="Seed for random creation of train, validation and test tuples (default: 111)",
)
parser.add_argument(
"--shuffle_train_data",
action="store_true",
default=False,
help="Shuffle train data before every epoch (default: False)",
)
parser.add_argument(
"--sender_hidden",
type=int,
default=50,
help="Size of the hidden layer of Sender (default: 50)",
)
parser.add_argument(
"--receiver_hidden",
type=int,
default=50,
help="Size of the hidden layer of Receiver (default: 50)",
)
parser.add_argument(
"--sender_embedding",
type=int,
default=10,
help="Dimensionality of the embedding hidden layer for Sender (default: 10)",
)
parser.add_argument(
"--receiver_embedding",
type=int,
default=10,
help="Dimensionality of the embedding hidden layer for Receiver (default: 10)",
)
parser.add_argument(
"--sender_cell",
type=str,
default="rnn",
help="Type of the cell used for Sender {rnn, gru, lstm} (default: rnn)",
)
parser.add_argument(
"--receiver_cell",
type=str,
default="rnn",
help="Type of the cell used for Receiver {rnn, gru, lstm} (default: rnn)",
)
parser.add_argument(
"--sender_lr",
type=float,
default=1e-1,
help="Learning rate for Sender's parameters (default: 1e-1)",
)
parser.add_argument(
"--receiver_lr",
type=float,
default=1e-1,
help="Learning rate for Receiver's parameters (default: 1e-1)",
)
parser.add_argument(
"--temperature",
type=float,
default=1.0,
help="GS temperature for the sender (default: 1.0)",
)
parser.add_argument(
"--mode",
type=str,
default="gs",
help="Selects whether Reinforce or GumbelSoftmax relaxation is used for training {gs only at the moment}"
"(default: rf)",
)
parser.add_argument(
"--output_json",
action="store_true",
default=False,
help="If set, egg will output validation stats in json format (default: False)",
)
parser.add_argument(
"--evaluate",
action="store_true",
default=False,
help="Evaluate trained model on test data",
)
parser.add_argument(
"--dump_data_folder",
type=str,
default=None,
help="Folder where file with dumped data will be created",
)
parser.add_argument(
"--dump_msg_folder",
type=str,
default=None,
help="Folder where file with dumped messages will be created",
)
parser.add_argument(
"--debug",
action="store_true",
default=False,
help="Run egg/objects_game with pdb enabled",
)
args = core.init(parser, params)
check_args(args)
print(args)
return args