Skip to content

Commit edb4972

Browse files
authored
Merge pull request #144 from JuliaSIMD/checkforbroadcastop
Check for broadcast op and format. Fixes #143.
2 parents 0b2d48a + 3fffaef commit edb4972

File tree

3 files changed

+57
-26
lines changed

3 files changed

+57
-26
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Polyester"
22
uuid = "f517fe37-dbe3-4b94-8317-1923a5111588"
33
authors = ["Chris Elrod <[email protected]> and contributors"]
4-
version = "0.7.12"
4+
version = "0.7.13"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

src/closure.jl

+3-2
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,9 @@ function extractargs!(
112112

113113
startind = 1
114114
if head === :call
115-
if args[1] isa Symbol
116-
startind = isdefined(mod, args[1]) ? 2 : 1
115+
arg1 = args[1]
116+
if arg1 isa Symbol && (first(string(arg1)) != '.')
117+
startind = isdefined(mod, arg1) ? 2 : 1
117118
else
118119
startind = 2
119120
end

test/runtests.jl

+53-23
Original file line numberDiff line numberDiff line change
@@ -114,21 +114,21 @@ end
114114

115115
function issue108!(y::Vector{T1}, x::Vector{T2}) where {T1,T2}
116116
@batch for i in eachindex(y)
117-
y[i] = sum(x[j] for j in 2i-oneunit(i):2i)
117+
y[i] = sum(x[j] for j = 2i-oneunit(i):2i)
118118
end
119119
end
120120

121121
function issue108_comment!(data::Vector{T}, functions) where {T}
122122
@batch for i in eachindex(data)
123123
for f in functions
124-
data[i] += f(data[i])
124+
data[i] += f(data[i])
125125
end
126126
end
127127
end
128128

129129
function issue116!(y::Vector{T}, x::Vector{T}) where {T}
130-
@batch for i in 1:length(x)
131-
y[i] = exp(x[i] + one(T))
130+
@batch for i = 1:length(x)
131+
y[i] = exp(x[i] + one(T))
132132
end
133133
end
134134

@@ -280,15 +280,15 @@ end
280280
x = collect(1:12)
281281
y = zeros(6)
282282
issue108!(y, x)
283-
@test y == [sum(x[j] for j in 2i-oneunit(i):2i) for i in 1:6]
284-
285-
functions = [x -> n*x for n in 1:3]
283+
@test y == [sum(x[j] for j = 2i-oneunit(i):2i) for i = 1:6]
284+
285+
functions = [x -> n * x for n = 1:3]
286286
data = rand(100)
287287
data1 = deepcopy(data)
288288
issue108_comment!(data, functions)
289289
for i in eachindex(data1)
290290
for f in functions
291-
data1[i] += f(data1[i])
291+
data1[i] += f(data1[i])
292292
end
293293
end
294294
@test data == data1
@@ -467,7 +467,7 @@ end
467467
end
468468
local7, local8 = let
469469
red = 0
470-
@batch minbatch = 100 stride = true reduction = (+,red) threadlocal = red for i = 0:9
470+
@batch minbatch = 100 stride = true reduction = (+, red) threadlocal = red for i = 0:9
471471
red += 1
472472
threadlocal += 1
473473
end
@@ -480,11 +480,19 @@ end
480480
end
481481
red
482482
end
483-
@test local1==local2==local3==local4==local5==local6==local7==local8==localsr
483+
@test local1 ==
484+
local2 ==
485+
local3 ==
486+
local4 ==
487+
local5 ==
488+
local6 ==
489+
local7 ==
490+
local8 ==
491+
localsr
484492
# check different operations
485493
local9 = let
486494
red = 1.0
487-
@batch reduction = (*,red) for i = 1:100
495+
@batch reduction = (*, red) for i = 1:100
488496
red *= 4i^2 / (4i^2 - 1)
489497
end
490498
2red
@@ -495,7 +503,7 @@ end
495503
red1 = 0
496504
red2 = 0
497505
red3 = 0
498-
@batch reduction = ((+,red1), (+,red2), (+,red3)) for i = 0:9
506+
@batch reduction = ((+, red1), (+, red2), (+, red3)) for i = 0:9
499507
red1 += 1
500508
red2 += 1
501509
red3 -= 1
@@ -507,13 +515,19 @@ end
507515
function f()
508516
n = 1000
509517
threadlocal = 0
510-
@batch minbatch = 10 reduction = (+,threadlocal) for i = 1:n
518+
@batch minbatch = 10 reduction = (+, threadlocal) for i = 1:n
511519
threadlocal += 1
512520
end
513521
return threadlocal
514522
end
515523
allocated(f::F) where {F} = @allocated f()
516-
inferred(f::F) where {F} = try @inferred f(); true catch; false end
524+
inferred(f::F) where {F} =
525+
try
526+
@inferred f()
527+
true
528+
catch
529+
false
530+
end
517531
allocated(f)
518532
@test allocated(f) == 0
519533
@test inferred(f) == true
@@ -524,16 +538,20 @@ end
524538
red2 = false
525539
red3 = typemax(eltype(arr))
526540
red4 = typemin(eltype(arr))
527-
@batch reduction = ((&,red1), (|,red2), (min,red3), (max,red4)) for x in arr
528-
red1 &= x > 0.5
529-
red2 |= x > 0.5
530-
red3 = min(red3, x)
531-
red4 = max(red4, x)
541+
@batch reduction = ((&, red1), (|, red2), (min, red3), (max, red4)) for x in arr
542+
red1 &= x > 0.5
543+
red2 |= x > 0.5
544+
red3 = min(red3, x)
545+
red4 = max(red4, x)
532546
end
533547
red1, red2, red3, red4
534548
end
535-
@test (local13, local14, local15, local16) ==
536-
(mapreduce(x->x>0.5, &, arr), mapreduce(x->x>0.5, |, arr), minimum(arr), maximum(arr))
549+
@test (local13, local14, local15, local16) == (
550+
mapreduce(x -> x > 0.5, &, arr),
551+
mapreduce(x -> x > 0.5, |, arr),
552+
minimum(arr),
553+
maximum(arr),
554+
)
537555
end
538556

539557
@testset "locks and refvalues" begin
@@ -747,14 +765,26 @@ end
747765
return any(find_call_to_nthreads, expr.args)
748766
end
749767

750-
expr = @macroexpand @batch for i in 1:100
768+
expr = @macroexpand @batch for i = 1:100
751769
a[i] = i
752770
end
753771

754772
@test find_call_to_nthreads(expr)
755773
end
756774

775+
776+
function dummy_broadcast!(x)
777+
@batch for i = 1:2
778+
a = (1,) .+ (1,)
779+
x[i] = only(a)
780+
end
781+
end
782+
let x = Vector{Float64}(undef, 2)
783+
dummy_broadcast!(x)
784+
@test x == fill(2.0, 2)
785+
end
786+
757787
if VERSION v"1.6"
758788
println("Package tests complete. Running `Aqua` checks.")
759-
Aqua.test_all(Polyester; deps_compat = (check_extras=false,))
789+
Aqua.test_all(Polyester; deps_compat = (check_extras = false,))
760790
end

0 commit comments

Comments
 (0)