@@ -169,6 +169,7 @@ void MetricsLibrary::release() {
169
169
api = {};
170
170
callbacks = {};
171
171
context = {};
172
+ isWorkloadPartitionEnabled = false ;
172
173
initializationState = ZE_RESULT_ERROR_UNINITIALIZED;
173
174
}
174
175
@@ -193,13 +194,19 @@ bool MetricsLibrary::load() {
193
194
return true ;
194
195
}
195
196
197
+ void MetricsLibrary::enableWorkloadPartition () {
198
+ isWorkloadPartitionEnabled = true ;
199
+ }
200
+
196
201
void MetricsLibrary::getSubDeviceClientOptions (
197
- NEO::Device &neoDevice,
198
202
ClientOptionsData_1_0 &subDevice,
199
203
ClientOptionsData_1_0 &subDeviceIndex,
200
- ClientOptionsData_1_0 &subDeviceCount) {
204
+ ClientOptionsData_1_0 &subDeviceCount,
205
+ ClientOptionsData_1_0 &workloadPartition) {
206
+
207
+ const auto &deviceImp = *static_cast <DeviceImp *>(&metricContext.getDevice ());
201
208
202
- if (!neoDevice. isSubDevice () ) {
209
+ if (!deviceImp. isSubdevice ) {
203
210
204
211
// Root device.
205
212
subDevice.Type = ClientOptionsType::SubDevice;
@@ -209,7 +216,10 @@ void MetricsLibrary::getSubDeviceClientOptions(
209
216
subDeviceIndex.SubDeviceIndex .Index = 0 ;
210
217
211
218
subDeviceCount.Type = ClientOptionsType::SubDeviceCount;
212
- subDeviceCount.SubDeviceCount .Count = std::max (neoDevice.getNumSubDevices (), 1u );
219
+ subDeviceCount.SubDeviceCount .Count = std::max (deviceImp.neoDevice ->getRootDevice ()->getNumSubDevices (), 1u );
220
+
221
+ workloadPartition.Type = ClientOptionsType::WorkloadPartition;
222
+ workloadPartition.WorkloadPartition .Enabled = false ;
213
223
214
224
} else {
215
225
@@ -218,10 +228,13 @@ void MetricsLibrary::getSubDeviceClientOptions(
218
228
subDevice.SubDevice .Enabled = true ;
219
229
220
230
subDeviceIndex.Type = ClientOptionsType::SubDeviceIndex;
221
- subDeviceIndex.SubDeviceIndex .Index = static_cast <NEO::SubDevice *>(& neoDevice)->getSubDeviceIndex ();
231
+ subDeviceIndex.SubDeviceIndex .Index = static_cast <NEO::SubDevice *>(deviceImp. neoDevice )->getSubDeviceIndex ();
222
232
223
233
subDeviceCount.Type = ClientOptionsType::SubDeviceCount;
224
- subDeviceCount.SubDeviceCount .Count = std::max (neoDevice.getRootDevice ()->getNumSubDevices (), 1u );
234
+ subDeviceCount.SubDeviceCount .Count = std::max (deviceImp.neoDevice ->getRootDevice ()->getNumSubDevices (), 1u );
235
+
236
+ workloadPartition.Type = ClientOptionsType::WorkloadPartition;
237
+ workloadPartition.WorkloadPartition .Enabled = isWorkloadPartitionEnabled;
225
238
}
226
239
}
227
240
@@ -230,7 +243,7 @@ bool MetricsLibrary::createContext() {
230
243
const auto &hwHelper = device.getHwHelper ();
231
244
const auto &asyncComputeEngines = hwHelper.getGpgpuEngineInstances (device.getHwInfo ());
232
245
ContextCreateData_1_0 createData = {};
233
- ClientOptionsData_1_0 clientOptions[5 ] = {};
246
+ ClientOptionsData_1_0 clientOptions[6 ] = {};
234
247
ClientData_1_0 clientData = {};
235
248
ClientType_1_0 clientType = {};
236
249
ClientDataLinuxAdapter_1_0 adapter = {};
@@ -259,7 +272,7 @@ bool MetricsLibrary::createContext() {
259
272
clientOptions[1 ].Tbs .Enabled = metricContext.getMetricStreamer () != nullptr ;
260
273
261
274
// Sub device client options #2
262
- getSubDeviceClientOptions (*device. getNEODevice (), clientOptions[2 ], clientOptions[3 ], clientOptions[4 ]);
275
+ getSubDeviceClientOptions (clientOptions[2 ], clientOptions[3 ], clientOptions[4 ], clientOptions[ 5 ]);
263
276
264
277
clientData.Linux .Adapter = &adapter;
265
278
clientData.ClientOptions = clientOptions;
@@ -422,7 +435,7 @@ ze_result_t metricQueryPoolCreate(zet_context_handle_t hContext, zet_device_hand
422
435
const auto &deviceImp = *static_cast <DeviceImp *>(device);
423
436
auto metricPoolImp = new MetricQueryPoolImp (device->getMetricContext (), hMetricGroup, *pDesc);
424
437
425
- if (!deviceImp. isSubdevice && deviceImp .isMultiDeviceCapable ()) {
438
+ if (metricContext .isMultiDeviceCapable ()) {
426
439
427
440
auto emptyMetricGroups = std::vector<zet_metric_group_handle_t >();
428
441
auto &metricGroups = hMetricGroup
@@ -436,12 +449,15 @@ ze_result_t metricQueryPoolCreate(zet_context_handle_t hContext, zet_device_hand
436
449
for (size_t i = 0 ; i < deviceImp.numSubDevices ; ++i) {
437
450
438
451
auto &subDevice = deviceImp.subDevices [i];
452
+ auto &subDeviceMetricContext = subDevice->getMetricContext ();
453
+
454
+ subDeviceMetricContext.getMetricsLibrary ().enableWorkloadPartition ();
439
455
440
456
zet_metric_group_handle_t metricGroupHandle = useMetricGroupSubDevice
441
- ? metricGroups[subDevice-> getMetricContext () .getSubDeviceIndex ()]
457
+ ? metricGroups[subDeviceMetricContext .getSubDeviceIndex ()]
442
458
: hMetricGroup;
443
459
444
- auto metricPoolSubdeviceImp = new MetricQueryPoolImp (subDevice-> getMetricContext () , metricGroupHandle, *pDesc);
460
+ auto metricPoolSubdeviceImp = new MetricQueryPoolImp (subDeviceMetricContext , metricGroupHandle, *pDesc);
445
461
446
462
// Create metric query pool.
447
463
if (!metricPoolSubdeviceImp->create ()) {
@@ -534,7 +550,7 @@ bool MetricQueryPoolImp::allocateGpuMemory() {
534
550
if (description.type == ZET_METRIC_QUERY_POOL_TYPE_PERFORMANCE) {
535
551
// Get allocation size.
536
552
const auto &deviceImp = *static_cast <DeviceImp *>(&metricContext.getDevice ());
537
- allocationSize = (!deviceImp. isSubdevice && deviceImp .isMultiDeviceCapable ())
553
+ allocationSize = (metricContext .isMultiDeviceCapable ())
538
554
? deviceImp.subDevices [0 ]->getMetricContext ().getMetricsLibrary ().getQueryReportGpuSize () * description.count * deviceImp.numSubDevices
539
555
: metricsLibrary.getQueryReportGpuSize () * description.count ;
540
556
@@ -867,7 +883,7 @@ ze_result_t MetricQuery::appendMemoryBarrier(CommandList &commandList) {
867
883
868
884
DeviceImp *pDeviceImp = static_cast <DeviceImp *>(commandList.device );
869
885
870
- if (! pDeviceImp->isSubdevice && pDeviceImp ->isMultiDeviceCapable ()) {
886
+ if (pDeviceImp->metricContext ->isMultiDeviceCapable ()) {
871
887
// Use one of the sub-device contexts to append to command list.
872
888
pDeviceImp = static_cast <DeviceImp *>(pDeviceImp->subDevices [0 ]);
873
889
}
@@ -893,9 +909,10 @@ ze_result_t MetricQuery::appendStreamerMarker(CommandList &commandList,
893
909
894
910
DeviceImp *pDeviceImp = static_cast <DeviceImp *>(commandList.device );
895
911
896
- if (! pDeviceImp->isSubdevice && pDeviceImp ->isMultiDeviceCapable ()) {
912
+ if (pDeviceImp->metricContext ->isMultiDeviceCapable ()) {
897
913
// Use one of the sub-device contexts to append to command list.
898
914
pDeviceImp = static_cast <DeviceImp *>(pDeviceImp->subDevices [0 ]);
915
+ pDeviceImp->metricContext ->getMetricsLibrary ().enableWorkloadPartition ();
899
916
}
900
917
auto &metricContext = pDeviceImp->getMetricContext ();
901
918
auto &metricsLibrary = metricContext.getMetricsLibrary ();
0 commit comments