-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathdecorators.py
214 lines (170 loc) · 5.34 KB
/
decorators.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
__all__ = ['log_calls', 'log_calls_recursive', 'selective_cache']
from sys import _getframe
from functools import wraps
from itertools import count, starmap, chain
from inspect import signature, Parameter
from .io_helpers import log, rlog, reprint
# Adapted from: https://stackoverflow.com/a/47956089/3889449
def _stack_size(size_hint: int=8) -> int:
'''Get number of call frames for the caller.'''
frame = 0
try:
while True:
frame = _getframe(size_hint)
size_hint *= 2
except ValueError:
if frame:
size_hint //= 2
else:
while not frame:
size_hint = max(2, size_hint // 2)
try:
frame = _getframe(size_hint)
except ValueError:
continue
for size in count(size_hint):
frame = frame.f_back
if not frame:
return size
def log_calls(log_return: bool=True):
'''Decorate a function logging arguments and return value
(if log_return=True) of every call to standard error.
'''
def decorator(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
repr_args = map(repr, args)
repr_kwargs = starmap('{}={!r}'.format, kwargs.items())
signature = ', '.join(chain(repr_args, repr_kwargs))
log('{}({})\n', fn.__name__, signature)
retval = fn(*args, **kwargs)
if log_return:
log('-> {!r}\n', retval)
return retval
return wrapper
return decorator
def log_calls_recursive(log_call: bool=True, log_return: bool=True):
'''Decorate a function logging arguments (if log_call=True) and return value
(if log_return=True) of every call to standard error, indenting the logs for
recursive calls according to the recursion level.
Additionally, add .log() and .eprint() methods to the decorated function
that will act as log() or eprint(), but add indentation according to the
current recursion level.
Usage example:
@log_calls_recursive()
def fib(n):
if n == 0 or n == 1:
return n
fib.eprint('hello!')
return n + fib(n - 1) + fib(n - 2)
>>> fib(3)
┌ fib(3)
│ hello!
│ ┌ fib(2)
│ │ hello!
│ │ ┌ fib(1)
│ │ └> 1
│ │ ┌ fib(0)
│ │ └> 0
│ └> 3
│ ┌ fib(1)
│ └> 1
└> 7
'''
# Explanation for the depth calculation:
#
# Since we are wrapping the function, every time the function is called we
# actually add 2 call frames instead of 1, because wrapper() is called,
# which then calls the real function:
#
# depth = _stack_size() // 2
#
# Additionally, since the logging is done eithr through logfunc() or
# through eprintfunc(), there is always one more frame to discard:
#
# depth = (_stack_size() - 1) // 2
#
# If logfunc() or eprintfunc() are called by the user with func.log() or
# func.eprint() inside the decorated function, then we have yet another
# frame to discard:
#
# depth = (_stack_size() - 1 - 1) // 2
#
# We use _user_call to distinguish between an internal call (_user_call=0)
# or one made by the user (_user_call=1):
#
# depth = (_stack_size() - 1 - _user_call) // 2
def decorator(fn):
initial_depth = None
def logfunc(*a, _user_call=1, **kwa):
nonlocal initial_depth
depth = (_stack_size() - 1 - _user_call) // 2
if initial_depth is None:
initial_depth = depth
rlog(depth - initial_depth)(*a, **kwa)
def eprintfunc(*a, _user_call=1, **kwa):
nonlocal initial_depth
depth = (_stack_size() - 1 - _user_call) // 2
if initial_depth is None:
initial_depth = depth
reprint(depth - initial_depth)(*a, **kwa)
fn.log = logfunc
fn.eprint = eprintfunc
@wraps(fn)
def wrapper(*args, **kwargs):
repr_args = map(repr, args)
repr_kwargs = starmap('{}={!r}'.format, kwargs.items())
signature = ', '.join(chain(repr_args, repr_kwargs))
if log_call:
fn.log('{}({})\n', fn.__name__, signature, _user_call=0, is_header=True)
retval = fn(*args, **kwargs)
if log_return:
fn.eprint(retval, _user_call=0, is_retval=True)
return retval
return wrapper
return decorator
def selective_cache(*arg_names: str):
'''Memoize results using only arguments with the specified names as key.
Note: does NOT support functions using *args, **kwargs or default values.
Example:
# Cache results using (a, b) as key.
@selective_cache('a', 'b')
def func(a, b, c):
return a + b + c
>>> func(1, 2, 3)
6
>>> func.cache
{(1, 2): 6}
>>> func(1, 2, 99)
6
func.cache: internal cache.
func.cache_clear(): clears internal cache.
Cache size is unbounded! Beware.
'''
def decorator(fn):
key_args_indexes = []
cache = {}
for i, (name, p) in enumerate(signature(fn).parameters.items()):
# We are lazy, supporting every kind of strange Python parameter
# type is very complex. Detect bad usages here and bail out.
if p.kind not in (Parameter.POSITIONAL_OR_KEYWORD, Parameter.POSITIONAL_ONLY):
raise TypeError('can only wrap functions with positional '
"parameters, and '{}' is not positional".format(name))
elif p.default != Parameter.empty:
raise TypeError('can only wrap functions without default '
"parameter values, and '{}' has a default".format(name))
if name in arg_names:
key_args_indexes.append(i)
@wraps(fn)
def wrapper(*args):
nonlocal cache, key_args_indexes
key = tuple(args[i] for i in key_args_indexes)
if key in cache:
return cache[key]
res = fn(*args)
cache[key] = res
return res
wrapper.cache = cache
wrapper.cache_clear = cache.clear
return wrapper
return decorator