diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index e38a6101..25017999 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -172,7 +172,27 @@ class StableDiffusionGGML { #endif #ifdef SD_USE_VULKAN LOG_DEBUG("Using Vulkan backend"); - for (int device = 0; device < ggml_backend_vk_get_device_count(); ++device) { + size_t device = 0; + const int device_count = ggml_backend_vk_get_device_count(); + if (device_count) { + const char* SD_VK_DEVICE = getenv("SD_VK_DEVICE"); + if (SD_VK_DEVICE != nullptr) { + std::string sd_vk_device_str = SD_VK_DEVICE; + try { + device = std::stoull(sd_vk_device_str); + } catch (const std::invalid_argument&) { + LOG_WARN("SD_VK_DEVICE environment variable is not a valid integer (%s). Falling back to device 0.", SD_VK_DEVICE); + device = 0; + } catch (const std::out_of_range&) { + LOG_WARN("SD_VK_DEVICE environment variable value is out of range for `unsigned long long` type (%s). Falling back to device 0.", SD_VK_DEVICE); + device = 0; + } + if (device >= device_count) { + LOG_WARN("Cannot find targeted vulkan device (%llu). Falling back to device 0.", device); + device = 0; + } + } + LOG_INFO("Vulkan: Using device %llu", device); backend = ggml_backend_vk_init(device); } if (!backend) {