Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions src/helpers/macrohelpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,4 +117,81 @@ macro test_inferred(T, expression)
end)
end

function check_rule_interfaces(macrotype, fform, lambda, ifaces, on_type, m_names, q_names; mod = __MODULE__)
# skip rules like (typeof(+))(:in1_in2) for which interfaces returns nothing
if ifaces === nothing
return nothing
end
names_expected = valof_set(ifaces, mod)
onames = valof_set(on_type, mod)
mnames = valof_set(m_names, mod)
qnames = valof_set(q_names, mod)
names_used = union(onames, mnames, qnames)

names_unknown = setdiff(names_expected, names_used)
if !isempty(names_unknown)
missing_list = join(sort(collect(names_unknown)), ", ")
expected_list = join(sort(collect(names_expected)), ", ")
provided_list = join(sort(collect(names_used)), ", ")

throw(ArgumentError("""
Interface mismatch for $(macrotype) $(fform) $(lambda):
Expected symbols: $expected_list
Provided symbols: $provided_list
Missing symbols: $missing_list
"""))
end

names_extra = setdiff(names_used, names_expected)
if !isempty(names_extra)
extras_list = join(sort(collect(names_extra)), ", ")
expected_list = join(sort(collect(names_expected)), ", ")
provided_list = join(sort(collect(names_used)), ", ")

throw(ArgumentError("""
Interface mismatch for $(macrotype) $(fform) $(lambda):
Expected symbols: $expected_list
Provided symbols: $provided_list
Extra symbols: $extras_list
"""))
end
end

function valof_set(x, mod::Module)
s = Set{Symbol}()

if x === nothing || x === :Nothing
return s
elseif x isa Symbol
# Split joint message symbol by underscores
for part in split(string(x), '_')
push!(s, Symbol(part))
end
return s
elseif x isa Val
return valof_set(typeof(x).parameters[1], mod)
elseif x isa DataType && x <: Val
return valof_set(x.parameters[1], mod)
elseif x isa DataType && x <: Tuple
# Handle tuple types like Tuple{Val{:inputs}, Int}
for p in x.parameters
if p <: Integer
continue
end
s = union(s, valof_set(p, mod))
end
return s
elseif x isa Tuple
# Handle **tuple values** (instances)
for xi in x
s = union(s, valof_set(xi, mod))
end
return s
elseif x isa Expr
return valof_set(Core.eval(mod, x), mod)
else
return s
end
end

end
16 changes: 16 additions & 0 deletions src/rule.jl
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,14 @@ macro rule(fform, lambda)
m_names, m_types, m_init_block = rule_macro_parse_fn_args(inputs; specname = :messages, prefix = :m_, proxy = :(ReactiveMP.Message))
q_names, q_types, q_init_block = rule_macro_parse_fn_args(inputs; specname = :marginals, prefix = :q_, proxy = :(ReactiveMP.Marginal))

fexpr = fform.head == :call ? fform.args[1] : fform
if fexpr isa Expr && fexpr.head == :curly
fexpr = fexpr.args[1]
end
fform_type = Core.eval(__module__, fexpr)
ifaces = ReactiveMP.interfaces(fform_type)
MacroHelpers.check_rule_interfaces("@rule", fform, lambda, ifaces, on_type, m_names, q_names; mod = __module__)

output = quote
$(
rule_function_expression(fuppertype, on_type, vconstraint, m_names, m_types, q_names, q_types, metatype, whereargs) do
Expand Down Expand Up @@ -602,6 +610,14 @@ macro marginalrule(fform, lambda)
m_names, m_types, m_init_block = rule_macro_parse_fn_args(inputs; specname = :messages, prefix = :m_, proxy = :(ReactiveMP.Message))
q_names, q_types, q_init_block = rule_macro_parse_fn_args(inputs; specname = :marginals, prefix = :q_, proxy = :(ReactiveMP.Marginal))

fexpr = fform.head == :call ? fform.args[1] : fform
if fexpr isa Expr && fexpr.head == :curly
fexpr = fexpr.args[1]
end
fform_type = Core.eval(__module__, fexpr)
ifaces = ReactiveMP.interfaces(fform_type)
MacroHelpers.check_rule_interfaces("@marginalrule", fform, lambda, ifaces, on_type, m_names, q_names; mod = __module__)

output = quote
$(
marginalrule_function_expression(fuppertype, on_type, m_names, m_types, q_names, q_types, metatype, whereargs) do
Expand Down
4 changes: 4 additions & 0 deletions src/score/score.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ macro average_energy(fformtype, lambda)

q_names, q_types, q_init_block = rule_macro_parse_fn_args(inputs; specname = :marginals, prefix = :q_, proxy = :Marginal)

fform_type = Core.eval(__module__, fformtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not use eval, is very dangerous especially when macro generating the code (you have eval inside of eval). Is using the fuppertype (defined above) not sufficient?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also think that the tests are failing exactly because of the eval, since it evaluates inside the ReactiveMP, but the structure in tests is defined outside of ReactiveMP

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also add a test, that actually checks that the error is being thrown by defining a rule with wrong interfaces

ifaces = ReactiveMP.interfaces(fform_type)
MacroHelpers.check_rule_interfaces("@average_energy", fformtype, lambda, ifaces, nothing, nothing, q_names; mod = __module__)

result = quote
function ReactiveMP.score(::AverageEnergy, fform::$(fuppertype), marginals_names::$(q_names), marginals::$(q_types), meta::$(metatype)) where {$(whereargs...)}
$(q_init_block...)
Expand Down
Loading