|
14 | 14 | import scipy.sparse |
15 | 15 |
|
16 | 16 |
|
17 | | -def unit_box(name='a', shape=(11, 11), grid=None): |
| 17 | +def unit_box(name='a', shape=(11, 11), grid=None, space_order=1): |
18 | 18 | """Create a field with value 0. to 1. in each dimension""" |
19 | 19 | grid = grid or Grid(shape=shape) |
20 | | - a = Function(name=name, grid=grid) |
| 20 | + a = Function(name=name, grid=grid, space_order=space_order) |
21 | 21 | dims = tuple([np.linspace(0., 1., d) for d in shape]) |
22 | 22 | a.data[:] = np.meshgrid(*dims)[1] |
23 | 23 | return a |
24 | 24 |
|
25 | 25 |
|
26 | | -def unit_box_time(name='a', shape=(11, 11)): |
| 26 | +def unit_box_time(name='a', shape=(11, 11), space_order=1): |
27 | 27 | """Create a field with value 0. to 1. in each dimension""" |
28 | 28 | grid = Grid(shape=shape) |
29 | | - a = TimeFunction(name=name, grid=grid, time_order=1) |
| 29 | + a = TimeFunction(name=name, grid=grid, time_order=1, space_order=space_order) |
30 | 30 | dims = tuple([np.linspace(0., 1., d) for d in shape]) |
31 | 31 | a.data[0, :] = np.meshgrid(*dims)[1] |
32 | 32 | a.data[1, :] = np.meshgrid(*dims)[1] |
@@ -117,16 +117,15 @@ def test_precomputed_interpolation(r): |
117 | 117 | origin = (0, 0) |
118 | 118 |
|
119 | 119 | grid = Grid(shape=shape, origin=origin) |
120 | | - r = 2 # Constant for linear interpolation |
121 | | - # because we interpolate across 2 neighbouring points in each dimension |
122 | 120 |
|
123 | 121 | def init(data): |
| 122 | + # This is data with halo so need to shift to match the m.data expectations |
124 | 123 | for i in range(data.shape[0]): |
125 | 124 | for j in range(data.shape[1]): |
126 | | - data[i, j] = sin(grid.spacing[0]*i) + sin(grid.spacing[1]*j) |
| 125 | + data[i, j] = sin(grid.spacing[0]*(i-r)) + sin(grid.spacing[1]*(j-r)) |
127 | 126 | return data |
128 | 127 |
|
129 | | - m = Function(name='m', grid=grid, initializer=init, space_order=0) |
| 128 | + m = Function(name='m', grid=grid, initializer=init, space_order=r) |
130 | 129 |
|
131 | 130 | gridpoints, interpolation_coeffs = precompute_linear_interpolation(points, |
132 | 131 | grid, origin, |
@@ -154,10 +153,8 @@ def test_precomputed_interpolation_time(r): |
154 | 153 | origin = (0, 0) |
155 | 154 |
|
156 | 155 | grid = Grid(shape=shape, origin=origin) |
157 | | - r = 2 # Constant for linear interpolation |
158 | | - # because we interpolate across 2 neighbouring points in each dimension |
159 | 156 |
|
160 | | - u = TimeFunction(name='u', grid=grid, space_order=0, save=5) |
| 157 | + u = TimeFunction(name='u', grid=grid, space_order=r, save=5) |
161 | 158 | for it in range(5): |
162 | 159 | u.data[it, :] = it |
163 | 160 |
|
@@ -190,11 +187,7 @@ def test_precomputed_injection(r): |
190 | 187 | origin = (0, 0) |
191 | 188 | result = 0.25 |
192 | 189 |
|
193 | | - # Constant for linear interpolation |
194 | | - # because we interpolate across 2 neighbouring points in each dimension |
195 | | - r = 2 |
196 | | - |
197 | | - m = unit_box(shape=shape) |
| 190 | + m = unit_box(shape=shape, space_order=r) |
198 | 191 | m.data[:] = 0. |
199 | 192 |
|
200 | 193 | gridpoints, interpolation_coeffs = precompute_linear_interpolation(coords, |
@@ -228,11 +221,7 @@ def test_precomputed_injection_time(r): |
228 | 221 | result = 0.25 |
229 | 222 | nt = 20 |
230 | 223 |
|
231 | | - # Constant for linear interpolation |
232 | | - # because we interpolate across 2 neighbouring points in each dimension |
233 | | - r = 2 |
234 | | - |
235 | | - m = unit_box_time(shape=shape) |
| 224 | + m = unit_box_time(shape=shape, space_order=r) |
236 | 225 | m.data[:] = 0. |
237 | 226 |
|
238 | 227 | gridpoints, interpolation_coeffs = precompute_linear_interpolation(coords, |
@@ -761,3 +750,16 @@ def test_inject_function(): |
761 | 750 | for i in [0, 1, 3, 4]: |
762 | 751 | for j in [0, 1, 3, 4]: |
763 | 752 | assert u.data[1, i, j] == 0 |
| 753 | + |
| 754 | + |
| 755 | +def test_interpolation_radius(): |
| 756 | + nt = 11 |
| 757 | + |
| 758 | + grid = Grid(shape=(5, 5)) |
| 759 | + u = TimeFunction(name="u", grid=grid, space_order=0) |
| 760 | + src = SparseTimeFunction(name="src", grid=grid, nt=nt, npoint=1) |
| 761 | + try: |
| 762 | + src.interpolate(u) |
| 763 | + assert False |
| 764 | + except ValueError: |
| 765 | + assert True |
0 commit comments