benchmarks/fp8/torchao/distrib_deepspeed.py [171:188]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    for batch in train_dataloader:
        outputs = model(**batch)
        data.append(batch.to("cpu"))
        model_outputs.append(outputs.logits.to("cpu"))
        loss = outputs.loss
        accelerator.backward(loss)
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()

    trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
    model.destroy()
    assert trained_model_results["accuracy"] > base_model_results["accuracy"], (
        f"Accuracy should be higher for the trained model: {trained_model_results['accuracy']} > {base_model_results['accuracy']}"
    )
    assert trained_model_results["f1"] > base_model_results["f1"], (
        f"F1 score should be higher for the trained model: {trained_model_results['f1']} > {base_model_results['f1']}"
    )
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



benchmarks/fp8/transformer_engine/distrib_deepspeed.py [150:167]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        for batch in train_dataloader:
            outputs = model(**batch)
            data.append(batch.to("cpu"))
            model_outputs.append(outputs.logits.to("cpu"))
            loss = outputs.loss
            accelerator.backward(loss)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

    trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
    model.destroy()
    assert trained_model_results["accuracy"] > base_model_results["accuracy"], (
        f"Accuracy should be higher for the trained model: {trained_model_results['accuracy']} > {base_model_results['accuracy']}"
    )
    assert trained_model_results["f1"] > base_model_results["f1"], (
        f"F1 score should be higher for the trained model: {trained_model_results['f1']} > {base_model_results['f1']}"
    )
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



