jax-projects/model_parallel/run_clm_mp.py (5 lines): - line 450: # TODO: weights should be initialized in pjitted fun, this won't work for REALLY large models - line 451: # TODO: when loading from pre-trained model we need to make sure the vocab is divisible by num_partitions - line 486: # TODO: optax returns different state for different optimizers, how can we handle this generically ? - line 506: # TODO: allow loading weights on CPU in pre-trained model - line 533: # TODO: try to use TrainState instead of passing params and opt_state individually rag/lightning_base.py (3 lines): - line 76: # TODO: move to self.save_hyperparameters() - line 166: num_devices = max(1, self.hparams.gpus) # TODO: consider num_tpu_cores - line 379: # TODO: remove with PyTorch 1.6 since pl uses native amp seq2seq-distillation/lightning_base.py (3 lines): - line 76: # TODO: move to self.save_hyperparameters() - line 166: num_devices = max(1, self.hparams.gpus) # TODO: consider num_tpu_cores - line 369: # TODO: remove with PyTorch 1.6 since pl uses native amp pplm/run_pplm.py (3 lines): - line 125: # TODO fix this comment (SUMANTH) - line 159: # TODO: Check the layer-norm consistency of this with trained discriminator (Sumanth) - line 175: # TODO why we need to do this assignment and not just using unpert_past? (Sumanth) rag-end2end-retriever/lightning_base.py (2 lines): - line 76: # TODO: move to self.save_hyperparameters() - line 168: num_devices = max(1, self.hparams.gpus) # TODO: consider num_tpu_cores seq2seq-distillation/_test_bash_script.py (2 lines): - line 117: # TODO: turn on args.do_predict when PL bug fixed. - line 197: # TODO: turn on args.do_predict when PL bug fixed. longform-qa/eli5_app.py (1 line): - line 57: wiki40b_gpu_index_flat.add(wiki40b_passage_reps) # TODO fix for larger GPU performer/modeling_flax_performer.py (1 line): - line 213: # TODO: Add ACT2FN reference to change activation function information-gain-filtration/igf/igf.py (1 line): - line 302: # TODO in original code this is written as number of actual batches seen seq2seq-distillation/_test_seq2seq_examples.py (1 line): - line 192: # TODO: understand why this breaks