Skip to content

Commit bc989eb

Browse files
vdebortoHackable Diffusion Authors
authored andcommitted
RFM: Update Riemannian Documentation
PiperOrigin-RevId: 890469073
1 parent 0805a54 commit bc989eb

File tree

5 files changed

+292
-24
lines changed

5 files changed

+292
-24
lines changed

docs/architecture.md

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,14 +140,57 @@ print(f"Output shape: {output.shape}")
140140

141141
For simpler, non-image data, a `ConditionalMLP` backbone is provided. It
142142
processes the input `x`, combines it with conditioning embeddings, and passes it
143-
through a series of dense layers. This module is mainly use for testing
143+
through a series of dense layers. This module is mainly used for testing
144144
purposes.
145145

146+
### `RiemannianConditionalBackbone`
147+
148+
(`lib/architecture/riemannian.py`)
149+
150+
A specialized wrapper for any `ConditionalBackbone` that handles Riemannian
151+
manifold constraints. Its primary role is to ensure that the model's output
152+
`velocity` is a valid **tangent vector** at the point `xt`.
153+
154+
This is achieved by applying the manifold's **`project`** operator to the raw
155+
output of the underlying backbone.
156+
157+
#### Riemannian Projections
158+
159+
Each manifold defines a `project(x, v)` method that ensures the output $$v$$ is a
160+
valid tangent vector at point $$x$$.
161+
162+
* **Sphere ($$S^d$$)**: The projection is $$v_{\text{tangent}} = v - \langle x, v \rangle x$$, which removes the component of $$v$$ parallel to $$x$$.
163+
* **SO(3)**: The projection maps a $$3 \times 3$$ matrix $$V$$ to the tangent space
164+
$$T_R SO(3)$$ by computing the skew-symmetric part of the relative velocity in
165+
the Lie algebra: $$R \cdot \text{skew}(R^T V)$$, where
166+
$$\text{skew}(\Omega) = 0.5(\Omega - \Omega^T)$$.
167+
168+
By wrapping a standard neural network (e.g., a UNet) in this backbone, we can
169+
learn complex velocity fields on manifolds using standard architectures.
170+
171+
#### Example Usage
172+
173+
```python
174+
from hackable_diffusion.lib import manifolds
175+
from hackable_diffusion.lib.architecture.riemannian import RiemannianConditionalBackbone
176+
from hackable_diffusion.lib.architecture.mlp import ConditionalMLP
177+
178+
# 1. Choose a manifold
179+
manifold = manifolds.Sphere()
180+
181+
# 2. Create a standard backbone
182+
mlp = ConditionalMLP(num_features=256, num_layers=3)
183+
184+
# 3. Wrap it in a RiemannianConditionalBackbone
185+
model = RiemannianConditionalBackbone(
186+
backbone=mlp,
187+
manifold=manifold,
188+
)
189+
```
190+
146191
The conditioning mechanism is simpler here, limited to `SUM` or `CONCATENATE` of
147192
the conditioning embeddings with the intermediate representation of `x`.
148193

149-
## Attention
150-
151194
### `MultiHeadAttention`
152195

153196
(`lib/architecture/attention.py`)

docs/corruption.md

Lines changed: 125 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,14 @@ various corruption processes for both continuous and discrete data.
1616

1717
The main components are:
1818

19-
* **`CorruptionProcess` Protocol**: An interface that standardizes how
19+
* **`CorruptionProcess` Protocol**: An interface that standardizes how
2020
corruption is applied.
21-
* **Schedules**: Functions that define the rate and nature of corruption over
21+
* **Schedules**: Functions that define the rate and nature of corruption over
2222
time `t`.
23-
* **Process Implementations**: Concrete classes like `GaussianProcess` for
24-
continuous data (e.g., images) and `CategoricalProcess` for discrete data
25-
(e.g., labels, tokens).
23+
* **Process Implementations**: Concrete classes like `GaussianProcess` for
24+
continuous data (e.g., images), `CategoricalProcess` for discrete data
25+
(e.g., labels, tokens), and `RiemannianProcess` for data on Riemannian
26+
manifolds.
2627

2728
## `CorruptionProcess` Protocol
2829

