Skip to content

Commit 319d781

Browse files
committed
gpu - pad out elem loop for shared/gen
1 parent 2cbb475 commit 319d781

File tree

8 files changed

+723
-222
lines changed

8 files changed

+723
-222
lines changed

backends/cuda-gen/ceed-cuda-gen-operator-build.cpp

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -473,13 +473,17 @@ static int CeedOperatorBuildKernelRestriction_Cuda_gen(std::ostringstream &code,
473473
CeedInt comp_stride;
474474

475475
CeedCallBackend(CeedElemRestrictionGetLVectorSize(elem_rstr, &l_size));
476+
code << tab << "if (e < num_elem) {\n";
477+
tab.push();
476478
code << tab << "const CeedInt l_size" << var_suffix << " = " << l_size << ";\n";
477479
CeedCallBackend(CeedElemRestrictionGetCompStride(elem_rstr, &comp_stride));
478-
code << tab << "const CeedInt comp_stride" << var_suffix << " = " << comp_stride << ";\n";
480+
code << tab << "const CeedInt comp_stride" << var_suffix << " = " << comp_stride << ";\n\n";
479481
data->indices.outputs[i] = (CeedInt *)rstr_data->d_offsets;
480482
code << tab << "WriteLVecStandard" << (is_all_tensor ? max_dim : 1) << "d<num_comp" << var_suffix << ", comp_stride" << var_suffix << ", "
481483
<< P_name << ">(data, l_size" << var_suffix << ", elem, indices.outputs[" << i << "], r_e" << var_suffix << ", d" << var_suffix
482484
<< ");\n";
485+
tab.pop();
486+
code << tab << "}\n";
483487
break;
484488
}
485489
case CEED_RESTRICTION_STRIDED: {
@@ -493,11 +497,15 @@ static int CeedOperatorBuildKernelRestriction_Cuda_gen(std::ostringstream &code,
493497
if (!has_backend_strides) {
494498
CeedCallBackend(CeedElemRestrictionGetStrides(elem_rstr, strides));
495499
}
500+
code << tab << "if (e < num_elem) {\n";
501+
tab.push();
496502
code << tab << "const CeedInt strides" << var_suffix << "_0 = " << strides[0] << ", strides" << var_suffix << "_1 = " << strides[1]
497-
<< ", strides" << var_suffix << "_2 = " << strides[2] << ";\n";
503+
<< ", strides" << var_suffix << "_2 = " << strides[2] << ";\n\n";
498504
code << tab << "WriteLVecStrided" << (is_all_tensor ? max_dim : 1) << "d<num_comp" << var_suffix << ", " << P_name << ", strides"
499505
<< var_suffix << "_0, strides" << var_suffix << "_1, strides" << var_suffix << "_2>(data, elem, r_e" << var_suffix << ", d" << var_suffix
500506
<< ");\n";
507+
tab.pop();
508+
code << tab << "}\n";
501509
break;
502510
}
503511
case CEED_RESTRICTION_POINTS:
@@ -1033,10 +1041,14 @@ static int CeedOperatorBuildKernelQFunction_Cuda_gen(std::ostringstream &code, C
10331041
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_rstr));
10341042
CeedCallBackend(CeedElemRestrictionGetCompStride(elem_rstr, &comp_stride));
10351043
CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr));
1036-
code << tab << "const CeedInt comp_stride" << var_suffix << " = " << comp_stride << ";\n";
1044+
code << tab << "if (e < num_elem) {\n";
1045+
tab.push();
1046+
code << tab << "const CeedInt comp_stride" << var_suffix << " = " << comp_stride << ";\n\n";
10371047
code << tab << "WritePoint<num_comp" << var_suffix << ", comp_stride" << var_suffix
10381048
<< ", max_num_points>(data, elem, i, points.num_per_elem[elem], indices.outputs[" << i << "]"
10391049
<< ", r_s" << var_suffix << ", d" << var_suffix << ");\n";
1050+
tab.pop();
1051+
code << tab << "}\n";
10401052
break;
10411053
}
10421054
case CEED_EVAL_INTERP:
@@ -1482,8 +1494,10 @@ extern "C" int CeedOperatorBuildKernel_Cuda_gen(CeedOperator op, bool *is_good_b
14821494
// Loop over all elements
14831495
code << "\n" << tab << "// Element loop\n";
14841496
code << tab << "__syncthreads();\n";
1485-
code << tab << "for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < num_elem; elem += gridDim.x*blockDim.z) {\n";
1497+
code << tab << "const CeedInt elem_loop_bound = num_elem * ceil(1.0*num_elem/(gridDim.x*blockDim.z));\n\n";
1498+
code << tab << "for (CeedInt e = blockIdx.x*blockDim.z + threadIdx.z; e < elem_loop_bound; e += gridDim.x*blockDim.z) {\n";
14861499
tab.push();
1500+
code << tab << "const CeedInt elem = e % num_elem;\n\n";
14871501

14881502
// -- Compute minimum buffer space needed
14891503
CeedInt max_rstr_buffer_size = 1;
@@ -2042,11 +2056,15 @@ static int CeedOperatorBuildKernelAssemblyAtPoints_Cuda_gen(CeedOperator op, boo
20422056

20432057
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_rstr));
20442058
CeedCallBackend(CeedElemRestrictionGetLVectorSize(elem_rstr, &l_size));
2059+
code << tab << "if (e < num_elem) {\n";
2060+
tab.push();
20452061
code << tab << "const CeedInt l_size" << var_suffix << " = " << l_size << ";\n";
20462062
CeedCallBackend(CeedElemRestrictionGetCompStride(elem_rstr, &comp_stride));
2047-
code << tab << "const CeedInt comp_stride" << var_suffix << " = " << comp_stride << ";\n";
2063+
code << tab << "const CeedInt comp_stride" << var_suffix << " = " << comp_stride << ";\n\n";
20482064
code << tab << "WriteLVecStandard" << max_dim << "d_Assembly<num_comp" << var_suffix << ", comp_stride" << var_suffix << ", P_1d" + var_suffix
20492065
<< ">(data, l_size" << var_suffix << ", elem, n, r_e" << var_suffix << ", values_array);\n";
2066+
tab.pop();
2067+
code << tab << "}\n";
20502068
CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr));
20512069
} else {
20522070
std::string var_suffix = "_out_" + std::to_string(i);
@@ -2056,11 +2074,15 @@ static int CeedOperatorBuildKernelAssemblyAtPoints_Cuda_gen(CeedOperator op, boo
20562074

20572075
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_rstr));
20582076
CeedCallBackend(CeedElemRestrictionGetLVectorSize(elem_rstr, &l_size));
2077+
code << tab << "if (e < num_elem) {\n";
2078+
tab.push();
20592079
code << tab << "const CeedInt l_size" << var_suffix << " = " << l_size << ";\n";
20602080
CeedCallBackend(CeedElemRestrictionGetCompStride(elem_rstr, &comp_stride));
2061-
code << tab << "const CeedInt comp_stride" << var_suffix << " = " << comp_stride << ";\n";
2081+
code << tab << "const CeedInt comp_stride" << var_suffix << " = " << comp_stride << ";\n\n";
20622082
code << tab << "WriteLVecStandard" << max_dim << "d_Single<num_comp" << var_suffix << ", comp_stride" << var_suffix << ", P_1d" + var_suffix
20632083
<< ">(data, l_size" << var_suffix << ", elem, n, indices.outputs[" << i << "], r_e" << var_suffix << ", values_array);\n";
2084+
tab.pop();
2085+
code << tab << "}\n";
20642086
CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr));
20652087
}
20662088
}
@@ -2642,8 +2664,12 @@ extern "C" int CeedOperatorBuildKernelLinearAssembleQFunction_Cuda_gen(CeedOpera
26422664
// ---- Restriction
26432665
CeedInt field_size;
26442666

2667+
code << tab << "if (e < num_elem) {\n";
2668+
tab.push();
26452669
code << tab << "WriteLVecStandard" << (is_all_tensor ? max_dim : 1) << "d_QFAssembly<total_size_out, field_size_out_" << i << ", "
26462670
<< (is_all_tensor ? "Q_1d" : "Q") << ">(data, num_elem, elem, input_offset + s, " << offset << ", r_q_out_" << i << ", values_array);\n";
2671+
tab.pop();
2672+
code << tab << "}\n";
26472673
CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[i], &field_size));
26482674
offset += field_size;
26492675
}

