Skip to content

Releases: fla-org/flash-linear-attention

v0.4.2

12 Mar 14:45
ca910f8

Choose a tag to compare

What's Changed

  • [Misc] Use autopep8 to keep style by @zhiyuan1i in #697
  • [Misc] Reduce D2H/H2D Sync by @zhiyuan1i in #698
  • [Conv]: Support mix mode(Triton fwd and CUDA bwd) by @zhiyuan1i in #699
  • [Test] Add memory guard fixtures for CUDA memory safety testing by @zhiyuan1i in #700
  • [KDA] Add lowerbound gate function by @zhiyuan1i in #701
  • [KDA] Remove deprecated head_first in kda gate func by @zhiyuan1i in #702
  • [KDA] Speed up chunk_kda by introducing lowerbound gate by @zhiyuan1i in #703
  • [NSA] fix varlen related logic in cmp dkv kernel by @yibozhong in #707
  • [DPLR] Speed up DPLR by lowerbound gate by @zhiyuan1i in #709
  • [Misc] Enhance non-cuda platform ci by @zhiyuan1i in #708
  • [Conv] Add non-contiguous tensor support for convolution ops by @zhiyuan1i in #712
  • [Conv] Fix corner case by @zhiyuan1i in #714
  • [Conv] Refactor dh0 into separate Triton kernel and add gradient tests by @zhiyuan1i in #717
  • [Misc] Wrap exp/log math ops with @triton.jit and enforce float32 pre… by @zhiyuan1i in #720
  • [Conv] Refractor to packages by @zhiyuan1i in #722
  • [Conv] Clean duplicate calculate chunk_indices by @zhiyuan1i in #724
  • [DPLR] Add disable_recompute support for DPLR chunk op by @zhiyuan1i in #726
  • [CP] fuse fwd/bwd kernels and fix IMA in long context by @zhiyuan1i in #733
  • [KCP] add KCP.md; fix fp32 precision in M matrix chain; cleanup CP tests by @zhiyuan1i in #740
  • [Backend] Introduce dispatch system by @zhiyuan1i in #741
  • [Backend] Select from available backends based on priority order by @zhiyuan1i in #742
  • [MAMBA2] fix initialization for mamba2 by @mayank31398 in #739
  • [KDA] Refractor interface by @zhiyuan1i in #744
  • [Deltarule] Added intra-card context parallel optimization for KDA and GDN by @zhiyuan1i in #743
  • [Misc] centralize reference implementations in naive.py by @zhiyuan1i in #746
  • [Cache] Fix get_seq_length to return per-layer length by @zhiyuan1i in #748
  • fewer l2norm recompilations by @tyler-romero in #745
  • [OJA] Integrate Gated OJA Rule by @AwesomeSeq in #730
  • [Misc] remove redundant dot precision param in KDA recompute_w_u by @KevinZeng08 in #750
  • Fix shared memory guards for AMD RDNA GPUs (64KB shared mem) by @Gildoniel in #751
  • [Fix] Guard A_log and dt_bias re-initialization against loaded checkpoint values in GatedDeltaNet, Comba, and KDA by @ljxw88 in #754
  • [MAMBA-2] fix mamba-2 init for FSDP-2 with DTensors by @mayank31398 in #753
  • [Deltrarule] Add cache for intra cp by @zhiyuan1i in #755
  • Update Windows warning to detect triton-windows by @erm14254 in #757
  • [Misc] Reduce recompile by @zhiyuan1i in #764
  • [PaTH] Prevent int32 index overflow for long sequences by @zhixuan-lin in #769
  • [Test] Add regression tests for cache seen-token bug (GH-766) by @zhiyuan1i in #768
  • [Misc] removes all @torch.jit.script decorators from the codebase by @zhiyuan1i in #767
  • [Conv] Fix invalid memory access by @zhiyuan1i in #774
  • [KDA][GDN] Support transpose_state_layout for [V,K] state memory layout by @zhiyuan1i in #776
  • [GDN] Enhance Triton 3.2 compatibility by @winglet0996 in #773
  • [Model] Unify cache function by @zhiyuan1i in #777

New Contributors

Full Changelog: v0.4.1...v0.4.2

🎄 v0.4.1

24 Dec 18:07
3a904f0

Choose a tag to compare

