Skip to content

Commit 7763d03

Browse files
authored
Merge pull request #2190 from devitocodes/no-more-dyn-classes-final
dsl: No more dynamic classes for AbstractFunctions
2 parents 67e5779 + 25c856a commit 7763d03

22 files changed

Lines changed: 446 additions & 169 deletions

devito/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@
3838
from devito.mpi.routines import mpi_registry
3939
from devito.operator import profiler_registry, operator_registry
4040

41+
# Apply monkey-patching while we wait for our patches to be upstreamed and released
42+
from devito.mpatches import * # noqa
43+
4144

4245
from ._version import get_versions # noqa
4346
__version__ = get_versions()['version']

devito/arch/compiler.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import numpy.ctypeslib as npct
1212
from codepy.jit import compile_from_string
13-
from codepy.toolchain import GCCToolchain, call_capture_output
13+
from codepy.toolchain import GCCToolchain, call_capture_output as _call_capture_output
1414

1515
from devito.arch import (AMDGPUX, Cpu64, M1, NVIDIAX, POWER8, POWER9, GRAVITON,
1616
INTELGPUX, get_nvidia_cc, check_cuda_runtime,
@@ -19,7 +19,7 @@
1919
from devito.logger import debug, warning, error
2020
from devito.parameters import configuration
2121
from devito.tools import (as_list, change_directory, filter_ordered,
22-
memoized_func, memoized_meth, make_tempdir)
22+
memoized_func, make_tempdir)
2323

2424
__all__ = ['sniff_mpi_distro', 'compiler_registry']
2525

@@ -123,6 +123,15 @@ def sniff_mpi_flags(mpicc='mpicc'):
123123
return compile_flags.split(), link_flags.split()
124124

125125

