Skip to content

Commit 7633c08

Browse files
Re-enable CONV test
1 parent c065346 commit 7633c08

File tree

3 files changed

+68
-39
lines changed

3 files changed

+68
-39
lines changed

samples/cooperative_matrix/code/main/cooperative_matrix_shaders.hpp

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -264,9 +264,6 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
264264
265265
void main()
266266
{
267-
if(gl_LocalInvocationIndex == 0)
268-
debugPrintfEXT("\nRunning SPIR-V shader (QCOM version) gl_WorkGroupSize(%d, %d, %d)\n", gl_WorkGroupSize.x, gl_WorkGroupSize.y, gl_WorkGroupSize.z);
269-
270267
const uint32_t block_id_m = gl_GlobalInvocationID.y;
271268
const uint32_t block_id_n = gl_GlobalInvocationID.z;
272269
if ((block_id_m >= TOTAL_M/TILE_M) || (block_id_n >= TOTAL_N/TILE_N)) return;
@@ -287,15 +284,13 @@ void main()
287284
uint32_t subMatrixBStartInElements = layoutB_Kfirst ? col * strideBinElements + step : col + step * strideBinElements;
288285
289286
coopmat<A_TYPE, gl_ScopeSubgroup, TILE_M, TILE_K, gl_MatrixUseA> matA;
290-
#define NEW
291-
#ifdef NEW
287+
292288
uint32_t uvecA[8];
293289
for (int i=0; i<8; i++)
294-
uvecA[i] = floatBitsToInt(inputA.x[subMatrixAStartInElements + gl_GlobalInvocationID.x * strideAinElements + i]);
295-
matA = constructCoopMatA64QCOM(uvecA, gl_Float32QCOM);
296-
#else
297-
coopMatLoad(matA, inputA.x, subMatrixAStartInElements, strideAinElements, int(layoutA_Mfirst));
298-
#endif
290+
uvecA[i] = floatBitsToInt(inputA.x[subMatrixAStartInElements + gl_GlobalInvocationID.x * strideAinElements + i]);
291+
292+
// convert A vector to A matrix
293+
vectorToCoopmatQCOM(uvecA, matA);
299294
300295
coopmat<A_TYPE, gl_ScopeSubgroup, TILE_K, TILE_N, gl_MatrixUseB> matB;
301296
coopMatLoad(matB, inputB.x, subMatrixBStartInElements, strideBinElements, int(layoutB_Kfirst));

samples/cooperative_matrix/code/main/cooperative_matrix_tester.cpp

Lines changed: 57 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -642,23 +642,6 @@ void CooperativeMatrixRunner::RenderUI()
642642
{
643643
ImGui::DragInt("Test Repeats", &m_test_repeats, 1.0f, 0, 100);
644644

645-
// NOTE: Validation (and its transpose option) will be added in a future path
646-
ImGui::BeginDisabled();
647-
if (m_validate_matrix_result)
648-
{
649-
ImGui::BeginDisabled();
650-
static bool always_true = true;
651-
ImGui::Checkbox("Transpose When Needed", &always_true);
652-
ImGui::EndDisabled();
653-
}
654-
else
655-
{
656-
ImGui::Checkbox("Transpose When Needed", &m_transpose_when_needed);
657-
}
658-
659-
ImGui::Checkbox("Validate Result", &m_validate_matrix_result);
660-
ImGui::EndDisabled();
661-
662645
static const char* test_case_names[] = {
663646
"MxM Basic",
664647
"MxM Vector To Matrix",
@@ -668,31 +651,36 @@ void CooperativeMatrixRunner::RenderUI()
668651
int test_type_current_index = static_cast<int>(m_test_type);
669652
bool changed = false;
670653

654+
ImGui::Text("Note: Not all tests are compatible with all devices!");
655+
ImGui::Text("Check shader instruction set for compatibility if testing other than MXM_BASIC");
656+
671657
if (ImGui::BeginCombo("Test Case", test_case_names[test_type_current_index]))
672658
{
673659
// NOTE: Temporarily disabled other tests, new test template coming on the next patch
674660
for (int i = 0; i < static_cast<int>(TestType::TT_COUNT); ++i)
675661
{
676-
// NOTE: Temporarily disabled other tests, new test template coming on the next patch
677-
ImGui::BeginDisabled(i > 0);
678-
679662
const bool is_selected = (test_type_current_index == i);
680663
if (ImGui::Selectable(test_case_names[i], is_selected))
681664
{
682665
m_test_type = static_cast<TestType>(i);
683666
changed = true;
684667
}
685668

686-
ImGui::EndDisabled();
687-
688669
if (is_selected)
689670
ImGui::SetItemDefaultFocus();
690671
}
691672
ImGui::EndCombo();
692673
}
693674

694-
ImGui::Separator();
675+
ImGui::BeginDisabled(m_test_type != TestType::TT_CONV);
676+
ImGui::DragInt("Conv Width", &m_input_width, 1.0f, 1, 256);
677+
ImGui::DragInt("Conv Height", &m_input_height, 1.0f, 1, 256);
678+
ImGui::Checkbox("Normalize Inputs", &m_normalize_inputs);
679+
ImGui::EndDisabled();
680+
}
695681

682+
if (ImGui::CollapsingHeader("Matrix Configuration", ImGuiTreeNodeFlags_None))
683+
{
696684
static const char* fill_type_labels[] = {
697685
"Fill with Zero",
698686
"Fill with Constants",
@@ -728,6 +716,23 @@ void CooperativeMatrixRunner::RenderUI()
728716
m_matrix_transpose_options[i] = static_cast<MatrixTransposeOption>(current_index);
729717
}
730718
}
719+
720+
// NOTE: Validation (and its transpose option) will be added in a future path
721+
ImGui::BeginDisabled();
722+
if (m_validate_matrix_result)
723+
{
724+
ImGui::BeginDisabled();
725+
static bool always_true = true;
726+
ImGui::Checkbox("Transpose When Needed", &always_true);
727+
ImGui::EndDisabled();
728+
}
729+
else
730+
{
731+
ImGui::Checkbox("Transpose When Needed", &m_transpose_when_needed);
732+
}
733+
734+
ImGui::Checkbox("Validate Result", &m_validate_matrix_result);
735+
ImGui::EndDisabled();
731736
}
732737

733738
if (ImGui::CollapsingHeader("Device Configuration", 0))
@@ -746,8 +751,6 @@ void CooperativeMatrixRunner::RenderUI()
746751
PrepareTestSession();
747752
}
748753

