From 85f715a7edb3f25a9b5b810b76c6869ae5f278a3 Mon Sep 17 00:00:00 2001 From: MarcoH Date: Tue, 23 Sep 2025 15:26:39 +0200 Subject: [PATCH 1/2] feature: new node. Add rules and av energy --- src/nodes/predefined.jl | 1 + src/nodes/predefined/sigmoid.jl | 17 +++++++++++++++++ src/rules/gamma_mixture/a.jl | 2 +- src/rules/sigmoid/in.jl | 9 +++++++++ src/rules/sigmoid/out.jl | 7 +++++++ src/rules/sigmoid/xi.jl | 4 ++++ test/nodes/predefined/sigmoid_tests.jl | 18 ++++++++++++++++++ 7 files changed, 57 insertions(+), 1 deletion(-) create mode 100644 src/nodes/predefined/sigmoid.jl create mode 100644 src/rules/sigmoid/in.jl create mode 100644 src/rules/sigmoid/out.jl create mode 100644 src/rules/sigmoid/xi.jl create mode 100644 test/nodes/predefined/sigmoid_tests.jl diff --git a/src/nodes/predefined.jl b/src/nodes/predefined.jl index 7bc1bf17b..9d57efa7c 100644 --- a/src/nodes/predefined.jl +++ b/src/nodes/predefined.jl @@ -34,6 +34,7 @@ include("predefined/continuous_transition.jl") include("predefined/half_normal.jl") include("predefined/binomial_polya.jl") include("predefined/multinomial_polya.jl") +include("predefined/Sigmoid.jl") include("predefined/flow/flow.jl") include("predefined/delta/delta.jl") diff --git a/src/nodes/predefined/sigmoid.jl b/src/nodes/predefined/sigmoid.jl new file mode 100644 index 000000000..88ba19b11 --- /dev/null +++ b/src/nodes/predefined/sigmoid.jl @@ -0,0 +1,17 @@ +struct Sigmoid end +using StatsFuns: logistic + +@node Sigmoid Stochastic [out, in, ξ] + +@average_energy Sigmoid (q_out::Categorical, q_in::UnivariateNormalDistributionsFamily, q_ξ::PointMass) = begin + + mout = mean(q_out) + m_in, v_in = mean_var(q_in) + + ξ_hat = mean(q_ξ) + + U = m_in * mout + log(logistic(ξ_hat)) - 0.5 * (m_in + ξ_hat) - ((logistic(ξ_hat) - 0.5)/(2*ξ_hat)) * (m_in^2 + v_in - ξ_hat^2) + return U +end + +export Sigmoid \ No newline at end of file diff --git a/src/rules/gamma_mixture/a.jl b/src/rules/gamma_mixture/a.jl index 9623a082e..9c00830e4 100644 --- a/src/rules/gamma_mixture/a.jl +++ b/src/rules/gamma_mixture/a.jl @@ -1,5 +1,5 @@ -@rule GammaMixture((:a, k), Marginalisation) (q_out::Any, q_switch::Any, q_b::GammaDistributionsFamily) = begin +@rule GammaMixture((:a, k), Marginalisation) (q_out::GammaDistributionsFamily, q_switch::Any, q_b::GammaDistributionsFamily) = begin p = probvec(q_switch)[k] β = mean(log, q_out) + mean(log, q_b) γ = p * β diff --git a/src/rules/sigmoid/in.jl b/src/rules/sigmoid/in.jl new file mode 100644 index 000000000..0d77f61fe --- /dev/null +++ b/src/rules/sigmoid/in.jl @@ -0,0 +1,9 @@ + +@rule Sigmoid(:in, Marginalisation) (q_out::Categorical, q_ξ::PointMass) = begin + + mout = mean(q_out) + ξ_hat = mean(q_ξ) + w = 2 * ((logistic(ξ_hat) - 0.5)/(2*ξ_hat)) + mout_w = (mout - 0.5)* w + return NormalWeightedMeanPrecision(mout_w, w) +end \ No newline at end of file diff --git a/src/rules/sigmoid/out.jl b/src/rules/sigmoid/out.jl new file mode 100644 index 000000000..ba4799980 --- /dev/null +++ b/src/rules/sigmoid/out.jl @@ -0,0 +1,7 @@ +using StatsFuns: logistic +@rule sigmoid(:out, Marginalisation) (q_in::UnivariateNormalDistributionsFamily,q_ξ::PointMass) = begin + m_in = mean(q_in) + p = logistic(m_in) + return Categorical(p, 1 - p) +end + diff --git a/src/rules/sigmoid/xi.jl b/src/rules/sigmoid/xi.jl new file mode 100644 index 000000000..e9258595a --- /dev/null +++ b/src/rules/sigmoid/xi.jl @@ -0,0 +1,4 @@ +@rule Sigmoid(:ξ, Marginalisation) (q_out::Any, q_in::UnivariateNormalDistributionsFamily) = begin + m_in, v_in = mean_cov(q_in) + return sqrt(m_in^2 + v_in) +end \ No newline at end of file diff --git a/test/nodes/predefined/sigmoid_tests.jl b/test/nodes/predefined/sigmoid_tests.jl new file mode 100644 index 000000000..bdd945a42 --- /dev/null +++ b/test/nodes/predefined/sigmoid_tests.jl @@ -0,0 +1,18 @@ +@testitem "sigmoidNode" begin + using ReactiveMP, Random, BayesBase, ExponentialFamily + import ReactiveMP: Sigmoid + + @testset "Average energy" begin + q_in = NormalMeanVariance(0.0, 1.0) + for normal_fam in (NormalMeanVariance, NormalMeanPrecision, NormalWeightedMeanPrecision) + q_in_adj = convert(normal_fam, q_in) + @test score( + AverageEnergy(), + Sigmoid, + Val{(:out, :in, :ξ)}(), + (Marginal(Categorical(0.5, 0.5), false, false, nothing), Marginal(q_in_adj, false, false, nothing), Marginal(PointMass(1.0), false, false, nothing)), + nothing + )≈ -0.8132616875182228 + end + end +end From 39d3242f5ee535112b8dde18dbfbae3c24b51ee4 Mon Sep 17 00:00:00 2001 From: MarcoH Date: Tue, 21 Oct 2025 16:25:05 +0200 Subject: [PATCH 2/2] feature: new node (logistic sigmoid). Add rules and av energy --- src/nodes/predefined.jl | 2 +- src/nodes/predefined/sigmoid.jl | 21 +++++++++++---------- src/rules/predefined.jl | 4 ++++ src/rules/sigmoid/in.jl | 26 ++++++++++++++++++-------- src/rules/sigmoid/out.jl | 10 +++++++--- src/rules/sigmoid/xi.jl | 4 ---- src/rules/sigmoid/zeta.jl | 5 +++++ test/nodes/predefined/sigmoid_tests.jl | 4 ++-- test/rules/sigmoid/in_tests.jl | 22 ++++++++++++++++++++++ test/rules/sigmoid/out_tests.jl | 19 +++++++++++++++++++ test/rules/sigmoid/zeta_tests.jl | 19 +++++++++++++++++++ 11 files changed, 108 insertions(+), 28 deletions(-) delete mode 100644 src/rules/sigmoid/xi.jl create mode 100644 src/rules/sigmoid/zeta.jl create mode 100644 test/rules/sigmoid/in_tests.jl create mode 100644 test/rules/sigmoid/out_tests.jl create mode 100644 test/rules/sigmoid/zeta_tests.jl diff --git a/src/nodes/predefined.jl b/src/nodes/predefined.jl index 9d57efa7c..5257b4de3 100644 --- a/src/nodes/predefined.jl +++ b/src/nodes/predefined.jl @@ -34,7 +34,7 @@ include("predefined/continuous_transition.jl") include("predefined/half_normal.jl") include("predefined/binomial_polya.jl") include("predefined/multinomial_polya.jl") -include("predefined/Sigmoid.jl") +include("predefined/sigmoid.jl") include("predefined/flow/flow.jl") include("predefined/delta/delta.jl") diff --git a/src/nodes/predefined/sigmoid.jl b/src/nodes/predefined/sigmoid.jl index 88ba19b11..7c4c80e33 100644 --- a/src/nodes/predefined/sigmoid.jl +++ b/src/nodes/predefined/sigmoid.jl @@ -1,17 +1,18 @@ +using StatsFuns: logistic, softplus +using Distributions: pdf + +export Sigmoid + struct Sigmoid end -using StatsFuns: logistic -@node Sigmoid Stochastic [out, in, ξ] +@node Sigmoid Stochastic [out, in, ζ] -@average_energy Sigmoid (q_out::Categorical, q_in::UnivariateNormalDistributionsFamily, q_ξ::PointMass) = begin - - mout = mean(q_out) +@average_energy Sigmoid (q_out::Categorical, q_in::UnivariateNormalDistributionsFamily, q_ζ::PointMass) = begin + m_out = pdf(q_out, 1) m_in, v_in = mean_var(q_in) - - ξ_hat = mean(q_ξ) - U = m_in * mout + log(logistic(ξ_hat)) - 0.5 * (m_in + ξ_hat) - ((logistic(ξ_hat) - 0.5)/(2*ξ_hat)) * (m_in^2 + v_in - ξ_hat^2) + ζ_hat = mean(q_ζ) + + U = -(m_in * m_out - softplus(-ζ_hat) - (0.5 * (m_in + ζ_hat)) - 0.5 * ((logistic(ζ_hat) - 0.5)/ζ_hat) * (m_in^2 + v_in - ζ_hat^2)) return U end - -export Sigmoid \ No newline at end of file diff --git a/src/rules/predefined.jl b/src/rules/predefined.jl index a406bc2f2..883f4608a 100644 --- a/src/rules/predefined.jl +++ b/src/rules/predefined.jl @@ -198,3 +198,7 @@ include("multinomial_polya/x.jl") include("dirichlet_collection/out.jl") include("dirichlet_collection/marginals.jl") + +include("sigmoid/in.jl") +include("sigmoid/out.jl") +include("sigmoid/zeta.jl") diff --git a/src/rules/sigmoid/in.jl b/src/rules/sigmoid/in.jl index 0d77f61fe..1fe0e92db 100644 --- a/src/rules/sigmoid/in.jl +++ b/src/rules/sigmoid/in.jl @@ -1,9 +1,19 @@ +using Distributions: pdf +using StatsFuns: logistic +@rule Sigmoid(:in, Marginalisation) (q_out::Categorical, q_ζ::PointMass) = begin + m_out = pdf(q_out, 1) + ζ_hat = mean(q_ζ) + w = (logistic(ζ_hat) - 0.5)/ζ_hat + ξ = (m_out - 0.5) * w + T = promote_type(eltype(m_out), eltype(ζ_hat)) + return NormalWeightedMeanPrecision{T}(ξ, w) +end -@rule Sigmoid(:in, Marginalisation) (q_out::Categorical, q_ξ::PointMass) = begin - - mout = mean(q_out) - ξ_hat = mean(q_ξ) - w = 2 * ((logistic(ξ_hat) - 0.5)/(2*ξ_hat)) - mout_w = (mout - 0.5)* w - return NormalWeightedMeanPrecision(mout_w, w) -end \ No newline at end of file +@rule Sigmoid(:in, Marginalisation) (q_out::PointMass, q_ζ::PointMass) = begin + m_out = mean(q_out) + ζ_hat = mean(q_ζ) + w = (logistic(ζ_hat) - 0.5)/ζ_hat + ξ = (m_out - 0.5) * w + T = promote_type(eltype(m_out), eltype(ζ_hat)) + return NormalWeightedMeanPrecision{T}(ξ, w) +end diff --git a/src/rules/sigmoid/out.jl b/src/rules/sigmoid/out.jl index ba4799980..f7af8f1c4 100644 --- a/src/rules/sigmoid/out.jl +++ b/src/rules/sigmoid/out.jl @@ -1,7 +1,11 @@ using StatsFuns: logistic -@rule sigmoid(:out, Marginalisation) (q_in::UnivariateNormalDistributionsFamily,q_ξ::PointMass) = begin +@rule Sigmoid(:out, Marginalisation) (q_in::UnivariateNormalDistributionsFamily, q_ζ::PointMass) = begin m_in = mean(q_in) + ζ_hat = mean(q_ζ) p = logistic(m_in) - return Categorical(p, 1 - p) + T = promote_type(eltype(m_in), eltype(ζ_hat)) + probs = clamp.([p, 1 - p], tiny, 1 - tiny) + probs ./= sum(probs) + probs_T = convert(Vector{T}, probs) + return Categorical(probs_T) end - diff --git a/src/rules/sigmoid/xi.jl b/src/rules/sigmoid/xi.jl deleted file mode 100644 index e9258595a..000000000 --- a/src/rules/sigmoid/xi.jl +++ /dev/null @@ -1,4 +0,0 @@ -@rule Sigmoid(:ξ, Marginalisation) (q_out::Any, q_in::UnivariateNormalDistributionsFamily) = begin - m_in, v_in = mean_cov(q_in) - return sqrt(m_in^2 + v_in) -end \ No newline at end of file diff --git a/src/rules/sigmoid/zeta.jl b/src/rules/sigmoid/zeta.jl new file mode 100644 index 000000000..8d4da1086 --- /dev/null +++ b/src/rules/sigmoid/zeta.jl @@ -0,0 +1,5 @@ +@rule Sigmoid(:ζ, Marginalisation) (q_out::Any, q_in::UnivariateNormalDistributionsFamily) = begin + m_in, v_in = mean_var(q_in) + T = promote_type(eltype(m_in), eltype(v_in)) + return PointMass{T}(sqrt(m_in^2 + v_in)) +end diff --git a/test/nodes/predefined/sigmoid_tests.jl b/test/nodes/predefined/sigmoid_tests.jl index bdd945a42..d5148633f 100644 --- a/test/nodes/predefined/sigmoid_tests.jl +++ b/test/nodes/predefined/sigmoid_tests.jl @@ -9,10 +9,10 @@ @test score( AverageEnergy(), Sigmoid, - Val{(:out, :in, :ξ)}(), + Val{(:out, :in, :ζ)}(), (Marginal(Categorical(0.5, 0.5), false, false, nothing), Marginal(q_in_adj, false, false, nothing), Marginal(PointMass(1.0), false, false, nothing)), nothing - )≈ -0.8132616875182228 + ) ≈ 0.8132616875182228 end end end diff --git a/test/rules/sigmoid/in_tests.jl b/test/rules/sigmoid/in_tests.jl new file mode 100644 index 000000000..ace1e835d --- /dev/null +++ b/test/rules/sigmoid/in_tests.jl @@ -0,0 +1,22 @@ +@testitem "rules:Sigmoid:in" begin + using ReactiveMP, BayesBase, Random, ExponentialFamily, Distributions + using StatsFuns: logistic + + import ReactiveMP: @test_rules + + @testset "Mean Field: (q_out::Categorical, q_ζ::PointMass) - Float64" begin + @test_rules [check_type_promotion = true, atol = [Float64 => 1e-5]] Sigmoid(:in, Marginalisation) [ + (input = (q_out = Categorical([0.5, 0.5]), q_ζ = PointMass(1.0)), output = NormalWeightedMeanPrecision(0.0, 0.2310585786300049)), + (input = (q_out = Categorical([1.0, 0.0]), q_ζ = PointMass(1.0)), output = NormalWeightedMeanPrecision(0.11552928931500245, 0.2310585786300049)), + (input = (q_out = Categorical([0.0, 1.0]), q_ζ = PointMass(1.0)), output = NormalWeightedMeanPrecision(-0.11552928931500245, 0.2310585786300049)) + ] + end + + @testset "Mean Field: (q_out::PointMass, q_ζ::PointMass) - Float64" begin + @test_rules [check_type_promotion = true, atol = [Float64 => 1e-5]] Sigmoid(:in, Marginalisation) [ + (input = (q_out = PointMass(0.5), q_ζ = PointMass(1.0)), output = NormalWeightedMeanPrecision(0.0, 0.2310585786300049)), + (input = (q_out = PointMass(1.0), q_ζ = PointMass(1.0)), output = NormalWeightedMeanPrecision(0.11552928931500245, 0.2310585786300049)), + (input = (q_out = PointMass(0.0), q_ζ = PointMass(1.0)), output = NormalWeightedMeanPrecision(-0.11552928931500245, 0.2310585786300049)) + ] + end +end diff --git a/test/rules/sigmoid/out_tests.jl b/test/rules/sigmoid/out_tests.jl new file mode 100644 index 000000000..92fde765f --- /dev/null +++ b/test/rules/sigmoid/out_tests.jl @@ -0,0 +1,19 @@ +@testitem "rules:Sigmoid:out" begin + using ReactiveMP, BayesBase, Random, ExponentialFamily, Distributions + using StatsFuns: logistic + + import ReactiveMP: @test_rules + + @testset "Mean Field: (q_in::UnivariateNormalDistributionsFamily, q_ζ::PointMass)" begin + q_in = [NormalMeanVariance(0.0, 1.0), NormalMeanVariance(-1.0, 1.0), NormalMeanVariance(10.0, 1.0)] + results = [[0.5, 0.5], [0.2689414213699951, 0.7310585786300049], [0.9999546021312976, 4.5397868702390376e-5]] + for (i, result) in enumerate(results) + for normal_fam in (NormalMeanVariance, NormalMeanPrecision, NormalWeightedMeanPrecision) + q_in_adj = convert(normal_fam, q_in[i]) + @test_rules [check_type_promotion = true, atol = [Float64 => 1e-5]] Sigmoid(:out, Marginalisation) [( + input = (q_in = q_in_adj, q_ζ = PointMass(2.0)), output = Categorical(result) + )] + end + end + end +end diff --git a/test/rules/sigmoid/zeta_tests.jl b/test/rules/sigmoid/zeta_tests.jl new file mode 100644 index 000000000..7e8668f3b --- /dev/null +++ b/test/rules/sigmoid/zeta_tests.jl @@ -0,0 +1,19 @@ +@testitem "rules:Sigmoid:zeta" begin + using ReactiveMP, BayesBase, Random, ExponentialFamily, Distributions + using StatsFuns: logistic + + import ReactiveMP: @test_rules + + @testset "Mean Field: (q_out::Any, q_in::UnivariateNormalDistributionsFamily)" begin + q_in = [NormalMeanVariance(0.0, 1.0), NormalMeanVariance(-1.0, 1.0), NormalMeanVariance(10.0, 1.0)] + results = [1.0, 1.4142135623730951, 10.04987562112089] + for (i, result) in enumerate(results) + for normal_fam in (NormalMeanVariance, NormalMeanPrecision, NormalWeightedMeanPrecision) + q_in_adj = convert(normal_fam, q_in[i]) + @test_rules [check_type_promotion = false, atol = [Float64 => 1e-5]] Sigmoid(:ζ, Marginalisation) [( + input = (q_out = 2.0, q_in = q_in_adj), output = PointMass(result) + )] + end + end + end +end