hyperparams.py (651 lines of code) (raw):

import os HPARAMS_REGISTRY = {} DEFAULT_OUT_DIR = os.path.expandvars('$HOME/dist-aug') class Hyperparams(dict): def __getattr__(self, attr): try: return self[attr] except KeyError: return None def __setattr__(self, attr, value): self[attr] = value good_baseline_sm = Hyperparams() good_baseline_sm.float16 = True good_baseline_sm.fp16_mean_var = True good_baseline_sm.fp16_allreduce = True good_baseline_sm.no_vocab_rounding = False good_baseline_sm.skip_initial_evals = True good_baseline_sm.n_ctx = 2048 good_baseline_sm.n_layer = 32 good_baseline_sm.n_head = 4 good_baseline_sm.n_batch = 16 good_baseline_sm.n_embd = 256 good_baseline_sm.activation = 'quick_gelu' good_baseline_sm.optimizer = 'bs_adam' good_baseline_sm.blocksparse_op = True good_baseline_sm.recompute = True good_baseline_sm.resid_pdrop = 0.05 good_baseline_sm.warmup_iters = 7500 good_baseline_sm.embd_pdrop = 0.05 good_baseline_sm.lr = 0.0007 good_baseline_sm.total_epochs = 120 good_baseline_sm.pos_embd_std = 0.007 good_baseline_sm.w_embd_std = 0.013 good_baseline_sm.fp16_loss_scale = 2.0**16 good_baseline_sm.merge_layer_allreduce = 1 good_baseline_sm.max_grad_norm = 1.0 good_baseline_sm.blocksize = 64 good_baseline_sm.attention_layers = 'a' good_baseline_sm.mlp_w1 = 0.125 good_baseline_sm.qk_w = 0.125 good_baseline_sm.v_w = 0.125 good_baseline_sm.post_w = 0.125 good_baseline_sm.mlp_w2 = 0.5 good_baseline_sm.mlp_multiple = 4.0 good_baseline_sm.qk_ratio = 1.0 HPARAMS_REGISTRY['good_base_sm'] = good_baseline_sm good_baseline_med = Hyperparams() good_baseline_med.n_layer = 64 good_baseline_med.lr = 0.0005 good_baseline_med.n_batch = 4 HPARAMS_REGISTRY['good_base_med'] = good_baseline_med good_baseline_large = Hyperparams() good_baseline_large.n_layer = 64 good_baseline_large.n_head = 16 good_baseline_large.n_embd = 512 good_baseline_large.n_batch = 1 HPARAMS_REGISTRY['good_base_lg'] = good_baseline_large sample_during_eval_8gpu = Hyperparams() sample_during_eval_8gpu.sample_during_eval = True sample_during_eval_8gpu.samples_to_generate = 1 sample_during_eval_8gpu.sample_batch = 1 sample_during_eval_8gpu.sample_grid_dim = 4 HPARAMS_REGISTRY['sample-during-eval-8gpu'] = sample_during_eval_8gpu c10 = Hyperparams() c10.n_ctx = 3072 c10.dataset = 'cifar10' c10.mlp_multiple = 2.0 c10.qk_ratio = 0.5 c10.n_embd = 256 HPARAMS_REGISTRY['cifar10'] = c10 c10_dense = Hyperparams() c10_dense.update(good_baseline_sm) c10_dense.update(sample_during_eval_8gpu) c10_dense.update(c10) c10_dense.lr = 0.00035 c10_dense.dynamic_loss_scaling = True c10_dense.warmup_iters = 15000 c10_dense.max_grad_norm = 1.0 c10_dense.resid_pdrop = 0.25 c10_dense.embd_pdrop = 0.0 c10_dense.n_batch = 2 c10_dense.n_layer = 128 c10_dense.merge_layer_allreduce = 4 c10_dense.n_head = 2 c10_dense.total_epochs = 140 c10_dense.qk_w = 0.125 c10_dense.mlp_w1 = 0.125 c10_dense.mlp_w2 = 0.125 c10_dense.post_w = 0.125 c10_dense.logits_w = 0.0 c10_dense.pos_embd_std = 0.01 c10_dense.w_embd_std = 0.01 c10_dense.blocksize = 32 c10_dense.l2_loss = 0.01 HPARAMS_REGISTRY['c10-dense'] = c10_dense c10_sparse = Hyperparams() c10_sparse.update(c10_dense) c10_sparse.blocksize = 32 c10_sparse.local_attn_ctx = 96 c10_sparse.attention_layers = 'bT,b,b,b' c10_sparse.test_size = 2000 c10_sparse.datapoints = 48000 HPARAMS_REGISTRY['c10-gemnet'] = c10_sparse c10_58m = Hyperparams() c10_58m.update(c10_sparse) HPARAMS_REGISTRY['c10-58m'] = c10_58m c10_58m_rot = Hyperparams() c10_58m_rot.update(c10_58m) c10_58m_rot.use_rotation = True c10_58m_rot.total_epochs = 10000 c10_58m_rot.resid_pdrop = 0.01 HPARAMS_REGISTRY['c10-58m-rot'] = c10_58m_rot c10_58m_rot_tr = Hyperparams() c10_58m_rot_tr.update(c10_58m) c10_58m_rot_tr.use_rotation = True c10_58m_rot_tr.use_transposition = True c10_58m_rot_tr.total_epochs = 10000 c10_58m_rot_tr.resid_pdrop = 0.01 HPARAMS_REGISTRY['c10-58m-rot-tr'] = c10_58m_rot_tr c10_15m_dense = Hyperparams() c10_15m_dense.update(c10_dense) c10_15m_dense.n_layer = 32 c10_15m_dense.n_batch = 16 c10_15m_dense.resid_pdrop = 0.005 c10_15m_dense.total_epochs = 10000 c10_15m_dense.test_size = 2000 c10_15m_dense.datapoints = 48000 HPARAMS_REGISTRY['c10_15m_dense'] = c10_15m_dense c10_15m = Hyperparams() c10_15m.update(c10_sparse) c10_15m.n_layer = 32 c10_15m.n_batch = 16 c10_15m.resid_pdrop = 0.005 c10_15m.total_epochs = 10000 HPARAMS_REGISTRY['c10_15m'] = c10_15m c10_15m_rot = Hyperparams() c10_15m_rot.update(c10_15m) c10_15m_rot.use_rotation = True HPARAMS_REGISTRY['c10_15m_rot'] = c10_15m_rot c10_15m_rot_tr = Hyperparams() c10_15m_rot_tr.update(c10_15m) c10_15m_rot_tr.use_rotation = True c10_15m_rot_tr.use_transposition = True HPARAMS_REGISTRY['c10_15m_rot_tr'] = c10_15m_rot_tr c10_15m_tr = Hyperparams() c10_15m_tr.update(c10_15m) c10_15m_tr.use_transposition = True HPARAMS_REGISTRY['c10_15m_tr'] = c10_15m_tr c10_15m_rev = Hyperparams() c10_15m_rev.update(c10_15m) c10_15m_rev.use_reverse = True HPARAMS_REGISTRY['c10_15m_rev'] = c10_15m_rev c10_15m_c = Hyperparams() c10_15m_c.update(c10_15m) c10_15m_c.use_color = True HPARAMS_REGISTRY['c10_15m_c'] = c10_15m_c c10_15m_js = Hyperparams() c10_15m_js.update(c10_15m) c10_15m_js.use_jigsaw = True c10_15m_js.jigsaw_grid_size = 2 HPARAMS_REGISTRY['c10_15m_js'] = c10_15m_js c10_15m_lr = Hyperparams() c10_15m_lr.update(c10_15m) c10_15m_lr.aug = True HPARAMS_REGISTRY['c10_15m_lr'] = c10_15m_lr c10_15m_ra_n2_m3 = Hyperparams() c10_15m_ra_n2_m3.update(c10_15m) c10_15m_ra_n2_m3.rand_augment = True c10_15m_ra_n2_m3.rand_augment_conditioning = True c10_15m_ra_n2_m3.rand_augment_n = 2 c10_15m_ra_n2_m3.rand_augment_m = 3 HPARAMS_REGISTRY['c10_15m_ra_n2_m3'] = c10_15m_ra_n2_m3 c10_15m_ra_n1_m2 = Hyperparams() c10_15m_ra_n1_m2.update(c10_15m) c10_15m_ra_n1_m2.rand_augment = True c10_15m_ra_n1_m2.rand_augment_conditioning = True c10_15m_ra_n1_m2.rand_augment_n = 1 c10_15m_ra_n1_m2.rand_augment_m = 2 HPARAMS_REGISTRY['c10_15m_ra_n1_m2'] = c10_15m_ra_n1_m2 c10_15m_i32_nocond = Hyperparams() c10_15m_i32_nocond.update(c10_15m) c10_15m_i32_nocond.dataset = 'imagenet32cifar' c10_15m_i32_nocond.use_imagenet_fraction = 1.0 c10_15m_i32_nocond.eval_after_n_examples = 48000 c10_15m_i32_nocond.use_dataset_conditioning = True c10_15m_i32_nocond.use_unconditional_augmentation = True HPARAMS_REGISTRY['c10_15m_i32_nocond'] = c10_15m_i32_nocond c10_15m_i32_cond = Hyperparams() c10_15m_i32_cond.update(c10_15m) c10_15m_i32_cond.dataset = 'imagenet32cifar' c10_15m_i32_cond.use_imagenet_fraction = 1.0 c10_15m_i32_cond.eval_after_n_examples = 48000 c10_15m_i32_cond.use_dataset_conditioning = True HPARAMS_REGISTRY['c10_15m_i32_cond'] = c10_15m_i32_cond c10_15m_ss_i32_nocond = Hyperparams() c10_15m_ss_i32_nocond.update(c10_15m) c10_15m_ss_i32_nocond.auxiliary_dataset = 'imagenet32' c10_15m_ss_i32_nocond.auxiliary_dataset_fraction = 0.5 c10_15m_ss_i32_nocond.use_dataset_conditioning = True c10_15m_ss_i32_nocond.use_unconditional_augmentation = True HPARAMS_REGISTRY['c10_15m_ss_i32_nocond'] = c10_15m_ss_i32_nocond c10_15m_ss_i32_cond = Hyperparams() c10_15m_ss_i32_cond.update(c10_15m) c10_15m_ss_i32_cond.auxiliary_dataset = 'imagenet32' c10_15m_ss_i32_cond.auxiliary_dataset_fraction = 0.5 c10_15m_ss_i32_cond.use_dataset_conditioning = True HPARAMS_REGISTRY['c10_15m_ss_i32_cond'] = c10_15m_ss_i32_cond c10_15m_dense_rd = Hyperparams() c10_15m_dense_rd.update(c10_15m_dense) c10_15m_dense_rd.use_randomly_determined_order = True c10_15m_dense_rd.randomly_determined_order_num_perms = 3 c10_15m_dense_rd.randomly_determined_order_seed = 42 HPARAMS_REGISTRY['c10_15m_dense_rd'] = c10_15m_dense_rd c10_15m_rd = Hyperparams() c10_15m_rd.update(c10_15m) c10_15m_rd.use_randomly_determined_order = True c10_15m_rd.randomly_determined_order_num_perms = 3 c10_15m_rd.randomly_determined_order_seed = 42 HPARAMS_REGISTRY['c10_15m_rd'] = c10_15m_rd c10_15m_rd_s314 = Hyperparams() c10_15m_rd_s314.update(c10_15m) c10_15m_rd_s314.use_randomly_determined_order = True c10_15m_rd_s314.randomly_determined_order_num_perms = 3 c10_15m_rd_s314.randomly_determined_order_seed = 314 HPARAMS_REGISTRY['c10_15m_rd_s314'] = c10_15m_rd_s314 c10_15m_rd_s2718 = Hyperparams() c10_15m_rd_s2718.update(c10_15m) c10_15m_rd_s2718.use_randomly_determined_order = True c10_15m_rd_s2718.randomly_determined_order_num_perms = 3 c10_15m_rd_s2718.randomly_determined_order_seed = 2718 HPARAMS_REGISTRY['c10_15m_rd_s2718'] = c10_15m_rd_s2718 c10_15m_rd_s1618 = Hyperparams() c10_15m_rd_s1618.update(c10_15m) c10_15m_rd_s1618.use_randomly_determined_order = True c10_15m_rd_s1618.randomly_determined_order_num_perms = 3 c10_15m_rd_s1618.randomly_determined_order_seed = 1618 HPARAMS_REGISTRY['c10_15m_rd_s1618'] = c10_15m_rd_s1618 imagenet64_8gpu = Hyperparams() imagenet64_8gpu.update(c10_sparse) imagenet64_8gpu.n_batch = 16 imagenet64_8gpu.n_embd = 512 imagenet64_8gpu.n_layer = 28 imagenet64_8gpu.n_head = 4 imagenet64_8gpu.dataset = 'imagenet64' imagenet64_8gpu.blocksize = 64 imagenet64_8gpu.local_attn_ctx = 128 imagenet64_8gpu.lr = 0.00025 imagenet64_8gpu.n_ctx = 8192 imagenet64_8gpu.resid_pdrop = 0.01 imagenet64_8gpu.embd_pdrop = 0.01 imagenet64_8gpu.total_epochs = 50 imagenet64_8gpu.mlp_w1 = 0.125 imagenet64_8gpu.qk_w = 0.125 imagenet64_8gpu.v_w = 0.125 imagenet64_8gpu.post_w = 0.125 imagenet64_8gpu.mlp_w2 = 0.5 imagenet64_8gpu.mlp_multiple = 4.0 imagenet64_8gpu.qk_ratio = 1.0 HPARAMS_REGISTRY['imagenet64-8gpu'] = imagenet64_8gpu c10_150m_baseline = Hyperparams() c10_150m_baseline.update(imagenet64_8gpu) c10_150m_baseline.blocksize = 32 c10_150m_baseline.local_attn_ctx = 96 c10_150m_baseline.n_batch = 2 c10_150m_baseline.lr = 0.00015 c10_150m_baseline.merge_layer_allreduce = 4 c10_150m_baseline.n_layer = 48 c10_150m_baseline.resid_pdrop = 0.005 c10_150m_baseline.pos_embd_std = 0.01 c10_150m_baseline.w_embd_std = 0.01 c10_150m_baseline.dynamic_loss_scaling = True c10_150m_baseline.embd_pdrop = 0.0 c10_150m_baseline.mlp_w2 = 0.125 c10_150m_baseline.n_ctx = 3072 c10_150m_baseline.n_head = 16 c10_150m_baseline.attention_layers = 'b,bT,b,b' c10_150m_baseline.dataset = 'cifar10' c10_150m_baseline.total_epochs = 10000 c10_150m_baseline.test_size = 2000 c10_150m_baseline.datapoints = 48000 HPARAMS_REGISTRY['c10_150m_baseline'] = c10_150m_baseline c10_150m_pgd1 = Hyperparams() c10_150m_pgd1.update(c10_150m_baseline) c10_150m_pgd1.use_linf_pgd = True c10_150m_pgd1.linf_pgd_epsilon = 1.0 c10_150m_pgd1.linf_pgd_n = 1 c10_150m_pgd1.linf_pgd_a = 1.0 HPARAMS_REGISTRY['c10_150m_pgd1'] = c10_150m_pgd1 c10_150m_pgd3 = Hyperparams() c10_150m_pgd3.update(c10_150m_baseline) c10_150m_pgd3.use_linf_pgd = True c10_150m_pgd3.linf_pgd_epsilon = 2.0 c10_150m_pgd3.linf_pgd_n = 3 c10_150m_pgd3.linf_pgd_a = 1.0 HPARAMS_REGISTRY['c10_150m_pgd3'] = c10_150m_pgd3 c10_150m_pgd4 = Hyperparams() c10_150m_pgd4.update(c10_150m_baseline) c10_150m_pgd4.use_linf_pgd = True c10_150m_pgd4.linf_pgd_epsilon = 3.0 c10_150m_pgd4.linf_pgd_n = 4 c10_150m_pgd4.linf_pgd_a = 1.0 HPARAMS_REGISTRY['c10_150m_pgd4'] = c10_150m_pgd4 c10_150m_pgd5 = Hyperparams() c10_150m_pgd5.update(c10_150m_baseline) c10_150m_pgd5.use_linf_pgd = True c10_150m_pgd5.linf_pgd_epsilon = 4.0 c10_150m_pgd5.linf_pgd_n = 5 c10_150m_pgd5.linf_pgd_a = 1.0 HPARAMS_REGISTRY['c10_150m_pgd5'] = c10_150m_pgd5 c10_150m_rot = Hyperparams() c10_150m_rot.update(c10_150m_baseline) c10_150m_rot.use_rotation = True HPARAMS_REGISTRY['c10_150m_rot'] = c10_150m_rot c10_150m_tr = Hyperparams() c10_150m_tr.update(c10_150m_baseline) c10_150m_tr.use_transposition = True HPARAMS_REGISTRY['c10_150m_tr'] = c10_150m_tr c10_150m_js = Hyperparams() c10_150m_js.update(c10_150m_baseline) c10_150m_js.use_jigsaw = True c10_150m_js.jigsaw_grid_size = 2 HPARAMS_REGISTRY['c10_150m_js'] = c10_150m_js c10_150m_color = Hyperparams() c10_150m_color.update(c10_150m_baseline) c10_150m_color.use_color = True HPARAMS_REGISTRY['c10_150m_color'] = c10_150m_color c10_150m_tr = Hyperparams() c10_150m_tr.update(c10_150m_baseline) c10_150m_tr.use_transposition = True HPARAMS_REGISTRY['c10_150m_tr'] = c10_150m_tr c10_150m_rot_tr = Hyperparams() c10_150m_rot_tr.update(c10_150m_baseline) c10_150m_rot_tr.use_rotation = True c10_150m_rot_tr.use_transposition = True HPARAMS_REGISTRY['c10_150m_rot_tr'] = c10_150m_rot_tr c10_150m_rot_js = Hyperparams() c10_150m_rot_js.update(c10_150m_baseline) c10_150m_rot_js.use_rotation = True c10_150m_rot_js.use_jigsaw = True c10_150m_rot_js.jigsaw_grid_size = 2 HPARAMS_REGISTRY['c10_150m_rot_js'] = c10_150m_rot_js c10_150m_rot_js_tr = Hyperparams() c10_150m_rot_js_tr.update(c10_150m_baseline) c10_150m_rot_js_tr.use_rotation = True c10_150m_rot_js_tr.use_jigsaw = True c10_150m_rot_js_tr.jigsaw_grid_size = 2 c10_150m_rot_js_tr.use_transposition = True HPARAMS_REGISTRY['c10_150m_rot_js_tr'] = c10_150m_rot_js_tr c10_150m_rot_js_tr_c = Hyperparams() c10_150m_rot_js_tr_c.update(c10_150m_baseline) c10_150m_rot_js_tr_c.use_rotation = True c10_150m_rot_js_tr_c.use_jigsaw = True c10_150m_rot_js_tr_c.jigsaw_grid_size = 2 c10_150m_rot_js_tr_c.use_transposition = True c10_150m_rot_js_tr_c.use_color = True HPARAMS_REGISTRY['c10_150m_rot_js_tr_c'] = c10_150m_rot_js_tr_c c10_150m_imagenet = Hyperparams() c10_150m_imagenet.update(c10_150m_baseline) c10_150m_imagenet.dataset = 'imagenet32cifar' c10_150m_imagenet.use_imagenet_fraction = 1.0 c10_150m_imagenet.eval_after_n_examples = 48000 c10_150m_imagenet.use_dataset_conditioning = True HPARAMS_REGISTRY['c10_150m_imagenet'] = c10_150m_imagenet c10_150m_aug = Hyperparams() c10_150m_aug.update(c10_150m_baseline) c10_150m_aug.aug = True c10_150m_aug.resid_pdrop = 0.40 HPARAMS_REGISTRY['c10_150m_aug'] = c10_150m_aug c10_150m_randaugment_dataaug = Hyperparams() c10_150m_randaugment_dataaug.update(c10_150m_baseline) c10_150m_randaugment_dataaug.rand_augment = True c10_150m_randaugment_dataaug.rand_augment_n = 2 c10_150m_randaugment_dataaug.rand_augment_m = 3 HPARAMS_REGISTRY['c10_150m_randaugment_dataaug'] = c10_150m_randaugment_dataaug c10_150m_randaugment_distaug = Hyperparams() c10_150m_randaugment_distaug.update(c10_150m_baseline) c10_150m_randaugment_distaug.rand_augment = True c10_150m_randaugment_distaug.rand_augment_conditioning = True c10_150m_randaugment_distaug.rand_augment_n = 2 c10_150m_randaugment_distaug.rand_augment_m = 3 HPARAMS_REGISTRY['c10_150m_randaugment_distaug'] = c10_150m_randaugment_distaug c10_150m_rot = Hyperparams() c10_150m_rot.update(c10_150m_baseline) c10_150m_rot.use_rotation = True HPARAMS_REGISTRY['c10-150m-rot'] = c10_150m_rot c10_150m_rot_c_tr = Hyperparams() c10_150m_rot_c_tr.update(c10_150m_baseline) c10_150m_rot_c_tr.use_rotation = True c10_150m_rot_c_tr.use_color = True c10_150m_rot_c_tr.use_transposition = True HPARAMS_REGISTRY['c10-150m-rot-c-tr'] = c10_150m_rot_c_tr c10_150m_rot_c_tr_js = Hyperparams() c10_150m_rot_c_tr_js.update(c10_150m_baseline) c10_150m_rot_c_tr_js.use_rotation = True c10_150m_rot_c_tr_js.use_color = True c10_150m_rot_c_tr_js.use_transposition = True c10_150m_rot_c_tr_js.use_jigsaw = True c10_150m_rot_c_tr_js.jigsaw_grid_size = 2 HPARAMS_REGISTRY['c10-150m-rot-c-tr-js'] = c10_150m_rot_c_tr_js c10_150m_rot_tr_js = Hyperparams() c10_150m_rot_tr_js.update(c10_150m_baseline) c10_150m_rot_tr_js.use_rotation = True c10_150m_rot_tr_js.use_transposition = True c10_150m_rot_tr_js.use_jigsaw = True c10_150m_rot_tr_js.jigsaw_grid_size = 2 HPARAMS_REGISTRY['c10-150m-rot-tr-js'] = c10_150m_rot_tr_js c10_150m_rot_c = Hyperparams() c10_150m_rot_c.update(c10_150m_baseline) c10_150m_rot_c.use_rotation = True c10_150m_rot_c.use_color = True HPARAMS_REGISTRY['c10-150m-rot-c'] = c10_150m_rot_c c10_150m_rot_tr = Hyperparams() c10_150m_rot_tr.update(c10_150m_baseline) c10_150m_rot_tr.use_rotation = True c10_150m_rot_tr.use_transposition = True HPARAMS_REGISTRY['c10-150m-rot-tr'] = c10_150m_rot_tr c10_150m_rot_tr_ra_n2_m3 = Hyperparams() c10_150m_rot_tr_ra_n2_m3.update(c10_150m_baseline) c10_150m_rot_tr_ra_n2_m3.use_rotation = True c10_150m_rot_tr_ra_n2_m3.use_transposition = True c10_150m_rot_tr_ra_n2_m3.rand_augment = True c10_150m_rot_tr_ra_n2_m3.rand_augment_n = 2 c10_150m_rot_tr_ra_n2_m3.rand_augment_m = 3 c10_150m_rot_tr_ra_n2_m3.rand_augment_conditioning = True c10_150m_rot_tr_ra_n2_m3.rand_augment_rate = 0.5 HPARAMS_REGISTRY['c10-150m-rot-tr-ra-n2-m3'] = c10_150m_rot_tr_ra_n2_m3 c10_150m_rot_tr_ra_n1_m2 = Hyperparams() c10_150m_rot_tr_ra_n1_m2.update(c10_150m_baseline) c10_150m_rot_tr_ra_n1_m2.use_rotation = True c10_150m_rot_tr_ra_n1_m2.use_transposition = True c10_150m_rot_tr_ra_n1_m2.rand_augment = True c10_150m_rot_tr_ra_n1_m2.rand_augment_n = 1 c10_150m_rot_tr_ra_n1_m2.rand_augment_m = 2 c10_150m_rot_tr_ra_n1_m2.rand_augment_conditioning = True c10_150m_rot_tr_ra_n1_m2.rand_augment_rate = 0.5 HPARAMS_REGISTRY['c10-150m-rot-tr-ra-n1-m2'] = c10_150m_rot_tr_ra_n1_m2 c10_150m_rot_c_tr_js_ra_n1_m2 = Hyperparams() c10_150m_rot_c_tr_js_ra_n1_m2.update(c10_150m_baseline) c10_150m_rot_c_tr_js_ra_n1_m2.use_rotation = True c10_150m_rot_c_tr_js_ra_n1_m2.use_color = True c10_150m_rot_c_tr_js_ra_n1_m2.use_transposition = True c10_150m_rot_c_tr_js_ra_n1_m2.use_jigsaw = True c10_150m_rot_c_tr_js_ra_n1_m2.jigsaw_grid_size = 2 c10_150m_rot_c_tr_js_ra_n1_m2.rand_augment = True c10_150m_rot_c_tr_js_ra_n1_m2.rand_augment_n = 1 c10_150m_rot_c_tr_js_ra_n1_m2.rand_augment_m = 2 c10_150m_rot_c_tr_js_ra_n1_m2.rand_augment_conditioning = True c10_150m_rot_c_tr_js_ra_n1_m2.rand_augment_rate = 0.5 HPARAMS_REGISTRY['c10-150m-rot-c-tr-js-ra-n1-m2'] = c10_150m_rot_c_tr_js_ra_n1_m2 c10_150m_c_tr = Hyperparams() c10_150m_c_tr.update(c10_150m_baseline) c10_150m_c_tr.use_color = True c10_150m_c_tr.use_transposition = True HPARAMS_REGISTRY['c10-150m-c-tr'] = c10_150m_c_tr c10_10m_baseline = Hyperparams() c10_10m_baseline.update(c10_150m_baseline) c10_10m_baseline.n_embd = 128 c10_10m_baseline.n_batch = 16 HPARAMS_REGISTRY['c10_10m_baseline'] = c10_10m_baseline c10_10m_rot = Hyperparams() c10_10m_rot.update(c10_10m_baseline) c10_10m_rot.use_rotation = True HPARAMS_REGISTRY['c10_10m_rot'] = c10_10m_rot c10_2m_baseline = Hyperparams() c10_2m_baseline.update(c10_150m_baseline) c10_2m_baseline.n_embd = 64 c10_2m_baseline.n_batch = 16 c10_2m_baseline.n_head = 8 HPARAMS_REGISTRY['c10_2m_baseline'] = c10_2m_baseline c10_2m_rot = Hyperparams() c10_2m_rot.update(c10_2m_baseline) c10_2m_rot.use_rotation = True HPARAMS_REGISTRY['c10_2m_rot'] = c10_2m_rot i64_150m_32gpu = Hyperparams() i64_150m_32gpu.update(imagenet64_8gpu) i64_150m_32gpu.n_batch = 4 i64_150m_32gpu.lr = 0.00015 i64_150m_32gpu.l2_loss = 0.001 i64_150m_32gpu.total_epochs = 10000 i64_150m_32gpu.merge_layer_allreduce = 4 i64_150m_32gpu.n_layer = 48 i64_150m_32gpu.resid_pdrop = 0.005 i64_150m_32gpu.blocksize = 32 i64_150m_32gpu.pos_embd_std = 0.01 i64_150m_32gpu.w_embd_std = 0.01 i64_150m_32gpu.dropout_broadcast_dims = None i64_150m_32gpu.dynamic_loss_scaling = True i64_150m_32gpu.embd_pdrop = 0.0 i64_150m_32gpu.mlp_w2 = 0.125 i64_150m_32gpu.n_ctx = 12288 i64_150m_32gpu.n_head = 16 i64_150m_32gpu.attention_layers = 'b,bT,b,b' HPARAMS_REGISTRY['i64_150m_32gpu'] = i64_150m_32gpu i64_150m_32gpu_rot = Hyperparams() i64_150m_32gpu_rot.update(i64_150m_32gpu) i64_150m_32gpu_rot.use_rotation = True HPARAMS_REGISTRY['i64_150m_32gpu_rot_32gpu'] = i64_150m_32gpu_rot i64_150m_32gpu_rot_tr = Hyperparams() i64_150m_32gpu_rot_tr.update(i64_150m_32gpu) i64_150m_32gpu_rot_tr.use_rotation = True i64_150m_32gpu_rot_tr.use_transposition = True HPARAMS_REGISTRY['i64_150m_32gpu_rot_tr_32gpu'] = i64_150m_32gpu_rot_tr i64_300m_64gpu = Hyperparams() i64_300m_64gpu.update(i64_150m_32gpu) i64_300m_64gpu.n_layer = 96 i64_300m_64gpu.n_batch = 2 HPARAMS_REGISTRY['i64_300m_64gpu'] = i64_300m_64gpu i64_300m_64gpu_rot = Hyperparams() i64_300m_64gpu_rot.update(i64_300m_64gpu) i64_300m_64gpu_rot.use_rotation = True HPARAMS_REGISTRY['i64_300m_64gpu_rot'] = i64_300m_64gpu_rot i64_300m_64gpu_rot_tr = Hyperparams() i64_300m_64gpu_rot_tr.update(i64_300m_64gpu) i64_300m_64gpu_rot_tr.use_rotation = True i64_300m_64gpu_rot_tr.use_transposition = True HPARAMS_REGISTRY['i64_300m_64gpu_rot_tr'] = i64_300m_64gpu_rot_tr i64_300m_64gpu_rot_c_tr = Hyperparams() i64_300m_64gpu_rot_c_tr.update(i64_300m_64gpu) i64_300m_64gpu_rot_c_tr.use_rotation = True i64_300m_64gpu_rot_c_tr.use_color = True i64_300m_64gpu_rot_c_tr.use_transposition = True HPARAMS_REGISTRY['i64_300m_64gpu_rot_c_tr'] = i64_300m_64gpu_rot_c_tr def parse_args_and_update_hparams(H, parser, s=None): args = parser.parse_args(s) valid_args = set(args.__dict__.keys()) hparam_sets = [x for x in args.hparam_sets.split(',') if x] for hp_set in hparam_sets: hps = HPARAMS_REGISTRY[hp_set] for k in hps: if k not in valid_args: raise ValueError(f"{k} not in default args") parser.set_defaults(**hps) H.update(parser.parse_args().__dict__) # H is updated in place, so return nothing. def add_arguments(parser): parser.add_argument('--out_dir', type=str, default=DEFAULT_OUT_DIR) parser.add_argument('--desc', type=str, default='test') parser.add_argument('--print_params', action="store_true") parser.add_argument('--hparam_sets', '--hps', type=str, default='') # dataset params parser.add_argument('--dataset', type=str, default="cifar10") parser.add_argument('--auxiliary_dataset', type=str, default=None) parser.add_argument('--auxiliary_dataset_fraction', type=float, default=0.5) parser.add_argument('--auxiliary_dataset_subset_size', type=int, default=None) parser.add_argument('--auxiliary_dataset_seed', type=int, default=42) # Training params parser.add_argument('--n_batch', type=int, default=128) parser.add_argument('--max_grad_norm', type=float, default=1.0) # Transformer architectural parameters parser.add_argument('--n_embd', type=int, default=512) parser.add_argument('--n_ctx', type=int, default=256) parser.add_argument('--n_head', type=int, default=8) parser.add_argument('--n_layer', type=int, default=6) parser.add_argument('--dropout_broadcast_dims', type=str, default=None) parser.add_argument('--embd_pdrop', type=float, default=0.1) parser.add_argument('--resid_pdrop', type=float, default=0.1) parser.add_argument('--mlp_multiple', type=float, default=4.0) parser.add_argument('--qk_ratio', type=float, default=1.0) parser.add_argument('--attention_layers', type=str, default='a') parser.add_argument('--local_attn_ctx', type=int, default=64) parser.add_argument('--pos_embd_std', type=float, default=0.007) parser.add_argument('--w_embd_std', type=float, default=0.013) parser.add_argument('--mlp_w1', type=float, default=0.125) parser.add_argument('--mlp_w2', type=float, default=0.125) parser.add_argument('--qk_w', type=float, default=0.125) parser.add_argument('--v_w', type=float, default=0.125) parser.add_argument('--post_w', type=float, default=0.125) parser.add_argument('--logits_w', type=float, default=0.125) parser.add_argument('--preconv_w', type=float, default=0.125) # rand augment params # https://arxiv.org/pdf/1909.13719.pdf parser.add_argument('--rand_augment', action="store_true") parser.add_argument('--rand_augment_conditioning', action="store_true") parser.add_argument('--rand_augment_rate', type=float, default=0.95) parser.add_argument('--rand_augment_n', type=int, default=1) # Number of sequential perturbations -- range [1, 3] parser.add_argument('--rand_augment_m', type=int, default=2) # Magnitude of pertubations -- range [2, 30] # Distr Aug Params parser.add_argument('--aug', action='store_true') parser.add_argument('--permute_embeddings', dest='permute_embeddings', action="store_true") parser.add_argument('--no_permute_embeddings', dest='permute_embeddings', action="store_false") parser.set_defaults(permute_embeddings=True) parser.add_argument('--use_imagenet_fraction', type=float, default=1.0) parser.add_argument('--unaugmented_data_rate', type=float, default=None) parser.add_argument('--use_rotation', action="store_true") parser.add_argument('--use_dataset_conditioning', action="store_true") parser.add_argument('--no_dataset_conditioning', action="store_false", dest="use_dataset_conditioning") parser.add_argument('--use_color', action="store_true") parser.add_argument('--use_transposition', action="store_true") parser.add_argument('--use_randomly_determined_order', action="store_true") parser.add_argument('--randomly_determined_order_num_perms', type=int, default=3) parser.add_argument('--randomly_determined_order_seed', type=int, default=42) parser.add_argument('--randomly_determined_order_use_lookahead', action="store_true") parser.add_argument('--use_reverse', action="store_true") parser.add_argument('--use_linf_pgd', action="store_true") parser.add_argument('--use_jigsaw', action="store_true") parser.add_argument('--jigsaw_grid_size', type=int, default=2) parser.add_argument('--use_unconditional_augmentation', action='store_true') parser.add_argument('--datapoints', type=int, default=None) parser.add_argument('--test_size', type=int, default=None) # Training params parser.add_argument('--seed', type=int, default=42) parser.add_argument('--aug_seed', type=int, default=314) parser.add_argument('--optimizer', type=str, default='bs_adam') parser.add_argument('--activation', type=str, default='quick_gelu') parser.add_argument('--beta2', type=float, default=0.999) parser.add_argument('--l2_loss', type=float, default=0.0) parser.add_argument('--recompute', action="store_true", dest="recompute") parser.add_argument('--no_recompute', action="store_false", dest="recompute") parser.add_argument('--float16', action="store_true") parser.add_argument('--no_float16', action="store_false", dest='float16') parser.add_argument('--blocksparse_op', action="store_true") parser.add_argument('--no_blocksparse_op', action="store_false", dest="blocksparse_op") parser.add_argument('--blocksize', type=int, default=64) parser.add_argument('--fp16_allreduce', action="store_true") parser.add_argument('--no_fp16_allreduce', action="store_false", dest='fp16_allreduce') parser.add_argument('--merge_layer_allreduce', default=0, type=int) parser.add_argument('--fp32_gains_biases', action="store_true") parser.add_argument('--fp16_loss_scale', type=float, default=2.0**16) parser.add_argument('--min_loss_scale', type=float, default=2.0**10) parser.add_argument('--fp16_loss_freq', type=int, default=1000) parser.add_argument('--fp16_mean_var', action='store_true') parser.add_argument('--no_fp16_mean_var', action='store_false', dest='fp16_mean_var') parser.add_argument('--dynamic_loss_scaling', action='store_true') parser.add_argument('--no_dynamic_loss_scaling', action='store_false', dest='dynamic_loss_scaling') parser.add_argument('--lr', type=float, default=0.0005) parser.add_argument('--lr_offset', type=int, default=0) parser.add_argument('--decay_lr_linearly', action="store_true") parser.add_argument('--no_vocab_rounding', action="store_true") parser.add_argument('--disable_ema_vars', action="store_true") parser.add_argument('--total_epochs', type=int, default=100) parser.add_argument('--exit_after_n_epochs', type=int, default=None) parser.add_argument('--warmup_iters', type=int, default=5000) parser.add_argument('--weights_beta', type=float, default=0.999) parser.add_argument('--iters_per_log', type=int, default=500) parser.add_argument('--aug_eval', type=str, default=None) parser.add_argument('--aug_eval_n_examples', type=int, default=None) parser.add_argument('--eval_after_n_examples', type=int, default=None) parser.add_argument('--epochs_per_save', type=int, default=1) parser.add_argument('--epochs_per_backup', type=int, default=1) parser.add_argument('--epochs_per_eval', type=int, default=1) # eval stuff parser.add_argument('--skip_initial_evals', action="store_true") parser.add_argument('--eval_and_exit', action="store_true") parser.add_argument('--no_skip_initial_evals', action="store_false", dest='skip_initial_evals') parser.add_argument('--eval_test', action="store_true") parser.add_argument('--eval_start_idx', type=int, default=0) parser.add_argument('--eval_n_examples', type=int, default=100000) # Generating unconditional samples parser.add_argument('--sample_batch', type=int, default=4) parser.add_argument('--samples_to_generate', type=int, default=4) parser.add_argument('--sample_grid_dim', type=int, default=4) parser.add_argument('--sample_and_exit', action="store_true") parser.add_argument('--sample_during_eval', action="store_true") parser.add_argument('--sample_f16', action="store_true") parser.add_argument('--temperature', type=float, default=1.0) parser.add_argument('--no_sample_during_eval', action="store_false", dest='sample_during_eval') # Restoring jobs parser.add_argument('--restore_path', type=str, default='') return parser