What's Changed

New Contributors

Full Changelog: v0.4.0...v0.4.1

v0.4.0

27 Oct 08:18

Choose a tag to compare

🧠 New Models

What's Changed

New Contributors

Full Changelog: v0.3.2...v0.4.0

v0.3.2

10 Sep 07:43
f7d95fa

Choose a tag to compare

📣 Highlights

Starting with this release, every time we ship a new version of flash-linear-attention, we will simultaneously publish fla-core: a minimal-dependency subset of the main repo that contains only the essentials.

🧠 New Models

What's Changed

Full Changelog: v0.3.1...v0.3.2

v0.3.1

26 Aug 20:28
80acaeb

Choose a tag to compare

What's Changed

New Contributors

Full Changelog: v0.3.0...v0.3.1

v0.3.0

14 Jul 09:49
17dd566

Choose a tag to compare

Highlights

🧠 New Models

We are excited to expand our model library with the addition of four powerful new architectures.

What's Changed

New Contributors

Full Changelog: v0.2.2...v0.3.0

v0.2.2

05 Jun 16:50

Choose a tag to compare

What's Changed

New Contributors

Full Changelog: v0.2.1...v0.2.2

v0.2.1

23 Apr 17:08
a670dff

Choose a tag to compare

Highlights

🚀 Performance Boost for DeltaNet

We've achieved a notable performance enhancement for (Gated) DeltaNet models. The optimization efforts focused on the fused LayerNormGated layer, particularly for small headdims, which has resulted in a 1.1x speedup.

Below are the benchmarks for 1B parameter models, tested on 4k sequences in varlen mode, using a single H100 GPU

TPS (K tokens/s)
Transformer++ 53.8
DeltaNet (before) 48.6
DeltaNet (after) 54.0

by running

python -m benchmarks.benchmark_training_throughput \
  --name delta_net \
  --batch_size 1 \
  --seq_len 32768 \
  --context_len 4096 \
  --varlen \
  --steps 512

What's Changed

New Contributors

Full Changelog: v0.2.0...v0.2.1

v0.2.0

11 Apr 20:31
6bfd5e6

Choose a tag to compare

