Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions comfy/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,20 +130,39 @@ def __init__(self, key, patches, convert_func=None, set_func=None):
self.set_func = set_func

def __call__(self, weight):
# Detect SageAttention and skip conversion for compatibility
sage_attention_active = False
try:
import comfy.cli_args
sage_attention_active = hasattr(comfy.cli_args.args, 'use_sage_attention') and \
comfy.cli_args.args.use_sage_attention
except:
pass

intermediate_dtype = weight.dtype
if self.convert_func is not None:

# Skip convert_func when SageAttention is active (compatibility mode)
if self.convert_func is not None and not sage_attention_active:
weight = self.convert_func(weight.to(dtype=torch.float32, copy=True), inplace=True)
elif sage_attention_active and self.convert_func is not None:
logging.debug(f"Skipping convert_func for {self.key} (SageAttention compatibility)")

if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops
intermediate_dtype = torch.float32
out = comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype)
if self.set_func is None:
return comfy.float.stochastic_rounding(out, weight.dtype, seed=string_to_seed(self.key))
else:
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True)
# Skip set_func when SageAttention is active (compatibility mode)
if not sage_attention_active:
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True)
else:
return comfy.float.stochastic_rounding(out, weight.dtype, seed=string_to_seed(self.key))

out = comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype)
if self.set_func is not None:

# Skip set_func when SageAttention is active (compatibility mode)
if self.set_func is not None and not sage_attention_active:
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True).to(dtype=intermediate_dtype)
else:
return out
Expand Down