Broken audio processing

#42
by souflaeeh - opened

Audio processing seems to break for pretty much all use-cases that don't exclusively involve transcription, summarization or translation. For example, "Transcribe this audio" prompts seem to work well but the model points the audio inputs out as unusual with other prompts. Example code adapted from the official audio processing guide:

from transformers import AutoProcessor, AutoModelForImageTextToText

GEMMA_MODEL_ID = "google/gemma-3n-E4B-it"

processor = AutoProcessor.from_pretrained(GEMMA_MODEL_ID)
model = AutoModelForImageTextToText.from_pretrained(
            GEMMA_MODEL_ID, torch_dtype="auto")

messages = [
    {
        "role": "user",
        "content": [
            {"type": "audio", "audio": "https://ai.google.dev/gemma/docs/audio/roses-are.wav"},
            {"type": "text", "text": "1. Transcribe the audio\n2. Summarize the audio\n3. Did you notice anything unusual about the audio?"},
        ]
    }
]

input_ids = processor.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True, return_dict=True,
        return_tensors="pt",
)
input_ids = input_ids.to(model.device, dtype=model.dtype)

outputs = model.generate(**input_ids, max_new_tokens=400)

text = processor.batch_decode(
    outputs,
    skip_special_tokens=False,
    clean_up_tokenization_spaces=False
)
print(text[0])

Output (user prompt omitted):

<start_of_turn>model
**1. Transcription of the audio:**

The audio consists of the phrase "Roses are red, violets are blue." repeated multiple times.

**2. Summary of the audio:**

The audio simply repeats the well-known rhyming couplet "Roses are red, violets are blue." over and over again. There is no variation in tone or pacing.

**3. Did you notice anything unusual about the audio?**

Yes, the most unusual thing about the audio is the **extreme repetition**. The phrase is played repeatedly, filling the entire duration of the audio. This is not a typical way to hear a common phrase, making it stand out.<end_of_turn>

This behavior isn't a hallucination that only manifests when the model is asked about finding issues with the audio. It also occurs when trying to "voice-chat" with the model (sending user messages as audio), where the model has issues understanding the input and points out repeated phrases or letters (although it also seems to often be able to correctly respond to such voice-chat messages while pointing out the weirdness). Due to these behaviors, I'm suspecting that the model is trained to and capable of responding to such queries, but there are issues with the Transformers implementation (masking?).

Google org

Hi @souflaeeh Apologies for the delay
Can you please prompt the model as follows: 'Transcribe this audio into English summarize it, and identify any unusual characteristics'. This approach successfully produced a clear transcription of the 'Roses are red' poem, a brief summary of its content, and an observation that the audio quality was standard.

model
Transcription:

Roses are red, violets are blue.

Summary:

The audio contains the classic beginning of a poem or rhyme: "Roses are red, violets are blue." It is a simple, well-known phrase.

Unusual Observations:

There is nothing unusual about the audio

Thanks

Hi! It prints the following:

<start_of_turn>model
## Transcription:

Roses are red, violets are blue.

## Summary:

The audio contains a very short, common rhyming couplet. It states the well-known lines "Roses are red, violets are blue."

## Unusual Characteristics:

The most unusual characteristic of this audio is its **repetitive nature**. The word "ra" is repeated numerous times before the actual couplet is spoken. This is highly unexpected and doesn't contribute to the meaning or natural flow of the phrase. It suggests a technical issue, a glitch, or perhaps an intentional, albeit unusual, auditory effect. 

The spoken part of the couplet itself is clear and standard. However, the preceding repetition of "ra" is the most striking and unusual feature.<end_of_turn>

I'm not sure if this is helpful in any way, but I tried removing the 188 soft token padding, which leads to better results in my case.

Transcribe this audio into English summarize it, and identify any unusual characteristics<end_of_turn>
<start_of_turn>model
## Transcription:

Roses are red, violets are blue.

## Summary:

The audio contains the very beginning of a classic nursery rhyme: "Roses are red, violets are blue." It's a simple, well-known phrase often used as a starting point for poems or messages. 

## Unusual Characteristics:

There are **no unusual characteristics** in the audio. It's a clear, straightforward recitation of a very common and standard phrase. There's no unusual tone, background noise, or deviation from the expected pronunciation. It's a perfectly normal audio recording of this particular line.<end_of_turn>

Diff of my modifications:

