Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reuse baroclinic_instability_model in sharding script and add test #116

Merged
merged 45 commits into from
Mar 30, 2025
Merged
Show file tree
Hide file tree
Changes from 36 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
39d2d93
Add new reusable workflow for simulation setup
glwagner Mar 28, 2025
3b5e424
Fix indenting
glwagner Mar 28, 2025
4560ccc
Bugfix maybe
glwagner Mar 28, 2025
1aac3db
Another try
glwagner Mar 28, 2025
42f92fe
Restore CompileOrRun
glwagner Mar 28, 2025
f7b7958
Generalize CompileOrRun for sharding
glwagner Mar 28, 2025
218ed6e
Dont forget existing env
glwagner Mar 28, 2025
fdcbcf2
Single quotes
glwagner Mar 28, 2025
05b0cc5
Typo
glwagner Mar 28, 2025
454ad35
Old name
glwagner Mar 28, 2025
b2a8808
Sharding not required
glwagner Mar 28, 2025
14673f7
Add sharding sim
glwagner Mar 28, 2025
aedb2cd
Small update to baro instability
glwagner Mar 28, 2025
9ce7e45
Fix typo
glwagner Mar 28, 2025
55b3c5b
Add simulation to script name
glwagner Mar 28, 2025
ad856fd
Rm
glwagner Mar 28, 2025
aedb06f
Simplify inputs
glwagner Mar 28, 2025
f295f32
Add grid name to files
glwagner Mar 28, 2025
3251743
small bugfix
simone-silvestri Mar 28, 2025
7adda32
add an mpi init
simone-silvestri Mar 28, 2025
eb5c685
add MPI to project
simone-silvestri Mar 28, 2025
1c97e7a
fix path
simone-silvestri Mar 28, 2025
5218af0
just run with dt=1 for now
simone-silvestri Mar 28, 2025
254831c
Update CompileOrRun.yml
glwagner Mar 28, 2025
54f42c6
Merge remote-tracking branch 'origin/main' into glw/groomed-sharding
glwagner Mar 28, 2025
a579c52
Pass grid name into CompileOrRun
glwagner Mar 28, 2025
a4337b3
Rm common ref
glwagner Mar 28, 2025
43bda96
Fix bug in sharding grid spec
glwagner Mar 28, 2025
495b8ad
bugfix
glwagner Mar 29, 2025
2db7488
fix
glwagner Mar 29, 2025
e16ccda
Dont use parse
glwagner Mar 29, 2025
d9f40e2
Merge branch 'main' into glw/groomed-sharding
glwagner Mar 29, 2025
f41accd
Update Run.yml
glwagner Mar 29, 2025
e9ba0ca
Update Compile.yml
glwagner Mar 29, 2025
b7f7d14
Merge branch 'main' into glw/groomed-sharding
giordano Mar 29, 2025
2b5a5f1
Update .github/workflows/Compile.yml
glwagner Mar 29, 2025
ead10d8
Update .github/workflows/Run.yml
glwagner Mar 29, 2025
d3e6843
grid_name -> grid_type
glwagner Mar 29, 2025
dfd0e46
Merge remote-tracking branch 'origin/main' into glw/groomed-sharding
glwagner Mar 30, 2025
6806704
Forget about Gaussian islands for a sec
glwagner Mar 30, 2025
1281ffe
Update sharded_baroclinic_instability_simulation_run.jl
glwagner Mar 30, 2025
83bfb53
Update model_utils.jl
glwagner Mar 30, 2025
7175b98
Update model_utils.jl
glwagner Mar 30, 2025
3f704c4
grid -> grid_type in precompile extensions
glwagner Mar 30, 2025
c7a03aa
Update sharding/sharded_baroclinic_instability_simulation_run.jl
glwagner Mar 30, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/Compile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ jobs:
uses: ./.github/workflows/CompileOrRun.yml
with:
sim_type: ${{ matrix.sim_type }}
grid_name: ${{ matrix.sim_type == 'baroclinic_instability' && 'simple_lat_lon' || 'gaussian_islands' }}
julia_version: ${{ matrix.julia_version }}
os: ${{ matrix.os }}
xla_runtime: ${{ matrix.xla_runtime }}
Expand Down
33 changes: 24 additions & 9 deletions .github/workflows/CompileOrRun.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,16 @@ on:
required: true
default: ''
type: string
sharded:
description: 'Whether we are using sharding'
required: false
default: false
type: boolean
grid_name:
description: 'What type of grid to use'
required: false
default: ''
type: string
xla_runtime:
description: 'The XLA runtime'
required: true
Expand Down Expand Up @@ -58,7 +68,10 @@ jobs:
- uses: julia-actions/cache@v2
if: ${{ startsWith(inputs.os, 'ubuntu') }}
with:
cache-name: julia-cache;workflow=${{ inputs.julia_version }}-${{ inputs.sim_type }}-${{ inputs.os }}-${{ inputs.xla_runtime }}-${{ inputs.compile_or_run }};job=${{ github.job }}
cache-name: |
julia-cache;
workflow=${{ inputs.julia_version }}-${{ inputs.sim_type }}-${{ inputs.os }}-${{ inputs.xla_runtime }}-${{ inputs.compile_or_run }};
job=${{ github.job }}
- name: Collect Workflow Telemetry
uses: catchpoint/workflow-telemetry-action@v2
with:
Expand Down Expand Up @@ -88,19 +101,21 @@ jobs:
timeout-minutes: 10
if: ${{ always() }}
with:
name: 'environment-${{ inputs.sim_type }}-${{ inputs.julia_version }}-${{ inputs.os }}-${{ inputs.xla_runtime }}-${{ inputs.compile_or_run }}'
name: 'environment-${{ inputs.sim_type }}-${{ inputs.grid_name }}-${{ inputs.julia_version }}-${{ inputs.os }}-${{ inputs.xla_runtime }}-${{ inputs.compile_or_run }}'
path: |
Manifest.toml
Project.toml
retention-days: 90
overwrite: false
- name: Ocean climate simulation
timeout-minutes: 180
env:
grid_name: ${{ inputs.grid_name }}
XLA_FLAGS: ${{ inputs.sharded && '--xla_force_host_platform_device_count=4 --xla_dump_to=xla_dump' || '--xla_dump_to=xla_dump' }}
RUNDIR: ${{ inputs.sharded && 'sharding' || 'simulations' }}
run: |
earlyoom -m ${{ inputs.earlyoom_threshold }} -s 100 -r 120 --prefer 'julia' &
julia --color=yes --project -O${{ inputs.julia_optlevel }} simulations/${{ inputs.sim_type }}_simulation_${{ inputs.compile_or_run }}.jl
env:
XLA_FLAGS: "--xla_dump_to=xla_dump"
julia --color=yes --project -O${{ inputs.julia_optlevel }} $RUNDIR/${{ inputs.sim_type }}_simulation_${{ inputs.compile_or_run }}.jl
- name: Show remaining jit calls
if: ${{ inputs.compile_or_run == 'compile' }}
timeout-minutes: 10
Expand All @@ -121,7 +136,7 @@ jobs:
timeout-minutes: 10
if: ${{ always() }}
with:
name: 'simulation-mlir-${{ inputs.sim_type }}-${{ inputs.julia_version }}-${{ inputs.os }}-${{ inputs.xla_runtime }}'
name: 'simulation-mlir-${{ inputs.sim_type }}-${{ inputs.grid_name }}-${{ inputs.julia_version }}-${{ inputs.os }}-${{ inputs.xla_runtime }}'
path: '**/*.mlir'
retention-days: 90
overwrite: false
Expand All @@ -130,7 +145,7 @@ jobs:
timeout-minutes: 10
if: ${{ always() && inputs.compile_or_run == 'compile' }}
with:
name: 'simulation-julia-profile-${{ inputs.sim_type }}-${{ inputs.julia_version }}-${{ inputs.os }}-${{ inputs.xla_runtime }}'
name: 'simulation-julia-profile-${{ inputs.sim_type }}-${{ inputs.grid_name }}-${{ inputs.julia_version }}-${{ inputs.os }}-${{ inputs.xla_runtime }}'
path: |
**/profile_*.txt
**/profile_*.dat
Expand All @@ -141,7 +156,7 @@ jobs:
timeout-minutes: 10
if: ${{ always() && inputs.compile_or_run == 'run' }}
with:
name: 'simulation-xla-dump-${{ inputs.sim_type }}-${{ inputs.julia_version }}-${{ inputs.os }}-${{ inputs.xla_runtime }}'
name: 'simulation-xla-dump-${{ inputs.sim_type }}-${{ inputs.grid_name }}-${{ inputs.julia_version }}-${{ inputs.os }}-${{ inputs.xla_runtime }}'
path: '**/xla_dump'
retention-days: 90
overwrite: false
Expand All @@ -150,7 +165,7 @@ jobs:
timeout-minutes: 10
if: ${{ always() && inputs.compile_or_run == 'run' }}
with:
name: 'simulation-xla-profile-${{ inputs.sim_type }}-${{ inputs.julia_version }}-${{ inputs.os }}-${{ inputs.xla_runtime }}'
name: 'simulation-xla-profile-${{ inputs.sim_type }}-${{ inputs.grid_name }}-${{ inputs.julia_version }}-${{ inputs.os }}-${{ inputs.xla_runtime }}'
path: '**/plugins'
retention-days: 90
overwrite: false
40 changes: 27 additions & 13 deletions .github/workflows/Run.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,22 +33,15 @@ concurrency:

