forked from civitai/civitai_comfy_nodes
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcivitai_checkpoint_loader.py
111 lines (83 loc) · 4.01 KB
/
civitai_checkpoint_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import hashlib
import json
import os
import requests
import sys
import time
from tqdm import tqdm
import folder_paths
import comfy.sd
import comfy.utils
from nodes import CheckpointLoaderSimple
from .CivitAI_Model import CivitAI_Model
from .utils import short_paths_map, model_path
ROOT_PATH = os.path.dirname(os.path.abspath(__file__))
CHECKPOINT_PATH = folder_paths.folder_names_and_paths["checkpoints"][0][0]
CHECKPOINTS = folder_paths.folder_names_and_paths["checkpoints"][0]
MSG_PREFIX = '\33[1m\33[34m[CivitAI] \33[0m'
class CivitAI_Checkpoint_Loader:
"""
Implements the CivitAI Checkpoint Loader node for ComfyUI
"""
def __init__(self):
self.ckpt_loader = None
@classmethod
def INPUT_TYPES(cls):
checkpoints = folder_paths.get_filename_list("checkpoints")
checkpoints.insert(0, 'none')
checkpoint_paths = short_paths_map(CHECKPOINTS)
return {
"required": {
"ckpt_air": ("STRING", {"default": "{model_id}@{model_version}", "multiline": False}),
"ckpt_name": (checkpoints,),
},
"optional": {
"api_key": ("STRING", {"default": "", "multiline": False}),
"download_chunks": ("INT", {"default": 4, "min": 1, "max": 12, "step": 1}),
"download_path": (list(checkpoint_paths.keys()),),
},
"hidden": {
"extra_pnginfo": "EXTRA_PNGINFO"
}
}
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
FUNCTION = "load_checkpoint"
CATEGORY = "CivitAI/Loaders"
def load_checkpoint(self, ckpt_air, ckpt_name, api_key=None, download_chunks=None, download_path=None, extra_pnginfo=None):
if extra_pnginfo and 'workflow' in extra_pnginfo:
extra_pnginfo['workflow']['extra'].setdefault('ckpt_airs', [])
if not self.ckpt_loader:
self.ckpt_loader = CheckpointLoaderSimple()
if ckpt_name == 'none':
ckpt_id = None
version_id = None
if '@' in ckpt_air:
ckpt_id, version_id = ckpt_air.split('@')
else:
ckpt_id = ckpt_air
ckpt_id = int(ckpt_id) if ckpt_id else None
version_id = int(version_id) if version_id else None
checkpoint_paths = short_paths_map(CHECKPOINTS)
if download_path:
if checkpoint_paths.__contains__(download_path):
download_path = checkpoint_paths[download_path]
else:
download_path = CHECKPOINTS[0]
civitai_model = CivitAI_Model(model_id=ckpt_id, model_version=version_id, model_types=["Checkpoint",], token=api_key, save_path=download_path, model_paths=CHECKPOINTS, download_chunks=download_chunks)
if not civitai_model.download():
return None, None, None
ckpt_name = civitai_model.name
if extra_pnginfo and 'workflow' in extra_pnginfo:
air = f'{civitai_model.model_id}@{civitai_model.version}'
if air not in extra_pnginfo['workflow']['extra']['ckpt_airs']:
extra_pnginfo['workflow']['extra']['ckpt_airs'].append(air)
else:
ckpt_path = model_path(ckpt_name, CHECKPOINTS)
model_id, version_id, details = CivitAI_Model.sha256_lookup(ckpt_path)
if model_id and version_id and extra_pnginfo and 'workflow' in extra_pnginfo:
air = f'{model_id}@{version_id}'
if air not in extra_pnginfo['workflow']['extra']['ckpt_airs']:
extra_pnginfo['workflow']['extra']['ckpt_airs'].append(air)
print(f"{MSG_PREFIX}Loading checkpoint from disk: {ckpt_path}")
out = self.ckpt_loader.load_checkpoint(ckpt_name=ckpt_name)
return out[0], out[1], out[2], { "extra_pnginfo": extra_pnginfo }