@@ -230,3 +231,122 @@ print(f"Logits target shape: {target_info['logits'].shape}")
230231
* The model prediction for discrete data is expected to be logits over the
231232
categories. `convert_predictions` will then convert these logits to a
232233
predicted `x0` (via argmax).
234+
235+
## `RiemannianProcess`
236+
237+
(`lib/corruption/riemannian.py`)
238+
239+
This process implements **Riemannian Flow Matching (RFM)**, a generalization of
240+
Flow Matching to smooth Riemannian manifolds. Unlike standard diffusion, which
241+
relies on Gaussian noise, RFM uses the manifold's intrinsic geometry to
242+
interpolate between data and noise distributions.
243+
244+
### Mathematical Foundations: Continuous-time Flow Matching
245+
246+
Let $$(\mathcal{M}, g)$$ be a $$d$$-dimensional smooth Riemannian manifold. A
247+
probability path $$p_t$$ on $$\mathcal{M}$$ can be defined via the **Continuity
248+
Equation**:
249+
250+
$$\frac{\partial p_t}{\partial t} + \operatorname{div}_g (p_t v_t) = 0$$
251+
252+
where $$\operatorname{div}_g$$ is the Riemannian divergence operator and $$v_t \in T_x \mathcal{M}$$ is a time-dependent vector field. Riemannian Flow Matching aims to find a vector field $$v_\theta(x, t)$$ that generates a path $$p_t$$ such that $$p_0$$ is the data distribution and $$p_1$$ is an invariant noise distribution.
253+
254+
### Riemannian Concepts: Exp, Log, and Geodesics
255+
256+
The geometry of the manifold is abstracted through three key operations implemented in `lib/manifolds.py`:
257+
258+
#### 1. Exponential Mapping ($$\text{Exp}_x$$)
259+
260+
The exponential map $$\text{Exp}_x : T_x \mathcal{M} \to \mathcal{M}$$ provides a way to "map" a tangent vector $$v \in T_x \mathcal{M}$$ back onto the manifold. Intuitively, if you start at point $$x$$ and walk in the direction of $$v$$ for unit time along the unique "straightest" path (geodesic), you arrive at $$\text{Exp}_x(v)$$.
261+
262+
In the library, this is used during **sampling** (to move from $$x_t$$ to $$x_{t-dt}$$) and to construct geodesics.
263+
264+
#### 2. Logarithm Mapping ($$\text{Log}_x$$)
265+
266+
The logarithm map $$\text{Log}_x : \mathcal{M} \to T_x \mathcal{M}$$ is the inverse of the exponential map (where defined). Given two points $$x, y \in \mathcal{M}$$, $$\text{Log}_x(y)$$ returns the tangent vector at $$x$$ that points toward $$y$$ along the shortest geodesic. The length of this vector equals the Riemannian distance between the two points: $$\|\text{Log}_x(y)\|_g = d_g(x, y)$$.
267+
268+
In the library, this is used during **training** to find the direction of the conditional flow between noise and data.
269+
270+
#### 3. Geodesic Mapping ($$\gamma$$)
271+
272+
A geodesic is the generalization of a straight line to curved spaces. The unique geodesic path starting at $$x$$ and ending at $$y$$ can be parameterized by $$t \in [0, 1]$$ as:
273+
274+
$$\gamma(t) = \text{Exp}_x(t \cdot \text{Log}_x(y))$$
275+
276+
This mapping ensures that the interpolation between distributions stays on the manifold and follows the shortest possible paths, which is the cornerstone of Riemannian Flow Matching.
277+
278+
### The Riemannian Flow Matching loss
279+
280+
$$\mathcal{L}(\theta) = \mathbb{E}_{t \sim \mathcal{U}[0, 1], x_0 \sim p_0, x_1 \sim p_1} [ \| v_{\theta}(x_t, t) - u_t(x_t | x_0, x_1) \|_{g}^2 ]$$
281+
282+
where the conditional velocity field $$u_t(x|x_0, x_1)$$ is derived from a
283+
conditional probability path $$p_t(x|x_0, x_1)$$ that satisfies the continuity
284+
equation. In this library, we use **geodesic paths** for the conditional
285+
interpolation:
286+
287+
1. **Conditional Path**: $$x_t = \text{Exp}_{x_1}(\alpha(t) \text{Log}_{x_1}(x_0))$$
288+
2. **Conditional Velocity**: $$u_t(x_t | x_0, x_1) = \dot{\alpha}(t) \cdot \frac{d}{ds} \text{Exp}_{x_1}(s \text{Log}_{x_1}(x_0)) \big|_{s=\alpha(t)}$$
289+
290+
For the standard `LinearRiemannianSchedule`, $$\alpha(t) = 1 - t$$, meaning the
291+
path flows from noise ($$t=0, \alpha=1, x_{t=0}=x_1$$) to data ($$t=1, \alpha=0,
292+
x_{t=1}=x_0$$). *Note: The implementation uses $$\alpha(t)$$ such that $$t=0$$ is
293+
clean data and $$t=1$$ is noise, with internal interpolation adjustments to
294+
match this theory.*
295+
296+
### Supported Manifolds (`lib/manifolds.py`)
297+
298+
Each manifold implements the `Manifold` protocol, providing core geometric
299+
operations with an emphasis on numerical stability.
300+
301+
#### 1. Unit Hypersphere ($$S^d$$)
302+
303+
Points $$x \in \mathbb{R}^{d+1}$$ such that $$\|x\|_2 = 1$$. The tangent space
304+
$$T_x S^d$$ is the subspace $$\{v \in \mathbb{R}^{d+1} \mid \langle x, v \rangle = 0\}$$.
305+
306+
* **Exp**: $$\text{Exp}_x(v) = \cos(\|v\|)x + \text{sinc}(\|v\|)v$$
307+
* **Log**: $$\text{Log}_x(y) = \frac{\theta}{\sin \theta}(y - \cos \theta x)$$, where $$\theta = \arccos(\langle x, y \rangle)$$
308+
* **Velocity**: The time-derivative along the geodesic:
309+
$$u_t = -\theta \sin(\theta t)x_1 + \cos(\theta t) \text{Log}_{x_1}(x_0)$$
310+
311+
The implementation uses an **unnormalized sinc trick** ($$\text{sinc}(x) = \frac{\sin x}{x}$$) to handle the singularity at $$\theta=0$$ gracefully.
312+
313+
#### 2. Special Orthogonal Group ($$SO(3)$$)
314+
315+
Points $$R$$ are $$3 \times 3$$ rotation matrices. The tangent space $$T_R SO(3)$$ is
316+
isomorphic to the Lie Algebra $$\mathfrak{so}(3)$$ of skew-symmetric matrices
317+
via $$R \cdot \omega^\wedge$$.
318+
319+
* **Exp**: Computed via **Rodrigues' Rotation Formula**:
320+
$$\text{Exp}_R(v) = R (I + \text{sinc}(\theta)\omega^\wedge + \text{cosc}(\theta)(\omega^\wedge)^2)$$, where $$\theta = \|\omega\|$$.
321+
* **Log**: Maps $$R_1^T R_0$$ to its rotation axis and angle $$\theta$$.
322+
* **Velocity**: $$u_t = x_t \cdot \text{Log}(x_1^T x_0)$$.
323+
324+
The library uses a safe **cosc trick** ($$\text{cosc}(x) = \frac{1 - \cos x}{x^2} = \frac{1}{2} \text{sinc}(\frac{x}{2})^2$$) to ensure numerical stability in the Rodrigues formula.
325+
326+
#### 3. Flat Torus ($[0, 1]^d$)
327+
328+
The torus is a flat space with periodic boundary conditions.
329+
330+
* **Metric**: Standard Euclidean metric $$g = I$$.
331+
* **Geodesics**: Straight lines modulo 1.
332+
* **Velocity**: Constant velocity $$u = \text{Log}_{x_1}(x_0) = (x_0 - x_1 + 0.5) \pmod 1 - 0.5$$.
333+
334+
### Example Usage
335+
336+
```python
337+
from hackable_diffusion.lib import manifolds
338+
from hackable_diffusion.lib.corruption.riemannian import RiemannianProcess
339+
from hackable_diffusion.lib.corruption.schedules import LinearRiemannianSchedule
340+
341+
# 1. Define manifold and process
342+
manifold = manifolds.Sphere()
343+
schedule = LinearRiemannianSchedule()
344+
process = RiemannianProcess(manifold=manifold, schedule=schedule)
345+
346+
# 2. Corrupt data
347+
x0 = jnp.array([[1.0, 0.0, 0.0]]) # Point on S2
348+
time = jnp.array([0.5])
349+
xt, target_info = process.corrupt(subkey, x0, time)
350+
351+
# target_info['velocity'] is the regression target u_t
352+
```

