11from collections .abc import Callable
2+ import itertools
23
34import torch
45from torch .profiler import ProfilerActivity
@@ -34,6 +35,65 @@ def time_cpu(fn: Callable, args: tuple, warmup: int = 25, rep: int = 100) -> flo
3435 return torch .mean (times ).item ()
3536
3637
38+ # Based on Intel XPU Triton backend benchmarks.
39+ def time_xpu (fn : Callable , args : tuple , warmup : int = 25 , rep : int = 100 ) -> float :
40+ """Measure execution time of the provided function on XPU.
41+ Args:
42+ fn: Function to measure
43+ args: Arguments to pass to the function
44+ warmup: Warmup iterations
45+ rep: Measurement iterations
46+ Returns:
47+ Mean runtime in microseconds
48+ """
49+
50+ # A device buffer used to clear L2 cache between kernel runs.
51+ cache_size = 256 * 1024 * 1024
52+ cache = torch .empty (cache_size , dtype = torch .int8 , device = torch .device ("xpu" ))
53+
54+ for _ in range (warmup ):
55+ fn (* args )
56+ torch .accelerator .synchronize ()
57+
58+ with profile (activities = [ProfilerActivity .CPU , ProfilerActivity .XPU ]) as prof :
59+ for _ in range (rep ):
60+ # Clear L2 cache.
61+ cache .zero_ ()
62+ torch .accelerator .synchronize ()
63+
64+ with record_function ("profiled_fn" ):
65+ fn (* args )
66+ # Ensure all measurements are recorded.
67+ torch .accelerator .synchronize ()
68+
69+ def extract_kernels (funcs ):
70+ """Traverse event tree recursively to extract device kernels."""
71+ kernels = []
72+ kernels += list (
73+ itertools .chain .from_iterable (
74+ map (lambda func : extract_kernels (func .cpu_children ), funcs )
75+ )
76+ )
77+ kernels += list (itertools .chain .from_iterable ([func .kernels for func in funcs ]))
78+ return kernels
79+
80+ events = [e for e in prof .events () if e .name .startswith ("profiled_fn" )]
81+ kernels = [extract_kernels (func .cpu_children ) for func in events ]
82+ kernels = [kernel for kernel in kernels if kernel ]
83+ if len (kernels ) != rep :
84+ raise AssertionError ("Unexpected number of profiled kernels" )
85+
86+ times = torch .tensor (
87+ [sum ([k .duration for k in kernel ]) for kernel in kernels ], dtype = torch .float
88+ )
89+
90+ # Trim extremes if there are enough measurements.
91+ if len (times ) >= 10 :
92+ times = torch .sort (times ).values [1 :- 1 ]
93+
94+ return torch .mean (times ).item ()
95+
96+
3797def time (
3898 fn : Callable ,
3999 args : tuple ,
@@ -53,4 +113,6 @@ def time(
53113 """
54114 if not device or device .type == "cpu" :
55115 return time_cpu (fn , args , warmup = warmup , rep = rep )
116+ if device .type == "xpu" :
117+ return time_xpu (fn , args , warmup = warmup , rep = rep )
56118 raise ValueError ("Unsupported device for timing" )
0 commit comments