Skip to content

[Feat] Adds LongCat-AudioDiT pipeline #13390

Open
RuixiangMa wants to merge 11 commits intohuggingface:mainfrom
RuixiangMa:longcataudiodit
Open

[Feat] Adds LongCat-AudioDiT pipeline #13390
RuixiangMa wants to merge 11 commits intohuggingface:mainfrom
RuixiangMa:longcataudiodit

Conversation

@RuixiangMa
Copy link
Copy Markdown

@RuixiangMa RuixiangMa commented Apr 2, 2026

What does this PR do?

Adds LongCat-AudioDiT model support to diffusers.

Although LongCat-AudioDiT can be used for TTS-like generation, it is fundamentally a diffusion-based audio generation model (text conditioning + iterative latent denoising + VAE decoding) rather than a conventional autoregressive TTS model, so i think it fits naturally into diffusers.

Test

import soundfile as sf
import torch
from diffusers import LongCatAudioDiTPipeline

pipeline = LongCatAudioDiTPipeline.from_pretrained(
    "meituan-longcat/LongCat-AudioDiT-1B",
    torch_dtype=torch.float16,
)
pipeline = pipeline.to("cuda")

audio = pipeline(
    prompt="A calm ocean wave ambience with soft wind in the background.",
    audio_end_in_s=5.0,
    num_inference_steps=16,
    guidance_scale=4.0,
    output_type="pt",
).audios

output = audio[0, 0].float().cpu().numpy()
sf.write("longcat.wav", output, pipeline.sample_rate)

Result

longcat.wav

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

Signed-off-by: Lancer <maruixiang6688@gmail.com>
@RuixiangMa RuixiangMa changed the title Longcataudiodit [Feat] Adds LongCat-AudioDiT support Apr 2, 2026
@RuixiangMa RuixiangMa changed the title [Feat] Adds LongCat-AudioDiT support [Feat] Adds LongCat-AudioDiT pipeline Apr 2, 2026
Signed-off-by: Lancer <maruixiang6688@gmail.com>
@dg845 dg845 requested review from dg845 and yiyixuxu April 4, 2026 00:31
)


def _pixel_shuffle_1d(hidden_states: torch.Tensor, factor: int) -> torch.Tensor:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly, I think we should inline _pixel_shuffle_1d in UpsampleShortcut following #13390 (comment).

