@@ -157,6 +157,17 @@ static inline void store_chunk(at::Float8_e4m3fn *output, CHUNK chunk) {
157157 _mm_storeu_si128 (reinterpret_cast <__m128i *>(output + 112 ),
158158 at::vec::CPU_CAPABILITY::cvtfp32_fp8e4m3 (x7));
159159}
160+
161+ // Prefetch all cache lines of an embedding row (all blocks).
162+ // emb_bytes = emb_dim * sizeof(data_t). Cache line = 64 bytes.
163+ template <typename data_t >
164+ static inline void _prefetch_emb_row (const data_t *base, int64_t emb_dim) {
165+ const char *ptr = reinterpret_cast <const char *>(base);
166+ const int64_t emb_bytes = emb_dim * static_cast <int64_t >(sizeof (data_t ));
167+ for (int64_t off = 0 ; off < emb_bytes; off += 64 ) {
168+ _mm_prefetch (ptr + off, _MM_HINT_T0);
169+ }
170+ }
160171#endif
161172
162173static inline void store_elem (float &out, float input) {
@@ -181,11 +192,28 @@ inline void _scaled_embedding_bag_krnl(
181192 const index_t *offsets, const data_t *weight, const double scale,
182193 output_t *result, const int64_t num_batch) {
183194#if defined(CPU_CAPABILITY_AVX512)
195+ // How many batch entries ahead to prefetch. Each entry has ~3 rows to fetch
196+ // from a 40M-row table; DRAM latency ~100 ns means we must keep enough
197+ // in-flight requests to hide latency.
198+ constexpr int64_t PREFETCH_DIST = 8 ;
184199 if (kHasAVX512 && emb_dim % 128 == 0 ) {
185200 constexpr int64_t block_dim = 128 ;
186201 const int64_t num_blocks = emb_dim / block_dim;
187202 __m512 scale_v = _mm512_set1_ps (scale);
188203 for (int64_t b = bs_begin; b < bs_end; ++b) {
204+ // Software prefetch for batch entry b+PREFETCH_DIST to overlap DRAM
205+ // latency (~100 ns per random access to large table) with AVX512 compute.
206+ const int64_t pref_b = b + PREFETCH_DIST;
207+ if (pref_b < bs_end) {
208+ const int64_t pref_start = offsets[pref_b];
209+ const int64_t pref_end = (pref_b + 1 == num_batch && last_offset != -1 )
210+ ? last_offset
211+ : offsets[pref_b + 1 ];
212+ for (int64_t pj = pref_start; pj < pref_end; ++pj) {
213+ _prefetch_emb_row (weight + indices[pj] * emb_dim, emb_dim);
214+ }
215+ }
216+
189217 __m512 x0, x1, x2, x3, x4, x5, x6, x7;
190218 __m512 y0, y1, y2, y3, y4, y5, y6, y7;
191219 int64_t start_idx = offsets[b];
0 commit comments