Skip to content

Add Gemma 3n model to KerasHub#2606

Open
laxmareddyp wants to merge 20 commits intokeras-team:masterfrom
laxmareddyp:gemma3n_model
Open

Add Gemma 3n model to KerasHub#2606
laxmareddyp wants to merge 20 commits intokeras-team:masterfrom
laxmareddyp:gemma3n_model

Conversation

@laxmareddyp
Copy link
Collaborator

@laxmareddyp laxmareddyp commented Feb 19, 2026

Description of the change

This PR completes the implementation of the Gemma 3n model, building upon the foundations laid in #2404.
It introduces critical architectural features, ensures numerical accuracy against the reference implementation, and streamlines the codebase for production readiness.

Note: Special thanks to @harshaljanjani for the initial work and foundations laid in #2404.

Key Changes & Improvements:

KV Sharing Implementation:

  • Integrated Key-Value (KV) sharing to optimize memory usage and inference efficiency, aligning with the Gemma 3 architecture.

Causal Masking:

  • Implemented proper causal masking logic to ensure the model correctly handles autoregressive sequences.

Numerical Parity Fixes:

  • Identified and resolved several sources of numerical divergence. The implementation now achieves accpetable numerical parity with the original Hugging Face weights.

Code Refactoring & Cleanup:

  • Removed redundant logic and consolidated overlapping code paths from the previous draft.
  • Deleted unnecessary files to keep the keras-hub directory clean and maintainable.

Relationship to Previous Work:

Reference

Colab Notebook

Numerical Verification Results:

Text-only Validation

  • Predicted tokens: ✅ 100% match (35/35 positions)
  • Mean absolute difference: 0.00045175
  • Elements within 1e-3: 99.73%

Multimodal Validation (text + image + audio):

  • Predicted tokens: ✅ 100% match (460/460 positions)
  • Mean absolute difference: 0.00082702
  • Per-modality breakdown:
    Text (14 positions): mean=0.00126, max=0.01021
    Vision (257 positions): mean=0.00072, max=0.00865
    Audio (189 positions): mean=0.00094, max=0.02665

Note on tolerance

  • Gemma3n has a uniquely deep architecture — 30 decoder layers with AltUp (4-way prediction/correction), Laurel blocks, and per-layer input gating.

  • Cross-framework float32 rounding differences (JAX/XLA vs PyTorch) accumulate ~5.6e-06 per layer, compounding to ~4.5e-04 at the logit level.

  • Layer-by-layer debugging confirmed that input embeddings match perfectly (0.00 diff) and error grows linearly through the decoder stack — there is no implementation bug.

  • At atol=1e-3, 99.7% match, At atol=1e-4, approximately 70% of logit elements match.

  • The 100% token prediction match at every position confirms the conversion is functionally correct.

Note on Parameter Count Mismatch:

  • The KerasHub Gemma3n backbone has 5,439,595,456 parameters, while the HF model has 5,439,438,272 (a difference of +157,184).
  • This is expected. The difference comes entirely from the pre-existing MobileNetV5 implementation in KerasHub, which sets use_bias=True for its rms_norm convolutions.
  • The HF/timm implementation does not use bias here.
  • We deliberately left MobileNetV5 untouched to preserve its existing Kaggle preset and avoid breaking downstream models.
  • The extra bias weights are initialized to zero during conversion and do not affect the output (as proven by the 100% token match in multimodal validation).

At atol=1e-3:
Screenshot 2026-02-23 at 10 30 56 PM
Screenshot 2026-02-23 at 10 31 45 PM

At atol=1e-4:

Screenshot 2026-02-23 at 10 33 43 PM Screenshot 2026-02-23 at 10 34 06 PM

Checklist

  • I have added all the necessary unit tests for my change.
  • I have verified that my change does not break existing code and works with all backends (TensorFlow, JAX, and PyTorch).
  • My PR is based on the latest changes of the main branch (if unsure, rebase the code).
  • I have followed the Keras Hub Model contribution guidelines in making these changes.
  • I have followed the Keras Hub API design guidelines in making these changes.
  • I have signed the Contributor License Agreement.