126+
@memoized_func
127+
def call_capture_output(cmd):
128+
"""
129+
Memoize calls to codepy's `call_capture_output` to avoid leaking memory due
130+
to some prefork/subprocess voodoo.
131+
"""
132+
return _call_capture_output(cmd)
133+
134+
126135
class Compiler(GCCToolchain):
127136
"""
128137
Base class for all compiler classes.
@@ -220,12 +229,16 @@ def __new_with__(self, **kwargs):
220229
def name(self):
221230
return self.__class__.__name__
222231

223-
@memoized_meth
232+
def get_version(self):
233+
result, stdout, stderr = call_capture_output((self.cc, "--version"))
234+
if result != 0:
235+
raise RuntimeError(f"version query failed: {stderr}")
236+
return stdout
237+
224238
def get_jit_dir(self):
225239
"""A deterministic temporary directory for jit-compiled objects."""
226240
return make_tempdir('jitcache')
227241

228-
@memoized_meth
229242
def get_codepy_dir(self):
230243
"""A deterministic temporary directory for the codepy cache."""
231244
return make_tempdir('codepy')
@@ -729,9 +742,9 @@ def __init__(self, *args, **kwargs):
729742

730743
def get_version(self):
731744
if configuration['mpi']:
732-
cmd = [self.cc, "-cc=%s" % self.CC, "--version"]
745+
cmd = (self.cc, "-cc=%s" % self.CC, "--version")
733746
else:
734-
cmd = [self.cc, "--version"]
747+
cmd = (self.cc, "--version")
735748
result, stdout, stderr = call_capture_output(cmd)
736749
if result != 0:
737750
raise RuntimeError(f"version query failed: {stderr}")

devito/builtins/arithmetic.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,9 @@ def sum(f, dims=None):
8383
elif f.is_SparseTimeFunction:
8484
if f.time_dim in dims:
8585
# Sum over time -> SparseFunction
86-
new_coords = f.coordinates._rebuild(name="%ssum_coords" % f.name)
86+
new_coords = f.coordinates._rebuild(
87+
name="%ssum_coords" % f.name, initializer=f.coordinates.initializer
88+
)
8789
out = dv.SparseFunction(name="%ssum" % f.name, grid=f.grid,
8890
dimensions=new_dims, npoint=f.shape[1],
8991
coordinates=new_coords)

devito/finite_differences/derivative.py

Lines changed: 44 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,18 @@ def __new__(cls, expr, *dims, **kwargs):
106106
obj._deriv_order = orders if skip else DimensionTuple(*orders, getters=obj._dims)
107107
obj._side = kwargs.get("side")
108108
obj._transpose = kwargs.get("transpose", direct)
109-
obj._ppsubs = as_tuple(frozendict(i) for i in
110-
kwargs.get("subs", kwargs.get("_ppsubs", [])))
109+
110+
ppsubs = kwargs.get("subs", kwargs.get("_ppsubs", []))
111+
processed = []
112+
if ppsubs:
113+
for i in ppsubs:
114+
try:
115+
processed.append(frozendict(i))
116+
except AttributeError:
117+
# E.g. `i` is a Transform object
118+
processed.append(i)
119+
obj._ppsubs = tuple(processed)
120+
111121
obj._x0 = frozendict(kwargs.get('x0', {}))
112122
return obj
113123

@@ -207,34 +217,51 @@ def _new_from_self(self, **kwargs):
207217
def func(self, expr, *args, **kwargs):
208218
return self._new_from_self(expr=expr, **kwargs)
209219

210-
def subs(self, *args, **kwargs):
211-
"""
212-
Bypass sympy.Subs as Devito has its own lazy evaluation mechanism.
213-
"""
214-
# Check if we are calling subs(self, old, new, **hint) in which case
215-
# return the standard substitution. Need to check `==` rather than `is`
216-
# because a new derivative could be created i.e `f.dx.subs(f.dx, y)`
217-
if len(args) == 2 and args[0] == self:
218-
return args[1]
219-
try:
220-
rules = dict(*args)
221-
except TypeError:
222-
rules = dict((args,))
223-
kwargs.pop('simultaneous', None)
224-
return self.xreplace(rules, **kwargs)
220+
def _subs(self, old, new, **hints):
221+
# Basic case
222+
if old == self:
223+
return new
224+
# Is it in expr?
225+
if self.expr.has(old):
226+
newexpr = self.expr._subs(old, new, **hints)
227+
try:
228+
return self._new_from_self(expr=newexpr)
229+
except ValueError:
230+
# Expr replacement leads to non-differentiable expression
231+
# e.g `f.dx.subs(f: 1) = 1.dx = 0`
232+
# returning zero
233+
return sympy.S.Zero
234+
235+
# In case `x0` was passed as a substitution instead of `(x0=`
236+
if str(old) == 'x0':
237+
return self._new_from_self(x0={self.dims[0]: new})
238+
239+
# Trying to substitute by another derivative with different metadata
240+
# Only need to check if is a Derivative since one for the cases above would
241+
# have found it
242+
if isinstance(old, Derivative):
243+
return self
244+
245+
# Fall back if we didn't catch any special case
246+
return self.xreplace({old: new}, **hints)
225247

226248
def _xreplace(self, subs):
227249
"""
228250
This is a helper method used internally by SymPy. We exploit it to postpone
229251
substitutions until evaluation.
230252
"""
253+
# Return if no subs
254+
if not subs:
255+
return self, False
256+
231257
# Check if trying to replace the whole expression
232258
if self in subs:
233259
new = subs.pop(self)
234260
try:
235261
return new._xreplace(subs)
236262
except AttributeError:
237263
return new, True
264+
238265
subs = self._ppsubs + (subs,) # Postponed substitutions
239266
return self._new_from_self(subs=subs), True
240267

devito/finite_differences/differentiable.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -121,9 +121,9 @@ def _eval_at(self, func):
121121
return self.func(*[getattr(a, '_eval_at', lambda x: a)(func) for a in self.args])
122122

123123
def _subs(self, old, new, **hints):
124-
if old is self:
124+
if old == self:
125125
return new
126-
if old is new:
126+
if old == new:
127127
return self
128128
args = list(self.args)
129129
for i, arg in enumerate(args):
@@ -613,15 +613,15 @@ def __init_finalize__(self, *args, **kwargs):
613613

614614
def __eq__(self, other):
615615
return (isinstance(other, Weights) and
616-
self.dimension is other.dimension and
617616
self.name == other.name and
617+
self.dimension == other.dimension and
618618
self.indices == other.indices and
619619
self.weights == other.weights)
620620

621621
__hash__ = sympy.Basic.__hash__
622622

623623
def _hashable_content(self):
624-
return super()._hashable_content() + (self.name,) + tuple(self.weights)
624+
return (self.name, self.dimension, hash(tuple(self.weights)))
625625

626626
@property
627627
def dimension(self):

devito/mpatches/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .rationaltools import * # noqa

devito/mpatches/rationaltools.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
"""Tools for manipulation of rational expressions. """
2+
3+
import importlib
4+
5+
import sympy
6+
from sympy.core import Basic, Add, sympify
7+
from sympy.core.exprtools import gcd_terms
8+
from sympy.utilities import public
9+
from sympy.utilities.iterables import iterable
10+
11+
__all__ = []
12+
13+
14+
@public
15+
def together(expr, deep=False, fraction=True):
16+
"""
17+
Denest and combine rational expressions using symbolic methods.
18+
19+
This function takes an expression or a container of expressions
20+
and puts it (them) together by denesting and combining rational
21+
subexpressions. No heroic measures are taken to minimize degree
22+
of the resulting numerator and denominator. To obtain completely
23+
reduced expression use :func:`~.cancel`. However, :func:`~.together`
24+
can preserve as much as possible of the structure of the input
25+
expression in the output (no expansion is performed).
26+
27+
A wide variety of objects can be put together including lists,
28+
tuples, sets, relational objects, integrals and others. It is
29+
also possible to transform interior of function applications,
30+
by setting ``deep`` flag to ``True``.
31+
32+
By definition, :func:`~.together` is a complement to :func:`~.apart`,
33+
so ``apart(together(expr))`` should return expr unchanged. Note
34+
however, that :func:`~.together` uses only symbolic methods, so
35+
it might be necessary to use :func:`~.cancel` to perform algebraic
36+
simplification and minimize degree of the numerator and denominator.
37+
38+
Examples
39+
========
40+
41+
>>> from sympy import together, exp
42+
>>> from sympy.abc import x, y, z
43+
44+
>>> together(1/x + 1/y)
45+
(x + y)/(x*y)
46+
>>> together(1/x + 1/y + 1/z)
47+
(x*y + x*z + y*z)/(x*y*z)
48+
49+
>>> together(1/(x*y) + 1/y**2)
50+
(x + y)/(x*y**2)
51+
52+
>>> together(1/(1 + 1/x) + 1/(1 + 1/y))
53+
(x*(y + 1) + y*(x + 1))/((x + 1)*(y + 1))
54+
55+
>>> together(exp(1/x + 1/y))
56+
exp(1/y + 1/x)
57+
>>> together(exp(1/x + 1/y), deep=True)
58+
exp((x + y)/(x*y))
59+
60+
>>> together(1/exp(x) + 1/(x*exp(x)))
61+
(x + 1)*exp(-x)/x
62+
63+
>>> together(1/exp(2*x) + 1/(x*exp(3*x)))
64+
(x*exp(x) + 1)*exp(-3*x)/x
65+
66+
"""
67+
def _together(expr):
68+
if isinstance(expr, Basic):
69+
if expr.is_Atom or (expr.is_Function and not deep):
70+
return expr
71+
elif expr.is_Add:
72+
return gcd_terms(list(map(_together, Add.make_args(expr))),
73+
fraction=fraction)
74+
elif expr.is_Pow:
75+
base = _together(expr.base)
76+
77+
if deep:
78+
exp = _together(expr.exp)
79+
else:
80+
exp = expr.exp
81+
82+
return expr.func(base, exp)
83+
else:
84+
return expr.func(*[_together(arg) for arg in expr.args])
85+
elif iterable(expr):
86+
return expr.__class__([_together(ex) for ex in expr])
87+
88+
return expr
89+
90+
return _together(sympify(expr))
91+
92+
93+
# Apply the monkey patch
94+
simplify = importlib.import_module(sympy.simplify.__module__)
95+
simplify.together = together

devito/operations/solve.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,9 @@ def linsolve(expr, target, **kwargs):
7373
The symbol w.r.t. which the equation is rearranged. May be a `Function`
7474
or any other symbolic object.
7575
"""
76-
c = factorize_target(expr, target)
76+
c, expr = factorize_target(expr, target)
7777
if c != 0:
78-
return -expr.xreplace({target: 0})/c
78+
return -expr/c
7979
raise SolveError("No linear solution found")
8080

