Skip to content

Fix distributed remapping bug #2169

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 5 additions & 30 deletions .github/workflows/JuliaFormatter.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,34 +7,9 @@ on:

jobs:
format:
runs-on: ubuntu-24.04
timeout-minutes: 30
runs-on: ubuntu-latest
steps:
- name: Cancel Previous Runs
uses: styfle/[email protected]
with:
access_token: ${{ github.token }}

- uses: actions/checkout@v4

- uses: dorny/[email protected]
id: filter
with:
filters: |
julia_file_change:
- added|modified: '**.jl'

- uses: julia-actions/setup-julia@v2
if: steps.filter.outputs.julia_file_change == 'true'
with:
version: '1.10'

- name: Apply JuliaFormatter
if: steps.filter.outputs.julia_file_change == 'true'
run: |
julia --color=yes --project=.dev .dev/climaformat.jl --verbose .

- name: Check formatting diff
if: steps.filter.outputs.julia_file_change == 'true'
run: |
git diff --color=always --exit-code
- uses: julia-actions/julia-format@v3
with:
version: '1' # Set `version` to '1.0.54' if you need to use JuliaFormatter.jl v1.0.54 (default: '1')
suggestion-label: 'format-suggest' # leave this unset or empty to show suggestions for all PRs
6 changes: 4 additions & 2 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@ main
-------