Comment on lines +515 to +519
self.time_embed = AudioDiTTimestepEmbedding(dim)
self.input_embed = AudioDiTEmbedder(latent_dim, dim)
self.text_embed = AudioDiTEmbedder(dit_text_dim, dim)
self.rotary_embed = AudioDiTRotaryEmbedding(dim_head, 2048, base=100000.0)
self.blocks = nn.ModuleList(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.time_embed = AudioDiTTimestepEmbedding(dim)
self.input_embed = AudioDiTEmbedder(latent_dim, dim)
self.text_embed = AudioDiTEmbedder(dit_text_dim, dim)
self.rotary_embed = AudioDiTRotaryEmbedding(dim_head, 2048, base=100000.0)
self.blocks = nn.ModuleList(
self.time_embed = AudioDiTTimestepEmbedding(dim)
self.input_embed = AudioDiTEmbedder(latent_dim, dim)
self.text_embed = AudioDiTEmbedder(dit_text_dim, dim)
self.rotary_embed = AudioDiTRotaryEmbedding(dim_head, 2048, base=100000.0)
self.blocks = nn.ModuleList(

See #13390 (comment).

Comment on lines +584 to +589
batch_size = hidden_states.shape[0]
if timestep.ndim == 0:
timestep = timestep.repeat(batch_size)
timestep_embed = self.time_embed(timestep)
text_mask = encoder_attention_mask.bool()
encoder_hidden_states = self.text_embed(encoder_hidden_states, text_mask)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
batch_size = hidden_states.shape[0]
if timestep.ndim == 0:
timestep = timestep.repeat(batch_size)
timestep_embed = self.time_embed(timestep)
text_mask = encoder_attention_mask.bool()
encoder_hidden_states = self.text_embed(encoder_hidden_states, text_mask)
batch_size = hidden_states.shape[0]
if timestep.ndim == 0:
timestep = timestep.repeat(batch_size)
timestep_embed = self.time_embed(timestep)
text_mask = encoder_attention_mask.bool()
encoder_hidden_states = self.text_embed(encoder_hidden_states, text_mask)

Can you also refactor forward here so that it is better organized, following #13390 (comment)? See for example the QwenImageTransformer2DModel.forward method:

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reorganized parts of forward incrementally; kept the current structure otherwise to avoid unnecessary behavioral churn.

Signed-off-by: Lancer <maruixiang6688@gmail.com>
@github-actions github-actions bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 10, 2026
@RuixiangMa
Copy link
Copy Markdown
Author

Thanks for iterating! I left some follow-up comments.

Thx, PTAL

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@dg845
Copy link
Copy Markdown
Collaborator

dg845 commented Apr 11, 2026

@bot /style

@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Apr 11, 2026

Style bot fixed some files and pushed the changes.

@github-actions github-actions bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 11, 2026
Signed-off-by: Lancer <maruixiang6688@gmail.com>
@github-actions github-actions bot added utils size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 11, 2026
@github-actions github-actions bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 11, 2026
@RuixiangMa
Copy link
Copy Markdown
Author

These CI failures do not appear to be related to this PR


def _get_uniform_flow_match_scheduler_sigmas(num_inference_steps: int) -> list[float]:
num_inference_steps = max(int(num_inference_steps), 2)
num_updates = num_inference_steps - 1
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should define num_inference_steps to match the number of function evaluations we're performing (that is, to have the same semantics that num_updates currently has), which is the usual diffusers behavior. This would also allow us to remove the behavior where we overwrite num_inference_steps=1 below in __call__.

return {key[len(prefix) :]: value for key, value in state_dict.items() if key.startswith(prefix)}


def _get_uniform_flow_match_scheduler_sigmas(num_inference_steps: int) -> list[float]:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should inline _get_uniform_flow_match_scheduler_sigmas into __call__ so that it's easier to understand how the sigma schedule is being prepared. See e.g. Flux2Pipeline for an example of this:

sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas

We generally prefer not to have too many small functions in the pipeline code.

Comment on lines +54 to +60
def _approx_duration_from_text(text: str | list[str], max_duration: float = 30.0) -> float:
if isinstance(text, list):
if not text:
return 0.0
return max(_approx_duration_from_text(prompt, max_duration=max_duration) for prompt in text)

en_dur_per_char = 0.082
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _approx_duration_from_text(text: str | list[str], max_duration: float = 30.0) -> float:
if isinstance(text, list):
if not text:
return 0.0
return max(_approx_duration_from_text(prompt, max_duration=max_duration) for prompt in text)
en_dur_per_char = 0.082
def _approx_duration_from_text(text: str | list[str], max_duration: float = 30.0) -> float:
if not text:
return 0.0
if isinstance(text, str):
text = [text]
en_dur_per_char = 0.082

nit: I think refactoring this function to be non-recursive (by making it work naturally with a list of strings) would make it more clear.

first_hidden = F.layer_norm(first_hidden, (first_hidden.shape[-1],), eps=1e-6)
prompt_embeds = prompt_embeds + first_hidden
lengths = attention_mask.sum(dim=1).to(device)
return prompt_embeds.float(), lengths
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return prompt_embeds.float(), lengths
return prompt_embeds, lengths

Do we need to call .float() on prompt_embeds here? I think we should generally respect the output dtype from self.text_encoder.

)
self.scheduler.set_begin_index(0)
timesteps = self.scheduler.timesteps
sample = latents
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think using the standard name latents instead of sample would be more clear. It would also work better with PipelineTesterMixin tests.

if latents is None:
duration = max(1, min(duration, max_duration))

text_condition, text_condition_len = self.encode_prompt(normalized_prompts, device)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
text_condition, text_condition_len = self.encode_prompt(normalized_prompts, device)
prompt_embeds, text_condition_len = self.encode_prompt(normalized_prompts, device)

Similarly to #13390 (comment), I think using the standard name prompt_embeds would be better here.

Comment on lines +536 to +537
if not return_dict:
return (waveform,)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if not return_dict:
return (waveform,)
self.maybe_free_model_hooks()
if not return_dict:
return (waveform,)

Calling self.maybe_free_model_hooks() allows the pipeline to clear model hooks correctly, such as those used to support model offloading.

Comment on lines +527 to +530
if output_type == "latent":
if not return_dict:
return (sample,)
return AudioPipelineOutput(audios=sample)
Copy link
Copy Markdown
Collaborator

@dg845 dg845 Apr 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if output_type == "latent":
if not return_dict:
return (sample,)
return AudioPipelineOutput(audios=sample)
if output_type == "latent":
waveform = sample

A little simpler. Also makes it so that we don't have to call self.maybe_free_model_hooks() twice (see #13390 (comment)).

latent_cond=latent_cond,
).sample
pred = null_pred + (pred - null_pred) * guidance_scale
sample = self.scheduler.step(pred, t, sample, return_dict=False)[0]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
sample = self.scheduler.step(pred, t, sample, return_dict=False)[0]
sample = self.scheduler.step(pred, t, sample, return_dict=False)[0]
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
sample = callback_outputs.pop("latents", latent)
text_condition = callback_outputs.pop("prompt_embeds", prompt_embeds)

Example for supporting callbacks. This assumes we use the standard names latents and prompt_embeds (see #13390 (comment), #13390 (comment)). See also how e.g. Flux2Pipeline supports callbacks:

if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)

guidance_scale: float = 4.0,
generator: torch.Generator | list[torch.Generator] | None = None,
output_type: str = "np",
return_dict: bool = True,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return_dict: bool = True,
return_dict: bool = True,
callback_on_step_end: Callable[[int, int], None] | None = None,
callback_on_step_end_tensor_inputs: list[str] = ["latents"],

Follow-up for callback support (see #13390 (comment)).



class LongCatAudioDiTPipeline(DiffusionPipeline):
model_cpu_offload_seq = "text_encoder->transformer->vae"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
model_cpu_offload_seq = "text_encoder->transformer->vae"
model_cpu_offload_seq = "text_encoder->transformer->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds"]

Follow up for callback support (#13390 (comment)). The callback tests specifically check for the name latents here, which is one reason to use it over samples.



class LongCatAudioDiTTransformer(ModelMixin, ConfigMixin):
_supports_gradient_checkpointing = False
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
_supports_gradient_checkpointing = False
_supports_gradient_checkpointing = False
_repeated_blocks = ["AudioDiTBlock"]

Setting _repeated_blocks here enables regional compilation support. This also allows us to not skip the TestLongCatAudioDiTTransformerCompile.test_torch_compile_repeated_blocks test.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation models pipelines size/L PR with diff > 200 LOC tests utils

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants