Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 19 additions & 17 deletions src/loading.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
loadleaf!(dst, src, err) = dst
loadleaf!(dst::AbstractArray, src, err) =
loadleaf!(dst, src) = dst
loadleaf!(dst::AbstractArray, src) =
error("Tried to copy $src into an array destination; this is not allowed.")
loadleaf!(dst, src::AbstractArray, err) =
loadleaf!(dst, src::AbstractArray) =
error("Tried to copy an array to $dst; this is not allowed.")
function loadleaf!(dst::AbstractArray, src::Bool, err)

function loadleaf!(dst::AbstractArray, src::Bool)
if iszero(src)
dst .= src
else
error("Cannot copy boolean parameter == true to non-zero parameter.")
end
return dst
end
loadleaf!(dst::Bool, src::AbstractArray, err) = iszero(dst) ? dst :

loadleaf!(dst::Bool, src::AbstractArray) = iszero(dst) ? dst :
error("Cannot copy non-zero parameter to boolean parameter == true.")
function loadleaf!(dst::AbstractArray, src::AbstractArray, err)

function loadleaf!(dst::AbstractArray, src::AbstractArray)
err = DimensionMismatch("Tried to load size $(size(src)) array into size $(size(dst))")
(size(dst) == size(src)) || throw(err)
copyto!(dst, src)
end
Expand All @@ -28,9 +32,6 @@ _tie_check(dst, src) = true

_bool_tie_check(dst, src) = true

_filter_children(f, children::NamedTuple) =
NamedTuple(filter(kv -> f(kv[2]), pairs(children)))
_filter_children(f, children) = filter(f, children)

"""
loadmodel!(dst, src)
Expand Down Expand Up @@ -81,21 +82,22 @@ however, attempting to copy a non-zero array to an inactive parameter will throw
Likewise, copying a `src` value of `false` to any `dst` array is valid,
but copying a `src` value of `true` will error.
"""
function loadmodel!(dst, src; filter = _ -> true, cache = Base.IdSet())
ldsts = _filter_children(filter, functor(dst)[1])
lsrcs = _filter_children(filter, functor(src)[1])
(keys(ldsts) == keys(lsrcs)) ||
throw(ArgumentError("Tried to load $src into $dst but the structures do not match."))
function loadmodel!(dst, src; cache = Base.IdSet())
ldsts = Functors.children(dst)
lsrcs = Functors.children(src)
kdsts = keys(ldsts)
ksrcs = keys(lsrcs)
(kdsts == ksrcs) ||
throw(ArgumentError("Tried to load $ksrcs into $kdsts but the structures do not match."))

err = DimensionMismatch("Tried to load $src into $dst but the parameter sizes do not match.")
foreach(ldsts, lsrcs) do ldst, lsrc
if ldst in cache # we already loaded this parameter before
_tie_check(ldst, lsrc) && return ldst
elseif Functors.isleaf(ldst) # our first time loading this leaf
push!(cache, ldst)
loadleaf!(ldst, lsrc, err)
loadleaf!(ldst, lsrc)
else # this isn't a leaf
loadmodel!(ldst, lsrc; filter = filter, cache = cache)
loadmodel!(ldst, lsrc; cache = cache)
end
end

Expand Down