Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions lib/nx_signal.ex
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ defmodule NxSignal do

## Examples

iex> {z, t, f} = NxSignal.stft(Nx.iota({4}), NxSignal.Windows.rectangular(n: 2), overlap_length: 1, fft_length: 2, sampling_rate: 400)
iex> {z, t, f} = NxSignal.stft(Nx.iota({4}), NxSignal.Windows.rectangular(2), overlap_length: 1, fft_length: 2, sampling_rate: 400)
iex> z
#Nx.Tensor<
c64[frames: 3][frequencies: 2]
Expand Down Expand Up @@ -464,7 +464,7 @@ defmodule NxSignal do

iex> fft_length = 16
iex> sampling_rate = 8.0e3
iex> {z, _, _} = NxSignal.stft(Nx.iota({10}), NxSignal.Windows.hann(n: 4), overlap_length: 2, fft_length: fft_length, sampling_rate: sampling_rate, window_padding: :reflect)
iex> {z, _, _} = NxSignal.stft(Nx.iota({10}), NxSignal.Windows.hann(4), overlap_length: 2, fft_length: fft_length, sampling_rate: sampling_rate, window_padding: :reflect)
iex> Nx.axis_size(z, :frequencies)
16
iex> Nx.axis_size(z, :frames)
Expand Down Expand Up @@ -543,7 +543,7 @@ defmodule NxSignal do
of the signal end up being distorted.

iex> t = Nx.tensor([10, 10, 1, 0, 10, 10, 2, 20])
iex> w = NxSignal.Windows.hann(n: 4)
iex> w = NxSignal.Windows.hann(4)
iex> opts = [sampling_rate: 1, fft_length: 4]
iex> {z, _time, _freqs} = NxSignal.stft(t, w, opts)
iex> result = NxSignal.istft(z, w, opts)
Expand All @@ -557,7 +557,7 @@ defmodule NxSignal do
For perfect reconstruction, you want to use the same scaling as the STFT:

iex> t = Nx.tensor([10, 10, 1, 0, 10, 10, 2, 20])
iex> w = NxSignal.Windows.hann(n: 4)
iex> w = NxSignal.Windows.hann(4)
iex> opts = [scaling: :spectrum, sampling_rate: 1, fft_length: 4]
iex> {z, _time, _freqs} = NxSignal.stft(t, w, opts)
iex> result = NxSignal.istft(z, w, opts)
Expand All @@ -568,7 +568,7 @@ defmodule NxSignal do
>

iex> t = Nx.tensor([10, 10, 1, 0, 10, 10, 2, 20], type: :f32)
iex> w = NxSignal.Windows.hann(n: 4)
iex> w = NxSignal.Windows.hann(4)
iex> opts = [scaling: :psd, sampling_rate: 1, fft_length: 4]
iex> {z, _time, _freqs} = NxSignal.stft(t, w, opts)
iex> result = NxSignal.istft(z, w, opts)
Expand Down
138 changes: 76 additions & 62 deletions lib/nx_signal/windows.ex
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,25 @@ defmodule NxSignal.Windows do

## Options

* `:n` - the window length
* `:type` - the output type. Defaults to `s64`

## Examples

iex> NxSignal.Windows.rectangular(n: 5)
iex> NxSignal.Windows.rectangular(5)
#Nx.Tensor<
s64[5]
[1, 1, 1, 1, 1]
>

iex> NxSignal.Windows.rectangular(n: 5, type: :f32)
iex> NxSignal.Windows.rectangular(5, type: :f32)
#Nx.Tensor<
f32[5]
[1.0, 1.0, 1.0, 1.0, 1.0]
>
"""
@doc type: :windowing
defn rectangular(opts \\ []) do
opts = keyword!(opts, [:n, type: :s64])
{n, opts} = pop_window_size(opts)
deftransform rectangular(n, opts \\ []) when is_integer(n) do
opts = Keyword.validate!(opts, type: :s64)
Nx.broadcast(Nx.tensor(1, type: opts[:type]), {n})
end

Expand All @@ -44,22 +42,25 @@ defmodule NxSignal.Windows do

## Options

* `:n` - The window length. Mandatory option.
* `:type` - the output type for the window. Defaults to `{:f, 32}`
* `:name` - the axis name. Defaults to `nil`

## Examples

iex> NxSignal.Windows.bartlett(n: 3)
iex> NxSignal.Windows.bartlett(3)
#Nx.Tensor<
f32[3]
[0.0, 0.6666666865348816, 0.6666666269302368]
>
"""
@doc type: :windowing
defn bartlett(opts \\ []) do
opts = keyword!(opts, [:n, :name, type: {:f, 32}])
{n, opts} = pop_window_size(opts)
deftransform bartlett(n, opts \\ []) when is_integer(n) do
opts = Keyword.validate!(opts, type: {:f, 32})
bartlett_n(Keyword.put(opts, :n, n))
end

defnp bartlett_n(opts) do
n = opts[:n]
name = opts[:name]
type = opts[:type]

