Skip to content

Commit 1c67443

Browse files
committed
use ScopedValues for TestSets
1 parent 2e9b0bb commit 1c67443

File tree

2 files changed

+121
-157
lines changed

2 files changed

+121
-157
lines changed

stdlib/Test/src/Test.jl

+76-112
Original file line numberDiff line numberDiff line change
@@ -1004,7 +1004,6 @@ end
10041004
A simple fallback test set that throws immediately on a failure.
10051005
"""
10061006
struct FallbackTestSet <: AbstractTestSet end
1007-
fallback_testset = FallbackTestSet()
10081007

10091008
struct FallbackTestSetException <: Exception
10101009
msg::String
@@ -1074,9 +1073,9 @@ mutable struct DefaultTestSet <: AbstractTestSet
10741073
end
10751074
function DefaultTestSet(desc::AbstractString; verbose::Bool = false, showtiming::Bool = true, failfast::Union{Nothing,Bool} = nothing, source = nothing)
10761075
if isnothing(failfast)
1077-
# pass failfast state into child testsets
1078-
parent_ts = get_testset()
1079-
if parent_ts isa DefaultTestSet
1076+
# pass failfast state into child testsets
1077+
if get_testset_depth() != 0
1078+
parent_ts = get_testset()
10801079
failfast = parent_ts.failfast
10811080
else
10821081
failfast = false
@@ -1230,9 +1229,8 @@ function finish(ts::DefaultTestSet; print_results::Bool=TESTSET_PRINT_ENABLE[])
12301229
ts.time_end = time()
12311230
# If we are a nested test set, do not print a full summary
12321231
# now - let the parent test set do the printing
1233-
if get_testset_depth() != 0
1234-
# Attach this test set to the parent test set
1235-
parent_ts = get_testset()
1232+
parent_ts = get_testset()
1233+
if !(parent_ts isa FallbackTestSet)
12361234
record(parent_ts, ts)
12371235
return ts
12381236
end
@@ -1639,22 +1637,11 @@ function testset_context(args, ex, source)
16391637
else
16401638
error("Malformed `let` expression is given")
16411639
end
1642-
reverse!(contexts)
1643-
1644-
test_ex = ex.args[2]
1645-
1646-
ex.args[2] = quote
1647-
$(map(contexts) do context
1648-
:($push_testset($(ContextTestSet)($(QuoteNode(context)), $context; $options...)))
1649-
end...)
1650-
try
1651-
$(test_ex)
1652-
finally
1653-
$(map(_->:($pop_testset()), contexts)...)
1654-
end
1640+
test_ex = esc(ex.args[2])
1641+
for context in reverse(contexts)
1642+
test_ex = :(@with_testset(ContextTestSet($(QuoteNode(context)), $(esc(context)); $(esc(options))...), $test_ex))
16551643
end
1656-
1657-
return esc(ex)
1644+
return test_ex
16581645
end
16591646

16601647
"""
@@ -1672,7 +1659,7 @@ function testset_beginend_call(args, tests, source)
16721659
# If we're at the top level we'll default to DefaultTestSet. Otherwise
16731660
# default to the type of the parent testset
16741661
if testsettype === nothing
1675-
testsettype = :(get_testset_depth() == 0 ? DefaultTestSet : typeof(get_testset()))
1662+
testsettype = :(new_testset_type(get_testset()))
16761663
end
16771664

