11/*
2- * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
2+ * SPDX-FileCopyrightText: Copyright (c) 2025-2026 , NVIDIA CORPORATION.
33 * SPDX-License-Identifier: Apache-2.0
44 */
55
@@ -27,6 +27,10 @@ struct EpsUnexpL2SqNeighborhood : public BaseClass {
2727
2828 DataT acc[P::AccRowsPerTh][P::AccColsPerTh];
2929
30+ size_t n_blocks_y;
31+ size_t block_x;
32+ size_t block_y;
33+
3034 public:
3135 DI EpsUnexpL2SqNeighborhood (bool * _adj,
3236 IdxT* _vd,
@@ -36,9 +40,17 @@ struct EpsUnexpL2SqNeighborhood : public BaseClass {
3640 IdxT _n,
3741 IdxT _k,
3842 DataT _eps,
39- char * _smem)
40- : BaseClass(_x, _y, _m, _n, _k, _smem), adj(_adj), eps(_eps), vd(_vd), smem(_smem)
43+ char * _smem,
44+ size_t _n_blocks_y)
45+ : BaseClass(_x, _y, _m, _n, _k, _smem),
46+ adj(_adj),
47+ eps(_eps),
48+ vd(_vd),
49+ smem(_smem),
50+ n_blocks_y(_n_blocks_y)
4151 {
52+ block_x = static_cast <size_t >(blockIdx .x ) / n_blocks_y;
53+ block_y = static_cast <size_t >(blockIdx .x ) % n_blocks_y;
4254 }
4355
4456 DI void run ()
@@ -51,7 +63,7 @@ struct EpsUnexpL2SqNeighborhood : public BaseClass {
5163 private:
5264 DI void prolog ()
5365 {
54- this ->ldgXY (IdxT ( blockIdx . x ) * P::Mblk, IdxT ( blockIdx . y ) * P::Nblk, 0 );
66+ this ->ldgXY (block_x * P::Mblk, block_y * P::Nblk, 0 );
5567#pragma unroll
5668 for (int i = 0 ; i < P::AccRowsPerTh; ++i) {
5769#pragma unroll
@@ -67,7 +79,7 @@ struct EpsUnexpL2SqNeighborhood : public BaseClass {
6779 DI void loop ()
6880 {
6981 for (int kidx = P::Kblk; kidx < this ->k ; kidx += P::Kblk) {
70- this ->ldgXY (IdxT ( blockIdx . x ) * P::Mblk, IdxT ( blockIdx . y ) * P::Nblk, kidx);
82+ this ->ldgXY (block_x * P::Mblk, block_y * P::Nblk, kidx);
7183 accumulate (); // on the previous k-block
7284 this ->stsXY ();
7385 __syncthreads ();
@@ -79,8 +91,8 @@ struct EpsUnexpL2SqNeighborhood : public BaseClass {
7991
8092 DI void epilog ()
8193 {
82- IdxT startx = blockIdx . x * P::Mblk + this ->accrowid ;
83- IdxT starty = blockIdx . y * P::Nblk + this ->acccolid ;
94+ IdxT startx = block_x * P::Mblk + this ->accrowid ;
95+ IdxT starty = block_y * P::Nblk + this ->acccolid ;
8496 auto lid = raft::laneId ();
8597 IdxT sums[P::AccRowsPerTh];
8698#pragma unroll
@@ -126,7 +138,7 @@ struct EpsUnexpL2SqNeighborhood : public BaseClass {
126138 __syncthreads (); // so that we can safely reuse smem
127139 int gid = this ->accrowid ;
128140 int lid = this ->acccolid ;
129- auto cidx = IdxT ( blockIdx . x ) * P::Mblk + gid;
141+ auto cidx = block_x * P::Mblk + gid;
130142 IdxT totalSum = 0 ;
131143 // update the individual vertex degrees
132144#pragma unroll
@@ -157,11 +169,18 @@ struct EpsUnexpL2SqNeighborhood : public BaseClass {
157169}; // struct EpsUnexpL2SqNeighborhood
158170
159171template <typename DataT, typename IdxT, typename Policy>
160- __launch_bounds__ (Policy::Nthreads, 2 ) RAFT_KERNEL epsUnexpL2SqNeighKernel (
161- bool * adj, IdxT* vd, const DataT* x, const DataT* y, IdxT m, IdxT n, IdxT k, DataT eps)
172+ __launch_bounds__ (Policy::Nthreads, 2 ) RAFT_KERNEL epsUnexpL2SqNeighKernel (bool * adj,
173+ IdxT* vd,
174+ const DataT* x,
175+ const DataT* y,
176+ IdxT m,
177+ IdxT n,
178+ IdxT k,
179+ DataT eps,
180+ size_t n_blocks_y)
162181{
163182 extern __shared__ char smem[];
164- EpsUnexpL2SqNeighborhood<DataT, IdxT, Policy> obj (adj, vd, x, y, m, n, k, eps, smem);
183+ EpsUnexpL2SqNeighborhood<DataT, IdxT, Policy> obj (adj, vd, x, y, m, n, k, eps, smem, n_blocks_y );
165184 obj.run ();
166185}
167186
@@ -177,10 +196,12 @@ void epsUnexpL2SqNeighImpl(bool* adj,
177196 cudaStream_t stream)
178197{
179198 typedef typename raft::linalg::Policy4x4<DataT, VecLen>::Policy Policy;
180- dim3 grid (raft::ceildiv<int >(m, Policy::Mblk), raft::ceildiv<int >(n, Policy::Nblk));
199+ size_t n_blocks_x = raft::ceildiv<size_t >(m, Policy::Mblk);
200+ size_t n_blocks_y = raft::ceildiv<size_t >(n, Policy::Nblk);
201+ dim3 grid (n_blocks_x * n_blocks_y);
181202 dim3 blk (Policy::Nthreads);
182203 epsUnexpL2SqNeighKernel<DataT, IdxT, Policy>
183- <<<grid, blk, Policy::SmemSize, stream>>> (adj, vd, x, y, m, n, k, eps);
204+ <<<grid, blk, Policy::SmemSize, stream>>> (adj, vd, x, y, m, n, k, eps, n_blocks_y );
184205 RAFT_CUDA_TRY (cudaGetLastError ());
185206}
186207
0 commit comments