8181

@@ -102,27 +102,39 @@ def _(expr):
102102

103103
@singledispatch
104104
def factorize_target(expr, target):
105-
return 1 if expr is target else 0
105+
return (1, 0) if expr == target else (0, expr)
106106

107107

108108
@factorize_target.register(Add)
109109
@factorize_target.register(EvalDerivative)
110110
def _(expr, target):
111111
c = 0
112112
if not expr.has(target):
113-
return c
113+
return c, expr
114114

115+
args = []
115116
for a in expr.args:
116-
c += factorize_target(a, target)
117-
return c
117+
c1, a1 = factorize_target(a, target)
118+
c += c1
119+
args.append(a1)
120+
121+
return c, expr.func(*args, evaluate=False)
118122

119123

120124
@factorize_target.register(Mul)
121125
def _(expr, target):
122126
if not expr.has(target):
123-
return 0
127+
return 0, expr
124128

125129
c = 1
130+
args = []
126131
for a in expr.args:
127-
c *= a if not a.has(target) else factorize_target(a, target)
128-
return c
132+
if not a.has(target):
133+
c *= a
134+
args.append(a)
135+
else:
136+
c1, a1 = factorize_target(a, target)
137+
c *= c1
138+
args.append(a1)
139+
140+
return c, expr.func(*args, evaluate=False)

