Skip to content

Add Deepseek 3.1 & Deepseek 3.1 base#2607

Open
cheachu wants to merge 10 commits intokeras-team:masterfrom
cheachu:deepseek-v3.1
Open

Add Deepseek 3.1 & Deepseek 3.1 base#2607
cheachu wants to merge 10 commits intokeras-team:masterfrom
cheachu:deepseek-v3.1

Conversation

@cheachu
Copy link

@cheachu cheachu commented Feb 22, 2026

Description of the change

This PR introduces the DeepSeek-V3.1 model architecture to Keras Hub, fully compliant with the Keras 3 backend-agnostic design (TensorFlow, JAX, PyTorch).
Key architectural components implemented in this PR:

  • Multi-head Latent Attention (MLA): Implemented the MLA absorption trick to significantly reduce KV cache size during inference without materializing full per-head Key/Value tensors.
  • DeepSeekMoE (Mixture-of-Experts): Implemented the auxiliary-loss-free top-K routing mechanism using Sigmoid affinity scores. Crucially, the expert routing and computation loop has been fully vectorized using ops.einsum and batched kernel tensors, ensuring it is 100% XLA-compatible (jit_compile=True) and avoids the severe graph bloat associated with iterating over 256 experts.
  • YaRN RoPE Scaling: Added Yet another RoPE extensioN (YaRN) for effective long-context scaling.
  • RMSNorm & SwiGLU FFN: Implemented precise RMS normalization (using float32 casting for numerical stability) and dense SwiGLU layers for the initial non-MoE transformer layers.
  • Tokenizer & CausalLM: Added the BytePairTokenizer (based on the official DeepSeek implementation) along with the CausalLM and CausalLMPreprocessor wrappers for text generation.

Reference

Checklist

  • I have added all the necessary unit tests for my change.
  • I have verified that my change does not break existing code and works with all backends (TensorFlow, JAX, and PyTorch).
  • My PR is based on the latest changes of the main branch (if unsure, rebase the code).
  • I have followed the Keras Hub Model contribution guidelines in making these changes.
  • I have followed the Keras Hub API design guidelines in making these changes.
  • I have signed the Contributor License Agreement.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @cheachu, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly expands Keras Hub's capabilities by integrating the advanced DeepSeek-V3.1 large language model. The implementation focuses on cutting-edge architectural components like Multi-head Latent Attention for efficient memory usage during inference, a highly optimized Mixture-of-Experts system for scalable computation, and advanced positional encoding techniques for handling longer contexts. This addition provides Keras users with a powerful, state-of-the-art language model that is fully compatible with Keras 3's multi-backend design.