16781665
# Generate a block of code that initializes a new testset, adds
@@ -1687,35 +1674,34 @@ function testset_beginend_call(args, tests, source)
16871674
else
16881675
$(testsettype)($desc; $options...)
16891676
end
1690-
push_testset(ts)
1691-
# we reproduce the logic of guardseed, but this function
1692-
# cannot be used as it changes slightly the semantic of @testset,
1693-
# by wrapping the body in a function
1694-
local default_rng_orig = copy(default_rng())
1695-
local tls_seed_orig = copy(Random.get_tls_seed())
1696-
try
1697-
# default RNG is reset to its state from last `seed!()` to ease reproduce a failed test
1698-
copy!(Random.default_rng(), tls_seed_orig)
1699-
let
1700-
$(esc(tests))
1701-
end
1702-
catch err
1703-
err isa InterruptException && rethrow()
1704-
# something in the test block threw an error. Count that as an
1705-
# error in this test set
1706-
trigger_test_failure_break(err)
1707-
if err isa FailFastError
1708-
get_testset_depth() > 1 ? rethrow() : failfast_print()
1709-
else
1710-
record(ts, Error(:nontest_error, Expr(:tuple), err, Base.current_exceptions(), $(QuoteNode(source))))
1677+
@with_testset ts begin
1678+
# we reproduce the logic of guardseed, but this function
1679+
# cannot be used as it changes slightly the semantic of @testset,
1680+
# by wrapping the body in a function
1681+
local default_rng_orig = copy(default_rng())
1682+
local tls_seed_orig = copy(Random.get_tls_seed())
1683+
try
1684+
# default RNG is reset to its state from last `seed!()` to ease reproduce a failed test
1685+
copy!(Random.default_rng(), tls_seed_orig)
1686+
let
1687+
$(esc(tests))
1688+
end
1689+
catch err
1690+
err isa InterruptException && rethrow()
1691+
# something in the test block threw an error. Count that as an
1692+
# error in this test set
1693+
trigger_test_failure_break(err)
1694+
if err isa FailFastError
1695+
get_testset_depth() > 1 ? rethrow() : failfast_print()
1696+
else
1697+
record(ts, Error(:nontest_error, Expr(:tuple), err, Base.current_exceptions(), $(QuoteNode(source))))
1698+
end
1699+
finally
1700+
copy!(default_rng(), default_rng_orig)
1701+
copy!(Random.get_tls_seed(), tls_seed_orig)
17111702
end
1712-
finally
1713-
copy!(default_rng(), default_rng_orig)
1714-
copy!(Random.get_tls_seed(), tls_seed_orig)
1715-
pop_testset()
1716-
ret = finish(ts)
17171703
end
1718-
ret
1704+
ts
17191705
end
17201706
# preserve outer location if possible
17211707
if tests isa Expr && tests.head === :block && !isempty(tests.args) && tests.args[1] isa LineNumberNode
@@ -1761,7 +1747,7 @@ function testset_forloop(args, testloop, source)
17611747
end
17621748

17631749
if testsettype === nothing
1764-
testsettype = :(get_testset_depth() == 0 ? DefaultTestSet : typeof(get_testset()))
1750+
testsettype = :(new_testset_type(get_testset()))
17651751
end
17661752

17671753
# Uses a similar block as for `@testset`, except that it is
@@ -1771,52 +1757,38 @@ function testset_forloop(args, testloop, source)
17711757
_check_testset($testsettype, $(QuoteNode(testsettype.args[1])))
17721758
# Trick to handle `break` and `continue` in the test code before
17731759
# they can be handled properly by `finally` lowering.
1774-
if !first_iteration
1775-
pop_testset()
1776-
finish_errored = true
1777-
push!(arr, finish(ts))
1778-
finish_errored = false
1779-
copy!(default_rng(), tls_seed_orig)
1780-
end
17811760
ts = if ($testsettype === $DefaultTestSet) && $(isa(source, LineNumberNode))
17821761
$(testsettype)($desc; source=$(QuoteNode(source.file)), $options...)
17831762
else
17841763
$(testsettype)($desc; $options...)
17851764
end
1786-
push_testset(ts)
1787-
first_iteration = false
1788-
try
1789-
$(esc(tests))
1790-
catch err
1791-
err isa InterruptException && rethrow()
1792-
# Something in the test block threw an error. Count that as an
1793-
# error in this test set
1794-
trigger_test_failure_break(err)
1795-
if !isa(err, FailFastError)
1796-
record(ts, Error(:nontest_error, Expr(:tuple), err, Base.current_exceptions(), $(QuoteNode(source))))
1765+
@with_testset ts begin
1766+
try
1767+
# default RNG is reset to its state from last `seed!()` to ease reproduce a failed test
1768+
copy!(Random.default_rng(), tls_seed_orig)
1769+
$(esc(tests))
1770+
catch err
1771+
err isa InterruptException && rethrow()
1772+
# Something in the test block threw an error. Count that as an
1773+
# error in this test set
1774+
trigger_test_failure_break(err)
1775+
if !isa(err, FailFastError)
1776+
record(ts, Error(:nontest_error, Expr(:tuple), err, Base.current_exceptions(), $(QuoteNode(source))))
1777+
end
1778+
finally
1779+
copy!(default_rng(), default_rng_orig)
1780+
copy!(Random.get_tls_seed(), tls_seed_orig)
1781+
push!(arr, ts)
17971782
end
17981783
end
17991784
end
18001785
quote
18011786
local arr = Vector{Any}()
1802-
local first_iteration = true
1803-
local ts
1804-
local finish_errored = false
18051787
local default_rng_orig = copy(default_rng())
18061788
local tls_seed_orig = copy(Random.get_tls_seed())
1807-
copy!(Random.default_rng(), tls_seed_orig)
1808-
try
1809-
let
1810-
$(Expr(:for, Expr(:block, [esc(v) for v in loopvars]...), blk))
1811-
end
1812-
finally
1813-
# Handle `return` in test body
1814-
if !first_iteration && !finish_errored
1815-
pop_testset()
1816-
push!(arr, finish(ts))
1817-
end
1818-
copy!(default_rng(), default_rng_orig)
1819-
copy!(Random.get_tls_seed(), tls_seed_orig)
1789+
local ts
1790+
let
1791+
$(Expr(:for, Expr(:block, [esc(v) for v in loopvars]...), blk))
18201792
end
18211793
arr
18221794
end
@@ -1855,50 +1827,42 @@ end
18551827
#-----------------------------------------------------------------------
18561828
# Various helper methods for test sets
18571829

