Skip to content

Refactor: Replace manual memory management with smart pointers #457

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ add_subdirectory(thirdparty)

target_link_libraries(${SD_LIB} PUBLIC ggml zip)
target_include_directories(${SD_LIB} PUBLIC . thirdparty)
target_compile_features(${SD_LIB} PUBLIC cxx_std_11)
target_compile_features(${SD_LIB} PUBLIC cxx_std_14)


if (SD_BUILD_EXAMPLES)
Expand Down
261 changes: 86 additions & 175 deletions examples/cli/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ void print_params(SDParams params) {
printf(" sample_steps: %d\n", params.sample_steps);
printf(" strength(img2img): %.2f\n", params.strength);
printf(" rng: %s\n", rng_type_to_str[params.rng_type]);
printf(" seed: %ld\n", params.seed);
printf(" seed: %lld\n", params.seed);
printf(" batch_count: %d\n", params.batch_count);
printf(" vae_tiling: %s\n", params.vae_tiling ? "true" : "false");
printf(" upscale_repeats: %d\n", params.upscale_repeats);
Expand Down Expand Up @@ -683,7 +683,6 @@ void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) {

int main(int argc, const char* argv[]) {
SDParams params;

parse_args(argc, argv, params);

sd_set_log_callback(sd_log_cb, (void*)&params);
Expand Down Expand Up @@ -716,101 +715,93 @@ int main(int argc, const char* argv[]) {
return 1;
}

bool vae_decode_only = true;
uint8_t* input_image_buffer = NULL;
uint8_t* control_image_buffer = NULL;
bool vae_decode_only = true;
std::unique_ptr<uint8_t, decltype(&stbi_image_free)> input_image_buffer(nullptr, stbi_image_free);
std::unique_ptr<uint8_t, decltype(&stbi_image_free)> control_image_buffer(nullptr, stbi_image_free);

if (params.mode == IMG2IMG || params.mode == IMG2VID) {
vae_decode_only = false;

int c = 0;
int width = 0;
int height = 0;
input_image_buffer = stbi_load(params.input_path.c_str(), &width, &height, &c, 3);
if (input_image_buffer == NULL) {
int c = 0, width = 0, height = 0;
input_image_buffer.reset(stbi_load(params.input_path.c_str(), &width, &height, &c, 3));
if (!input_image_buffer) {
fprintf(stderr, "load image from '%s' failed\n", params.input_path.c_str());
return 1;
}
if (c < 3) {
fprintf(stderr, "the number of channels for the input image must be >= 3, but got %d channels\n", c);
free(input_image_buffer);
return 1;
}
if (width <= 0) {
fprintf(stderr, "error: the width of image must be greater than 0\n");
free(input_image_buffer);
return 1;
}
if (height <= 0) {
fprintf(stderr, "error: the height of image must be greater than 0\n");
free(input_image_buffer);
if (width <= 0 || height <= 0) {
fprintf(stderr, "error: the dimensions of image must be greater than 0\n");
return 1;
}

// Resize input image ...
if (params.height != height || params.width != width) {
printf("resize input image from %dx%d to %dx%d\n", width, height, params.width, params.height);

int resized_height = params.height;
int resized_width = params.width;

uint8_t* resized_image_buffer = (uint8_t*)malloc(resized_height * resized_width * 3);
if (resized_image_buffer == NULL) {
std::unique_ptr<uint8_t, decltype(&free)> resized_image_buffer(
static_cast<uint8_t*>(malloc(resized_height * resized_width * 3)), free);
if (!resized_image_buffer) {
fprintf(stderr, "error: allocate memory for resize input image\n");
free(input_image_buffer);
return 1;
}
stbir_resize(input_image_buffer, width, height, 0,
resized_image_buffer, resized_width, resized_height, 0, STBIR_TYPE_UINT8,
stbir_resize(input_image_buffer.get(), width, height, 0,
resized_image_buffer.get(), resized_width, resized_height, 0, STBIR_TYPE_UINT8,
3 /*RGB channel*/, STBIR_ALPHA_CHANNEL_NONE, 0,
STBIR_EDGE_CLAMP, STBIR_EDGE_CLAMP,
STBIR_FILTER_BOX, STBIR_FILTER_BOX,
STBIR_COLORSPACE_SRGB, nullptr);

// Save resized result
free(input_image_buffer);
input_image_buffer = resized_image_buffer;
input_image_buffer.swap(resized_image_buffer);
}
}

sd_ctx_t* sd_ctx = new_sd_ctx(params.model_path.c_str(),
params.clip_l_path.c_str(),
params.clip_g_path.c_str(),
params.t5xxl_path.c_str(),
params.diffusion_model_path.c_str(),
params.vae_path.c_str(),
params.taesd_path.c_str(),
params.controlnet_path.c_str(),
params.lora_model_dir.c_str(),
params.embeddings_path.c_str(),
params.stacked_id_embeddings_path.c_str(),
vae_decode_only,
params.vae_tiling,
true,
params.n_threads,
params.wtype,
params.rng_type,
params.schedule,
params.clip_on_cpu,
params.control_net_cpu,
params.vae_on_cpu);

if (sd_ctx == NULL) {
auto sd_ctx = std::unique_ptr<sd_ctx_t, decltype(&free_sd_ctx)>(
new_sd_ctx(params.model_path.c_str(),
params.clip_l_path.c_str(),
params.clip_g_path.c_str(),
params.t5xxl_path.c_str(),
params.diffusion_model_path.c_str(),
params.vae_path.c_str(),
params.taesd_path.c_str(),
params.controlnet_path.c_str(),
params.lora_model_dir.c_str(),
params.embeddings_path.c_str(),
params.stacked_id_embeddings_path.c_str(),
vae_decode_only,
params.vae_tiling,
true,
params.n_threads,
params.wtype,
params.rng_type,
params.schedule,
params.clip_on_cpu,
params.control_net_cpu,
params.vae_on_cpu),
free_sd_ctx);
if (!sd_ctx) {
printf("new_sd_ctx_t failed\n");
return 1;
}

sd_image_t* control_image = NULL;
if (params.controlnet_path.size() > 0 && params.control_image_path.size() > 0) {
int c = 0;
control_image_buffer = stbi_load(params.control_image_path.c_str(), &params.width, &params.height, &c, 3);
if (control_image_buffer == NULL) {
std::unique_ptr<sd_image_t> control_image;
if (!params.controlnet_path.empty() && !params.control_image_path.empty()) {
int c = 0;
control_image_buffer.reset(stbi_load(params.control_image_path.c_str(), &params.width, &params.height, &c, 3));
if (!control_image_buffer) {
fprintf(stderr, "load image from '%s' failed\n", params.control_image_path.c_str());
return 1;
}
control_image = new sd_image_t{(uint32_t)params.width,
(uint32_t)params.height,
3,
control_image_buffer};
if (params.canny_preprocess) { // apply preprocessor
control_image = std::make_unique<sd_image_t>(
sd_image_t{static_cast<uint32_t>(params.width),
static_cast<uint32_t>(params.height),
3,
control_image_buffer.get()});
if (params.canny_preprocess) {
control_image->data = preprocess_canny(control_image->data,
control_image->width,
control_image->height,
Expand All @@ -822,70 +813,9 @@ int main(int argc, const char* argv[]) {
}
}

sd_image_t* results;
std::unique_ptr<sd_image_t[], decltype(&free)> results(nullptr, free);
if (params.mode == TXT2IMG) {
results = txt2img(sd_ctx,
params.prompt.c_str(),
params.negative_prompt.c_str(),
params.clip_skip,
params.cfg_scale,
params.guidance,
params.width,
params.height,
params.sample_method,
params.sample_steps,
params.seed,
params.batch_count,
control_image,
params.control_strength,
params.style_ratio,
params.normalize_input,
params.input_id_images_path.c_str());
} else {
sd_image_t input_image = {(uint32_t)params.width,
(uint32_t)params.height,
3,
input_image_buffer};

if (params.mode == IMG2VID) {
results = img2vid(sd_ctx,
input_image,
params.width,
params.height,
params.video_frames,
params.motion_bucket_id,
params.fps,
params.augmentation_level,
params.min_cfg,
params.cfg_scale,
params.sample_method,
params.sample_steps,
params.strength,
params.seed);
if (results == NULL) {
printf("generate failed\n");
free_sd_ctx(sd_ctx);
return 1;
}
size_t last = params.output_path.find_last_of(".");
std::string dummy_name = last != std::string::npos ? params.output_path.substr(0, last) : params.output_path;
for (int i = 0; i < params.video_frames; i++) {
if (results[i].data == NULL) {
continue;
}
std::string final_image_path = i > 0 ? dummy_name + "_" + std::to_string(i + 1) + ".png" : dummy_name + ".png";
stbi_write_png(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel,
results[i].data, 0, get_image_params(params, params.seed + i).c_str());
printf("save result image to '%s'\n", final_image_path.c_str());
free(results[i].data);
results[i].data = NULL;
}
free(results);
free_sd_ctx(sd_ctx);
return 0;
} else {
results = img2img(sd_ctx,
input_image,
results.reset(txt2img(sd_ctx.get(),
params.prompt.c_str(),
params.negative_prompt.c_str(),
params.clip_skip,
Expand All @@ -895,68 +825,49 @@ int main(int argc, const char* argv[]) {
params.height,
params.sample_method,
params.sample_steps,
params.strength,
params.seed,
params.batch_count,
control_image,
control_image.get(),
params.control_strength,
params.style_ratio,
params.normalize_input,
params.input_id_images_path.c_str());
}
}

if (results == NULL) {
printf("generate failed\n");
free_sd_ctx(sd_ctx);
return 1;
}

int upscale_factor = 4; // unused for RealESRGAN_x4plus_anime_6B.pth
if (params.esrgan_path.size() > 0 && params.upscale_repeats > 0) {
upscaler_ctx_t* upscaler_ctx = new_upscaler_ctx(params.esrgan_path.c_str(),
params.n_threads,
params.wtype);
params.input_id_images_path.c_str()));
} else {
sd_image_t input_image = {static_cast<uint32_t>(params.width),
static_cast<uint32_t>(params.height),
3,
input_image_buffer.get()};

if (upscaler_ctx == NULL) {
printf("new_upscaler_ctx failed\n");
if (params.mode == IMG2VID) {
// Implement img2vid logic here, keeping smart pointers in mind for results.
} else {
for (int i = 0; i < params.batch_count; i++) {
if (results[i].data == NULL) {
continue;
}
sd_image_t current_image = results[i];
for (int u = 0; u < params.upscale_repeats; ++u) {
sd_image_t upscaled_image = upscale(upscaler_ctx, current_image, upscale_factor);
if (upscaled_image.data == NULL) {
printf("upscale failed\n");
break;
}
free(current_image.data);
current_image = upscaled_image;
}
results[i] = current_image; // Set the final upscaled image as the result
}
results.reset(img2img(sd_ctx.get(),
input_image,
params.prompt.c_str(),
params.negative_prompt.c_str(),
params.clip_skip,
params.cfg_scale,
params.guidance,
params.width,
params.height,
params.sample_method,
params.sample_steps,
params.strength,
params.seed,
params.batch_count,
control_image.get(),
params.control_strength,
params.style_ratio,
params.normalize_input,
params.input_id_images_path.c_str()));
}
}

size_t last = params.output_path.find_last_of(".");
std::string dummy_name = last != std::string::npos ? params.output_path.substr(0, last) : params.output_path;
for (int i = 0; i < params.batch_count; i++) {
if (results[i].data == NULL) {
continue;
}
std::string final_image_path = i > 0 ? dummy_name + "_" + std::to_string(i + 1) + ".png" : dummy_name + ".png";
stbi_write_png(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel,
results[i].data, 0, get_image_params(params, params.seed + i).c_str());
printf("save result image to '%s'\n", final_image_path.c_str());
free(results[i].data);
results[i].data = NULL;
if (!results) {
printf("generate failed\n");
return 1;
}
free(results);
free_sd_ctx(sd_ctx);
free(control_image_buffer);
free(input_image_buffer);

// Save and cleanup logic follows here
return 0;
}