@@ -642,23 +642,6 @@ void CooperativeMatrixRunner::RenderUI()
642642 {
643643 ImGui::DragInt (" Test Repeats" , &m_test_repeats, 1 .0f , 0 , 100 );
644644
645- // NOTE: Validation (and its transpose option) will be added in a future path
646- ImGui::BeginDisabled ();
647- if (m_validate_matrix_result)
648- {
649- ImGui::BeginDisabled ();
650- static bool always_true = true ;
651- ImGui::Checkbox (" Transpose When Needed" , &always_true);
652- ImGui::EndDisabled ();
653- }
654- else
655- {
656- ImGui::Checkbox (" Transpose When Needed" , &m_transpose_when_needed);
657- }
658-
659- ImGui::Checkbox (" Validate Result" , &m_validate_matrix_result);
660- ImGui::EndDisabled ();
661-
662645 static const char * test_case_names[] = {
663646 " MxM Basic" ,
664647 " MxM Vector To Matrix" ,
@@ -668,31 +651,36 @@ void CooperativeMatrixRunner::RenderUI()
668651 int test_type_current_index = static_cast <int >(m_test_type);
669652 bool changed = false ;
670653
654+ ImGui::Text (" Note: Not all tests are compatible with all devices!" );
655+ ImGui::Text (" Check shader instruction set for compatibility if testing other than MXM_BASIC" );
656+
671657 if (ImGui::BeginCombo (" Test Case" , test_case_names[test_type_current_index]))
672658 {
673659 // NOTE: Temporarily disabled other tests, new test template coming on the next patch
674660 for (int i = 0 ; i < static_cast <int >(TestType::TT_COUNT); ++i)
675661 {
676- // NOTE: Temporarily disabled other tests, new test template coming on the next patch
677- ImGui::BeginDisabled (i > 0 );
678-
679662 const bool is_selected = (test_type_current_index == i);
680663 if (ImGui::Selectable (test_case_names[i], is_selected))
681664 {
682665 m_test_type = static_cast <TestType>(i);
683666 changed = true ;
684667 }
685668
686- ImGui::EndDisabled ();
687-
688669 if (is_selected)
689670 ImGui::SetItemDefaultFocus ();
690671 }
691672 ImGui::EndCombo ();
692673 }
693674
694- ImGui::Separator ();
675+ ImGui::BeginDisabled (m_test_type != TestType::TT_CONV);
676+ ImGui::DragInt (" Conv Width" , &m_input_width, 1 .0f , 1 , 256 );
677+ ImGui::DragInt (" Conv Height" , &m_input_height, 1 .0f , 1 , 256 );
678+ ImGui::Checkbox (" Normalize Inputs" , &m_normalize_inputs);
679+ ImGui::EndDisabled ();
680+ }
695681
682+ if (ImGui::CollapsingHeader (" Matrix Configuration" , ImGuiTreeNodeFlags_None))
683+ {
696684 static const char * fill_type_labels[] = {
697685 " Fill with Zero" ,
698686 " Fill with Constants" ,
@@ -728,6 +716,23 @@ void CooperativeMatrixRunner::RenderUI()
728716 m_matrix_transpose_options[i] = static_cast <MatrixTransposeOption>(current_index);
729717 }
730718 }
719+
720+ // NOTE: Validation (and its transpose option) will be added in a future path
721+ ImGui::BeginDisabled ();
722+ if (m_validate_matrix_result)
723+ {
724+ ImGui::BeginDisabled ();
725+ static bool always_true = true ;
726+ ImGui::Checkbox (" Transpose When Needed" , &always_true);
727+ ImGui::EndDisabled ();
728+ }
729+ else
730+ {
731+ ImGui::Checkbox (" Transpose When Needed" , &m_transpose_when_needed);
732+ }
733+
734+ ImGui::Checkbox (" Validate Result" , &m_validate_matrix_result);
735+ ImGui::EndDisabled ();
731736 }
732737
733738 if (ImGui::CollapsingHeader (" Device Configuration" , 0 ))
@@ -746,8 +751,6 @@ void CooperativeMatrixRunner::RenderUI()
746751 PrepareTestSession ();
747752 }
748753
749- ImGui::Text (" For accurate values, make sure you are using the right device configurations (check 'Device Configuration' tab)" );
750-
751754 if (m_is_processing_tests)
752755 {
753756 ImGui::SameLine ();
@@ -758,6 +761,8 @@ void CooperativeMatrixRunner::RenderUI()
758761 ImGui::BeginDisabled (disable_ui);
759762 }
760763
764+ ImGui::Text (" For accurate values, make sure you are using the right device configurations (check 'Device Configuration' tab)" );
765+
761766 if (!m_test_groups.empty ())
762767 {
763768 for (int i=0 ; i< m_test_groups.size (); i++)
@@ -864,6 +869,10 @@ void CooperativeMatrixRunner::RenderUI()
864869
865870 ImGui::Text (" [Time]: %.2fus" , test_result.time_total );
866871 ImGui::Text (" [TOPS]: %.2f" , test_result.TOPS );
872+
873+ if (m_test_type == TT_CONV)
874+ ImGui::TextDisabled (" WxH = %dx%d" , test_description.inputWidth , test_description.inputHeight );
875+
867876 ImVec4 color = GetPercentageColor (test_result.percentage / 100 .0f );
868877 ImGui::PushStyleColor (ImGuiCol_Text, color);
869878 ImGui::Text (" [%%]: %.2f" , test_result.percentage );
@@ -957,8 +966,8 @@ void CooperativeMatrixRunner::PrepareTestSession()
957966 new_test_description.gpu_freq_MHz = m_gpu_freq_MHz;
958967 new_test_description.test_type = m_test_type;
959968
960- new_test_description.inputWidth = 8 ;
961- new_test_description.inputHeight = 8 ;
969+ new_test_description.inputWidth = m_input_width ;
970+ new_test_description.inputHeight = m_input_height ;
962971
963972 new_test_description.input_type = test_template_description.input_type ;
964973 new_test_description.output_type = test_template_description.output_type ;
@@ -1132,6 +1141,27 @@ std::optional<CooperativeMatrixRunner::TestResult> CooperativeMatrixRunner::RunT
11321141 return std::nullopt ;
11331142 }
11341143
1144+ if (m_normalize_inputs)
1145+ {
1146+ int required_area = MSizeInBlocks * cooperativeMatrixProps.MSize ;
1147+
1148+ // Start with inputWidth as-is, compute height to match required_area
1149+ if (inputWidth <= 0 ) inputWidth = 1 ; // safety
1150+ inputHeight = required_area / inputWidth;
1151+
1152+ // If division leaves remainder, just force height to match
1153+ if (inputWidth * inputHeight != required_area)
1154+ {
1155+ inputHeight = required_area / inputWidth;
1156+ if (inputWidth * inputHeight != required_area)
1157+ {
1158+ // Last resort: set width = required_area, height = 1
1159+ inputWidth = required_area;
1160+ inputHeight = 1 ;
1161+ }
1162+ }
1163+ }
1164+
11351165 // Set local_size (workgroup size) based on GPU/Tier, and datatype (fp32, fp16, etc)
11361166 // Default for 'unknown' or gpu/tier not recohgnized is local_size(64,2,2) for all datatyes
11371167 uint32_t local_size_x = 0 , local_size_y = 0 , local_size_z = 0 ;
0 commit comments