@@ -5,7 +5,6 @@ defmodule NxSignal.Windows do
55 import Nx.Defn
66
77 @ pi :math . pi ( )
8- @ eps 1.0e-7
98
109 @ doc """
1110 Rectangular window.
@@ -302,44 +301,46 @@ defmodule NxSignal.Windows do
302301 otherwise produces a symmetric window. Defaults to `true`
303302 * `:type` - the output type for the window. Defaults to `{:f, 32}`
304303 * `:beta` - Shape parameter for the window. As beta increases, the window becomes more focused in frequency domain. Defaults to 12.0.
305- * `:name ` - the axis name. Defaults to `nil`
304+ * `:axis_name ` - the axis name. Defaults to `nil`
306305
307306 ## Examples
308307 iex> NxSignal.Windows.kaiser(n: 4, beta: 12.0, is_periodic: true)
309308 #Nx.Tensor<
310309 f32[4]
311- [5.2776191296288744e-5, 0.21566684544086456 , 1.0, 0.21566684544086456 ]
310+ [5.2776191296288744e-5, 0.21566666662693024 , 1.0, 0.21566666662693024 ]
312311 >
313312
314313 iex> NxSignal.Windows.kaiser(n: 5, beta: 12.0, is_periodic: true)
315314 #Nx.Tensor<
316315 f32[5]
317- [5.2776191296288744e-5, 0.10171464085578918, 0.7929376363754272 , 0.7929376363754272 , 0.10171464085578918]
316+ [5.2776191296288744e-5, 0.10171464085578918, 0.7929369807243347 , 0.7929369807243347 , 0.10171464085578918]
318317 >
319318
320319 iex> NxSignal.Windows.kaiser(n: 4, beta: 12.0, is_periodic: false)
321320 #Nx.Tensor<
322321 f32[4]
323- [5.2776191296288744e-5, 0.5188400149345398 , 0.5188400149345398 , 5.2776191296288744e-5]
322+ [5.2776191296288744e-5, 0.5188394784927368 , 0.5188390612602234 , 5.2776191296288744e-5]
324323 >
325324 """
326325 @ doc type: :windowing
327326 defn kaiser ( opts \\ [ ] ) do
328- opts = keyword! ( opts , [ :n , :name , beta: 12.0 , is_periodic: true , type: { :f , 32 } ] )
327+ opts =
328+ keyword! ( opts , [ :n , :axis_name , eps: 1.0e-7 , beta: 12.0 , is_periodic: true , type: { :f , 32 } ] )
329+
329330 { l , opts } = pop_window_size ( opts )
330- name = opts [ :name ]
331+ name = opts [ :axis_name ]
331332 type = opts [ :type ]
332333 beta = opts [ :beta ]
334+ eps = opts [ :eps ]
333335 is_periodic = opts [ :is_periodic ]
334336
335337 window_length = if is_periodic , do: l + 1 , else: l
336- alpha = ( window_length - 1 ) / 2.0
337338
338- n = Nx . iota ( { window_length } , names: [ name ] , type: type )
339- ratio = ( n - alpha ) / alpha
340- r = beta * Nx . sqrt ( 1 - Nx . pow ( ratio , 2 ) + @ eps )
339+ ratio = Nx . linspace ( - 1 , 1 , n: window_length , endpoint: true , type: type ) |> Nx . rename ( [ name ] )
340+ sqrt_arg = Nx . max ( 1 - ratio ** 2 , eps )
341+ r = beta * Nx . sqrt ( sqrt_arg )
341342
342- window = i0 ( r ) / i0 ( beta )
343+ window = kaiser_bessel_i0 ( r ) / kaiser_bessel_i0 ( beta )
343344
344345 if is_periodic do
345346 Nx . slice ( window , [ 0 ] , [ l ] )
@@ -348,19 +349,19 @@ defmodule NxSignal.Windows do
348349 end
349350 end
350351
351- defnp i0 ( x ) do
352+ defnp kaiser_bessel_i0 ( x ) do
352353 abs_x = Nx . abs ( x )
353354
354355 small_x_result =
355- 1.0 +
356- Nx . pow ( abs_x , 2 ) / 4.0 +
357- Nx . pow ( abs_x , 4 ) / 64.0 +
358- Nx . pow ( abs_x , 6 ) / 2304.0 +
359- Nx . pow ( abs_x , 8 ) / 147_456.0
356+ 1 +
357+ abs_x ** 2 / 4 +
358+ abs_x ** 4 / 64 +
359+ abs_x ** 6 / 2304 +
360+ abs_x ** 8 / 147_456
360361
361362 large_x_result =
362363 Nx . exp ( abs_x ) / Nx . sqrt ( 2 * Nx.Constants . pi ( ) * abs_x ) *
363- ( 1.0 + 1.0 / ( 8.0 * abs_x ) + 9.0 / ( 128.0 * Nx . pow ( abs_x , 2 ) ) )
364+ ( 1 + 1 / ( 8 * abs_x ) + 9 / ( 128 * Nx . pow ( abs_x , 2 ) ) )
364365
365366 Nx . select ( abs_x < 3.75 , small_x_result , large_x_result )
366367 end
0 commit comments