Skip to content

amazon-science/TransitionFlowMatching

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

5 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Demystifying Transition Matching

This repository contains the official implementation of Demystifying Transition Matching: When and Why It Can Beat Flow Matching, accepted to the Twenty-Ninth Annual Conference on Artificial Intelligence and Statistics (AISTATS), 2026.

Flow Matching (FM) underpins many state-of-the-art generative models, yet Transition Matching (TM) can achieve higher sample quality with fewer steps. We answer when and why: for unimodal Gaussian targets, TM attains strictly lower KL divergence than FM at any finite step count, because its stochastic latent updates preserve target covariance that deterministic FM underestimates. For Gaussian mixtures with well-separated modes, the distribution is approximately locally unimodal within each component, and TM retains this advantage β€” explaining its strong performance in multimodal settings.

These theoretical gains translate to practice. Across image and video generation benchmarks, TM consistently achieves better quality under the same or lower compute budgets, reaching competitive or superior performance with fewer sampling steps.


Class-Conditioned Image Generation


Frame-Conditioned Video Generation

For detailed theoretical analysis and additional experiments, see the full paper.


πŸ“‹ Table of Contents


πŸ“ Project Structure

TransitionFlowMatching/
β”œβ”€β”€ TM-Image/                        # Image generation module
β”‚   β”œβ”€β”€ main_mar.py                  # Training & inference entry point
β”‚   β”œβ”€β”€ main_cache.py                # VAE latent caching
β”‚   β”œβ”€β”€ engine_mar.py                # Training & evaluation engine
β”‚   β”œβ”€β”€ models/
β”‚   β”‚   β”œβ”€β”€ mar.py                   # MAR backbone
β”‚   β”‚   β”œβ”€β”€ dtm.py                   # Discrete Transition Matching
β”‚   β”‚   β”œβ”€β”€ ar_backbone.py           # Causal autoregressive backbone
β”‚   β”‚   β”œβ”€β”€ diffloss.py              # Diffusion head
β”‚   β”‚   └── vae.py                   # VAE architecture
β”‚   β”œβ”€β”€ diffusion/                   # Gaussian diffusion utilities
β”‚   β”œβ”€β”€ util/                        # Data loading, LR scheduling, misc
β”‚   └── fid_stats/                   # Pre-computed FID statistics
β”‚
β”œβ”€β”€ TM-Video/                        # Video generation module
β”‚   β”œβ”€β”€ main.py                      # Hydra-based entry point
β”‚   β”œβ”€β”€ algorithms/
β”‚   β”‚   β”œβ”€β”€ dfot/                    # Diffusion Forcing Transformer
β”‚   β”‚   β”œβ”€β”€ vae/                     # Video/image VAE
β”‚   β”‚   └── common/                  # Shared components & metrics
β”‚   β”œβ”€β”€ configurations/              # Hydra YAML configs
β”‚   β”‚   β”œβ”€β”€ algorithm/               # Model configs (token_video, etc.)
β”‚   β”‚   β”œβ”€β”€ dataset/                 # Dataset configs (Kinetics-600, etc.)
β”‚   β”‚   β”œβ”€β”€ experiment/              # Experiment configs
β”‚   β”‚   └── shortcut/                # Pre-built configs (@DiT/token_XL)
β”‚   β”œβ”€β”€ datasets/                    # Dataset implementations
β”‚   β”œβ”€β”€ experiments/                 # Training/evaluation scripts
β”‚   └── utils/                       # Checkpointing, W&B, distributed
β”‚
β”œβ”€β”€ requirements.txt
β”œβ”€β”€ LICENSE
β”œβ”€β”€ CODE_OF_CONDUCT.md
└── CONTRIBUTING.md

πŸ”§ Setup

Prerequisites

  • Python 3.10
  • CUDA-compatible GPU(s)
  • Conda package manager

Installation

conda create python=3.10 -n tm
conda activate tm
pip install -r requirements.txt

πŸ–ΌοΈ Image Generation (TM-Image)

Navigate to ./TM-Image for the image generation module, built on top of MAR (Masked Autoregressive Representation).

Available Models

Model Description
mar_large MAR backbone (large)
dtm_large Discrete Transition Matching (large)
fm_large Flow Matching baseline (large)

Dataset & Pretrained Models

  1. Download the ImageNet dataset and place it in your IMAGENET_PATH.
  2. Download the pretrained VAE by running download_pretrained_vae() in download.py.

πŸ’Ύ (Optional) Caching VAE Latents

Since data augmentation only involves center cropping and random flipping, VAE latents can be pre-computed to speed up training:

torchrun --nproc_per_node=8 --nnodes=1 --node_rank=0 \
  main_cache.py \
  --img_size 256 --vae_path pretrained_models/vae/kl16.ckpt --vae_embed_dim 16 \
  --batch_size 128 \
  --data_path ${IMAGENET_PATH} --cached_path ${CACHED_PATH}

πŸ‹οΈ Training

export NODE_RANK=0
export MASTER_ADDR=<your_master_address>
export MASTER_PORT=29500

torchrun --nproc_per_node=${N_GPU} --nnodes=${NNODES} \
  --node_rank=${NODE_RANK} --master_addr=${MASTER_ADDR} --master_port=${MASTER_PORT} \
  main_mar.py \
  --img_size 256 \
  --vae_path pretrained_models/vae/kl16.ckpt --vae_embed_dim 16 --vae_stride 16 --patch_size 1 \
  --model ${MODEL} --diffloss_d ${DIFF_DEPTH} --diffloss_w ${DIFF_CH} \
  --epochs ${EPOCHS} --warmup_epochs 100 --batch_size 128 --diffusion_batch_mul 1 \
  --output_dir ${OUTPUT_DIR} \
  --use_cached --cached_path ${CACHED_PATH} --save_last_freq 100 \
  --blr 1.0e-4 --lr 0.0005 --T ${T}

