Skip to content

Commit 323206d

Browse files
committed
Added various convenience method to Contract Manager; improved the type definitions for the StandardContracts; added a SchnorrSigner type to handle signatures; created an initial test suite for rps, vaults and ram examples
1 parent db2b672 commit 323206d

14 files changed

+510
-66
lines changed

examples/rps/rps.py

+1-32
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040
import argparse
4141
import socket
4242
import json
43-
import hashlib
4443
import random
4544
import os
4645

@@ -54,7 +53,7 @@
5453
from matt.environment import Environment
5554
from matt import ContractInstance, ContractManager
5655

57-
from rps_contracts import DEFAULT_STAKE, RPSGameS0
56+
from rps_contracts import DEFAULT_STAKE, RPS, RPSGameS0
5857

5958

6059
load_dotenv()
@@ -65,36 +64,6 @@
6564
rpc_port = os.getenv("RPC_PORT", 18443)
6665

6766

68-
class RPS:
69-
@staticmethod
70-
def move_str(move: int) -> str:
71-
assert 0 <= move <= 2
72-
if move == 0:
73-
return "rock"
74-
elif move == 1:
75-
return "paper"
76-
else:
77-
return "scissors"
78-
79-
@staticmethod
80-
def adjudicate(move_alice, move_bob):
81-
assert 0 <= move_alice <= 2 and 0 <= move_bob <= 2
82-
if move_bob == move_alice:
83-
return "tie"
84-
elif (move_bob - move_alice) % 3 == 2:
85-
return "alice_wins"
86-
else:
87-
return "bob_wins"
88-
89-
@staticmethod
90-
def calculate_hash(move: int, r: bytes) -> bytes:
91-
assert 0 <= move <= 2 and len(r) == 32
92-
93-
m = hashlib.sha256()
94-
m.update(script.bn2vch(move) + r)
95-
return m.digest()
96-
97-
9867
class AliceGame:
9968
def __init__(self, env: Environment, args: dict):
10069
self.env = env

examples/rps/rps_contracts.py

+46-26
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,45 @@
1-
from matt.argtypes import BytesType, IntType
2-
from matt.btctools.messages import CTransaction, CTxIn, CTxOut, sha256
1+
import hashlib
2+
3+
from matt.argtypes import BytesType, IntType, SignerType
4+
from matt.btctools.messages import sha256
5+
from matt.btctools import script
36
from matt.btctools.script import OP_ADD, OP_CAT, OP_CHECKCONTRACTVERIFY, OP_CHECKSIG, OP_CHECKTEMPLATEVERIFY, OP_DUP, OP_ENDIF, OP_EQUALVERIFY, OP_FROMALTSTACK, OP_IF, OP_LESSTHAN, OP_OVER, OP_SHA256, OP_SUB, OP_SWAP, OP_TOALTSTACK, OP_VERIFY, OP_WITHIN, CScript, bn2vch
47
from matt import CCV_FLAG_CHECK_INPUT, NUMS_KEY, P2TR, ClauseOutput, StandardClause, StandardP2TR, StandardAugmentedP2TR
8+
from matt.utils import make_ctv_template
59

610
DEFAULT_STAKE: int = 1000 # amount of sats that the players bet
711

812

13+
class RPS:
14+
@staticmethod
15+
def move_str(move: int) -> str:
16+
assert 0 <= move <= 2
17+
if move == 0:
18+
return "rock"
19+
elif move == 1:
20+
return "paper"
21+
else:
22+
return "scissors"
23+
24+
@staticmethod
25+
def adjudicate(move_alice, move_bob):
26+
assert 0 <= move_alice <= 2 and 0 <= move_bob <= 2
27+
if move_bob == move_alice:
28+
return "tie"
29+
elif (move_bob - move_alice) % 3 == 2:
30+
return "alice_wins"
31+
else:
32+
return "bob_wins"
33+
34+
@staticmethod
35+
def calculate_hash(move: int, r: bytes) -> bytes:
36+
assert 0 <= move <= 2 and len(r) == 32
37+
38+
m = hashlib.sha256()
39+
m.update(script.bn2vch(move) + r)
40+
return m.digest()
41+
42+
943
# params:
1044
# - alice_pk
1145
# - bob_pk
@@ -44,7 +78,7 @@ def __init__(self, alice_pk: bytes, bob_pk: bytes, c_a: bytes, stake: int = DEFA
4478
]),
4579
arg_specs=[
4680
('m_b', IntType()),
47-
('bob_sig', BytesType()),
81+
('bob_sig', SignerType(bob_pk)),
4882
],
4983
next_output_fn=lambda args: [ClauseOutput(n=0, next_contract=S1, next_data=sha256(bn2vch(args['m_b'])))]
5084
)
@@ -116,29 +150,15 @@ def make_script(diff: int, ctv_hash: bytes):
116150
OP_CHECKTEMPLATEVERIFY
117151
])
118152

