Skip to content

Commit ae0fcc2

Browse files
committed
Refactor imatrix api, fix build shared libs
1 parent 91a7a66 commit ae0fcc2

File tree

3 files changed

+28
-21
lines changed

3 files changed

+28
-21
lines changed

examples/cli/main.cpp

+4-13
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,6 @@
2222
#define STB_IMAGE_RESIZE_STATIC
2323
#include "stb_image_resize.h"
2424

25-
#define IMATRIX_IMPL
26-
#include "imatrix.hpp"
27-
static IMatrixCollector g_collector;
28-
2925
const char* rng_type_to_str[] = {
3026
"std_default",
3127
"cuda",
@@ -663,7 +659,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
663659
}
664660
}
665661

666-
if (params.imatrix_out.size() > 0 && file_exists(params.imatrix_out)) {
662+
if (params.imatrix_out.size() > 0 && std::ifstream(params.imatrix_out).good()) {
667663
// imatrix file already exists
668664
if (std::find(params.imatrix_in.begin(), params.imatrix_in.end(), params.imatrix_out) == params.imatrix_in.end()) {
669665
printf("\n IMPORTANT: imatrix file %s already exists, but wasn't found in the imatrix inputs.\n", params.imatrix_out.c_str());
@@ -823,10 +819,6 @@ void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) {
823819
fflush(out_stream);
824820
}
825821

826-
static bool collect_imatrix(struct ggml_tensor* t, bool ask, void* user_data) {
827-
return g_collector.collect_imatrix(t, ask, user_data);
828-
}
829-
830822
int main(int argc, const char* argv[]) {
831823
SDParams params;
832824

@@ -840,13 +832,12 @@ int main(int argc, const char* argv[]) {
840832
}
841833

842834
if (params.imatrix_out != "") {
843-
sd_set_backend_eval_callback((sd_graph_eval_callback_t)collect_imatrix, &params);
835+
enableImatrixCollection();
844836
}
845837
if (params.imatrix_out != "" || params.mode == CONVERT || params.wtype != SD_TYPE_COUNT) {
846-
setConvertImatrixCollector((void*)&g_collector);
847838
for (const auto& in_file : params.imatrix_in) {
848839
printf("loading imatrix from '%s'\n", in_file.c_str());
849-
if (!g_collector.load_imatrix(in_file.c_str())) {
840+
if (!loadImatrix(in_file.c_str())) {
850841
printf("Failed to load %s\n", in_file.c_str());
851842
}
852843
}
@@ -1165,7 +1156,7 @@ int main(int argc, const char* argv[]) {
11651156
results[i].data = NULL;
11661157
}
11671158
if (params.imatrix_out != "") {
1168-
g_collector.save_imatrix(params.imatrix_out);
1159+
saveImatrix(params.imatrix_out.c_str());
11691160
}
11701161
free(results);
11711162
free_sd_ctx(sd_ctx);

model.cpp

+19-7
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
#define ST_HEADER_SIZE_LEN 8
3131

32-
static IMatrixCollector* imatrix_collector = NULL;
32+
static IMatrixCollector imatrix_collector;
3333

3434
uint64_t read_u64(uint8_t* buffer) {
3535
// little endian
@@ -1842,7 +1842,7 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
18421842

18431843
auto processed_name = convert_tensor_name(tensor_storage.name);
18441844
// LOG_DEBUG("%s",processed_name.c_str());
1845-
std::vector<float> imatrix = imatrix_collector ? imatrix_collector->get_values(processed_name) : std::vector<float>{};
1845+
std::vector<float> imatrix = imatrix_collector.get_values(processed_name);
18461846

18471847
convert_tensor((void*)read_buffer.data(), tensor_storage.type, dst_tensor->data,
18481848
dst_tensor->type, (int)tensor_storage.nelements() / (int)tensor_storage.ne[0], (int)tensor_storage.ne[0], imatrix);
@@ -1869,7 +1869,7 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
18691869
// convert first, then copy to device memory
18701870
auto processed_name = convert_tensor_name(tensor_storage.name);
18711871
// LOG_DEBUG("%s",processed_name.c_str());
1872-
std::vector<float> imatrix = imatrix_collector ? imatrix_collector->get_values(processed_name) : std::vector<float>{};
1872+
std::vector<float> imatrix = imatrix_collector.get_values(processed_name);
18731873

18741874
convert_buffer.resize(ggml_nbytes(dst_tensor));
18751875
convert_tensor((void*)read_buffer.data(), tensor_storage.type,
@@ -2069,10 +2069,6 @@ int64_t ModelLoader::get_params_mem_size(ggml_backend_t backend, ggml_type type)
20692069
return mem_size;
20702070
}
20712071

2072-
void setConvertImatrixCollector(void* collector) {
2073-
imatrix_collector = ((IMatrixCollector*)collector);
2074-
}
2075-
20762072
bool convert(const char* model_path, const char* clip_l_path, const char* clip_g_path, const char* t5xxl_path, const char* diffusion_model_path, const char* vae_path, const char* output_path, sd_type_t output_type) {
20772073
ModelLoader model_loader;
20782074

@@ -2120,3 +2116,19 @@ bool convert(const char* model_path, const char* clip_l_path, const char* clip_g
21202116
bool success = model_loader.save_to_gguf_file(output_path, (ggml_type)output_type);
21212117
return success;
21222118
}
2119+
2120+
bool loadImatrix(const char* imatrix_path) {
2121+
return imatrix_collector.load_imatrix(imatrix_path);
2122+
}
2123+
void saveImatrix(const char* imatrix_path) {
2124+
imatrix_collector.save_imatrix(imatrix_path);
2125+
}
2126+
static bool collect_imatrix(struct ggml_tensor* t, bool ask, void* user_data) {
2127+
return imatrix_collector.collect_imatrix(t, ask, user_data);
2128+
}
2129+
void enableImatrixCollection() {
2130+
sd_set_backend_eval_callback((sd_graph_eval_callback_t)collect_imatrix, NULL);
2131+
}
2132+
void disableImatrixCollection() {
2133+
sd_set_backend_eval_callback(NULL, NULL);
2134+
}

stable-diffusion.h

+5-1
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,6 @@ SD_API void free_upscaler_ctx(upscaler_ctx_t* upscaler_ctx);
230230

231231
SD_API sd_image_t upscale(upscaler_ctx_t* upscaler_ctx, sd_image_t input_image, uint32_t upscale_factor);
232232

233-
SD_API void setConvertImatrixCollector(void * collector);
234233
SD_API bool convert(const char* model_path, const char* clip_l_path, const char* clip_g_path, const char* t5xxl_path, const char* diffusion_model_path, const char* vae_path, const char* output_path, enum sd_type_t output_type);
235234

236235
SD_API uint8_t* preprocess_canny(uint8_t* img,
@@ -242,6 +241,11 @@ SD_API uint8_t* preprocess_canny(uint8_t* img,
242241
float strong,
243242
bool inverse);
244243

244+
SD_API bool loadImatrix(const char * imatrix_path);
245+
SD_API void saveImatrix(const char * imatrix_path);
246+
SD_API void enableImatrixCollection();
247+
SD_API void disableImatrixCollection();
248+
245249
#ifdef __cplusplus
246250
}
247251
#endif

0 commit comments

Comments
 (0)