diff --git a/src/dual.jl b/src/dual.jl index 2ca4683f..157daeaa 100644 --- a/src/dual.jl +++ b/src/dual.jl @@ -353,7 +353,9 @@ end @inline Base.zero(::Type{Dual{T,V,N}}) where {T,V,N} = Dual{T}(zero(V), zero(Partials{N,V})) @inline Base.one(d::Dual) = one(typeof(d)) -@inline Base.one(::Type{Dual{T,V,N}}) where {T,V,N} = Dual{T}(one(V), zero(Partials{N,V})) +@inline Base.oneunit(d::Dual) = oneunit(typeof(d)) +@inline Base.one(::Type{Dual{T,V,N}}) where {T,V,N} = one(V) +@inline Base.oneunit(::Type{Dual{T,V,N}}) where {T,V,N} = Dual{T}(oneunit(V), zero(Partials{N,V})) @inline function Base.Int(d::Dual) all(iszero, partials(d)) || throw(InexactError(:Int, Int, d)) diff --git a/src/partials.jl b/src/partials.jl index a5316e3e..ddec4ebc 100644 --- a/src/partials.jl +++ b/src/partials.jl @@ -7,7 +7,7 @@ end ############################## @generated function single_seed(::Type{Partials{N,V}}, ::Val{i}) where {N,V,i} - ex = Expr(:tuple, [ifelse(i === j, :(one(V)), :(zero(V))) for j in 1:N]...) + ex = Expr(:tuple, [ifelse(i === j, :(oneunit(V)), :(zero(V))) for j in 1:N]...) return :(Partials($(ex))) end diff --git a/test/DualTest.jl b/test/DualTest.jl index 938cd0e6..f48970d3 100644 --- a/test/DualTest.jl +++ b/test/DualTest.jl @@ -210,10 +210,15 @@ ForwardDiff.:≺(::Type{OuterTestTag}, ::Type{TestTag}) = false @test zero(NESTED_FDNUM) === Dual{TestTag()}(Dual{TestTag()}(zero(PRIMAL), zero(M_PARTIALS)), zero(NESTED_PARTIALS)) @test zero(typeof(NESTED_FDNUM)) === Dual{TestTag()}(Dual{TestTag()}(zero(V), zero(Partials{M,V})), zero(Partials{N,Dual{TestTag(),V,M}})) - @test one(FDNUM) === Dual{TestTag()}(one(PRIMAL), zero(PARTIALS)) - @test one(typeof(FDNUM)) === Dual{TestTag()}(one(V), zero(Partials{N,V})) - @test one(NESTED_FDNUM) === Dual{TestTag()}(Dual{TestTag()}(one(PRIMAL), zero(M_PARTIALS)), zero(NESTED_PARTIALS)) - @test one(typeof(NESTED_FDNUM)) === Dual{TestTag()}(Dual{TestTag()}(one(V), zero(Partials{M,V})), zero(Partials{N,Dual{TestTag(),V,M}})) + @test one(FDNUM) === one(value(FDNUM)) + @test one(typeof(FDNUM)) === one(typeof(value(FDNUM))) + @test one(NESTED_FDNUM) === one(value(NESTED_FDNUM)) + @test one(typeof(NESTED_FDNUM)) === one(typeof(value(NESTED_FDNUM))) + + @test oneunit(FDNUM) === Dual{TestTag()}(one(PRIMAL), zero(PARTIALS)) + @test oneunit(typeof(FDNUM)) === Dual{TestTag()}(one(V), zero(Partials{N,V})) + @test oneunit(NESTED_FDNUM) === Dual{TestTag()}(Dual{TestTag()}(one(PRIMAL), zero(M_PARTIALS)), zero(NESTED_PARTIALS)) + @test oneunit(typeof(NESTED_FDNUM)) === Dual{TestTag()}(Dual{TestTag()}(one(V), zero(Partials{M,V})), zero(Partials{N,Dual{TestTag(),V,M}})) if V <: Integer @test rand(samerng(), FDNUM) == rand(samerng(), value(FDNUM)) @@ -233,11 +238,11 @@ ForwardDiff.:≺(::Type{OuterTestTag}, ::Type{TestTag}) = false #------------# @test ForwardDiff.isconstant(zero(FDNUM)) - @test ForwardDiff.isconstant(one(FDNUM)) + @test ForwardDiff.isconstant(oneunit(FDNUM)) @test ForwardDiff.isconstant(FDNUM) == (N == 0) @test ForwardDiff.isconstant(zero(NESTED_FDNUM)) - @test ForwardDiff.isconstant(one(NESTED_FDNUM)) + @test ForwardDiff.isconstant(oneunit(NESTED_FDNUM)) @test ForwardDiff.isconstant(NESTED_FDNUM) == (N == 0) # Recall that FDNUM = Dual{TestTag()}(PRIMAL, PARTIALS) has N partials, @@ -476,15 +481,15 @@ ForwardDiff.:≺(::Type{OuterTestTag}, ::Type{TestTag}) = false if arity == 1 deriv = DiffRules.diffrule(M, f, :x) modifier = if in(f, (:asec, :acsc, :asecd, :acscd, :acosh, :acoth)) - one(V) + oneunit(V) elseif in(f, (:log1mexp, :log2mexp)) - -one(V) + -oneunit(V) else zero(V) end @eval begin x = rand() + $modifier - dx = @inferred $M.$f(Dual{TestTag()}(x, one(x))) + dx = @inferred $M.$f(Dual{TestTag()}(x, oneunit(x))) actualval = $M.$f(x) @assert actualval isa Real || actualval isa Complex if actualval isa Real @@ -510,8 +515,8 @@ ForwardDiff.:≺(::Type{OuterTestTag}, ::Type{TestTag}) = false end @eval begin x, y = $x, $y - dx = @inferred $M.$f(Dual{TestTag()}(x, one(x)), y) - dy = @inferred $M.$f(x, Dual{TestTag()}(y, one(y))) + dx = @inferred $M.$f(Dual{TestTag()}(x, oneunit(x)), y) + dy = @inferred $M.$f(x, Dual{TestTag()}(y, oneunit(y))) actualdx = $(derivs[1]) actualdy = $(derivs[2]) actualval = $M.$f(x, y)