Skip to content

Commit

Permalink
MX Updated to_blocked to not call nn.pad (#1762)
Browse files Browse the repository at this point in the history
stack-info: PR: #1762, branch: drisspg/stack/38
  • Loading branch information
drisspg authored Feb 22, 2025
1 parent d370196 commit 2a3fbff
Showing 1 changed file with 14 additions and 6 deletions.
20 changes: 14 additions & 6 deletions torchao/prototype/mx_formats/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# LICENSE file in the root directory of this source tree.

import torch
import torch.nn.functional as F

Tensor = torch.Tensor

Expand All @@ -31,14 +30,23 @@ def to_blocked(input_matrix) -> Tensor:
n_row_blocks = ceil_div(rows, 128)
n_col_blocks = ceil_div(cols, 4)

# Pad out and view as tiles of (128, 4)
padded = F.pad(input_matrix, (0, -cols % 4, 0, -rows % 128))
blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3)
# Calculate the padded shape
padded_rows = n_row_blocks * 128
padded_cols = n_col_blocks * 4

padded = input_matrix
if (rows, cols) != (padded_rows, padded_cols):
padded = torch.zeros(
(padded_rows, padded_cols),
device=input_matrix.device,
dtype=input_matrix.dtype,
)
padded[:rows, :cols] = input_matrix

# rearrange all tiles
# Rearrange the blocks
blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3)
rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16)

# Layout rearranged tiles according to second pic
return rearranged.flatten()


Expand Down

0 comments on commit 2a3fbff

Please sign in to comment.