jobs:
run_simulations:
name: Julia ${{ matrix.julia_version }} - ${{ matrix.sim_type }} - ${{ matrix.os }} - ${{ matrix.xla_runtime }}
name: Serial - Julia ${{ matrix.julia_version }} - ${{ matrix.sim_type }} - ${{ matrix.os }} - ${{ matrix.xla_runtime }}
strategy:
fail-fast: false
matrix:
sim_type:
- 'baroclinic_instability'
- 'ocean_climate'
julia_version:
- '1.11'
os:
- ubuntu-24.04
xla_runtime:
- 'PJRT'
- 'IFRT'
earlyoom_threshold:
- 4
sim_type: ['baroclinic_instability', 'ocean_climate']
julia_version: ['1.11']
os: ['ubuntu-24.04']
xla_runtime: ['IFRT', 'PJRT']
earlyoom_threshold: [4]
include:
- os: ubuntu-22.04-arm
julia_version: '1.11'
Expand All @@ -68,9 +61,30 @@ jobs:
uses: ./.github/workflows/CompileOrRun.yml
with:
sim_type: ${{ matrix.sim_type }}
grid_type: ${{ matrix.sim_type == 'baroclinic_instability' && 'simple_lat_lon' || 'gaussian_islands' }}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/PRONTOLab/GB-25/actions/runs/14148328037?pr=116

The workflow is not valid. .github/workflows/Run.yml (Line: 64, Col: 18): Invalid input, grid_type is not defined in the referenced workflow.

🤔

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops, its grid_name

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually I will just change them all to grid_type, its better

julia_version: ${{ matrix.julia_version }}
os: ${{ matrix.os }}
xla_runtime: ${{ matrix.xla_runtime }}
compile_or_run: 'run'
earlyoom_threshold: ${{ matrix.earlyoom_threshold }}
julia_optlevel: 0

run_sharded:
name: Sharded - Julia ${{ matrix.julia_version }} - ${{ matrix.grid_name }} - ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: ['ubuntu-24.04', 'ubuntu-22.04-arm']
grid_name: ['simple_lat_lon', 'gaussian_islands']
uses: ./.github/workflows/CompileOrRun.yml
with:
sim_type: 'sharded_baroclinic_instability'
grid_name: ${{ matrix.grid_name }}
sharded: true
julia_version: '1.11'
os: ${{ matrix.os }}
xla_runtime: 'IFRT'
compile_or_run: 'run'
earlyoom_threshold: 4
julia_optlevel: 0

1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ CFTime = "179af706-886a-5703-950a-314cd64e0468"
ClimaOcean = "0376089a-ecfe-4b0e-a64f-9c555d74d754"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
HDF5_jll = "0234f1f7-429e-5d53-9886-15a909be8d59"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
Oceananigans = "9e8cae18-63c1-5223-a75c-80ca9d6e9a09"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Expand Down
124 changes: 0 additions & 124 deletions sharding/sharded_baroclinic_instability.jl

