def run_emb()

in benchmarks/dlrm/ubench/dlrm_ubench_train_embeddingbag_driver.py [0:0]


def run_emb(args, run_dataset):
    # Assumption is that all tablesare identical in terms of shape, number of accesses and batch size
    assert(len(run_dataset) == 1)
    B = run_dataset[0][3]
    T = run_dataset[0][4]
    Ds = [run_dataset[0][1]] * T
    D = np.average(Ds)
    E = run_dataset[0][0]
    L = run_dataset[0][2]
    weights_precision = str_to_sparsetype(args.weights_precision)
    output_dtype = str_to_sparsetype(args.output_dtype)

    forward_only = args.forward_only

    optimizer = OptimType.EXACT_ROWWISE_ADAGRAD if args.row_wise else OptimType.EXACT_ADAGRAD
    managed_option = (
        EmbeddingLocation.DEVICE
        if torch.cuda.is_available()
        else EmbeddingLocation.HOST
    )

    if weights_precision == SparseType.INT4 or weights_precision == SparseType.INT8:
        # this is inference only, so no optimzer
        emb = IntNBitTableBatchedEmbeddingBagsCodegen(
            [("", E, d, weights_precision, managed_option) for d in Ds],
            bounds_check_mode=BoundsCheckMode.WARNING,
            output_dtype=output_dtype,
        ).cuda()
        emb.initialize_weights()
        forward_only = True
    else:
        emb = SplitTableBatchedEmbeddingBagsCodegen(
            [(E, d, managed_option,
                    ComputeDevice.CUDA
                    if torch.cuda.is_available()
                    else ComputeDevice.CPU,
                )
                for d in Ds
            ],
            optimizer=optimizer,
            learning_rate=0.1,
            eps=0.1,
            weights_precision=weights_precision,
            output_dtype=output_dtype,
        ).cuda()
    isIntNTableBatched = isinstance(emb, IntNBitTableBatchedEmbeddingBagsCodegen)

    param_size_multiplier = PRECISION_SIZE_MULTIPLIER[weights_precision]

    print(
        f"Forward, B: {B}, "
        f"E: {E}, T: {T}, D: {D}, L: {L}, W: {args.weighted}, "
    )
    requests = bench.split_table_batched_embeddings_benchmark.generate_requests(
        args.warmups+args.steps,
        B,
        T,
        L,
        E,
        alpha=args.alpha,
        weights_precision=args.weights_precision,
        weighted=args.weighted,
    )
    if isIntNTableBatched:
        requests = [(a.int(), b.int(), c if c else None) for (a, b, c) in requests]
    warmup_requests, requests = requests[:args.warmups], requests[args.warmups:]

    #warmups
    for (indices, offsets, weights) in warmup_requests:
        emb.forward(indices, offsets, weights)

    # forward
    time_per_iter = bench.split_table_batched_embeddings_benchmark.benchmark_requests(
        requests,
        lambda indices, offsets, per_sample_weights: emb.forward(
            indices,
            offsets,
            per_sample_weights,
        ),
        flush_gpu_cache_size_mb=args.flush_gpu_cache_size_mb,
    )
    bytes_per_iter = B * L * D * T * param_size_multiplier

    if forward_only:
        return time_per_iter, bytes_per_iter

    grad_output = torch.randn(B, sum(Ds)).cuda()
    # backward
    time_per_iter = bench.split_table_batched_embeddings_benchmark.benchmark_requests(
        requests,
        lambda indices, offsets, per_sample_weights: emb(
            indices.long(),
            offsets.long(),
            per_sample_weights,
        ).backward(grad_output),
        flush_gpu_cache_size_mb=args.flush_gpu_cache_size_mb,
    )
    bytes_per_iter = B * L * D * T * param_size_multiplier * 3

    return time_per_iter, bytes_per_iter