33# from https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/convolution.py
44
55import math
6+ import warnings
67from typing import Optional , Tuple
78
89import torch
1213import triton .language as tl
1314from einops import rearrange
1415
15- from fla .modules .activations import ACT2FN
1616from fla .ops .utils import prepare_chunk_indices , prepare_sequence_ids
17- from fla .utils import checkpoint , get_multiprocessor_count , input_guard
17+ from fla .utils import get_multiprocessor_count , input_guard
1818
1919try :
20- from causal_conv1d import causal_conv1d_fn , causal_conv1d_update
20+ from causal_conv1d import causal_conv1d_fn
21+ from causal_conv1d import causal_conv1d_update as causal_conv1d_update_cuda
2122except ImportError :
2223 causal_conv1d_fn = None
23- causal_conv1d_update = None
24+ causal_conv1d_update_cuda = None
2425
2526
2627@triton .heuristics ({
@@ -74,8 +75,8 @@ def causal_conv1d_fwd_kernel(
7475 m_d = o_d < D
7576
7677 if HAS_WEIGHT :
77- # [D , W]
78- b_w = tl .load (weight + o_d [:, None ] * W + o_w , mask = m_d [:, None ], other = 0 )
78+ # [BD , W]
79+ b_w = tl .load (weight + o_d [:, None ] * W + o_w , mask = m_d [:, None ], other = 0 ). to ( tl . float32 )
7980
8081 b_y = tl .zeros ((BT , BD ), dtype = tl .float32 )
8182 for i_w in tl .static_range (- W + 1 , 1 ):
@@ -87,7 +88,7 @@ def causal_conv1d_fwd_kernel(
8788 b_yi *= tl .sum (b_w * (o_w == (i_w + W - 1 )), 1 )
8889 b_y += b_yi
8990 if HAS_BIAS :
90- b_y += tl .load (bias + o_d , mask = m_d )
91+ b_y += tl .load (bias + o_d , mask = m_d ). to ( tl . float32 )
9192
9293 if ACTIVATION == 'swish' or ACTIVATION == 'silu' :
9394 b_y = b_y * tl .sigmoid (b_y )
@@ -159,7 +160,7 @@ def causal_conv1d_bwd_kernel(
159160 if HAS_WEIGHT :
160161 p_x = tl .make_block_ptr (x + bos * D , (T , D ), (D , 1 ), (i_t * BT , i_d * BD ), (BT , BD ), (1 , 0 ))
161162 b_x = tl .load (p_x , boundary_check = (0 , 1 ))
162- # [D , W]
163+ # [BD , W]
163164 b_w = tl .load (weight + o_d [:, None ] * W + o_w , mask = m_d [:, None ], other = 0 )
164165
165166 b_dx = tl .zeros ((BT , BD ), dtype = tl .float32 )
@@ -196,6 +197,65 @@ def causal_conv1d_bwd_kernel(
196197 tl .store (p_dx , tl .cast (b_dx , dtype = p_dx .dtype .element_ty , fp_downcast_rounding = 'rtne' ), boundary_check = (0 , 1 ))
197198
198199
200+ @triton .heuristics ({
201+ 'HAS_WEIGHT' : lambda args : args ['weight' ] is not None ,
202+ 'HAS_BIAS' : lambda args : args ['bias' ] is not None ,
203+ 'HAS_RESIDUAL' : lambda args : args ['residual' ] is not None ,
204+ })
205+ @triton .jit
206+ def causal_conv1d_update_kernel (
207+ x ,
208+ cache ,
209+ residual ,
210+ y ,
211+ weight ,
212+ bias ,
213+ D : tl .constexpr ,
214+ W : tl .constexpr ,
215+ BD : tl .constexpr ,
216+ ACTIVATION : tl .constexpr ,
217+ HAS_WEIGHT : tl .constexpr ,
218+ HAS_BIAS : tl .constexpr ,
219+ HAS_RESIDUAL : tl .constexpr ,
220+ ):
221+ i_d , i_n = tl .program_id (0 ), tl .program_id (1 )
222+
223+ o_d = i_d * BD + tl .arange (0 , BD )
224+ o_w = tl .arange (0 , W )
225+ m_d = o_d < D
226+ m_c = o_w < W - 1
227+
228+ # [BD]
229+ b_x = tl .load (x + i_n * D + o_d , mask = m_d , other = 0 ).to (tl .float32 )
230+
231+ # shift the cache by 1 with the last one being discarded
232+ p_cache = tl .make_block_ptr (cache + i_n * D * W , (D , W ), (W , 1 ), (i_d * BD , 1 ), (BD , W ), (1 , 0 ))
233+ # [BD, W]
234+ b_cache = tl .load (p_cache , boundary_check = (0 , 1 )).to (tl .float32 )
235+ b_cache = tl .where (m_c [None , :], b_cache , b_x [:, None ])
236+
237+ if HAS_WEIGHT :
238+ b_w = tl .load (weight + o_d [:, None ] * W + o_w , mask = m_d [:, None ], other = 0 )
239+ b_y = tl .sum (b_cache * b_w , 1 )
240+ else :
241+ b_y = tl .sum (b_cache , 1 )
242+ if HAS_BIAS :
243+ b_y += tl .load (bias + o_d , mask = m_d )
244+
245+ if ACTIVATION == 'swish' or ACTIVATION == 'silu' :
246+ b_y = b_y * tl .sigmoid (b_y )
247+
248+ if HAS_RESIDUAL :
249+ b_y += tl .load (residual + i_n * D + o_d , mask = m_d , other = 0 )
250+
251+ tl .store (y + i_n * D + o_d , tl .cast (b_y , dtype = y .dtype .element_ty , fp_downcast_rounding = 'rtne' ), mask = m_d )
252+
253+ b_cache = tl .cast (b_cache , dtype = cache .dtype .element_ty , fp_downcast_rounding = 'rtne' )
254+ # update the cache in-place
255+ p_cache = tl .make_block_ptr (cache + i_n * D * W , (D , W ), (W , 1 ), (i_d * BD , 0 ), (BD , W ), (1 , 0 ))
256+ tl .store (p_cache , b_cache , boundary_check = (0 , 1 ))
257+
258+
199259def causal_conv1d_fwd (
200260 x : torch .Tensor ,
201261 weight : torch .Tensor ,
@@ -236,17 +296,18 @@ def grid(meta): return (triton.cdiv(D, meta['BD']), NT, B)
236296
237297def causal_conv1d_bwd (
238298 x : torch .Tensor ,
239- weight : torch .Tensor ,
240- bias : torch .Tensor ,
241- residual : torch .Tensor ,
242299 dy : torch .Tensor ,
300+ weight : Optional [torch .Tensor ] = None ,
301+ bias : Optional [torch .Tensor ] = None ,
302+ residual : Optional [torch .Tensor ] = None ,
243303 activation : Optional [str ] = None ,
244304 cu_seqlens : Optional [torch .Tensor ] = None
245305):
246306 shape = x .shape
247307 if x .shape [- 1 ] != weight .shape [0 ]:
248308 x = rearrange (x , 'b t ... -> b t (...)' )
249- B , T , D , W = * x .shape , weight .shape [1 ]
309+ B , T , D = x .shape
310+ W = weight .shape [1 ] if weight is not None else None
250311 BT = min (64 , triton .next_power_of_2 (triton .cdiv (max (16 , B * T ), get_multiprocessor_count (x .device .index ))))
251312 chunk_indices = prepare_chunk_indices (cu_seqlens , BT ) if cu_seqlens is not None else None
252313 NT = len (chunk_indices ) if cu_seqlens is not None else triton .cdiv (T , BT )
@@ -295,6 +356,42 @@ def grid(meta): return (triton.cdiv(D, meta['BD']), NT, B)
295356 return dx .view (shape ), dw , db , dr
296357
297358
359+ @input_guard
360+ def causal_conv1d_update (
361+ x : torch .Tensor ,
362+ cache : torch .Tensor ,
363+ residual : Optional [torch .Tensor ] = None ,
364+ weight : Optional [torch .Tensor ] = None ,
365+ bias : Optional [torch .Tensor ] = None ,
366+ activation : Optional [str ] = None
367+ ) -> torch .Tensor :
368+ shape = x .shape
369+ if weight is not None and x .shape [- 1 ] != weight .shape [0 ]:
370+ x = rearrange (x , 'b t ... -> b t (...)' )
371+ * _ , D = x .shape
372+ N = x .numel () // D
373+ W = weight .shape [1 ] if weight is not None else None
374+ BD = 16
375+
376+ y = torch .empty_like (x )
377+ # NOTE: autotuning is disabled as cache is updated in-place
378+ def grid (meta ): return (triton .cdiv (D , meta ['BD' ]), N )
379+ causal_conv1d_update_kernel [grid ](
380+ x = x ,
381+ cache = cache ,
382+ residual = residual ,
383+ y = y ,
384+ weight = weight ,
385+ bias = bias ,
386+ D = D ,
387+ W = W ,
388+ BD = BD ,
389+ ACTIVATION = activation ,
390+ num_warps = 32 ,
391+ )
392+ return y .view (shape ), cache
393+
394+
298395class CausalConv1dFunction (torch .autograd .Function ):
299396
300397 @staticmethod
@@ -319,9 +416,9 @@ def backward(ctx, dy: torch.Tensor):
319416 x , weight , bias , residual = ctx .saved_tensors
320417 dx , dw , db , dr = causal_conv1d_bwd (
321418 x = x ,
419+ dy = dy ,
322420 weight = weight ,
323421 bias = bias ,
324- dy = dy ,
325422 residual = residual ,
326423 activation = ctx .activation ,
327424 cu_seqlens = ctx .cu_seqlens
@@ -338,10 +435,10 @@ def causal_conv1d(
338435 cu_seqlens : Optional [torch .Tensor ] = None
339436):
340437 """
341- Implementation of Causal Conv1d layer used by Mamba/Mamba2 and DeltaNet.
438+ A causal 1D convolution implementation that powers Mamba/Mamba2 and DeltaNet architectures .
342439
343- If ` residual` is provided, it functions as the Canon operation
344- as mentioned in https://papers.ssrn.com/sol3/papers.cfm?abstract_id=5240330.
440+ When a residual connection is provided, this implements the Canon operation
441+ described in the paper at https://papers.ssrn.com/sol3/papers.cfm?abstract_id=5240330.
345442
346443 Args:
347444 x:
@@ -386,39 +483,6 @@ def fft_conv(u, k, dropout_mask, gelu=True, k_rev=None):
386483 return out .to (dtype = u .dtype )
387484
388485
389- @checkpoint
390- def proj_then_conv1d (
391- x : torch .Tensor ,
392- proj_weight : torch .Tensor ,
393- conv1d_weight : torch .Tensor ,
394- conv1d_bias : Optional [torch .Tensor ] = None ,
395- cache : Optional [torch .Tensor ] = None
396- ) -> torch .Tensor :
397- # We do matmul and transpose BLH -> HBL at the same time
398- x = rearrange (proj_weight @ rearrange (x , "b t d -> d (b t)" ), "d (b t) -> b d t" , t = x .shape [- 2 ])
399-
400- if causal_conv1d_fn is None :
401- raise ImportError ("`causal_conv1d_fn` is not available. Please install `causal-conv1d` first." )
402- if cache is None :
403- x = causal_conv1d_fn (
404- x = x ,
405- weight = rearrange (conv1d_weight , "d 1 w -> d w" ),
406- bias = conv1d_bias ,
407- activation = "silu" ,
408- ).transpose (1 , 2 )
409- else :
410- assert x .shape [- 1 ] == 1 , "Only support decoding with 1 token at a time for now"
411- x = x .squeeze (- 1 )
412- x = causal_conv1d_update (
413- x = x ,
414- weight = rearrange (conv1d_weight , "d 1 w -> d w" ),
415- bias = conv1d_bias ,
416- cache = cache ,
417- activation = "silu" ,
418- )
419- return x
420-
421-
422486@triton .jit
423487def causal_conv1d_varlen_states_fwd_kernel (
424488 x ,
@@ -478,6 +542,7 @@ def __init__(
478542 use_fast_conv1d : Optional [bool ] = True ,
479543 device : Optional [torch .device ] = None ,
480544 dtype : Optional [torch .dtype ] = None ,
545+ ** kwargs ,
481546 ):
482547 super ().__init__ (
483548 in_channels = hidden_size ,
@@ -492,11 +557,22 @@ def __init__(
492557
493558 self .hidden_size = hidden_size
494559 self .activation = None
560+
561+ if kernel_size % 2 != 0 :
562+ raise ValueError ("Kernel size must be the power of 2" )
495563 if activation is not None :
496564 assert activation in ['silu' , 'swish' ], f"Activation `{ activation } ` not supported yet."
497565 self .activation = activation
498566
499567 self .use_fast_conv1d = use_fast_conv1d
568+ if use_fast_conv1d :
569+ if causal_conv1d_fn is None :
570+ warnings .warn (
571+ "The `use_fast_conv1d` parameter is set to `True`, but `causal_conv1d_fn` is not available. "
572+ "Switching to the Triton implementation instead. "
573+ "Consider installing `causal_conv1d` to enable the CUDA implementation."
574+ )
575+ self .use_fast_conv1d = False
500576
501577 def extra_repr (self ):
502578 s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
@@ -568,7 +644,7 @@ def forward(
568644 else :
569645 cache [:, :, - min (W , T ):].copy_ (rearrange (x [..., - min (W , T ):, :], 'n w d -> n d w' ))
570646
571- if self .use_fast_conv1d and causal_conv1d_fn is None :
647+ if not self .use_fast_conv1d :
572648 y = causal_conv1d (
573649 x = x ,
574650 weight = rearrange (self .weight , "d 1 w -> d w" ),
@@ -580,29 +656,27 @@ def forward(
580656 return y , cache
581657
582658 x = rearrange (x , 'b t d -> b d t' )
583- if self .use_fast_conv1d :
584- # Sequence index for each token. Used for varlen.
585- # Suppose a batch consists of two sequences with lengths 3 and 4,
586- # seq_idx=[0, 0, 0, 1, 1, 1, 1] for this batch.
587- # NOTE: No need to provide this arg if `cu_seqlens` is passed.
588- # This arg is just for BC, and will be removed in the future.
589- # [B, T]
590- seq_idx = kwargs .get ('seq_idx' , None )
591- if cu_seqlens is not None and seq_idx is None :
592- seq_idx = prepare_sequence_ids (cu_seqlens ).to (torch .int32 ).unsqueeze (0 )
593- y = causal_conv1d_fn (
594- x = x ,
595- weight = rearrange (self .weight , "d 1 w -> d w" ),
596- bias = self .bias ,
597- activation = self .activation ,
598- seq_idx = seq_idx ,
599- )
600- else :
601- if cu_seqlens is not None :
602- raise ValueError ("`cu_seqlens` is not supported for the naive Pytorch version" )
603- y = self ._conv_forward (x , self .weight , self .bias )[..., :x .shape [- 1 ]]
604- if self .activation is not None :
605- y = ACT2FN [self .activation ](x )
659+ # Sequence index for each token. Used for varlen.
660+ # Suppose a batch consists of two sequences with lengths 3 and 4,
661+ # seq_idx=[0, 0, 0, 1, 1, 1, 1] for this batch.
662+ # NOTE: No need to provide this arg if `cu_seqlens` is passed.
663+ # This arg is just for BC, and will be removed in the future.
664+ # [B, T]
665+ seq_idx = kwargs .get ('seq_idx' , None )
666+ if cu_seqlens is not None and seq_idx is None :
667+ seq_idx = prepare_sequence_ids (cu_seqlens ).to (torch .int32 ).unsqueeze (0 )
668+
669+ # equivalent to:
670+ # y = self._conv_forward(x, self.weight, self.bias)[..., :x.shape[-1]]
671+ # if self.activation is not None:
672+ # y = ACT2FN[self.activation](x)
673+ y = causal_conv1d_fn (
674+ x = x ,
675+ weight = rearrange (self .weight , "d 1 w -> d w" ),
676+ bias = self .bias ,
677+ activation = self .activation ,
678+ seq_idx = seq_idx ,
679+ )
606680 y = rearrange (y , 'b d t -> b t d' )
607681 if residual is not None :
608682 y = y + residual
@@ -615,26 +689,30 @@ def step(
615689 cache : torch .Tensor ,
616690 cu_seqlens : Optional [torch .LongTensor ] = None
617691 ):
618- shape = x .shape
619- x = x .squeeze (0 ) if cu_seqlens is not None else x .squeeze (1 )
620- if self .use_fast_conv1d :
621- y = causal_conv1d_update (
692+ # NOTE: we follow the fast mode that updates the cache in-place
693+ if not self .use_fast_conv1d :
694+ return causal_conv1d_update (
622695 x = x ,
623- conv_state = cache ,
624- weight = rearrange (self .weight , "d 1 w -> d w" ),
696+ cache = cache ,
697+ residual = residual ,
698+ weight = self .weight ,
625699 bias = self .bias ,
626700 activation = self .activation ,
627701 )
628- else :
629- dtype = x .dtype
630- # we follow the fast mode that updates the cache in-place
631- cache .copy_ (cache .roll (shifts = - 1 , dims = - 1 ))
632- cache [:, :, - 1 ] = x
633- y = torch .sum (cache * rearrange (self .weight , "d 1 w -> d w" ), dim = - 1 )
634- if self .bias is not None :
635- y = y + self .bias
636- if self .activation is not None :
637- y = ACT2FN [self .activation ](y ).to (dtype = dtype )
702+
703+ shape = x .shape
704+ x = x .squeeze (0 ) if cu_seqlens is not None else x .squeeze (1 )
705+ # equivalent to:
706+ # cache.copy_(cache.roll(shifts=-1, dims=-1))
707+ # cache[:, :, -1] = x
708+ # y = torch.sum(cache * rearrange(self.weight, "d 1 w -> d w"), dim=-1)
709+ y = causal_conv1d_update_cuda (
710+ x = x ,
711+ conv_state = cache ,
712+ weight = rearrange (self .weight , "d 1 w -> d w" ),
713+ bias = self .bias ,
714+ activation = self .activation ,
715+ )
638716 y = y .view (shape )
639717 if residual is not None :
640718 y = y + residual
0 commit comments