Skip to content

Commit 7fc0439

Browse files
agalashovHackable Diffusion Authors
authored andcommitted
AR Diffusion sampler for hackable_diffusion
PiperOrigin-RevId: 922902317
1 parent 0093c0c commit 7fc0439

3 files changed

Lines changed: 871 additions & 1 deletion

File tree

hackable_diffusion/lib/sampling/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
"""Sampling."""
1616

1717
# pylint: disable=g-importing-member
18+
from hackable_diffusion.lib.sampling.ar_diffusion_sampler import ARStateHandler
19+
from hackable_diffusion.lib.sampling.ar_diffusion_sampler import AutoregressiveDiffusionSampler
1820
from hackable_diffusion.lib.sampling.base import DiffusionStep
1921
from hackable_diffusion.lib.sampling.base import DiffusionStepTree
2022
from hackable_diffusion.lib.sampling.base import SamplerStep
@@ -32,8 +34,8 @@
3234
from hackable_diffusion.lib.sampling.discrete_step_sampler import NoRemaskingFn
3335
from hackable_diffusion.lib.sampling.discrete_step_sampler import RemaskingFn
3436
from hackable_diffusion.lib.sampling.discrete_step_sampler import RescaledRemaskingFn
35-
from hackable_diffusion.lib.sampling.discrete_step_sampler import RoutingStrategy
3637
from hackable_diffusion.lib.sampling.discrete_step_sampler import Routing
38+
from hackable_diffusion.lib.sampling.discrete_step_sampler import RoutingStrategy
3739
from hackable_diffusion.lib.sampling.discrete_step_sampler import UnMaskingStep
3840
from hackable_diffusion.lib.sampling.gaussian_step_sampler import AdjustedDDIMStep
3941
from hackable_diffusion.lib.sampling.gaussian_step_sampler import DDIMStep
Lines changed: 323 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,323 @@
1+
# Copyright 2026 Hackable Diffusion Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Autoregressive diffusion sampler.
16+
17+
This module implements an autoregressive generation loop where each generation
18+
step produces a fixed-length "canvas" of data via diffusion sampling. The canvas
19+
is then post-processed and integrated into the running sampler state.
20+
21+
The overall sampling flow is:
22+
23+
conditioning
24+
25+
26+
┌─────────────────────┐
27+
│ ARStateHandler │
28+
│ .init_ar_state │──── Build initial state,
29+
└─────────────────────┘
30+
31+
(SamplerState, batch_size)
32+
33+
34+
┌──────────────────────────────┐
35+
│ AR Loop │
36+
│ (up to max_num_canvases) │
37+
│ │
38+
│ ┌────────────────────────┐ │
39+
│ │ EarlyStoppingFn │──┼──▶ break if done
40+
│ └────────────────────────┘ │
41+
│ │ │
42+
│ ▼ │
43+
│ ┌────────────────────────┐ │
44+
│ │ DiffusionProcess │ │
45+
│ │ .sample_from_invariant│──┼──▶ initialize noisy canvas
46+
│ └────────────────────────┘ │
47+
│ │ │
48+
│ ▼ │
49+
│ ┌────────────────────────┐ │
50+
│ │ ARStateHandler │ │
51+
│ │ .create_conditioning │──┼──▶ extract diffusion
52+
│ │ _from_state │ │ conditioning from state
53+
│ └────────────────────────┘ │
54+
│ │ │
55+
│ ▼ │
56+
│ ┌────────────────────────┐ │
57+
│ │ DiffusionSampler │ │
58+
│ │ (canvas_sampler) │──┼──▶ denoise canvas via
59+
│ └────────────────────────┘ │ diffusion sampling
60+
│ │ │
61+
│ ▼ │
62+
│ ┌────────────────────────┐ │
63+
│ │ ARStateHandler │ │
64+
│ │ .update_ar_state │──┼──▶ update state
65+
│ └────────────────────────┘ │
66+
│ │ │
67+
│ └───────────┐ │
68+
│ next │ │
69+
│ canvas │ │
70+
└──────────────────────────────┘
71+
72+
73+
┌─────────────────────┐
74+
│ ARStateHandler │
75+
│ .finalize_ar_state │──── Extract generated output
76+
└─────────────────────┘
77+
78+
79+
output data
80+
81+
The architecture is model-agnostic: all model-specific logic is injected via
82+
the ``ARStateHandler`` base class, which encapsulates:
83+
84+
- ``init_ar_state``: Initializes the state.
85+
- ``update_ar_state``: Handles canvas post-processing and state bookkeeping
86+
after each canvas is sampled.
87+
- ``finalize_ar_state``: Extracts the final generated output from the state.
88+
89+
An ``EarlyStoppingFn`` Protocol controls when to terminate the AR loop early.
90+
91+
The AR loop uses ``jax.lax.while_loop`` and terminates when ``max_num_canvases``
92+
is reached or an early stopping condition is met.
93+
"""
94+
95+
from __future__ import annotations
96+
97+
import dataclasses
98+
from typing import Any, Protocol
99+
100+
from hackable_diffusion.lib import corruption
101+
from hackable_diffusion.lib import hd_typing
102+
from hackable_diffusion.lib import inference
103+
from hackable_diffusion.lib.sampling.sampling import DiffusionSampler
104+
import jax
105+
import jax.numpy as jnp
106+
import kauldron.ktyping as kt
107+
from kauldron.ktyping import Bool, PRNGKey
108+
109+
################################################################################
110+
# MARK: Type aliases
111+
################################################################################
112+
113+
SamplerState = dict[str, Any]
114+
DataArray = hd_typing.DataArray
115+
Conditioning = hd_typing.Conditioning
116+
InferenceFn = inference.InferenceFn
117+
118+
################################################################################
119+
# MARK: ARStateHandler
120+
################################################################################
121+
122+
123+
class ARStateHandler(Protocol):
124+
"""Manages the sampler state lifecycle during AR sampling.
125+
126+
Subclass this to inject model-specific logic for initializing,
127+
updating, and finalizing the autoregressive sampler state.
128+
129+
Methods:
130+
init_ar_state: Creates the initial state from conditioning.
131+
update_ar_state: Post-processes a sampled canvas and updates the state
132+
(i.e. update KV cache for LLMs).
133+
finalize_ar_state: Extracts the final generated output from the
134+
state.
135+
create_conditioning_from_state: Extracts the subset of sampler
136+
state needed as conditioning for the diffusion sampler.
137+
"""
138+
139+
def init_ar_state(
140+
self,
141+
*,
142+
batch_size: int,
143+
conditioning: Conditioning,
144+
canvas_length: int,
145+
max_num_canvases: int,
146+
) -> SamplerState:
147+
...
148+
149+
def update_ar_state(
150+
self,
151+
canvas: DataArray,
152+
sampler_state: SamplerState,
153+
) -> SamplerState:
154+
...
155+
156+
def finalize_ar_state(
157+
self,
158+
sampler_state: SamplerState,
159+
) -> DataArray:
160+
...
161+
162+
def create_conditioning_from_state(
163+
self,
164+
sampler_state: SamplerState,
165+
) -> Conditioning:
166+
...
167+
168+
169+
################################################################################
170+
# MARK: EarlyStoppingFn
171+
################################################################################
172+
173+
174+
class EarlyStoppingFn(Protocol):
175+
"""Determines whether to terminate the AR loop early.
176+
177+
The function receives the full sampler state and must return a JAX
178+
boolean *scalar* (``True`` → stop). The canonical implementation
179+
checks ``jnp.all(sampler_state['done'])`` where ``done`` is a
180+
per-batch-element boolean array.
181+
"""
182+
183+
def __call__(self, sampler_state: SamplerState) -> Bool['']:
184+
"""Returns true when the AR loop should terminate."""
185+
186+
187+
class DoneEarlyStoppingFn(EarlyStoppingFn):
188+
"""Stops when every batch element is done."""
189+
190+
def __call__(self, sampler_state: SamplerState) -> Bool['']:
191+
if 'done' not in sampler_state:
192+
raise ValueError(
193+
'DoneEarlyStoppingFn requires sampler_state["done"] to be set.'
194+
)
195+
return jnp.all(sampler_state['done'])
196+
197+
198+
################################################################################
199+
# MARK: Sampler
200+
################################################################################
201+
202+
203+
@dataclasses.dataclass(kw_only=True, frozen=True)
204+
class AutoregressiveDiffusionSampler:
205+
"""Generates data by autoregressively sampling fixed-length canvases.
206+
207+
Each iteration of the generation loop:
208+
1. Samples a canvas of ``canvas_length`` elements via diffusion.
209+
2. Passes the canvas to ``state_handler.update_ar_state`` for
210+
post-processing and state bookkeeping.
211+
3. Checks EarlyStoppingFn to decide whether to stop.
212+
213+
After the loop, ``state_handler.finalize_ar_state`` extracts the final
214+
generated output.
215+
216+
The loop is implemented via ``jax.lax.while_loop`` for JIT compatibility.
217+
218+
Attributes:
219+
canvas_sampler: Diffusion sampler that denoises a single canvas.
220+
diffusion_process: Noise process used to initialize canvases.
221+
canvas_length: Number of elements per canvas.
222+
max_num_canvases: Maximum number of canvases to generate.
223+
state_handler: Manages the AR state lifecycle (init, update, finalize).
224+
early_stopping_fn: Determines whether to terminate the AR loop early.
225+
data_dtype: Data type of the generated output.
226+
data_shape: Additional dimensions of the generated output (e.g., spatial
227+
dimensions for images).
228+
"""
229+
230+
canvas_sampler: DiffusionSampler
231+
diffusion_process: corruption.CategoricalProcess
232+
canvas_length: int
233+
max_num_canvases: int
234+
state_handler: ARStateHandler
235+
data_dtype: jnp.dtype
236+
data_shape: tuple[int, ...]
237+
early_stopping_fn: EarlyStoppingFn = DoneEarlyStoppingFn()
238+
239+
@kt.typechecked
240+
def __call__(
241+
self,
242+
diffusion_inference_fn: inference.InferenceFn,
243+
batch_size: int,
244+
rng: PRNGKey,
245+
conditioning: Conditioning,
246+
) -> tuple[DataArray, SamplerState]:
247+
"""Generates data autoregressively via discrete diffusion.
248+
249+
Uses ``jax.lax.while_loop`` for JIT compatibility with true early
250+
stopping.
251+
252+
Args:
253+
diffusion_inference_fn: Model inference function called during diffusion
254+
sampling.
255+
batch_size: Batch size of the generation.
256+
rng: JAX PRNG key, split per canvas for reproducibility.
257+
conditioning: Conditioning for the generation (e.g., text prompts, images,
258+
or any modality-specific inputs).
259+
260+
Returns:
261+
A tuple of (generated output, final sampler state).
262+
"""
263+
264+
sampler_state = self.state_handler.init_ar_state(
265+
batch_size=batch_size,
266+
conditioning=conditioning,
267+
canvas_length=self.canvas_length,
268+
max_num_canvases=self.max_num_canvases,
269+
)
270+
271+
# Carry: (sampler_state, rng, step_counter)
272+
init_carry = (sampler_state, rng, jnp.int32(0))
273+
274+
def _cond_fn(carry):
275+
sampler_state, _, step = carry
276+
should_stop = self.early_stopping_fn(sampler_state)
277+
should_continue = ~should_stop
278+
less_than_max_canvases = step < self.max_num_canvases
279+
return should_continue & less_than_max_canvases
280+
281+
def _body_fn(carry):
282+
sampler_state, rng, step = carry
283+
284+
# Propagate random number generator.
285+
rng, canvas_init_rng, canvas_sampler_rng = jax.random.split(rng, 3)
286+
287+
# Create new canvas.
288+
initial_canvas = self.diffusion_process.sample_from_invariant(
289+
key=canvas_init_rng,
290+
data_spec=jnp.zeros(
291+
(
292+
batch_size,
293+
self.canvas_length,
294+
)
295+
+ self.data_shape,
296+
dtype=self.data_dtype,
297+
),
298+
)
299+
300+
# Sample canvas via diffusion.
301+
# TODO: Implement returning the whole sampling trajectory.
302+
last_step, _ = self.canvas_sampler(
303+
inference_fn=diffusion_inference_fn,
304+
rng=canvas_sampler_rng,
305+
initial_noise=initial_canvas,
306+
conditioning=self.state_handler.create_conditioning_from_state(
307+
sampler_state=sampler_state
308+
),
309+
)
310+
sampled_canvas = last_step.xt
311+
312+
# Post-process canvas and update sampler state.
313+
sampler_state = self.state_handler.update_ar_state(
314+
canvas=sampled_canvas, sampler_state=sampler_state
315+
)
316+
317+
return (sampler_state, rng, step + 1)
318+
319+
sampler_state, _, _ = jax.lax.while_loop(_cond_fn, _body_fn, init_carry)
320+
321+
# Read-out the final output.
322+
output = self.state_handler.finalize_ar_state(sampler_state=sampler_state)
323+
return output, sampler_state

0 commit comments

Comments
 (0)