Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
6f52606
feat: add per-model FP8 layerwise casting for VRAM reduction
Pfannkuchensack Mar 6, 2026
bf3bd2e
feat: add FP8 storage option to Model Manager UI
Pfannkuchensack Mar 6, 2026
afe246e
ruff format
Pfannkuchensack Mar 6, 2026
2262d8d
Merge branch 'main' into feature/fp8-layerwise-casting
JPPhoto Mar 9, 2026
5327df8
Merge branch 'main' into feature/fp8-layerwise-casting
JPPhoto Mar 11, 2026
6c13fca
Merge branch 'main' into feature/fp8-layerwise-casting
JPPhoto Mar 20, 2026
0d7b39f
fix: enable FP8 layerwise casting for checkpoint Flux models
Pfannkuchensack Mar 21, 2026
a0df643
fix: exclude Z-Image from FP8 due to diffusers layerwise casting bug
Pfannkuchensack Mar 21, 2026
06ad3c7
fix: detect model dtype for FP8 compute instead of using global dtype
Pfannkuchensack Mar 21, 2026
025759f
Remove call for _should_use_fp8 in z-image
Pfannkuchensack Mar 21, 2026
8ddb200
Merge branch 'main' into feature/fp8-layerwise-casting
Pfannkuchensack Mar 26, 2026
9798012
Merge branch 'main' into feature/fp8-layerwise-casting
Pfannkuchensack Mar 31, 2026
2b0af7c
Merge remote-tracking branch 'upstream/main' into feature/fp8-layerwi…
Pfannkuchensack May 9, 2026
55d41a6
Merge branch 'main' + exclude VAEs from FP8 layerwise casting
Pfannkuchensack May 9, 2026
f0a53a5
fix(fp8): invalidate cache on settings change, exception-safe nn.Modu…
Pfannkuchensack May 11, 2026
f841598
fix(fp8): honor class swap for LoRA patches, evict stale locked entri…
Pfannkuchensack May 11, 2026
ae2068a
Merge branch 'main' into feature/fp8-layerwise-casting
JPPhoto May 11, 2026
f94b705
fix(fp8): switch nn.Module FP8 wrapper to hooks so CustomLinear dispa…
Pfannkuchensack May 11, 2026
458a425
Merge branch 'feature/fp8-layerwise-casting' of https://github.com/Pf…
Pfannkuchensack May 11, 2026
2598698
Add docs for fp8
Pfannkuchensack May 11, 2026
9a4a2f8
Merge branch 'main' into feature/fp8-layerwise-casting
lstein May 12, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions docs/src/content/docs/configuration/fp8-storage.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
---
title: FP8 Storage
sidebar:
order: 3
---

import { Steps } from '@astrojs/starlight/components';

FP8 Storage cuts a model's VRAM footprint roughly in half by keeping weights on the GPU in 8-bit floating-point format (`float8_e4m3fn`). During inference, each layer's weights are cast on-the-fly back up to the compute precision (FP16/BF16), then cast back to FP8 after the forward pass — so quality is largely preserved.

It pairs well with [Low-VRAM mode](/configuration/low-vram-mode/): low-VRAM mode streams layers between RAM and VRAM, while FP8 Storage shrinks the layers themselves.

## Requirements

- **Nvidia GPU on Windows or Linux.** FP8 Storage uses CUDA tensor types and is silently disabled on CPU and MPS.
- **CUDA 12.x and recent PyTorch.** The `float8_e4m3fn` dtype was added in PyTorch 2.1 — InvokeAI's bundled versions satisfy this.

There is no hardware requirement for FP8 *compute* — InvokeAI casts back to FP16/BF16 for math. This means FP8 Storage works on GPUs that do not natively support FP8 matmul (e.g. RTX 30-series), at a small per-step throughput cost.

## Enabling FP8 Storage

FP8 Storage is a **per-model setting**, configured from the Model Manager:

