Skip to content

Commit ab610dd

Browse files
sdbdszengyh1900
andauthored
[Fix] Powerpaint to load safetensors (open-mmlab#2088)
* fix load safetensors * fix lint --------- Co-authored-by: zengyh1900 <[email protected]>
1 parent 03e24cc commit ab610dd

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

projects/powerpaint/gradio_PowerPaint.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
StableDiffusionInpaintPipeline as Pipeline
1212
from pipeline.pipeline_PowerPaint_ControlNet import \
1313
StableDiffusionControlNetInpaintPipeline as controlnetPipeline
14+
from safetensors.torch import load_file
1415
from transformers import DPTFeatureExtractor, DPTForDepthEstimation
1516
from utils.utils import TokenizerWrapper, add_tokens
1617

@@ -34,7 +35,9 @@
3435
initialize_tokens=['a', 'a', 'a'],
3536
num_vectors_per_token=10)
3637
pipe.unet.load_state_dict(
37-
torch.load('./models/unet/diffusion_pytorch_model.bin'), strict=False)
38+
load_file(
39+
'./models/unet/diffusion_pytorch_model.safetensors', device='cuda'),
40+
strict=False)
3841
pipe.text_encoder.load_state_dict(
3942
torch.load('./models/text_encoder/pytorch_model.bin'), strict=False)
4043
pipe = pipe.to('cuda')

0 commit comments

Comments
 (0)