@@ -744,29 +744,6 @@ function _reset_interpolated_values!(remapper::Remapper)
744
744
fill! (remapper. _interpolated_values, 0 )
745
745
end
746
746
747
- """
748
- _collect_and_return_interpolated_values!(remapper::Remapper,
749
- num_fields::Int)
750
-
751
- Perform an MPI call to aggregate the interpolated points from all the MPI processes and save
752
- the result in the local state of the `remapper`. Only the root process will return the
753
- interpolated data.
754
-
755
- `_collect_and_return_interpolated_values!` is type-unstable and allocates new return arrays.
756
-
757
- `num_fields` is the number of fields that have been interpolated in this batch.
758
- """
759
- function _collect_and_return_interpolated_values! (
760
- remapper:: Remapper ,
761
- num_fields:: Int ,
762
- )
763
- return ClimaComms. reduce (
764
- remapper. comms_ctx,
765
- remapper. _interpolated_values[remapper. colons... , 1 : num_fields],
766
- + ,
767
- )
768
- end
769
-
770
747
function _collect_interpolated_values! (
771
748
dest,
772
749
remapper:: Remapper ,
@@ -796,19 +773,6 @@ function _collect_interpolated_values!(
796
773
return nothing
797
774
end
798
775
799
- """
800
- batched_ranges(num_fields, buffer_length)
801
-
802
- Partition the indices from 1 to num_fields in such a way that no range is larger than
803
- buffer_length.
804
- """
805
- function batched_ranges (num_fields, buffer_length)
806
- return [
807
- (i * buffer_length + 1 ): (min ((i + 1 ) * buffer_length, num_fields)) for
808
- i in 0 : (div ((num_fields - 1 ), buffer_length))
809
- ]
810
- end
811
-
812
776
"""
813
777
interpolate(remapper::Remapper, fields)
814
778
interpolate!(dest, remapper::Remapper, fields)
@@ -860,58 +824,21 @@ int12 = interpolate(remapper, [field1, field2])
860
824
```
861
825
"""
862
826
function interpolate (remapper:: Remapper , fields)
863
-
827
+ ArrayType = ClimaComms. array_type (remapper. space)
828
+ FT = Spaces. undertype (remapper. space)
864
829
only_one_field = fields isa Fields. Field
865
- if only_one_field
866
- fields = [fields]
867
- end
868
-
869
- for field in fields
870
- axes (field) == remapper. space ||
871
- error (" Field is defined on a different space than remapper" )
872
- end
873
-
874
- isa_vertical_space = remapper. space isa Spaces. FiniteDifferenceSpace
875
-
876
- index_field_begin, index_field_end =
877
- 1 , min (length (fields), remapper. buffer_length)
878
830
879
- # Partition the indices in such a way that nothing is larger than
880
- # buffer_length
881
- index_ranges = batched_ranges (length (fields), remapper. buffer_length)
831
+ interpolated_values_dim... , _buffer_length =
832
+ size (remapper. _interpolated_values)
882
833
883
- cat_fn = (l... ) -> cat (l... , dims = length (remapper. colons) + 1 )
834
+ allocate_extra = only_one_field ? () : (length (fields),)
835
+ dest = ArrayType (zeros (FT, interpolated_values_dim... , allocate_extra... ))
884
836
885
- interpolated_values = mapreduce (cat_fn, index_ranges) do range
886
- num_fields = length (range)
887
-
888
- # Reset interpolated_values. This is needed because we collect distributed results
889
- # with a + reduction.
890
- _reset_interpolated_values! (remapper)
891
- # Perform the interpolations (horizontal and vertical)
892
- _set_interpolated_values! (
893
- remapper,
894
- view (fields, index_field_begin: index_field_end),
895
- )
896
-
897
- if ! isa_vertical_space
898
- # For spaces with an horizontal component, reshape the output so that it is a nice grid.
899
- _apply_mpi_bitmask! (remapper, num_fields)
900
- else
901
- # For purely vertical spaces, just move to _interpolated_values
902
- remapper. _interpolated_values .= remapper. _local_interpolated_values
903
- end
904
-
905
- # Finally, we have to send all the _interpolated_values to root and sum them up to
906
- # obtain the final answer. Only the root will contain something useful.
907
- return _collect_and_return_interpolated_values! (remapper, num_fields)
908
- end
909
-
910
- # Non-root processes
911
- isnothing (interpolated_values) && return nothing
912
-
913
- return only_one_field ? interpolated_values[remapper. colons... , begin ] :
914
- interpolated_values
837
+ # interpolate! has an MPI call, so it is important to return after it is
838
+ # called, not before!
839
+ interpolate! (dest, remapper, fields)
840
+ ClimaComms. iamroot (remapper. comms_ctx) || return nothing
841
+ return dest
915
842
end
916
843
917
844
# dest has to be allowed to be nothing because interpolation happens only on the root
@@ -927,13 +854,18 @@ function interpolate!(
927
854
end
928
855
isa_vertical_space = remapper. space isa Spaces. FiniteDifferenceSpace
929
856
857
+ for field in fields
858
+ axes (field) == remapper. space ||
859
+ error (" Field is defined on a different space than remapper" )
860
+ end
861
+
930
862
if ! isnothing (dest)
931
863
# !isnothing(dest) means that this is the root process, in this case, the size have
932
864
# to match (ignoring the buffer_length)
933
865
dest_size = only_one_field ? size (dest) : size (dest)[1 : (end - 1 )]
934
866
935
867
dest_size == size (remapper. _interpolated_values)[1 : (end - 1 )] || error (
936
- " Destination array is not compatible with remapper (size mismatch)" ,
868
+ " Destination array is not compatible with remapper (size mismatch), $dest_size " ,
937
869
)
938
870
939
871
expected_array_type =
0 commit comments