in xformers/benchmarks/benchmark_triton_fused_linear.py [0:0]
def bench_linear(activations: List[Optional[Activation]]):
device = torch.device("cuda")
for dtype in [
torch.float16,
torch.float32,
]:
for backward in [True, False]:
for activation in activations:
results: Dict[str, Any] = {}
for bias in [False, True]:
for B, M, K in SHAPES:
a = torch.rand(
B, M, K, device=device, dtype=dtype, requires_grad=backward
)
# Pytorch linear layer + activation
torch_linear = torch.nn.Linear(K, 4 * K, bias=bias).to(
dtype=dtype, device=device
)
torch_activation = build_activation(activation)
# Fused layer equivalent
fused_linear = FusedLinear(
K, 4 * K, bias=bias, activation=activation
).to(dtype=dtype, device=device)
def torch_step(x):
y = torch_activation(torch_linear(x))
if backward:
torch.norm(y).backward()
return y
def triton_step(x):
y = fused_linear(x)
if backward:
torch.norm(y).backward()
return y
metrics_transform = get_metrics_transform(
activation,
a,
torch_linear.weight,
torch_linear.bias,
backward,
)
for testcase in [
TestCase(
torch_step,
"pytorch - {} - {} bias - fw{}".format(
activation,
"no" if not bias else "",
"+bw" if backward else "",
),
),
TestCase(
triton_step,
"triton - {} - {} bias - fw{}".format(
activation,
"no" if not bias else "",
"+bw" if backward else "",
),
),
]:
time = triton.testing.do_bench(
lambda: testcase.function(a)
)[0]
key = f"B={B}, M={M}, K={K}"
if key not in results:
results[key] = {}
metric = metrics_transform(time)
results[key][testcase.name] = f"{metric:.1f}"
pretty_print(
results,
title="\n --- Type: {} ---".format(dtype),
units="TFlops/s",
)
_type = "_fp16" if dtype == torch.float16 else "_fp32"
title = "FusedLinear" + _type + "_FW"
if backward:
title += "_BW"
title += "_" + activation.value if activation else "_none"
pretty_plot(results, title, "TFlops/s", dash_key="pytorch")