- Prior to this version, `CommonSpaces` could not be created with
`ClimaComms.MPICommContext`. This is now fixed with PR
`ClimaComms.MPICommsContext`. This is now fixed with PR
[2176](https://github.com/CliMA/ClimaCore.jl/pull/2176).

- Fixed bug in distributed remapping with CUDA. Sometimes, `ClimaCore` would not
properly fill the output arrays with the correct values. This is now fixed. PR
[2169](https://github.com/CliMA/ClimaCore.jl/pull/2169)

v0.14.24
-------
Expand Down
125 changes: 29 additions & 96 deletions src/Remapping/distributed_remapping.jl
Original file line number Diff line number Diff line change
Expand Up @@ -744,29 +744,6 @@ function _reset_interpolated_values!(remapper::Remapper)
fill!(remapper._interpolated_values, 0)
end

"""
_collect_and_return_interpolated_values!(remapper::Remapper,
num_fields::Int)

Perform an MPI call to aggregate the interpolated points from all the MPI processes and save
the result in the local state of the `remapper`. Only the root process will return the
interpolated data.

`_collect_and_return_interpolated_values!` is type-unstable and allocates new return arrays.

`num_fields` is the number of fields that have been interpolated in this batch.
"""
function _collect_and_return_interpolated_values!(
remapper::Remapper,
num_fields::Int,
)
return ClimaComms.reduce(
remapper.comms_ctx,
remapper._interpolated_values[remapper.colons..., 1:num_fields],
+,
)
end

function _collect_interpolated_values!(
dest,
remapper::Remapper,
Expand All @@ -777,38 +754,26 @@ function _collect_interpolated_values!(
if only_one_field
ClimaComms.reduce!(
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remapper.comms_ctx,
remapper._interpolated_values[remapper.colons..., begin],
view(remapper._interpolated_values, remapper.colons..., 1),
dest,
+,
)
return nothing
else
num_fields = 1 + index_field_end - index_field_begin
ClimaComms.reduce!(
remapper.comms_ctx,
view(
remapper._interpolated_values,
remapper.colons...,
1:num_fields,
),
view(dest, remapper.colons..., index_field_begin:index_field_end),
+,
)
end

num_fields = 1 + index_field_end - index_field_begin

ClimaComms.reduce!(
remapper.comms_ctx,
view(remapper._interpolated_values, remapper.colons..., 1:num_fields),
view(dest, remapper.colons..., index_field_begin:index_field_end),
+,
)

return nothing
end

"""
batched_ranges(num_fields, buffer_length)

Partition the indices from 1 to num_fields in such a way that no range is larger than
buffer_length.
"""
function batched_ranges(num_fields, buffer_length)
return [
(i * buffer_length + 1):(min((i + 1) * buffer_length, num_fields)) for
i in 0:(div((num_fields - 1), buffer_length))
]
end

"""
interpolate(remapper::Remapper, fields)
interpolate!(dest, remapper::Remapper, fields)
Expand Down Expand Up @@ -860,58 +825,21 @@ int12 = interpolate(remapper, [field1, field2])
```
"""
function interpolate(remapper::Remapper, fields)

ArrayType = ClimaComms.array_type(remapper.space)
FT = Spaces.undertype(remapper.space)
only_one_field = fields isa Fields.Field
if only_one_field
fields = [fields]
end

for field in fields
axes(field) == remapper.space ||
error("Field is defined on a different space than remapper")
end
interpolated_values_dim..., _buffer_length =
size(remapper._interpolated_values)

isa_vertical_space = remapper.space isa Spaces.FiniteDifferenceSpace

index_field_begin, index_field_end =
1, min(length(fields), remapper.buffer_length)

# Partition the indices in such a way that nothing is larger than
# buffer_length
index_ranges = batched_ranges(length(fields), remapper.buffer_length)
allocate_extra = only_one_field ? () : (length(fields),)
dest = ArrayType(zeros(FT, interpolated_values_dim..., allocate_extra...))

cat_fn = (l...) -> cat(l..., dims = length(remapper.colons) + 1)

interpolated_values = mapreduce(cat_fn, index_ranges) do range
num_fields = length(range)

# Reset interpolated_values. This is needed because we collect distributed results
# with a + reduction.
_reset_interpolated_values!(remapper)
# Perform the interpolations (horizontal and vertical)
_set_interpolated_values!(
remapper,
view(fields, index_field_begin:index_field_end),
)

if !isa_vertical_space
# For spaces with an horizontal component, reshape the output so that it is a nice grid.
_apply_mpi_bitmask!(remapper, num_fields)
else
# For purely vertical spaces, just move to _interpolated_values
remapper._interpolated_values .= remapper._local_interpolated_values
end

# Finally, we have to send all the _interpolated_values to root and sum them up to
# obtain the final answer. Only the root will contain something useful.
return _collect_and_return_interpolated_values!(remapper, num_fields)
end

# Non-root processes
isnothing(interpolated_values) && return nothing

return only_one_field ? interpolated_values[remapper.colons..., begin] :
interpolated_values
# interpolate! has an MPI call, so it is important to return after it is
# called, not before!
interpolate!(dest, remapper, fields)
ClimaComms.iamroot(remapper.comms_ctx) || return nothing
return dest
end

# dest has to be allowed to be nothing because interpolation happens only on the root
Expand All @@ -927,6 +855,11 @@ function interpolate!(
end
isa_vertical_space = remapper.space isa Spaces.FiniteDifferenceSpace

for field in fields
axes(field) == remapper.space ||
error("Field is defined on a different space than remapper")
end

if !isnothing(dest)
# !isnothing(dest) means that this is the root process, in this case, the size have
# to match (ignoring the buffer_length)
Expand Down
15 changes: 2 additions & 13 deletions test/Remapping/distributed_remapping.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,6 @@ atexit() do
global_logger(prev_logger)
end

@testset "Utils" begin
# batched_ranges(num_fields, buffer_length)
@test Remapping.batched_ranges(1, 1) == [1:1]
@test Remapping.batched_ranges(1, 2) == [1:1]
@test Remapping.batched_ranges(2, 2) == [1:2]
@test Remapping.batched_ranges(3, 2) == [1:2, 3:3]
end

with_mpi = context isa ClimaComms.MPICommsContext

@testset "2D extruded" begin
Expand Down Expand Up @@ -161,10 +153,7 @@ end

quad = Quadratures.GLL{4}()
horzmesh = Meshes.RectilinearMesh(horzdomain, 10, 10)
horztopology = Topologies.Topology2D(
ClimaComms.SingletonCommsContext(device),
horzmesh,
)
horztopology = Topologies.Topology2D(context, horzmesh)
horzspace = Spaces.SpectralElementSpace2D(horztopology, quad)

hv_center_space =
Expand Down Expand Up @@ -330,7 +319,7 @@ end
quad = Quadratures.GLL{4}()
horzmesh = Meshes.RectilinearMesh(horzdomain, 10, 10)
horztopology = Topologies.Topology2D(
ClimaComms.SingletonCommsContext(device),
context,
horzmesh,
Topologies.spacefillingcurve(horzmesh),
)
Expand Down
Loading