Skip to content

Commit 739d6ec

Browse files
authored
add a timestep scale for sana-sprint teacher model (#11150)
1 parent 1ddf3f3 commit 739d6ec

File tree

2 files changed

+6
-0
lines changed

2 files changed

+6
-0
lines changed

src/diffusers/models/transformers/sana_transformer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,10 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
326326
Whether to use elementwise affinity in the normalization layer.
327327
norm_eps (`float`, defaults to `1e-6`):
328328
The epsilon value for the normalization layer.
329+
qk_norm (`str`, *optional*, defaults to `None`):
330+
The normalization to use for the query and key.
331+
timestep_scale (`float`, defaults to `1.0`):
332+
The scale to use for the timesteps.
329333
"""
330334

331335
_supports_gradient_checkpointing = True
@@ -355,6 +359,7 @@ def __init__(
355359
guidance_embeds: bool = False,
356360
guidance_embeds_scale: float = 0.1,
357361
qk_norm: Optional[str] = None,
362+
timestep_scale: float = 1.0,
358363
) -> None:
359364
super().__init__()
360365

src/diffusers/pipelines/sana/pipeline_sana.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -938,6 +938,7 @@ def __call__(
938938

939939
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
940940
timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
941+
timestep = timestep * self.transformer.config.timestep_scale
941942

942943
# predict noise model_output
943944
noise_pred = self.transformer(

0 commit comments

Comments
 (0)