in benchmarks/dlrm/ubench/dlrm_ubench_train_embeddingbag_driver.py [0:0]
def run_emb(args, run_dataset):
# Assumption is that all tablesare identical in terms of shape, number of accesses and batch size
assert(len(run_dataset) == 1)
B = run_dataset[0][3]
T = run_dataset[0][4]
Ds = [run_dataset[0][1]] * T
D = np.average(Ds)
E = run_dataset[0][0]
L = run_dataset[0][2]
weights_precision = str_to_sparsetype(args.weights_precision)
output_dtype = str_to_sparsetype(args.output_dtype)
forward_only = args.forward_only
optimizer = OptimType.EXACT_ROWWISE_ADAGRAD if args.row_wise else OptimType.EXACT_ADAGRAD
managed_option = (
EmbeddingLocation.DEVICE
if torch.cuda.is_available()
else EmbeddingLocation.HOST
)
if weights_precision == SparseType.INT4 or weights_precision == SparseType.INT8:
# this is inference only, so no optimzer
emb = IntNBitTableBatchedEmbeddingBagsCodegen(
[("", E, d, weights_precision, managed_option) for d in Ds],
bounds_check_mode=BoundsCheckMode.WARNING,
output_dtype=output_dtype,
).cuda()
emb.initialize_weights()
forward_only = True
else:
emb = SplitTableBatchedEmbeddingBagsCodegen(
[(E, d, managed_option,
ComputeDevice.CUDA
if torch.cuda.is_available()
else ComputeDevice.CPU,
)
for d in Ds
],
optimizer=optimizer,
learning_rate=0.1,
eps=0.1,
weights_precision=weights_precision,
output_dtype=output_dtype,
).cuda()
isIntNTableBatched = isinstance(emb, IntNBitTableBatchedEmbeddingBagsCodegen)
param_size_multiplier = PRECISION_SIZE_MULTIPLIER[weights_precision]
print(
f"Forward, B: {B}, "
f"E: {E}, T: {T}, D: {D}, L: {L}, W: {args.weighted}, "
)
requests = bench.split_table_batched_embeddings_benchmark.generate_requests(
args.warmups+args.steps,
B,
T,
L,
E,
alpha=args.alpha,
weights_precision=args.weights_precision,
weighted=args.weighted,
)
if isIntNTableBatched:
requests = [(a.int(), b.int(), c if c else None) for (a, b, c) in requests]
warmup_requests, requests = requests[:args.warmups], requests[args.warmups:]
#warmups
for (indices, offsets, weights) in warmup_requests:
emb.forward(indices, offsets, weights)
# forward
time_per_iter = bench.split_table_batched_embeddings_benchmark.benchmark_requests(
requests,
lambda indices, offsets, per_sample_weights: emb.forward(
indices,
offsets,
per_sample_weights,
),
flush_gpu_cache_size_mb=args.flush_gpu_cache_size_mb,
)
bytes_per_iter = B * L * D * T * param_size_multiplier
if forward_only:
return time_per_iter, bytes_per_iter
grad_output = torch.randn(B, sum(Ds)).cuda()
# backward
time_per_iter = bench.split_table_batched_embeddings_benchmark.benchmark_requests(
requests,
lambda indices, offsets, per_sample_weights: emb(
indices.long(),
offsets.long(),
per_sample_weights,
).backward(grad_output),
flush_gpu_cache_size_mb=args.flush_gpu_cache_size_mb,
)
bytes_per_iter = B * L * D * T * param_size_multiplier * 3
return time_per_iter, bytes_per_iter