[Feat] Adds LongCat-AudioDiT pipeline #13390
[Feat] Adds LongCat-AudioDiT pipeline #13390RuixiangMa wants to merge 11 commits intohuggingface:mainfrom
Conversation
Signed-off-by: Lancer <maruixiang6688@gmail.com>
9c4613f to
d2a2621
Compare
src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
| ) | ||
|
|
||
|
|
||
| def _pixel_shuffle_1d(hidden_states: torch.Tensor, factor: int) -> torch.Tensor: |
There was a problem hiding this comment.
Similarly, I think we should inline _pixel_shuffle_1d in UpsampleShortcut following #13390 (comment).
src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
src/diffusers/models/transformers/transformer_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
src/diffusers/models/transformers/transformer_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
src/diffusers/models/transformers/transformer_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
src/diffusers/models/transformers/transformer_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
src/diffusers/models/transformers/transformer_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
| self.time_embed = AudioDiTTimestepEmbedding(dim) | ||
| self.input_embed = AudioDiTEmbedder(latent_dim, dim) | ||
| self.text_embed = AudioDiTEmbedder(dit_text_dim, dim) | ||
| self.rotary_embed = AudioDiTRotaryEmbedding(dim_head, 2048, base=100000.0) | ||
| self.blocks = nn.ModuleList( |
There was a problem hiding this comment.
| self.time_embed = AudioDiTTimestepEmbedding(dim) | |
| self.input_embed = AudioDiTEmbedder(latent_dim, dim) | |
| self.text_embed = AudioDiTEmbedder(dit_text_dim, dim) | |
| self.rotary_embed = AudioDiTRotaryEmbedding(dim_head, 2048, base=100000.0) | |
| self.blocks = nn.ModuleList( | |
| self.time_embed = AudioDiTTimestepEmbedding(dim) | |
| self.input_embed = AudioDiTEmbedder(latent_dim, dim) | |
| self.text_embed = AudioDiTEmbedder(dit_text_dim, dim) | |
| self.rotary_embed = AudioDiTRotaryEmbedding(dim_head, 2048, base=100000.0) | |
| self.blocks = nn.ModuleList( |
See #13390 (comment).
src/diffusers/models/transformers/transformer_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
| batch_size = hidden_states.shape[0] | ||
| if timestep.ndim == 0: | ||
| timestep = timestep.repeat(batch_size) | ||
| timestep_embed = self.time_embed(timestep) | ||
| text_mask = encoder_attention_mask.bool() | ||
| encoder_hidden_states = self.text_embed(encoder_hidden_states, text_mask) |
There was a problem hiding this comment.
| batch_size = hidden_states.shape[0] | |
| if timestep.ndim == 0: | |
| timestep = timestep.repeat(batch_size) | |
| timestep_embed = self.time_embed(timestep) | |
| text_mask = encoder_attention_mask.bool() | |
| encoder_hidden_states = self.text_embed(encoder_hidden_states, text_mask) | |
| batch_size = hidden_states.shape[0] | |
| if timestep.ndim == 0: | |
| timestep = timestep.repeat(batch_size) | |
| timestep_embed = self.time_embed(timestep) | |
| text_mask = encoder_attention_mask.bool() | |
| encoder_hidden_states = self.text_embed(encoder_hidden_states, text_mask) |
Can you also refactor forward here so that it is better organized, following #13390 (comment)? See for example the QwenImageTransformer2DModel.forward method:
There was a problem hiding this comment.
Reorganized parts of forward incrementally; kept the current structure otherwise to avoid unnecessary behavioral churn.
Thx, PTAL |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
@bot /style |
|
Style bot fixed some files and pushed the changes. |
|
These CI failures do not appear to be related to this PR |
|
|
||
| def _get_uniform_flow_match_scheduler_sigmas(num_inference_steps: int) -> list[float]: | ||
| num_inference_steps = max(int(num_inference_steps), 2) | ||
| num_updates = num_inference_steps - 1 |
There was a problem hiding this comment.
I think we should define num_inference_steps to match the number of function evaluations we're performing (that is, to have the same semantics that num_updates currently has), which is the usual diffusers behavior. This would also allow us to remove the behavior where we overwrite num_inference_steps=1 below in __call__.
| return {key[len(prefix) :]: value for key, value in state_dict.items() if key.startswith(prefix)} | ||
|
|
||
|
|
||
| def _get_uniform_flow_match_scheduler_sigmas(num_inference_steps: int) -> list[float]: |
There was a problem hiding this comment.
I think we should inline _get_uniform_flow_match_scheduler_sigmas into __call__ so that it's easier to understand how the sigma schedule is being prepared. See e.g. Flux2Pipeline for an example of this:
We generally prefer not to have too many small functions in the pipeline code.
| def _approx_duration_from_text(text: str | list[str], max_duration: float = 30.0) -> float: | ||
| if isinstance(text, list): | ||
| if not text: | ||
| return 0.0 | ||
| return max(_approx_duration_from_text(prompt, max_duration=max_duration) for prompt in text) | ||
|
|
||
| en_dur_per_char = 0.082 |
There was a problem hiding this comment.
| def _approx_duration_from_text(text: str | list[str], max_duration: float = 30.0) -> float: | |
| if isinstance(text, list): | |
| if not text: | |
| return 0.0 | |
| return max(_approx_duration_from_text(prompt, max_duration=max_duration) for prompt in text) | |
| en_dur_per_char = 0.082 | |
| def _approx_duration_from_text(text: str | list[str], max_duration: float = 30.0) -> float: | |
| if not text: | |
| return 0.0 | |
| if isinstance(text, str): | |
| text = [text] | |
| en_dur_per_char = 0.082 |
nit: I think refactoring this function to be non-recursive (by making it work naturally with a list of strings) would make it more clear.
| first_hidden = F.layer_norm(first_hidden, (first_hidden.shape[-1],), eps=1e-6) | ||
| prompt_embeds = prompt_embeds + first_hidden | ||
| lengths = attention_mask.sum(dim=1).to(device) | ||
| return prompt_embeds.float(), lengths |
There was a problem hiding this comment.
| return prompt_embeds.float(), lengths | |
| return prompt_embeds, lengths |
Do we need to call .float() on prompt_embeds here? I think we should generally respect the output dtype from self.text_encoder.
| ) | ||
| self.scheduler.set_begin_index(0) | ||
| timesteps = self.scheduler.timesteps | ||
| sample = latents |
There was a problem hiding this comment.
I think using the standard name latents instead of sample would be more clear. It would also work better with PipelineTesterMixin tests.
| if latents is None: | ||
| duration = max(1, min(duration, max_duration)) | ||
|
|
||
| text_condition, text_condition_len = self.encode_prompt(normalized_prompts, device) |
There was a problem hiding this comment.
| text_condition, text_condition_len = self.encode_prompt(normalized_prompts, device) | |
| prompt_embeds, text_condition_len = self.encode_prompt(normalized_prompts, device) |
Similarly to #13390 (comment), I think using the standard name prompt_embeds would be better here.
| if not return_dict: | ||
| return (waveform,) |
There was a problem hiding this comment.
| if not return_dict: | |
| return (waveform,) | |
| self.maybe_free_model_hooks() | |
| if not return_dict: | |
| return (waveform,) |
Calling self.maybe_free_model_hooks() allows the pipeline to clear model hooks correctly, such as those used to support model offloading.
| if output_type == "latent": | ||
| if not return_dict: | ||
| return (sample,) | ||
| return AudioPipelineOutput(audios=sample) |
There was a problem hiding this comment.
| if output_type == "latent": | |
| if not return_dict: | |
| return (sample,) | |
| return AudioPipelineOutput(audios=sample) | |
| if output_type == "latent": | |
| waveform = sample |
A little simpler. Also makes it so that we don't have to call self.maybe_free_model_hooks() twice (see #13390 (comment)).
| latent_cond=latent_cond, | ||
| ).sample | ||
| pred = null_pred + (pred - null_pred) * guidance_scale | ||
| sample = self.scheduler.step(pred, t, sample, return_dict=False)[0] |
There was a problem hiding this comment.
| sample = self.scheduler.step(pred, t, sample, return_dict=False)[0] | |
| sample = self.scheduler.step(pred, t, sample, return_dict=False)[0] | |
| if callback_on_step_end is not None: | |
| callback_kwargs = {} | |
| for k in callback_on_step_end_tensor_inputs: | |
| callback_kwargs[k] = locals()[k] | |
| callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) | |
| sample = callback_outputs.pop("latents", latent) | |
| text_condition = callback_outputs.pop("prompt_embeds", prompt_embeds) |
Example for supporting callbacks. This assumes we use the standard names latents and prompt_embeds (see #13390 (comment), #13390 (comment)). See also how e.g. Flux2Pipeline supports callbacks:
diffusers/src/diffusers/pipelines/flux2/pipeline_flux2.py
Lines 993 to 997 in dc8d903
| guidance_scale: float = 4.0, | ||
| generator: torch.Generator | list[torch.Generator] | None = None, | ||
| output_type: str = "np", | ||
| return_dict: bool = True, |
There was a problem hiding this comment.
| return_dict: bool = True, | |
| return_dict: bool = True, | |
| callback_on_step_end: Callable[[int, int], None] | None = None, | |
| callback_on_step_end_tensor_inputs: list[str] = ["latents"], |
Follow-up for callback support (see #13390 (comment)).
|
|
||
|
|
||
| class LongCatAudioDiTPipeline(DiffusionPipeline): | ||
| model_cpu_offload_seq = "text_encoder->transformer->vae" |
There was a problem hiding this comment.
| model_cpu_offload_seq = "text_encoder->transformer->vae" | |
| model_cpu_offload_seq = "text_encoder->transformer->vae" | |
| _callback_tensor_inputs = ["latents", "prompt_embeds"] |
Follow up for callback support (#13390 (comment)). The callback tests specifically check for the name latents here, which is one reason to use it over samples.
|
|
||
|
|
||
| class LongCatAudioDiTTransformer(ModelMixin, ConfigMixin): | ||
| _supports_gradient_checkpointing = False |
There was a problem hiding this comment.
| _supports_gradient_checkpointing = False | |
| _supports_gradient_checkpointing = False | |
| _repeated_blocks = ["AudioDiTBlock"] |
Setting _repeated_blocks here enables regional compilation support. This also allows us to not skip the TestLongCatAudioDiTTransformerCompile.test_torch_compile_repeated_blocks test.
What does this PR do?
Adds LongCat-AudioDiT model support to diffusers.
Although LongCat-AudioDiT can be used for TTS-like generation, it is fundamentally a diffusion-based audio generation model (text conditioning + iterative latent denoising + VAE decoding) rather than a conventional autoregressive TTS model, so i think it fits naturally into diffusers.
Test
Result
longcat.wav
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.