Skip to content

Value mismatch in static_pir when mutating base tensor after creating a flatten view #79127

@rookieLiu2018

Description

@rookieLiu2018

bug描述 Describe the Bug

Under the new static_pir infrastructure (paddle.jit.to_static(..., full_graph=True)), the compiler fails to track alias dependencies when an in-place mutation occurs on a base tensor after a view tensor (e.g., via paddle.flatten) has been created.

In Eager mode, mutating x implicitly updates y since they share the same underlying memory. However, in static_pir mode, the optimization or functionalization pass seemingly isolates the mutation to x and fails to synchronize the alias y. As a result, y stays unchanged and uses its stale, pre-mutation snapshot. This leads to a severe silent value corruption (Value Mismatch) without any compilation errors or warnings.

import paddle

paddle.disable_static()
paddle.set_device("cpu")

def kernel(x):
    y = paddle.flatten(x)
    
    # In-place mutation on the base tensor
    x[0, 0] = 99.0
    x[1, 1] = -3.0
    
    return y

x_eager = paddle.arange(6, dtype="float32").reshape([2, 3])
x_static = x_eager.clone()

# 1. Run Eager Mode
out_eager = kernel(x_eager)

# 2. Run Static PIR Mode
static_fn = paddle.jit.to_static(kernel, full_graph=True)
out_static = static_fn(x_static)

print(f"Eager Output:  {out_eager.numpy().tolist()}")
print(f"Static Output: {out_static.numpy().tolist()}")

Output:

Eager Output:  [99.0, 1.0, 2.0, 3.0, -3.0, 5.0]
Static Output: [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]

其他补充信息 Additional Supplementary Information

Paddle version: 3.3.0

Metadata

Metadata

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions