Skip to content

Commit e1ff260

Browse files
authored
fix: sdxl build oom (#151)
1 parent 7e10c9c commit e1ff260

File tree

3 files changed

+6
-6
lines changed

3 files changed

+6
-6
lines changed

src/streamdiffusion/acceleration/tensorrt/builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def build(
6060
opt_batch_size=opt_batch_size,
6161
onnx_opset=onnx_opset,
6262
)
63+
self.network = self.network.to("cpu")
6364
del self.network
6465
gc.collect()
6566
torch.cuda.empty_cache()
@@ -89,7 +90,6 @@ def build(
8990
build_all_tactics=build_all_tactics,
9091
build_enable_refit=build_enable_refit,
9192
)
92-
9393
for file in os.listdir(os.path.dirname(engine_path)):
9494
if file.endswith('.engine'):
9595
continue

src/streamdiffusion/modules/controlnet_module.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -610,7 +610,7 @@ def _load_pytorch_controlnet_model(self, model_id: str, conditioning_channels: O
610610
)
611611
else:
612612
controlnet = ControlNetModel.from_pretrained(model_id, **load_kwargs)
613-
controlnet = controlnet.to(device=self.device, dtype=self.dtype)
613+
controlnet = controlnet.to(dtype=self.dtype)
614614
# Track model_id for updater diffing
615615
try:
616616
setattr(controlnet, 'model_id', model_id)

src/streamdiffusion/wrapper.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1042,9 +1042,9 @@ def _load_model(
10421042
traceback.print_exc()
10431043
raise RuntimeError(error_msg)
10441044
else:
1045-
if hasattr(pipe, "text_encoder") and pipe.text_encoder is not None:
1045+
if not compile_engines_only and hasattr(pipe, "text_encoder") and pipe.text_encoder is not None:
10461046
pipe.text_encoder = pipe.text_encoder.to(device=self.device)
1047-
if hasattr(pipe, "text_encoder_2") and pipe.text_encoder_2 is not None:
1047+
if not compile_engines_only and hasattr(pipe, "text_encoder_2") and pipe.text_encoder_2 is not None:
10481048
pipe.text_encoder_2 = pipe.text_encoder_2.to(device=self.device)
10491049

10501050
# If we get here, the model loaded successfully - break out of retry loop
@@ -1571,7 +1571,7 @@ def _load_model(
15711571
if self.use_safety_checker or safety_checker_engine_exists:
15721572
if not safety_checker_engine_exists:
15731573
from transformers import AutoModelForImageClassification
1574-
self.safety_checker = AutoModelForImageClassification.from_pretrained(safety_checker_model_id).to("cuda")
1574+
self.safety_checker = AutoModelForImageClassification.from_pretrained(safety_checker_model_id)
15751575

15761576
safety_checker_model = NSFWDetector(
15771577
device=self.device,
@@ -1586,7 +1586,7 @@ def _load_model(
15861586
model_config=safety_checker_model,
15871587
batch_size=self.batch_size if self.mode == "txt2img" else stream.frame_bff_size,
15881588
cuda_stream=None,
1589-
load_engine=load_engine,
1589+
load_engine=False,
15901590
)
15911591

15921592
if load_engine:

0 commit comments

Comments
 (0)