@@ -81,6 +81,65 @@ def vch2bn(s: bytes) -> int:
81
81
return - v_abs if is_negative else v_abs
82
82
83
83
84
+ class Clause :
85
+ def __init__ (self , name : str , script : CScript ):
86
+ self .name = name
87
+ self .script = script
88
+
89
+ def stack_elements_from_args (self , args : dict ) -> List [bytes ]:
90
+ raise NotImplementedError
91
+
92
+ def args_from_stack_elements (self , elements : List [bytes ]) -> dict :
93
+ raise NotImplementedError
94
+
95
+
96
+ StandardType = type [int ] | type [bytes ]
97
+
98
+
99
+ # A StandardClause encodes simple scripts where the witness is exactly
100
+ # a list of arguments, always in the same order, and each is either
101
+ # an integer or a byte array.
102
+ # Other types of generic treatable clauses could be defined (for example, a MiniscriptClause).
103
+ class StandardClause (Clause ):
104
+ def __init__ (self , name : str , script : CScript , arg_specs : list [tuple [str , StandardType ]]):
105
+ super ().__init__ (name , script )
106
+ self .arg_specs = arg_specs
107
+
108
+ for _ , arg_cls in self .arg_specs :
109
+ if arg_cls not in [int , bytes ]:
110
+ raise ValueError (f"Unsupported type: { arg_cls .__name__ } " )
111
+
112
+ def stack_elements_from_args (self , args : dict ) -> list [bytes ]:
113
+ result : list [bytes ] = []
114
+ for arg_name , arg_cls in self .arg_specs :
115
+ if arg_name not in args :
116
+ raise ValueError (f"Missing argument: { arg_name } " )
117
+ arg_value = args [arg_name ]
118
+ if type (arg_value ) != arg_cls :
119
+ raise ValueError (f"Argument { arg_name } must be of type { arg_cls .__name__ } , not { type (arg_value ).__name__ } " )
120
+ if arg_cls == int :
121
+ result .append (script .bn2vch (arg_value ))
122
+ elif arg_cls == bytes :
123
+ result .append (arg_value )
124
+ else :
125
+ raise ValueError ("Unexpected type" ) # this should never happen
126
+
127
+ return result
128
+
129
+ def args_from_stack_elements (self , elements : List [bytes ]) -> dict :
130
+ result : dict = {}
131
+ if len (elements ) != len (self .arg_specs ):
132
+ raise ValueError (f"Expected { len (self .arg_specs )} elements, not { len (elements )} " )
133
+ for i , (arg_name , arg_cls ) in enumerate (self .arg_specs ):
134
+ if arg_cls == int :
135
+ result [arg_name ] = vch2bn (elements [i ])
136
+ elif arg_cls == bytes :
137
+ result [arg_name ] = elements [i ]
138
+ else :
139
+ raise ValueError ("Unexpected type" ) # this should never happen
140
+ return result
141
+
142
+
84
143
class P2TR :
85
144
"""
86
145
A class representing a Pay-to-Taproot script.
@@ -133,13 +192,81 @@ def get_tx_out(self, value: int, data: bytes) -> CTxOut:
133
192
return CTxOut (nValue = value , scriptPubKey = self .get_tr_info (data ).scriptPubKey )
134
193
135
194
195
+ class StandardP2TR (P2TR ):
196
+ """
197
+ A StandardP2TR where all the transitions are given by a StandardClause.
198
+ """
199
+
200
+ def __init__ (self , internal_pubkey : bytes , clauses : list [StandardClause ]):
201
+ super ().__init__ (internal_pubkey , list (map (lambda x : (x .name , x .script ), clauses )))
202
+ self .clauses = clauses
203
+ self ._clauses_dict = {clause .name : clause for clause in clauses }
204
+
205
+ def get_scripts (self ) -> List [Tuple [str , CScript ]]:
206
+ return list (map (lambda clause : (clause .name , clause .script ), self .clauses ))
207
+
208
+ def encode_args (self , clause_name : str , ** args : dict ) -> list [bytes ]:
209
+ return [
210
+ * self ._clauses_dict [clause_name ].stack_elements_from_args (args ),
211
+ self .get_tr_info ().leaves [clause_name ].script ,
212
+ self .get_tr_info ().controlblock_for_script_spend (clause_name ),
213
+ ]
214
+
215
+ def decode_wit_stack (self , stack_elems : list [bytes ]) -> tuple [str , dict ]:
216
+ leaf_hash = stack_elems [- 2 ]
217
+
218
+ clause_name = None
219
+ for clause in self .clauses :
220
+ if leaf_hash == self .get_tr_info ().leaves [clause .name ].script :
221
+ clause_name = clause .name
222
+ break
223
+ if clause_name is None :
224
+ raise ValueError ("Clause not found" )
225
+
226
+ return clause_name , self ._clauses_dict [clause_name ].args_from_stack_elements (stack_elems [:- 2 ])
227
+
228
+
229
+ class StandardAugmentedP2TR (AugmentedP2TR ):
230
+ """
231
+ An AugmentedP2TR where all the transitions are given by a StandardClause.
232
+ """
233
+
234
+ def __init__ (self , naked_internal_pubkey : bytes , clauses : list [StandardClause ]):
235
+ super ().__init__ (naked_internal_pubkey )
236
+ self .clauses = clauses
237
+ self ._clauses_dict = {clause .name : clause for clause in clauses }
238
+
239
+ def get_scripts (self ) -> List [Tuple [str , CScript ]]:
240
+ return list (map (lambda clause : (clause .name , clause .script ), self .clauses ))
241
+
242
+ def encode_args (self , clause_name : str , data : bytes , ** args : dict ) -> list [bytes ]:
243
+ return [
244
+ * self ._clauses_dict [clause_name ].stack_elements_from_args (args ),
245
+ self .get_tr_info (data ).leaves [clause_name ].script ,
246
+ self .get_tr_info (data ).controlblock_for_script_spend (clause_name ),
247
+ ]
248
+
249
+ def decode_wit_stack (self , data : bytes , stack_elems : list [bytes ]) -> tuple [str , dict ]:
250
+ leaf_hash = stack_elems [- 2 ]
251
+
252
+ clause_name = None
253
+ for clause in self .clauses :
254
+ if leaf_hash == self .get_tr_info (data ).leaves [clause .name ].script :
255
+ clause_name = clause .name
256
+ break
257
+ if clause_name is None :
258
+ raise ValueError ("Clause not found" )
259
+
260
+ return clause_name , self ._clauses_dict [clause_name ].args_from_stack_elements (stack_elems [:- 2 ])
261
+
262
+
136
263
# params:
137
264
# - alice_pk
138
265
# - bob_pk
139
266
# - c_a
140
267
# spending conditions:
141
268
# - bob_pk (m_b) => RPSGameS0[m_b]
142
- class RPSGameS0 (P2TR ):
269
+ class RPSGameS0 (StandardP2TR ):
143
270
def __init__ (self , alice_pk : bytes , bob_pk : bytes , c_a : bytes ):
144
271
assert len (alice_pk ) == 32 and len (bob_pk ) == 32 and len (c_a ) == 32
145
272
@@ -148,7 +275,7 @@ def __init__(self, alice_pk: bytes, bob_pk: bytes, c_a: bytes):
148
275
self .c_a = c_a
149
276
150
277
# witness: <m_b> <bob_sig>
151
- bob_move = (
278
+ bob_move = StandardClause (
152
279
"bob_move" ,
153
280
CScript ([
154
281
bob_pk ,
@@ -165,7 +292,10 @@ def __init__(self, alice_pk: bytes, bob_pk: bytes, c_a: bytes):
165
292
RPSGameS1 (alice_pk , bob_pk , c_a ).get_taptree (),
166
293
0 , # flags
167
294
OP_CHECKCONTRACTVERIFY ,
168
- ])
295
+ ]), [
296
+ ('m_b' , int ),
297
+ ('bob_sig' , bytes ),
298
+ ]
169
299
)
170
300
171
301
super ().__init__ (NUMS_KEY , [bob_move ])
@@ -181,15 +311,12 @@ def __init__(self, alice_pk: bytes, bob_pk: bytes, c_a: bytes):
181
311
# - alice_pk, reveal winning move => ctv(alice wins)
182
312
# - alice_pk, reveal losing move => ctv(bob wins)
183
313
# - alice_pk, reveal tie move => ctv(tie)
184
- class RPSGameS1 (AugmentedP2TR ):
314
+ class RPSGameS1 (StandardAugmentedP2TR ):
185
315
def __init__ (self , alice_pk : bytes , bob_pk : bytes , c_a : bytes ):
186
316
self .alice_pk = alice_pk
187
317
self .bob_pk = bob_pk
188
318
self .c_a = c_a
189
319
190
- super ().__init__ (NUMS_KEY )
191
-
192
- def get_scripts (self ):
193
320
def make_script (outcome : int , ctv_hash : bytes ):
194
321
assert 0 <= outcome <= 2
195
322
# witness: [<m_b> <m_a> <r_a>]
@@ -260,11 +387,16 @@ def make_ctv_hash(alice_amount, bob_amount) -> bytes:
260
387
ctvhash_bob_wins = make_ctv_hash (0 , 2 * STAKE )
261
388
ctvhash_tie = make_ctv_hash (STAKE , STAKE )
262
389
263
- alice_wins = ("tie" , make_script (0 , ctvhash_tie ))
264
- bob_wins = ("bob_wins" , make_script (1 , ctvhash_bob_wins ))
265
- tie = ("alice_wins" , make_script (2 , ctvhash_alice_wins ))
390
+ arg_specs = [
391
+ ('m_b' , int ),
392
+ ('m_a' , int ),
393
+ ('r_a' , bytes ),
394
+ ]
395
+ alice_wins = StandardClause ("tie" , make_script (0 , ctvhash_tie ), arg_specs )
396
+ bob_wins = StandardClause ("bob_wins" , make_script (1 , ctvhash_bob_wins ), arg_specs )
397
+ tie = StandardClause ("alice_wins" , make_script (2 , ctvhash_alice_wins ), arg_specs )
266
398
267
- return [alice_wins , bob_wins , tie ]
399
+ super (). __init__ ( NUMS_KEY , [alice_wins , bob_wins , tie ])
268
400
269
401
270
402
load_dotenv ()
@@ -379,20 +511,11 @@ def start_session(self, m_a: int):
379
511
380
512
assert len (in_wit .scriptWitness .stack ) == 4
381
513
382
- [m_b_bytes , bob_sig , leaf_script , control_block ] = in_wit .scriptWitness .stack
383
-
384
- assert leaf_script == contract_S0 .get_tr_info ().leaves ["bob_move" ].script
385
- assert control_block == contract_S0 .get_tr_info ().controlblock_for_script_spend ("bob_move" )
386
-
387
- assert 0 <= len (m_b_bytes ) <= 1
388
-
389
- m_b = vch2bn (m_b_bytes )
390
- m_b_hash = sha256 (m_b_bytes )
391
-
514
+ _ , args = contract_S0 .decode_wit_stack (in_wit .scriptWitness .stack )
515
+ m_b : int = args ['m_b' ]
392
516
assert 0 <= m_b <= 2
393
517
394
- print (m_b_bytes .hex ())
395
- print (f"Bob's move: { m_b } ({ RPS .move_str (m_b )} ). Hash: { m_b_hash .hex ()} " )
518
+ print (f"Bob's move: { m_b } ({ RPS .move_str (m_b )} )." )
396
519
397
520
outcome = RPS .adjudicate (m_a , m_b )
398
521
print (f"Game result: { outcome } " )
@@ -431,14 +554,10 @@ def start_session(self, m_a: int):
431
554
)
432
555
]
433
556
557
+ m_b_hash = sha256 (script .bn2vch (m_b ))
558
+
434
559
tx_payout .wit .vtxinwit = [CTxInWitness ()]
435
- tx_payout .wit .vtxinwit [0 ].scriptWitness .stack = [
436
- script .bn2vch (m_b ),
437
- script .bn2vch (m_a ),
438
- r_a ,
439
- RPS_S1 .get_tr_info (m_b_hash ).leaves [outcome ].script ,
440
- RPS_S1 .get_tr_info (m_b_hash ).controlblock_for_script_spend (outcome ),
441
- ]
560
+ tx_payout .wit .vtxinwit [0 ].scriptWitness .stack = RPS_S1 .encode_args (outcome , m_b_hash , m_b = m_b , m_a = m_a , r_a = r_a )
442
561
443
562
self .prompt ("Broadcasting adjudication transaction" )
444
563
txid = rpc .sendrawtransaction (tx_payout .serialize ().hex ())
@@ -518,12 +637,8 @@ def join_session(self, m_b: int):
518
637
bob_sig = key .sign_schnorr (self .priv_key .privkey , sighash )
519
638
520
639
tx .wit .vtxinwit = [CTxInWitness ()]
521
- tx .wit .vtxinwit [0 ].scriptWitness .stack = [
522
- script .bn2vch (m_b ),
523
- bob_sig ,
524
- contract_S0 .get_tr_info ().leaves ["bob_move" ].script ,
525
- contract_S0 .get_tr_info ().controlblock_for_script_spend ("bob_move" ),
526
- ]
640
+
641
+ tx .wit .vtxinwit [0 ].scriptWitness .stack = contract_S0 .encode_args ('bob_move' , m_b = m_b , bob_sig = bob_sig )
527
642
528
643
txid = tx .rehash ()
529
644
@@ -544,10 +659,9 @@ def join_session(self, m_b: int):
544
659
tx , vin , last_height = wait_for_spending_tx (rpc , contract_S1_outpoint , starting_height = last_height )
545
660
in_wit : CTxInWitness = tx .wit .vtxinwit [vin ]
546
661
547
- leaf_script , control_block = in_wit .scriptWitness .stack [- 2 :]
548
- for transition_name in map (lambda x : x [0 ], S1 .get_scripts ()):
549
- if leaf_script == S1 .get_tr_info (m_b_hash ).leaves [transition_name ].script and control_block == S1 .get_tr_info (m_b_hash ).controlblock_for_script_spend (transition_name ):
550
- print (f"Outcome: { transition_name } " )
662
+ outcome , _ = S1 .decode_wit_stack (m_b_hash , in_wit .scriptWitness .stack )
663
+
664
+ print (f"Outcome: { outcome } " )
551
665
552
666
s .close ()
553
667
0 commit comments