Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions lib/nx_signal/filters.ex
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ defmodule NxSignal.Filters do
Common filter functions.
"""
import Nx.Defn
import NxSignal.Convolution

@doc ~S"""
Performs a median filter on a tensor.
Expand Down Expand Up @@ -52,4 +53,83 @@ defmodule NxSignal.Filters do
end

deftransformp kernel_lengths(kernel_shape), do: Tuple.to_list(kernel_shape)

@doc """
Applies a Wiener filter to the given Nx tensor.

## Options

* `:kernel_size` - filter size given either a number or a tuple.
If a number is given, a kernel with the given size, and same number of axes
as the input tensor will be used. Defaults to `3`.
* `:noise` - noise power, given as a scalar. This will be estimated based on the input tensor if `nil`. Defaults to `nil`.

## Examples

iex> t = Nx.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]])
iex> NxSignal.Filters.wiener(t, kernel_size: {2, 2}, noise: 10)
#Nx.Tensor<
f32[3][3]
[
[0.25, 0.75, 1.25],
[1.25, 3.0, 4.0],
[2.75, 6.0, 7.0]
]
>
"""
@doc type: :filters
deftransform wiener(t, opts \\ []) do
# Validate and extract options
opts = Keyword.validate!(opts, noise: nil, kernel_size: 3)

rank = Nx.rank(t)
kernel_size = Keyword.fetch!(opts, :kernel_size)
noise = Keyword.fetch!(opts, :noise)

# Ensure `kernel_size` is a tuple
kernel_size =
cond do
is_integer(kernel_size) -> Tuple.duplicate(kernel_size, rank)
is_tuple(kernel_size) -> kernel_size
true -> raise ArgumentError, "kernel_size must be an integer or tuple"
end

# Convert `nil` noise to `0.0` so it's always a valid tensor
noise_t = if is_nil(noise), do: Nx.tensor(0.0), else: Nx.tensor(noise)

# Compute filter window size
size = Tuple.to_list(kernel_size) |> Enum.reduce(1, &*/2)

# Ensure the kernel is the same size as the filter window
kernel = Nx.broadcast(1.0, kernel_size)

t
|> Nx.as_type(:f64)
|> wiener_n(kernel, noise_t, calculate_noise: is_nil(noise), size: size)
|> Nx.as_type(Nx.type(t))
end

defnp wiener_n(t, kernel, noise, opts) do
size = opts[:size]

# Compute local mean using "same" mode in correlation
l_mean = correlate(t, kernel, mode: :same) / size

# Compute local variance
l_var =
correlate(t ** 2, kernel, mode: :same)
|> Nx.divide(size)
|> Nx.subtract(l_mean ** 2)

# Ensure `noise` is a tensor to avoid `nil` issues in `defnp`
noise =
case opts[:calculate_noise] do
true -> Nx.mean(l_var)
false -> noise
end

# Apply Wiener filter formula
res = (t - l_mean) * (1 - noise / l_var)
Nx.select(l_var < noise, l_mean, res + l_mean)
end
end
125 changes: 125 additions & 0 deletions test/nx_signal/filters_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -116,4 +116,129 @@ defmodule NxSignal.FiltersTest do
)
end
end

describe "wiener/2" do
test "performs n-dim wiener filter with calculated noise" do
im =
Nx.tensor(
[
[1.0, 2.0, 3.0, 4.0, 5.0],
[6.0, 7.0, 8.0, 9.0, 10.0],
[11.0, 12.0, 13.0, 14.0, 15.0]
],
type: :f64
)

kernel_size = {3, 3}

expected =
Nx.tensor(
[
[
1.7777777777777777,
3.0,
3.6666666666666665,
4.333333333333333,
3.111111111111111
],
[4.3366520642506305, 7.0, 8.0, 9.0, 7.58637597408283],
[
4.692197051420351,
7.261706150595039,
8.748939779474131,
10.157992415073023,
9.813815742524799
]
],
type: :f64
)

assert NxSignal.Filters.wiener(im, kernel_size: kernel_size) == expected
assert NxSignal.Filters.wiener(im, kernel_size: 3) == expected

assert NxSignal.Filters.wiener(Nx.as_type(im, :f32), kernel_size: kernel_size) ==
Nx.tensor([
[
1.7777777910232544,
3.0,
3.6666667461395264,
4.333333492279053,
3.1111111640930176
],
[4.3366522789001465, 7.0, 8.0, 9.0, 7.586376190185547],
[
4.692196846008301,
7.261706352233887,
8.748939514160156,
10.157992362976074,
9.81381607055664
]
])
end

test "performs n-dim wiener filter with parameterized noise" do
im =
Nx.tensor(
[
[1.0, 2.0, 3.0, 4.0, 5.0],
[6.0, 7.0, 8.0, 9.0, 10.0],
[11.0, 12.0, 13.0, 14.0, 15.0]
],
type: :f64
)

kernel_size = {3, 3}

assert NxSignal.Filters.wiener(im, kernel_size: kernel_size, noise: 10) ==
Nx.tensor(
[
[
1.7777777777777777,
3.0,
3.5882352941176467,
4.238095238095238,
3.7397034596375622
],
[5.193548387096774, 7.0, 8.0, 9.0, 8.829787234042554],
[
7.941747572815534,
9.702702702702702,
10.938931297709924,
12.137254901960784,
12.485549132947977
]
],
type: :f64
)

assert NxSignal.Filters.wiener(Nx.as_type(im, :f32), kernel_size: kernel_size, noise: 10) ==
Nx.tensor([
[
1.7777777910232544,
3.0,
3.588235378265381,
4.238095283508301,
3.739703416824341
],
[5.193548202514648, 7.0, 8.0, 9.0, 8.829787254333496],
[
7.941747665405273,
9.702702522277832,
10.938931465148926,
12.13725471496582,
12.485548973083496
]
])

assert NxSignal.Filters.wiener(im, kernel_size: kernel_size, noise: 0) ==
Nx.tensor(
[
[1.0, 2.0, 3.0, 4.0, 5.0],
[6.0, 7.0, 8.0, 9.0, 10.0],
[11.0, 12.0, 13.0, 14.0, 15.0]
],
type: :f64
)
end
end
end