@@ -473,13 +473,17 @@ static int CeedOperatorBuildKernelRestriction_Cuda_gen(std::ostringstream &code,
473473 CeedInt comp_stride;
474474
475475 CeedCallBackend (CeedElemRestrictionGetLVectorSize (elem_rstr, &l_size));
476+ code << tab << " if (e < num_elem) {\n " ;
477+ tab.push ();
476478 code << tab << " const CeedInt l_size" << var_suffix << " = " << l_size << " ;\n " ;
477479 CeedCallBackend (CeedElemRestrictionGetCompStride (elem_rstr, &comp_stride));
478- code << tab << " const CeedInt comp_stride" << var_suffix << " = " << comp_stride << " ;\n " ;
480+ code << tab << " const CeedInt comp_stride" << var_suffix << " = " << comp_stride << " ;\n\n " ;
479481 data->indices .outputs [i] = (CeedInt *)rstr_data->d_offsets ;
480482 code << tab << " WriteLVecStandard" << (is_all_tensor ? max_dim : 1 ) << " d<num_comp" << var_suffix << " , comp_stride" << var_suffix << " , "
481483 << P_name << " >(data, l_size" << var_suffix << " , elem, indices.outputs[" << i << " ], r_e" << var_suffix << " , d" << var_suffix
482484 << " );\n " ;
485+ tab.pop ();
486+ code << tab << " }\n " ;
483487 break ;
484488 }
485489 case CEED_RESTRICTION_STRIDED: {
@@ -493,11 +497,15 @@ static int CeedOperatorBuildKernelRestriction_Cuda_gen(std::ostringstream &code,
493497 if (!has_backend_strides) {
494498 CeedCallBackend (CeedElemRestrictionGetStrides (elem_rstr, strides));
495499 }
500+ code << tab << " if (e < num_elem) {\n " ;
501+ tab.push ();
496502 code << tab << " const CeedInt strides" << var_suffix << " _0 = " << strides[0 ] << " , strides" << var_suffix << " _1 = " << strides[1 ]
497- << " , strides" << var_suffix << " _2 = " << strides[2 ] << " ;\n " ;
503+ << " , strides" << var_suffix << " _2 = " << strides[2 ] << " ;\n\n " ;
498504 code << tab << " WriteLVecStrided" << (is_all_tensor ? max_dim : 1 ) << " d<num_comp" << var_suffix << " , " << P_name << " , strides"
499505 << var_suffix << " _0, strides" << var_suffix << " _1, strides" << var_suffix << " _2>(data, elem, r_e" << var_suffix << " , d" << var_suffix
500506 << " );\n " ;
507+ tab.pop ();
508+ code << tab << " }\n " ;
501509 break ;
502510 }
503511 case CEED_RESTRICTION_POINTS:
@@ -1033,10 +1041,14 @@ static int CeedOperatorBuildKernelQFunction_Cuda_gen(std::ostringstream &code, C
10331041 CeedCallBackend (CeedOperatorFieldGetElemRestriction (op_output_fields[i], &elem_rstr));
10341042 CeedCallBackend (CeedElemRestrictionGetCompStride (elem_rstr, &comp_stride));
10351043 CeedCallBackend (CeedElemRestrictionDestroy (&elem_rstr));
1036- code << tab << " const CeedInt comp_stride" << var_suffix << " = " << comp_stride << " ;\n " ;
1044+ code << tab << " if (e < num_elem) {\n " ;
1045+ tab.push ();
1046+ code << tab << " const CeedInt comp_stride" << var_suffix << " = " << comp_stride << " ;\n\n " ;
10371047 code << tab << " WritePoint<num_comp" << var_suffix << " , comp_stride" << var_suffix
10381048 << " , max_num_points>(data, elem, i, points.num_per_elem[elem], indices.outputs[" << i << " ]"
10391049 << " , r_s" << var_suffix << " , d" << var_suffix << " );\n " ;
1050+ tab.pop ();
1051+ code << tab << " }\n " ;
10401052 break ;
10411053 }
10421054 case CEED_EVAL_INTERP:
@@ -1482,8 +1494,10 @@ extern "C" int CeedOperatorBuildKernel_Cuda_gen(CeedOperator op, bool *is_good_b
14821494 // Loop over all elements
14831495 code << " \n " << tab << " // Element loop\n " ;
14841496 code << tab << " __syncthreads();\n " ;
1485- code << tab << " for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < num_elem; elem += gridDim.x*blockDim.z) {\n " ;
1497+ code << tab << " const CeedInt elem_loop_bound = num_elem * ceil(1.0*num_elem/(gridDim.x*blockDim.z));\n\n " ;
1498+ code << tab << " for (CeedInt e = blockIdx.x*blockDim.z + threadIdx.z; e < elem_loop_bound; e += gridDim.x*blockDim.z) {\n " ;
14861499 tab.push ();
1500+ code << tab << " const CeedInt elem = e % num_elem;\n\n " ;
14871501
14881502 // -- Compute minimum buffer space needed
14891503 CeedInt max_rstr_buffer_size = 1 ;
@@ -2042,11 +2056,15 @@ static int CeedOperatorBuildKernelAssemblyAtPoints_Cuda_gen(CeedOperator op, boo
20422056
20432057 CeedCallBackend (CeedOperatorFieldGetElemRestriction (op_output_fields[i], &elem_rstr));
20442058 CeedCallBackend (CeedElemRestrictionGetLVectorSize (elem_rstr, &l_size));
2059+ code << tab << " if (e < num_elem) {\n " ;
2060+ tab.push ();
20452061 code << tab << " const CeedInt l_size" << var_suffix << " = " << l_size << " ;\n " ;
20462062 CeedCallBackend (CeedElemRestrictionGetCompStride (elem_rstr, &comp_stride));
2047- code << tab << " const CeedInt comp_stride" << var_suffix << " = " << comp_stride << " ;\n " ;
2063+ code << tab << " const CeedInt comp_stride" << var_suffix << " = " << comp_stride << " ;\n\n " ;
20482064 code << tab << " WriteLVecStandard" << max_dim << " d_Assembly<num_comp" << var_suffix << " , comp_stride" << var_suffix << " , P_1d" + var_suffix
20492065 << " >(data, l_size" << var_suffix << " , elem, n, r_e" << var_suffix << " , values_array);\n " ;
2066+ tab.pop ();
2067+ code << tab << " }\n " ;
20502068 CeedCallBackend (CeedElemRestrictionDestroy (&elem_rstr));
20512069 } else {
20522070 std::string var_suffix = " _out_" + std::to_string (i);
@@ -2056,11 +2074,15 @@ static int CeedOperatorBuildKernelAssemblyAtPoints_Cuda_gen(CeedOperator op, boo
20562074
20572075 CeedCallBackend (CeedOperatorFieldGetElemRestriction (op_output_fields[i], &elem_rstr));
20582076 CeedCallBackend (CeedElemRestrictionGetLVectorSize (elem_rstr, &l_size));
2077+ code << tab << " if (e < num_elem) {\n " ;
2078+ tab.push ();
20592079 code << tab << " const CeedInt l_size" << var_suffix << " = " << l_size << " ;\n " ;
20602080 CeedCallBackend (CeedElemRestrictionGetCompStride (elem_rstr, &comp_stride));
2061- code << tab << " const CeedInt comp_stride" << var_suffix << " = " << comp_stride << " ;\n " ;
2081+ code << tab << " const CeedInt comp_stride" << var_suffix << " = " << comp_stride << " ;\n\n " ;
20622082 code << tab << " WriteLVecStandard" << max_dim << " d_Single<num_comp" << var_suffix << " , comp_stride" << var_suffix << " , P_1d" + var_suffix
20632083 << " >(data, l_size" << var_suffix << " , elem, n, indices.outputs[" << i << " ], r_e" << var_suffix << " , values_array);\n " ;
2084+ tab.pop ();
2085+ code << tab << " }\n " ;
20642086 CeedCallBackend (CeedElemRestrictionDestroy (&elem_rstr));
20652087 }
20662088 }
@@ -2642,8 +2664,12 @@ extern "C" int CeedOperatorBuildKernelLinearAssembleQFunction_Cuda_gen(CeedOpera
26422664 // ---- Restriction
26432665 CeedInt field_size;
26442666
2667+ code << tab << " if (e < num_elem) {\n " ;
2668+ tab.push ();
26452669 code << tab << " WriteLVecStandard" << (is_all_tensor ? max_dim : 1 ) << " d_QFAssembly<total_size_out, field_size_out_" << i << " , "
26462670 << (is_all_tensor ? " Q_1d" : " Q" ) << " >(data, num_elem, elem, input_offset + s, " << offset << " , r_q_out_" << i << " , values_array);\n " ;
2671+ tab.pop ();
2672+ code << tab << " }\n " ;
26472673 CeedCallBackend (CeedQFunctionFieldGetSize (qf_output_fields[i], &field_size));
26482674 offset += field_size;
26492675 }
0 commit comments