Expand Down Expand Up @@ -87,16 +88,20 @@ defmodule NxSignal.Windows do

## Examples

iex> NxSignal.Windows.triangular(n: 3)
iex> NxSignal.Windows.triangular(3)
#Nx.Tensor<
f32[3]
[0.5, 1.0, 0.5]
>
"""
@doc type: :windowing
defn triangular(opts \\ []) do
opts = keyword!(opts, [:n, :name, type: {:f, 32}])
{n, opts} = pop_window_size(opts)
deftransform triangular(n, opts \\ []) when is_integer(n) do
opts = Keyword.validate!(opts, [:name, type: {:f, 32}])
triangular_n(Keyword.put(opts, :n, n))
end

defnp triangular_n(opts) do
n = opts[:n]
name = opts[:name]
type = opts[:type]

Expand Down Expand Up @@ -126,48 +131,52 @@ defmodule NxSignal.Windows do

## Options

* `:n` - The window length. Mandatory option.
* `:is_periodic` - If `true`, produces a periodic window,
otherwise produces a symmetric window. Defaults to `true`
* `:type` - the output type for the window. Defaults to `{:f, 32}`
* `:name` - the axis name. Defaults to `nil`

## Examples

iex> NxSignal.Windows.blackman(n: 5, is_periodic: false)
iex> NxSignal.Windows.blackman(5, is_periodic: false)
#Nx.Tensor<
f32[5]
[-1.4901161193847656e-8, 0.3400000333786011, 0.9999999403953552, 0.3400000333786011, -1.4901161193847656e-8]
>

iex> NxSignal.Windows.blackman(n: 5, is_periodic: true)
iex> NxSignal.Windows.blackman(5, is_periodic: true)
#Nx.Tensor<
f32[5]
[-1.4901161193847656e-8, 0.20077012479305267, 0.8492299318313599, 0.8492299318313599, 0.20077012479305267]
>

iex> NxSignal.Windows.blackman(n: 6, is_periodic: true, type: {:f, 32})
iex> NxSignal.Windows.blackman(6, is_periodic: true, type: {:f, 32})
#Nx.Tensor<
f32[6]
[-1.4901161193847656e-8, 0.12999999523162842, 0.6299999952316284, 0.9999999403953552, 0.6299999952316284, 0.12999999523162842]
>
"""
@doc type: :windowing
defn blackman(opts \\ []) do
opts = keyword!(opts, [:n, :name, is_periodic: true, type: {:f, 32}])
{l, opts} = pop_window_size(opts)
deftransform blackman(n, opts \\ []) when is_integer(n) do
opts = Keyword.validate!(opts, [:name, is_periodic: true, type: {:f, 32}])
blackman_n(Keyword.put(opts, :n, n))
end

defnp blackman_n(opts) do
n = opts[:n]
name = opts[:name]
type = opts[:type]
is_periodic = opts[:is_periodic]

l =
if is_periodic do
l + 1
n + 1
else
l
n
end

m = div_ceil(l, 2)
m =
integer_div_ceil(l, 2)

n = Nx.iota({m}, names: [name], type: type)

Expand All @@ -194,38 +203,41 @@ defmodule NxSignal.Windows do

## Options

* `:n` - The window length. Mandatory option.
* `:is_periodic` - If `true`, produces a periodic window,
otherwise produces a symmetric window. Defaults to `true`
* `:type` - the output type for the window. Defaults to `{:f, 32}`
* `:name` - the axis name. Defaults to `nil`

## Examples

iex> NxSignal.Windows.hamming(n: 5, is_periodic: true)
iex> NxSignal.Windows.hamming(5, is_periodic: true)
#Nx.Tensor<
f32[5]
[0.08000001311302185, 0.39785221219062805, 0.9121478796005249, 0.9121478199958801, 0.3978521227836609]
>
iex> NxSignal.Windows.hamming(n: 5, is_periodic: false)
iex> NxSignal.Windows.hamming(5, is_periodic: false)
#Nx.Tensor<
f32[5]
[0.08000001311302185, 0.5400000214576721, 1.0, 0.5400000214576721, 0.08000001311302185]
>
"""
@doc type: :windowing
defn hamming(opts \\ []) do
opts = keyword!(opts, [:n, :name, is_periodic: true, type: {:f, 32}])
{l, opts} = pop_window_size(opts)
deftransform hamming(n, opts \\ []) when is_integer(n) do
opts = Keyword.validate!(opts, [:name, is_periodic: true, type: {:f, 32}])
hamming_n(Keyword.put(opts, :n, n))
end

defnp hamming_n(opts) do
n = opts[:n]
name = opts[:name]
type = opts[:type]
is_periodic = opts[:is_periodic]

l =
if is_periodic do
l + 1
n + 1
else
l
n
end

n = Nx.iota({l}, names: [name], type: type)
Expand All @@ -244,38 +256,41 @@ defmodule NxSignal.Windows do

## Options

* `:n` - The window length. Mandatory option.
* `:is_periodic` - If `true`, produces a periodic window,
otherwise produces a symmetric window. Defaults to `true`
* `:type` - the output type for the window. Defaults to `{:f, 32}`
* `:name` - the axis name. Defaults to `nil`

## Examples

iex> NxSignal.Windows.hann(n: 5, is_periodic: false)
iex> NxSignal.Windows.hann(5, is_periodic: false)
#Nx.Tensor<
f32[5]
[0.0, 0.5, 1.0, 0.5, 0.0]
>
iex> NxSignal.Windows.hann(n: 5, is_periodic: true)
iex> NxSignal.Windows.hann(5, is_periodic: true)
#Nx.Tensor<
f32[5]
[0.0, 0.34549152851104736, 0.9045085310935974, 0.9045084714889526, 0.3454914391040802]
>
"""
@doc type: :windowing
defn hann(opts \\ []) do
opts = keyword!(opts, [:n, :name, is_periodic: true, type: {:f, 32}])
{l, opts} = pop_window_size(opts)
deftransform hann(n, opts \\ []) when is_integer(n) do
opts = Keyword.validate!(opts, [:name, is_periodic: true, type: {:f, 32}])
hann_n(Keyword.put(opts, :n, n))
end

