Skip to content

Commit

Permalink
fix: error in float32 case
Browse files Browse the repository at this point in the history
  • Loading branch information
tk2lab committed Oct 25, 2022
1 parent 1c28220 commit 2526769
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 9 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "logbesselk"
version = "2.4.1"
version = "2.4.2"
description = "Provide function to calculate the modified Bessel function of the second kind"
license = "Apache-2.0"
authors = ["TAKEKAWA Takashi <[email protected]>"]
Expand Down
2 changes: 1 addition & 1 deletion src/logbesselk/integral.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def loop(b, fmax, out):
fmax = tf.where(fmax > ft, fmax, ft)
return b, fmax, out

init = 0, funcb(0), tf.ones(shape, dtype)
init = tf.cast(0, dtype), funcb(0), tf.ones(shape, dtype)
b, fmax, out = tf.while_loop(cond, loop, init)
h = (t1 - t0) / bins
out = tk.log(h) + fmax + tk.log(out)
Expand Down
15 changes: 8 additions & 7 deletions tests/test_logk.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,18 @@


@pytest.mark.parametrize(
'func, wrap, data', [
(f, w, d)
'func, wrap, data, dtype', [
(f, w, d, dt)
for dt in [np.float32, np.float64]
for d, f in funcs.items()
for w in [False, True]
])
def test_logk(func, wrap, data):
def test_logk(func, wrap, data, dtype):
if wrap:
func = tf.function(func)
df = pd.read_csv(f'./data/{data}_mathematica.csv')
v = df['v']
x = df['x']
val = df['true']
v = df['v'].to_numpy().astype(dtype)
x = df['x'].to_numpy().astype(dtype)
val = df['true'].to_numpy().astype(dtype)
out = func(v, x).numpy()
assert np.all(np.isclose(out, val))
assert np.allclose(out, val, rtol=5e-3, atol=0)

0 comments on commit 2526769

Please sign in to comment.