Skip to content

Commit 37173cc

Browse files
authored
feat: add chain.get_balance() method (#2520)
1 parent 7e5e462 commit 37173cc

File tree

2 files changed

+39
-1
lines changed

2 files changed

+39
-1
lines changed

src/ape/managers/chain.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
APINotImplementedError,
2323
BlockNotFoundError,
2424
ChainError,
25+
ConversionError,
2526
ProviderNotConnectedError,
2627
QueryEngineError,
2728
TransactionNotFoundError,
@@ -937,7 +938,39 @@ def mine(
937938
self.pending_timestamp += deltatime
938939
self.provider.mine(num_blocks)
939940

940-
def set_balance(self, account: Union[BaseAddress, AddressType], amount: Union[int, str]):
941+
def get_balance(
942+
self, address: Union[BaseAddress, AddressType, str], block_id: Optional["BlockID"] = None
943+
) -> int:
944+
"""
945+
Get the balance of the given address. If ``ape-ens`` is installed,
946+
you can pass ENS names.
947+
948+
Args:
949+
address (BaseAddress, AddressType | str): An address, ENS, or account/contract.
950+
block_id (:class:`~ape.types.BlockID` | None): The block ID. Defaults to latest.
951+
952+
Returns:
953+
int: The account balance.
954+
"""
955+
if (isinstance(address, str) and not address.startswith("0x")) or not isinstance(
956+
address, str
957+
):
958+
try:
959+
address = self.conversion_manager.convert(address, AddressType)
960+
except ConversionError:
961+
# Try to get the balance anyway; maybe the provider can handle it.
962+
address = address
963+
964+
return self.provider.get_balance(address, block_id=block_id)
965+
966+
def set_balance(self, account: Union[BaseAddress, AddressType, str], amount: Union[int, str]):
967+
"""
968+
Set an account balance, only works on development chains.
969+
970+
Args:
971+
account (BaseAddress, AddressType | str): The account.
972+
amount (int | str): The new balance.
973+
"""
941974
if isinstance(account, BaseAddress):
942975
account = account.address
943976

tests/functional/test_chain.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,11 @@ def test_set_pending_timestamp_failure(chain):
173173
)
174174

175175

176+
def test_get_balance(chain, owner):
177+
assert chain.get_balance(owner) == owner.balance
178+
assert chain.get_balance(owner.address) == owner.balance
179+
180+
176181
def test_set_balance(chain, owner):
177182
with pytest.raises(APINotImplementedError):
178183
chain.set_balance(owner, "1000 ETH")

0 commit comments

Comments
 (0)