Skip to content

Commit 7cb87d5

Browse files
authored
Merge pull request #116 from cscherrer/static-integer
Use StaticInteger instead of StaticInt and add tests for static util functions
2 parents 5bba40f + 4dac4fc commit 7cb87d5

10 files changed

+47
-12
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MeasureBase"
22
uuid = "fa1605e6-acd5-459c-a1e6-7e635759db14"
33
authors = ["Chad Scherrer <[email protected]> and contributors"]
4-
version = "0.14.6"
4+
version = "0.14.7"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/MeasureBase.jl

+1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ const Pretty = PrettyPrinting
3434
using ChainRulesCore
3535
import FillArrays
3636
using Static
37+
using Static: StaticInteger
3738
using FunctionChains
3839

3940
export

src/density-core.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ end
143143
@generated function _logdensity_rel(
144144
μs::Tμ,
145145
νs::Tν,
146-
::Tuple{StaticInt{M},StaticInt{N}},
146+
::Tuple{<:StaticInteger{M},<:StaticInteger{N}},
147147
x::X,
148148
) where {Tμ,Tν,M,N,X}
149149
= schema(Tμ)

src/standard/stdmeasure.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ end
3434

3535
# Implement transport_to(NU::Type{<:StdMeasure}, μ) and transport_to(ν, MU::Type{<:StdMeasure}):
3636

37-
_std_measure(::Type{M}, ::StaticInt{1}) where {M<:StdMeasure} = M()
37+
_std_measure(::Type{M}, ::StaticInteger{1}) where {M<:StdMeasure} = M()
3838
_std_measure(::Type{M}, dof::IntegerLike) where {M<:StdMeasure} = M()^dof
3939
_std_measure_for(::Type{M}, μ::Any) where {M<:StdMeasure} = _std_measure(M, getdof(μ))
4040

src/static.jl

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
"""
22
MeasureBase.IntegerLike
33
4-
Equivalent to `Union{Integer,Static.StaticInt}`.
4+
Equivalent to `Union{Integer,Static.StaticInteger}`.
55
"""
6-
const IntegerLike = Union{Integer,Static.StaticInt}
6+
const IntegerLike = Union{Integer,Static.StaticInteger}
77

88
"""
99
MeasureBase.one_to(n::IntegerLike)
@@ -14,7 +14,7 @@ Returns an instance of `Base.OneTo` or `Static.SOneTo`, depending
1414
on the type of `n`.
1515
"""
1616
@inline one_to(n::Integer) = Base.OneTo(n)
17-
@inline one_to(::Static.StaticInt{N}) where {N} = Static.SOneTo{N}()
17+
@inline one_to(::Static.StaticInteger{N}) where {N} = Static.SOneTo{N}()
1818