1830+
const CURRENT_TESTSET = ScopedValue{AbstractTestSet}(FallbackTestSet())
1831+
const TESTSET_DEPTH = ScopedValue{Int}(0)
1832+
1833+
macro with_testset(ts, expr)
1834+
quote
1835+
ts = $(esc(ts))
1836+
Expr(:tryfinally,
1837+
@with(CURRENT_TESTSET => ts, TESTSET_DEPTH => get_testset_depth() + 1, $(esc(expr))),
1838+
finish(ts)
1839+
)
1840+
end
1841+
end
1842+
18581843
"""
18591844
get_testset()
18601845
18611846
Retrieve the active test set from the task's local storage. If no
18621847
test set is active, use the fallback default test set.
18631848
"""
18641849
function get_testset()
1865-
testsets = get(task_local_storage(), :__BASETESTNEXT__, AbstractTestSet[])
1866-
return isempty(testsets) ? fallback_testset : testsets[end]
1850+
something(Base.ScopedValues.get(CURRENT_TESTSET))
18671851
end
18681852

18691853
"""
1870-
push_testset(ts::AbstractTestSet)
1854+
get_testset_depth()
18711855
1872-
Adds the test set to the `task_local_storage`.
1856+
Return the number of active test sets, not including the default test set
18731857
"""
1874-
function push_testset(ts::AbstractTestSet)
1875-
testsets = get(task_local_storage(), :__BASETESTNEXT__, AbstractTestSet[])
1876-
push!(testsets, ts)
1877-
setindex!(task_local_storage(), testsets, :__BASETESTNEXT__)
1858+
function get_testset_depth()
1859+
something(Base.ScopedValues.get(TESTSET_DEPTH))
18781860
end
18791861

1880-
"""
1881-
pop_testset()
18821862

1883-
Pops the last test set added to the `task_local_storage`. If there are no
1884-
active test sets, returns the fallback default test set.
1885-
"""
1886-
function pop_testset()
1887-
testsets = get(task_local_storage(), :__BASETESTNEXT__, AbstractTestSet[])
1888-
ret = isempty(testsets) ? fallback_testset : pop!(testsets)
1889-
setindex!(task_local_storage(), testsets, :__BASETESTNEXT__)
1890-
return ret
1891-
end
1863+
new_testset_type(::AbstractTestSet) = DefaultTestSet
18921864

1893-
"""
1894-
get_testset_depth()
18951865

1896-
Return the number of active test sets, not including the default test set
1897-
"""
1898-
function get_testset_depth()
1899-
testsets = get(task_local_storage(), :__BASETESTNEXT__, AbstractTestSet[])
1900-
return length(testsets)
1901-
end
19021866

19031867
_args_and_call(args...; kwargs...) = (args[1:end-1], kwargs, args[end](args[1:end-1]...; kwargs...))
19041868
_materialize_broadcasted(f, args...) = Broadcast.materialize(Broadcast.broadcasted(f, args...))

test/runtests.jl

