4646
4747#include " integer_advanced_indexing.hpp"
4848
49- #define INDEXING_MODES 2
50- #define WRAP_MODE 0
51- #define CLIP_MODE 1
52-
5349namespace dpctl
5450{
5551namespace tensor
@@ -62,11 +58,17 @@ namespace td_ns = dpctl::tensor::type_dispatch;
6258using dpctl::tensor::kernels::indexing::put_fn_ptr_t ;
6359using dpctl::tensor::kernels::indexing::take_fn_ptr_t ;
6460
65- static take_fn_ptr_t take_dispatch_table[INDEXING_MODES][td_ns::num_types]
66- [td_ns::num_types];
61+ static take_fn_ptr_t take_wrap_dispatch_table[td_ns::num_types]
62+ [td_ns::num_types];
63+
64+ static take_fn_ptr_t take_clip_dispatch_table[td_ns::num_types]
65+ [td_ns::num_types];
66+
67+ static put_fn_ptr_t put_wrap_dispatch_table[INDEXING_MODES][td_ns::num_types]
68+ [td_ns::num_types];
6769
68- static put_fn_ptr_t put_dispatch_table [INDEXING_MODES][td_ns::num_types]
69- [td_ns::num_types];
70+ static put_fn_ptr_t put_clip_dispatch_table [INDEXING_MODES][td_ns::num_types]
71+ [td_ns::num_types];
7072
7173namespace py = pybind11;
7274
@@ -486,7 +488,8 @@ py_take(const dpctl::tensor::usm_ndarray &src,
486488 std::end (pack_deps));
487489 all_deps.insert (std::end (all_deps), std::begin (depends), std::end (depends));
488490
489- auto fn = take_dispatch_table[mode][src_type_id][ind_type_id];
491+ auto fn = mode ? take_wrap_dispatch_table[src_type_id][ind_type_id]
492+ : take_clip_dispatch_table[src_type_id][ind_type_id];
490493
491494 if (fn == nullptr ) {
492495 sycl::event::wait (host_task_events);
@@ -755,7 +758,8 @@ py_put(const dpctl::tensor::usm_ndarray &dst,
755758 std::end (pack_deps));
756759 all_deps.insert (std::end (all_deps), std::begin (depends), std::end (depends));
757760
758- auto fn = put_dispatch_table[mode][dst_type_id][ind_type_id];
761+ auto fn = mode ? put_wrap_dispatch_table[src_type_id][ind_type_id]
762+ : put_clip_dispatch_table[src_type_id][ind_type_id];
759763
760764 if (fn == nullptr ) {
761765 sycl::event::wait (host_task_events);
@@ -790,20 +794,20 @@ void init_advanced_indexing_dispatch_tables(void)
790794 using dpctl::tensor::kernels::indexing::TakeClipFactory;
791795 DispatchTableBuilder<take_fn_ptr_t , TakeClipFactory, num_types>
792796 dtb_takeclip;
793- dtb_takeclip.populate_dispatch_table (take_dispatch_table[CLIP_MODE] );
797+ dtb_takeclip.populate_dispatch_table (take_clip_dispatch_table );
794798
795799 using dpctl::tensor::kernels::indexing::TakeWrapFactory;
796800 DispatchTableBuilder<take_fn_ptr_t , TakeWrapFactory, num_types>
797801 dtb_takewrap;
798- dtb_takewrap.populate_dispatch_table (take_dispatch_table[WRAP_MODE] );
802+ dtb_takewrap.populate_dispatch_table (take_wrap_dispatch_table );
799803
800804 using dpctl::tensor::kernels::indexing::PutClipFactory;
801805 DispatchTableBuilder<put_fn_ptr_t , PutClipFactory, num_types> dtb_putclip;
802- dtb_putclip.populate_dispatch_table (put_dispatch_table[CLIP_MODE] );
806+ dtb_putclip.populate_dispatch_table (put_clip_dispatch_table );
803807
804808 using dpctl::tensor::kernels::indexing::PutWrapFactory;
805809 DispatchTableBuilder<put_fn_ptr_t , PutWrapFactory, num_types> dtb_putwrap;
806- dtb_putwrap.populate_dispatch_table (put_dispatch_table[WRAP_MODE] );
810+ dtb_putwrap.populate_dispatch_table (put_wrap_dispatch_table );
807811}
808812
809813} // namespace py_internal
0 commit comments