defnp hann_n(opts) do
n = opts[:n]
name = opts[:name]
type = opts[:type]
is_periodic = opts[:is_periodic]

l =
if is_periodic do
l + 1
n + 1
else
l
n
end

n = Nx.iota({l}, names: [name], type: type)
Expand All @@ -296,7 +311,6 @@ defmodule NxSignal.Windows do

## Options

* `:n` - The window length. Mandatory option.
* `:is_periodic` - If `true`, produces a periodic window,
otherwise produces a symmetric window. Defaults to `true`
* `:type` - the output type for the window. Defaults to `{:f, 32}`
Expand All @@ -305,46 +319,50 @@ defmodule NxSignal.Windows do
* `:axis_name` - the axis name. Defaults to `nil`

## Examples
iex> NxSignal.Windows.kaiser(n: 4, beta: 12.0, is_periodic: true)
iex> NxSignal.Windows.kaiser(4, beta: 12.0, is_periodic: true)
#Nx.Tensor<
f32[4]
[5.2776191296288744e-5, 0.21566666662693024, 1.0, 0.21566666662693024]
>

iex> NxSignal.Windows.kaiser(n: 5, beta: 12.0, is_periodic: true)
iex> NxSignal.Windows.kaiser(5, beta: 12.0, is_periodic: true)
#Nx.Tensor<
f32[5]
[5.2776191296288744e-5, 0.10171464085578918, 0.7929369807243347, 0.7929369807243347, 0.10171464085578918]
>

iex> NxSignal.Windows.kaiser(n: 4, beta: 12.0, is_periodic: false)
iex> NxSignal.Windows.kaiser(4, beta: 12.0, is_periodic: false)
#Nx.Tensor<
f32[4]
[5.2776191296288744e-5, 0.5188394784927368, 0.5188390612602234, 5.2776191296288744e-5]
>
"""
@doc type: :windowing
defn kaiser(opts \\ []) do
deftransform kaiser(n, opts \\ []) when is_integer(n) do
opts =
keyword!(opts, [:n, :axis_name, eps: 1.0e-7, beta: 12.0, is_periodic: true, type: {:f, 32}])
Keyword.validate!(opts, [:name, eps: 1.0e-7, beta: 12.0, is_periodic: true, type: {:f, 32}])

kaiser_n(Keyword.put(opts, :n, n))
end

{l, opts} = pop_window_size(opts)
name = opts[:axis_name]
defnp kaiser_n(opts) do
n = opts[:n]
name = opts[:name]
type = opts[:type]
beta = opts[:beta]
eps = opts[:eps]
is_periodic = opts[:is_periodic]

window_length = if is_periodic, do: l + 1, else: l
window_length = if is_periodic, do: n + 1, else: n

ratio = Nx.linspace(-1, 1, n: window_length, endpoint: true, type: type) |> Nx.rename([name])
ratio = Nx.linspace(-1, 1, n: window_length, endpoint: true, type: type, name: name)
sqrt_arg = Nx.max(1 - ratio ** 2, eps)
r = beta * Nx.sqrt(sqrt_arg)

window = kaiser_bessel_i0(r) / kaiser_bessel_i0(beta)

if is_periodic do
Nx.slice(window, [0], [l])
Nx.slice(window, [0], [n])
else
window
end
Expand All @@ -367,17 +385,13 @@ defmodule NxSignal.Windows do
Nx.select(abs_x < 3.75, small_x_result, large_x_result)
end

deftransformp pop_window_size(opts) do
{n, opts} = Keyword.pop(opts, :n)
deftransformp integer_div_ceil(num, den) when is_integer(num) and is_integer(den) do
rem = rem(num, den)

if !n do
raise "missing :n option"
if rem == 0 do
div(num, den)
else
div(num, den) + 1
end

{n, opts}
end

deftransformp div_ceil(num, den) do
ceil(num / den)
end
end