This file was deleted.

74 changes: 74 additions & 0 deletions sharding/sharded_baroclinic_instability_simulation_run.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
using GordonBell25
using GordonBell25: first_time_step!, time_step!, loop!
using Oceananigans
using Oceananigans.Units
using Oceananigans.Architectures: ReactantState
using Random
using Printf
using Reactant
using MPI

# Need this for sharding with non-openMPI implementations?
# (GHA uses MPICH)
MPI.Init()

Reactant.Distributed.initialize(; single_gpu_per_process=false)

@show Ngpu = length(Reactant.devices())

if Ngpu == 1
rank = 0
arch = Oceananigans.ReactantState()
elseif Ngpu == 2
rank = Reactant.Distributed.local_rank()

arch = Oceananigans.Distributed(
Oceananigans.ReactantState();
partition = Partition(1, 2, 1)
)
else
Rx = floor(Int, sqrt(Ngpu))
Ry = Ngpu ÷ Rx
rank = Reactant.Distributed.local_rank()

arch = Oceananigans.Distributed(
Oceananigans.ReactantState();
partition = Partition(Rx, Ry, 1)
)
end

using Dates
@info "[$rank] Generating model..." now(UTC)

grid_str = get(ENV, "grid_name", "simple_lat_lon")
resolution_fraction_str = get(ENV, "resolution_fraction", "2")
time_step_str = get(ENV, "time_step", "60")
Nz_str = get(ENV, "Nz", "10")

@show grid_type = Symbol(grid_str)
@show resolution_fraction = parse(Float64, resolution_fraction_str)
@show time_step_str = parse(Float64, time_step_str)
@show Nz = parse(Int, Nz_str)

model = GordonBell25.baroclinic_instability_model(arch; grid_type, Δt=1, Nz,
resolution=1/resolution_fraction)

@info "[$rank] Compiling first_time_step!..."
rfirst! = @compile first_time_step!(model)

@info "[$rank] Compiling loop..."
rstep! = @compile time_step!(model)

@time "[$rank] Running first_time_step!..." rfirst!(model)
@time "[$rank] Warming up..." rstep!(model)

rstep!(model)
rstep!(model)
rstep!(model)

@time "[$rank] Running loop..." begin
for n = 1:10
rstep!(model)
end
end

4 changes: 2 additions & 2 deletions sharding/sharded_tripolar_instability.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ using Dates
@info "[$rank] Generating model..." now(UTC)
resolution_fraction_str = get(ENV, "resolution_fraction", "4")
@show resolution_fraction = parse(Float64, resolution_fraction_str)
grid = :gaussian_islands # an idealized TripolarGrid with Gaussian-shaped islands
model = GordonBell25.baroclinic_instability_model(arch; resolution=1/resolution_fraction, Δt=1, grid)
grid_type = :gaussian_islands # an idealized TripolarGrid with Gaussian-shaped islands
model = GordonBell25.baroclinic_instability_model(arch; resolution=1/resolution_fraction, Δt=1, grid_type)

@info "[$rank] Compiling first_time_step!..."
rfirst! = @compile first_time_step!(model)
Expand Down
Loading
Loading