Skip to content

Commit fe794de

Browse files
SamGondelmanfacebook-github-bot
authored andcommitted
Ability to set external instance + devices (#11393)
Summary: Adds a new extension to ETVK, set_and_get_external_adapter, which allows clients to share their Vulkan instance/devices with ETVK. This is useful when using volk, which does not support multiple devices when using volkLoadDevice. Differential Revision: D71372344
1 parent e42dafc commit fe794de

18 files changed

+336
-94
lines changed

backends/vulkan/runtime/VulkanBackend.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313

1414
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
1515

16+
#include <executorch/extension/vulkan/external_adapter_wrapper.h>
17+
1618
#include <executorch/runtime/backend/interface.h>
1719
#include <executorch/runtime/core/error.h>
1820
#include <executorch/runtime/core/evalue.h>
@@ -517,7 +519,9 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
517519
return Error::MemoryAllocationFailed;
518520
}
519521

520-
new (compute_graph) ComputeGraph(get_graph_config(compile_specs));
522+
GraphConfig graph_config = get_graph_config(compile_specs);
523+
graph_config.external_adapter = extension::set_and_get_external_adapter();
524+
new (compute_graph) ComputeGraph(graph_config);
521525

522526
Error err = compileModel(processed->data(), compute_graph);
523527

backends/vulkan/runtime/api/Context.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@
2424
namespace vkcompute {
2525
namespace api {
2626

27-
Context::Context(size_t adapter_i, const ContextConfig& config)
27+
Context::Context(vkapi::Adapter* adapter, const ContextConfig& config)
2828
: config_(config),
2929
// Important handles
30-
adapter_p_(vkapi::runtime()->get_adapter_p(adapter_i)),
30+
adapter_p_(adapter),
3131
device_(adapter_p_->device_handle()),
3232
queue_(adapter_p_->request_queue()),
3333
// Resource pools
@@ -256,7 +256,7 @@ Context* context() {
256256
query_pool_config,
257257
};
258258

259-
return new Context(vkapi::runtime()->default_adapter_i(), config);
259+
return new Context(vkapi::runtime()->get_adapter_p(), config);
260260
} catch (...) {
261261
}
262262

backends/vulkan/runtime/api/Context.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ struct ContextConfig final {
4242

4343
class Context final {
4444
public:
45-
explicit Context(size_t adapter_i, const ContextConfig&);
45+
explicit Context(vkapi::Adapter*, const ContextConfig&);
4646

4747
Context(const Context&) = delete;
4848
Context& operator=(const Context&) = delete;

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,8 @@ ComputeGraph::ComputeGraph(GraphConfig config)
122122
prepack_descriptor_counts_{},
123123
execute_descriptor_counts_{},
124124
context_{new api::Context(
125-
vkapi::runtime()->default_adapter_i(),
125+
config.external_adapter ? config.external_adapter
126+
: vkapi::runtime()->get_adapter_p(),
126127
config_.context_config)},
127128
shared_objects_{},
128129
values_{},

backends/vulkan/runtime/graph/GraphConfig.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ GraphConfig::GraphConfig() {
6363

6464
enable_local_wg_size_override = false;
6565
local_wg_size_override = {};
66+
67+
external_adapter = nullptr;
6668
}
6769

6870
void GraphConfig::set_storage_type_override(utils::StorageType storage_type) {

backends/vulkan/runtime/graph/GraphConfig.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ struct GraphConfig final {
3333
bool enable_local_wg_size_override;
3434
utils::uvec3 local_wg_size_override;
3535

36+
vkapi::Adapter* external_adapter;
37+
3638
// Generate a default graph config with pre-configured settings
3739
explicit GraphConfig();
3840

backends/vulkan/runtime/vk_api/Adapter.cpp

Lines changed: 106 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,12 @@ namespace vkapi {
1717

1818
namespace {
1919

20-
VkDevice create_logical_device(
20+
void find_compute_queues(
2121
const PhysicalDevice& physical_device,
2222
const uint32_t num_queues_to_create,
23-
std::vector<Adapter::Queue>& queues,
24-
std::vector<uint32_t>& queue_usage) {
25-
// Find compute queues up to the requested number of queues
26-
27-
std::vector<VkDeviceQueueCreateInfo> queue_create_infos;
23+
std::vector<VkDeviceQueueCreateInfo>& queue_create_infos,
24+
std::vector<std::pair<uint32_t, uint32_t>>& queues_to_get) {
2825
queue_create_infos.reserve(num_queues_to_create);
29-
30-
std::vector<std::pair<uint32_t, uint32_t>> queues_to_get;
3126
queues_to_get.reserve(num_queues_to_create);
3227

3328
uint32_t remaining_queues = num_queues_to_create;
@@ -60,12 +55,44 @@ VkDevice create_logical_device(
6055
break;
6156
}
6257
}
58+
}
6359

60+
void populate_queue_info(
61+
const PhysicalDevice& physical_device,
62+
VkDevice logical_device,
63+
const std::vector<std::pair<uint32_t, uint32_t>>& queues_to_get,
64+
std::vector<Adapter::Queue>& queues,
65+
std::vector<uint32_t>& queue_usage) {
6466
queues.reserve(queues_to_get.size());
6567
queue_usage.reserve(queues_to_get.size());
6668

67-
// Create the VkDevice
69+
// Obtain handles for the created queues and initialize queue usage heuristic
70+
71+
for (const std::pair<uint32_t, uint32_t>& queue_idx : queues_to_get) {
72+
VkQueue queue_handle = VK_NULL_HANDLE;
73+
VkQueueFlags flags =
74+
physical_device.queue_families.at(queue_idx.first).queueFlags;
75+
vkGetDeviceQueue(
76+
logical_device, queue_idx.first, queue_idx.second, &queue_handle);
77+
queues.push_back({queue_idx.first, queue_idx.second, flags, queue_handle});
78+
// Initial usage value
79+
queue_usage.push_back(0);
80+
}
81+
}
82+
83+
VkDevice create_logical_device(
84+
const PhysicalDevice& physical_device,
85+
const uint32_t num_queues_to_create,
86+
std::vector<Adapter::Queue>& queues,
87+
std::vector<uint32_t>& queue_usage) {
88+
// Find compute queues up to the requested number of queues
6889

90+
std::vector<VkDeviceQueueCreateInfo> queue_create_infos;
91+
std::vector<std::pair<uint32_t, uint32_t>> queues_to_get;
92+
find_compute_queues(
93+
physical_device, num_queues_to_create, queue_create_infos, queues_to_get);
94+
95+
// Create the VkDevice
6996
std::vector<const char*> requested_device_extensions{
7097
#ifdef VK_KHR_portability_subset
7198
VK_KHR_PORTABILITY_SUBSET_EXTENSION_NAME,
@@ -143,19 +170,42 @@ VkDevice create_logical_device(
143170
volkLoadDevice(handle);
144171
#endif /* USE_VULKAN_VOLK */
145172

146-
// Obtain handles for the created queues and initialize queue usage heuristic
173+
populate_queue_info(
174+
physical_device, handle, queues_to_get, queues, queue_usage);
147175

148-
for (const std::pair<uint32_t, uint32_t>& queue_idx : queues_to_get) {
149-
VkQueue queue_handle = VK_NULL_HANDLE;
150-
VkQueueFlags flags =
151-
physical_device.queue_families.at(queue_idx.first).queueFlags;
152-
vkGetDeviceQueue(handle, queue_idx.first, queue_idx.second, &queue_handle);
153-
queues.push_back({queue_idx.first, queue_idx.second, flags, queue_handle});
154-
// Initial usage value
155-
queue_usage.push_back(0);
176+
return handle;
177+
}
178+
179+
bool test_linear_tiling_3d_image_support(VkDevice device) {
180+
// Test creating a 3D image with linear tiling to see if it is supported.
181+
// According to the Vulkan spec, linear tiling may not be supported for 3D
182+
// images.
183+
VkExtent3D image_extents{1u, 1u, 1u};
184+
const VkImageCreateInfo image_create_info{
185+
VK_STRUCTURE_TYPE_IMAGE_CREATE_INFO, // sType
186+
nullptr, // pNext
187+
0u, // flags
188+
VK_IMAGE_TYPE_3D, // imageType
189+
VK_FORMAT_R32G32B32A32_SFLOAT, // format
190+
image_extents, // extents
191+
1u, // mipLevels
192+
1u, // arrayLayers
193+
VK_SAMPLE_COUNT_1_BIT, // samples
194+
VK_IMAGE_TILING_LINEAR, // tiling
195+
VK_IMAGE_USAGE_SAMPLED_BIT | VK_IMAGE_USAGE_STORAGE_BIT, // usage
196+
VK_SHARING_MODE_EXCLUSIVE, // sharingMode
197+
0u, // queueFamilyIndexCount
198+
nullptr, // pQueueFamilyIndices
199+
VK_IMAGE_LAYOUT_UNDEFINED, // initialLayout
200+
};
201+
VkImage image = VK_NULL_HANDLE;
202+
VkResult res = vkCreateImage(device, &image_create_info, nullptr, &image);
203+
204+
if (res == VK_SUCCESS) {
205+
vkDestroyImage(device, image, nullptr);
156206
}
157207

158-
return handle;
208+
return res == VK_SUCCESS;
159209
}
160210

161211
} // namespace
@@ -186,37 +236,44 @@ Adapter::Adapter(
186236
compute_pipeline_cache_(device_.handle, cache_data_path),
187237
sampler_cache_(device_.handle),
188238
vma_(instance_, physical_device_.handle, device_.handle),
189-
linear_tiling_3d_enabled_{true} {
190-
// Test creating a 3D image with linear tiling to see if it is supported.
191-
// According to the Vulkan spec, linear tiling may not be supported for 3D
192-
// images.
193-
VkExtent3D image_extents{1u, 1u, 1u};
194-
const VkImageCreateInfo image_create_info{
195-
VK_STRUCTURE_TYPE_IMAGE_CREATE_INFO, // sType
196-
nullptr, // pNext
197-
0u, // flags
198-
VK_IMAGE_TYPE_3D, // imageType
199-
VK_FORMAT_R32G32B32A32_SFLOAT, // format
200-
image_extents, // extents
201-
1u, // mipLevels
202-
1u, // arrayLayers
203-
VK_SAMPLE_COUNT_1_BIT, // samples
204-
VK_IMAGE_TILING_LINEAR, // tiling
205-
VK_IMAGE_USAGE_SAMPLED_BIT | VK_IMAGE_USAGE_STORAGE_BIT, // usage
206-
VK_SHARING_MODE_EXCLUSIVE, // sharingMode
207-
0u, // queueFamilyIndexCount
208-
nullptr, // pQueueFamilyIndices
209-
VK_IMAGE_LAYOUT_UNDEFINED, // initialLayout
210-
};
211-
VkImage image = VK_NULL_HANDLE;
212-
VkResult res =
213-
vkCreateImage(device_.handle, &image_create_info, nullptr, &image);
214-
if (res != VK_SUCCESS) {
215-
linear_tiling_3d_enabled_ = false;
216-
} else {
217-
vkDestroyImage(device_.handle, image, nullptr);
239+
linear_tiling_3d_enabled_{
240+
test_linear_tiling_3d_image_support(device_.handle)},
241+
owns_device_{true} {}
242+
243+
Adapter::Adapter(
244+
VkInstance instance,
245+
VkPhysicalDevice physical_device,
246+
VkDevice logical_device,
247+
const uint32_t num_queues,
248+
const std::string& cache_data_path)
249+
: queue_usage_mutex_{},
250+
physical_device_(physical_device),
251+
queues_{},
252+
queue_usage_{},
253+
queue_mutexes_{},
254+
instance_(instance),
255+
device_(logical_device),
256+
shader_layout_cache_(device_.handle),
257+
shader_cache_(device_.handle),
258+
pipeline_layout_cache_(device_.handle),
259+
compute_pipeline_cache_(device_.handle, cache_data_path),
260+
sampler_cache_(device_.handle),
261+
vma_(instance_, physical_device_.handle, device_.handle),
262+
linear_tiling_3d_enabled_{
263+
test_linear_tiling_3d_image_support(device_.handle)},
264+
owns_device_{false} {
265+
std::vector<VkDeviceQueueCreateInfo> queue_create_infos;
266+
std::vector<std::pair<uint32_t, uint32_t>> queues_to_get;
267+
find_compute_queues(
268+
physical_device_, num_queues, queue_create_infos, queues_to_get);
269+
populate_queue_info(
270+
physical_device_, device_.handle, queues_to_get, queues_, queue_usage_);
271+
}
272+
273+
Adapter::~Adapter() {
274+
if (!owns_device_) {
275+
device_.handle = VK_NULL_HANDLE;
218276
}
219-
return;
220277
}
221278

222279
Adapter::Queue Adapter::request_queue() {

backends/vulkan/runtime/vk_api/Adapter.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,20 @@ class Adapter final {
5656
const uint32_t num_queues,
5757
const std::string& cache_data_path);
5858

59+
explicit Adapter(
60+
VkInstance instance,
61+
VkPhysicalDevice physical_device,
62+
VkDevice logical_device,
63+
const uint32_t num_queues,
64+
const std::string& cache_data_path);
65+
5966
Adapter(const Adapter&) = delete;
6067
Adapter& operator=(const Adapter&) = delete;
6168

6269
Adapter(Adapter&&) = delete;
6370
Adapter& operator=(Adapter&&) = delete;
6471

65-
~Adapter() = default;
72+
~Adapter();
6673

6774
struct Queue {
6875
uint32_t family_index;
@@ -94,6 +101,7 @@ class Adapter final {
94101
Allocator vma_;
95102
// Miscellaneous
96103
bool linear_tiling_3d_enabled_;
104+
bool owns_device_;
97105

98106
public:
99107
// Physical Device metadata

backends/vulkan/runtime/vk_api/Runtime.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@
1414
#include <iostream>
1515
#include <sstream>
1616

17+
#ifdef USE_VOLK_HEADER_ONLY
18+
// For volk.h, define this before including volk.h in exactly one CPP file.
19+
#define VOLK_IMPLEMENTATION
20+
#include <volk.h>
21+
#endif /* USE_VOLK_HEADER_ONLY */
22+
1723
namespace vkcompute {
1824
namespace vkapi {
1925

backends/vulkan/runtime/vk_api/Types.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717
#include <cstddef>
1818
#include <cstdint>
1919

20+
#if defined(__linux__)
21+
#undef Bool
22+
#endif
23+
2024
#ifdef USE_VULKAN_FP16_INFERENCE
2125
#define VK_FORMAT_FLOAT4 VK_FORMAT_R16G16B16A16_SFLOAT
2226
#else

0 commit comments

Comments
 (0)