benchmarks/fp8/torchao/ddp.py [63:78]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
def train_baseline():
    set_seed(42)
    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(MODEL_NAME)
    first_linear = None
    last_linear = None
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear):
            if first_linear is None:
                first_linear = name
            last_linear = name
    func = partial(filter_linear_layers, first_layer_name=first_linear, last_layer_name=last_linear)
    accelerator = Accelerator()
    device = accelerator.device
    model.to(device)

    convert_to_float8_training(model, module_filter_fn=func)
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



benchmarks/fp8/torchao/fsdp.py [69:84]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
def train_baseline():
    set_seed(42)
    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(MODEL_NAME)
    first_linear = None
    last_linear = None
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear):
            if first_linear is None:
                first_linear = name
            last_linear = name
    func = partial(filter_linear_layers, first_layer_name=first_linear, last_layer_name=last_linear)
    accelerator = Accelerator()
    device = accelerator.device
    model.to(device)

    convert_to_float8_training(model, module_filter_fn=func)
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