119-
def make_ctv_hash(alice_amount, bob_amount) -> CTransaction:
120-
tmpl = CTransaction()
121-
tmpl.nVersion = 2
122-
tmpl.vin = [CTxIn(nSequence=0)]
123-
if alice_amount > 0:
124-
tmpl.vout.append(
125-
CTxOut(
126-
nValue=alice_amount,
127-
scriptPubKey=P2TR(self.alice_pk, []).get_tr_info().scriptPubKey
128-
)
129-
)
130-
if bob_amount > 0:
131-
tmpl.vout.append(
132-
CTxOut(
133-
nValue=bob_amount,
134-
scriptPubKey=P2TR(self.bob_pk, []).get_tr_info().scriptPubKey
135-
)
136-
)
137-
return tmpl
138-
139-
tmpl_alice_wins = make_ctv_hash(2*self.stake, 0)
140-
tmpl_bob_wins = make_ctv_hash(0, 2*self.stake)
141-
tmpl_tie = make_ctv_hash(self.stake, self.stake)
153+
alice_spk = P2TR(self.alice_pk, []).get_tr_info().scriptPubKey
154+
bob_spk = P2TR(self.bob_pk, []).get_tr_info().scriptPubKey
155+
156+
tmpl_alice_wins = make_ctv_template([(alice_spk, 2*self.stake)])
157+
tmpl_bob_wins = make_ctv_template([(bob_spk, 2*self.stake)])
158+
tmpl_tie = make_ctv_template([
159+
(alice_spk, self.stake),
160+
(bob_spk, self.stake),
161+
])
142162

143163
arg_specs = [
144164
('m_b', IntType()),

examples/vault/vault_contracts.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from matt.argtypes import BytesType, IntType
1+
from matt.argtypes import BytesType, IntType, SignerType
22
from matt.btctools.script import OP_CHECKCONTRACTVERIFY, OP_CHECKSEQUENCEVERIFY, OP_CHECKSIG, OP_CHECKTEMPLATEVERIFY, OP_DROP, OP_DUP, OP_SWAP, OP_TRUE, CScript
33
from matt import CCV_FLAG_CHECK_INPUT, CCV_FLAG_DEDUCT_OUTPUT_AMOUNT, NUMS_KEY, ClauseOutput, ClauseOutputAmountBehaviour, OpaqueP2TR, StandardClause, StandardP2TR, StandardAugmentedP2TR
44

@@ -27,7 +27,7 @@ def __init__(self, alternate_pk: bytes | None, spend_delay: int, recover_pk: byt
2727
OP_CHECKSIG
2828
]),
2929
arg_specs=[
30-
('sig', BytesType()),
30+
('sig', SignerType(unvault_pk)),
3131
('ctv_hash', BytesType()),
3232
('out_i', IntType()),
3333
],
@@ -55,7 +55,7 @@ def __init__(self, alternate_pk: bytes | None, spend_delay: int, recover_pk: byt
5555
OP_CHECKSIG
5656
]),
5757
arg_specs=[
58-
('sig', BytesType()),
58+
('sig', SignerType(unvault_pk)),
5959
('ctv_hash', BytesType()),
6060
('out_i', IntType()),
6161
('revault_out_i', IntType()),
@@ -117,7 +117,7 @@ def __init__(self, alternate_pk: bytes | None, spend_delay: int, recover_pk: byt
117117
OP_CHECKTEMPLATEVERIFY
118118
]),
119119
arg_specs=[
120-
('ctv_hash', bytes)
120+
('ctv_hash', BytesType())
121121
]
122122
)
123123

matt/__init__.py

+98-2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from dataclasses import dataclass
22
from enum import Enum
33
from io import BytesIO
4-
from typing import Callable
4+
from typing import Callable, Optional
55

6-
from .argtypes import ArgType
6+
from .argtypes import ArgType, SignerType
77
from .btctools import script, key
88
from .btctools.auth_proxy import AuthServiceProxy
99
from .btctools.messages import COutPoint, CTransaction, CTxIn, CTxInWitness
@@ -245,6 +245,36 @@ def __repr__(self):
245245
return f"{self.__class__.__name__}(naked_internal_pubkey={self.naked_internal_pubkey.hex()})"
246246

