Skip to content

Commit dfd9c4c

Browse files
committed
Refactor preview to match the other callbacks
1 parent a539cb5 commit dfd9c4c

File tree

5 files changed

+90
-94
lines changed

5 files changed

+90
-94
lines changed

examples/cli/main.cpp

+13-18
Original file line numberDiff line numberDiff line change
@@ -134,10 +134,10 @@ struct SDParams {
134134
float skip_layer_start = 0.01;
135135
float skip_layer_end = 0.2;
136136

137-
sd_preview_policy_t preview_method = SD_PREVIEW_NONE;
138-
int preview_interval = 1;
139-
std::string preview_path = "preview.png";
140-
bool taesd_preview = false;
137+
sd_preview_t preview_method = SD_PREVIEW_NONE;
138+
int preview_interval = 1;
139+
std::string preview_path = "preview.png";
140+
bool taesd_preview = false;
141141
};
142142

143143
void print_params(SDParams params) {
@@ -653,7 +653,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
653653
invalid_arg = true;
654654
break;
655655
}
656-
params.preview_method = (sd_preview_policy_t)preview_method;
656+
params.preview_method = (sd_preview_t)preview_method;
657657
} else if (arg == "--preview-interval") {
658658
if (++i >= argc) {
659659
invalid_arg = true;
@@ -836,6 +836,7 @@ int main(int argc, const char* argv[]) {
836836
preview_path = params.preview_path.c_str();
837837

838838
sd_set_log_callback(sd_log_cb, (void*)&params);
839+
sd_set_preview_callback((sd_preview_cb_t)step_callback, params.preview_method, params.preview_interval);
839840

840841
if (params.verbose) {
841842
print_params(params);
@@ -1010,10 +1011,7 @@ int main(int argc, const char* argv[]) {
10101011
params.skip_layers.size(),
10111012
params.slg_scale,
10121013
params.skip_layer_start,
1013-
params.skip_layer_end,
1014-
params.preview_method,
1015-
params.preview_interval,
1016-
(step_callback_t)step_callback);
1014+
params.skip_layer_end);
10171015
} else {
10181016
sd_image_t input_image = {(uint32_t)params.width,
10191017
(uint32_t)params.height,
@@ -1081,10 +1079,7 @@ int main(int argc, const char* argv[]) {
10811079
params.skip_layers.size(),
10821080
params.slg_scale,
10831081
params.skip_layer_start,
1084-
params.skip_layer_end,
1085-
params.preview_method,
1086-
params.preview_interval,
1087-
(step_callback_t)step_callback);
1082+
params.skip_layer_end);
10881083
}
10891084
}
10901085

@@ -1123,19 +1118,19 @@ int main(int argc, const char* argv[]) {
11231118

11241119
std::string dummy_name, ext, lc_ext;
11251120
bool is_jpg;
1126-
size_t last = params.output_path.find_last_of(".");
1121+
size_t last = params.output_path.find_last_of(".");
11271122
size_t last_path = std::min(params.output_path.find_last_of("/"),
11281123
params.output_path.find_last_of("\\"));
1129-
if (last != std::string::npos // filename has extension
1130-
&& (last_path == std::string::npos || last > last_path)) {
1124+
if (last != std::string::npos // filename has extension
1125+
&& (last_path == std::string::npos || last > last_path)) {
11311126
dummy_name = params.output_path.substr(0, last);
11321127
ext = lc_ext = params.output_path.substr(last);
11331128
std::transform(ext.begin(), ext.end(), lc_ext.begin(), ::tolower);
11341129
is_jpg = lc_ext == ".jpg" || lc_ext == ".jpeg" || lc_ext == ".jpe";
11351130
} else {
11361131
dummy_name = params.output_path;
11371132
ext = lc_ext = "";
1138-
is_jpg = false;
1133+
is_jpg = false;
11391134
}
11401135
// appending ".png" to absent or unknown extension
11411136
if (!is_jpg && lc_ext != ".png") {
@@ -1147,7 +1142,7 @@ int main(int argc, const char* argv[]) {
11471142
continue;
11481143
}
11491144
std::string final_image_path = i > 0 ? dummy_name + "_" + std::to_string(i + 1) + ext : dummy_name + ext;
1150-
if(is_jpg) {
1145+
if (is_jpg) {
11511146
stbi_write_jpg(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel,
11521147
results[i].data, 90, get_image_params(params, params.seed + i).c_str());
11531148
printf("save result JPEG image to '%s'\n", final_image_path.c_str());

stable-diffusion.cpp

+34-54
Original file line numberDiff line numberDiff line change
@@ -808,7 +808,7 @@ class StableDiffusionGGML {
808808
int step,
809809
struct ggml_tensor* latents,
810810
enum SDVersion version,
811-
sd_preview_policy_t preview_mode,
811+
sd_preview_t preview_mode,
812812
ggml_tensor* result,
813813
std::function<void(int, sd_image_t)> step_callback) {
814814
const uint32_t channel = 3;
@@ -919,14 +919,11 @@ class StableDiffusionGGML {
919919
const std::vector<float>& sigmas,
920920
int start_merge_step,
921921
SDCondition id_cond,
922-
std::vector<int> skip_layers = {},
923-
float slg_scale = 0,
924-
float skip_layer_start = 0.01,
925-
float skip_layer_end = 0.2,
926-
ggml_tensor* noise_mask = nullptr,
927-
sd_preview_policy_t preview_mode = SD_PREVIEW_NONE,
928-
int preview_interval = 1,
929-
std::function<void(int, sd_image_t)> step_callback = nullptr) {
922+
std::vector<int> skip_layers = {},
923+
float slg_scale = 0,
924+
float skip_layer_start = 0.01,
925+
float skip_layer_end = 0.2,
926+
ggml_tensor* noise_mask = nullptr) {
930927
size_t steps = sigmas.size() - 1;
931928
// noise = load_tensor_from_file(work_ctx, "./rand0.bin");
932929
// print_ggml_tensor(noise);
@@ -958,7 +955,8 @@ class StableDiffusionGGML {
958955
struct ggml_tensor* denoised = ggml_dup_tensor(work_ctx, x);
959956

960957
struct ggml_tensor* preview_tensor = NULL;
961-
if (preview_mode != SD_PREVIEW_NONE && preview_mode != SD_PREVIEW_PROJ) {
958+
auto sd_preview_mode = sd_get_preview_mode();
959+
if (sd_preview_mode != SD_PREVIEW_NONE && sd_preview_mode != SD_PREVIEW_PROJ) {
962960
preview_tensor = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32,
963961
(denoised->ne[0] * 8),
964962
(denoised->ne[1] * 8),
@@ -1106,10 +1104,11 @@ class StableDiffusionGGML {
11061104
pretty_progress(step, (int)steps, (t1 - t0) / 1000000.f);
11071105
// LOG_INFO("step %d sampling completed taking %.2fs", step, (t1 - t0) * 1.0f / 1000000);
11081106
}
1109-
1110-
if (step_callback != nullptr) {
1111-
if (step % preview_interval == 0) {
1112-
preview_image(work_ctx, step, denoised, version, preview_mode, preview_tensor, step_callback);
1107+
auto sd_preview_cb = sd_get_preview_callback();
1108+
auto sd_preview_mode = sd_get_preview_mode();
1109+
if (sd_preview_cb != NULL) {
1110+
if (step % sd_get_preview_interval() == 0) {
1111+
preview_image(work_ctx, step, denoised, version, sd_preview_mode, preview_tensor, sd_preview_cb);
11131112
}
11141113
}
11151114
return denoised;
@@ -1334,14 +1333,11 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
13341333
float style_ratio,
13351334
bool normalize_input,
13361335
std::string input_id_images_path,
1337-
std::vector<int> skip_layers = {},
1338-
float slg_scale = 0,
1339-
float skip_layer_start = 0.01,
1340-
float skip_layer_end = 0.2,
1341-
ggml_tensor* masked_image = NULL,
1342-
sd_preview_policy_t preview_mode = SD_PREVIEW_NONE,
1343-
int preview_interval = 1,
1344-
std::function<void(int, sd_image_t)> step_callback = nullptr) {
1336+
std::vector<int> skip_layers = {},
1337+
float slg_scale = 0,
1338+
float skip_layer_start = 0.01,
1339+
float skip_layer_end = 0.2,
1340+
ggml_tensor* masked_image = NULL) {
13451341
if (seed < 0) {
13461342
// Generally, when using the provided command line, the seed is always >0.
13471343
// However, to prevent potential issues if 'stable-diffusion.cpp' is invoked as a library
@@ -1597,10 +1593,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
15971593
slg_scale,
15981594
skip_layer_start,
15991595
skip_layer_end,
1600-
noise_mask,
1601-
preview_mode,
1602-
preview_interval,
1603-
step_callback);
1596+
noise_mask);
16041597

16051598
// struct ggml_tensor* x_0 = load_tensor_from_file(ctx, "samples_ddim.bin");
16061599
// print_ggml_tensor(x_0);
@@ -1668,14 +1661,11 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
16681661
float style_ratio,
16691662
bool normalize_input,
16701663
const char* input_id_images_path_c_str,
1671-
int* skip_layers = NULL,
1672-
size_t skip_layers_count = 0,
1673-
float slg_scale = 0,
1674-
float skip_layer_start = 0.01,
1675-
float skip_layer_end = 0.2,
1676-
sd_preview_policy_t preview_mode = SD_PREVIEW_NONE,
1677-
int preview_interval = 1,
1678-
step_callback_t step_callback = NULL) {
1664+
int* skip_layers = NULL,
1665+
size_t skip_layers_count = 0,
1666+
float slg_scale = 0,
1667+
float skip_layer_start = 0.01,
1668+
float skip_layer_end = 0.2) {
16791669
std::vector<int> skip_layers_vec(skip_layers, skip_layers + skip_layers_count);
16801670
LOG_DEBUG("txt2img %dx%d", width, height);
16811671
if (sd_ctx == NULL) {
@@ -1693,7 +1683,8 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
16931683
if (sd_ctx->sd->stacked_id) {
16941684
params.mem_size += static_cast<size_t>(10 * 1024 * 1024); // 10 MB
16951685
}
1696-
if (preview_mode != SD_PREVIEW_NONE && preview_mode != SD_PREVIEW_PROJ) {
1686+
auto sd_preview_mode = sd_get_preview_mode();
1687+
if (sd_preview_mode != SD_PREVIEW_NONE && sd_preview_mode != SD_PREVIEW_PROJ) {
16971688
params.mem_size *= 2;
16981689
}
16991690
params.mem_size += width * height * 3 * sizeof(float);
@@ -1756,10 +1747,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
17561747
slg_scale,
17571748
skip_layer_start,
17581749
skip_layer_end,
1759-
NULL,
1760-
preview_mode,
1761-
preview_interval,
1762-
step_callback);
1750+
NULL);
17631751

17641752
size_t t1 = ggml_time_ms();
17651753

@@ -1788,14 +1776,11 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
17881776
float style_ratio,
17891777
bool normalize_input,
17901778
const char* input_id_images_path_c_str,
1791-
int* skip_layers = NULL,
1792-
size_t skip_layers_count = 0,
1793-
float slg_scale = 0,
1794-
float skip_layer_start = 0.01,
1795-
float skip_layer_end = 0.2,
1796-
sd_preview_policy_t preview_mode = SD_PREVIEW_NONE,
1797-
int preview_interval = 1,
1798-
step_callback_t step_callback = NULL) {
1779+
int* skip_layers = NULL,
1780+
size_t skip_layers_count = 0,
1781+
float slg_scale = 0,
1782+
float skip_layer_start = 0.01,
1783+
float skip_layer_end = 0.2) {
17991784
std::vector<int> skip_layers_vec(skip_layers, skip_layers + skip_layers_count);
18001785
LOG_DEBUG("img2img %dx%d", width, height);
18011786
if (sd_ctx == NULL) {
@@ -1941,10 +1926,7 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
19411926
slg_scale,
19421927
skip_layer_start,
19431928
skip_layer_end,
1944-
masked_image,
1945-
preview_mode,
1946-
preview_interval,
1947-
step_callback);
1929+
masked_image);
19481930

19491931
size_t t2 = ggml_time_ms();
19501932

@@ -2047,9 +2029,7 @@ SD_API sd_image_t* img2vid(sd_ctx_t* sd_ctx,
20472029
-1,
20482030
SDCondition(NULL, NULL, NULL),
20492031
{},
2050-
0, 0, 0, NULL,
2051-
(sd_preview_policy_t)0, 1,
2052-
NULL);
2032+
0, 0, 0, NULL);
20532033

20542034
int64_t t2 = ggml_time_ms();
20552035
LOG_INFO("sampling completed, taking %.2fs", (t2 - t1) * 1.0f / 1000);

stable-diffusion.h

+13-20
Original file line numberDiff line numberDiff line change
@@ -107,31 +107,32 @@ enum sd_log_level_t {
107107
SD_LOG_ERROR
108108
};
109109

110-
enum sd_preview_policy_t {
110+
enum sd_preview_t {
111111
SD_PREVIEW_NONE,
112112
SD_PREVIEW_PROJ,
113113
SD_PREVIEW_TAE,
114114
SD_PREVIEW_VAE,
115115
N_PREVIEWS
116116
};
117117

118+
typedef struct {
119+
uint32_t width;
120+
uint32_t height;
121+
uint32_t channel;
122+
uint8_t* data;
123+
} sd_image_t;
124+
118125
typedef void (*sd_log_cb_t)(enum sd_log_level_t level, const char* text, void* data);
119126
typedef void (*sd_progress_cb_t)(int step, int steps, float time, void* data);
127+
typedef void (*sd_preview_cb_t)(int, sd_image_t);
128+
120129

121130
SD_API void sd_set_log_callback(sd_log_cb_t sd_log_cb, void* data);
122131
SD_API void sd_set_progress_callback(sd_progress_cb_t cb, void* data);
123-
SD_API sd_progress_cb_t sd_get_progress_callback();
124-
SD_API void* sd_get_progress_callback_data();
132+
SD_API void sd_set_preview_callback(sd_preview_cb_t cb, sd_preview_t mode, int interval);
125133
SD_API int32_t get_num_physical_cores();
126134
SD_API const char* sd_get_system_info();
127135

128-
typedef struct {
129-
uint32_t width;
130-
uint32_t height;
131-
uint32_t channel;
132-
uint8_t* data;
133-
} sd_image_t;
134-
135136
typedef struct sd_ctx_t sd_ctx_t;
136137

137138
SD_API sd_ctx_t* new_sd_ctx(const char* model_path,
@@ -160,8 +161,6 @@ SD_API sd_ctx_t* new_sd_ctx(const char* model_path,
160161

161162
SD_API void free_sd_ctx(sd_ctx_t* sd_ctx);
162163

163-
typedef void (*step_callback_t)(int, sd_image_t);
164-
165164
SD_API sd_image_t* txt2img(sd_ctx_t* sd_ctx,
166165
const char* prompt,
167166
const char* negative_prompt,
@@ -183,10 +182,7 @@ SD_API sd_image_t* txt2img(sd_ctx_t* sd_ctx,
183182
size_t skip_layers_count,
184183
float slg_scale,
185184
float skip_layer_start,
186-
float skip_layer_end,
187-
sd_preview_policy_t preview_mode,
188-
int preview_interval,
189-
step_callback_t step_callback);
185+
float skip_layer_end);
190186

191187
SD_API sd_image_t* img2img(sd_ctx_t* sd_ctx,
192188
sd_image_t init_image,
@@ -212,10 +208,7 @@ SD_API sd_image_t* img2img(sd_ctx_t* sd_ctx,
212208
size_t skip_layers_count,
213209
float slg_scale,
214210
float skip_layer_start,
215-
float skip_layer_end,
216-
sd_preview_policy_t preview_mode,
217-
int preview_interval,
218-
step_callback_t step_callback);
211+
float skip_layer_end);
219212

220213
SD_API sd_image_t* img2vid(sd_ctx_t* sd_ctx,
221214
sd_image_t init_image,

util.cpp

+23-2
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,10 @@ int32_t get_num_physical_cores() {
234234
static sd_progress_cb_t sd_progress_cb = NULL;
235235
void* sd_progress_cb_data = NULL;
236236

237+
static sd_preview_cb_t sd_preview_cb = NULL;
238+
sd_preview_t sd_preview_mode = SD_PREVIEW_NONE;
239+
int sd_preview_interval = 1;
240+
237241
std::u32string utf8_to_utf32(const std::string& utf8_str) {
238242
std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> converter;
239243
return converter.from_bytes(utf8_str);
@@ -407,10 +411,27 @@ void sd_set_progress_callback(sd_progress_cb_t cb, void* data) {
407411
sd_progress_cb = cb;
408412
sd_progress_cb_data = data;
409413
}
410-
sd_progress_cb_t sd_get_progress_callback(){
414+
void sd_set_preview_callback(sd_preview_cb_t cb, sd_preview_t mode = SD_PREVIEW_PROJ, int interval = 1) {
415+
sd_preview_cb = cb;
416+
sd_preview_mode = mode;
417+
sd_preview_interval = interval;
418+
}
419+
420+
sd_preview_cb_t sd_get_preview_callback() {
421+
return sd_preview_cb;
422+
}
423+
424+
sd_preview_t sd_get_preview_mode() {
425+
return sd_preview_mode;
426+
}
427+
int sd_get_preview_interval() {
428+
return sd_preview_interval;
429+
}
430+
431+
sd_progress_cb_t sd_get_progress_callback() {
411432
return sd_progress_cb;
412433
}
413-
void* sd_get_progress_callback_data(){
434+
void* sd_get_progress_callback_data() {
414435
return sd_progress_cb_data;
415436
}
416437
const char* sd_get_system_info() {

util.h

+7
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,13 @@ std::string trim(const std::string& s);
5454

5555
std::vector<std::pair<std::string, float>> parse_prompt_attention(const std::string& text);
5656

57+
sd_progress_cb_t sd_get_progress_callback();
58+
void* sd_get_progress_callback_data();
59+
60+
sd_preview_cb_t sd_get_preview_callback();
61+
sd_preview_t sd_get_preview_mode();
62+
int sd_get_preview_interval();
63+
5764
#define LOG_DEBUG(format, ...) log_printf(SD_LOG_DEBUG, __FILE__, __LINE__, format, ##__VA_ARGS__)
5865
#define LOG_INFO(format, ...) log_printf(SD_LOG_INFO, __FILE__, __LINE__, format, ##__VA_ARGS__)
5966
#define LOG_WARN(format, ...) log_printf(SD_LOG_WARN, __FILE__, __LINE__, format, ##__VA_ARGS__)

0 commit comments

Comments
 (0)