Skip to content

Commit 264657e

Browse files
committed
Add additional helper classes to handle Clauses and witness encoding/decoding for transitions
1 parent edb5afb commit 264657e

File tree

1 file changed

+154
-40
lines changed

1 file changed

+154
-40
lines changed

rps.py

+154-40
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,65 @@ def vch2bn(s: bytes) -> int:
8181
return -v_abs if is_negative else v_abs
8282

8383

84+
class Clause:
85+
def __init__(self, name: str, script: CScript):
86+
self.name = name
87+
self.script = script
88+
89+
def stack_elements_from_args(self, args: dict) -> List[bytes]:
90+
raise NotImplementedError
91+
92+
def args_from_stack_elements(self, elements: List[bytes]) -> dict:
93+
raise NotImplementedError
94+
95+
96+
StandardType = type[int] | type[bytes]
97+
98+
99+
# A StandardClause encodes simple scripts where the witness is exactly
100+
# a list of arguments, always in the same order, and each is either
101+
# an integer or a byte array.
102+
# Other types of generic treatable clauses could be defined (for example, a MiniscriptClause).
103+
class StandardClause(Clause):
104+
def __init__(self, name: str, script: CScript, arg_specs: list[tuple[str, StandardType]]):
105+
super().__init__(name, script)
106+
self.arg_specs = arg_specs
107+
108+
for _, arg_cls in self.arg_specs:
109+
if arg_cls not in [int, bytes]:
110+
raise ValueError(f"Unsupported type: {arg_cls.__name__}")
111+
112+
def stack_elements_from_args(self, args: dict) -> list[bytes]:
113+
result: list[bytes] = []
114+
for arg_name, arg_cls in self.arg_specs:
115+
if arg_name not in args:
116+
raise ValueError(f"Missing argument: {arg_name}")
117+
arg_value = args[arg_name]
118+
if type(arg_value) != arg_cls:
119+
raise ValueError(f"Argument {arg_name} must be of type {arg_cls.__name__}, not {type(arg_value).__name__}")
120+
if arg_cls == int:
121+
result.append(script.bn2vch(arg_value))
122+
elif arg_cls == bytes:
123+
result.append(arg_value)
124+
else:
125+
raise ValueError("Unexpected type") # this should never happen
126+
127+
return result
128+
129+
def args_from_stack_elements(self, elements: List[bytes]) -> dict:
130+
result: dict = {}
131+
if len(elements) != len(self.arg_specs):
132+
raise ValueError(f"Expected {len(self.arg_specs)} elements, not {len(elements)}")
133+
for i, (arg_name, arg_cls) in enumerate(self.arg_specs):
134+
if arg_cls == int:
135+
result[arg_name] = vch2bn(elements[i])
136+
elif arg_cls == bytes:
137+
result[arg_name] = elements[i]
138+
else:
139+
raise ValueError("Unexpected type") # this should never happen
140+
return result
141+
142+
84143
class P2TR:
85144
"""
86145
A class representing a Pay-to-Taproot script.
@@ -133,13 +192,81 @@ def get_tx_out(self, value: int, data: bytes) -> CTxOut:
133192
return CTxOut(nValue=value, scriptPubKey=self.get_tr_info(data).scriptPubKey)
134193

135194

195+
class StandardP2TR(P2TR):
196+
"""
197+
A StandardP2TR where all the transitions are given by a StandardClause.
198+
"""
199+
200+
def __init__(self, internal_pubkey: bytes, clauses: list[StandardClause]):
201+
super().__init__(internal_pubkey, list(map(lambda x: (x.name, x.script), clauses)))
202+
self.clauses = clauses
203+
self._clauses_dict = {clause.name: clause for clause in clauses}
204+
205+
def get_scripts(self) -> List[Tuple[str, CScript]]:
206+
return list(map(lambda clause: (clause.name, clause.script), self.clauses))
207+
208+
def encode_args(self, clause_name: str, **args: dict) -> list[bytes]:
209+
return [
210+
*self._clauses_dict[clause_name].stack_elements_from_args(args),
211+
self.get_tr_info().leaves[clause_name].script,
212+
self.get_tr_info().controlblock_for_script_spend(clause_name),
213+
]
214+
215+
def decode_wit_stack(self, stack_elems: list[bytes]) -> tuple[str, dict]:
216+
leaf_hash = stack_elems[-2]
217+
218+
clause_name = None
219+
for clause in self.clauses:
220+
if leaf_hash == self.get_tr_info().leaves[clause.name].script:
221+
clause_name = clause.name
222+
break
223+
if clause_name is None:
224+
raise ValueError("Clause not found")
225+
226+
return clause_name, self._clauses_dict[clause_name].args_from_stack_elements(stack_elems[:-2])
227+
228+
229+
class StandardAugmentedP2TR(AugmentedP2TR):
230+
"""
231+
An AugmentedP2TR where all the transitions are given by a StandardClause.
232+
"""
233+
234+
def __init__(self, naked_internal_pubkey: bytes, clauses: list[StandardClause]):
235+
super().__init__(naked_internal_pubkey)
236+
self.clauses = clauses
237+
self._clauses_dict = {clause.name: clause for clause in clauses}
238+
239+
def get_scripts(self) -> List[Tuple[str, CScript]]:
240+
return list(map(lambda clause: (clause.name, clause.script), self.clauses))
241+
242+
def encode_args(self, clause_name: str, data: bytes, **args: dict) -> list[bytes]:
243+
return [
244+
*self._clauses_dict[clause_name].stack_elements_from_args(args),
245+
self.get_tr_info(data).leaves[clause_name].script,
246+
self.get_tr_info(data).controlblock_for_script_spend(clause_name),
247+
]
248+
249+
def decode_wit_stack(self, data: bytes, stack_elems: list[bytes]) -> tuple[str, dict]:
250+
leaf_hash = stack_elems[-2]
251+
252+
clause_name = None
253+
for clause in self.clauses:
254+
if leaf_hash == self.get_tr_info(data).leaves[clause.name].script:
255+
clause_name = clause.name
256+
break
257+
if clause_name is None:
258+
raise ValueError("Clause not found")
259+
260+
return clause_name, self._clauses_dict[clause_name].args_from_stack_elements(stack_elems[:-2])
261+
262+
136263
# params:
137264
# - alice_pk
138265
# - bob_pk
139266
# - c_a
140267
# spending conditions:
141268
# - bob_pk (m_b) => RPSGameS0[m_b]
142-
class RPSGameS0(P2TR):
269+
class RPSGameS0(StandardP2TR):
143270
def __init__(self, alice_pk: bytes, bob_pk: bytes, c_a: bytes):
144271
assert len(alice_pk) == 32 and len(bob_pk) == 32 and len(c_a) == 32
145272

@@ -148,7 +275,7 @@ def __init__(self, alice_pk: bytes, bob_pk: bytes, c_a: bytes):
148275
self.c_a = c_a
149276

150277
# witness: <m_b> <bob_sig>
151-
bob_move = (
278+
bob_move = StandardClause(
152279
"bob_move",
153280
CScript([
154281
bob_pk,
@@ -165,7 +292,10 @@ def __init__(self, alice_pk: bytes, bob_pk: bytes, c_a: bytes):
165292
RPSGameS1(alice_pk, bob_pk, c_a).get_taptree(),
166293
0, # flags
167294
OP_CHECKCONTRACTVERIFY,
168-
])
295+
]), [
296+
('m_b', int),
297+
('bob_sig', bytes),
298+
]
169299
)
170300

171301
super().__init__(NUMS_KEY, [bob_move])
@@ -181,15 +311,12 @@ def __init__(self, alice_pk: bytes, bob_pk: bytes, c_a: bytes):
181311
# - alice_pk, reveal winning move => ctv(alice wins)
182312
# - alice_pk, reveal losing move => ctv(bob wins)
183313
# - alice_pk, reveal tie move => ctv(tie)
184-
class RPSGameS1(AugmentedP2TR):
314+
class RPSGameS1(StandardAugmentedP2TR):
185315
def __init__(self, alice_pk: bytes, bob_pk: bytes, c_a: bytes):
186316
self.alice_pk = alice_pk
187317
self.bob_pk = bob_pk
188318
self.c_a = c_a
189319

190-
super().__init__(NUMS_KEY)
191-
192-
def get_scripts(self):
193320
def make_script(outcome: int, ctv_hash: bytes):
194321
assert 0 <= outcome <= 2
195322
# witness: [<m_b> <m_a> <r_a>]
@@ -260,11 +387,16 @@ def make_ctv_hash(alice_amount, bob_amount) -> bytes:
260387
ctvhash_bob_wins = make_ctv_hash(0, 2*STAKE)
261388
ctvhash_tie = make_ctv_hash(STAKE, STAKE)
262389

263-
alice_wins = ("tie", make_script(0, ctvhash_tie))
264-
bob_wins = ("bob_wins", make_script(1, ctvhash_bob_wins))
265-
tie = ("alice_wins", make_script(2, ctvhash_alice_wins))
390+
arg_specs = [
391+
('m_b', int),
392+
('m_a', int),
393+
('r_a', bytes),
394+
]
395+
alice_wins = StandardClause("tie", make_script(0, ctvhash_tie), arg_specs)
396+
bob_wins = StandardClause("bob_wins", make_script(1, ctvhash_bob_wins), arg_specs)
397+
tie = StandardClause("alice_wins", make_script(2, ctvhash_alice_wins), arg_specs)
266398

267-
return [alice_wins, bob_wins, tie]
399+
super().__init__(NUMS_KEY, [alice_wins, bob_wins, tie])
268400

269401

270402
load_dotenv()
@@ -379,20 +511,11 @@ def start_session(self, m_a: int):
379511

380512
assert len(in_wit.scriptWitness.stack) == 4
381513

382-
[m_b_bytes, bob_sig, leaf_script, control_block] = in_wit.scriptWitness.stack
383-
384-
assert leaf_script == contract_S0.get_tr_info().leaves["bob_move"].script
385-
assert control_block == contract_S0.get_tr_info().controlblock_for_script_spend("bob_move")
386-
387-
assert 0 <= len(m_b_bytes) <= 1
388-
389-
m_b = vch2bn(m_b_bytes)
390-
m_b_hash = sha256(m_b_bytes)
391-
514+
_, args = contract_S0.decode_wit_stack(in_wit.scriptWitness.stack)
515+
m_b: int = args['m_b']
392516
assert 0 <= m_b <= 2
393517

394-
print(m_b_bytes.hex())
395-
print(f"Bob's move: {m_b} ({RPS.move_str(m_b)}). Hash: {m_b_hash.hex()}")
518+
print(f"Bob's move: {m_b} ({RPS.move_str(m_b)}).")
396519

397520
outcome = RPS.adjudicate(m_a, m_b)
398521
print(f"Game result: {outcome}")
@@ -431,14 +554,10 @@ def start_session(self, m_a: int):
431554
)
432555
]
433556

557+
m_b_hash = sha256(script.bn2vch(m_b))
558+
434559
tx_payout.wit.vtxinwit = [CTxInWitness()]
435-
tx_payout.wit.vtxinwit[0].scriptWitness.stack = [
436-
script.bn2vch(m_b),
437-
script.bn2vch(m_a),
438-
r_a,
439-
RPS_S1.get_tr_info(m_b_hash).leaves[outcome].script,
440-
RPS_S1.get_tr_info(m_b_hash).controlblock_for_script_spend(outcome),
441-
]
560+
tx_payout.wit.vtxinwit[0].scriptWitness.stack = RPS_S1.encode_args(outcome, m_b_hash, m_b=m_b, m_a=m_a, r_a=r_a)
442561

443562
self.prompt("Broadcasting adjudication transaction")
444563
txid = rpc.sendrawtransaction(tx_payout.serialize().hex())
@@ -518,12 +637,8 @@ def join_session(self, m_b: int):
518637
bob_sig = key.sign_schnorr(self.priv_key.privkey, sighash)
519638

520639
tx.wit.vtxinwit = [CTxInWitness()]
521-
tx.wit.vtxinwit[0].scriptWitness.stack = [
522-
script.bn2vch(m_b),
523-
bob_sig,
524-
contract_S0.get_tr_info().leaves["bob_move"].script,
525-
contract_S0.get_tr_info().controlblock_for_script_spend("bob_move"),
526-
]
640+
641+
tx.wit.vtxinwit[0].scriptWitness.stack = contract_S0.encode_args('bob_move', m_b=m_b, bob_sig=bob_sig)
527642

528643
txid = tx.rehash()
529644

@@ -544,10 +659,9 @@ def join_session(self, m_b: int):
544659
tx, vin, last_height = wait_for_spending_tx(rpc, contract_S1_outpoint, starting_height=last_height)
545660
in_wit: CTxInWitness = tx.wit.vtxinwit[vin]
546661

547-
leaf_script, control_block = in_wit.scriptWitness.stack[-2:]
548-
for transition_name in map(lambda x: x[0], S1.get_scripts()):
549-
if leaf_script == S1.get_tr_info(m_b_hash).leaves[transition_name].script and control_block == S1.get_tr_info(m_b_hash).controlblock_for_script_spend(transition_name):
550-
print(f"Outcome: {transition_name}")
662+
outcome, _ = S1.decode_wit_stack(m_b_hash, in_wit.scriptWitness.stack)
663+
664+
print(f"Outcome: {outcome}")
551665

552666
s.close()
553667

0 commit comments

Comments
 (0)