Skip to content

Commit 8dfe3ef

Browse files
committed
Simplify distributed_remapping.interpolate
I spent many hours tracking down #2108 and could not find the root issue. I decided to take a different approach and simplify redefine `interpolate` in terms of `interpolate!`.
1 parent fdbadd3 commit 8dfe3ef

File tree

3 files changed

+24
-93
lines changed

3 files changed

+24
-93
lines changed

NEWS.md

+7
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,13 @@ ClimaCore.jl Release Notes
22
========================
33

44
main
5+
-------
6+
7+
- Fixed bug in distributed remapping with CUDA. Sometimes, `ClimaCore` would not
8+
properly fill the output arrays with the correct values. This is now fixed. PR
9+
[2169](https://github.com/CliMA/ClimaCore.jl/pull/2169)
10+
11+
v0.14.24
512
-------
613

714
- A new `Adapt` wrapper was added, `to_device`, which allows users to adapt datalayouts, spaces, fields, and fieldvectors between the cpu and gpu. PR [2159](https://github.com/CliMA/ClimaCore.jl/pull/2159).

src/Remapping/distributed_remapping.jl

+17-85
Original file line numberDiff line numberDiff line change
@@ -744,29 +744,6 @@ function _reset_interpolated_values!(remapper::Remapper)
744744
fill!(remapper._interpolated_values, 0)
745745
end
746746

747-
"""
748-
_collect_and_return_interpolated_values!(remapper::Remapper,
749-
num_fields::Int)
750-
751-
Perform an MPI call to aggregate the interpolated points from all the MPI processes and save
752-
the result in the local state of the `remapper`. Only the root process will return the
753-
interpolated data.
754-
755-
`_collect_and_return_interpolated_values!` is type-unstable and allocates new return arrays.
756-
757-
`num_fields` is the number of fields that have been interpolated in this batch.
758-
"""
759-
function _collect_and_return_interpolated_values!(
760-
remapper::Remapper,
761-
num_fields::Int,
762-
)
763-
return ClimaComms.reduce(
764-
remapper.comms_ctx,
765-
remapper._interpolated_values[remapper.colons..., 1:num_fields],
766-
+,
767-
)
768-
end
769-
770747
function _collect_interpolated_values!(
771748
dest,
772749
remapper::Remapper,
@@ -796,19 +773,6 @@ function _collect_interpolated_values!(
796773
return nothing
797774
end
798775

799-
"""
800-
batched_ranges(num_fields, buffer_length)
801-
802-
Partition the indices from 1 to num_fields in such a way that no range is larger than
803-
buffer_length.
804-
"""
805-
function batched_ranges(num_fields, buffer_length)
806-
return [
807-
(i * buffer_length + 1):(min((i + 1) * buffer_length, num_fields)) for
808-
i in 0:(div((num_fields - 1), buffer_length))
809-
]
810-
end
811-
812776
"""
813777
interpolate(remapper::Remapper, fields)
814778
interpolate!(dest, remapper::Remapper, fields)
@@ -860,58 +824,21 @@ int12 = interpolate(remapper, [field1, field2])
860824
```
861825
"""
862826
function interpolate(remapper::Remapper, fields)
863-
827+
ArrayType = ClimaComms.array_type(remapper.space)
828+
FT = Spaces.undertype(remapper.space)
864829
only_one_field = fields isa Fields.Field
865-
if only_one_field
866-
fields = [fields]
867-
end
868-
869-
for field in fields
870-
axes(field) == remapper.space ||
871-
error("Field is defined on a different space than remapper")
872-
end
873-
874-
isa_vertical_space = remapper.space isa Spaces.FiniteDifferenceSpace
875-
876-
index_field_begin, index_field_end =
877-
1, min(length(fields), remapper.buffer_length)
878830

879-
# Partition the indices in such a way that nothing is larger than
880-
# buffer_length
881-
index_ranges = batched_ranges(length(fields), remapper.buffer_length)
831+
interpolated_values_dim..., _buffer_length =
832+
size(remapper._interpolated_values)
882833

883-
cat_fn = (l...) -> cat(l..., dims = length(remapper.colons) + 1)
834+
allocate_extra = only_one_field ? () : (length(fields),)
835+
dest = ArrayType(zeros(FT, interpolated_values_dim..., allocate_extra...))
884836

885-
interpolated_values = mapreduce(cat_fn, index_ranges) do range
886-
num_fields = length(range)
887-
888-
# Reset interpolated_values. This is needed because we collect distributed results
889-
# with a + reduction.
890-
_reset_interpolated_values!(remapper)
891-
# Perform the interpolations (horizontal and vertical)
892-
_set_interpolated_values!(
893-
remapper,
894-
view(fields, index_field_begin:index_field_end),
895-
)
896-
897-
if !isa_vertical_space
898-
# For spaces with an horizontal component, reshape the output so that it is a nice grid.
899-
_apply_mpi_bitmask!(remapper, num_fields)
900-
else
901-
# For purely vertical spaces, just move to _interpolated_values
902-
remapper._interpolated_values .= remapper._local_interpolated_values
903-
end
904-
905-
# Finally, we have to send all the _interpolated_values to root and sum them up to
906-
# obtain the final answer. Only the root will contain something useful.
907-
return _collect_and_return_interpolated_values!(remapper, num_fields)
908-
end
909-
910-
# Non-root processes
911-
isnothing(interpolated_values) && return nothing
912-
913-
return only_one_field ? interpolated_values[remapper.colons..., begin] :
914-
interpolated_values
837+
# interpolate! has an MPI call, so it is important to return after it is
838+
# called, not before!
839+
interpolate!(dest, remapper, fields)
840+
ClimaComms.iamroot(remapper.comms_ctx) || return nothing
841+
return dest
915842
end
916843

917844
# dest has to be allowed to be nothing because interpolation happens only on the root
@@ -927,13 +854,18 @@ function interpolate!(
927854
end
928855
isa_vertical_space = remapper.space isa Spaces.FiniteDifferenceSpace
929856

857+
for field in fields
858+
axes(field) == remapper.space ||
859+
error("Field is defined on a different space than remapper")
860+
end
861+
930862
if !isnothing(dest)
931863
# !isnothing(dest) means that this is the root process, in this case, the size have
932864
# to match (ignoring the buffer_length)
933865
dest_size = only_one_field ? size(dest) : size(dest)[1:(end - 1)]
934866

935867
dest_size == size(remapper._interpolated_values)[1:(end - 1)] || error(
936-
"Destination array is not compatible with remapper (size mismatch)",
868+
"Destination array is not compatible with remapper (size mismatch), $dest_size",
937869
)
938870

939871
expected_array_type =

test/Remapping/distributed_remapping.jl

-8
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,6 @@ atexit() do
3131
global_logger(prev_logger)
3232
end
3333

34-
@testset "Utils" begin
35-
# batched_ranges(num_fields, buffer_length)
36-
@test Remapping.batched_ranges(1, 1) == [1:1]
37-
@test Remapping.batched_ranges(1, 2) == [1:1]
38-
@test Remapping.batched_ranges(2, 2) == [1:2]
39-
@test Remapping.batched_ranges(3, 2) == [1:2, 3:3]
40-
end
41-
4234
on_gpu = device isa ClimaComms.CUDADevice
4335
with_mpi = context isa ClimaComms.MPICommsContext
4436

0 commit comments

Comments
 (0)