docs/index.md

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,9 @@ system for encoding and injecting conditioning signals via
6666

6767
This module defines the **forward process** of diffusion. It includes
6868
implementations for corrupting data with noise, such as `GaussianProcess` for
69-
continuous data and `CategoricalProcess` for discrete data. It also defines the
70-
noise `schedules` that govern the corruption over time.
69+
continuous data, `CategoricalProcess` for discrete data, and `RiemannianProcess`
70+
for data on Riemannian manifolds (e.g., Sphere, SO(3), Torus). It also defines
71+
the noise `schedules` that govern the corruption over time.
7172

7273
### [Inference Function](./inference.md)
7374

@@ -101,14 +102,18 @@ The `notebooks/` directory contains a set of tutorials that demonstrate how to
101102
use the library to train and sample from diffusion models. These serve as
102103
excellent starting points for understanding the library's components in action.
103104

104-
* **`2d_training.ipynb`**: A minimal example that trains a diffusion model on
105+
* **`2d_training.ipynb`**: A minimal example that trains a diffusion model on
105106
a simple 2D toy dataset.
106-
* **`mnist.ipynb`**: Trains a standard continuous diffusion model (Gaussian
107+
* **`mnist.ipynb`**: Trains a standard continuous diffusion model (Gaussian
107108
process) on the MNIST dataset, demonstrating image data handling.
108-
* **`mnist_discrete.ipynb`**: Trains a discrete diffusion model on MNIST,
109+
* **`mnist_discrete.ipynb`**: Trains a discrete diffusion model on MNIST,
109110
treating pixel values as categorical data. This showcases the use of
110111
`CategoricalProcess`.
111-
* **`mnist_multimodal.ipynb`**: A more advanced example that trains a
112+
* **`mnist_multimodal.ipynb`**: A more advanced example that trains a
112113
multimodal model to jointly generate MNIST images with discrete and
113114
continuous diffusion models, demonstrating the "Nested" design pattern in a
114115
practical setting.
116+
* **`riemannian_sphere_training.ipynb`**: Demonstrates Riemannian Flow
117+
Matching on the unit sphere S^2.
118+
* **`riemannian_torus_ode_to_sde.ipynb`**: Shows how to use Riemannian Flow
119+
Matching on the torus manifold for both ODE and SDE sampling.

docs/loss.md

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,3 +158,39 @@ loss requires a `DiscreteSchedule`.
158158

159159
This is a concrete implementation that computes discrete diffusion loss without
160160
any weighting (i.e. weight=1).
161+
---
162+
163+
## Riemannian Flow Matching Loss
164+
165+
Training a Riemannian Flow Matching (RFM) model requires a loss function that
166+
respects the intrinsic geometry of the manifold $$(\mathcal{M}, g)$$.
167+
168+
### Metric-Aware Loss
169+
170+
The **Riemannian Flow Matching loss** is defined as the squared norm of the
171+
difference between the model's velocity prediction $$v_\theta$$ and the true
172+
geodesic velocity $$u_t$$:
173+
174+
$$\mathcal{L}(\theta) = \mathbb{E}_{t, x_0, x_1} [ \| v_{\theta}(x_t, t) - u_t(x_t | x_0, x_1) \|_{g}^2 ]$$
175+
176+
where the norm is induced by the Riemannian metric $$g$$ at point $$x_t$$:
177+
178+
$$\| v \|_{g} = \sqrt{g_{x_t}(v, v)}$$
179+
180+
### Implementation for Embedded Manifolds
181+
182+
For many manifolds implemented in this library (like the Sphere $$S^d$$ or $$SO(3)$$), the Riemannian metric is induced by the standard Euclidean metric of the ambient space $$\mathbb{R}^n$$. In these cases, the loss simplifies to:
183+
184+
$$\mathcal{L}(\theta) = \mathbb{E}_{t, x_0, x_1} [ \| v_{\theta}(x_t, t) - u_t(x_t | x_0, x_1) \|_{2}^2 ]$$
185+
186+
**Crucially**, this equivalence only holds if $$v_{\theta}$$ and $$u_t$$ are both valid **tangent vectors** (i.e., $$v, u \in T_{x_t} \mathcal{M}$$). The library ensures this via:
187+
188+
1. **True Target**: The `RiemannianProcess` returns a $$u_t$$ that is
189+
mathematically guaranteed to be tangent to the manifold.
190+
2. **Model Forecast**: The **`RiemannianConditionalBackbone`** (see
191+
[Architecture docs](./architecture.md)) acts as a wrapper that projects the
192+
raw model output onto the tangent space $$T_{x_t} \mathcal{M}$$ before
193+
computing the loss.
194+
195+
By enforcing the tangent space constraint, the RFM objective can be optimized
196+
using standard MSE loss while remaining geometrically rigorous.

0 commit comments

Comments
 (0)