Skip to content

Commit 727da63

Browse files
mhaurupenelopeysm
andauthored
Fix merge_metadata for cases where the dimension of the variable changes (#781)
* Add test merging VarInfos with different dimensions for a variable * Fix merge_metadata for differing dimensions * Bump patch version to 0.34.1. * Fix test * Fix test more * Pin KernelAbstractions to v0.9.31 * Make KernelAbstractions version bound an upper bound Co-authored-by: Penelope Yong <[email protected]> * Fix syntax --------- Co-authored-by: Penelope Yong <[email protected]>
1 parent 938a69d commit 727da63

File tree

3 files changed

+32
-66
lines changed

3 files changed

+32
-66
lines changed

Project.toml

+7-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.34.0"
3+
version = "0.34.1"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -15,6 +15,9 @@ ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
1515
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1616
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1717
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
18+
# TODO(penelopeysm,mhauru) KernelAbstractions is only a dependency so that we can pin its version, see
19+
# https://github.com/TuringLang/DynamicPPL.jl/pull/781#event-16017866767
20+
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
1821
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1922
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
2023
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
@@ -55,6 +58,9 @@ Compat = "4"
5558
ConstructionBase = "1.5.4"
5659
Distributions = "0.25"
5760
DocStringExtensions = "0.9"
61+
# TODO(penelopeysm,mhauru) See https://github.com/TuringLang/DynamicPPL.jl/pull/781#event-16017866767
62+
# for why KernelAbstractions is pinned like this.
63+
KernelAbstractions = "< 0.9.32"
5864
EnzymeCore = "0.6 - 0.8"
5965
ForwardDiff = "0.10"
6066
JET = "0.9"

src/varinfo.jl

+14-65
Original file line numberDiff line numberDiff line change
@@ -521,73 +521,22 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata)
521521
offset = 0
522522

523523
for (idx, vn) in enumerate(vns_both)
524-
# `idcs`
525524
idcs[vn] = idx
526-
# `vns`
527525
push!(vns, vn)
528-
if vn in vns_left && vn in vns_right
529-
# `vals`: only valid if they're the length.
530-
vals_left = getindex_internal(metadata_left, vn)
531-
vals_right = getindex_internal(metadata_right, vn)
532-
@assert length(vals_left) == length(vals_right)
533-
append!(vals, vals_right)
534-
# `ranges`
535-
r = (offset + 1):(offset + length(vals_left))
536-
push!(ranges, r)
537-
offset = r[end]
538-
# `dists`: only valid if they're the same.
539-
dist_right = getdist(metadata_right, vn)
540-
# Give precedence to `metadata_right`.
541-
push!(dists, dist_right)
542-
gid = metadata_right.gids[getidx(metadata_right, vn)]
543-
push!(gids, gid)
544-
# `orders`: giving precedence to `metadata_right`
545-
push!(orders, getorder(metadata_right, vn))
546-
# `flags`
547-
for k in keys(flags)
548-
# Using `metadata_right`; should we?
549-
push!(flags[k], is_flagged(metadata_right, vn, k))
550-
end
551-
elseif vn in vns_left
552-
# Just extract the metadata from `metadata_left`.
553-
# `vals`
554-
vals_left = getindex_internal(metadata_left, vn)
555-
append!(vals, vals_left)
556-
# `ranges`
557-
r = (offset + 1):(offset + length(vals_left))
558-
push!(ranges, r)
559-
offset = r[end]
560-
# `dists`
561-
dist_left = getdist(metadata_left, vn)
562-
push!(dists, dist_left)
563-
gid = metadata_left.gids[getidx(metadata_left, vn)]
564-
push!(gids, gid)
565-
# `orders`
566-
push!(orders, getorder(metadata_left, vn))
567-
# `flags`
568-
for k in keys(flags)
569-
push!(flags[k], is_flagged(metadata_left, vn, k))
570-
end
571-
else
572-
# Just extract the metadata from `metadata_right`.
573-
# `vals`
574-
vals_right = getindex_internal(metadata_right, vn)
575-
append!(vals, vals_right)
576-
# `ranges`
577-
r = (offset + 1):(offset + length(vals_right))
578-
push!(ranges, r)
579-
offset = r[end]
580-
# `dists`
581-
dist_right = getdist(metadata_right, vn)
582-
push!(dists, dist_right)
583-
gid = metadata_right.gids[getidx(metadata_right, vn)]
584-
push!(gids, gid)
585-
# `orders`
586-
push!(orders, getorder(metadata_right, vn))
587-
# `flags`
588-
for k in keys(flags)
589-
push!(flags[k], is_flagged(metadata_right, vn, k))
590-
end
526+
metadata_for_vn = vn in vns_right ? metadata_right : metadata_left
527+
528+
val = getindex_internal(metadata_for_vn, vn)
529+
append!(vals, val)
530+
r = (offset + 1):(offset + length(val))
531+
push!(ranges, r)
532+
offset = r[end]
533+
dist = getdist(metadata_for_vn, vn)
534+
push!(dists, dist)
535+
gid = metadata_for_vn.gids[getidx(metadata_for_vn, vn)]
536+
push!(gids, gid)
537+
push!(orders, getorder(metadata_for_vn, vn))
538+
for k in keys(flags)
539+
push!(flags[k], is_flagged(metadata_for_vn, vn, k))
591540
end
592541
end
593542

test/varinfo.jl

+11
Original file line numberDiff line numberDiff line change
@@ -869,6 +869,17 @@ end
869869
@test DynamicPPL.getgid(varinfo_merged, @varname(x)) == gidset_left
870870
@test DynamicPPL.getgid(varinfo_merged, @varname(y)) == gidset_right
871871
end
872+
873+
# The below used to error, testing to avoid regression.
874+
@testset "merge different dimensions" begin
875+
vn = @varname(x)
876+
vi_single = VarInfo()
877+
vi_single = push!!(vi_single, vn, 1.0, Normal())
878+
vi_double = VarInfo()
879+
vi_double = push!!(vi_double, vn, [0.5, 0.6], Dirichlet(2, 1.0))
880+
@test merge(vi_single, vi_double)[vn] == [0.5, 0.6]
881+
@test merge(vi_double, vi_single)[vn] == 1.0
882+
end
872883
end
873884

874885
@testset "VarInfo with selectors" begin

0 commit comments

Comments
 (0)