in sparse_autoencoder/train.py [0:0]
def main():
cfg = Config()
comms = make_torch_comms(n_op_shards=cfg.n_op_shards, n_replicas=cfg.n_replicas)
## dataloading is left as an exercise for the reader
acts_iter = ...
stats_acts_sample = ...
n_dirs_local = cfg.n_dirs // cfg.n_op_shards
bs_local = cfg.bs // cfg.n_replicas
ae = FastAutoencoder(
n_dirs_local=n_dirs_local,
d_model=cfg.d_model,
k=cfg.k,
auxk=cfg.auxk,
dead_steps_threshold=cfg.dead_toks_threshold // cfg.bs,
comms=comms,
)
ae.cuda()
init_from_data_(ae, stats_acts_sample, comms)
# IMPORTANT: make sure all DP ranks have the same params
comms.init_broadcast_(ae)
mse_scale = (
1 / ((stats_acts_sample.float().mean(dim=0) - stats_acts_sample.float()) ** 2).mean()
)
comms.all_broadcast(mse_scale)
mse_scale = mse_scale.item()
logger = Logger(
project=cfg.wandb_project,
name=cfg.wandb_name,
dummy=cfg.wandb_project is None,
)
training_loop_(
ae,
batch_tensors(
acts_iter,
bs_local,
drop_last=True,
),
lambda ae, flat_acts_train_batch, recons, info, logger: (
# MSE
logger.logkv("train_recons", mse_scale * mse(recons, flat_acts_train_batch))
# AuxK
+ logger.logkv(
"train_maxk_recons",
cfg.auxk_coef
* normalized_mse(
ae.decode_sparse(
info["auxk_inds"],
info["auxk_vals"],
),
flat_acts_train_batch - recons.detach() + ae.pre_bias.detach(),
).nan_to_num(0),
)
),
lr=cfg.lr,
eps=cfg.eps,
clip_grad=cfg.clip_grad,
ema_multiplier=cfg.ema_multiplier,
logger=logger,
comms=comms,
)