+45-45
Original file line numberDiff line numberDiff line change
@@ -379,55 +379,55 @@ cd(@__DIR__) do
379379
Test.TESTSET_PRINT_ENABLE[] = false
380380
o_ts = Test.DefaultTestSet("Overall")
381381
o_ts.time_end = o_ts.time_start + o_ts_duration # manually populate the timing
382-
Test.push_testset(o_ts)
383-
completed_tests = Set{String}()
384-
for (testname, (resp,), duration) in results
385-
push!(completed_tests, testname)
386-
if isa(resp, Test.DefaultTestSet)
387-
resp.time_end = resp.time_start + duration
388-
Test.push_testset(resp)
389-
Test.record(o_ts, resp)
390-
Test.pop_testset()
391-
elseif isa(resp, Test.TestSetException)
392-
fake = Test.DefaultTestSet(testname)
393-
fake.time_end = fake.time_start + duration
394-
for i in 1:resp.pass
395-
Test.record(fake, Test.Pass(:test, nothing, nothing, nothing, LineNumberNode(@__LINE__, @__FILE__)))
396-
end
397-
for i in 1:resp.broken
398-
Test.record(fake, Test.Broken(:test, nothing))
399-
end
400-
for t in resp.errors_and_fails
401-
Test.record(fake, t)
382+
Test.@with_testset o_ts begin
383+
completed_tests = Set{String}()
384+
for (testname, (resp,), duration) in results
385+
push!(completed_tests, testname)
386+
if isa(resp, Test.DefaultTestSet)
387+
resp.time_end = resp.time_start + duration
388+
Test.@with_testset resp begin
389+
Test.record(o_ts, resp)
390+
end
391+
elseif isa(resp, Test.TestSetException)
392+
fake = Test.DefaultTestSet(testname)
393+
fake.time_end = fake.time_start + duration
394+
for i in 1:resp.pass
395+
Test.record(fake, Test.Pass(:test, nothing, nothing, nothing, LineNumberNode(@__LINE__, @__FILE__)))
396+
end
397+
for i in 1:resp.broken
398+
Test.record(fake, Test.Broken(:test, nothing))
399+
end
400+
for t in resp.errors_and_fails
401+
Test.record(fake, t)
402+
end
403+
Test.@with_testset fake begin
404+
Test.record(o_ts, fake)
405+
end
406+
else
407+
if !isa(resp, Exception)
408+
resp = ErrorException(string("Unknown result type : ", typeof(resp)))
409+
end
410+
# If this test raised an exception that is not a remote testset exception,
411+
# i.e. not a RemoteException capturing a TestSetException that means
412+
# the test runner itself had some problem, so we may have hit a segfault,
413+
# deserialization errors or something similar. Record this testset as Errored.
414+
fake = Test.DefaultTestSet(testname)
415+
fake.time_end = fake.time_start + duration
416+
Test.record(fake, Test.Error(:nontest_error, testname, nothing, Any[(resp, [])], LineNumberNode(1)))
417+
Test.@with_testset fake begin
418+
Test.record(o_ts, fake)
419+
end
402420
end
403-
Test.push_testset(fake)
404-
Test.record(o_ts, fake)
405-
Test.pop_testset()
406-
else
407-
if !isa(resp, Exception)
408-
resp = ErrorException(string("Unknown result type : ", typeof(resp)))
421+
end
422+
for test in all_tests
423+
(test in completed_tests) && continue
424+
fake = Test.DefaultTestSet(test)
425+
Test.record(fake, Test.Error(:test_interrupted, test, nothing, [("skipped", [])], LineNumberNode(1)))
426+
Test.@with_testset fake begin
427+
Test.record(o_ts, fake)
409428
end
410-
# If this test raised an exception that is not a remote testset exception,
411-
# i.e. not a RemoteException capturing a TestSetException that means
412-
# the test runner itself had some problem, so we may have hit a segfault,
413-
# deserialization errors or something similar. Record this testset as Errored.
414-
fake = Test.DefaultTestSet(testname)
415-
fake.time_end = fake.time_start + duration
416-
Test.record(fake, Test.Error(:nontest_error, testname, nothing, Any[(resp, [])], LineNumberNode(1)))
417-
Test.push_testset(fake)
418-
Test.record(o_ts, fake)
419-
Test.pop_testset()
420429
end
421430
end
422-
for test in all_tests
423-
(test in completed_tests) && continue
424-
fake = Test.DefaultTestSet(test)
425-
Test.record(fake, Test.Error(:test_interrupted, test, nothing, [("skipped", [])], LineNumberNode(1)))
426-
Test.push_testset(fake)
427-
Test.record(o_ts, fake)
428-
Test.pop_testset()
429-
end
430-
431431
if Base.get_bool_env("CI", false)
432432
@info "Writing test result data to $(@__DIR__)"
433433
write_testset_json_files(@__DIR__, o_ts)

0 commit comments

Comments
 (0)