<Steps>
1. Open the **Model Manager**.
2. Select a model (Main, ControlNet, or T2I-Adapter).
3. Under **Default Settings**, toggle **FP8 Storage (Save VRAM)**.
4. Click **Save**.
</Steps>

The setting takes effect on the next load. If the model is already in the cache, InvokeAI evicts the cached copy automatically so the new setting applies — even if a generation is currently using the model (the eviction is deferred until the generation finishes).

:::tip[When to enable]
Enable FP8 Storage on large models that don't fit comfortably in VRAM — FLUX dev/Klein, large SDXL checkpoints, ControlNet-XL adapters. For smaller SD1 / SD2 models, the savings are negligible and not worth the small precision trade-off.
:::

## What FP8 Storage applies to

FP8 Storage is **only** applied to layers where the precision trade-off is acceptable:

| Model type | FP8 applied? |
| ----------------------------- | -------------------------------------- |
| Main models (SD1, SD2, SDXL) | Yes |
| FLUX.1 / FLUX.2 Klein | Yes |
| ControlNet, T2I-Adapter | Yes |
| VAE | No — visible decode-quality regression |
| Text encoders, tokenizers | No — small models, no benefit |
| Z-Image (any variant) | No — dtype mismatch with skipped layers|
| LoRA, ControlLoRA | No — patched into base, not run alone |

Within a supported model, **norm layers, position/patch embeddings, and `proj_in`/`proj_out` are skipped** so precision-sensitive tiny learned scalars (e.g. FLUX `RMSNorm.scale`) aren't crushed to FP8. This mirrors the diffusers default skip list.

## Quality trade-offs

FP8 Storage is **near-lossless** for most workloads because:

- Norms and embeddings (the precision-sensitive layers) are skipped.
- The actual matmul still happens in FP16/BF16 — FP8 is only the on-GPU storage format.

That said, some artifacts have been reported on:

- **VAEs** — never cast (the toggle has no effect on VAE submodels).
- **Heavy LoRA stacks** — patching is unaffected, but very precision-sensitive LoRAs may show slight drift. Compare a side-by-side if your workflow depends on subtle LoRA behavior.

If you see unexpected quality regressions, disable FP8 Storage on the affected model and re-run.

## Combining with Low-VRAM mode and quantized models

- **FP8 + partial loading**: fully supported. FP8 Storage shrinks the layers; partial loading streams them between RAM and VRAM as needed. Use both on tight VRAM budgets.
- **FP8 + GGUF / NF4 / int8 quantized checkpoints**: these formats already have their own storage precision. FP8 Storage is not applied on top — the toggle is silently a no-op for quantized formats, since the loader returns a different module type.

## Troubleshooting

### "I toggled FP8 Storage but VRAM usage didn't change"

The cache eviction is immediate for idle models, but **deferred until the next unlock** if the model is mid-generation. Wait for the current generation to finish, then start a new one — the next load will use the new setting.

If VRAM still hasn't dropped:

- Check the InvokeAI log for `FP8 layerwise casting enabled for <model name>`. If the line isn't there, the model is on the exclusion list (VAE, text encoder, Z-Image, LoRA — see table above).
- Confirm you are on CUDA. FP8 Storage is silently disabled on CPU and MPS.

### Quality regression on a specific model

Disable FP8 Storage for that model in Model Manager and reload. If quality is restored, the model has FP8-sensitive layers that fall outside the default skip list. Please open an issue with the model name and a side-by-side comparison.

### "RuntimeError: ... float8_e4m3fn ..."

