Skip to content

[Bug] WanTransformerInferFirstBlock lost the class method 'switch_status' #902

@jerry2102

Description

@jerry2102

Seems 8689e4c causing class method switch_status of class WanTransformerInferFirstBlock lost, which is add in commit dcaefe6
the diff file of this two commit is as follow,

with this bug, wan model cannot work properly using FBCache, I assume.

Can you help confirm it? thx @helloyongyang

commit 8689e4c7b4838992a38f882030af29af5a60372c
Author: helloyongyang <yongyang1030@163.com>
Date:   Fri Aug 15 11:52:00 2025 +0000

    update tea cache

diff --git a/lightx2v/models/networks/wan/infer/transformer_infer.py b/lightx2v/models/networks/wan/infer/transformer_infer.py
index f53842a4..e0ab522a 100755
--- a/lightx2v/models/networks/wan/infer/transformer_infer.py
+++ b/lightx2v/models/networks/wan/infer/transformer_infer.py
@@ -78,11 +78,6 @@ class WanTransformerInfer(BaseTransformerInfer):
         else:
             self.infer_func = self._infer_without_offload
 
-        self.infer_conditional = True
-
-    def switch_status(self):
-        self.infer_conditional = not self.infer_conditional
-
     def _calculate_q_k_len(self, q, k_lens):
         q_lens = torch.tensor([q.size(0)], dtype=torch.int32, device=q.device)
         cu_seqlens_q = torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32)

commit dcaefe635a3b8da6a1f8f80f4fd06d556b16bf50
Author: Yang Yong(雍洋) <yongyang1030@163.com>
Date:   Sun Jun 29 21:41:27 2025 +0800

    update feature caching (#78)
    
    Co-authored-by: Linboyan-trc <1584340372@qq.com>

diff --git a/lightx2v/models/networks/wan/infer/transformer_infer.py b/lightx2v/models/networks/wan/infer/transformer_infer.py
index 421e3534..174be2e6 100755
--- a/lightx2v/models/networks/wan/infer/transformer_infer.py
+++ b/lightx2v/models/networks/wan/infer/transformer_infer.py
@@ -4,10 +4,11 @@ from lightx2v.common.offload.manager import (
     WeightAsyncStreamManager,
     LazyWeightAsyncStreamManager,
 )
+from lightx2v.common.transformer_infer.transformer_infer import BaseTransformerInfer
 from lightx2v.utils.envs import *
 
 
-class WanTransformerInfer:
+class WanTransformerInfer(BaseTransformerInfer):
     def __init__(self, config):
         self.config = config
         self.task = config["task"]
@@ -49,8 +50,10 @@ class WanTransformerInfer:
         else:
             self.infer_func = self._infer_without_offload
 
-    def set_scheduler(self, scheduler):
-        self.scheduler = scheduler
+        self.infer_conditional = True
+
+    def switch_status(self):
+        self.infer_conditional = not self.infer_conditional
 
     def _calculate_q_k_len(self, q, k_lens):
         q_lens = torch.tensor([q.size(0)], dtype=torch.int32, device=q.device)

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    Status

    Todo

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions