Skip to content

Commit ae88298

Browse files
committed
Fix regression in batched SSE2 patch.
Signed-off-by: Tuomas Tonteri <[email protected]>
1 parent e0197db commit ae88298

File tree

1 file changed

+74
-31
lines changed

1 file changed

+74
-31
lines changed

src/liboslexec/llvm_util.cpp

Lines changed: 74 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -519,7 +519,7 @@ LLVM_Util::LLVM_Util(const PerThreadInfo& per_thread_info, int debuglevel,
519519
// TODO: why are there casts to the base class llvm::Type *?
520520
m_vector_width = OIIO::floor2(OIIO::clamp(m_vector_width, 4, 16));
521521
m_llvm_type_wide_float = llvm_vector_type(m_llvm_type_float,
522-
m_vector_width);
522+
m_vector_width);
523523
m_llvm_type_wide_double = llvm_vector_type(m_llvm_type_double,
524524
m_vector_width);
525525
m_llvm_type_wide_int = llvm_vector_type(m_llvm_type_int, m_vector_width);
@@ -790,8 +790,8 @@ LLVM_Util::debug_push_inlined_function(OIIO::ustring function_name,
790790
method_scope_line, // Scope Line,
791791
fnFlags,
792792
llvm::DISubprogram::toSPFlags(true /*isLocalToUnit*/,
793-
true /*isDefinition*/,
794-
true /*false*/ /*isOptimized*/));
793+
true /*isDefinition*/,
794+
true /*false*/ /*isOptimized*/));
795795

796796
mLexicalBlocks.push_back(function);
797797
}
@@ -3698,12 +3698,21 @@ LLVM_Util::mask_as_int(llvm::Value* mask)
36983698
// Convert <4 x i1> -> <4 x i32>
36993699
llvm::Value* w4_int_mask = builder().CreateSExt(mask,
37003700
type_wide_int());
3701+
3702+
// Now we will use the horizontal sign extraction intrinsic
3703+
// to build a 32 bit mask value. However the only 256bit
3704+
// version works on floats, so we will cast from int32 to
3705+
// float beforehand
3706+
llvm::Type* w4_float_type = llvm_vector_type(m_llvm_type_float, 4);
3707+
llvm::Value* w4_float_mask = builder().CreateBitCast(w4_int_mask,
3708+
w4_float_type);
3709+
37013710
// Now we will use the horizontal sign extraction intrinsic
37023711
// to build a 32 bit mask value.
37033712
llvm::Function* func = llvm::Intrinsic::getDeclaration(
3704-
module(), llvm::Intrinsic::x86_sse2_pmovmskb_128);
3713+
module(), llvm::Intrinsic::x86_sse_movmsk_ps);
37053714

3706-
llvm::Value* args[1] = { w4_int_mask };
3715+
llvm::Value* args[1] = { w4_float_mask };
37073716
llvm::Value* int8_mask;
37083717
int8_mask = builder().CreateCall(func, toArrayRef(args));
37093718
return int8_mask;
@@ -3727,18 +3736,28 @@ LLVM_Util::mask_as_int(llvm::Value* mask)
37273736
auto w4_int_masks = op_quarter_16x(wide_int_mask);
37283737

37293738
// Now we will use the horizontal sign extraction intrinsic
3730-
// to build a 32 bit mask value.
3739+
// to build a 32 bit mask value. However the only 128bit
3740+
// version works on floats, so we will cast from int32 to
3741+
// float beforehand
3742+
llvm::Type* w4_float_type = llvm_vector_type(m_llvm_type_float, 4);
3743+
std::array<llvm::Value*, 4> w4_float_masks = {
3744+
{ builder().CreateBitCast(w4_int_masks[0], w4_float_type),
3745+
builder().CreateBitCast(w4_int_masks[1], w4_float_type),
3746+
builder().CreateBitCast(w4_int_masks[2], w4_float_type),
3747+
builder().CreateBitCast(w4_int_masks[3], w4_float_type) }
3748+
};
3749+
37313750
llvm::Function* func = llvm::Intrinsic::getDeclaration(
3732-
module(), llvm::Intrinsic::x86_sse2_pmovmskb_128);
3751+
module(), llvm::Intrinsic::x86_sse_movmsk_ps);
37333752

3734-
llvm::Value* args[1] = { w4_int_masks[0] };
3753+
llvm::Value* args[1] = { w4_float_masks[0] };
37353754
std::array<llvm::Value*, 4> int4_masks;
37363755
int4_masks[0] = builder().CreateCall(func, toArrayRef(args));
3737-
args[0] = w4_int_masks[1];
3756+
args[0] = w4_float_masks[1];
37383757
int4_masks[1] = builder().CreateCall(func, toArrayRef(args));
3739-
args[0] = w4_int_masks[2];
3758+
args[0] = w4_float_masks[2];
37403759
int4_masks[2] = builder().CreateCall(func, toArrayRef(args));
3741-
args[0] = w4_int_masks[3];
3760+
args[0] = w4_float_masks[3];
37423761
int4_masks[3] = builder().CreateCall(func, toArrayRef(args));
37433762

37443763
llvm::Value* bits12_15 = op_shl(int4_masks[3], constant(12));
@@ -3759,14 +3778,22 @@ LLVM_Util::mask_as_int(llvm::Value* mask)
37593778
auto w4_int_masks = op_split_8x(wide_int_mask);
37603779

37613780
// Now we will use the horizontal sign extraction intrinsic
3762-
// to build a 32 bit mask value.
3781+
// to build a 32 bit mask value. However the only 128bit
3782+
// version works on floats, so we will cast from int32 to
3783+
// float beforehand
3784+
llvm::Type* w4_float_type = llvm_vector_type(m_llvm_type_float, 4);
3785+
std::array<llvm::Value*, 2> w4_float_masks = {
3786+
{ builder().CreateBitCast(w4_int_masks[0], w4_float_type),
3787+
builder().CreateBitCast(w4_int_masks[1], w4_float_type) }
3788+
};
3789+
37633790
llvm::Function* func = llvm::Intrinsic::getDeclaration(
3764-
module(), llvm::Intrinsic::x86_sse2_pmovmskb_128);
3791+
module(), llvm::Intrinsic::x86_sse_movmsk_ps);
37653792

3766-
llvm::Value* args[1] = { w4_int_masks[0] };
3793+
llvm::Value* args[1] = { w4_float_masks[0] };
37673794
std::array<llvm::Value*, 2> int4_masks;
37683795
int4_masks[0] = builder().CreateCall(func, toArrayRef(args));
3769-
args[0] = w4_int_masks[1];
3796+
args[0] = w4_float_masks[1];
37703797
int4_masks[1] = builder().CreateCall(func, toArrayRef(args));
37713798

37723799
llvm::Value* bits4_7 = op_shl(int4_masks[1], constant(4));
@@ -3782,12 +3809,20 @@ LLVM_Util::mask_as_int(llvm::Value* mask)
37823809
llvm::Value* w4_int_mask = builder().CreateSExt(mask,
37833810
type_wide_int());
37843811

3812+
// Now we will use the horizontal sign extraction intrinsic
3813+
// to build a 32 bit mask value. However the only 256bit
3814+
// version works on floats, so we will cast from int32 to
3815+
// float beforehand
3816+
llvm::Type* w4_float_type = llvm_vector_type(m_llvm_type_float, 4);
3817+
llvm::Value* w4_float_mask = builder().CreateBitCast(w4_int_mask,
3818+
w4_float_type);
3819+
37853820
// Now we will use the horizontal sign extraction intrinsic
37863821
// to build a 32 bit mask value.
37873822
llvm::Function* func = llvm::Intrinsic::getDeclaration(
3788-
module(), llvm::Intrinsic::x86_sse2_pmovmskb_128);
3823+
module(), llvm::Intrinsic::x86_sse_movmsk_ps);
37893824

3790-
llvm::Value* args[1] = { w4_int_mask };
3825+
llvm::Value* args[1] = { w4_float_mask };
37913826
llvm::Value* int4_mask = builder().CreateCall(func,
37923827
toArrayRef(args));
37933828

@@ -3833,12 +3868,20 @@ LLVM_Util::mask4_as_int8(llvm::Value* mask)
38333868
// Convert <4 x i1> -> <4 x i32>
38343869
llvm::Value* w4_int_mask = builder().CreateSExt(mask, type_wide_int());
38353870

3871+
// Now we will use the horizontal sign extraction intrinsic
3872+
// to build a 32 bit mask value. However the only 256bit
3873+
// version works on floats, so we will cast from int32 to
3874+
// float beforehand
3875+
llvm::Type* w4_float_type = llvm_vector_type(m_llvm_type_float, 4);
3876+
llvm::Value* w4_float_mask = builder().CreateBitCast(w4_int_mask,
3877+
w4_float_type);
3878+
38363879
// Now we will use the horizontal sign extraction intrinsic
38373880
// to build a 32 bit mask value.
38383881
llvm::Function* func = llvm::Intrinsic::getDeclaration(
3839-
module(), llvm::Intrinsic::x86_sse2_pmovmskb_128);
3882+
module(), llvm::Intrinsic::x86_sse_movmsk_ps);
38403883

