Skip to content

Commit 111079d

Browse files
committed
api: fix interpolate with complex dtype
1 parent 2eafd3e commit 111079d

3 files changed

Lines changed: 18 additions & 2 deletions

File tree

devito/finite_differences/finite_difference.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,6 @@ def first_derivative(expr, dim, fd_order, **kwargs):
157157

158158
def make_derivative(expr, dim, fd_order, deriv_order, side, matvec, x0, coefficients,
159159
expand, weights=None):
160-
if deriv_order == 0 and not expr.is_Add:
161-
print(expr, dim, fd_order)
162160
# Always expand time derivatives to avoid issue with buffering and streaming.
163161
# Time derivative are almost always short stencils and won't benefit from
164162
# unexpansion in the rare case the derivative is not evaluated for time stepping.

devito/ir/equations/equation.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,19 @@ def __new__(cls, *args, **kwargs):
243243
shift = 0
244244
expr = uxreplace(expr, {d: IntDiv(index, d.symbolic_factor) + shift})
245245

246+
# Merge conditionals when possible. E.g if we have an implicit_dim
247+
# and there is a dimension with the same parent, we ca merged
248+
# its condition
249+
for d in input_expr.implicit_dims:
250+
if d not in conditionals:
251+
continue
252+
for cd in dict(conditionals):
253+
if cd.parent == d.parent and cd != d:
254+
cond = conditionals.pop(d)
255+
mode = cd.relation and d.relation
256+
conditionals[cd] = mode(cond, conditionals[cd])
257+
break
258+
246259
conditionals = frozendict(conditionals)
247260

248261
# Lower all Differentiable operations into SymPy operations

devito/passes/clusters/cse.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,14 @@ def cse(cluster, sregistry=None, options=None, **kwargs):
103103
if cluster.is_fence:
104104
return cluster
105105

106+
<<<<<<< HEAD
106107
def make(e):
107108
edtype = cse_dtype(e.dtype, dtype)
108109
return CTemp(name=sregistry.make_name(), dtype=edtype)
110+
=======
111+
make_dtype = lambda e: cse_dtype(e.dtype, dtype)
112+
make = lambda e: CTemp(name=sregistry.make_name(), dtype=make_dtype(e))
113+
>>>>>>> 54c5e49e2 (api: fix interpolate with complex dtype)
109114

110115
exprs = _cse(cluster, make, min_cost=min_cost, mode=mode)
111116

0 commit comments

Comments
 (0)