Skip to content

Commit a46f204

Browse files
committed
[ShortConv] Add triton kernels for inference
1 parent d0596d7 commit a46f204

File tree

3 files changed

+337
-350
lines changed

3 files changed

+337
-350
lines changed

fla/modules/convolution.py

Lines changed: 167 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# from https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/convolution.py
44

55
import math
6+
import warnings
67
from typing import Optional, Tuple
78

89
import torch
@@ -12,15 +13,15 @@
1213
import triton.language as tl
1314
from einops import rearrange
1415

15-
from fla.modules.activations import ACT2FN
1616
from fla.ops.utils import prepare_chunk_indices, prepare_sequence_ids
17-
from fla.utils import checkpoint, get_multiprocessor_count, input_guard
17+
from fla.utils import get_multiprocessor_count, input_guard
1818

1919
try:
20-
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
20+
from causal_conv1d import causal_conv1d_fn
21+
from causal_conv1d import causal_conv1d_update as causal_conv1d_update_cuda
2122
except ImportError:
2223
causal_conv1d_fn = None
23-
causal_conv1d_update = None
24+
causal_conv1d_update_cuda = None
2425

2526

2627
@triton.heuristics({
@@ -74,8 +75,8 @@ def causal_conv1d_fwd_kernel(
7475
m_d = o_d < D
7576

7677
if HAS_WEIGHT:
77-
# [D, W]
78-
b_w = tl.load(weight + o_d[:, None] * W + o_w, mask=m_d[:, None], other=0)
78+
# [BD, W]
79+
b_w = tl.load(weight + o_d[:, None] * W + o_w, mask=m_d[:, None], other=0).to(tl.float32)
7980

8081
b_y = tl.zeros((BT, BD), dtype=tl.float32)
8182
for i_w in tl.static_range(-W + 1, 1):
@@ -87,7 +88,7 @@ def causal_conv1d_fwd_kernel(
8788
b_yi *= tl.sum(b_w * (o_w == (i_w + W - 1)), 1)
8889
b_y += b_yi
8990
if HAS_BIAS:
90-
b_y += tl.load(bias + o_d, mask=m_d)
91+
b_y += tl.load(bias + o_d, mask=m_d).to(tl.float32)
9192

9293
if ACTIVATION == 'swish' or ACTIVATION == 'silu':
9394
b_y = b_y * tl.sigmoid(b_y)
@@ -159,7 +160,7 @@ def causal_conv1d_bwd_kernel(
159160
if HAS_WEIGHT:
160161
p_x = tl.make_block_ptr(x + bos * D, (T, D), (D, 1), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
161162
b_x = tl.load(p_x, boundary_check=(0, 1))
162-
# [D, W]
163+
# [BD, W]
163164
b_w = tl.load(weight + o_d[:, None] * W + o_w, mask=m_d[:, None], other=0)
164165

165166
b_dx = tl.zeros((BT, BD), dtype=tl.float32)
@@ -196,6 +197,65 @@ def causal_conv1d_bwd_kernel(
196197
tl.store(p_dx, tl.cast(b_dx, dtype=p_dx.dtype.element_ty, fp_downcast_rounding='rtne'), boundary_check=(0, 1))
197198

198199

200+
@triton.heuristics({
201+
'HAS_WEIGHT': lambda args: args['weight'] is not None,
202+
'HAS_BIAS': lambda args: args['bias'] is not None,
203+
'HAS_RESIDUAL': lambda args: args['residual'] is not None,
204+
})
205+
@triton.jit
206+
def causal_conv1d_update_kernel(
207+
x,
208+
cache,
209+
residual,
210+
y,
211+
weight,
212+
bias,
213+
D: tl.constexpr,
214+
W: tl.constexpr,
215+
BD: tl.constexpr,
216+
ACTIVATION: tl.constexpr,
217+
HAS_WEIGHT: tl.constexpr,
218+
HAS_BIAS: tl.constexpr,
219+
HAS_RESIDUAL: tl.constexpr,
220+
):
221+
i_d, i_n = tl.program_id(0), tl.program_id(1)
222+
223+
o_d = i_d * BD + tl.arange(0, BD)
224+
o_w = tl.arange(0, W)
225+
m_d = o_d < D
226+
m_c = o_w < W - 1
227+
228+
# [BD]
229+
b_x = tl.load(x + i_n * D + o_d, mask=m_d, other=0).to(tl.float32)
230+
231+
# shift the cache by 1 with the last one being discarded
232+
p_cache = tl.make_block_ptr(cache + i_n * D*W, (D, W), (W, 1), (i_d * BD, 1), (BD, W), (1, 0))
233+
# [BD, W]
234+
b_cache = tl.load(p_cache, boundary_check=(0, 1)).to(tl.float32)
235+
b_cache = tl.where(m_c[None, :], b_cache, b_x[:, None])
236+
237+
if HAS_WEIGHT:
238+
b_w = tl.load(weight + o_d[:, None] * W + o_w, mask=m_d[:, None], other=0)
239+
b_y = tl.sum(b_cache * b_w, 1)
240+
else:
241+
b_y = tl.sum(b_cache, 1)
242+
if HAS_BIAS:
243+
b_y += tl.load(bias + o_d, mask=m_d)
244+
245+
if ACTIVATION == 'swish' or ACTIVATION == 'silu':
246+
b_y = b_y * tl.sigmoid(b_y)
247+
248+
if HAS_RESIDUAL:
249+
b_y += tl.load(residual + i_n * D + o_d, mask=m_d, other=0)
250+
251+
tl.store(y + i_n * D + o_d, tl.cast(b_y, dtype=y.dtype.element_ty, fp_downcast_rounding='rtne'), mask=m_d)
252+
253+
b_cache = tl.cast(b_cache, dtype=cache.dtype.element_ty, fp_downcast_rounding='rtne')
254+
# update the cache in-place
255+
p_cache = tl.make_block_ptr(cache + i_n * D*W, (D, W), (W, 1), (i_d * BD, 0), (BD, W), (1, 0))
256+
tl.store(p_cache, b_cache, boundary_check=(0, 1))
257+
258+
199259
def causal_conv1d_fwd(
200260
x: torch.Tensor,
201261
weight: torch.Tensor,
@@ -236,17 +296,18 @@ def grid(meta): return (triton.cdiv(D, meta['BD']), NT, B)
236296

237297
def causal_conv1d_bwd(
238298
x: torch.Tensor,
239-
weight: torch.Tensor,
240-
bias: torch.Tensor,
241-
residual: torch.Tensor,
242299
dy: torch.Tensor,
300+
weight: Optional[torch.Tensor] = None,
301+
bias: Optional[torch.Tensor] = None,
302+
residual: Optional[torch.Tensor] = None,
243303
activation: Optional[str] = None,
244304
cu_seqlens: Optional[torch.Tensor] = None
245305
):
246306
shape = x.shape
247307
if x.shape[-1] != weight.shape[0]:
248308
x = rearrange(x, 'b t ... -> b t (...)')
249-
B, T, D, W = *x.shape, weight.shape[1]
309+
B, T, D = x.shape
310+
W = weight.shape[1] if weight is not None else None
250311
BT = min(64, triton.next_power_of_2(triton.cdiv(max(16, B*T), get_multiprocessor_count(x.device.index))))
251312
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
252313
NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT)
@@ -295,6 +356,42 @@ def grid(meta): return (triton.cdiv(D, meta['BD']), NT, B)
295356
return dx.view(shape), dw, db, dr
296357

297358

359+
@input_guard
360+
def causal_conv1d_update(
361+
x: torch.Tensor,
362+
cache: torch.Tensor,
363+
residual: Optional[torch.Tensor] = None,
364+
weight: Optional[torch.Tensor] = None,
365+
bias: Optional[torch.Tensor] = None,
366+
activation: Optional[str] = None
367+
) -> torch.Tensor:
368+
shape = x.shape
369+
if weight is not None and x.shape[-1] != weight.shape[0]:
370+
x = rearrange(x, 'b t ... -> b t (...)')
371+
*_, D = x.shape
372+
N = x.numel() // D
373+
W = weight.shape[1] if weight is not None else None
374+
BD = 16
375+
376+
y = torch.empty_like(x)
377+
# NOTE: autotuning is disabled as cache is updated in-place
378+
def grid(meta): return (triton.cdiv(D, meta['BD']), N)
379+
causal_conv1d_update_kernel[grid](
380+
x=x,
381+
cache=cache,
382+
residual=residual,
383+
y=y,
384+
weight=weight,
385+
bias=bias,
386+
D=D,
387+
W=W,
388+
BD=BD,
389+
ACTIVATION=activation,
390+
num_warps=32,
391+
)
392+
return y.view(shape), cache
393+
394+
298395
class CausalConv1dFunction(torch.autograd.Function):
299396

300397
@staticmethod
@@ -319,9 +416,9 @@ def backward(ctx, dy: torch.Tensor):
319416
x, weight, bias, residual = ctx.saved_tensors
320417
dx, dw, db, dr = causal_conv1d_bwd(
321418
x=x,
419+
dy=dy,
322420
weight=weight,
323421
bias=bias,
324-
dy=dy,
325422
residual=residual,
326423
activation=ctx.activation,
327424
cu_seqlens=ctx.cu_seqlens
@@ -338,10 +435,10 @@ def causal_conv1d(
338435
cu_seqlens: Optional[torch.Tensor] = None
339436
):
340437
"""
341-
Implementation of Causal Conv1d layer used by Mamba/Mamba2 and DeltaNet.
438+
A causal 1D convolution implementation that powers Mamba/Mamba2 and DeltaNet architectures.
342439
343-
If `residual` is provided, it functions as the Canon operation
344-
as mentioned in https://papers.ssrn.com/sol3/papers.cfm?abstract_id=5240330.
440+
When a residual connection is provided, this implements the Canon operation
441+
described in the paper at https://papers.ssrn.com/sol3/papers.cfm?abstract_id=5240330.
345442
346443
Args:
347444
x:
@@ -386,39 +483,6 @@ def fft_conv(u, k, dropout_mask, gelu=True, k_rev=None):
386483
return out.to(dtype=u.dtype)
387484

388485

389-
@checkpoint
390-
def proj_then_conv1d(
391-
x: torch.Tensor,
392-
proj_weight: torch.Tensor,
393-
conv1d_weight: torch.Tensor,
394-
conv1d_bias: Optional[torch.Tensor] = None,
395-
cache: Optional[torch.Tensor] = None
396-
) -> torch.Tensor:
397-
# We do matmul and transpose BLH -> HBL at the same time
398-
x = rearrange(proj_weight @ rearrange(x, "b t d -> d (b t)"), "d (b t) -> b d t", t=x.shape[-2])
399-
400-
if causal_conv1d_fn is None:
401-
raise ImportError("`causal_conv1d_fn` is not available. Please install `causal-conv1d` first.")
402-
if cache is None:
403-
x = causal_conv1d_fn(
404-
x=x,
405-
weight=rearrange(conv1d_weight, "d 1 w -> d w"),
406-
bias=conv1d_bias,
407-
activation="silu",
408-
).transpose(1, 2)
409-
else:
410-
assert x.shape[-1] == 1, "Only support decoding with 1 token at a time for now"
411-
x = x.squeeze(-1)
412-
x = causal_conv1d_update(
413-
x=x,
414-
weight=rearrange(conv1d_weight, "d 1 w -> d w"),
415-
bias=conv1d_bias,
416-
cache=cache,
417-
activation="silu",
418-
)
419-
return x
420-
421-
422486
@triton.jit
423487
def causal_conv1d_varlen_states_fwd_kernel(
424488
x,
@@ -478,6 +542,7 @@ def __init__(
478542
use_fast_conv1d: Optional[bool] = True,
479543
device: Optional[torch.device] = None,
480544
dtype: Optional[torch.dtype] = None,
545+
**kwargs,
481546
):
482547
super().__init__(
483548
in_channels=hidden_size,
@@ -492,11 +557,22 @@ def __init__(
492557

493558
self.hidden_size = hidden_size
494559
self.activation = None
560+
561+
if kernel_size % 2 != 0:
562+
raise ValueError("Kernel size must be the power of 2")
495563
if activation is not None:
496564
assert activation in ['silu', 'swish'], f"Activation `{activation}` not supported yet."
497565
self.activation = activation
498566

499567
self.use_fast_conv1d = use_fast_conv1d
568+
if use_fast_conv1d:
569+
if causal_conv1d_fn is None:
570+
warnings.warn(
571+
"The `use_fast_conv1d` parameter is set to `True`, but `causal_conv1d_fn` is not available. "
572+
"Switching to the Triton implementation instead. "
573+
"Consider installing `causal_conv1d` to enable the CUDA implementation."
574+
)
575+
self.use_fast_conv1d = False
500576

501577
def extra_repr(self):
502578
s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
@@ -568,7 +644,7 @@ def forward(
568644
else:
569645
cache[:, :, -min(W, T):].copy_(rearrange(x[..., -min(W, T):, :], 'n w d -> n d w'))
570646

571-
if self.use_fast_conv1d and causal_conv1d_fn is None:
647+
if not self.use_fast_conv1d:
572648
y = causal_conv1d(
573649
x=x,
574650
weight=rearrange(self.weight, "d 1 w -> d w"),
@@ -580,29 +656,27 @@ def forward(
580656
return y, cache
581657

582658
x = rearrange(x, 'b t d -> b d t')
583-
if self.use_fast_conv1d:
584-
# Sequence index for each token. Used for varlen.
585-
# Suppose a batch consists of two sequences with lengths 3 and 4,
586-
# seq_idx=[0, 0, 0, 1, 1, 1, 1] for this batch.
587-
# NOTE: No need to provide this arg if `cu_seqlens` is passed.
588-
# This arg is just for BC, and will be removed in the future.
589-
# [B, T]
590-
seq_idx = kwargs.get('seq_idx', None)
591-
if cu_seqlens is not None and seq_idx is None:
592-
seq_idx = prepare_sequence_ids(cu_seqlens).to(torch.int32).unsqueeze(0)
593-
y = causal_conv1d_fn(
594-
x=x,
595-
weight=rearrange(self.weight, "d 1 w -> d w"),
596-
bias=self.bias,
597-
activation=self.activation,
598-
seq_idx=seq_idx,
599-
)
600-
else:
601-
if cu_seqlens is not None:
602-
raise ValueError("`cu_seqlens` is not supported for the naive Pytorch version")
603-
y = self._conv_forward(x, self.weight, self.bias)[..., :x.shape[-1]]
604-
if self.activation is not None:
605-
y = ACT2FN[self.activation](x)
659+
# Sequence index for each token. Used for varlen.
660+
# Suppose a batch consists of two sequences with lengths 3 and 4,
661+
# seq_idx=[0, 0, 0, 1, 1, 1, 1] for this batch.
662+
# NOTE: No need to provide this arg if `cu_seqlens` is passed.
663+
# This arg is just for BC, and will be removed in the future.
664+
# [B, T]
665+
seq_idx = kwargs.get('seq_idx', None)
666+
if cu_seqlens is not None and seq_idx is None:
667+
seq_idx = prepare_sequence_ids(cu_seqlens).to(torch.int32).unsqueeze(0)
668+
669+
# equivalent to:
670+
# y = self._conv_forward(x, self.weight, self.bias)[..., :x.shape[-1]]
671+
# if self.activation is not None:
672+
# y = ACT2FN[self.activation](x)
673+
y = causal_conv1d_fn(
674+
x=x,
675+
weight=rearrange(self.weight, "d 1 w -> d w"),
676+
bias=self.bias,
677+
activation=self.activation,
678+
seq_idx=seq_idx,
679+
)
606680
y = rearrange(y, 'b d t -> b t d')
607681
if residual is not None:
608682
y = y + residual
@@ -615,26 +689,30 @@ def step(
615689
cache: torch.Tensor,
616690
cu_seqlens: Optional[torch.LongTensor] = None
617691
):
618-
shape = x.shape
619-
x = x.squeeze(0) if cu_seqlens is not None else x.squeeze(1)
620-
if self.use_fast_conv1d:
621-
y = causal_conv1d_update(
692+
# NOTE: we follow the fast mode that updates the cache in-place
693+
if not self.use_fast_conv1d:
694+
return causal_conv1d_update(
622695
x=x,
623-
conv_state=cache,
624-
weight=rearrange(self.weight, "d 1 w -> d w"),
696+
cache=cache,
697+
residual=residual,
698+
weight=self.weight,
625699
bias=self.bias,
626700
activation=self.activation,
627701
)
628-
else:
629-
dtype = x.dtype
630-
# we follow the fast mode that updates the cache in-place
631-
cache.copy_(cache.roll(shifts=-1, dims=-1))
632-
cache[:, :, -1] = x
633-
y = torch.sum(cache * rearrange(self.weight, "d 1 w -> d w"), dim=-1)
634-
if self.bias is not None:
635-
y = y + self.bias
636-
if self.activation is not None:
637-
y = ACT2FN[self.activation](y).to(dtype=dtype)
702+
703+
shape = x.shape
704+
x = x.squeeze(0) if cu_seqlens is not None else x.squeeze(1)
705+
# equivalent to:
706+
# cache.copy_(cache.roll(shifts=-1, dims=-1))
707+
# cache[:, :, -1] = x
708+
# y = torch.sum(cache * rearrange(self.weight, "d 1 w -> d w"), dim=-1)
709+
y = causal_conv1d_update_cuda(
710+
x=x,
711+
conv_state=cache,
712+
weight=rearrange(self.weight, "d 1 w -> d w"),
713+
bias=self.bias,
714+
activation=self.activation,
715+
)
638716
y = y.view(shape)
639717
if residual is not None:
640718
y = y + residual

0 commit comments

Comments
 (0)