Highlights

  • DeepSeek-V3.1 Model Integration: Introduced the DeepSeek-V3.1 model architecture to Keras Hub, ensuring full compliance with the Keras 3 backend-agnostic design (TensorFlow, JAX, PyTorch).
  • Multi-head Latent Attention (MLA): Implemented MLA with an absorption trick to significantly reduce KV cache size during inference by compressing keys and values through a shared low-rank latent vector, avoiding materialization of full per-head Key/Value tensors.
  • DeepSeekMoE (Mixture-of-Experts): Integrated an auxiliary-loss-free top-K routing mechanism using Sigmoid affinity scores, with expert routing and computation fully vectorized using 'ops.einsum' for 100% XLA-compatibility and reduced graph bloat.
  • YaRN RoPE Scaling: Added Yet another RoPE extensioN (YaRN) to support effective long-context scaling by applying differential scaling to RoPE frequencies based on wavelength.
  • RMSNorm & SwiGLU FFN: Implemented precise RMS normalization (using float32 casting for numerical stability) and dense SwiGLU layers for the initial non-MoE transformer layers.
  • Tokenizer & CausalLM: Included the BytePairTokenizer (based on the official DeepSeek implementation) along with 'CausalLM' and 'CausalLMPreprocessor' wrappers for text generation.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • keras_hub/api/models/init.py
    • Added imports for DeepSeekV31Backbone, DeepSeekV31CausalLM, DeepSeekV31CausalLMPreprocessor, and DeepSeekV31Tokenizer to expose them in the API.
  • keras_hub/api/tokenizers/init.py
    • Added import for DeepSeekV31Tokenizer to make it accessible via the tokenizers API.
  • keras_hub/src/models/deepseek_v31/init.py
    • Added new file to export DeepSeekV31 components and register their presets.
  • keras_hub/src/models/deepseek_v31/deepseek_v31_attention.py
    • Added new file implementing the DeepSeek V31 Multi-head Latent Attention (MLA) layer, including YaRN RoPE scaling.
  • keras_hub/src/models/deepseek_v31/deepseek_v31_backbone.py
    • Added new file defining the DeepSeek V31 core transformer backbone, incorporating MLA and a configurable Mixture-of-Experts (MoE) or dense FFN.
  • keras_hub/src/models/deepseek_v31/deepseek_v31_backbone_test.py
    • Added new file containing unit tests for the DeepSeekV31Backbone, including cache building and parameter count verification.
  • keras_hub/src/models/deepseek_v31/deepseek_v31_causal_lm.py
    • Added new file implementing the DeepSeek V31 Causal Language Model, which pairs the backbone with an LM head and supports MLA-compatible KV caching for generation.
  • keras_hub/src/models/deepseek_v31/deepseek_v31_causal_lm_preprocessor.py
    • Added new file for the DeepSeek V31 Causal LM preprocessor, designed to prepare text inputs for the DeepSeekV31CausalLM.
  • keras_hub/src/models/deepseek_v31/deepseek_v31_causal_lm_preprocessor_test.py
    • Added new file with unit tests for the DeepSeekV31CausalLMPreprocessor, verifying preprocessing logic and special token handling.
  • keras_hub/src/models/deepseek_v31/deepseek_v31_causal_lm_test.py
    • Added new file containing unit tests for the DeepSeekV31CausalLM, covering basic functionality and text generation.
  • keras_hub/src/models/deepseek_v31/deepseek_v31_decoder_block.py
    • Added new file defining the DeepSeek V31 transformer decoder block, including DeepSeekV31RMSNorm and DeepSeekV31DenseFeedForward layers.
  • keras_hub/src/models/deepseek_v31/deepseek_v31_moe.py
    • Added new file implementing the DeepSeek V31 Mixture-of-Experts (MoE) layer with sigmoid-based routing and vectorized expert computation.
  • keras_hub/src/models/deepseek_v31/deepseek_v31_presets.py
    • Added new file to define preset configurations and metadata for DeepSeek V31 backbone, tokenizer, and preprocessor models.
  • keras_hub/src/models/deepseek_v31/deepseek_v31_tokenizer.py
    • Added new file implementing the DeepSeek V31 BytePairTokenizer, including special token handling and SentencePiece proto support.
  • keras_hub/src/models/deepseek_v31/deepseek_v31_tokenizer_test.py
    • Added new file with unit tests for the DeepSeekV31Tokenizer, covering basic tokenization and special token properties.
Activity
  • No human activity (comments, reviews) was detected on this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request introduces the DeepSeek-V3.1 model architecture to Keras Hub, including Multi-head Latent Attention (MLA), DeepSeekMoE, YaRN RoPE Scaling, RMSNorm, SwiGLU FFN, Tokenizer, and CausalLM. A critical security vulnerability was identified in the DeepSeekV31Tokenizer class where the proto argument, accepting a file path to a SentencePiece model, is loaded without checking Keras's safe_mode. This bypasses security protections against arbitrary file reads during model deserialization. It is recommended to include the necessary safe_mode check. Additionally, there are minor issues related to naming conventions in the presets file and a potential for improved clarity in the _yarn_inv_freq method's constant.

Comment on lines +59 to +64
if proto is not None:
try:
import sentencepiece as spm

sp = spm.SentencePieceProcessor()
sp.Load(proto)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-medium medium

The proto argument allows loading a SentencePiece model from an arbitrary file path. This bypasses Keras's safe_mode protection, which is intended to prevent arbitrary file reads during deserialization of untrusted models. The parent class BytePairTokenizer implements a check for safe_mode when loading from a path, and DeepSeekV31Tokenizer should do the same for the proto argument to maintain a consistent security posture.

Suggested change
if proto is not None:
try:
import sentencepiece as spm
sp = spm.SentencePieceProcessor()
sp.Load(proto)
if proto is not None:
from keras.src.saving import serialization_lib
if isinstance(proto, str) and serialization_lib.in_safe_mode():
raise ValueError(
"Requested the loading of a SentencePiece proto file outside of the "
"model archive. This carries a potential risk of loading "
"arbitrary and sensitive files and thus it is disallowed "
"by default. If you trust the source of the artifact, you "
"can override this error by passing `safe_mode=False` to "
"the loading function, or calling "
"`keras.config.enable_unsafe_deserialization()`."
)
try:
import sentencepiece as spm
sp = spm.SentencePieceProcessor()
sp.Load(proto)


# Wavelength = 2π / freq. High-freq → small wavelength, low-freq →
# large wavelength. YaRN applies more scaling to low-freq dimensions.
wavelengths = 2.0 * 3.14159265358979 / freqs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The constant 3.14159265358979 is a magic number. It would be more readable and maintainable to define this as a named constant, e.g., PI = 3.14159265358979.

Suggested change
wavelengths = 2.0 * 3.14159265358979 / freqs
PI = 3.14159265358979
wavelengths = 2.0 * PI / freqs


# Metadata for loading pretrained model weights and configurations.
backbone_presets = {
"deepseek_v3_base": {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The preset name deepseek_v3_base does not align with the model name DeepSeekV31 used throughout the codebase. The repository's naming conventions (Rule 55) state that preset names should be in snake_case and follow the pattern <model_name>_<component_type>.py. To maintain consistency, this should be deepseek_v31_base.

backbone_presets = {
    "deepseek_v31_base": {

"37B activated parameters."
),
"params": 671000000000,
"path": "deepseek_v3",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The path metadata field should reflect the model's specific version, deepseek_v31, for clarity and consistency with the model's naming. This helps in easily identifying the model version when loading presets.

            "path": "deepseek_v31",

"params": 671000000000,
"path": "deepseek_v3",
"model_type": "MoE",
"tokenizer": "DeepSeekV3Tokenizer",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The tokenizer name DeepSeekV3Tokenizer should be updated to DeepSeekV31Tokenizer to match the model's version and maintain consistency across the project. This ensures that the correct tokenizer is referenced for the DeepSeekV3.1 model.

            "tokenizer": "DeepSeekV31Tokenizer",

preprocessor_presets = {
"deepseek_v3_base": {
"metadata": {
"description": "DeepSeek V3 preprocessor.",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The path metadata field should reflect the model's specific version, deepseek_v31, for clarity and consistency with the model's naming. This helps in easily identifying the model version when loading presets.

            "path": "deepseek_v31",

"metadata": {
"description": "DeepSeek V3 preprocessor.",
"path": "deepseek_v3",
},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The Kaggle handle should reflect the model's specific version, deepseek_v31, for clarity and consistency with the model's naming. This helps in easily identifying the model version when loading presets.

        "kaggle_handle": "kaggle://deepseek-ai/deepseek-v31/preprocessor/1",

"path": "deepseek_v3",
},
"kaggle_handle": "kaggle://deepseek-ai/deepseek-v3/preprocessor/1",
},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The preset name deepseek_v3 does not align with the model name DeepSeekV31 used throughout the codebase. The repository's naming conventions (Rule 55) state that preset names should be in snake_case and follow the pattern <model_name>_<component_type>.py. To maintain consistency, this should be deepseek_v31.

    "deepseek_v31": {

},
"deepseek_v3": {
"metadata": {
"description": "DeepSeek V3 preprocessor.",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The path metadata field should reflect the model's specific version, deepseek_v31, for clarity and consistency with the model's naming. This helps in easily identifying the model version when loading presets.

            "path": "deepseek_v31",

"metadata": {
"description": "DeepSeek V3 preprocessor.",
"path": "deepseek_v3",
},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The Kaggle handle should reflect the model's specific version, deepseek_v31, for clarity and consistency with the model's naming. This helps in easily identifying the model version when loading presets.

        "kaggle_handle": "kaggle://deepseek-ai/deepseek-v31/preprocessor/1",

@cheachu
Copy link
Author

cheachu commented Feb 22, 2026

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This is an impressive and comprehensive contribution, adding the DeepSeek-V3.1 model to Keras Hub. The implementation is well-structured, follows the repository's modular design, and makes good use of Keras 3 features for backend-agnostic code. The vectorized MoE layer and MLA implementation are particularly well done. I have a few suggestions to improve code clarity, consistency, and testing, primarily related to adhering to the style guide and fixing an empty test case. Overall, great work on this complex model implementation.

Comment on lines +41 to +51
def test_backbone_basics(self):
original_assert_dtype_equal = self.assertDTypeEqual

def assert_dtype_flexible(tensor, expected_dtype, msg=None):
actual_dtype = str(tensor.dtype)
allowed_dtypes = ["float16", "bfloat16"]
if actual_dtype not in allowed_dtypes:
self.fail(
msg
or f"Tensor dtype {actual_dtype} not in allowed {allowed_dtypes}" # noqa: E501
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This test case test_backbone_basics is currently empty and does not perform any checks. It should be implemented to validate the backbone's basic functionality using self.run_backbone_test, as shown in the repository's testing guidelines (lines 462-468).

    def test_backbone_basics(self):
        self.run_backbone_test(
            cls=DeepSeekV31Backbone,
            init_kwargs=self.init_kwargs,
            input_data=self.input_data,
            expected_output_shape=(2, 5, 64),
        )
References
  1. The style guide requires using helper methods like self.run_backbone_test() to verify basic usage and shape inference for backbone models. (link)


# Wavelength = 2π / freq. High-freq → small wavelength, low-freq →
# large wavelength. YaRN applies more scaling to low-freq dimensions.
PI = 3.14159265358979
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For better precision and code clarity, it's recommended to use math.pi instead of a hardcoded value for PI. You can remove this line and use math.pi directly in the calculation on the next line. You'll also need to add import math at the top of the file.

Comment on lines +48 to +49
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The __init__ method should have explicit arguments instead of *args, **kwargs to improve clarity and align with the repository's style guide (see lines 233-240 of CONTRIBUTING_MODELS.md).

Suggested change
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __init__(
self,
tokenizer,
sequence_length=1024,
add_start_token=True,
add_end_token=True,
**kwargs,
):
super().__init__(
tokenizer=tokenizer,
sequence_length=sequence_length,
add_start_token=add_start_token,
add_end_token=add_end_token,
**kwargs,
)
References
  1. The style guide provides an example for Preprocessor __init__ methods with explicit arguments for tokenizer and sequence_length. (link)


# Normalize only the selected K scores (eq. 13).
top_k_weights = top_k_scores / (
ops.sum(top_k_scores, axis=-1, keepdims=True) + 1e-9
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The epsilon value 1e-9 is hardcoded. For consistency with other layers like DeepSeekV31RMSNorm and for better maintainability, consider making this a configurable parameter in the __init__ method with a default value. This would involve adding an epsilon argument to __init__ and updating get_config.

@sachinprasadhs sachinprasadhs added the new model For PRs that contribute a new model to the Keras Hub registry. label Feb 24, 2026
… format checks + removed the hardcoded pi and epsilon values
@cheachu
Copy link
Author

cheachu commented Feb 25, 2026

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a comprehensive and high-quality implementation of the DeepSeek-V3.1 model, including its novel Multi-head Latent Attention (MLA) and Mixture-of-Experts (MoE) components. The code is well-structured, follows the repository's style guide closely, and includes thorough testing. My feedback is minor and focuses on a couple of style improvements to enhance code consistency and readability.


# Wavelength = 2π / freq. High-freq → small wavelength, low-freq →
# large wavelength. YaRN applies more scaling to low-freq dimensions.
import math
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The math module is imported locally within the _yarn_inv_freq method. According to Python style guidelines (PEP 8), imports should be placed at the top of the file. Please move import math to the top of the file to make dependencies clear and avoid re-importing.

References
  1. The repository style guide requires using ruff for code formatting (line 736). ruff would flag this as an out-of-place import (E402). Moving imports to the top of the file is a standard Python convention (PEP 8) that improves readability by making dependencies clear at a glance. (link)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool :)
Will do that!

@cheachu
Copy link
Author

cheachu commented Feb 25, 2026

Hi @sachinprasadhs 👋
Can you review this? :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

new model For PRs that contribute a new model to the Keras Hub registry.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants