Skip to content

Commit efc6db8

Browse files
committed
Refactor preview to match the other callbacks
1 parent 53903e6 commit efc6db8

File tree

5 files changed

+90
-94
lines changed

5 files changed

+90
-94
lines changed

examples/cli/main.cpp

+13-18
Original file line numberDiff line numberDiff line change
@@ -137,10 +137,10 @@ struct SDParams {
137137
float skip_layer_start = 0.01;
138138
float skip_layer_end = 0.2;
139139

140-
sd_preview_policy_t preview_method = SD_PREVIEW_NONE;
141-
int preview_interval = 1;
142-
std::string preview_path = "preview.png";
143-
bool taesd_preview = false;
140+
sd_preview_t preview_method = SD_PREVIEW_NONE;
141+
int preview_interval = 1;
142+
std::string preview_path = "preview.png";
143+
bool taesd_preview = false;
144144
};
145145

146146
void print_params(SDParams params) {
@@ -666,7 +666,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
666666
invalid_arg = true;
667667
break;
668668
}
669-
params.preview_method = (sd_preview_policy_t)preview_method;
669+
params.preview_method = (sd_preview_t)preview_method;
670670
} else if (arg == "--preview-interval") {
671671
if (++i >= argc) {
672672
invalid_arg = true;
@@ -850,6 +850,7 @@ int main(int argc, const char* argv[]) {
850850
preview_path = params.preview_path.c_str();
851851

852852
sd_set_log_callback(sd_log_cb, (void*)&params);
853+
sd_set_preview_callback((sd_preview_cb_t)step_callback, params.preview_method, params.preview_interval);
853854

854855
if (params.verbose) {
855856
print_params(params);
@@ -1025,10 +1026,7 @@ int main(int argc, const char* argv[]) {
10251026
params.skip_layers.size(),
10261027
params.slg_scale,
10271028
params.skip_layer_start,
1028-
params.skip_layer_end,
1029-
params.preview_method,
1030-
params.preview_interval,
1031-
(step_callback_t)step_callback);
1029+
params.skip_layer_end);
10321030
} else {
10331031
sd_image_t input_image = {(uint32_t)params.width,
10341032
(uint32_t)params.height,
@@ -1097,10 +1095,7 @@ int main(int argc, const char* argv[]) {
10971095
params.skip_layers.size(),
10981096
params.slg_scale,
10991097
params.skip_layer_start,
1100-
params.skip_layer_end,
1101-
params.preview_method,
1102-
params.preview_interval,
1103-
(step_callback_t)step_callback);
1098+
params.skip_layer_end);
11041099
}
11051100
}
11061101

@@ -1139,19 +1134,19 @@ int main(int argc, const char* argv[]) {
11391134

11401135
std::string dummy_name, ext, lc_ext;
11411136
bool is_jpg;
1142-
size_t last = params.output_path.find_last_of(".");
1137+
size_t last = params.output_path.find_last_of(".");
11431138
size_t last_path = std::min(params.output_path.find_last_of("/"),
11441139
params.output_path.find_last_of("\\"));
1145-
if (last != std::string::npos // filename has extension
1146-
&& (last_path == std::string::npos || last > last_path)) {
1140+
if (last != std::string::npos // filename has extension
1141+
&& (last_path == std::string::npos || last > last_path)) {
11471142
dummy_name = params.output_path.substr(0, last);
11481143
ext = lc_ext = params.output_path.substr(last);
11491144
std::transform(ext.begin(), ext.end(), lc_ext.begin(), ::tolower);
11501145
is_jpg = lc_ext == ".jpg" || lc_ext == ".jpeg" || lc_ext == ".jpe";
11511146
} else {
11521147
dummy_name = params.output_path;
11531148
ext = lc_ext = "";
1154-
is_jpg = false;
1149+
is_jpg = false;
11551150
}
11561151
// appending ".png" to absent or unknown extension
11571152
if (!is_jpg && lc_ext != ".png") {
@@ -1163,7 +1158,7 @@ int main(int argc, const char* argv[]) {
11631158
continue;
11641159
}
11651160
std::string final_image_path = i > 0 ? dummy_name + "_" + std::to_string(i + 1) + ext : dummy_name + ext;
1166-
if(is_jpg) {
1161+
if (is_jpg) {
11671162
stbi_write_jpg(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel,
11681163
results[i].data, 90, get_image_params(params, params.seed + i).c_str());
11691164
printf("save result JPEG image to '%s'\n", final_image_path.c_str());

stable-diffusion.cpp

+34-54
Original file line numberDiff line numberDiff line change
@@ -810,7 +810,7 @@ class StableDiffusionGGML {
810810
int step,
811811
struct ggml_tensor* latents,
812812
enum SDVersion version,
813-
sd_preview_policy_t preview_mode,
813+
sd_preview_t preview_mode,
814814
ggml_tensor* result,
815815
std::function<void(int, sd_image_t)> step_callback) {
816816
const uint32_t channel = 3;
@@ -922,14 +922,11 @@ class StableDiffusionGGML {
922922
const std::vector<float>& sigmas,
923923
int start_merge_step,
924924
SDCondition id_cond,
925-
std::vector<int> skip_layers = {},
926-
float slg_scale = 0,
927-
float skip_layer_start = 0.01,
928-
float skip_layer_end = 0.2,
929-
ggml_tensor* noise_mask = nullptr,
930-
sd_preview_policy_t preview_mode = SD_PREVIEW_NONE,
931-
int preview_interval = 1,
932-
std::function<void(int, sd_image_t)> step_callback = nullptr) {
925+
std::vector<int> skip_layers = {},
926+
float slg_scale = 0,
927+
float skip_layer_start = 0.01,
928+
float skip_layer_end = 0.2,
929+
ggml_tensor* noise_mask = nullptr) {
933930
size_t steps = sigmas.size() - 1;
934931
// noise = load_tensor_from_file(work_ctx, "./rand0.bin");
935932
// print_ggml_tensor(noise);
@@ -961,7 +958,8 @@ class StableDiffusionGGML {
961958
struct ggml_tensor* denoised = ggml_dup_tensor(work_ctx, x);
962959

963960
struct ggml_tensor* preview_tensor = NULL;
964-
if (preview_mode != SD_PREVIEW_NONE && preview_mode != SD_PREVIEW_PROJ) {
961+
auto sd_preview_mode = sd_get_preview_mode();
962+
if (sd_preview_mode != SD_PREVIEW_NONE && sd_preview_mode != SD_PREVIEW_PROJ) {
965963
preview_tensor = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32,
966964
(denoised->ne[0] * 8),
967965
(denoised->ne[1] * 8),
@@ -1109,10 +1107,11 @@ class StableDiffusionGGML {
11091107
pretty_progress(step, (int)steps, (t1 - t0) / 1000000.f);
11101108
// LOG_INFO("step %d sampling completed taking %.2fs", step, (t1 - t0) * 1.0f / 1000000);
11111109
}
1112-
1113-
if (step_callback != nullptr) {
1114-
if (step % preview_interval == 0) {
1115-
preview_image(work_ctx, step, denoised, version, preview_mode, preview_tensor, step_callback);
1110+
auto sd_preview_cb = sd_get_preview_callback();
1111+
auto sd_preview_mode = sd_get_preview_mode();
1112+
if (sd_preview_cb != NULL) {
1113+
if (step % sd_get_preview_interval() == 0) {
1114+
preview_image(work_ctx, step, denoised, version, sd_preview_mode, preview_tensor, sd_preview_cb);
11161115
}
11171116
}
11181117
return denoised;
@@ -1338,14 +1337,11 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
13381337
float style_ratio,
13391338
bool normalize_input,
13401339
std::string input_id_images_path,
1341-
std::vector<int> skip_layers = {},
1342-
float slg_scale = 0,
1343-
float skip_layer_start = 0.01,
1344-
float skip_layer_end = 0.2,
1345-
ggml_tensor* masked_image = NULL,
1346-
sd_preview_policy_t preview_mode = SD_PREVIEW_NONE,
1347-
int preview_interval = 1,
1348-
std::function<void(int, sd_image_t)> step_callback = nullptr) {
1340+
std::vector<int> skip_layers = {},
1341+
float slg_scale = 0,
1342+
float skip_layer_start = 0.01,
1343+
float skip_layer_end = 0.2,
1344+
ggml_tensor* masked_image = NULL) {
13491345
if (seed < 0) {
13501346
// Generally, when using the provided command line, the seed is always >0.
13511347
// However, to prevent potential issues if 'stable-diffusion.cpp' is invoked as a library
@@ -1602,10 +1598,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
16021598
slg_scale,
16031599
skip_layer_start,
16041600
skip_layer_end,
1605-
noise_mask,
1606-
preview_mode,
1607-
preview_interval,
1608-
step_callback);
1601+
noise_mask);
16091602

16101603
// struct ggml_tensor* x_0 = load_tensor_from_file(ctx, "samples_ddim.bin");
16111604
// print_ggml_tensor(x_0);
@@ -1674,14 +1667,11 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
16741667
float style_ratio,
16751668
bool normalize_input,
16761669
const char* input_id_images_path_c_str,
1677-
int* skip_layers = NULL,
1678-
size_t skip_layers_count = 0,
1679-
float slg_scale = 0,
1680-
float skip_layer_start = 0.01,
1681-
float skip_layer_end = 0.2,
1682-
sd_preview_policy_t preview_mode = SD_PREVIEW_NONE,
1683-
int preview_interval = 1,
1684-
step_callback_t step_callback = NULL) {
1670+
int* skip_layers = NULL,
1671+
size_t skip_layers_count = 0,
1672+
float slg_scale = 0,
1673+
float skip_layer_start = 0.01,
1674+
float skip_layer_end = 0.2) {
16851675
std::vector<int> skip_layers_vec(skip_layers, skip_layers + skip_layers_count);
16861676
LOG_DEBUG("txt2img %dx%d", width, height);
16871677
if (sd_ctx == NULL) {
@@ -1699,7 +1689,8 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
16991689
if (sd_ctx->sd->stacked_id) {
17001690
params.mem_size += static_cast<size_t>(10 * 1024 * 1024); // 10 MB
17011691
}
1702-
if (preview_mode != SD_PREVIEW_NONE && preview_mode != SD_PREVIEW_PROJ) {
1692+
auto sd_preview_mode = sd_get_preview_mode();
1693+
if (sd_preview_mode != SD_PREVIEW_NONE && sd_preview_mode != SD_PREVIEW_PROJ) {
17031694
params.mem_size *= 2;
17041695
}
17051696
params.mem_size += width * height * 3 * sizeof(float);
@@ -1763,10 +1754,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
17631754
slg_scale,
17641755
skip_layer_start,
17651756
skip_layer_end,
1766-
NULL,
1767-
preview_mode,
1768-
preview_interval,
1769-
step_callback);
1757+
NULL);
17701758

17711759
size_t t1 = ggml_time_ms();
17721760

@@ -1796,14 +1784,11 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
17961784
float style_ratio,
17971785
bool normalize_input,
17981786
const char* input_id_images_path_c_str,
1799-
int* skip_layers = NULL,
1800-
size_t skip_layers_count = 0,
1801-
float slg_scale = 0,
1802-
float skip_layer_start = 0.01,
1803-
float skip_layer_end = 0.2,
1804-
sd_preview_policy_t preview_mode = SD_PREVIEW_NONE,
1805-
int preview_interval = 1,
1806-
step_callback_t step_callback = NULL) {
1787+
int* skip_layers = NULL,
1788+
size_t skip_layers_count = 0,
1789+
float slg_scale = 0,
1790+
float skip_layer_start = 0.01,
1791+
float skip_layer_end = 0.2) {
18071792
std::vector<int> skip_layers_vec(skip_layers, skip_layers + skip_layers_count);
18081793
LOG_DEBUG("img2img %dx%d", width, height);
18091794
if (sd_ctx == NULL) {
@@ -1950,10 +1935,7 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
19501935
slg_scale,
19511936
skip_layer_start,
19521937
skip_layer_end,
1953-
masked_image,
1954-
preview_mode,
1955-
preview_interval,
1956-
step_callback);
1938+
masked_image);
19571939

19581940
size_t t2 = ggml_time_ms();
19591941

@@ -2057,9 +2039,7 @@ SD_API sd_image_t* img2vid(sd_ctx_t* sd_ctx,
20572039
-1,
20582040
SDCondition(NULL, NULL, NULL),
20592041
{},
2060-
0, 0, 0, NULL,
2061-
(sd_preview_policy_t)0, 1,
2062-
NULL);
2042+
0, 0, 0, NULL);
20632043

20642044
int64_t t2 = ggml_time_ms();
20652045
LOG_INFO("sampling completed, taking %.2fs", (t2 - t1) * 1.0f / 1000);

stable-diffusion.h

+13-20
Original file line numberDiff line numberDiff line change
@@ -109,31 +109,32 @@ enum sd_log_level_t {
109109
SD_LOG_ERROR
110110
};
111111

112-
enum sd_preview_policy_t {
112+
enum sd_preview_t {
113113
SD_PREVIEW_NONE,
114114
SD_PREVIEW_PROJ,
115115
SD_PREVIEW_TAE,
116116
SD_PREVIEW_VAE,
117117
N_PREVIEWS
118118
};
119119

120+
typedef struct {
121+
uint32_t width;
122+
uint32_t height;
123+
uint32_t channel;
124+
uint8_t* data;
125+
} sd_image_t;
126+
120127
typedef void (*sd_log_cb_t)(enum sd_log_level_t level, const char* text, void* data);
121128
typedef void (*sd_progress_cb_t)(int step, int steps, float time, void* data);
129+
typedef void (*sd_preview_cb_t)(int, sd_image_t);
130+
122131

123132
SD_API void sd_set_log_callback(sd_log_cb_t sd_log_cb, void* data);
124133
SD_API void sd_set_progress_callback(sd_progress_cb_t cb, void* data);
125-
SD_API sd_progress_cb_t sd_get_progress_callback();
126-
SD_API void* sd_get_progress_callback_data();
134+
SD_API void sd_set_preview_callback(sd_preview_cb_t cb, sd_preview_t mode, int interval);
127135
SD_API int32_t get_num_physical_cores();
128136
SD_API const char* sd_get_system_info();
129137

130-
typedef struct {
131-
uint32_t width;
132-
uint32_t height;
133-
uint32_t channel;
134-
uint8_t* data;
135-
} sd_image_t;
136-
137138
typedef struct sd_ctx_t sd_ctx_t;
138139

139140
SD_API sd_ctx_t* new_sd_ctx(const char* model_path,
@@ -162,8 +163,6 @@ SD_API sd_ctx_t* new_sd_ctx(const char* model_path,
162163

163164
SD_API void free_sd_ctx(sd_ctx_t* sd_ctx);
164165

165-
typedef void (*step_callback_t)(int, sd_image_t);
166-
167166
SD_API sd_image_t* txt2img(sd_ctx_t* sd_ctx,
168167
const char* prompt,
169168
const char* negative_prompt,
@@ -186,10 +185,7 @@ SD_API sd_image_t* txt2img(sd_ctx_t* sd_ctx,
186185
size_t skip_layers_count,
187186
float slg_scale,
188187
float skip_layer_start,
189-
float skip_layer_end,
190-
sd_preview_policy_t preview_mode,
191-
int preview_interval,
192-
step_callback_t step_callback);
188+
float skip_layer_end);
193189

194190
SD_API sd_image_t* img2img(sd_ctx_t* sd_ctx,
195191
sd_image_t init_image,
@@ -216,10 +212,7 @@ SD_API sd_image_t* img2img(sd_ctx_t* sd_ctx,
216212
size_t skip_layers_count,
217213
float slg_scale,
218214
float skip_layer_start,
219-
float skip_layer_end,
220-
sd_preview_policy_t preview_mode,
221-
int preview_interval,
222-
step_callback_t step_callback);
215+
float skip_layer_end);
223216

224217
SD_API sd_image_t* img2vid(sd_ctx_t* sd_ctx,
225218
sd_image_t init_image,

util.cpp

+23-2
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,10 @@ int32_t get_num_physical_cores() {
247247
static sd_progress_cb_t sd_progress_cb = NULL;
248248
void* sd_progress_cb_data = NULL;
249249

250+
static sd_preview_cb_t sd_preview_cb = NULL;
251+
sd_preview_t sd_preview_mode = SD_PREVIEW_NONE;
252+
int sd_preview_interval = 1;
253+
250254
std::u32string utf8_to_utf32(const std::string& utf8_str) {
251255
std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> converter;
252256
return converter.from_bytes(utf8_str);
@@ -420,10 +424,27 @@ void sd_set_progress_callback(sd_progress_cb_t cb, void* data) {
420424
sd_progress_cb = cb;
421425
sd_progress_cb_data = data;
422426
}
423-
sd_progress_cb_t sd_get_progress_callback(){
427+
void sd_set_preview_callback(sd_preview_cb_t cb, sd_preview_t mode = SD_PREVIEW_PROJ, int interval = 1) {
428+
sd_preview_cb = cb;
429+
sd_preview_mode = mode;
430+
sd_preview_interval = interval;
431+
}
432+
433+
sd_preview_cb_t sd_get_preview_callback() {
434+
return sd_preview_cb;
435+
}
436+
437+
sd_preview_t sd_get_preview_mode() {
438+
return sd_preview_mode;
439+
}
440+
int sd_get_preview_interval() {
441+
return sd_preview_interval;
442+
}
443+
444+
sd_progress_cb_t sd_get_progress_callback() {
424445
return sd_progress_cb;
425446
}
426-
void* sd_get_progress_callback_data(){
447+
void* sd_get_progress_callback_data() {
427448
return sd_progress_cb_data;
428449
}
429450
const char* sd_get_system_info() {

util.h

+7
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,13 @@ std::string trim(const std::string& s);
5454

5555
std::vector<std::pair<std::string, float>> parse_prompt_attention(const std::string& text);
5656

57+
sd_progress_cb_t sd_get_progress_callback();
58+
void* sd_get_progress_callback_data();
59+
60+
sd_preview_cb_t sd_get_preview_callback();
61+
sd_preview_t sd_get_preview_mode();
62+
int sd_get_preview_interval();
63+
5764
#define LOG_DEBUG(format, ...) log_printf(SD_LOG_DEBUG, __FILE__, __LINE__, format, ##__VA_ARGS__)
5865
#define LOG_INFO(format, ...) log_printf(SD_LOG_INFO, __FILE__, __LINE__, format, ##__VA_ARGS__)
5966
#define LOG_WARN(format, ...) log_printf(SD_LOG_WARN, __FILE__, __LINE__, format, ##__VA_ARGS__)

0 commit comments

Comments
 (0)