@github-actions github-actions bot added the Gemma Gemma model specific issues label Feb 19, 2026
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @laxmareddyp, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request finalizes the integration of the Gemma 3n model into KerasHub, building upon previous foundational work. It introduces crucial architectural elements for multimodal processing, ensuring numerical accuracy against reference implementations and streamlining the overall code for production readiness. The changes enable the model to effectively handle text, image, and audio inputs with optimized performance.

Highlights

  • KV Sharing Implementation: Integrated Key-Value (KV) sharing to optimize memory usage and inference efficiency, aligning with the Gemma 3 architecture.
  • Causal Masking: Implemented proper causal masking logic to ensure the model correctly handles autoregressive sequences.
  • Numerical Parity Fixes: Identified and resolved several sources of numerical divergence, achieving acceptable numerical parity with the original Hugging Face weights.
  • Code Refactoring & Cleanup: Removed redundant logic, consolidated overlapping code paths, and deleted unnecessary files for a cleaner and more maintainable codebase.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • keras_hub/api/layers/init.py
    • Added imports for Gemma3nAudioConverter and Gemma3nImageConverter.
  • keras_hub/api/models/init.py
    • Added imports for Gemma3nBackbone, Gemma3nCausalLM, Gemma3nCausalLMPreprocessor, and Gemma3nTokenizer.
  • keras_hub/api/tokenizers/init.py
    • Added import for Gemma3nTokenizer.
  • keras_hub/src/models/gemma3n/gemma3n_attention.py
    • Added Gemma3nAudioRelativePositionEmbedding for audio sequence position embeddings.
    • Added Gemma3nTextAttention for multi-head attention in text sequences.
    • Added Gemma3nAudioAttention for chunk-based audio sequence attention.
  • keras_hub/src/models/gemma3n/gemma3n_audio_converter.py
    • Added Gemma3nAudioConverter class for converting raw audio waveforms into log-mel spectrograms.
  • keras_hub/src/models/gemma3n/gemma3n_audio_converter_test.py
    • Added tests for Gemma3nAudioConverter, covering output shape, padding, and normalization.
  • keras_hub/src/models/gemma3n/gemma3n_audio_encoder.py
    • Added Gemma3nAudioSubSampleConvProjection for subsampling audio features.
    • Added Gemma3nAudioConformerBlock for conformer architecture in audio processing.
    • Added Gemma3nAudioEncoder as the main audio encoder for the Gemma3n model.
  • keras_hub/src/models/gemma3n/gemma3n_audio_layers.py
    • Added Gemma3nAudioCumulativeGroupNorm for cumulative group normalization of audio features.
    • Added Gemma3nAudioSSCPConvBlock for spectrogram sub-sampling convolutional preprocessing.
    • Added Gemma3nAudioConformerFeedForward for feed-forward module in Conformer blocks.
    • Added Gemma3nAudioConformerAttention for multi-head self-attention in Conformer blocks.
  • keras_hub/src/models/gemma3n/gemma3n_backbone.py
    • Added Gemma3nMultimodalEmbedder for handling multimodal embeddings.
    • Added Gemma3nMultimodalEmbeddingProcessor for interleaving text, vision, and audio embeddings.
    • Added Gemma3nBackbone as the multimodal transformer backbone for the Gemma3n model.
  • keras_hub/src/models/gemma3n/gemma3n_backbone_test.py
    • Added tests for Gemma3nBackbone, covering multimodal and text-only configurations, and model saving.
  • keras_hub/src/models/gemma3n/gemma3n_causal_lm.py
    • Added Gemma3nCausalLM class for end-to-end multimodal causal language modeling.
  • keras_hub/src/models/gemma3n/gemma3n_causal_lm_preprocessor.py
    • Added Gemma3nCausalLMPreprocessor for multimodal input preprocessing for Gemma3nCausalLM.
  • keras_hub/src/models/gemma3n/gemma3n_causal_lm_preprocessor_test.py
    • Added tests for Gemma3nCausalLMPreprocessor, including text-only, vision, audio, and multimodal scenarios.
  • keras_hub/src/models/gemma3n/gemma3n_causal_lm_test.py
    • Added tests for Gemma3nCausalLM, covering basic functionality, flash attention, early stopping, and multimodal generation.
  • keras_hub/src/models/gemma3n/gemma3n_image_converter.py
    • Added Gemma3nImageConverter class for image preprocessing in Gemma3n models.
  • keras_hub/src/models/gemma3n/gemma3n_text_decoder.py
    • Added Gemma3nTextDecoderBlock for implementing a single Gemma3n decoder block.
  • keras_hub/src/models/gemma3n/gemma3n_text_layers.py
    • Added Gemma3nTextScaledWordEmbedding for scaled word embeddings.
    • Added Gemma3nTextMLP for Gemma3n-specific feed-forward networks with activation sparsity.
    • Added Gemma3nTextLaurelBlock for low-rank residual blocks.
    • Added Gemma3nTextAltUp for the Alternating Update (AltUp) mechanism.
  • keras_hub/src/models/gemma3n/gemma3n_text_model.py
    • Added Gemma3nTextModel as the core Gemma3n text model layer.
  • keras_hub/src/models/gemma3n/gemma3n_tokenizer.py
    • Added Gemma3nTokenizer class for tokenizing raw strings into integer sequences with special token handling.
  • keras_hub/src/models/gemma3n/gemma3n_tokenizer_test.py
    • Added tests for Gemma3nTokenizer, covering basic tokenization and error handling for missing special tokens.
  • keras_hub/src/models/gemma3n/rms_normalization.py
    • Added Gemma3nRMSNorm class for Gemma 3n specific RMS normalization.
  • keras_hub/src/tests/mocks/mock_gemma3n_tokenizer.py
    • Added MockGemma3nTokenizer for testing purposes, including special token definitions.
  • tools/checkpoint_conversion/convert_gemma3n_checkpoints.py
    • Added a script to convert Hugging Face Gemma3n model checkpoints to Keras format, including validation.
  • tools/sentencepiece_testing/create_gemma3n_test_proto.py
    • Added a utility script to create a SentencePiece proto file specifically for Gemma3n testing, including its special tokens.
