Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 68 additions & 9 deletions slither/analyses/data_dependency/data_dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,68 @@ def add_dependency(lvalue: Variable, function: Function, ir: Operation, is_prote
function.context[KEY_SSA_UNPROTECTED][lvalue].add(v)


def _init_unset_local_ir_vars(function: Function) -> dict:
unset_local_ir_vars = {}
for node in function.nodes:
for ir in node.irs_ssa:
for v in ir.used:
if isinstance(v, LocalIRVariable):
if v not in unset_local_ir_vars:
unset_local_ir_vars[v] = True
return unset_local_ir_vars


def _handle_operation_with_lvalue(function: Function, node, is_protected, unset_local_ir_vars):
for ir in node.irs_ssa:
if isinstance(ir, OperationWithLValue) and ir.lvalue:
if isinstance(ir.lvalue, LocalIRVariable) and ir.lvalue.is_storage:
continue
if isinstance(ir.lvalue, ReferenceVariable):
lvalue = ir.lvalue.points_to
if lvalue:
unset_local_ir_vars[lvalue] = False
add_dependency(lvalue, function, ir, is_protected)
unset_local_ir_vars[ir.lvalue] = False
add_dependency(ir.lvalue, function, ir, is_protected)


def _add_param_dependency_for_variable(
function: Function, v: LocalIRVariable, param_ssa, is_protected, unset_local_ir_vars
):
"""Add parameter dependency for a given variable and parameter."""
if v not in function.context[KEY_SSA]:
function.context[KEY_SSA][v] = set()
function.context[KEY_SSA][v].add(param_ssa)

if not is_protected:
if v not in function.context[KEY_SSA_UNPROTECTED]:
function.context[KEY_SSA_UNPROTECTED][v] = set()
function.context[KEY_SSA_UNPROTECTED][v].add(param_ssa)
unset_local_ir_vars[param_ssa] = False


def _find_matching_parameter(function: Function, v: LocalIRVariable):
"""Find the parameter that matches the given variable's non-SSA version."""
for param_ssa in function.parameters_ssa:
if v.non_ssa_version == param_ssa.non_ssa_version:
return param_ssa
return None


def _add_param_dependency_if_needed(function: Function, node, is_protected, unset_local_ir_vars):
for ir in node.irs_ssa:
for v in ir.used:
if not (isinstance(v, LocalIRVariable) and unset_local_ir_vars.get(v)):
continue

param_ssa = _find_matching_parameter(function, v)
if param_ssa:
_add_param_dependency_for_variable(
function, v, param_ssa, is_protected, unset_local_ir_vars
)
break


def compute_dependency_function(function: Function) -> None:
if KEY_SSA in function.context:
return
Expand All @@ -441,16 +503,13 @@ def compute_dependency_function(function: Function) -> None:
function.context[KEY_SSA_UNPROTECTED] = {}

is_protected = function.is_protected()
unset_local_ir_vars = _init_unset_local_ir_vars(function)

for node in function.nodes:
for ir in node.irs_ssa:
if isinstance(ir, OperationWithLValue) and ir.lvalue:
if isinstance(ir.lvalue, LocalIRVariable) and ir.lvalue.is_storage:
continue
if isinstance(ir.lvalue, ReferenceVariable):
lvalue = ir.lvalue.points_to
if lvalue:
add_dependency(lvalue, function, ir, is_protected)
add_dependency(ir.lvalue, function, ir, is_protected)
_handle_operation_with_lvalue(function, node, is_protected, unset_local_ir_vars)

for node in function.nodes:
_add_param_dependency_if_needed(function, node, is_protected, unset_local_ir_vars)

function.context[KEY_NON_SSA] = convert_to_non_ssa(function.context[KEY_SSA])
function.context[KEY_NON_SSA_UNPROTECTED] = convert_to_non_ssa(
Expand Down
Empty file.
36 changes: 36 additions & 0 deletions tests/unit/analyses/test_data/parameter_dependency_ssa.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
pragma solidity ^0.8.24;

contract TSURUWrapper{
bool private _opened = false;
address public immutable erc721Contract;
uint256 public constant _maxTotalSupply = 431_386_000 * 1e18;
uint256 private constant ERC721_RATIO = 400 * 1e18;
mapping(address owner => uint256) private _balancesOfOwner;
uint256 private _holders;
uint256 private _totalSupply;
function totalSupply() public view virtual returns (uint256) {
return _totalSupply;
}
function onERC721Received(
address,
address from,
uint256,
bytes calldata
) external returns (bytes4) {
require(_opened, "Already yet open.");
require(msg.sender == address(erc721Contract), "Unauthorized token");
_safeMint(from, ERC721_RATIO); // Adjust minting based on the ERC721_RATIO
return this.onERC721Received.selector;
}

function _safeMint(address account, uint256 value) internal {
require(_maxTotalSupply > totalSupply() + value, "Max supply exceeded.");

// _mint(account, value);

if (_balancesOfOwner[account] == 0) {
++_holders;
}
_balancesOfOwner[account] = _balancesOfOwner[account] + value;
}
}
43 changes: 43 additions & 0 deletions tests/unit/analyses/test_data_dependencies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from pathlib import Path

from slither import Slither
from slither.analyses.data_dependency.data_dependency import is_dependent_ssa
from slither.slithir.variables import LocalIRVariable

TEST_DATA_DIR = Path(__file__).resolve().parent / "test_data"


def test_param_dependency(solc_binary_path) -> None:
solc_path = solc_binary_path("0.8.24")
slither = Slither(
Path(TEST_DATA_DIR, "parameter_dependency_ssa.sol").as_posix(), solc=solc_path
)

target_function = slither.contracts[0].get_function_from_signature(
"onERC721Received(address,address,uint256,bytes)"
)

# Param is from (from_0 in SSA)
# Local SSA variable is from_1

param_var = None

for param_ssa in target_function.parameters_ssa:
if param_ssa.non_ssa_version.name == "from":
param_var = param_ssa
break

assert param_var is not None, "Param variable not found in SSA"

local_var = None
for ir in target_function.slithir_ssa_operations:
for v in ir.used:
if isinstance(v, LocalIRVariable) and v.non_ssa_version.name == "from" and v.index == 1:
local_var = v
break

assert local_var is not None, "Local variable not found in SSA"

assert is_dependent_ssa(
local_var, param_var, target_function
), "Param and local variable are not dependent"