def main()

in sparse_autoencoder/train.py [0:0]


def main():
    cfg = Config()
    comms = make_torch_comms(n_op_shards=cfg.n_op_shards, n_replicas=cfg.n_replicas)

    ## dataloading is left as an exercise for the reader
    acts_iter = ...
    stats_acts_sample = ...

    n_dirs_local = cfg.n_dirs // cfg.n_op_shards
    bs_local = cfg.bs // cfg.n_replicas

    ae = FastAutoencoder(
        n_dirs_local=n_dirs_local,
        d_model=cfg.d_model,
        k=cfg.k,
        auxk=cfg.auxk,
        dead_steps_threshold=cfg.dead_toks_threshold // cfg.bs,
        comms=comms,
    )
    ae.cuda()
    init_from_data_(ae, stats_acts_sample, comms)
    # IMPORTANT: make sure all DP ranks have the same params
    comms.init_broadcast_(ae)

    mse_scale = (
        1 / ((stats_acts_sample.float().mean(dim=0) - stats_acts_sample.float()) ** 2).mean()
    )
    comms.all_broadcast(mse_scale)
    mse_scale = mse_scale.item()

    logger = Logger(
        project=cfg.wandb_project,
        name=cfg.wandb_name,
        dummy=cfg.wandb_project is None,
    )

    training_loop_(
        ae,
        batch_tensors(
            acts_iter,
            bs_local,
            drop_last=True,
        ),
        lambda ae, flat_acts_train_batch, recons, info, logger: (
            # MSE
            logger.logkv("train_recons", mse_scale * mse(recons, flat_acts_train_batch))
            # AuxK
            + logger.logkv(
                "train_maxk_recons",
                cfg.auxk_coef
                * normalized_mse(
                    ae.decode_sparse(
                        info["auxk_inds"],
                        info["auxk_vals"],
                    ),
                    flat_acts_train_batch - recons.detach() + ae.pre_bias.detach(),
                ).nan_to_num(0),
            )
        ),
        lr=cfg.lr,
        eps=cfg.eps,
        clip_grad=cfg.clip_grad,
        ema_multiplier=cfg.ema_multiplier,
        logger=logger,
        comms=comms,
    )