Activity
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a comprehensive implementation of the Gemma 3n model, including its multimodal capabilities for text, image, and audio. While no specific security vulnerabilities were identified, there are critical issues with backend-agnosticism in the preprocessing layers. Specifically, Gemma3nAudioConverter and Gemma3nCausalLMPreprocessor rely on TensorFlow-specific operations (tf.signal, tf.py_function, tf.strings, tf.RaggedTensor) rather than keras.ops, which is required for compatibility with JAX and PyTorch backends according to the KerasHub style guide. These layers must be refactored to use keras.ops to ensure backend-agnosticism. Addressing these issues will make this an excellent contribution.

@laxmareddyp laxmareddyp added the kokoro:force-run Runs Tests on GPU label Feb 24, 2026
@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label Feb 24, 2026
@laxmareddyp laxmareddyp added the kokoro:force-run Runs Tests on GPU label Feb 24, 2026
@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label Feb 24, 2026
@sachinprasadhs sachinprasadhs added the new model For PRs that contribute a new model to the Keras Hub registry. label Feb 24, 2026
Copy link
Collaborator

@sachinprasadhs sachinprasadhs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thnak you !
I have reviewed few files and made comments, please check.

Comment on lines +31 to +33
conf_num_attention_heads,
conf_attention_context_left,
conf_attention_context_right,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

conf_num_attention_heads --> num_attention_heads
conf_attention_context_left --> num_attention_context_left
conf_attention_context_right --> num_attention_context_right


Args:
hidden_size: int. The size of the hidden state.
conf_num_attention_heads: int. The number of attention heads.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

conf_num_attention_heads --> num_attention_heads

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same changes in arg and in description

Comment on lines +577 to +583
conf_num_attention_heads: int. The number of attention heads.
conf_attention_chunk_size: int. The size of each processing chunk.
conf_attention_context_right: int. The number of steps to attend to in
the future.
conf_attention_context_left: int. The number of steps to attend to in
the past, including the current step.
conf_attention_logit_cap: float. The soft cap value to apply to the
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

conf_num_attention_heads --> num_attention_heads
conf_attention_chunk_size --> num_attention_context_right
conf_attention_context_left --> num_attention_context_left
conf_attention_logit_cap --> attention_logit_cap

Comment on lines +3 to +4
import keras
import numpy as np
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add from keras import ops and then use the ops as ops.xxx instead of keras.ops.xxx everywhere.

