diff --git a/lib/nx_signal.ex b/lib/nx_signal.ex index 077c095..d47b79f 100644 --- a/lib/nx_signal.ex +++ b/lib/nx_signal.ex @@ -43,7 +43,7 @@ defmodule NxSignal do ## Examples - iex> {z, t, f} = NxSignal.stft(Nx.iota({4}), NxSignal.Windows.rectangular(n: 2), overlap_length: 1, fft_length: 2, sampling_rate: 400) + iex> {z, t, f} = NxSignal.stft(Nx.iota({4}), NxSignal.Windows.rectangular(2), overlap_length: 1, fft_length: 2, sampling_rate: 400) iex> z #Nx.Tensor< c64[frames: 3][frequencies: 2] @@ -464,7 +464,7 @@ defmodule NxSignal do iex> fft_length = 16 iex> sampling_rate = 8.0e3 - iex> {z, _, _} = NxSignal.stft(Nx.iota({10}), NxSignal.Windows.hann(n: 4), overlap_length: 2, fft_length: fft_length, sampling_rate: sampling_rate, window_padding: :reflect) + iex> {z, _, _} = NxSignal.stft(Nx.iota({10}), NxSignal.Windows.hann(4), overlap_length: 2, fft_length: fft_length, sampling_rate: sampling_rate, window_padding: :reflect) iex> Nx.axis_size(z, :frequencies) 16 iex> Nx.axis_size(z, :frames) @@ -543,7 +543,7 @@ defmodule NxSignal do of the signal end up being distorted. iex> t = Nx.tensor([10, 10, 1, 0, 10, 10, 2, 20]) - iex> w = NxSignal.Windows.hann(n: 4) + iex> w = NxSignal.Windows.hann(4) iex> opts = [sampling_rate: 1, fft_length: 4] iex> {z, _time, _freqs} = NxSignal.stft(t, w, opts) iex> result = NxSignal.istft(z, w, opts) @@ -557,7 +557,7 @@ defmodule NxSignal do For perfect reconstruction, you want to use the same scaling as the STFT: iex> t = Nx.tensor([10, 10, 1, 0, 10, 10, 2, 20]) - iex> w = NxSignal.Windows.hann(n: 4) + iex> w = NxSignal.Windows.hann(4) iex> opts = [scaling: :spectrum, sampling_rate: 1, fft_length: 4] iex> {z, _time, _freqs} = NxSignal.stft(t, w, opts) iex> result = NxSignal.istft(z, w, opts) @@ -568,7 +568,7 @@ defmodule NxSignal do > iex> t = Nx.tensor([10, 10, 1, 0, 10, 10, 2, 20], type: :f32) - iex> w = NxSignal.Windows.hann(n: 4) + iex> w = NxSignal.Windows.hann(4) iex> opts = [scaling: :psd, sampling_rate: 1, fft_length: 4] iex> {z, _time, _freqs} = NxSignal.stft(t, w, opts) iex> result = NxSignal.istft(z, w, opts) diff --git a/lib/nx_signal/windows.ex b/lib/nx_signal/windows.ex index 1673ed4..2cc8762 100644 --- a/lib/nx_signal/windows.ex +++ b/lib/nx_signal/windows.ex @@ -13,27 +13,25 @@ defmodule NxSignal.Windows do ## Options - * `:n` - the window length * `:type` - the output type. Defaults to `s64` ## Examples - iex> NxSignal.Windows.rectangular(n: 5) + iex> NxSignal.Windows.rectangular(5) #Nx.Tensor< s64[5] [1, 1, 1, 1, 1] > - iex> NxSignal.Windows.rectangular(n: 5, type: :f32) + iex> NxSignal.Windows.rectangular(5, type: :f32) #Nx.Tensor< f32[5] [1.0, 1.0, 1.0, 1.0, 1.0] > """ @doc type: :windowing - defn rectangular(opts \\ []) do - opts = keyword!(opts, [:n, type: :s64]) - {n, opts} = pop_window_size(opts) + deftransform rectangular(n, opts \\ []) when is_integer(n) do + opts = Keyword.validate!(opts, type: :s64) Nx.broadcast(Nx.tensor(1, type: opts[:type]), {n}) end @@ -44,22 +42,25 @@ defmodule NxSignal.Windows do ## Options - * `:n` - The window length. Mandatory option. * `:type` - the output type for the window. Defaults to `{:f, 32}` * `:name` - the axis name. Defaults to `nil` ## Examples - iex> NxSignal.Windows.bartlett(n: 3) + iex> NxSignal.Windows.bartlett(3) #Nx.Tensor< f32[3] [0.0, 0.6666666865348816, 0.6666666269302368] > """ @doc type: :windowing - defn bartlett(opts \\ []) do - opts = keyword!(opts, [:n, :name, type: {:f, 32}]) - {n, opts} = pop_window_size(opts) + deftransform bartlett(n, opts \\ []) when is_integer(n) do + opts = Keyword.validate!(opts, type: {:f, 32}) + bartlett_n(Keyword.put(opts, :n, n)) + end + + defnp bartlett_n(opts) do + n = opts[:n] name = opts[:name] type = opts[:type] @@ -87,16 +88,20 @@ defmodule NxSignal.Windows do ## Examples - iex> NxSignal.Windows.triangular(n: 3) + iex> NxSignal.Windows.triangular(3) #Nx.Tensor< f32[3] [0.5, 1.0, 0.5] > """ @doc type: :windowing - defn triangular(opts \\ []) do - opts = keyword!(opts, [:n, :name, type: {:f, 32}]) - {n, opts} = pop_window_size(opts) + deftransform triangular(n, opts \\ []) when is_integer(n) do + opts = Keyword.validate!(opts, [:name, type: {:f, 32}]) + triangular_n(Keyword.put(opts, :n, n)) + end + + defnp triangular_n(opts) do + n = opts[:n] name = opts[:name] type = opts[:type] @@ -126,7 +131,6 @@ defmodule NxSignal.Windows do ## Options - * `:n` - The window length. Mandatory option. * `:is_periodic` - If `true`, produces a periodic window, otherwise produces a symmetric window. Defaults to `true` * `:type` - the output type for the window. Defaults to `{:f, 32}` @@ -134,40 +138,45 @@ defmodule NxSignal.Windows do ## Examples - iex> NxSignal.Windows.blackman(n: 5, is_periodic: false) + iex> NxSignal.Windows.blackman(5, is_periodic: false) #Nx.Tensor< f32[5] [-1.4901161193847656e-8, 0.3400000333786011, 0.9999999403953552, 0.3400000333786011, -1.4901161193847656e-8] > - iex> NxSignal.Windows.blackman(n: 5, is_periodic: true) + iex> NxSignal.Windows.blackman(5, is_periodic: true) #Nx.Tensor< f32[5] [-1.4901161193847656e-8, 0.20077012479305267, 0.8492299318313599, 0.8492299318313599, 0.20077012479305267] > - iex> NxSignal.Windows.blackman(n: 6, is_periodic: true, type: {:f, 32}) + iex> NxSignal.Windows.blackman(6, is_periodic: true, type: {:f, 32}) #Nx.Tensor< f32[6] [-1.4901161193847656e-8, 0.12999999523162842, 0.6299999952316284, 0.9999999403953552, 0.6299999952316284, 0.12999999523162842] > """ @doc type: :windowing - defn blackman(opts \\ []) do - opts = keyword!(opts, [:n, :name, is_periodic: true, type: {:f, 32}]) - {l, opts} = pop_window_size(opts) + deftransform blackman(n, opts \\ []) when is_integer(n) do + opts = Keyword.validate!(opts, [:name, is_periodic: true, type: {:f, 32}]) + blackman_n(Keyword.put(opts, :n, n)) + end + + defnp blackman_n(opts) do + n = opts[:n] name = opts[:name] type = opts[:type] is_periodic = opts[:is_periodic] l = if is_periodic do - l + 1 + n + 1 else - l + n end - m = div_ceil(l, 2) + m = + integer_div_ceil(l, 2) n = Nx.iota({m}, names: [name], type: type) @@ -194,7 +203,6 @@ defmodule NxSignal.Windows do ## Options - * `:n` - The window length. Mandatory option. * `:is_periodic` - If `true`, produces a periodic window, otherwise produces a symmetric window. Defaults to `true` * `:type` - the output type for the window. Defaults to `{:f, 32}` @@ -202,30 +210,34 @@ defmodule NxSignal.Windows do ## Examples - iex> NxSignal.Windows.hamming(n: 5, is_periodic: true) + iex> NxSignal.Windows.hamming(5, is_periodic: true) #Nx.Tensor< f32[5] [0.08000001311302185, 0.39785221219062805, 0.9121478796005249, 0.9121478199958801, 0.3978521227836609] > - iex> NxSignal.Windows.hamming(n: 5, is_periodic: false) + iex> NxSignal.Windows.hamming(5, is_periodic: false) #Nx.Tensor< f32[5] [0.08000001311302185, 0.5400000214576721, 1.0, 0.5400000214576721, 0.08000001311302185] > """ @doc type: :windowing - defn hamming(opts \\ []) do - opts = keyword!(opts, [:n, :name, is_periodic: true, type: {:f, 32}]) - {l, opts} = pop_window_size(opts) + deftransform hamming(n, opts \\ []) when is_integer(n) do + opts = Keyword.validate!(opts, [:name, is_periodic: true, type: {:f, 32}]) + hamming_n(Keyword.put(opts, :n, n)) + end + + defnp hamming_n(opts) do + n = opts[:n] name = opts[:name] type = opts[:type] is_periodic = opts[:is_periodic] l = if is_periodic do - l + 1 + n + 1 else - l + n end n = Nx.iota({l}, names: [name], type: type) @@ -244,7 +256,6 @@ defmodule NxSignal.Windows do ## Options - * `:n` - The window length. Mandatory option. * `:is_periodic` - If `true`, produces a periodic window, otherwise produces a symmetric window. Defaults to `true` * `:type` - the output type for the window. Defaults to `{:f, 32}` @@ -252,30 +263,34 @@ defmodule NxSignal.Windows do ## Examples - iex> NxSignal.Windows.hann(n: 5, is_periodic: false) + iex> NxSignal.Windows.hann(5, is_periodic: false) #Nx.Tensor< f32[5] [0.0, 0.5, 1.0, 0.5, 0.0] > - iex> NxSignal.Windows.hann(n: 5, is_periodic: true) + iex> NxSignal.Windows.hann(5, is_periodic: true) #Nx.Tensor< f32[5] [0.0, 0.34549152851104736, 0.9045085310935974, 0.9045084714889526, 0.3454914391040802] > """ @doc type: :windowing - defn hann(opts \\ []) do - opts = keyword!(opts, [:n, :name, is_periodic: true, type: {:f, 32}]) - {l, opts} = pop_window_size(opts) + deftransform hann(n, opts \\ []) when is_integer(n) do + opts = Keyword.validate!(opts, [:name, is_periodic: true, type: {:f, 32}]) + hann_n(Keyword.put(opts, :n, n)) + end + + defnp hann_n(opts) do + n = opts[:n] name = opts[:name] type = opts[:type] is_periodic = opts[:is_periodic] l = if is_periodic do - l + 1 + n + 1 else - l + n end n = Nx.iota({l}, names: [name], type: type) @@ -296,7 +311,6 @@ defmodule NxSignal.Windows do ## Options - * `:n` - The window length. Mandatory option. * `:is_periodic` - If `true`, produces a periodic window, otherwise produces a symmetric window. Defaults to `true` * `:type` - the output type for the window. Defaults to `{:f, 32}` @@ -305,46 +319,50 @@ defmodule NxSignal.Windows do * `:axis_name` - the axis name. Defaults to `nil` ## Examples - iex> NxSignal.Windows.kaiser(n: 4, beta: 12.0, is_periodic: true) + iex> NxSignal.Windows.kaiser(4, beta: 12.0, is_periodic: true) #Nx.Tensor< f32[4] [5.2776191296288744e-5, 0.21566666662693024, 1.0, 0.21566666662693024] > - iex> NxSignal.Windows.kaiser(n: 5, beta: 12.0, is_periodic: true) + iex> NxSignal.Windows.kaiser(5, beta: 12.0, is_periodic: true) #Nx.Tensor< f32[5] [5.2776191296288744e-5, 0.10171464085578918, 0.7929369807243347, 0.7929369807243347, 0.10171464085578918] > - iex> NxSignal.Windows.kaiser(n: 4, beta: 12.0, is_periodic: false) + iex> NxSignal.Windows.kaiser(4, beta: 12.0, is_periodic: false) #Nx.Tensor< f32[4] [5.2776191296288744e-5, 0.5188394784927368, 0.5188390612602234, 5.2776191296288744e-5] > """ @doc type: :windowing - defn kaiser(opts \\ []) do + deftransform kaiser(n, opts \\ []) when is_integer(n) do opts = - keyword!(opts, [:n, :axis_name, eps: 1.0e-7, beta: 12.0, is_periodic: true, type: {:f, 32}]) + Keyword.validate!(opts, [:name, eps: 1.0e-7, beta: 12.0, is_periodic: true, type: {:f, 32}]) + + kaiser_n(Keyword.put(opts, :n, n)) + end - {l, opts} = pop_window_size(opts) - name = opts[:axis_name] + defnp kaiser_n(opts) do + n = opts[:n] + name = opts[:name] type = opts[:type] beta = opts[:beta] eps = opts[:eps] is_periodic = opts[:is_periodic] - window_length = if is_periodic, do: l + 1, else: l + window_length = if is_periodic, do: n + 1, else: n - ratio = Nx.linspace(-1, 1, n: window_length, endpoint: true, type: type) |> Nx.rename([name]) + ratio = Nx.linspace(-1, 1, n: window_length, endpoint: true, type: type, name: name) sqrt_arg = Nx.max(1 - ratio ** 2, eps) r = beta * Nx.sqrt(sqrt_arg) window = kaiser_bessel_i0(r) / kaiser_bessel_i0(beta) if is_periodic do - Nx.slice(window, [0], [l]) + Nx.slice(window, [0], [n]) else window end @@ -367,17 +385,13 @@ defmodule NxSignal.Windows do Nx.select(abs_x < 3.75, small_x_result, large_x_result) end - deftransformp pop_window_size(opts) do - {n, opts} = Keyword.pop(opts, :n) + deftransformp integer_div_ceil(num, den) when is_integer(num) and is_integer(den) do + rem = rem(num, den) - if !n do - raise "missing :n option" + if rem == 0 do + div(num, den) + else + div(num, den) + 1 end - - {n, opts} - end - - deftransformp div_ceil(num, den) do - ceil(num / den) end end