Skip to content

add ChainRulesCore rules #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,17 @@ authors = ["Miles Lucas <[email protected]> and contributors"]
version = "0.3.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
CoordinateTransformations = "150eb455-5306-5404-9cee-2592286d6298"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
KeywordCalls = "4d827475-d3e4-43d6-abe3-9688362ede9f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[compat]
ChainRulesCore = "1"
CoordinateTransformations = "0.6"
Distances = "0.10"
KeywordCalls = "0.2"
Expand Down
3 changes: 3 additions & 0 deletions src/PSFModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,12 @@ plot(model, axes(other)) # use axes from other array
"""
module PSFModels

using ChainRulesCore
import ChainRulesCore: frule, rrule
using CoordinateTransformations
using Distances
using KeywordCalls
using LinearAlgebra
using SpecialFunctions
using StaticArrays

Expand Down
3 changes: 3 additions & 0 deletions src/airy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ end
Base.size(a::AiryDisk) = map(length, a.indices)
Base.axes(a::AiryDisk) = a.indices

# short printing
Base.show(io::IO, a::AiryDisk{T}) where {T} = print(io, "AiryDisk{$T}(pos=$(a.pos), fwhm=$(a.fwhm), amp=$(a.amp))")

const rz = 3.8317059702075125 / π

function (a::AiryDisk{T})(point::AbstractVector) where T
Expand Down
48 changes: 48 additions & 0 deletions src/gaussian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ end
Base.size(g::Gaussian) = map(length, g.indices)
Base.axes(g::Gaussian) = g.indices

# short printing
Base.show(io::IO, g::Gaussian{T}) where {T} = print(io, "Gaussian{$T}(pos=$(g.pos), fwhm=$(g.fwhm), amp=$(g.amp))")

# Gaussian pre-factor for normalizing the exponential
const GAUSS_PRE = -4 * log(2)

Expand All @@ -61,3 +64,48 @@ function (g::Gaussian{T,<:Union{Tuple,AbstractVector}})(point::AbstractVector) w
val = g.amp * exp(GAUSS_PRE * Δ)
return convert(T, val)
end

## gradients

# isotropic
function fgrad(g::Gaussian, point::AbstractVector)
f = g(point)

xdiff = first(point) - first(g.pos)
ydiff = last(point) - last(g.pos)
dfdpos = -2 * GAUSS_PRE * f / g.fwhm^2 .* SA[xdiff, ydiff]
dfdfwhm = -2 * GAUSS_PRE * f * (xdiff^2 + ydiff^2) / g.fwhm^3
dfdamp = f / g.amp
return f, dfdpos, dfdfwhm, dfdamp
end

# diagonal
function fgrad(g::Gaussian{T,<:Union{Tuple,AbstractVector}}, point::AbstractVector) where T
f = g(point)

xdiff = first(point) - first(g.pos)
ydiff = last(point) - last(g.pos)
dfdpos = -2 * GAUSS_PRE * f .* SA[xdiff / first(g.fwhm)^2, ydiff / last(g.fwhm)^2]
dfdfwhm = -2 * GAUSS_PRE * f .* SA[xdiff^2 / first(g.fwhm)^3, ydiff^2 / last(g.fwhm)^3]
dfda = f / g.amp
return f, dfdpos, dfdfwhm, dfda
end

function frule((Δpsf, Δp), g::Gaussian, point::AbstractVector)
f, dfdpos, dfdfwhm, dfda = fgrad(g, point)
Δf = dot(dfdpos, Δpsf.pos) + dot(dfdfwhm, Δpsf.fwhm) + dfda * Δpsf.amp
Δf -= dot(dfdpos, Δp)
return f, Δf
end

function rrule(g::G, point::AbstractVector) where {G<:Gaussian}
f, dfdpos, dfdfwhm, dfda = fgrad(g, point)
function Gaussian_pullback(Δf)
∂pos = dfdpos .* Δf
∂fwhm = dfdfwhm .* Δf
∂g = Tangent{G}(pos=∂pos, fwhm=∂fwhm, amp=dfda * Δf, indices=NoTangent())
∂pos = dfdpos .* -Δf
return ∂g, ∂pos
end
return f, Gaussian_pullback
end
4 changes: 4 additions & 0 deletions src/moffat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ end
Base.size(m::Moffat) = map(length, m.indices)
Base.axes(m::Moffat) = m.indices

# short printing
Base.show(io::IO, m::Moffat{T}) where {T} = print(io, "Moffat{$T}(pos=$(m.pos), fwhm=$(m.fwhm), amp=$(m.amp), alpha=$(m.alpha))")


# scalar case
function (m::Moffat{T})(point::AbstractVector) where T
hwhm = m.fwhm / 2
Expand Down
8 changes: 8 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
ChainRulesCore = "1"
ChainRulesTestUtils = "1"
FiniteDifferences = "0.12"
RecipesBase = "1"
StableRNGs = "1"
StaticArrays = "0.12, 1"
127 changes: 86 additions & 41 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
using ChainRulesCore
using ChainRulesTestUtils
using FiniteDifferences
using PSFModels
using PSFModels: Gaussian, Normal, AiryDisk, Moffat
using StableRNGs
using StaticArrays
using Test

ChainRulesCore.debug_mode() = true

rng = StableRNG(122)

function test_model_interface(K)
# test defaults
m = @inferred K(fwhm=10)
Expand Down Expand Up @@ -68,58 +76,95 @@ function test_model_interface(K)
@test m(m.pos) ≈ BigFloat(1)
end

@testset "Model Interface - $K" for K in (Gaussian, AiryDisk, Moffat)
test_model_interface(K)
end

@testset "Gaussian" begin
m = Gaussian(fwhm=10)
expected = exp(-4 * log(2) * sum(abs2, SA[1, 2]) / 100)
@test m[2, 1] ≈ m(1, 2) ≈ expected

m = Gaussian(fwhm=(10, 9))
wdist = (1/10)^2 + (2/9)^2
expected = exp(-4 * log(2) * wdist)
@test m[2, 1] ≈ m(1, 2) ≈ expected
test_model_interface(Gaussian)

@testset "isotropic" begin
m = Gaussian(fwhm=10)
expected = exp(-4 * log(2) * sum(abs2, SA[1, 2]) / 100)
@test m[2, 1] ≈ m(1, 2) ≈ expected
@test repr(m) == "Gaussian{Float64}(pos=[0, 0], fwhm=10, amp=1.0)"
end

@testset "diagonal" begin
m = Gaussian(fwhm=(10, 9))
wdist = (1/10)^2 + (2/9)^2
expected = exp(-4 * log(2) * wdist)
@test m[2, 1] ≈ m(1, 2) ≈ expected
@test repr(m) == "Gaussian{Float64}(pos=[0, 0], fwhm=(10, 9), amp=1.0)"
end

# test Normal alias
@test Normal(fwhm=10) === Gaussian(fwhm=10)

@testset "gradients" begin
FiniteDifferences.to_vec(x::Integer) = Bool[], _ -> x
# have to make sure PSFs are all floating point so tangents don't have type issues
psf_iso = Gaussian(fwhm=10.0, pos=zeros(2))
psf_tang = Tangent{Gaussian}(fwhm=rand(rng), pos=rand(rng, 2), amp=rand(rng), indices=NoTangent())
point = Float64[1, 2]
test_frule(psf_iso ⊢ psf_tang, point)
test_rrule(psf_iso ⊢ psf_tang, point)

psf_diag = Gaussian(fwhm=Float64[10, 8], pos=zeros(2))
psf_tang = Tangent{Gaussian}(fwhm=rand(rng, 2), pos=rand(rng, 2), amp=rand(rng), indices=NoTangent())
test_frule(psf_diag ⊢ psf_tang, point)
test_rrule(psf_diag ⊢ psf_tang, point)
end
end


@testset "AiryDisk" begin
m = AiryDisk(fwhm=10)
radius = m.fwhm * 1.18677
# first radius is 0
@test m(radius, 0) ≈ 0 atol=eps(Float64)
@test m(-radius, 0) ≈ 0 atol=eps(Float64)
@test m(0, radius) ≈ 0 atol=eps(Float64)
@test m(0, -radius) ≈ 0 atol=eps(Float64)

m = AiryDisk(fwhm=(10, 9))
r1 = m.fwhm[1] * 1.18677
r2 = m.fwhm[2] * 1.18677
# first radius is 0
@test m(r1, 0) ≈ 0 atol=eps(Float64)
@test m(-r1, 0) ≈ 0 atol=eps(Float64)
@test m(0, r2) ≈ 0 atol=eps(Float64)
@test m(0, -r2) ≈ 0 atol=eps(Float64)
test_model_interface(AiryDisk)

@testset "isotropic" begin
m = AiryDisk(fwhm=10)
radius = m.fwhm * 1.18677
# first radius is 0
@test m(radius, 0) ≈ 0 atol=eps(Float64)
@test m(-radius, 0) ≈ 0 atol=eps(Float64)
@test m(0, radius) ≈ 0 atol=eps(Float64)
@test m(0, -radius) ≈ 0 atol=eps(Float64)
@test repr(m) == "AiryDisk{Float64}(pos=[0, 0], fwhm=10, amp=1.0)"
end

@testset "diagonal" begin
m = AiryDisk(fwhm=(10, 9))
r1 = m.fwhm[1] * 1.18677
r2 = m.fwhm[2] * 1.18677
# first radius is 0
@test m(r1, 0) ≈ 0 atol=eps(Float64)
@test m(-r1, 0) ≈ 0 atol=eps(Float64)
@test m(0, r2) ≈ 0 atol=eps(Float64)
@test m(0, -r2) ≈ 0 atol=eps(Float64)
@test repr(m) == "AiryDisk{Float64}(pos=[0, 0], fwhm=(10, 9), amp=1.0)"
end
end

@testset "Moffat" begin
m = Moffat(fwhm=10)
expected = inv(1 + sum(abs2, SA[1, 2]) / 25)
@test m[2, 1] ≈ m(1, 2) ≈ expected

m = Moffat(fwhm=(10, 9))
wdist = (1/5)^2 + (2/4.5)^2
expected = inv(1 + wdist)
@test m[2, 1] ≈ m(1, 2) ≈ expected

# different alpha
m = Moffat(fwhm=10, alpha=2)
expected = inv(1 + sum(abs2, SA[1, 2]) / 25)^2
@test m[2, 1] ≈ m(1, 2) ≈ expected
test_model_interface(Moffat)

@testset "isotropic" begin
m = Moffat(fwhm=10)
expected = inv(1 + sum(abs2, SA[1, 2]) / 25)
@test m[2, 1] ≈ m(1, 2) ≈ expected
@test repr(m) == "Moffat{Float64}(pos=[0, 0], fwhm=10, amp=1.0, alpha=1)"
end

@testset "diagonal" begin
m = Moffat(fwhm=(10, 9))
wdist = (1/5)^2 + (2/4.5)^2
expected = inv(1 + wdist)
@test m[2, 1] ≈ m(1, 2) ≈ expected
@test repr(m) == "Moffat{Float64}(pos=[0, 0], fwhm=(10, 9), amp=1.0, alpha=1)"
end

@testset "alpha" begin
m = Moffat(fwhm=10, alpha=2)
expected = inv(1 + sum(abs2, SA[1, 2]) / 25)^2
@test m[2, 1] ≈ m(1, 2) ≈ expected
@test repr(m) == "Moffat{Float64}(pos=[0, 0], fwhm=10, amp=1.0, alpha=2)"
end
end

include("plotting.jl")