Skip to content

Commit

Permalink
feat: add chain.get_balance() method (#2520)
Browse files Browse the repository at this point in the history
  • Loading branch information
antazoey authored Feb 19, 2025
1 parent 7e5e462 commit 37173cc
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 1 deletion.
35 changes: 34 additions & 1 deletion src/ape/managers/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
APINotImplementedError,
BlockNotFoundError,
ChainError,
ConversionError,
ProviderNotConnectedError,
QueryEngineError,
TransactionNotFoundError,
Expand Down Expand Up @@ -937,7 +938,39 @@ def mine(
self.pending_timestamp += deltatime
self.provider.mine(num_blocks)

def set_balance(self, account: Union[BaseAddress, AddressType], amount: Union[int, str]):
def get_balance(
self, address: Union[BaseAddress, AddressType, str], block_id: Optional["BlockID"] = None
) -> int:
"""
Get the balance of the given address. If ``ape-ens`` is installed,
you can pass ENS names.
Args:
address (BaseAddress, AddressType | str): An address, ENS, or account/contract.
block_id (:class:`~ape.types.BlockID` | None): The block ID. Defaults to latest.
Returns:
int: The account balance.
"""
if (isinstance(address, str) and not address.startswith("0x")) or not isinstance(
address, str
):
try:
address = self.conversion_manager.convert(address, AddressType)
except ConversionError:
# Try to get the balance anyway; maybe the provider can handle it.
address = address

return self.provider.get_balance(address, block_id=block_id)

def set_balance(self, account: Union[BaseAddress, AddressType, str], amount: Union[int, str]):
"""
Set an account balance, only works on development chains.
Args:
account (BaseAddress, AddressType | str): The account.
amount (int | str): The new balance.
"""
if isinstance(account, BaseAddress):
account = account.address

Expand Down
5 changes: 5 additions & 0 deletions tests/functional/test_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,11 @@ def test_set_pending_timestamp_failure(chain):
)


def test_get_balance(chain, owner):
assert chain.get_balance(owner) == owner.balance
assert chain.get_balance(owner.address) == owner.balance


def test_set_balance(chain, owner):
with pytest.raises(APINotImplementedError):
chain.set_balance(owner, "1000 ETH")

0 comments on commit 37173cc

Please sign in to comment.