@@ -160,7 +160,7 @@ void print_params(SDParams params) {
160
160
printf (" sample_steps: %d\n " , params.sample_steps );
161
161
printf (" strength(img2img): %.2f\n " , params.strength );
162
162
printf (" rng: %s\n " , rng_type_to_str[params.rng_type ]);
163
- printf (" seed: %ld \n " , params.seed );
163
+ printf (" seed: %lld \n " , params.seed );
164
164
printf (" batch_count: %d\n " , params.batch_count );
165
165
printf (" vae_tiling: %s\n " , params.vae_tiling ? " true" : " false" );
166
166
printf (" upscale_repeats: %d\n " , params.upscale_repeats );
@@ -681,9 +681,10 @@ void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) {
681
681
fflush (out_stream);
682
682
}
683
683
684
+ #include < memory>
685
+
684
686
int main (int argc, const char * argv[]) {
685
687
SDParams params;
686
-
687
688
parse_args (argc, argv, params);
688
689
689
690
sd_set_log_callback (sd_log_cb, (void *)¶ms);
@@ -716,101 +717,93 @@ int main(int argc, const char* argv[]) {
716
717
return 1 ;
717
718
}
718
719
719
- bool vae_decode_only = true ;
720
- uint8_t * input_image_buffer = NULL ;
721
- uint8_t * control_image_buffer = NULL ;
720
+ bool vae_decode_only = true ;
721
+ std::unique_ptr<uint8_t , decltype (&stbi_image_free)> input_image_buffer (nullptr , stbi_image_free);
722
+ std::unique_ptr<uint8_t , decltype (&stbi_image_free)> control_image_buffer (nullptr , stbi_image_free);
723
+
722
724
if (params.mode == IMG2IMG || params.mode == IMG2VID) {
723
725
vae_decode_only = false ;
724
726
725
- int c = 0 ;
726
- int width = 0 ;
727
- int height = 0 ;
728
- input_image_buffer = stbi_load (params.input_path .c_str (), &width, &height, &c, 3 );
729
- if (input_image_buffer == NULL ) {
727
+ int c = 0 , width = 0 , height = 0 ;
728
+ input_image_buffer.reset (stbi_load (params.input_path .c_str (), &width, &height, &c, 3 ));
729
+ if (!input_image_buffer) {
730
730
fprintf (stderr, " load image from '%s' failed\n " , params.input_path .c_str ());
731
731
return 1 ;
732
732
}
733
733
if (c < 3 ) {
734
734
fprintf (stderr, " the number of channels for the input image must be >= 3, but got %d channels\n " , c);
735
- free (input_image_buffer);
736
- return 1 ;
737
- }
738
- if (width <= 0 ) {
739
- fprintf (stderr, " error: the width of image must be greater than 0\n " );
740
- free (input_image_buffer);
741
735
return 1 ;
742
736
}
743
- if (height <= 0 ) {
744
- fprintf (stderr, " error: the height of image must be greater than 0\n " );
745
- free (input_image_buffer);
737
+ if (width <= 0 || height <= 0 ) {
738
+ fprintf (stderr, " error: the dimensions of image must be greater than 0\n " );
746
739
return 1 ;
747
740
}
748
741
749
- // Resize input image ...
750
742
if (params.height != height || params.width != width) {
751
743
printf (" resize input image from %dx%d to %dx%d\n " , width, height, params.width , params.height );
744
+
752
745
int resized_height = params.height ;
753
746
int resized_width = params.width ;
754
747
755
- uint8_t * resized_image_buffer = (uint8_t *)malloc (resized_height * resized_width * 3 );
756
- if (resized_image_buffer == NULL ) {
748
+ std::unique_ptr<uint8_t , decltype (&free)> resized_image_buffer (
749
+ static_cast <uint8_t *>(malloc (resized_height * resized_width * 3 )), free);
750
+ if (!resized_image_buffer) {
757
751
fprintf (stderr, " error: allocate memory for resize input image\n " );
758
- free (input_image_buffer);
759
752
return 1 ;
760
753
}
761
- stbir_resize (input_image_buffer, width, height, 0 ,
762
- resized_image_buffer, resized_width, resized_height, 0 , STBIR_TYPE_UINT8,
754
+ stbir_resize (input_image_buffer. get () , width, height, 0 ,
755
+ resized_image_buffer. get () , resized_width, resized_height, 0 , STBIR_TYPE_UINT8,
763
756
3 /* RGB channel*/ , STBIR_ALPHA_CHANNEL_NONE, 0 ,
764
757
STBIR_EDGE_CLAMP, STBIR_EDGE_CLAMP,
765
758
STBIR_FILTER_BOX, STBIR_FILTER_BOX,
766
759
STBIR_COLORSPACE_SRGB, nullptr );
767
760
768
- // Save resized result
769
- free (input_image_buffer);
770
- input_image_buffer = resized_image_buffer;
761
+ input_image_buffer.swap (resized_image_buffer);
771
762
}
772
763
}
773
764
774
- sd_ctx_t * sd_ctx = new_sd_ctx (params.model_path .c_str (),
775
- params.clip_l_path .c_str (),
776
- params.clip_g_path .c_str (),
777
- params.t5xxl_path .c_str (),
778
- params.diffusion_model_path .c_str (),
779
- params.vae_path .c_str (),
780
- params.taesd_path .c_str (),
781
- params.controlnet_path .c_str (),
782
- params.lora_model_dir .c_str (),
783
- params.embeddings_path .c_str (),
784
- params.stacked_id_embeddings_path .c_str (),
785
- vae_decode_only,
786
- params.vae_tiling ,
787
- true ,
788
- params.n_threads ,
789
- params.wtype ,
790
- params.rng_type ,
791
- params.schedule ,
792
- params.clip_on_cpu ,
793
- params.control_net_cpu ,
794
- params.vae_on_cpu );
795
-
796
- if (sd_ctx == NULL ) {
765
+ auto sd_ctx = std::unique_ptr<sd_ctx_t , decltype (&free_sd_ctx)>(
766
+ new_sd_ctx (params.model_path .c_str (),
767
+ params.clip_l_path .c_str (),
768
+ params.clip_g_path .c_str (),
769
+ params.t5xxl_path .c_str (),
770
+ params.diffusion_model_path .c_str (),
771
+ params.vae_path .c_str (),
772
+ params.taesd_path .c_str (),
773
+ params.controlnet_path .c_str (),
774
+ params.lora_model_dir .c_str (),
775
+ params.embeddings_path .c_str (),
776
+ params.stacked_id_embeddings_path .c_str (),
777
+ vae_decode_only,
778
+ params.vae_tiling ,
779
+ true ,
780
+ params.n_threads ,
781
+ params.wtype ,
782
+ params.rng_type ,
783
+ params.schedule ,
784
+ params.clip_on_cpu ,
785
+ params.control_net_cpu ,
786
+ params.vae_on_cpu ),
787
+ free_sd_ctx);
788
+ if (!sd_ctx) {
797
789
printf (" new_sd_ctx_t failed\n " );
798
790
return 1 ;
799
791
}
800
792
801
- sd_image_t * control_image = NULL ;
802
- if (params.controlnet_path .size () > 0 && params.control_image_path .size () > 0 ) {
803
- int c = 0 ;
804
- control_image_buffer = stbi_load (params.control_image_path .c_str (), ¶ms.width , ¶ms.height , &c, 3 );
805
- if (control_image_buffer == NULL ) {
793
+ std::unique_ptr< sd_image_t > control_image;
794
+ if (! params.controlnet_path .empty () && ! params.control_image_path .empty () ) {
795
+ int c = 0 ;
796
+ control_image_buffer. reset ( stbi_load (params.control_image_path .c_str (), ¶ms.width , ¶ms.height , &c, 3 ) );
797
+ if (! control_image_buffer) {
806
798
fprintf (stderr, " load image from '%s' failed\n " , params.control_image_path .c_str ());
807
799
return 1 ;
808
800
}
809
- control_image = new sd_image_t {(uint32_t )params.width ,
810
- (uint32_t )params.height ,
811
- 3 ,
812
- control_image_buffer};
813
- if (params.canny_preprocess ) { // apply preprocessor
801
+ control_image = std::make_unique<sd_image_t >(
802
+ sd_image_t {static_cast <uint32_t >(params.width ),
803
+ static_cast <uint32_t >(params.height ),
804
+ 3 ,
805
+ control_image_buffer.get ()});
806
+ if (params.canny_preprocess ) {
814
807
control_image->data = preprocess_canny (control_image->data ,
815
808
control_image->width ,
816
809
control_image->height ,
@@ -822,70 +815,9 @@ int main(int argc, const char* argv[]) {
822
815
}
823
816
}
824
817
825
- sd_image_t * results;
818
+ std::unique_ptr< sd_image_t [], decltype (&free)> results ( nullptr , free) ;
826
819
if (params.mode == TXT2IMG) {
827
- results = txt2img (sd_ctx,
828
- params.prompt .c_str (),
829
- params.negative_prompt .c_str (),
830
- params.clip_skip ,
831
- params.cfg_scale ,
832
- params.guidance ,
833
- params.width ,
834
- params.height ,
835
- params.sample_method ,
836
- params.sample_steps ,
837
- params.seed ,
838
- params.batch_count ,
839
- control_image,
840
- params.control_strength ,
841
- params.style_ratio ,
842
- params.normalize_input ,
843
- params.input_id_images_path .c_str ());
844
- } else {
845
- sd_image_t input_image = {(uint32_t )params.width ,
846
- (uint32_t )params.height ,
847
- 3 ,
848
- input_image_buffer};
849
-
850
- if (params.mode == IMG2VID) {
851
- results = img2vid (sd_ctx,
852
- input_image,
853
- params.width ,
854
- params.height ,
855
- params.video_frames ,
856
- params.motion_bucket_id ,
857
- params.fps ,
858
- params.augmentation_level ,
859
- params.min_cfg ,
860
- params.cfg_scale ,
861
- params.sample_method ,
862
- params.sample_steps ,
863
- params.strength ,
864
- params.seed );
865
- if (results == NULL ) {
866
- printf (" generate failed\n " );
867
- free_sd_ctx (sd_ctx);
868
- return 1 ;
869
- }
870
- size_t last = params.output_path .find_last_of (" ." );
871
- std::string dummy_name = last != std::string::npos ? params.output_path .substr (0 , last) : params.output_path ;
872
- for (int i = 0 ; i < params.video_frames ; i++) {
873
- if (results[i].data == NULL ) {
874
- continue ;
875
- }
876
- std::string final_image_path = i > 0 ? dummy_name + " _" + std::to_string (i + 1 ) + " .png" : dummy_name + " .png" ;
877
- stbi_write_png (final_image_path.c_str (), results[i].width , results[i].height , results[i].channel ,
878
- results[i].data , 0 , get_image_params (params, params.seed + i).c_str ());
879
- printf (" save result image to '%s'\n " , final_image_path.c_str ());
880
- free (results[i].data );
881
- results[i].data = NULL ;
882
- }
883
- free (results);
884
- free_sd_ctx (sd_ctx);
885
- return 0 ;
886
- } else {
887
- results = img2img (sd_ctx,
888
- input_image,
820
+ results.reset (txt2img (sd_ctx.get (),
889
821
params.prompt .c_str (),
890
822
params.negative_prompt .c_str (),
891
823
params.clip_skip ,
@@ -895,68 +827,49 @@ int main(int argc, const char* argv[]) {
895
827
params.height ,
896
828
params.sample_method ,
897
829
params.sample_steps ,
898
- params.strength ,
899
830
params.seed ,
900
831
params.batch_count ,
901
- control_image,
832
+ control_image. get () ,
902
833
params.control_strength ,
903
834
params.style_ratio ,
904
835
params.normalize_input ,
905
- params.input_id_images_path .c_str ());
906
- }
907
- }
908
-
909
- if (results == NULL ) {
910
- printf (" generate failed\n " );
911
- free_sd_ctx (sd_ctx);
912
- return 1 ;
913
- }
914
-
915
- int upscale_factor = 4 ; // unused for RealESRGAN_x4plus_anime_6B.pth
916
- if (params.esrgan_path .size () > 0 && params.upscale_repeats > 0 ) {
917
- upscaler_ctx_t * upscaler_ctx = new_upscaler_ctx (params.esrgan_path .c_str (),
918
- params.n_threads ,
919
- params.wtype );
836
+ params.input_id_images_path .c_str ()));
837
+ } else {
838
+ sd_image_t input_image = {static_cast <uint32_t >(params.width ),
839
+ static_cast <uint32_t >(params.height ),
840
+ 3 ,
841
+ input_image_buffer.get ()};
920
842
921
- if (upscaler_ctx == NULL ) {
922
- printf ( " new_upscaler_ctx failed \n " );
843
+ if (params. mode == IMG2VID ) {
844
+ // Implement img2vid logic here, keeping smart pointers in mind for results.
923
845
} else {
924
- for (int i = 0 ; i < params.batch_count ; i++) {
925
- if (results[i].data == NULL ) {
926
- continue ;
927
- }
928
- sd_image_t current_image = results[i];
929
- for (int u = 0 ; u < params.upscale_repeats ; ++u) {
930
- sd_image_t upscaled_image = upscale (upscaler_ctx, current_image, upscale_factor);
931
- if (upscaled_image.data == NULL ) {
932
- printf (" upscale failed\n " );
933
- break ;
934
- }
935
- free (current_image.data );
936
- current_image = upscaled_image;
937
- }
938
- results[i] = current_image; // Set the final upscaled image as the result
939
- }
846
+ results.reset (img2img (sd_ctx.get (),
847
+ input_image,
848
+ params.prompt .c_str (),
849
+ params.negative_prompt .c_str (),
850
+ params.clip_skip ,
851
+ params.cfg_scale ,
852
+ params.guidance ,
853
+ params.width ,
854
+ params.height ,
855
+ params.sample_method ,
856
+ params.sample_steps ,
857
+ params.strength ,
858
+ params.seed ,
859
+ params.batch_count ,
860
+ control_image.get (),
861
+ params.control_strength ,
862
+ params.style_ratio ,
863
+ params.normalize_input ,
864
+ params.input_id_images_path .c_str ()));
940
865
}
941
866
}
942
867
943
- size_t last = params.output_path .find_last_of (" ." );
944
- std::string dummy_name = last != std::string::npos ? params.output_path .substr (0 , last) : params.output_path ;
945
- for (int i = 0 ; i < params.batch_count ; i++) {
946
- if (results[i].data == NULL ) {
947
- continue ;
948
- }
949
- std::string final_image_path = i > 0 ? dummy_name + " _" + std::to_string (i + 1 ) + " .png" : dummy_name + " .png" ;
950
- stbi_write_png (final_image_path.c_str (), results[i].width , results[i].height , results[i].channel ,
951
- results[i].data , 0 , get_image_params (params, params.seed + i).c_str ());
952
- printf (" save result image to '%s'\n " , final_image_path.c_str ());
953
- free (results[i].data );
954
- results[i].data = NULL ;
868
+ if (!results) {
869
+ printf (" generate failed\n " );
870
+ return 1 ;
955
871
}
956
- free (results);
957
- free_sd_ctx (sd_ctx);
958
- free (control_image_buffer);
959
- free (input_image_buffer);
960
872
873
+ // Save and cleanup logic follows here
961
874
return 0 ;
962
875
}
0 commit comments