Skip to content

Commit 1087d59

Browse files
authored
[CPU] Add software prefetch to overlap bandwidth for scaled_embedding_bag (#4171)
* Add software prefetch to overlap bandwidth * Move prefetch into AVX512 * Update comments
1 parent f6f29f6 commit 1087d59

File tree

1 file changed

+28
-0
lines changed

1 file changed

+28
-0
lines changed

torchao/csrc/cpu/aten_kernels/scaled_embedding_bag.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,17 @@ static inline void store_chunk(at::Float8_e4m3fn *output, CHUNK chunk) {
157157
_mm_storeu_si128(reinterpret_cast<__m128i *>(output + 112),
158158
at::vec::CPU_CAPABILITY::cvtfp32_fp8e4m3(x7));
159159
}
160+
161+
// Prefetch all cache lines of an embedding row (all blocks).
162+
// emb_bytes = emb_dim * sizeof(data_t). Cache line = 64 bytes.
163+
template <typename data_t>
164+
static inline void _prefetch_emb_row(const data_t *base, int64_t emb_dim) {
165+
const char *ptr = reinterpret_cast<const char *>(base);
166+
const int64_t emb_bytes = emb_dim * static_cast<int64_t>(sizeof(data_t));
167+
for (int64_t off = 0; off < emb_bytes; off += 64) {
168+
_mm_prefetch(ptr + off, _MM_HINT_T0);
169+
}
170+
}
160171
#endif
161172

162173
static inline void store_elem(float &out, float input) {
@@ -181,11 +192,28 @@ inline void _scaled_embedding_bag_krnl(
181192
const index_t *offsets, const data_t *weight, const double scale,
182193
output_t *result, const int64_t num_batch) {
183194
#if defined(CPU_CAPABILITY_AVX512)
195+
// How many batch entries ahead to prefetch. Each entry has ~3 rows to fetch
196+
// from a 40M-row table; DRAM latency ~100 ns means we must keep enough
197+
// in-flight requests to hide latency.
198+
constexpr int64_t PREFETCH_DIST = 8;
184199
if (kHasAVX512 && emb_dim % 128 == 0) {
185200
constexpr int64_t block_dim = 128;
186201
const int64_t num_blocks = emb_dim / block_dim;
187202
__m512 scale_v = _mm512_set1_ps(scale);
188203
for (int64_t b = bs_begin; b < bs_end; ++b) {
204+
// Software prefetch for batch entry b+PREFETCH_DIST to overlap DRAM
205+
// latency (~100 ns per random access to large table) with AVX512 compute.
206+
const int64_t pref_b = b + PREFETCH_DIST;
207+
if (pref_b < bs_end) {
208+
const int64_t pref_start = offsets[pref_b];
209+
const int64_t pref_end = (pref_b + 1 == num_batch && last_offset != -1)
210+
? last_offset
211+
: offsets[pref_b + 1];
212+
for (int64_t pj = pref_start; pj < pref_end; ++pj) {
213+
_prefetch_emb_row(weight + indices[pj] * emb_dim, emb_dim);
214+
}
215+
}
216+
189217
__m512 x0, x1, x2, x3, x4, x5, x6, x7;
190218
__m512 y0, y1, y2, y3, y4, y5, y6, y7;
191219
int64_t start_idx = offsets[b];

0 commit comments

Comments
 (0)