247247

248+
# Encapsulates a blind signer for one or more known keys.
249+
# Used byt the ContractManager to sign for the clause arguments of SignerType type.
250+
#
251+
# In the real world, we wouldn't blindly sign a hash, so the `sign` method
252+
# would include other info to help the signer decide (e.g.: the transaction)
253+
# There are no bad people here, though, so we keep it simple for now.
254+
class SchnorrSigner:
255+
def __init__(self, keys: key.ExtendedKey | list[key.ExtendedKey]):
256+
if not isinstance(keys, list):
257+
keys = [keys]
258+
259+
for key in keys:
260+
if not key.is_private:
261+
raise ValueError("The SchnorrSigner needs the private keys")
262+
263+
self.keys = keys
264+
265+
def sign(self, msg: bytes, pubkey: bytes) -> bytes | None:
266+
if len(msg) != 32:
267+
raise ValueError("msg should be 32 bytes long")
268+
if len(pubkey) != 32:
269+
raise ValueError("pubkey should be an x-only pubkey")
270+
271+
for k in self.keys:
272+
if k.pubkey[1:] == pubkey:
273+
return key.sign_schnorr(k.privkey, msg)
274+
275+
return None
276+
277+
248278
class ContractInstanceStatus(Enum):
249279
ABSTRACT = 0
250280
FUNDED = 1
@@ -258,6 +288,8 @@ def __init__(self, contract: StandardP2TR | StandardAugmentedP2TR):
258288

259289
self.data_expanded = None # TODO: figure out a good API for this
260290

291+
self.manager: ContractManager = None
292+
261293
self.last_height = 0
262294

263295
self.status = ContractInstanceStatus.ABSTRACT
@@ -303,6 +335,15 @@ def __repr__(self):
303335
value = self.funding_tx.vout[self.outpoint.n].nValue
304336
return f"{self.__class__.__name__}(contract={self.contract}, data={self.data if self.data is None else self.data.hex()}, value={value}, status={self.status}, outpoint={self.outpoint})"
305337

338+
def __call__(self, clause_name: str, *, signer: Optional[SchnorrSigner] = None, outputs: list[CTxOut] = [], **kwargs) -> list['ContractInstance']:
339+
if self.manager is None:
340+
raise ValueError("Direct invocation is only allowed after adding the instance to a ContractManager")
341+
342+
if self.status != ContractInstanceStatus.FUNDED:
343+
raise ValueError("Only implemented for FUNDED instances")
344+
345+
return self.manager.spend_instance(self, clause_name, kwargs, signer=signer, outputs=outputs)
346+
306347