3841-
llvm::Value* args[1] = { w4_int_mask };
3884+
llvm::Value* args[1] = { w4_float_mask };
38423885
llvm::Value* int32 = builder().CreateCall(func, toArrayRef(args));
38433886
llvm::Value* i8 = builder().CreateIntCast(int32, type_int8(), true);
38443887

@@ -4685,7 +4728,7 @@ LLVM_Util::op_gather(llvm::Type* src_type, llvm::Value* src_ptr,
46854728

46864729
llvm::Value* unmasked_value = wide_constant(0);
46874730
llvm::Value* args[] = { unmasked_value, void_ptr(src_ptr),
4688-
wide_index, int_mask, constant(4) };
4731+
wide_index, int_mask, constant(4) };
46894732
return builder().CreateCall(func_avx512_gather_pi,
46904733
toArrayRef(args));
46914734
} else if (m_supports_avx2) {
@@ -4705,8 +4748,8 @@ LLVM_Util::op_gather(llvm::Type* src_type, llvm::Value* src_ptr,
47054748
auto w8_int_masks = op_split_16x(wide_int_mask);
47064749
auto w8_int_indices = op_split_16x(wide_index);
47074750
llvm::Value* args[] = { avx2_unmasked_value, void_ptr(src_ptr),
4708-
w8_int_indices[0], w8_int_masks[0],
4709-
constant8((uint8_t)4) };
4751+
w8_int_indices[0], w8_int_masks[0],
4752+
constant8((uint8_t)4) };
47104753
llvm::Value* gather1 = builder().CreateCall(func_avx2_gather_pi,
47114754
toArrayRef(args));
47124755
args[2] = w8_int_indices[1];
@@ -4794,8 +4837,8 @@ LLVM_Util::op_gather(llvm::Type* src_type, llvm::Value* src_ptr,
47944837
toArrayRef(args));
47954838
args[2] = w8_int_indices[1];
47964839
args[3] = builder().CreateBitCast(w8_int_masks[1],
4797-
llvm_vector_type(type_float(),
4798-
8));
4840+
llvm_vector_type(type_float(),
4841+
8));
47994842
llvm::Value* gather2 = builder().CreateCall(func_avx2_gather_ps,
48004843
toArrayRef(args));
48014844
return op_combine_8x_vectors(gather1, gather2);
@@ -4990,8 +5033,8 @@ LLVM_Util::op_gather(llvm::Type* src_type, llvm::Value* src_ptr,
49905033
toArrayRef(args));
49915034
args[2] = w8_int_indices[1];
49925035
args[3] = builder().CreateBitCast(w8_int_masks[1],
4993-
llvm_vector_type(type_float(),
4994-
8));
5036+
llvm_vector_type(type_float(),
5037+
8));
49955038
llvm::Value* gather2 = builder().CreateCall(func_avx2_gather_ps,
49965039
toArrayRef(args));
49975040
return op_combine_8x_vectors(gather1, gather2);
@@ -5092,8 +5135,8 @@ LLVM_Util::op_gather(llvm::Type* src_type, llvm::Value* src_ptr,
50925135
auto w8_int_indices = op_split_16x(
50935136
op_linearize_16x_indices(wide_index));
50945137
llvm::Value* args[] = { avx2_unmasked_value, void_ptr(src_ptr),
5095-
w8_int_indices[0], w8_int_masks[0],
5096-
constant8((uint8_t)4) };
5138+
w8_int_indices[0], w8_int_masks[0],
5139+
constant8((uint8_t)4) };
50975140
llvm::Value* gather1 = builder().CreateCall(func_avx2_gather_pi,
50985141
toArrayRef(args));
50995142
args[2] = w8_int_indices[1];
@@ -5863,9 +5906,9 @@ LLVM_Util::apply_return_to(llvm::Value* existing_mask)
58635906
OSL_ASSERT(masked_function_context().return_count > 0);
58645907

58655908
llvm::Value* loc_of_return_mask = masked_function_context().location_of_mask;
5866-
llvm::Value* rs_mask = op_load_mask(loc_of_return_mask);
5867-
llvm::Value* result = builder().CreateSelect(rs_mask, existing_mask,
5868-
rs_mask);
5909+
llvm::Value* rs_mask = op_load_mask(loc_of_return_mask);
5910+
llvm::Value* result = builder().CreateSelect(rs_mask, existing_mask,
5911+
rs_mask);
58695912
return result;
58705913
}
58715914

0 commit comments

Comments
 (0)