configs/imagenet_text2image_movq_conv.yaml (99 lines of code) (raw):
wandb:
entity: null
experiment:
project: "muse"
name: "imagenet-movq-conv"
output_dir: "imagenet-movq-conv"
max_train_examples: 1281167 # total number of imagenet examples
max_eval_examples: 12800
save_every: 1000
eval_every: 1000
generate_every: 1000
log_every: 30
log_grad_norm_every: 500
resume_from_checkpoint: False
resume_lr_scheduler: True
model:
vq_model:
type: "movq"
pretrained: "openMUSE/movq-lion-high-res-f8-16384"
text_encoder:
type: "t5"
pretrained: "openMUSE/t5-v1_1-large-enc"
transformer:
vocab_size: 16400 # (16384 + + 1 = 16385 -> Vq + <mask>, use 16400 for even division by 8)
max_position_embeddings: 256
hidden_size: 1024
num_hidden_layers: 24
num_attention_heads: 16
intermediate_size: 4096
add_cross_attention: True
encoder_hidden_size: 1024
project_encoder_hidden_states: False
codebook_size: 16384
num_vq_tokens: 1024
initializer_range: 0.02
norm_type: "rmsnorm"
layer_norm_eps: 1e-6
use_normformer: False
use_encoder_layernorm: True
use_mlm_layer: True
use_mlm_layernorm: True
use_bias: False
hidden_dropout: 0.0
attention_dropout: 0.0
use_codebook_size_for_output: True
use_conv_in_out: True
patch_size: 2
gradient_checkpointing: True
enable_xformers_memory_efficient_attention: True
dataset:
type: "classification"
params:
train_shards_path_or_url: "pipe:aws s3 cp s3://muse-datasets/imagenet-wds/imagenet-train-{000000..000320}.tar -"
eval_shards_path_or_url: "pipe:aws s3 cp s3://muse-datasets/imagenet-wds/imagenet-val-{000000..000012}.tar -"
imagenet_class_mapping_path: "/fsx/suraj/data/imagenet-class-mapping.json"
dataset.params.validation_prompts_file: null
batch_size: ${training.batch_size}
shuffle_buffer_size: 1000
num_workers: 4
resolution: 256
pin_memory: True
persistent_workers: True
preprocessing:
max_seq_length: 16
resolution: 256
center_crop: True
random_flip: False
optimizer:
name: fused_adamw
params: # default adamw params
learning_rate: 1e-4
scale_lr: False # scale learning rate by total batch size
beta1: 0.9
beta2: 0.999
weight_decay: 0.01
epsilon: 1e-8
lr_scheduler:
scheduler: "constant_with_warmup"
params:
learning_rate: ${optimizer.params.learning_rate}
warmup_steps: 1000
training:
gradient_accumulation_steps: 2
batch_size: 64
mixed_precision: "no"
enable_tf32: True
use_ema: False
seed: 9345104
max_train_steps: 200000
overfit_one_batch: False
cond_dropout_prob: 0.1
min_masking_rate: 0.0
label_smoothing: 0.0
max_grad_norm: null
guidance_scale: 2.0
generation_timesteps: 8
# related to vae code sampling
use_soft_code_target: False
use_stochastic_code: False
soft_code_temp: 1.0