Skip to content

Commit 7448d87

Browse files
committed
[harness] XPU timer
1 parent 88dec86 commit 7448d87

File tree

2 files changed

+66
-1
lines changed

2 files changed

+66
-1
lines changed

ai_bench/harness/runner/kernel_bench_runner.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,10 @@ def run_kernels(self):
8181
fn = model.forward
8282
args = ai_hc.get_inputs(variant, inputs, device=self.device)
8383
if self.device.type == "cpu":
84-
meas = testing.time(fn, args, warmup=3, rep=10)
84+
meas = testing.time(fn, args, warmup=5, rep=20)
85+
print(f"time: {meas}us")
86+
if self.device.type == "xpu":
87+
meas = testing.time(fn, args, warmup=20, rep=100)
8588
print(f"time: {meas}us")
8689
else:
8790
fn(*args)

ai_bench/harness/testing/timer.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from collections.abc import Callable
2+
import itertools
23

34
import torch
45
from torch.profiler import ProfilerActivity
@@ -34,6 +35,65 @@ def time_cpu(fn: Callable, args: tuple, warmup: int = 25, rep: int = 100) -> flo
3435
return torch.mean(times).item()
3536

3637

38+
# Based on Intel XPU Triton backend benchmarks.
39+
def time_xpu(fn: Callable, args: tuple, warmup: int = 25, rep: int = 100) -> float:
40+
"""Measure execution time of the provided function on XPU.
41+
Args:
42+
fn: Function to measure
43+
args: Arguments to pass to the function
44+
warmup: Warmup iterations
45+
rep: Measurement iterations
46+
Returns:
47+
Mean runtime in microseconds
48+
"""
49+
50+
# A device buffer used to clear L2 cache between kernel runs.
51+
cache_size = 256 * 1024 * 1024
52+
cache = torch.empty(cache_size, dtype=torch.int8, device=torch.device("xpu"))
53+
54+
for _ in range(warmup):
55+
fn(*args)
56+
torch.accelerator.synchronize()
57+
58+
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.XPU]) as prof:
59+
for _ in range(rep):
60+
# Clear L2 cache.
61+
cache.zero_()
62+
torch.accelerator.synchronize()
63+
64+
with record_function("profiled_fn"):
65+
fn(*args)
66+
# Ensure all measurements are recorded.
67+
torch.accelerator.synchronize()
68+
69+
def extract_kernels(funcs):
70+
"""Traverse event tree recursively to extract device kernels."""
71+
kernels = []
72+
kernels += list(
73+
itertools.chain.from_iterable(
74+
map(lambda func: extract_kernels(func.cpu_children), funcs)
75+
)
76+
)
77+
kernels += list(itertools.chain.from_iterable([func.kernels for func in funcs]))
78+
return kernels
79+
80+
events = [e for e in prof.events() if e.name.startswith("profiled_fn")]
81+
kernels = [extract_kernels(func.cpu_children) for func in events]
82+
kernels = [kernel for kernel in kernels if kernel]
83+
if len(kernels) != rep:
84+
raise AssertionError("Unexpected number of profiled kernels")
85+
86+
times = torch.tensor(
87+
[sum([k.duration for k in kernel]) for kernel in kernels], dtype=torch.float
88+
)
89+
90+
# Trim extremes if there are enough measurements.
91+
if len(times) >= 10:
92+
times = torch.sort(times).values[1:-1]
93+
94+
return torch.mean(times).item()
95+
96+
3797
def time(
3898
fn: Callable,
3999
args: tuple,
@@ -53,4 +113,6 @@ def time(
53113
"""
54114
if not device or device.type == "cpu":
55115
return time_cpu(fn, args, warmup=warmup, rep=rep)
116+
if device.type == "xpu":
117+
return time_xpu(fn, args, warmup=warmup, rep=rep)
56118
raise ValueError("Unsupported device for timing")

0 commit comments

Comments
 (0)