You're on a PyTorch version that predates FP8 support. Reinstall InvokeAI using the official launcher — the bundled torch version supports FP8.
30 changes: 30 additions & 0 deletions invokeai/app/api/routers/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,17 @@ async def update_model_record(
logger = ApiDependencies.invoker.services.logger
record_store = ApiDependencies.invoker.services.model_manager.store
try:
previous_config = record_store.get_model(key)
config = record_store.update_model(key, changes=changes, allow_class_change=True)
# Settings that change how the model loads (e.g. fp8_storage, cpu_only) are baked into the cached
# nn.Module at load time, so toggling them on a cached model is otherwise silently a no-op until
# the entry is evicted. Drop any unlocked cached entries for this model so the next load rebuilds.
if _load_settings_changed(previous_config, config):
dropped = ApiDependencies.invoker.services.model_manager.load.ram_cache.drop_model(key)
if dropped:
logger.info(
f"Dropped {dropped} cached entr{'y' if dropped == 1 else 'ies'} for model {key} after settings change."
)
config = prepare_model_config_for_response(config, ApiDependencies)
logger.info(f"Updated model: {key}")
except UnknownModelException as e:
Expand All @@ -448,6 +458,26 @@ async def update_model_record(
return config


_LOAD_AFFECTING_SETTINGS: tuple[str, ...] = ("fp8_storage", "cpu_only")


def _load_settings_changed(previous: AnyModelConfig, updated: AnyModelConfig) -> bool:
"""Return True if any setting that influences how the model is loaded changed.

Such settings are read by the loader during `_load_model` and baked into the resulting
nn.Module, so a cached entry built under the old value must be evicted for the change
to take effect.
"""
if getattr(previous, "cpu_only", None) != getattr(updated, "cpu_only", None):
return True
previous_settings = getattr(previous, "default_settings", None)
updated_settings = getattr(updated, "default_settings", None)
for field in _LOAD_AFFECTING_SETTINGS:
if getattr(previous_settings, field, None) != getattr(updated_settings, field, None):
return True
return False


@model_manager_router.get(
"/i/{key}/image",
operation_id="get_model_image",
Expand Down
4 changes: 4 additions & 0 deletions invokeai/backend/model_manager/configs/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@
class ControlAdapterDefaultSettings(BaseModel):
# This could be narrowed to controlnet processor nodes, but they change. Leaving this a string is safer.
preprocessor: str | None
fp8_storage: bool | None = Field(
default=None,
description="Store weights in FP8 to reduce VRAM usage (~50% savings). Weights are cast to compute dtype during inference.",
)
model_config = ConfigDict(extra="forbid")

@classmethod
Expand Down
4 changes: 4 additions & 0 deletions invokeai/backend/model_manager/configs/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ class MainModelDefaultSettings(BaseModel):
height: int | None = Field(default=None, multiple_of=8, ge=64, description="Default height for this model")
guidance: float | None = Field(default=None, ge=1, description="Default Guidance for this model")
cpu_only: bool | None = Field(default=None, description="Whether this model should run on CPU only")
fp8_storage: bool | None = Field(
default=None,
description="Store weights in FP8 to reduce VRAM usage (~50% savings). Weights are cast to compute dtype during inference.",
)

model_config = ConfigDict(extra="forbid")

Expand Down
175 changes: 175 additions & 0 deletions invokeai/backend/model_manager/load/load_default.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
"""Default implementation of model loading in InvokeAI."""

import re
from logging import Logger
from pathlib import Path
from typing import Optional
Expand All @@ -21,6 +22,35 @@
)
from invokeai.backend.util.devices import TorchDevice

# Layer classes that benefit from FP8 storage. Mirrors diffusers'
# `_GO_LC_SUPPORTED_PYTORCH_LAYERS` so the plain-nn.Module fallback path makes the same
# precision/quality trade-offs as the ModelMixin path. Notably excludes norm and embedding
# wrapper modules — those are handled by their direct param types (Embedding is included
# but pos_embed/patch_embed are filtered by `_FP8_DEFAULT_SKIP_PATTERNS`).
_FP8_SUPPORTED_PYTORCH_LAYERS: tuple[type[torch.nn.Module], ...] = (
torch.nn.Linear,
torch.nn.Conv1d,
torch.nn.Conv2d,
torch.nn.Conv3d,
torch.nn.ConvTranspose1d,
torch.nn.ConvTranspose2d,
torch.nn.ConvTranspose3d,
torch.nn.Embedding,
)

# Module-path regexes (matched against `named_modules()` dotted paths) for precision-sensitive
# layers that should never be cast to FP8. Mirrors diffusers' `DEFAULT_SKIP_MODULES_PATTERN`
# — without these, FLUX RMSNorm.scale and similar tiny learned scalars get crushed to FP8 and
# inference quality degrades. Includes anything named `norm`, position/patch embeddings, and
# the in/out projection of transformer blocks.
_FP8_DEFAULT_SKIP_PATTERNS: tuple[str, ...] = (
"pos_embed",
"patch_embed",
"norm",
r"^proj_in$",
r"^proj_out$",
)


# TO DO: The loader is not thread safe!
class ModelLoader(ModelLoaderBase):
Expand Down Expand Up @@ -124,6 +154,151 @@ def get_size_fs(
variant=config.repo_variant if isinstance(config, Diffusers_Config_Base) else None,
)

def _should_use_fp8(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> bool:
"""Check if FP8 layerwise casting should be applied to a model."""
# FP8 storage only works on CUDA
if self._torch_device.type != "cuda":
Comment thread
lstein marked this conversation as resolved.
return False

# Z-Image has dtype mismatch issues with diffusers' layerwise casting
# (skipped modules produce bf16, hooked modules expect fp16).
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType

if hasattr(config, "base") and config.base == BaseModelType.ZImage:
return False

# VAEs are excluded — fp8 storage causes noticeable quality degradation in decode.
if hasattr(config, "type") and config.type == ModelType.VAE:
return False

# LoRAs (including ControlLoRA) are excluded — they are not run as a standalone forward pass,
# they are patched into a base model, so the layerwise-casting hooks would never fire. The
# toggle is also hidden in the UI for ControlLoRA; this guard handles legacy persisted values.
if hasattr(config, "type") and config.type in (ModelType.LoRA, ModelType.ControlLoRa):
return False

# Don't apply FP8 to text encoders, tokenizers, schedulers, VAEs, etc.
_excluded_submodel_types = {
SubModelType.TextEncoder,
SubModelType.TextEncoder2,
SubModelType.TextEncoder3,
SubModelType.Tokenizer,
SubModelType.Tokenizer2,
SubModelType.Tokenizer3,
SubModelType.Scheduler,
SubModelType.SafetyChecker,
SubModelType.VAE,
SubModelType.VAEDecoder,
SubModelType.VAEEncoder,
}
if submodel_type in _excluded_submodel_types:
return False

# Check default_settings.fp8_storage (Main models, ControlNet)
if hasattr(config, "default_settings") and config.default_settings is not None:
if hasattr(config.default_settings, "fp8_storage") and config.default_settings.fp8_storage is True:
return True

return False

def _apply_fp8_layerwise_casting(
self, model: AnyModel, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None
) -> AnyModel:
"""Apply FP8 layerwise casting to a model if enabled in its config."""
if not self._should_use_fp8(config, submodel_type):
return model

storage_dtype = torch.float8_e4m3fn
compute_dtype = self._torch_dtype

# Detect the model's current dtype to use as compute dtype, since models
# (e.g. Flux) may require a specific dtype (bf16) that differs from the global torch dtype (fp16).
if isinstance(model, torch.nn.Module):
first_param = next(model.parameters(), None)
if first_param is not None:
compute_dtype = first_param.dtype

from diffusers.models.modeling_utils import ModelMixin

if isinstance(model, ModelMixin):
model.enable_layerwise_casting(
storage_dtype=storage_dtype,
compute_dtype=compute_dtype,
)
elif isinstance(model, torch.nn.Module):
self._apply_fp8_to_nn_module(model, storage_dtype=storage_dtype, compute_dtype=compute_dtype)
else:
return model

param_bytes = sum(p.nelement() * p.element_size() for p in model.parameters())
self._logger.info(
f"FP8 layerwise casting enabled for {config.name} "
f"(storage=float8_e4m3fn, compute={compute_dtype}, "
f"param_size={param_bytes / (1024**2):.0f}MB)"
)
return model

@staticmethod
def _apply_fp8_to_nn_module(model: torch.nn.Module, storage_dtype: torch.dtype, compute_dtype: torch.dtype) -> None:
"""Apply FP8 layerwise casting to a plain nn.Module.

Mirrors diffusers' `apply_layerwise_casting` semantics: only the layer classes in
`_FP8_SUPPORTED_PYTORCH_LAYERS` are cast, and modules whose dotted path matches any of
`_FP8_DEFAULT_SKIP_PATTERNS` (norm, pos_embed, patch_embed, proj_in/out) are skipped.
Without the skip list, precision-sensitive tiny learned scalars (e.g. FLUX RMSNorm.scale)
get crushed to FP8 and quality degrades noticeably.
"""
for module_name, module in model.named_modules():
if not isinstance(module, _FP8_SUPPORTED_PYTORCH_LAYERS):
continue
if any(re.search(pattern, module_name) for pattern in _FP8_DEFAULT_SKIP_PATTERNS):
continue
params = list(module.parameters(recurse=False))
if not params:
continue

for param in params:
param.data = param.data.to(storage_dtype)

ModelLoader._wrap_forward_with_fp8_cast(module, storage_dtype, compute_dtype)

@staticmethod
def _wrap_forward_with_fp8_cast(
module: torch.nn.Module, storage_dtype: torch.dtype, compute_dtype: torch.dtype
) -> None:
"""Register pre/post forward hooks that cast params to compute dtype on entry and back
to storage dtype on exit.

We use hooks (rather than overriding `module.forward`) for two reasons:

1. **Correct dispatch after `apply_custom_layers_to_model`.** `ModelCache.put()` calls
`apply_custom_layers_to_model`, which creates a NEW `CustomLinear` instance and
shares the original `Linear.__dict__` (see `wrap_custom_layer`). Anything stored in
that dict — including an instance-level `forward` attribute — gets carried over to
the new object. An overridden `forward` would close over the OLD instance, so calls
to the new `CustomLinear` would silently route to `Linear.forward(old_instance, ...)`
and bypass the LoRA-patch-aware branch in `CustomLinear.forward`. Hooks, by contrast,
live in `_forward_hooks` / `_forward_pre_hooks` and are dispatched by
`nn.Module.__call__` with the *actual* called instance — so they run on the new
`CustomLinear` and the class's `forward` is still resolved normally.

2. **Exception safety.** `register_forward_hook(..., always_call=True)` fires the
post-hook even when `forward` raises. The plain pre-hook/post-hook pair without
`always_call` would leave params in compute dtype on exception, defeating FP8
storage savings and making cache size accounting stale.
"""

def pre_hook(mod: torch.nn.Module, _args: object) -> None:
for p in mod.parameters(recurse=False):
p.data = p.data.to(compute_dtype)

def post_hook(mod: torch.nn.Module, _args: object, _output: object) -> None:
for p in mod.parameters(recurse=False):
p.data = p.data.to(storage_dtype)

module.register_forward_pre_hook(pre_hook)
module.register_forward_hook(post_hook, always_call=True)

# This needs to be implemented in the subclass
def _load_model(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ class CacheRecord:
# Model in memory.
cached_model: CachedModelWithPartialLoad | CachedModelOnlyFullLoad
_locks: int = 0
# Set by ModelCache.drop_model() when the entry was locked at invalidation time.
# ModelCache.unlock() evicts the entry as soon as the last lock releases so a setting
# change (e.g. fp8_storage toggled during an in-flight generation) takes effect on the
# next load instead of silently being ignored.
is_stale: bool = False

def lock(self) -> None:
"""Lock this record."""
Expand Down
Loading
Loading