From eb7de41602fbe7a8fe3581c980f37b18a5342c68 Mon Sep 17 00:00:00 2001 From: Miles Lucas Date: Fri, 10 Sep 2021 14:20:52 -1000 Subject: [PATCH 1/9] add chainrulescore and chainrulestestutils --- Project.toml | 1 + test/Project.toml | 1 + 2 files changed, 2 insertions(+) diff --git a/Project.toml b/Project.toml index cfa2b44..a2006d5 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ authors = ["Miles Lucas and contributors"] version = "0.3.0" [deps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" CoordinateTransformations = "150eb455-5306-5404-9cee-2592286d6298" Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" KeywordCalls = "4d827475-d3e4-43d6-abe3-9688362ede9f" diff --git a/test/Project.toml b/test/Project.toml index bbdcf2a..d8d1595 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,4 +1,5 @@ [deps] +ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" From 8e2db5e07d4a60a76e76ea81f153a8ab21cf6e37 Mon Sep 17 00:00:00 2001 From: Miles Lucas Date: Wed, 15 Sep 2021 11:04:47 -1000 Subject: [PATCH 2/9] write up gradients for Gaussian PSF --- src/PSFModels.jl | 3 +++ src/gaussian.jl | 48 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+) diff --git a/src/PSFModels.jl b/src/PSFModels.jl index a9f5bbf..c0cda0a 100644 --- a/src/PSFModels.jl +++ b/src/PSFModels.jl @@ -111,9 +111,12 @@ plot(model, axes(other)) # use axes from other array """ module PSFModels +using ChainRulesCore +import ChainRulesCore: frule, rrule using CoordinateTransformations using Distances using KeywordCalls +using LinearAlgebra using SpecialFunctions using StaticArrays diff --git a/src/gaussian.jl b/src/gaussian.jl index 2835e32..509f07d 100644 --- a/src/gaussian.jl +++ b/src/gaussian.jl @@ -61,3 +61,51 @@ function (g::Gaussian{T,<:Union{Tuple,AbstractVector}})(point::AbstractVector) w val = g.amp * exp(GAUSS_PRE * Δ) return convert(T, val) end + +## gradients + +# isotropic +function fgrad(g::Gaussian, point::AbstractVector) + f = g(point) + + xdiff = first(point) - first(g.pos) + ydiff = last(point) - last(g.pos) + dfdpos = -2 * GAUSS_PRE * f / g.fwhm^2 .* SA[xdiff, ydiff] + dfdfwhm = -2 * GAUSS_PRE * f * (xdiff^2 + ydiff^2) / g.fwhm^3 + dfdamp = f / g.amp + return f, dfdpos, dfdfwhm, dfdamp +end + +# short printing +Base.show(io::IO, g::Gaussian{T}) where {T} = print(io, "Gaussian{$T}(pos=$(g.pos), fwhm=$(g.fwhm), amp=$(g.amp))") + +# diagonal +function fgrad(g::Gaussian{T,<:Union{Tuple,AbstractVector}}, point::AbstractVector) where T + f = g(point) + + xdiff = first(point) - first(g.pos) + ydiff = last(point) - last(g.pos) + dfdpos = -2 * GAUSS_PRE * f .* SA[xdiff / first(g.fwhm)^2, ydiff / last(g.fwhm)^2] + dfdfwhm = -2 * GAUSS_PRE * f .* SA[xdiff^2 / first(g.fwhm)^3, ydiff^2 / last(g.fwhm)^3] + dfda = f / g.amp + return f, dfdpos, dfdfwhm, dfda +end + +function frule((Δpsf, Δp), g::Gaussian, point::AbstractVector) + f, dfdpos, dfdfwhm, dfda = fgrad(g, point) + Δf = dot(dfdpos, Δpsf.pos) + dot(dfdfwhm, Δpsf.fwhm) + dfda * Δpsf.amp + Δf -= dot(dfdpos, Δp) + return f, Δf +end + +function rrule(g::G, point::AbstractVector) where {G<:Gaussian} + f, dfdpos, dfdfwhm, dfda = fgrad(g, point) + function Gaussian_pullback(Δf) + ∂pos = dfdpos .* Δf + ∂fwhm = dfdfwhm .* Δf + ∂g = Tangent{G}(pos=∂pos, fwhm=∂fwhm, amp=dfda * Δf, indices=ZeroTangent()) + ∂pos = dfdpos .* -Δf + return ∂g, ∂pos + end + return f, Gaussian_pullback +end From 7ea628d64b308e8a0e2eb6fd221d0453ac4b7745 Mon Sep 17 00:00:00 2001 From: Miles Lucas Date: Wed, 15 Sep 2021 11:06:21 -1000 Subject: [PATCH 3/9] add compat --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index a2006d5..83638ca 100644 --- a/Project.toml +++ b/Project.toml @@ -13,6 +13,7 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" [compat] +ChainRulesCore = "1" CoordinateTransformations = "0.6" Distances = "0.10" KeywordCalls = "0.2" From 1a2546c1bd3684633db6aeb29f521c55d752f72f Mon Sep 17 00:00:00 2001 From: Miles Lucas Date: Wed, 15 Sep 2021 11:06:50 -1000 Subject: [PATCH 4/9] add chainrules testing packages --- test/Project.toml | 3 +++ test/runtests.jl | 6 ++++++ 2 files changed, 9 insertions(+) diff --git a/test/Project.toml b/test/Project.toml index d8d1595..17b3be3 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,9 +1,12 @@ [deps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] +ChainRulesCore = "1" +ChainRulesTestUtils = "1" RecipesBase = "1" StaticArrays = "0.12, 1" diff --git a/test/runtests.jl b/test/runtests.jl index 3d860f9..9b22aaa 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,3 +1,5 @@ +using ChainRulesCore +using ChainRulesTestUtils using PSFModels using PSFModels: Gaussian, Normal, AiryDisk, Moffat using StaticArrays @@ -68,6 +70,10 @@ function test_model_interface(K) @test m(m.pos) ≈ BigFloat(1) end +function test_model_grads(K) + +end + @testset "Model Interface - $K" for K in (Gaussian, AiryDisk, Moffat) test_model_interface(K) end From 6cca53624c085f4adfdd5fabd85fac5c24cd4831 Mon Sep 17 00:00:00 2001 From: Miles Lucas Date: Wed, 15 Sep 2021 11:11:56 -1000 Subject: [PATCH 5/9] add printing and tests for all models --- Project.toml | 1 + src/airy.jl | 3 +++ src/moffat.jl | 4 ++++ test/runtests.jl | 9 ++++++++- 4 files changed, 16 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 83638ca..2a82138 100644 --- a/Project.toml +++ b/Project.toml @@ -8,6 +8,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" CoordinateTransformations = "150eb455-5306-5404-9cee-2592286d6298" Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" KeywordCalls = "4d827475-d3e4-43d6-abe3-9688362ede9f" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" diff --git a/src/airy.jl b/src/airy.jl index cb371f2..451dc73 100644 --- a/src/airy.jl +++ b/src/airy.jl @@ -43,6 +43,9 @@ end Base.size(a::AiryDisk) = map(length, a.indices) Base.axes(a::AiryDisk) = a.indices +# short printing +Base.show(io::IO, a::AiryDisk{T}) where {T} = print(io, "AiryDisk{$T}(pos=$(a.pos), fwhm=$(a.fwhm), amp=$(a.amp))") + const rz = 3.8317059702075125 / π function (a::AiryDisk{T})(point::AbstractVector) where T diff --git a/src/moffat.jl b/src/moffat.jl index e4a0563..c780e4a 100644 --- a/src/moffat.jl +++ b/src/moffat.jl @@ -39,6 +39,10 @@ end Base.size(m::Moffat) = map(length, m.indices) Base.axes(m::Moffat) = m.indices +# short printing +Base.show(io::IO, m::Moffat{T}) where {T} = print(io, "Moffat{$T}(pos=$(m.pos), fwhm=$(m.fwhm), amp=$(m.amp), alpha=$(m.alpha))") + + # scalar case function (m::Moffat{T})(point::AbstractVector) where T hwhm = m.fwhm / 2 diff --git a/test/runtests.jl b/test/runtests.jl index 9b22aaa..86c2b48 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -82,11 +82,13 @@ end m = Gaussian(fwhm=10) expected = exp(-4 * log(2) * sum(abs2, SA[1, 2]) / 100) @test m[2, 1] ≈ m(1, 2) ≈ expected - + @test repr(m) == "Gaussian{Float64}(pos=[0, 0], fwhm=10, amp=1.0)" + m = Gaussian(fwhm=(10, 9)) wdist = (1/10)^2 + (2/9)^2 expected = exp(-4 * log(2) * wdist) @test m[2, 1] ≈ m(1, 2) ≈ expected + @test repr(m) == "Gaussian{Float64}(pos=[0, 0], fwhm=(10, 9), amp=1.0)" # test Normal alias @test Normal(fwhm=10) === Gaussian(fwhm=10) @@ -101,6 +103,7 @@ end @test m(-radius, 0) ≈ 0 atol=eps(Float64) @test m(0, radius) ≈ 0 atol=eps(Float64) @test m(0, -radius) ≈ 0 atol=eps(Float64) + @test repr(m) == "AiryDisk{Float64}(pos=[0, 0], fwhm=10, amp=1.0)" m = AiryDisk(fwhm=(10, 9)) r1 = m.fwhm[1] * 1.18677 @@ -110,22 +113,26 @@ end @test m(-r1, 0) ≈ 0 atol=eps(Float64) @test m(0, r2) ≈ 0 atol=eps(Float64) @test m(0, -r2) ≈ 0 atol=eps(Float64) + @test repr(m) == "AiryDisk{Float64}(pos=[0, 0], fwhm=(10, 9), amp=1.0)" end @testset "Moffat" begin m = Moffat(fwhm=10) expected = inv(1 + sum(abs2, SA[1, 2]) / 25) @test m[2, 1] ≈ m(1, 2) ≈ expected + @test repr(m) == "Moffat{Float64}(pos=[0, 0], fwhm=10, amp=1.0, alpha=1)" m = Moffat(fwhm=(10, 9)) wdist = (1/5)^2 + (2/4.5)^2 expected = inv(1 + wdist) @test m[2, 1] ≈ m(1, 2) ≈ expected + @test repr(m) == "Moffat{Float64}(pos=[0, 0], fwhm=(10, 9), amp=1.0, alpha=1)" # different alpha m = Moffat(fwhm=10, alpha=2) expected = inv(1 + sum(abs2, SA[1, 2]) / 25)^2 @test m[2, 1] ≈ m(1, 2) ≈ expected + @test repr(m) == "Moffat{Float64}(pos=[0, 0], fwhm=10, amp=1.0, alpha=2)" end include("plotting.jl") From 1e5489abe077390dbd5f15ecb4f6a4bc853ead9b Mon Sep 17 00:00:00 2001 From: Miles Lucas Date: Wed, 15 Sep 2021 11:17:48 -1000 Subject: [PATCH 6/9] reorganize tests --- Project.toml | 1 + test/Project.toml | 1 + test/runtests.jl | 121 +++++++++++++++++++++++++++------------------- 3 files changed, 72 insertions(+), 51 deletions(-) diff --git a/Project.toml b/Project.toml index 2a82138..824a7d5 100644 --- a/Project.toml +++ b/Project.toml @@ -20,5 +20,6 @@ Distances = "0.10" KeywordCalls = "0.2" RecipesBase = "1" SpecialFunctions = "0.10, 1" +StableRNGs = "1" StaticArrays = "0.12, 1" julia = "1.5" diff --git a/test/Project.toml b/test/Project.toml index 17b3be3..a309c54 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -2,6 +2,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/runtests.jl b/test/runtests.jl index 86c2b48..b981da2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,9 +2,12 @@ using ChainRulesCore using ChainRulesTestUtils using PSFModels using PSFModels: Gaussian, Normal, AiryDisk, Moffat +using StableRNGs using StaticArrays using Test +rng = StableRNG(122) + function test_model_interface(K) # test defaults m = @inferred K(fwhm=10) @@ -70,69 +73,85 @@ function test_model_interface(K) @test m(m.pos) ≈ BigFloat(1) end -function test_model_grads(K) - -end - -@testset "Model Interface - $K" for K in (Gaussian, AiryDisk, Moffat) - test_model_interface(K) -end - @testset "Gaussian" begin - m = Gaussian(fwhm=10) - expected = exp(-4 * log(2) * sum(abs2, SA[1, 2]) / 100) - @test m[2, 1] ≈ m(1, 2) ≈ expected - @test repr(m) == "Gaussian{Float64}(pos=[0, 0], fwhm=10, amp=1.0)" + test_model_interface(Gaussian) + + @testset "isotropic" begin + m = Gaussian(fwhm=10) + expected = exp(-4 * log(2) * sum(abs2, SA[1, 2]) / 100) + @test m[2, 1] ≈ m(1, 2) ≈ expected + @test repr(m) == "Gaussian{Float64}(pos=[0, 0], fwhm=10, amp=1.0)" + end - m = Gaussian(fwhm=(10, 9)) - wdist = (1/10)^2 + (2/9)^2 - expected = exp(-4 * log(2) * wdist) - @test m[2, 1] ≈ m(1, 2) ≈ expected - @test repr(m) == "Gaussian{Float64}(pos=[0, 0], fwhm=(10, 9), amp=1.0)" + @testset "diagonal" begin + m = Gaussian(fwhm=(10, 9)) + wdist = (1/10)^2 + (2/9)^2 + expected = exp(-4 * log(2) * wdist) + @test m[2, 1] ≈ m(1, 2) ≈ expected + @test repr(m) == "Gaussian{Float64}(pos=[0, 0], fwhm=(10, 9), amp=1.0)" + end # test Normal alias @test Normal(fwhm=10) === Gaussian(fwhm=10) + + @testset "gradients" begin + psf_iso = Gaussian(fwhm=10) + psf_tang = Tangent{Gaussian}(fwhm=rand(rng), pos=rand(rng, 2), amp=rand(rng), indices=ZeroTangent()) + end end @testset "AiryDisk" begin - m = AiryDisk(fwhm=10) - radius = m.fwhm * 1.18677 - # first radius is 0 - @test m(radius, 0) ≈ 0 atol=eps(Float64) - @test m(-radius, 0) ≈ 0 atol=eps(Float64) - @test m(0, radius) ≈ 0 atol=eps(Float64) - @test m(0, -radius) ≈ 0 atol=eps(Float64) - @test repr(m) == "AiryDisk{Float64}(pos=[0, 0], fwhm=10, amp=1.0)" - - m = AiryDisk(fwhm=(10, 9)) - r1 = m.fwhm[1] * 1.18677 - r2 = m.fwhm[2] * 1.18677 - # first radius is 0 - @test m(r1, 0) ≈ 0 atol=eps(Float64) - @test m(-r1, 0) ≈ 0 atol=eps(Float64) - @test m(0, r2) ≈ 0 atol=eps(Float64) - @test m(0, -r2) ≈ 0 atol=eps(Float64) - @test repr(m) == "AiryDisk{Float64}(pos=[0, 0], fwhm=(10, 9), amp=1.0)" + test_model_interface(AiryDisk) + + @testset "isotropic" begin + m = AiryDisk(fwhm=10) + radius = m.fwhm * 1.18677 + # first radius is 0 + @test m(radius, 0) ≈ 0 atol=eps(Float64) + @test m(-radius, 0) ≈ 0 atol=eps(Float64) + @test m(0, radius) ≈ 0 atol=eps(Float64) + @test m(0, -radius) ≈ 0 atol=eps(Float64) + @test repr(m) == "AiryDisk{Float64}(pos=[0, 0], fwhm=10, amp=1.0)" + end + + @testset "diagonal" begin + m = AiryDisk(fwhm=(10, 9)) + r1 = m.fwhm[1] * 1.18677 + r2 = m.fwhm[2] * 1.18677 + # first radius is 0 + @test m(r1, 0) ≈ 0 atol=eps(Float64) + @test m(-r1, 0) ≈ 0 atol=eps(Float64) + @test m(0, r2) ≈ 0 atol=eps(Float64) + @test m(0, -r2) ≈ 0 atol=eps(Float64) + @test repr(m) == "AiryDisk{Float64}(pos=[0, 0], fwhm=(10, 9), amp=1.0)" + end end @testset "Moffat" begin - m = Moffat(fwhm=10) - expected = inv(1 + sum(abs2, SA[1, 2]) / 25) - @test m[2, 1] ≈ m(1, 2) ≈ expected - @test repr(m) == "Moffat{Float64}(pos=[0, 0], fwhm=10, amp=1.0, alpha=1)" - - m = Moffat(fwhm=(10, 9)) - wdist = (1/5)^2 + (2/4.5)^2 - expected = inv(1 + wdist) - @test m[2, 1] ≈ m(1, 2) ≈ expected - @test repr(m) == "Moffat{Float64}(pos=[0, 0], fwhm=(10, 9), amp=1.0, alpha=1)" - - # different alpha - m = Moffat(fwhm=10, alpha=2) - expected = inv(1 + sum(abs2, SA[1, 2]) / 25)^2 - @test m[2, 1] ≈ m(1, 2) ≈ expected - @test repr(m) == "Moffat{Float64}(pos=[0, 0], fwhm=10, amp=1.0, alpha=2)" + test_model_interface(Moffat) + + @testset "isotropic" begin + m = Moffat(fwhm=10) + expected = inv(1 + sum(abs2, SA[1, 2]) / 25) + @test m[2, 1] ≈ m(1, 2) ≈ expected + @test repr(m) == "Moffat{Float64}(pos=[0, 0], fwhm=10, amp=1.0, alpha=1)" + end + + @testset "diagonal" begin + m = Moffat(fwhm=(10, 9)) + wdist = (1/5)^2 + (2/4.5)^2 + expected = inv(1 + wdist) + @test m[2, 1] ≈ m(1, 2) ≈ expected + @test repr(m) == "Moffat{Float64}(pos=[0, 0], fwhm=(10, 9), amp=1.0, alpha=1)" + end + + @testset "alpha" begin + m = Moffat(fwhm=10, alpha=2) + expected = inv(1 + sum(abs2, SA[1, 2]) / 25)^2 + @test m[2, 1] ≈ m(1, 2) ≈ expected + @test repr(m) == "Moffat{Float64}(pos=[0, 0], fwhm=10, amp=1.0, alpha=2)" + end end include("plotting.jl") From 367fd350435631733ae6e9deb0f47c655b0674b9 Mon Sep 17 00:00:00 2001 From: Miles Lucas Date: Wed, 15 Sep 2021 11:19:39 -1000 Subject: [PATCH 7/9] fix project mistype --- Project.toml | 1 - test/Project.toml | 1 + test/runtests.jl | 8 ++++++++ 3 files changed, 9 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 824a7d5..2a82138 100644 --- a/Project.toml +++ b/Project.toml @@ -20,6 +20,5 @@ Distances = "0.10" KeywordCalls = "0.2" RecipesBase = "1" SpecialFunctions = "0.10, 1" -StableRNGs = "1" StaticArrays = "0.12, 1" julia = "1.5" diff --git a/test/Project.toml b/test/Project.toml index a309c54..0481fd7 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -10,4 +10,5 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" ChainRulesCore = "1" ChainRulesTestUtils = "1" RecipesBase = "1" +StableRNGs = "1" StaticArrays = "0.12, 1" diff --git a/test/runtests.jl b/test/runtests.jl index b981da2..6d9ff8a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -97,6 +97,14 @@ end @testset "gradients" begin psf_iso = Gaussian(fwhm=10) psf_tang = Tangent{Gaussian}(fwhm=rand(rng), pos=rand(rng, 2), amp=rand(rng), indices=ZeroTangent()) + point = [1, 2] + test_frule(psf_iso ⊢ psf_tang, point) + test_rrule(psf_iso ⊢ psf_tang, point) + + psf_diag = Gaussian(fwhm=[10, 8]) + psf_tang = Tangent{Gaussian}(fwhm=rand(rng, 2), pos=rand(rng, 2), amp=rand(rng), indices=ZeroTangent()) + test_frule(psf_diag ⊢ psf_tang, point) + test_rrule(psf_diag ⊢ psf_tang, point) end end From 1b026c588f826f2605893f0b38f6efb20ba7cde0 Mon Sep 17 00:00:00 2001 From: Miles Lucas Date: Wed, 15 Sep 2021 11:35:41 -1000 Subject: [PATCH 8/9] add debug mode --- test/runtests.jl | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 6d9ff8a..c46dd89 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -6,6 +6,8 @@ using StableRNGs using StaticArrays using Test +ChainRulesCore.debug_mode() = true + rng = StableRNG(122) function test_model_interface(K) @@ -95,13 +97,14 @@ end @test Normal(fwhm=10) === Gaussian(fwhm=10) @testset "gradients" begin - psf_iso = Gaussian(fwhm=10) + # have to make sure PSFs are all floating point so tangents don't have type issues + psf_iso = Gaussian(fwhm=10.0, pos=zeros(2)) psf_tang = Tangent{Gaussian}(fwhm=rand(rng), pos=rand(rng, 2), amp=rand(rng), indices=ZeroTangent()) - point = [1, 2] + point = Float64[1, 2] test_frule(psf_iso ⊢ psf_tang, point) test_rrule(psf_iso ⊢ psf_tang, point) - psf_diag = Gaussian(fwhm=[10, 8]) + psf_diag = Gaussian(fwhm=Float64[10, 8], pos=zeros(2)) psf_tang = Tangent{Gaussian}(fwhm=rand(rng, 2), pos=rand(rng, 2), amp=rand(rng), indices=ZeroTangent()) test_frule(psf_diag ⊢ psf_tang, point) test_rrule(psf_diag ⊢ psf_tang, point) From d46340adbaf15be714179e8401f60ee9379d1368 Mon Sep 17 00:00:00 2001 From: Miles Lucas Date: Thu, 16 Sep 2021 13:08:21 -1000 Subject: [PATCH 9/9] add finitedifferences snippet for frule tests --- src/gaussian.jl | 8 ++++---- test/Project.toml | 2 ++ test/runtests.jl | 6 ++++-- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/gaussian.jl b/src/gaussian.jl index 509f07d..c1522f0 100644 --- a/src/gaussian.jl +++ b/src/gaussian.jl @@ -45,6 +45,9 @@ end Base.size(g::Gaussian) = map(length, g.indices) Base.axes(g::Gaussian) = g.indices +# short printing +Base.show(io::IO, g::Gaussian{T}) where {T} = print(io, "Gaussian{$T}(pos=$(g.pos), fwhm=$(g.fwhm), amp=$(g.amp))") + # Gaussian pre-factor for normalizing the exponential const GAUSS_PRE = -4 * log(2) @@ -76,9 +79,6 @@ function fgrad(g::Gaussian, point::AbstractVector) return f, dfdpos, dfdfwhm, dfdamp end -# short printing -Base.show(io::IO, g::Gaussian{T}) where {T} = print(io, "Gaussian{$T}(pos=$(g.pos), fwhm=$(g.fwhm), amp=$(g.amp))") - # diagonal function fgrad(g::Gaussian{T,<:Union{Tuple,AbstractVector}}, point::AbstractVector) where T f = g(point) @@ -103,7 +103,7 @@ function rrule(g::G, point::AbstractVector) where {G<:Gaussian} function Gaussian_pullback(Δf) ∂pos = dfdpos .* Δf ∂fwhm = dfdfwhm .* Δf - ∂g = Tangent{G}(pos=∂pos, fwhm=∂fwhm, amp=dfda * Δf, indices=ZeroTangent()) + ∂g = Tangent{G}(pos=∂pos, fwhm=∂fwhm, amp=dfda * Δf, indices=NoTangent()) ∂pos = dfdpos .* -Δf return ∂g, ∂pos end diff --git a/test/Project.toml b/test/Project.toml index 0481fd7..6eca7ce 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,6 +1,7 @@ [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" +FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" @@ -9,6 +10,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] ChainRulesCore = "1" ChainRulesTestUtils = "1" +FiniteDifferences = "0.12" RecipesBase = "1" StableRNGs = "1" StaticArrays = "0.12, 1" diff --git a/test/runtests.jl b/test/runtests.jl index c46dd89..a02f5c3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,6 @@ using ChainRulesCore using ChainRulesTestUtils +using FiniteDifferences using PSFModels using PSFModels: Gaussian, Normal, AiryDisk, Moffat using StableRNGs @@ -97,15 +98,16 @@ end @test Normal(fwhm=10) === Gaussian(fwhm=10) @testset "gradients" begin + FiniteDifferences.to_vec(x::Integer) = Bool[], _ -> x # have to make sure PSFs are all floating point so tangents don't have type issues psf_iso = Gaussian(fwhm=10.0, pos=zeros(2)) - psf_tang = Tangent{Gaussian}(fwhm=rand(rng), pos=rand(rng, 2), amp=rand(rng), indices=ZeroTangent()) + psf_tang = Tangent{Gaussian}(fwhm=rand(rng), pos=rand(rng, 2), amp=rand(rng), indices=NoTangent()) point = Float64[1, 2] test_frule(psf_iso ⊢ psf_tang, point) test_rrule(psf_iso ⊢ psf_tang, point) psf_diag = Gaussian(fwhm=Float64[10, 8], pos=zeros(2)) - psf_tang = Tangent{Gaussian}(fwhm=rand(rng, 2), pos=rand(rng, 2), amp=rand(rng), indices=ZeroTangent()) + psf_tang = Tangent{Gaussian}(fwhm=rand(rng, 2), pos=rand(rng, 2), amp=rand(rng), indices=NoTangent()) test_frule(psf_diag ⊢ psf_tang, point) test_rrule(psf_diag ⊢ psf_tang, point) end