Skip to content

Commit 35c1f87

Browse files
authored
fix #36230, more efficient lowering of if with a chain of && (#36355)
1 parent 00c41cc commit 35c1f87

File tree

4 files changed

+36
-8
lines changed

4 files changed

+36
-8
lines changed

base/expr.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -347,9 +347,9 @@ function findmeta_block(exargs, argsmatch=args->true)
347347
for i = 1:length(exargs)
348348
a = exargs[i]
349349
if isa(a, Expr)
350-
if (a::Expr).head === :meta && argsmatch((a::Expr).args)
350+
if a.head === :meta && argsmatch(a.args)
351351
return i, exargs
352-
elseif (a::Expr).head === :block
352+
elseif a.head === :block
353353
idx, exa = findmeta_block(a.args, argsmatch)
354354
if idx != 0
355355
return idx, exa

base/meta.jl

+3-2
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,9 @@ true
6262
```
6363
"""
6464
isexpr(@nospecialize(ex), head::Symbol) = isa(ex, Expr) && ex.head === head
65-
isexpr(@nospecialize(ex), heads::Union{Set,Vector,Tuple}) = isa(ex, Expr) && in(ex.head, heads)
66-
isexpr(@nospecialize(ex), heads, n::Int) = isexpr(ex, heads) && length((ex::Expr).args) == n
65+
isexpr(@nospecialize(ex), heads) = isa(ex, Expr) && in(ex.head, heads)
66+
isexpr(@nospecialize(ex), head::Symbol, n::Int) = isa(ex, Expr) && ex.head === head && length(ex.args) == n
67+
isexpr(@nospecialize(ex), heads, n::Int) = isa(ex, Expr) && in(ex.head, heads) && length(ex.args) == n
6768

6869
"""
6970
Meta.show_sexpr([io::IO,], ex)

src/julia-syntax.scm

+21-4
Original file line numberDiff line numberDiff line change
@@ -1895,6 +1895,13 @@
18951895
(else
18961896
(error (string "invalid " syntax-str " \"" (deparse el) "\""))))))))
18971897

1898+
(define (expand-if e)
1899+
(if (and (pair? (cadr e)) (eq? (car (cadr e)) '&&))
1900+
(let ((clauses (cdr (flatten-ex '&& (cadr e)))))
1901+
`(if (&& ,@(map expand-forms clauses))
1902+
,@(map expand-forms (cddr e))))
1903+
(cons (car e) (map expand-forms (cdr e)))))
1904+
18981905
;; move an assignment into the last statement of a block to keep more statements at top level
18991906
(define (sink-assignment lhs rhs)
19001907
(if (and (pair? rhs) (eq? (car rhs) 'block))
@@ -2230,6 +2237,9 @@
22302237
,(expand-forms (cadr e)) ,(expand-forms (caddr e)))
22312238
(map expand-forms e)))
22322239

2240+
'if expand-if
2241+
'elseif expand-if
2242+
22332243
'while
22342244
(lambda (e)
22352245
`(break-block loop-exit
@@ -3709,7 +3719,8 @@ f(x) = yt(x)
37093719
(handler-level 0) ;; exception handler nesting depth
37103720
(catch-token-stack '())) ;; tokens identifying handler enter for current catch blocks
37113721
(define (emit c)
3712-
(set! code (cons c code)))
3722+
(set! code (cons c code))
3723+
c)
37133724
(define (make-label)
37143725
(begin0 label-counter
37153726
(set! label-counter (+ 1 label-counter))))
@@ -3957,15 +3968,21 @@ f(x) = yt(x)
39573968
(compile (cadr e) break-labels value tail)
39583969
#f))
39593970
((if elseif)
3960-
(let ((test `(gotoifnot ,(compile-cond (cadr e) break-labels) _))
3971+
(let ((tests (map (lambda (clause)
3972+
(emit `(gotoifnot ,(compile-cond clause break-labels) _)))
3973+
(if (and (pair? (cadr e)) (eq? (car (cadr e)) '&&))
3974+
(cdadr e)
3975+
(list (cadr e)))))
39613976
(end-jump `(goto _))
39623977
(val (if (and value (not tail)) (new-mutable-var) #f)))
3963-
(emit test)
39643978
(let ((v1 (compile (caddr e) break-labels value tail)))
39653979
(if val (emit-assignment val v1))
39663980
(if (and (not tail) (or (length> e 3) val))
39673981
(emit end-jump))
3968-
(set-car! (cddr test) (make&mark-label))
3982+
(let ((elselabel (make&mark-label)))
3983+
(for-each (lambda (test)
3984+
(set-car! (cddr test) elselabel))
3985+
tests))
39693986
(let ((v2 (if (length> e 3)
39703987
(compile (cadddr e) break-labels value tail)
39713988
'(null))))

test/compiler/inference.jl

+10
Original file line numberDiff line numberDiff line change
@@ -2650,3 +2650,13 @@ end
26502650
f(n) = depth(n, 1)
26512651
end
26522652
@test Base.return_types(TestConstPropRecursion.f, (TestConstPropRecursion.Node,)) == Any[Int]
2653+
2654+
# issue #36230, keeping implications of all conditions in a && chain
2655+
function symcmp36230(vec)
2656+
a, b = vec[1], vec[2]
2657+
if isa(a, Symbol) && isa(b, Symbol)
2658+
return a == b
2659+
end
2660+
return false
2661+
end
2662+
@test Base.return_types(symcmp36230, (Vector{Any},)) == Any[Bool]

0 commit comments

Comments
 (0)