diff --git a/huggingface-gemma3n/modular_gemma3n.py b/huggingface-gemma3n/modular_gemma3n.py
index 507c64983..3c5615238 100644
--- a/huggingface-gemma3n/modular_gemma3n.py
+++ b/huggingface-gemma3n/modular_gemma3n.py
@@ -2278,6 +2278,13 @@ class Gemma3nModel(PaliGemmaModel):
         Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
         equal to the length of multimodal features. If the lengths are different, an error is raised.
         """
+        def _count_feature_tokens(features: torch.Tensor) -> int:
+            if features.ndim == 3:
+                return features.shape[0] * features.shape[1]
+            if features.ndim == 2:
+                return features.shape[0]
+            raise ValueError(f"Expected features to be 2D or 3D, got shape={tuple(features.shape)}")
+
         if input_ids is None:
             special_image_mask = inputs_embeds == self.get_input_embeddings()(
                 torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
@@ -2293,18 +2300,20 @@ class Gemma3nModel(PaliGemmaModel):
             special_image_mask = input_ids == self.config.image_token_id
             special_audio_mask = input_ids == self.config.audio_token_id
 
-        n_image_tokens = special_image_mask.sum()
+        n_image_tokens = int(special_image_mask.sum().item())
         special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
         if image_features is not None and inputs_embeds[special_image_mask].numel() != image_features.numel():
+            n_image_features = _count_feature_tokens(image_features)
             raise ValueError(
-                f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.shape[0] * image_features.shape[1]}"
+                f"Image features and image tokens do not match: tokens: {n_image_tokens}, features: {n_image_features}"
             )
 
-        n_audio_tokens = special_audio_mask.sum()
+        n_audio_tokens = int(special_audio_mask.sum().item())
         special_audio_mask = special_audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
         if audio_features is not None and inputs_embeds[special_audio_mask].numel() != audio_features.numel():
+            n_audio_features = _count_feature_tokens(audio_features)
             raise ValueError(
-                f"Audio features and image tokens do not match: tokens: {n_audio_tokens}, features {audio_features.shape[0] * audio_features.shape[1]}"
+                f"Audio features and audio tokens do not match: tokens: {n_audio_tokens}, features: {n_audio_features}"
             )
 
         return special_image_mask, special_audio_mask
@@ -2404,26 +2413,35 @@ class Gemma3nModel(PaliGemmaModel):
         # Merge text and audio
         if input_features is not None and input_features_mask is not None:
             audio_features, audio_mask = self.get_audio_features(input_features, ~input_features_mask)
-
-            # The Gemma3nProcessor expects all audio will be 30s in length and inserts 188 audio soft tokens into the
-            # text to account for this. However, the audio preprocessing and encoder do not gurarantee they will
-            # produce 188 soft tokens; they will produce at most that many tokens, but they may produce fewer tokens
-            # depending on the length of the longest audio input in the batch. When we encounter this situation, we pad
-            # the audio feature out to 188 soft tokens with the emebedding of the last token in the embed_audio vocab.
-            audio_padding_toks = torch.tensor([[self.vocab_size - 1]], dtype=torch.long, device=audio_features.device)
-            audio_padding_embs = self.embed_audio(input_ids=audio_padding_toks)
-            audio_features = torch.where(audio_mask.unsqueeze(-1), audio_padding_embs, audio_features)
-
-            audio_batch_size, audio_seq_len, audio_embed_dim = audio_features.shape
-            extra_padding_tokens = self.config.audio_soft_tokens_per_image - audio_seq_len
-            extra_padding_features = audio_padding_embs.expand(audio_batch_size, extra_padding_tokens, audio_embed_dim)
-
-            audio_features = torch.cat((audio_features, extra_padding_features), dim=1)
             audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
-            _, special_audio_mask = self.get_placeholder_mask(
-                input_ids, inputs_embeds=inputs_embeds, audio_features=audio_features
+            audio_mask = audio_mask.to(device=audio_features.device)
+            valid_audio_mask = ~audio_mask
+            audio_features_to_scatter = audio_features[valid_audio_mask]
+
+            if input_ids is not None:
+                special_audio_mask = input_ids == self.config.audio_token_id
+            else:
+                audio_token_embed = self.get_input_embeddings()(
+                    torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device)
+                )
+                special_audio_mask = (inputs_embeds == audio_token_embed).all(-1)
+
+            limited_special_audio_mask = torch.zeros_like(special_audio_mask)
+            valid_audio_counts = valid_audio_mask.sum(dim=1).tolist()
+            for batch_idx, n_valid in enumerate(valid_audio_counts):
+                if n_valid == 0:
+                    continue
+                placeholder_positions = special_audio_mask[batch_idx].nonzero(as_tuple=False).squeeze(-1)
+                if n_valid > placeholder_positions.numel():
+                    raise ValueError(
+                        f"Audio features and audio tokens do not match: tokens: {placeholder_positions.numel()}, features: {n_valid}"
+                    )
+                limited_special_audio_mask[batch_idx, placeholder_positions[:n_valid]] = True
+
+            limited_special_audio_mask = (
+                limited_special_audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
             )
-            inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_features)
+            inputs_embeds = inputs_embeds.masked_scatter(limited_special_audio_mask, audio_features_to_scatter)
 
         outputs = self.language_model(
             input_ids=None,

Interesting ! It is because in the Gemma 3n Processor source code (specifically transformers version 4.53+) the audio_seq_length is hardcoded to default to 188 tokens . Since the input audio provided was only 3 sec i think it was getting repeated to fill the remaining slots to hit the 188 target .
Please try to input longer audio and check if the issue still persists .
And also Gemma‑3n rely heavily on how you ask the question. Generic prompts can lead the model to hallucinate, repeat sounds, or misinterpret audio. Being explicit (“Transcribe into English, summarize, and identify unusual characteristics”) gives the model a clearer task, which often results in cleaner, more structured output .

Sign up or log in to comment