Skip to content

Commit 0a9c4c6

Browse files
authored
Refactor AddDataMovePass (#822)
* Removed special case for permute >4D since this is fixed in tt-metal * Typo fixes and conditional rearranging * Rewrite _change_layout and _create_aligned_node for clarity * Reduced set of operations used for _reset_to_default_layout This passes lowering tests, but I am less confident in its generality * Moved _call_to_torch_with_meta inside NodeAligner Renamed variables and functions for consistency * Added repeat to LAYOUT_CHANGE_OPS to get lowering tests to pass * Removed _to_torch_with_meta Reduced special case handling Eliminated unneeded from_device calls * Fixes an issue with tilized inputs to ttnn.split by converting to row-major layout * Removed function to check if ttnn.embedding is user of a given node Updated TTNN_LAYOUT_CHANGE_OPS to separate TTNN_ROW_LAYOUT_OPS and TTNN_HOST_ONLY_OPS * Added to_layout call to convert max_pool2d output to tilized Removed debug message that was cluttering logs * Added to_layout calls for functions that might have row major layout from ToTtPass * One more fix for ones, ones_like, zeros, and zeros_like that default to row major layout * Added more special cases to get beit working, handle int inputs * Added special case for GPT2 underflow issue
1 parent 3953290 commit 0a9c4c6

File tree

3 files changed

+146
-149
lines changed

3 files changed

+146
-149
lines changed

torch_ttnn/backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def aten_backend(
146146
passes = [
147147
ConstantFoldingPass(),
148148
ToTtPass(option.device, option.use_less_ttnn_op_types),
149-
AddDataMovePass(),
149+
AddDataMovePass(option.device),
150150
EliminateCoreopsPass(),
151151
CSEPass(),
152152
PermuteReshapeTuple(),

0 commit comments

Comments
 (0)