307348
class ContractManager:
308349
def __init__(self, contract_instances: list[ContractInstance], rpc: AuthServiceProxy, *, poll_interval: float = 1, mine_automatically: bool = False):
@@ -324,6 +365,10 @@ def _check_instance(self, instance: ContractInstance, exp_statuses: None | Contr
324365
raise ValueError("Instance not in this manager")
325366

326367
def add_instance(self, instance: ContractInstance):
368+
if instance.manager is not None:
369+
raise ValueError("The instance can only be added to one ContractManager")
370+
371+
instance.manager = self
327372
self.instances.append(instance)
328373

329374
def wait_for_outpoint(self, instance: ContractInstance, txid: str | None = None):
@@ -544,3 +589,54 @@ def wait_for_spend(self, instances: ContractInstance | list[ContractInstance]) -
544589
for instance in result:
545590
self.add_instance(instance)
546591
return result
592+
593+
def fund_instance(self, contract: StandardP2TR | StandardAugmentedP2TR, amount: int, data: Optional[bytes] = None) -> ContractInstance:
594+
"""
595+
Convenience method to create an instance of a contract, add it to the ContractManager,
596+
and send a transaction to fund it with a certain amount.
597+
"""
598+
instance = ContractInstance(contract)
599+
600+
if isinstance(contract, StandardP2TR) and data is not None:
601+
raise ValueError("The data must None for a contract with no embedded data")
602+
603+
if isinstance(contract, StandardAugmentedP2TR):
604+
if data is None:
605+
raise ValueError("The data must be provided for an augmented P2TR contract instance")
606+
instance.data = data
607+
self.add_instance(instance)
608+
txid = self.rpc.sendtoaddress(instance.get_address(), amount/100_000_000)
609+
self.wait_for_outpoint(instance, txid)
610+
return instance
611+
612+
def spend_instance(self, instance: ContractInstance, clause_name: str, args: dict, *, signer: Optional[SchnorrSigner], outputs: Optional[list[CTxOut]] = None) -> list[ContractInstance]:
613+
"""
614+
Creates and broadcasts a transaction that spends a contract instance using a specified clause and arguments.
615+
616+
:param instance: The ContractInstance to spend from.
617+
:param clause_name: The name of the clause to be executed in the contract.
618+
:param args: A dictionary of arguments required for the clause.
619+
:param outputs: if not None, a list of CTxOut to add at the end of the list of
620+
outputs generated by the clause.
621+
:return: A list of ContractInstances resulting from the spend transaction.
622+
"""
623+
spend_tx, sighashes = self.get_spend_tx((instance, clause_name, args))
624+
625+
assert len(sighashes) == 1
626+
627+
sighash = sighashes[0]
628+
629+
if outputs is not None:
630+
spend_tx.vout.extend(outputs)
631+
632+
clause = instance.contract._clauses_dict[clause_name] # TODO: refactor, accessing private member
633+
for arg_name, arg_type in clause.arg_specs:
634+
if isinstance(arg_type, SignerType):
635+
if signer is None:
636+
raise ValueError("No signer was provided, but the witness requires signatures")
637+
args[arg_name] = signer.sign(sighash, arg_type.pubkey)
638+
639+
spend_tx.wit.vtxinwit = [self.get_spend_wit(instance, clause_name, args)]
640+
result = self.spend_and_wait(instance, spend_tx)
641+
642+
return result

matt/argtypes.py

+12
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,18 @@ def deserialize_from_wit(self, wit_stack: list[bytes]) -> Tuple[int, bytes]:
6565
return 1, wit_stack[0]
6666

6767

68+
class SignerType(BytesType):
69+
"""
70+
This is a special for arguments that represent signatures in tapscripts.
71+
It is encoded as bytes, but labeling it allows the ContractManager to get the correct
72+
signatures by calling SchnorrSigner object instances.
73+
"""
74+
def __init__(self, pubkey: bytes):
75+
if len(pubkey) != 32:
76+
raise ValueError("pubkey must be an x-only pubkey")
77+
self.pubkey = pubkey
78+
79+
6880
class MerkleProofType(ArgType):
6981
def __init__(self, depth: int):
7082
self.depth = depth

matt/utils.py

+33-2
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
import time
44

55
from .btctools.auth_proxy import AuthServiceProxy, JSONRPCException
6-
from .btctools.messages import COutPoint, CTransaction
6+
from .btctools.messages import COutPoint, CTransaction, CTxIn, CTxOut
77
from .btctools.script import CScript, CScriptNum, bn2vch
8+
from .btctools.segwit_addr import decode_segwit_address
89

910

1011
def vch2bn(s: bytes) -> int:
@@ -155,4 +156,34 @@ def print_tx(tx: CTransaction, title: str):
155156
156157
</details>
157158
158-
''')
159+
''')
160+
161+
162+
def addr_to_script(addr: str) -> bytes:
163+
# only for segwit/taproot on regtest
164+
# TODO: generalize to other address types, and other networks (currently, it assumes regtest)
165+
166+
wit_ver, wit_prog = decode_segwit_address("bcrt", addr)
167+
168+
if wit_ver is None or wit_prog is None:
169+
raise ValueError(f"Invalid segwit address (or wrong network): {addr}")
170+
171+
return bytes([
172+
wit_ver + (0x50 if wit_ver > 0 else 0),
173+
len(wit_prog),
174+
*wit_prog
175+
])
176+
177+
178+
def make_ctv_template(outputs: list[(bytes|str, int)], *, nVersion: int = 2, nSequence: int = 0) -> CTransaction:
179+
tmpl = CTransaction()
180+
tmpl.nVersion = nVersion
181+
tmpl.vin = [CTxIn(nSequence=nSequence)]
182+
for dest, amount in outputs:
183+
tmpl.vout.append(
184+
CTxOut(
185+
nValue=amount,
186+
scriptPubKey=dest if isinstance(dest, bytes) else addr_to_script(dest)
187+
)
188+
)
189+
return tmpl

pyproject.toml

+3
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,6 @@ requires-python = ">=3.7"
1414
keywords = ["covenant", "smart contracts", "bitcoin"]
1515
license = { file = "LICENSE" }
1616
dependencies = []
17+
18+
[tool.poetry.dev-dependencies]
19+
pytest = "^6.2.5"

0 commit comments

Comments
 (0)