jax-projects/big_bird/sweep_flax.yaml (16 lines of code) (raw):

command: - python3 - train.py method: random parameters: lr: values: [4e-5, 3e-5] warmup_steps: values: [20000, 15000, 10000, 5000] weight_decay: distribution: normal mu: 1e-2 sigma: 2e-3 metric: name: eval_loss goal: minimize