What's Changed

  • [Attn] Delete V reduction & Enable 256 headdim tests by @yzhangcs in #273
  • [RWKV7] Add more elementwise kernels by @zhiyuan1i in #271
  • [CI] Remove cache and disable full test on Arc GPU by @zhiyuan1i in #274
  • [Fox] Add model/layer/kernel impls w/ varlen support by @yzhangcs in #275
  • [FoX] Simplify some tests and enhance tiling by @zhiyuan1i in #277
  • [Test] Remove some warnings and correct condition checks by @zhiyuan1i in #278
  • [CI] auto-cancel workflows on PR merge via concurrency group by @zhiyuan1i in #280
  • [Test] use tl.float16 instead of tl.bfloat16 by @zhiyuan1i in #281
  • [OP] replace tl.exp, tl.log, tl.log2 with fast ops when FLA_USE_FAST_OPS=1 by @zhiyuan1i in #276
  • [FoX] Rename fox to forgetting_attn by @yzhangcs in #282
  • [DeltaNet] WY repr speedup by @yzhangcs in #279
  • [README] Add --no-use-pep517 flag for faster installation by @zhiyuan1i in #286
  • [FoX] Skip test D>128 on RTX4090 by @zhiyuan1i in #287
  • [FoX] Test different forget gate initialization ranges by @zhixuan-lin in #291
  • [FoX] Fix class inheritance for ForgettingTransformerForCausalLM by @zhixuan-lin in #293
  • [CI] use latest stable triton by @zhiyuan1i in #294
  • [Triton] use tl.gather to enhance performance by @zhiyuan1i in #270
  • [WY representation] Faster lower triangle inverse by @sustcsonglin in #289
  • [GroupNorm] Add argument is_rms_norm to GroupNorm by @zhixuan-lin in #295
  • [GroupNorm] Return correct residual in reference implementation by @zhixuan-lin in #297
  • [CI] Don't show Triton autotune logs in CI by @zhiyuan1i in #298
  • [FoX] Use GroupNorm for QK-norm implementation in FoX by @zhixuan-lin in #299
  • [Utils] Update H100 and A100 configs by @zhiyuan1i in #306
  • Pass shifted labels and add a warning to RWKV-7 initialization. by @Triang-jyed-driung in #304
  • [Misc.] Update imports for GatedDeltaProduct by @yzhangcs in #309
  • [FAQ] Rewrite the nightly installation instructions by @zhiyuan1i in #305
  • Add unit tests for model forward and variable-length checks by @yzhangcs in #310
  • [Test] Improve path handling and test file detection by @zhiyuan1i in #311
  • [ShortConv] Adjust input shape according to cu_seqlens by @yzhangcs in #316
  • [Tests] Add unit tests for generation with padding by @yzhangcs in #312
  • [Testing] Update testing.py by @zhiyuan1i in #320
  • [DeltaNet] optimize chunk_delta_h by @sustcsonglin in #315
  • [CI] Only cancel in-progress CI for pull requests by @zhiyuan1i in #321
  • [Test] Skip some tests on arcA770 by @zhiyuan1i in #322
  • [API] Update head_first parameter default to False by @yzhangcs in #324
  • [Rotary] Remove max_seqlen parameter and adjust related logic by @yzhangcs in #326
  • [DeltaProduct] Remove unnecessary config parameter. by @JulienSiems in #325
  • fix the training problem of GatedDeltaProduct by @ridgerchu in #327
  • [Linear Attn] Fix head_first tests by @yzhangcs in #330
  • [Deprecated] Remove head_first option in gla variants by @yzhangcs in #337
  • [Test] Ensure most tests on Triton 3.2.0 and add 4096 seq_length in tests [skip test] by @zhiyuan1i in #300
  • [FoX] Merge code to FlashAttention | support batch inference by @sustcsonglin in #333
  • [DeltaNet] Delete head_first option for all by @yzhangcs in #338
  • [WIP] Remove head_first option by @yzhangcs in #339
  • [RWKV7] add input_precision param [skip test] by @zhiyuan1i in #335
  • [Testing] Add recursive dependency finding for test discovery by @zhiyuan1i in #341
  • [WIP] Delete head_first option for cumsum by @yzhangcs in #342
  • [WIP] Delete head_first tests for DeltaNet/GLA by @yzhangcs in #344
  • [Attn] Remove head_first & rename offsets to cu_seqlens by @yzhangcs in #345
  • [RWKV7] Drop some kernels to enhance speed by @zhiyuan1i in #346
  • Remove the head_first arg from several token mixing layer fns. by @yzhangcs in #347

New Contributors

Full Changelog: v0.1.2...v0.2.0

v0.1.2

31 Mar 06:30
53b3ac7

Choose a tag to compare

What's Changed

  • [RWKV7] fix RWKV7Attention.__init__ by @exhyy in #238
  • fix(triton): remove num_warps=8 in bwd_prepare_wy_repr_kernel to avoid MMA layout assertion on non-Ampere GPUs. by @kugwzk in #240
  • [Fix]: reshape o before o_proj in linear_attn layer. by @Luther-Sparks in #243
  • [CI] Seperate tests to compile , normal and varlen by @zhiyuan1i in #247
  • [ABC] Add use_rope parameter to ABCAttention and ABCConfig & Fix compiler bugs in kernels by @yzhangcs in #248
  • [CI] trigger GPU workflow only on pull_request events by @zhiyuan1i in #249
  • Create test_linearatten.py by @kangyiyang in #250
  • [CI] Fix all erros and enable testing for PR by @zhiyuan1i in #251
  • [CI] add H100 GPU by @zhiyuan1i in #254
  • [Gated DeltaNet] fix gdn kernel bugs on h100 when vdim=64 by @kugwzk in #256
  • [Test] Enhance support for NVIDIA Hopper GPU by @zhiyuan1i in #257
  • [FAQ] Update triton-nightly links by @yzhangcs in #259
  • [Attn] Add triton impls for MHA/GQA by @yzhangcs in #260
  • [Attn] Use larger block size for hopper devices by @yzhangcs in #261
  • [Attn] Enable test for attn by @zhiyuan1i in #262
  • [CI] fix a syntax error in triton-nightly by @zhiyuan1i in #263
  • Bump fla to v0.1.2 by @yzhangcs in #264

New Contributors

Full Changelog: v0.1.1...v0.1.2