backends/hip-gen/ceed-hip-gen-operator-build.cpp

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -500,13 +500,17 @@ static int CeedOperatorBuildKernelRestriction_Hip_gen(std::ostringstream &code,
500500
CeedInt comp_stride;
501501

502502
CeedCallBackend(CeedElemRestrictionGetLVectorSize(elem_rstr, &l_size));
503+
code << tab << "if (e < num_elem) {\n";
504+
tab.push();
503505
code << tab << "const CeedInt l_size" << var_suffix << " = " << l_size << ";\n";
504506
CeedCallBackend(CeedElemRestrictionGetCompStride(elem_rstr, &comp_stride));
505-
code << tab << "const CeedInt comp_stride" << var_suffix << " = " << comp_stride << ";\n";
507+
code << tab << "const CeedInt comp_stride" << var_suffix << " = " << comp_stride << ";\n\n";
506508
data->indices.outputs[i] = (CeedInt *)rstr_data->d_offsets;
507509
code << tab << "WriteLVecStandard" << (is_all_tensor ? max_dim : 1) << "d<num_comp" << var_suffix << ", comp_stride" << var_suffix << ", "
508510
<< P_name << ">(data, l_size" << var_suffix << ", elem, indices.outputs[" << i << "], r_e" << var_suffix << ", d" << var_suffix
509511
<< ");\n";
512+
tab.pop();
513+
code << tab << "}\n";
510514
break;
511515
}
512516
case CEED_RESTRICTION_STRIDED: {
@@ -520,11 +524,15 @@ static int CeedOperatorBuildKernelRestriction_Hip_gen(std::ostringstream &code,
520524
if (!has_backend_strides) {
521525
CeedCallBackend(CeedElemRestrictionGetStrides(elem_rstr, strides));
522526
}
527+
code << tab << "if (e < num_elem) <\n";
528+
tab.push();
523529
code << tab << "const CeedInt strides" << var_suffix << "_0 = " << strides[0] << ", strides" << var_suffix << "_1 = " << strides[1]
524-
<< ", strides" << var_suffix << "_2 = " << strides[2] << ";\n";
530+
<< ", strides" << var_suffix << "_2 = " << strides[2] << ";\n\n";
525531
code << tab << "WriteLVecStrided" << (is_all_tensor ? max_dim : 1) << "d<num_comp" << var_suffix << ", " << P_name << ", strides"
526532
<< var_suffix << "_0, strides" << var_suffix << "_1, strides" << var_suffix << "_2>(data, elem, r_e" << var_suffix << ", d" << var_suffix
527533
<< ");\n";
534+
tab.pop();
535+
code << tab << "}\n";
528536
break;
529537
}
530538
case CEED_RESTRICTION_POINTS:
@@ -1060,10 +1068,14 @@ static int CeedOperatorBuildKernelQFunction_Hip_gen(std::ostringstream &code, Ce
10601068
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_rstr));
10611069
CeedCallBackend(CeedElemRestrictionGetCompStride(elem_rstr, &comp_stride));
10621070
CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr));
1063-
code << tab << "const CeedInt comp_stride" << var_suffix << " = " << comp_stride << ";\n";
1071+
code << tab << "if (e < num_elem) {\n";
1072+
tab.push();
1073+
code << tab << "const CeedInt comp_stride" << var_suffix << " = " << comp_stride << ";\n\n";
10641074
code << tab << "WritePoint<num_comp" << var_suffix << ", comp_stride" << var_suffix
10651075
<< ", max_num_points>(data, elem, i, points.num_per_elem[elem], indices.outputs[" << i << "]"
10661076
<< ", r_s" << var_suffix << ", d" << var_suffix << ");\n";
1077+
tab.pop();
1078+
code << tab << "}\n";
10671079
break;
10681080
}
10691081
case CEED_EVAL_INTERP:
@@ -1495,8 +1507,10 @@ extern "C" int CeedOperatorBuildKernel_Hip_gen(CeedOperator op, bool *is_good_bu
14951507
// Loop over all elements
14961508
code << "\n" << tab << "// Element loop\n";
14971509
code << tab << "__syncthreads();\n";
1498-
code << tab << "for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < num_elem; elem += gridDim.x*blockDim.z) {\n";
1510+
code << tab << "const CeedInt elem_loop_bound = num_elem * ceil(1.0*num_elem/(gridDim.x*blockDim.z));\n\n";
1511+
code << tab << "for (CeedInt e = blockIdx.x*blockDim.z + threadIdx.z; e < elem_loop_bound; e += gridDim.x*blockDim.z) {\n";
14991512
tab.push();
1513+
code << tab << "const CeedInt elem = e % num_elem;\n\n";
15001514

15011515
// -- Compute minimum buffer space needed
15021516
CeedInt max_rstr_buffer_size = 1;
@@ -2047,11 +2061,15 @@ static int CeedOperatorBuildKernelAssemblyAtPoints_Hip_gen(CeedOperator op, bool
20472061

20482062
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_rstr));
20492063
CeedCallBackend(CeedElemRestrictionGetLVectorSize(elem_rstr, &l_size));
2064+
code << tab << "if (e < num_elem) {\n";
2065+
tab.push();
20502066
code << tab << "const CeedInt l_size" << var_suffix << " = " << l_size << ";\n";
20512067
CeedCallBackend(CeedElemRestrictionGetCompStride(elem_rstr, &comp_stride));
2052-
code << tab << "const CeedInt comp_stride" << var_suffix << " = " << comp_stride << ";\n";
2068+
code << tab << "const CeedInt comp_stride" << var_suffix << " = " << comp_stride << ";\n\n";
20532069
code << tab << "WriteLVecStandard" << max_dim << "d_Assembly<num_comp" << var_suffix << ", comp_stride" << var_suffix << ", P_1d" + var_suffix
20542070
<< ">(data, l_size" << var_suffix << ", elem, n, r_e" << var_suffix << ", values_array);\n";
2071+
tab.pop();
2072+
code << tab << "}\n";
20552073
CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr));
20562074
} else {
20572075
std::string var_suffix = "_out_" + std::to_string(i);
@@ -2061,11 +2079,15 @@ static int CeedOperatorBuildKernelAssemblyAtPoints_Hip_gen(CeedOperator op, bool
20612079

20622080
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_rstr));
20632081
CeedCallBackend(CeedElemRestrictionGetLVectorSize(elem_rstr, &l_size));
2082+
code << tab << "if (e < num_elem) {\n";
2083+
tab.push();
20642084
code << tab << "const CeedInt l_size" << var_suffix << " = " << l_size << ";\n";
20652085
CeedCallBackend(CeedElemRestrictionGetCompStride(elem_rstr, &comp_stride));
2066-
code << tab << "const CeedInt comp_stride" << var_suffix << " = " << comp_stride << ";\n";
2086+
code << tab << "const CeedInt comp_stride" << var_suffix << " = " << comp_stride << ";\n\n";
20672087
code << tab << "WriteLVecStandard" << max_dim << "d_Single<num_comp" << var_suffix << ", comp_stride" << var_suffix << ", P_1d" + var_suffix
20682088
<< ">(data, l_size" << var_suffix << ", elem, n, indices.outputs[" << i << "], r_e" << var_suffix << ", values_array);\n";
2089+
tab.pop();
2090+
code << tab << "}\n";
20692091
CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr));
20702092
}
20712093
}
@@ -2638,8 +2660,12 @@ extern "C" int CeedOperatorBuildKernelLinearAssembleQFunction_Hip_gen(CeedOperat
26382660
// ---- Restriction
26392661
CeedInt field_size;
26402662

2663+
code << tab << "if (e < num_elem) {\n";
2664+
tab.push();
26412665
code << tab << "WriteLVecStandard" << (is_all_tensor ? max_dim : 1) << "d_QFAssembly<total_size_out, field_size_out_" << i << ", "
26422666
<< (is_all_tensor ? "Q_1d" : "Q") << ">(data, num_elem, elem, input_offset + s, " << offset << ", r_q_out_" << i << ", values_array);\n";
2667+
tab.pop();
2668+
code << tab << "}\n";
26432669
CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[i], &field_size));
26442670
offset += field_size;
26452671
}

0 commit comments

Comments
 (0)