devito/symbolics/printer.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from devito.arch.compiler import AOMPCompiler
1515
from devito.symbolics.inspection import has_integer_args
16+
from devito.types.basic import AbstractFunction
1617

1718
__all__ = ['ccode']
1819

@@ -44,10 +45,12 @@ def parenthesize(self, item, level, strict=False):
4445
return super().parenthesize(item, level, strict=strict)
4546

4647
def _print_Function(self, expr):
47-
# There exist no unknown Functions
48-
if expr.func.__name__ not in self.known_functions:
49-
self.known_functions[expr.func.__name__] = expr.func.__name__
50-
return super()._print_Function(expr)
48+
if isinstance(expr, AbstractFunction):
49+
return str(expr)
50+
else:
51+
if expr.func.__name__ not in self.known_functions:
52+
self.known_functions[expr.func.__name__] = expr.func.__name__
53+
return super()._print_Function(expr)
5154

5255
def _print_CondEq(self, expr):
5356
return "%s == %s" % (self._print(expr.lhs), self._print(expr.rhs))

devito/symbolics/queries.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,11 @@ def q_function(expr):
4343
return isinstance(expr, DiscreteFunction)
4444

4545

46+
def q_derivative(expr):
47+
from devito.finite_differences.derivative import Derivative
48+
return isinstance(expr, Derivative)
49+
50+
4651
def q_terminal(expr):
4752
return (expr.is_Symbol or
4853
expr.is_Indexed or

0 commit comments

Comments
 (0)