np.arange(num_timescales, dtype="float32")
* -log_timescale_increment
)
self.inv_timescales = keras.ops.expand_dims(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After importing ops, use ops.expand_dims, follow this for all the ops.

self._allow_non_tensor_positional_args = True
self.built = True

def _create_fb_matrix(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

avoid abbreviation, name it something like _create_filterbank_matrix

return batch_outputs_features, None
return batch_outputs_features, batch_outputs_masks

def call(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add call argument details.


@keras_hub_export("keras_hub.layers.Gemma3nAudioConverter")
class Gemma3nAudioConverter(keras.layers.Layer):
"""Converts raw audio waveforms into log-mel spectrograms.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add example usage section.

Comment on lines +467 to +469
"mel_floor": self.mel_floor_arg,
"per_bin_mean": self.per_bin_mean_arg,
"per_bin_stddev": self.per_bin_stddev_arg,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keep the names constant and avoid suffix like _arg

@@ -0,0 +1,580 @@
import keras
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import also from keras import ops and use ops.xxx in the file.

Copy link
Collaborator

@sachinprasadhs sachinprasadhs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Few more file reviews.

Comment on lines +200 to +211
conf_residual_weight: float. The weight for the residual connection in
the feed-forward layers.
conf_num_attention_heads: int. The number of attention heads.
conf_attention_chunk_size: int. The size of chunks for local attention.
conf_attention_context_right: int. The right context size for local
attention.
conf_attention_context_left: int. The left context size for local
attention.
conf_attention_logit_cap: float. The maximum value for the attention
logits.
conf_conv_kernel_size: int. The kernel size for the 1D convolution
layer.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this conf_ prefix, not Keras Hub standard.

Comment on lines +300 to +302
def compute_output_shape(self, input_shape):
audio_encodings_shape, _ = input_shape
return audio_encodings_shape
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not consistent with the build shape check, this handles only one type.

current_f_for_block_input = input_feat_size
self.calculated_block_padding = []
self.calculated_f_out_dims = []
for i in range(2):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this is hardcoded? is it always assumed to be 2 for all the configs? or doesn't it has to be length of sscp_conv_kernel_size?

Comment on lines +110 to +155
def build(self, input_shape):
_, t_in, f_in = input_shape
conv0_input_shape = (None, 1, t_in, f_in)
self.conv_0.build(conv0_input_shape)
if t_in is not None:
pad_t_top_0, pad_t_bottom_0 = self.calculated_block_padding[0][2:4]
kernel_h_0, _ = self.sscp_conv_kernel_size[0]
stride_h_0, _ = self.sscp_conv_stride_size[0]
t_padded_0 = t_in + pad_t_top_0 + pad_t_bottom_0
t_out_0 = (t_padded_0 - kernel_h_0) // stride_h_0 + 1
else:
t_out_0 = None
c_out_0 = self.sscp_conv_channel_size[0]
f_out_0 = self.calculated_f_out_dims[0]
conv1_input_shape = (None, c_out_0, t_out_0, f_out_0)
self.conv_1.build(conv1_input_shape)
if t_out_0 is not None:
t_padded_1 = (
t_out_0
+ self.calculated_block_padding[1][2]
+ self.calculated_block_padding[1][3]
)
kernel_h_1, _ = self.sscp_conv_kernel_size[1]
stride_h_1, _ = self.sscp_conv_stride_size[1]
t_out_1 = (t_padded_1 - kernel_h_1) // stride_h_1 + 1
else:
t_out_1 = None
c_out_1 = self.sscp_conv_channel_size[1]
f_out_1 = self.calculated_f_out_dims[1]
proj_input_shape = (None, t_out_1, f_out_1 * c_out_1)
self.input_proj_linear.build(proj_input_shape)
super().build(input_shape)

def compute_output_shape(self, input_shape):
b, t_in, f_in = input_shape
if t_in is not None:
_, _, pad_t_top_0, pad_t_bottom_0 = self.calculated_block_padding[0]
kernel_h_0, _ = self.sscp_conv_kernel_size[0]
stride_h_0, _ = self.sscp_conv_stride_size[0]
t_padded_0 = t_in + pad_t_top_0 + pad_t_bottom_0
t_out_0 = (t_padded_0 - kernel_h_0) // stride_h_0 + 1
_, _, pad_t_top_1, pad_t_bottom_1 = self.calculated_block_padding[1]
kernel_h_1, _ = self.sscp_conv_kernel_size[1]
stride_h_1, _ = self.sscp_conv_stride_size[1]
t_padded_1 = t_out_0 + pad_t_top_1 + pad_t_bottom_1
t_out_1 = (t_padded_1 - kernel_h_1) // stride_h_1 + 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is also doping direct indexing of [0] and [1] or based on assumption of sscp_conv_channel_size having only 2 elements.
It's better to add a validation check for sscp_conv_channel_size not to exceed 2 elements and document why it is this way.

Comment on lines +472 to +474
time_stride_product = 1
for stride_pair in self.sscp_conv_stride_size:
time_stride_product *= stride_pair[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not being used in the code.

max_position_embeddings: int. The maximum sequence length.
vocab_size_per_layer_input: int. The vocab size for per-layer inputs.
hidden_size_per_layer_input: int. The hidden size for per-layer inputs.
altup_num_inputs: int. The number of inputs for the AltUp mechanism.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternating Updates(AltUp) for better clarity.

Comment on lines +199 to +207
dtype=input_ids_spec.dtype
if hasattr(input_ids_spec.dtype, "name")
else "float32",
)
num_layers = self.language_model.num_hidden_layers
per_layer_hidden_size = self.language_model.hidden_size_per_layer_input
per_layer_inputs_spec = keras.KerasTensor(
shape=(batch_size, seq_len, num_layers, per_layer_hidden_size),
dtype=input_ids_spec.dtype
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't it have to be model compute dtype? not the token id dtype?

inputs_embeds,
)
if self.audio_encoder and self.embed_audio:
audio_mask = input_ids >= self.embed_audio.vocab_offset
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not consistent with the vision upper bound logic.

input_data = {
"token_ids": np.random.randint(0, 50, size=(1, 16), dtype="int32"),
"attention_mask": np.ones((1, 1, 16, 16), dtype=bool),
"pixel_values": np.random.rand(1, 1, 224, 224, 3).astype("float32"),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't it have to be "images" instead of "pixel_values" for the model input as per the implementation.

audio_indices = inputs.get("audio_indices", None)
vision_mask = inputs.get("vision_mask", None)
audio_mask = inputs.get("audio_mask", None)
audios = inputs.get("audios", None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the use of this? This is not being used.

Copy link
Collaborator

@sachinprasadhs sachinprasadhs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed rest of the files. Please address it and mak the reolved comments as Resolved.
For the Generic comments like naming convention, import stype etc apply it to all the files.

Comment on lines +364 to +373
input_features is not None
and len(keras.ops.shape(input_features)) == 2
):
input_features = keras.ops.expand_dims(input_features, axis=0)
if (
input_features_mask is not None
and len(keras.ops.shape(input_features_mask)) == 1
):
input_features_mask = keras.ops.expand_dims(
input_features_mask, axis=0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic is contradicting to the docstring, as per the docstring, unbatched input feature should be (num_audios, audio_seq_len, feature_size) but here you are checking only rank 2. Same for input_features_mask

if len(audios.shape) > 1:
audios = tf.RaggedTensor.from_tensor(audios)
else:
audios = tf.ragged.constant([audios.numpy()], dtype=tf.float32)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect .numpy() would fail in graph mode and currently this will not be caught in any of the test.

):
# If a 4D attention mask is passed,
# squeeze it to 2D for standard processing.
if padding_mask is not None and len(keras.ops.shape(padding_mask)) == 4:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use static rank for reliable result or to avoid failure in Graph mode len(padding_mask.shape)

Comment on lines +224 to +226
decoder_mask = merge_padding_and_attention_mask(
inputs=x, padding_mask=padding_mask, attention_mask=None
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here padding_mask which is passed is 4D, but in merge_padding_and_attention_mask it is documented as 2D, please check.

Comment on lines +317 to +320
reshape_shape = modalities_shape[:-1] + (
self.altup_num_inputs,
self.altup_num_inputs,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's better to use ops.concatenate

MODEL_CONFIGS = {"mobilenetv5_300m_enc": mobilenetv5_config}


def convert_model(hf_config, dtype=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to move convert weights and configs under utils/transformers and keep the validation and other codes here?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Gemma Gemma model specific issues new model For PRs that contribute a new model to the Keras Hub registry.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants