Skip to content

Commit 6e5d352

Browse files
BlaziusMaximusOrbax Authors
authored andcommitted
Fix typo: parition -> partition.
PiperOrigin-RevId: 829103956
1 parent 2344893 commit 6e5d352

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

checkpoint/orbax/checkpoint/_src/arrays/sharding.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def _construct_maximal_sharding(
5050
mesh_shape = []
5151
mesh_axes = []
5252

53-
current_parition_axis = 0
53+
current_partition_axis = 0
5454
# Max to min.
5555
for i in np.argsort(shape)[::-1]:
5656
assert available_device_dim > 0
@@ -64,15 +64,15 @@ def _construct_maximal_sharding(
6464
available_device_dim //= gcd
6565
mesh_shape.append(gcd)
6666

67-
current_parition_axis_name = _partition_axis_name(current_parition_axis)
68-
partition_axes[i] = current_parition_axis_name
69-
mesh_axes.append(current_parition_axis_name)
70-
current_parition_axis += 1
67+
current_partition_axis_name = _partition_axis_name(current_partition_axis)
68+
partition_axes[i] = current_partition_axis_name
69+
mesh_axes.append(current_partition_axis_name)
70+
current_partition_axis += 1
7171

7272
# Still have some partition dimension left over.
7373
if available_device_dim > 1:
7474
mesh_shape.append(available_device_dim)
75-
mesh_axes.append(_partition_axis_name(current_parition_axis))
75+
mesh_axes.append(_partition_axis_name(current_partition_axis))
7676

7777
logging.info(
7878
'Constructed sharding for array with shape: %s, mesh_shape: %s,'

0 commit comments

Comments
 (0)