# Flow Matching in PyTorch
This repository contains a simple PyTorch implementation of the paper [Flow Matching for Generative Modeling](https://arxiv.org/abs/2210.02747).
## 2D Flow Matching Example
The gif below demonstrates mapping a single Gaussian distribution to a checkerboard distribution, with the vector field visualized.
And, here is another example of moons dataset.
## Getting Started
Clone the repository and set up the python environment.
```bash
git clone https://github.com/keishihara/flow-matching.git
cd flow-matching
```
Make sure you have Python 3.10+ installed.
To set up the python environment using `uv`:
```bash
uv sync
source .venv/bin/activate
```
Alternatively, using `pip`:
```bash
python -m venv .venv
source .venv/bin/activate
pip install -e .
```
## Conditional Flow Matching [Lipman+ 2023]
This is the original CFM paper implementation [1]. Some components of the code are adapted from [2] and [3].
### 2D Toy Datasets
You can train the CFM models on 2D synthetic datasets such as `checkerboard` and `moons`. Specify the dataset name using `--dataset` option. Training parameters are predefined in the script, and visualizations of the training results are stored in the `outputs/` directory. Model checkpoints are not included as they are easily reproducible with the default settings.
```bash
python train_flow_matching_2d.py --dataset checkerboard
```
The vector fields and generated samples, like the ones displayed as GIFs at the top of this README, can now be found in the `outputs/cfm/` directory.
### Image Datasets
You can also train class-conditional CFM models on popular image classification datasets. Both the generated samples and model checkpoints will be stored in the `outputs/cfm` directory. For a detailed list of training parameters, run `python train_flow_matching_on_images.py --help`.
To train a class-conditional CFM on MNIST dataset, run:
```bash
python train_flow_matching_on_image.py --do_train --dataset mnist
```
After training, you can now generate samples with:
```bash
python train_flow_matching_on_image.py --do_sample --dataset mnist
```
Now, you should be able to see the generated samples in the `outputs/cfm/mnist/` directory.
## Rectified Flow [Liu+ 2023]
This is an implementation of the Reflow model (2-Rectified Flow to be specific) from the Rectified Flow paper [2].
### 2D Synthetic Data
We have implemented the Reflow on 2d synthetic datasets, same as the CFM. To train the reflow, you have to specify pretrained CFM checkpoints as reflow is a distillation model.
For example, to train on the `checkerboard` dataset with a pretrained CFM checkpoint:
```bash
python train_reflow_2d.py --dataset checkerboard --pretrained-model outputs/cfm/checkerboard/ckpt.pth
```
The training results, including vector field visualizations and generated samples, are saved under `outputs/reflow/` folder.
### Comparison of sampling process between CFM and Reflow
To compare CFM and Reflow on 2d datasets, run:
```bash
python plot_comparison_2d.py --dataset checkerboard
```
The resulting GIFs can be found under `outputs/comparisons/` folder. Below is an example comparison of the two methods in the `checkerboard` dataset:
## References
- [1] Lipman, Yaron, et al. "Flow Matching for Generative Modeling." [arXiv:2210.02747](https://arxiv.org/abs/2210.02747)
- [2] Liu, Xingchao, et al. "Flow Straight and Fast: Learning to Generate and Transfer Data with Rectified Flow." [arXiv:2209.03003](https://arxiv.org/abs/2209.03003)
- [3] [facebookresearch/flow_matching](https://github.com/facebookresearch/flow_matching)
- [4] [atong01/conditional-flow-matching](https://github.com/atong01/conditional-flow-matching)