Skip to content

Commit 7375c6a

Browse files
committed
Refactored general Matt utilities and rps contracts in separate files
1 parent 264657e commit 7375c6a

File tree

3 files changed

+349
-337
lines changed

3 files changed

+349
-337
lines changed

matt.py

+198
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
from btctools import script, key
2+
from btctools.script import CScript, CTxOut, TaprootInfo
3+
4+
# Flags for OP_CHECKCONTRACTVERIFY
5+
CCV_FLAG_CHECK_INPUT: int = 1
6+
CCV_FLAG_IGNORE_OUTPUT_AMOUNT: int = 2
7+
8+
9+
def vch2bn(s: bytes) -> int:
10+
"""Convert bitcoin-specific little endian format to number."""
11+
if len(s) == 0:
12+
return 0
13+
# The most significant bit is the sign bit.
14+
is_negative = s[0] & 0x80 != 0
15+
# Mask off the sign bit.
16+
s_abs = bytes([s[0] & 0x7f]) + s[1:]
17+
v_abs = int.from_bytes(s_abs, 'little')
18+
# Return as negative number if it's negative.
19+
return -v_abs if is_negative else v_abs
20+
21+
22+
class Clause:
23+
def __init__(self, name: str, script: CScript):
24+
self.name = name
25+
self.script = script
26+
27+
def stack_elements_from_args(self, args: dict) -> list[bytes]:
28+
raise NotImplementedError
29+
30+
def args_from_stack_elements(self, elements: list[bytes]) -> dict:
31+
raise NotImplementedError
32+
33+
34+
StandardType = type[int] | type[bytes]
35+
36+
37+
# A StandardClause encodes simple scripts where the witness is exactly
38+
# a list of arguments, always in the same order, and each is either
39+
# an integer or a byte array.
40+
# Other types of generic treatable clauses could be defined (for example, a MiniscriptClause).
41+
class StandardClause(Clause):
42+
def __init__(self, name: str, script: CScript, arg_specs: list[tuple[str, StandardType]]):
43+
super().__init__(name, script)
44+
self.arg_specs = arg_specs
45+
46+
for _, arg_cls in self.arg_specs:
47+
if arg_cls not in [int, bytes]:
48+
raise ValueError(f"Unsupported type: {arg_cls.__name__}")
49+
50+
def stack_elements_from_args(self, args: dict) -> list[bytes]:
51+
result: list[bytes] = []
52+
for arg_name, arg_cls in self.arg_specs:
53+
if arg_name not in args:
54+
raise ValueError(f"Missing argument: {arg_name}")
55+
arg_value = args[arg_name]
56+
if type(arg_value) != arg_cls:
57+
raise ValueError(f"Argument {arg_name} must be of type {arg_cls.__name__}, not {type(arg_value).__name__}")
58+
if arg_cls == int:
59+
result.append(script.bn2vch(arg_value))
60+
elif arg_cls == bytes:
61+
result.append(arg_value)
62+
else:
63+
raise ValueError("Unexpected type") # this should never happen
64+
65+
return result
66+
67+
def args_from_stack_elements(self, elements: list[bytes]) -> dict:
68+
result: dict = {}
69+
if len(elements) != len(self.arg_specs):
70+
raise ValueError(f"Expected {len(self.arg_specs)} elements, not {len(elements)}")
71+
for i, (arg_name, arg_cls) in enumerate(self.arg_specs):
72+
if arg_cls == int:
73+
result[arg_name] = vch2bn(elements[i])
74+
elif arg_cls == bytes:
75+
result[arg_name] = elements[i]
76+
else:
77+
raise ValueError("Unexpected type") # this should never happen
78+
return result
79+
80+
81+
class P2TR:
82+
"""
83+
A class representing a Pay-to-Taproot script.
84+
"""
85+
86+
def __init__(self, internal_pubkey: bytes, scripts: list[tuple[str, CScript]]):
87+
assert len(internal_pubkey) == 32
88+
89+
self.internal_pubkey = internal_pubkey
90+
self.scripts = scripts
91+
self.tr_info = script.taproot_construct(internal_pubkey, scripts)
92+
93+
def get_tr_info(self) -> TaprootInfo:
94+
return self.tr_info
95+
96+
def get_tx_out(self, value: int) -> CTxOut:
97+
return CTxOut(
98+
nValue=value,
99+
scriptPubKey=self.get_tr_info().scriptPubKey
100+
)
101+
102+
103+
class AugmentedP2TR:
104+
"""
105+
An abstract class representing a Pay-to-Taproot script with some embedded data.
106+
While the exact script can only be produced once the embedded data is known,
107+
the scripts and the "naked internal key" are decided in advance.
108+
"""
109+
110+
def __init__(self, naked_internal_pubkey: bytes):
111+
assert len(naked_internal_pubkey) == 32
112+
113+
self.naked_internal_pubkey = naked_internal_pubkey
114+
115+
def get_scripts(self) -> list[tuple[str, CScript]]:
116+
raise NotImplementedError("This must be implemented in subclasses")
117+
118+
def get_taptree(self) -> bytes:
119+
# use dummy data, since it doesn't affect the merkle root
120+
return self.get_tr_info(b'\0'*32).merkle_root
121+
122+
def get_tr_info(self, data: bytes) -> TaprootInfo:
123+
assert len(data) == 32
124+
125+
internal_pubkey, _ = key.tweak_add_pubkey(self.naked_internal_pubkey, data)
126+
127+
return script.taproot_construct(internal_pubkey, self.get_scripts())
128+
129+
def get_tx_out(self, value: int, data: bytes) -> CTxOut:
130+
return CTxOut(nValue=value, scriptPubKey=self.get_tr_info(data).scriptPubKey)
131+
132+
133+
class StandardP2TR(P2TR):
134+
"""
135+
A StandardP2TR where all the transitions are given by a StandardClause.
136+
"""
137+
138+
def __init__(self, internal_pubkey: bytes, clauses: list[StandardClause]):
139+
super().__init__(internal_pubkey, list(map(lambda x: (x.name, x.script), clauses)))
140+
self.clauses = clauses
141+
self._clauses_dict = {clause.name: clause for clause in clauses}
142+
143+
def get_scripts(self) -> list[tuple[str, CScript]]:
144+
return list(map(lambda clause: (clause.name, clause.script), self.clauses))
145+
146+
def encode_args(self, clause_name: str, **args: dict) -> list[bytes]:
147+
return [
148+
*self._clauses_dict[clause_name].stack_elements_from_args(args),
149+
self.get_tr_info().leaves[clause_name].script,
150+
self.get_tr_info().controlblock_for_script_spend(clause_name),
151+
]
152+
153+
def decode_wit_stack(self, stack_elems: list[bytes]) -> tuple[str, dict]:
154+
leaf_hash = stack_elems[-2]
155+
156+
clause_name = None
157+
for clause in self.clauses:
158+
if leaf_hash == self.get_tr_info().leaves[clause.name].script:
159+
clause_name = clause.name
160+
break
161+
if clause_name is None:
162+
raise ValueError("Clause not found")
163+
164+
return clause_name, self._clauses_dict[clause_name].args_from_stack_elements(stack_elems[:-2])
165+
166+
167+
class StandardAugmentedP2TR(AugmentedP2TR):
168+
"""
169+
An AugmentedP2TR where all the transitions are given by a StandardClause.
170+
"""
171+
172+
def __init__(self, naked_internal_pubkey: bytes, clauses: list[StandardClause]):
173+
super().__init__(naked_internal_pubkey)
174+
self.clauses = clauses
175+
self._clauses_dict = {clause.name: clause for clause in clauses}
176+
177+
def get_scripts(self) -> list[tuple[str, CScript]]:
178+
return list(map(lambda clause: (clause.name, clause.script), self.clauses))
179+
180+
def encode_args(self, clause_name: str, data: bytes, **args: dict) -> list[bytes]:
181+
return [
182+
*self._clauses_dict[clause_name].stack_elements_from_args(args),
183+
self.get_tr_info(data).leaves[clause_name].script,
184+
self.get_tr_info(data).controlblock_for_script_spend(clause_name),
185+
]
186+
187+
def decode_wit_stack(self, data: bytes, stack_elems: list[bytes]) -> tuple[str, dict]:
188+
leaf_hash = stack_elems[-2]
189+
190+
clause_name = None
191+
for clause in self.clauses:
192+
if leaf_hash == self.get_tr_info(data).leaves[clause.name].script:
193+
clause_name = clause.name
194+
break
195+
if clause_name is None:
196+
raise ValueError("Clause not found")
197+
198+
return clause_name, self._clauses_dict[clause_name].args_from_stack_elements(stack_elems[:-2])

0 commit comments

Comments
 (0)