diff --git a/slither/analyses/data_dependency/data_dependency.py b/slither/analyses/data_dependency/data_dependency.py index 12e809fa53..cea3082956 100644 --- a/slither/analyses/data_dependency/data_dependency.py +++ b/slither/analyses/data_dependency/data_dependency.py @@ -433,6 +433,68 @@ def add_dependency(lvalue: Variable, function: Function, ir: Operation, is_prote function.context[KEY_SSA_UNPROTECTED][lvalue].add(v) +def _init_unset_local_ir_vars(function: Function) -> dict: + unset_local_ir_vars = {} + for node in function.nodes: + for ir in node.irs_ssa: + for v in ir.used: + if isinstance(v, LocalIRVariable): + if v not in unset_local_ir_vars: + unset_local_ir_vars[v] = True + return unset_local_ir_vars + + +def _handle_operation_with_lvalue(function: Function, node, is_protected, unset_local_ir_vars): + for ir in node.irs_ssa: + if isinstance(ir, OperationWithLValue) and ir.lvalue: + if isinstance(ir.lvalue, LocalIRVariable) and ir.lvalue.is_storage: + continue + if isinstance(ir.lvalue, ReferenceVariable): + lvalue = ir.lvalue.points_to + if lvalue: + unset_local_ir_vars[lvalue] = False + add_dependency(lvalue, function, ir, is_protected) + unset_local_ir_vars[ir.lvalue] = False + add_dependency(ir.lvalue, function, ir, is_protected) + + +def _add_param_dependency_for_variable( + function: Function, v: LocalIRVariable, param_ssa, is_protected, unset_local_ir_vars +): + """Add parameter dependency for a given variable and parameter.""" + if v not in function.context[KEY_SSA]: + function.context[KEY_SSA][v] = set() + function.context[KEY_SSA][v].add(param_ssa) + + if not is_protected: + if v not in function.context[KEY_SSA_UNPROTECTED]: + function.context[KEY_SSA_UNPROTECTED][v] = set() + function.context[KEY_SSA_UNPROTECTED][v].add(param_ssa) + unset_local_ir_vars[param_ssa] = False + + +def _find_matching_parameter(function: Function, v: LocalIRVariable): + """Find the parameter that matches the given variable's non-SSA version.""" + for param_ssa in function.parameters_ssa: + if v.non_ssa_version == param_ssa.non_ssa_version: + return param_ssa + return None + + +def _add_param_dependency_if_needed(function: Function, node, is_protected, unset_local_ir_vars): + for ir in node.irs_ssa: + for v in ir.used: + if not (isinstance(v, LocalIRVariable) and unset_local_ir_vars.get(v)): + continue + + param_ssa = _find_matching_parameter(function, v) + if param_ssa: + _add_param_dependency_for_variable( + function, v, param_ssa, is_protected, unset_local_ir_vars + ) + break + + def compute_dependency_function(function: Function) -> None: if KEY_SSA in function.context: return @@ -441,16 +503,13 @@ def compute_dependency_function(function: Function) -> None: function.context[KEY_SSA_UNPROTECTED] = {} is_protected = function.is_protected() + unset_local_ir_vars = _init_unset_local_ir_vars(function) + for node in function.nodes: - for ir in node.irs_ssa: - if isinstance(ir, OperationWithLValue) and ir.lvalue: - if isinstance(ir.lvalue, LocalIRVariable) and ir.lvalue.is_storage: - continue - if isinstance(ir.lvalue, ReferenceVariable): - lvalue = ir.lvalue.points_to - if lvalue: - add_dependency(lvalue, function, ir, is_protected) - add_dependency(ir.lvalue, function, ir, is_protected) + _handle_operation_with_lvalue(function, node, is_protected, unset_local_ir_vars) + + for node in function.nodes: + _add_param_dependency_if_needed(function, node, is_protected, unset_local_ir_vars) function.context[KEY_NON_SSA] = convert_to_non_ssa(function.context[KEY_SSA]) function.context[KEY_NON_SSA_UNPROTECTED] = convert_to_non_ssa( diff --git a/tests/unit/analyses/__init__.py b/tests/unit/analyses/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/analyses/test_data/parameter_dependency_ssa.sol b/tests/unit/analyses/test_data/parameter_dependency_ssa.sol new file mode 100644 index 0000000000..e63dada3f6 --- /dev/null +++ b/tests/unit/analyses/test_data/parameter_dependency_ssa.sol @@ -0,0 +1,36 @@ +pragma solidity ^0.8.24; + +contract TSURUWrapper{ + bool private _opened = false; + address public immutable erc721Contract; + uint256 public constant _maxTotalSupply = 431_386_000 * 1e18; + uint256 private constant ERC721_RATIO = 400 * 1e18; + mapping(address owner => uint256) private _balancesOfOwner; + uint256 private _holders; + uint256 private _totalSupply; + function totalSupply() public view virtual returns (uint256) { + return _totalSupply; + } + function onERC721Received( + address, + address from, + uint256, + bytes calldata + ) external returns (bytes4) { + require(_opened, "Already yet open."); + require(msg.sender == address(erc721Contract), "Unauthorized token"); + _safeMint(from, ERC721_RATIO); // Adjust minting based on the ERC721_RATIO + return this.onERC721Received.selector; + } + + function _safeMint(address account, uint256 value) internal { + require(_maxTotalSupply > totalSupply() + value, "Max supply exceeded."); + + // _mint(account, value); + + if (_balancesOfOwner[account] == 0) { + ++_holders; + } + _balancesOfOwner[account] = _balancesOfOwner[account] + value; + } +} \ No newline at end of file diff --git a/tests/unit/analyses/test_data_dependencies.py b/tests/unit/analyses/test_data_dependencies.py new file mode 100644 index 0000000000..cfc7d5082f --- /dev/null +++ b/tests/unit/analyses/test_data_dependencies.py @@ -0,0 +1,43 @@ +from pathlib import Path + +from slither import Slither +from slither.analyses.data_dependency.data_dependency import is_dependent_ssa +from slither.slithir.variables import LocalIRVariable + +TEST_DATA_DIR = Path(__file__).resolve().parent / "test_data" + + +def test_param_dependency(solc_binary_path) -> None: + solc_path = solc_binary_path("0.8.24") + slither = Slither( + Path(TEST_DATA_DIR, "parameter_dependency_ssa.sol").as_posix(), solc=solc_path + ) + + target_function = slither.contracts[0].get_function_from_signature( + "onERC721Received(address,address,uint256,bytes)" + ) + + # Param is from (from_0 in SSA) + # Local SSA variable is from_1 + + param_var = None + + for param_ssa in target_function.parameters_ssa: + if param_ssa.non_ssa_version.name == "from": + param_var = param_ssa + break + + assert param_var is not None, "Param variable not found in SSA" + + local_var = None + for ir in target_function.slithir_ssa_operations: + for v in ir.used: + if isinstance(v, LocalIRVariable) and v.non_ssa_version.name == "from" and v.index == 1: + local_var = v + break + + assert local_var is not None, "Local variable not found in SSA" + + assert is_dependent_ssa( + local_var, param_var, target_function + ), "Param and local variable are not dependent"