1919
_dynamic(x::Number) = dynamic(x)
2020
_dynamic(::Static.SOneTo{N}) where {N} = Base.OneTo(N)
@@ -49,7 +49,7 @@ Returns the length of `x` as a dynamic or static integer.
4949
"""
5050
maybestatic_length(x) = length(x)
5151
maybestatic_length(x::AbstractUnitRange) = length(x)
52-
function maybestatic_length(::Static.OptionallyStaticUnitRange{StaticInt{A},StaticInt{B}}) where {A,B}
52+
function maybestatic_length(::Static.OptionallyStaticUnitRange{<:StaticInteger{A},<:StaticInteger{B}}) where {A,B}
5353
StaticInt{B - A + 1}()
5454
end
5555

src/transport.jl

+4-4
Original file line numberDiff line numberDiff line change
@@ -139,14 +139,14 @@ _origin_depth_pullback(ΔΩ) = NoTangent(), NoTangent()
139139
ChainRulesCore.rrule(::typeof(_origin_depth), ν) = _origin_depth(ν), _origin_depth_pullback
140140

141141
# If both both measures have no origin:
142-
function _transport_between_origins(ν, ::StaticInt{0}, ::StaticInt{0}, μ, x)
142+
function _transport_between_origins(ν, ::StaticInteger{0}, ::StaticInteger{0}, μ, x)
143143
_transport_with_intermediate(ν, _transport_intermediate(ν, μ), μ, x)
144144
end
145145

146146
@generated function _transport_between_origins(
147147
ν,
148-
::StaticInt{n_ν},
149-
::StaticInt{n_μ},
148+
::StaticInteger{n_ν},
149+
::StaticInteger{n_μ},
150150
μ,
151151
x,
152152
) where {n_ν,n_μ}
@@ -188,7 +188,7 @@ end
188188

189189
@inline _transport_intermediate(ν, μ) = _transport_intermediate(getdof(ν), getdof(μ))
190190
@inline _transport_intermediate(::Integer, n_μ::Integer) = StdUniform()^n_μ
191-
@inline _transport_intermediate(::StaticInt{1}, ::StaticInt{1}) = StdUniform()
191+
@inline _transport_intermediate(::StaticInteger{1}, ::StaticInteger{1}) = StdUniform()
192192

193193
_call_transport_def(ν, μ, x) = transport_def(ν, μ, x)
194194
_call_transport_def(::Any, ::Any, x::NoTransportOrigin) = x

src/utils.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ repeatedly until there's no change. That's what this does.
3333
_rootmeasure(μ, static(n))
3434
end
3535

36-
@generated function _rootmeasure(μ, ::StaticInt{n}) where {n}
36+
@generated function _rootmeasure(μ, ::StaticInteger{n}) where {n}
3737
q = quote end
3838
foreach(1:n) do _
3939
push!(q.args, :(μ = basemeasure(μ)))

test/Project.toml

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
33
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
44
ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
55
DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d"
6+
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
67
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
78
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
89
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

test/runtests.jl

+2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ using MeasureBase: test_interface, test_smf
1010
using Aqua
1111
Aqua.test_all(MeasureBase; ambiguities = false)
1212

13+
include("static.jl")
14+
1315
# Aqua._test_ambiguities(
1416
# Aqua.aspkgids(MeasureBase);
1517
# exclude = [LogarithmicNumbers.Logarithmic],

test/static.jl

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
using Test
2+
3+
import MeasureBase
4+
5+
import Static
6+
using Static: static
7+
import FillArrays
8+
9+
@testset "static" begin
10+
@test 2 isa MeasureBase.IntegerLike
11+
@test static(2) isa MeasureBase.IntegerLike
12+
@test true isa MeasureBase.IntegerLike
13+
@test static(true) isa MeasureBase.IntegerLike
14+
15+
@test @inferred(MeasureBase.one_to(7)) isa Base.OneTo
16+
@test @inferred(MeasureBase.one_to(7)) == 1:7
17+
@test @inferred(MeasureBase.one_to(static(7))) isa Static.SOneTo
18+
@test @inferred(MeasureBase.one_to(static(7))) == static(1):static(7)
19+
20+
@test @inferred(MeasureBase.fill_with(4.2, (7,))) == FillArrays.Fill(4.2, 7)
21+
@test @inferred(MeasureBase.fill_with(4.2, (static(7),))) == FillArrays.Fill(4.2, 7)
22+
@test @inferred(MeasureBase.fill_with(4.2, (3, static(7)))) == FillArrays.Fill(4.2, 3, 7)
23+
@test @inferred(MeasureBase.fill_with(4.2, (3:7,))) == FillArrays.Fill(4.2, (3:7,))
24+
@test @inferred(MeasureBase.fill_with(4.2, (static(3):static(7),))) == FillArrays.Fill(4.2, (3:7,))
25+
@test @inferred(MeasureBase.fill_with(4.2, (3:7, static(2):static(5)))) == FillArrays.Fill(4.2, (3:7, 2:5))
26+
27+
@test MeasureBase.maybestatic_length(MeasureBase.one_to(7)) isa Int
28+
@test MeasureBase.maybestatic_length(MeasureBase.one_to(7)) == 7
29+
@test MeasureBase.maybestatic_length(MeasureBase.one_to(static(7))) isa Static.StaticInt
30+
@test MeasureBase.maybestatic_length(MeasureBase.one_to(static(7))) == static(7)
31+
end

0 commit comments

Comments
 (0)