File tree Expand file tree Collapse file tree 1 file changed +6
-6
lines changed
checkpoint/orbax/checkpoint/_src/arrays Expand file tree Collapse file tree 1 file changed +6
-6
lines changed Original file line number Diff line number Diff line change @@ -50,7 +50,7 @@ def _construct_maximal_sharding(
5050 mesh_shape = []
5151 mesh_axes = []
5252
53- current_parition_axis = 0
53+ current_partition_axis = 0
5454 # Max to min.
5555 for i in np .argsort (shape )[::- 1 ]:
5656 assert available_device_dim > 0
@@ -64,15 +64,15 @@ def _construct_maximal_sharding(
6464 available_device_dim //= gcd
6565 mesh_shape .append (gcd )
6666
67- current_parition_axis_name = _partition_axis_name (current_parition_axis )
68- partition_axes [i ] = current_parition_axis_name
69- mesh_axes .append (current_parition_axis_name )
70- current_parition_axis += 1
67+ current_partition_axis_name = _partition_axis_name (current_partition_axis )
68+ partition_axes [i ] = current_partition_axis_name
69+ mesh_axes .append (current_partition_axis_name )
70+ current_partition_axis += 1
7171
7272 # Still have some partition dimension left over.
7373 if available_device_dim > 1 :
7474 mesh_shape .append (available_device_dim )
75- mesh_axes .append (_partition_axis_name (current_parition_axis ))
75+ mesh_axes .append (_partition_axis_name (current_partition_axis ))
7676
7777 logging .info (
7878 'Constructed sharding for array with shape: %s, mesh_shape: %s,'
You can’t perform that action at this time.
0 commit comments