in egg/zoo/language_bottleneck/guess_number/train.py [0:0]
def get_params(params):
parser = argparse.ArgumentParser()
parser.add_argument("--n_bits", type=int, default=8, help="")
parser.add_argument("--bits_s", type=int, default=4, help="")
parser.add_argument("--bits_r", type=int, default=4, help="")
parser.add_argument(
"--n_examples_per_epoch",
type=int,
default=8000,
help="Number of examples seen in an epoch (default: 8000)",
)
parser.add_argument(
"--sender_hidden",
type=int,
default=10,
help="Size of the hidden layer of Sender (default: 10)",
)
parser.add_argument(
"--receiver_hidden",
type=int,
default=10,
help="Size of the hidden layer of Receiver (default: 10)",
)
parser.add_argument(
"--temperature",
type=float,
default=1.0,
help="GS temperature for the sender (default: 1.0)",
)
parser.add_argument(
"--sender_entropy_coeff",
type=float,
default=1e-2,
help="Entropy regularisation coeff for Sender (default: 1e-2)",
)
parser.add_argument(
"--receiver_entropy_coeff",
type=float,
default=1e-2,
help="Entropy regularisation coeff for Receiver (default: 1e-2)",
)
parser.add_argument(
"--sender_lr",
type=float,
default=None,
help="Learning rate for Sender's parameters",
)
parser.add_argument(
"--receiver_lr",
type=float,
default=None,
help="Learning rate for Receiver's parameters",
)
parser.add_argument(
"--mode",
type=str,
default="gs",
help="Selects whether Reinforce or GumbelSoftmax relaxation is used for training {rf, gs,"
" non_diff} (default: gs)",
)
parser.add_argument("--variable_length", action="store_true", default=False)
parser.add_argument("--sender_cell", type=str, default="rnn")
parser.add_argument("--receiver_cell", type=str, default="rnn")
parser.add_argument(
"--sender_emb",
type=int,
default=10,
help="Size of the embeddings of Sender (default: 10)",
)
parser.add_argument(
"--receiver_emb",
type=int,
default=10,
help="Size of the embeddings of Receiver (default: 10)",
)
parser.add_argument(
"--early_stopping_thr",
type=float,
default=0.99,
help="Early stopping threshold on accuracy (defautl: 0.99)",
)
args = core.init(arg_parser=parser, params=params)
if args.sender_lr is None:
args.sender_lr = args.lr
if args.receiver_lr is None:
args.receiver_lr = args.lr
assert args.n_examples_per_epoch % args.batch_size == 0
return args