749-
ImGui::Text("For accurate values, make sure you are using the right device configurations (check 'Device Configuration' tab)");
750-
751754
if (m_is_processing_tests)
752755
{
753756
ImGui::SameLine();
@@ -758,6 +761,8 @@ void CooperativeMatrixRunner::RenderUI()
758761
ImGui::BeginDisabled(disable_ui);
759762
}
760763

764+
ImGui::Text("For accurate values, make sure you are using the right device configurations (check 'Device Configuration' tab)");
765+
761766
if (!m_test_groups.empty())
762767
{
763768
for (int i=0; i< m_test_groups.size(); i++)
@@ -864,6 +869,10 @@ void CooperativeMatrixRunner::RenderUI()
864869

865870
ImGui::Text("[Time]: %.2fus", test_result.time_total);
866871
ImGui::Text("[TOPS]: %.2f", test_result.TOPS);
872+
873+
if (m_test_type == TT_CONV)
874+
ImGui::TextDisabled("WxH = %dx%d", test_description.inputWidth, test_description.inputHeight);
875+
867876
ImVec4 color = GetPercentageColor(test_result.percentage / 100.0f);
868877
ImGui::PushStyleColor(ImGuiCol_Text, color);
869878
ImGui::Text("[%%]: %.2f", test_result.percentage);
@@ -957,8 +966,8 @@ void CooperativeMatrixRunner::PrepareTestSession()
957966
new_test_description.gpu_freq_MHz = m_gpu_freq_MHz;
958967
new_test_description.test_type = m_test_type;
959968

960-
new_test_description.inputWidth = 8;
961-
new_test_description.inputHeight = 8;
969+
new_test_description.inputWidth = m_input_width;
970+
new_test_description.inputHeight = m_input_height;
962971

963972
new_test_description.input_type = test_template_description.input_type;
964973
new_test_description.output_type = test_template_description.output_type;
@@ -1132,6 +1141,27 @@ std::optional<CooperativeMatrixRunner::TestResult> CooperativeMatrixRunner::RunT
11321141
return std::nullopt;
11331142
}
11341143

1144+
if (m_normalize_inputs)
1145+
{
1146+
int required_area = MSizeInBlocks * cooperativeMatrixProps.MSize;
1147+
1148+
// Start with inputWidth as-is, compute height to match required_area
1149+
if (inputWidth <= 0) inputWidth = 1; // safety
1150+
inputHeight = required_area / inputWidth;
1151+
1152+
// If division leaves remainder, just force height to match
1153+
if (inputWidth * inputHeight != required_area)
1154+
{
1155+
inputHeight = required_area / inputWidth;
1156+
if (inputWidth * inputHeight != required_area)
1157+
{
1158+
// Last resort: set width = required_area, height = 1
1159+
inputWidth = required_area;
1160+
inputHeight = 1;
1161+
}
1162+
}
1163+
}
1164+
11351165
// Set local_size (workgroup size) based on GPU/Tier, and datatype (fp32, fp16, etc)
11361166
// Default for 'unknown' or gpu/tier not recohgnized is local_size(64,2,2) for all datatyes
11371167
uint32_t local_size_x = 0, local_size_y = 0, local_size_z = 0;

samples/cooperative_matrix/code/main/cooperative_matrix_tester.hpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ class CooperativeMatrixRunner
5858
bool layoutC_Mfirst = false;
5959
bool layoutR_Mfirst = false;
6060

61-
int inputWidth = 1;
62-
int inputHeight = 1;
61+
int inputWidth = 32;
62+
int inputHeight = 16;
6363
};
6464

6565
struct TestResult
@@ -140,6 +140,10 @@ class CooperativeMatrixRunner
140140
bool m_transpose_when_needed = false;
141141
bool m_validate_matrix_result = false;
142142

143+
bool m_normalize_inputs = true;
144+
int m_input_width = 32;
145+
int m_input_height = 16;
146+
143147
bool m_is_processing_tests = false;
144148
uint32_t m_total_tests = 0;
145149
uint32_t m_total_processed_tests = 0;

0 commit comments

Comments
 (0)