Key arguments:

Argument Description Default
N_GPU Number of GPUs per node β€”
NNODES Number of nodes β€”
MODEL Model type (mar_large, fm_large, dtm_large) β€”
DIFF_DEPTH Diffusion head depth 6
DIFF_CH Diffusion head channel size 1024
EPOCHS Training epochs 500
T Discretized steps for TM 128

πŸ“Š Inference & Evaluation

torchrun --nproc_per_node=8 --nnodes=1 main_mar.py \
  --model ${MODEL} --diffloss_d ${DIFF_DEPTH} --diffloss_w ${DIFF_CH} \
  --eval_bsz 128 --num_images 50000 \
  --cfg 4.0 --cfg_schedule constant --temperature 1.0 \
  --data_path ./dataset/imagenet_10k --evaluate \
  --output_dir ${OUTPUT_DIR} --resume ${CKPT_PATH}/checkpoint-400.pth \
  --num_iter ${NUM_ITER} --num_sampling_steps ${NUM_SAMPLING_STEP} --T ${T}
Argument Description
CKPT_PATH Path to model checkpoint
NUM_ITER Backbone sampling steps (N)
NUM_SAMPLING_STEP Flow head sampling steps (S)

🎬 Video Generation (TM-Video)

Navigate to ./TM-Video for the video generation module, built on top of Diffusion Forcing Transformer. Configuration is managed via Hydra.

Supported Datasets

Dataset Config key
Kinetics-600 kinetics_600
RealEstate10K realestate10k
Minecraft minecraft

πŸ“‘ (Optional) Connect to Weights & Biases

We use Weights & Biases for experiment tracking. Sign up if you don't have an account, and modify wandb.entity in config.yaml to your user/organization name.

πŸ‹οΈ Training

python -m main +name=${TRAIN_RUNNAME} \
  dataset=kinetics_600 \
  algorithm=token_video \
  experiment=video_generation \
  @DiT/token_XL

πŸ“Š Inference & Evaluation

python -m main +name=${EVAL_RUNNAME} \
  experiment.ema.enable=True \
  dataset.context_length=1 \
  algorithm.backbone.num_sampling_steps=${NUM_SAMPLING_STEPS} \
  algorithm.diffusion.sampling_timesteps=${SAMPLING_TIMESTEPS} \
  dataset=kinetics_600 \
  algorithm=token_video \
  experiment=video_generation \
  @DiT/token_XL \
  'experiment.tasks=[validation]' \
  experiment.validation.batch_size=50 \
  dataset.num_eval_videos=50000 \
  'algorithm.logging.metrics=[fvd, vbench, is, fid, lpips, mse, ssim, psnr]' \
  load=${CKPT}
Argument Description
dataset.context_length Number of ground-truth frames for context
algorithm.diffusion.sampling_timesteps TM backbone sampling steps (N)
algorithm.backbone.num_sampling_steps TM flow head sampling steps (S)
load Checkpoint file to load

πŸ“ Evaluation Metrics

Metric Domain Description
FID Image / Video Frechet Inception Distance
IS Image / Video Inception Score
FVD Video Frechet Video Distance
VBench Video Comprehensive video quality assessment
LPIPS Video Learned Perceptual Image Patch Similarity
MSE Video Mean Squared Error
SSIM Video Structural Similarity Index
PSNR Video Peak Signal-to-Noise Ratio

πŸ™ Acknowledgements

The image generation codebase is built upon MAR. We appreciate the authors for releasing their code.

For video generation, the repo uses Boyuan Chen's research template repo. By its license, we ask you to keep this attribution in README.md and the LICENSE file to credit the author.


πŸ“„ Citations

If you find this repository useful, please consider citing:

@article{kim2025demystifying,
  title={Demystifying Transition Matching: When and Why It Can Beat Flow Matching},
  author={Kim, Jaihoon and Saha, Rajarshi and Sung, Minhyuk and Park, Youngsuk},
  journal={arXiv preprint arXiv:2510.17991},
  year={2025}
}

@article{li2024autoregressive,
  title={Autoregressive Image Generation without Vector Quantization},
  author={Li, Tianhong and Tian, Yonglong and Li, He and Deng, Mingyang and He, Kaiming},
  journal={arXiv preprint arXiv:2406.11838},
  year={2024}
}

@misc{song2025historyguidedvideodiffusion,
  title={History-Guided Video Diffusion}, 
  author={Kiwhan Song and Boyuan Chen and Max Simchowitz and Yilun Du and Russ Tedrake and Vincent Sitzmann},
  year={2025},
  eprint={2502.06764},
  archivePrefix={arXiv},
  primaryClass={cs.LG},
  url={https://arxiv.org/abs/2502.06764}, 
}

πŸ”’ Security

See CONTRIBUTING for more information.

βš–οΈ License

This project is licensed under the Apache-2.0 License.

About

Official implementation of "Demystifying Transition Matching: When and Why It Can Beat Flow Matching" (AISTATS 2026). Code for image and video generation using Transition Matching.

Resources

License

Code of conduct

Contributing

Security policy

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages