1
+ from btctools import script , key
2
+ from btctools .script import CScript , CTxOut , TaprootInfo
3
+
4
+ # Flags for OP_CHECKCONTRACTVERIFY
5
+ CCV_FLAG_CHECK_INPUT : int = 1
6
+ CCV_FLAG_IGNORE_OUTPUT_AMOUNT : int = 2
7
+
8
+
9
+ def vch2bn (s : bytes ) -> int :
10
+ """Convert bitcoin-specific little endian format to number."""
11
+ if len (s ) == 0 :
12
+ return 0
13
+ # The most significant bit is the sign bit.
14
+ is_negative = s [0 ] & 0x80 != 0
15
+ # Mask off the sign bit.
16
+ s_abs = bytes ([s [0 ] & 0x7f ]) + s [1 :]
17
+ v_abs = int .from_bytes (s_abs , 'little' )
18
+ # Return as negative number if it's negative.
19
+ return - v_abs if is_negative else v_abs
20
+
21
+
22
+ class Clause :
23
+ def __init__ (self , name : str , script : CScript ):
24
+ self .name = name
25
+ self .script = script
26
+
27
+ def stack_elements_from_args (self , args : dict ) -> list [bytes ]:
28
+ raise NotImplementedError
29
+
30
+ def args_from_stack_elements (self , elements : list [bytes ]) -> dict :
31
+ raise NotImplementedError
32
+
33
+
34
+ StandardType = type [int ] | type [bytes ]
35
+
36
+
37
+ # A StandardClause encodes simple scripts where the witness is exactly
38
+ # a list of arguments, always in the same order, and each is either
39
+ # an integer or a byte array.
40
+ # Other types of generic treatable clauses could be defined (for example, a MiniscriptClause).
41
+ class StandardClause (Clause ):
42
+ def __init__ (self , name : str , script : CScript , arg_specs : list [tuple [str , StandardType ]]):
43
+ super ().__init__ (name , script )
44
+ self .arg_specs = arg_specs
45
+
46
+ for _ , arg_cls in self .arg_specs :
47
+ if arg_cls not in [int , bytes ]:
48
+ raise ValueError (f"Unsupported type: { arg_cls .__name__ } " )
49
+
50
+ def stack_elements_from_args (self , args : dict ) -> list [bytes ]:
51
+ result : list [bytes ] = []
52
+ for arg_name , arg_cls in self .arg_specs :
53
+ if arg_name not in args :
54
+ raise ValueError (f"Missing argument: { arg_name } " )
55
+ arg_value = args [arg_name ]
56
+ if type (arg_value ) != arg_cls :
57
+ raise ValueError (f"Argument { arg_name } must be of type { arg_cls .__name__ } , not { type (arg_value ).__name__ } " )
58
+ if arg_cls == int :
59
+ result .append (script .bn2vch (arg_value ))
60
+ elif arg_cls == bytes :
61
+ result .append (arg_value )
62
+ else :
63
+ raise ValueError ("Unexpected type" ) # this should never happen
64
+
65
+ return result
66
+
67
+ def args_from_stack_elements (self , elements : list [bytes ]) -> dict :
68
+ result : dict = {}
69
+ if len (elements ) != len (self .arg_specs ):
70
+ raise ValueError (f"Expected { len (self .arg_specs )} elements, not { len (elements )} " )
71
+ for i , (arg_name , arg_cls ) in enumerate (self .arg_specs ):
72
+ if arg_cls == int :
73
+ result [arg_name ] = vch2bn (elements [i ])
74
+ elif arg_cls == bytes :
75
+ result [arg_name ] = elements [i ]
76
+ else :
77
+ raise ValueError ("Unexpected type" ) # this should never happen
78
+ return result
79
+
80
+
81
+ class P2TR :
82
+ """
83
+ A class representing a Pay-to-Taproot script.
84
+ """
85
+
86
+ def __init__ (self , internal_pubkey : bytes , scripts : list [tuple [str , CScript ]]):
87
+ assert len (internal_pubkey ) == 32
88
+
89
+ self .internal_pubkey = internal_pubkey
90
+ self .scripts = scripts
91
+ self .tr_info = script .taproot_construct (internal_pubkey , scripts )
92
+
93
+ def get_tr_info (self ) -> TaprootInfo :
94
+ return self .tr_info
95
+
96
+ def get_tx_out (self , value : int ) -> CTxOut :
97
+ return CTxOut (
98
+ nValue = value ,
99
+ scriptPubKey = self .get_tr_info ().scriptPubKey
100
+ )
101
+
102
+
103
+ class AugmentedP2TR :
104
+ """
105
+ An abstract class representing a Pay-to-Taproot script with some embedded data.
106
+ While the exact script can only be produced once the embedded data is known,
107
+ the scripts and the "naked internal key" are decided in advance.
108
+ """
109
+
110
+ def __init__ (self , naked_internal_pubkey : bytes ):
111
+ assert len (naked_internal_pubkey ) == 32
112
+
113
+ self .naked_internal_pubkey = naked_internal_pubkey
114
+
115
+ def get_scripts (self ) -> list [tuple [str , CScript ]]:
116
+ raise NotImplementedError ("This must be implemented in subclasses" )
117
+
118
+ def get_taptree (self ) -> bytes :
119
+ # use dummy data, since it doesn't affect the merkle root
120
+ return self .get_tr_info (b'\0 ' * 32 ).merkle_root
121
+
122
+ def get_tr_info (self , data : bytes ) -> TaprootInfo :
123
+ assert len (data ) == 32
124
+
125
+ internal_pubkey , _ = key .tweak_add_pubkey (self .naked_internal_pubkey , data )
126
+
127
+ return script .taproot_construct (internal_pubkey , self .get_scripts ())
128
+
129
+ def get_tx_out (self , value : int , data : bytes ) -> CTxOut :
130
+ return CTxOut (nValue = value , scriptPubKey = self .get_tr_info (data ).scriptPubKey )
131
+
132
+
133
+ class StandardP2TR (P2TR ):
134
+ """
135
+ A StandardP2TR where all the transitions are given by a StandardClause.
136
+ """
137
+
138
+ def __init__ (self , internal_pubkey : bytes , clauses : list [StandardClause ]):
139
+ super ().__init__ (internal_pubkey , list (map (lambda x : (x .name , x .script ), clauses )))
140
+ self .clauses = clauses
141
+ self ._clauses_dict = {clause .name : clause for clause in clauses }
142
+
143
+ def get_scripts (self ) -> list [tuple [str , CScript ]]:
144
+ return list (map (lambda clause : (clause .name , clause .script ), self .clauses ))
145
+
146
+ def encode_args (self , clause_name : str , ** args : dict ) -> list [bytes ]:
147
+ return [
148
+ * self ._clauses_dict [clause_name ].stack_elements_from_args (args ),
149
+ self .get_tr_info ().leaves [clause_name ].script ,
150
+ self .get_tr_info ().controlblock_for_script_spend (clause_name ),
151
+ ]
152
+
153
+ def decode_wit_stack (self , stack_elems : list [bytes ]) -> tuple [str , dict ]:
154
+ leaf_hash = stack_elems [- 2 ]
155
+
156
+ clause_name = None
157
+ for clause in self .clauses :
158
+ if leaf_hash == self .get_tr_info ().leaves [clause .name ].script :
159
+ clause_name = clause .name
160
+ break
161
+ if clause_name is None :
162
+ raise ValueError ("Clause not found" )
163
+
164
+ return clause_name , self ._clauses_dict [clause_name ].args_from_stack_elements (stack_elems [:- 2 ])
165
+
166
+
167
+ class StandardAugmentedP2TR (AugmentedP2TR ):
168
+ """
169
+ An AugmentedP2TR where all the transitions are given by a StandardClause.
170
+ """
171
+
172
+ def __init__ (self , naked_internal_pubkey : bytes , clauses : list [StandardClause ]):
173
+ super ().__init__ (naked_internal_pubkey )
174
+ self .clauses = clauses
175
+ self ._clauses_dict = {clause .name : clause for clause in clauses }
176
+
177
+ def get_scripts (self ) -> list [tuple [str , CScript ]]:
178
+ return list (map (lambda clause : (clause .name , clause .script ), self .clauses ))
179
+
180
+ def encode_args (self , clause_name : str , data : bytes , ** args : dict ) -> list [bytes ]:
181
+ return [
182
+ * self ._clauses_dict [clause_name ].stack_elements_from_args (args ),
183
+ self .get_tr_info (data ).leaves [clause_name ].script ,
184
+ self .get_tr_info (data ).controlblock_for_script_spend (clause_name ),
185
+ ]
186
+
187
+ def decode_wit_stack (self , data : bytes , stack_elems : list [bytes ]) -> tuple [str , dict ]:
188
+ leaf_hash = stack_elems [- 2 ]
189
+
190
+ clause_name = None
191
+ for clause in self .clauses :
192
+ if leaf_hash == self .get_tr_info (data ).leaves [clause .name ].script :
193
+ clause_name = clause .name
194
+ break
195
+ if clause_name is None :
196
+ raise ValueError ("Clause not found" )
197
+
198
+ return clause_name , self ._clauses_dict [clause_name ].args_from_stack_elements (stack_elems [:- 2 ])
0 commit comments