diff --git a/lib/nx_signal/filters.ex b/lib/nx_signal/filters.ex index 79591fe..b4b7cfe 100644 --- a/lib/nx_signal/filters.ex +++ b/lib/nx_signal/filters.ex @@ -3,6 +3,7 @@ defmodule NxSignal.Filters do Common filter functions. """ import Nx.Defn + import NxSignal.Convolution @doc ~S""" Performs a median filter on a tensor. @@ -52,4 +53,83 @@ defmodule NxSignal.Filters do end deftransformp kernel_lengths(kernel_shape), do: Tuple.to_list(kernel_shape) + + @doc """ + Applies a Wiener filter to the given Nx tensor. + + ## Options + + * `:kernel_size` - filter size given either a number or a tuple. + If a number is given, a kernel with the given size, and same number of axes + as the input tensor will be used. Defaults to `3`. + * `:noise` - noise power, given as a scalar. This will be estimated based on the input tensor if `nil`. Defaults to `nil`. + + ## Examples + + iex> t = Nx.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]) + iex> NxSignal.Filters.wiener(t, kernel_size: {2, 2}, noise: 10) + #Nx.Tensor< + f32[3][3] + [ + [0.25, 0.75, 1.25], + [1.25, 3.0, 4.0], + [2.75, 6.0, 7.0] + ] + > + """ + @doc type: :filters + deftransform wiener(t, opts \\ []) do + # Validate and extract options + opts = Keyword.validate!(opts, noise: nil, kernel_size: 3) + + rank = Nx.rank(t) + kernel_size = Keyword.fetch!(opts, :kernel_size) + noise = Keyword.fetch!(opts, :noise) + + # Ensure `kernel_size` is a tuple + kernel_size = + cond do + is_integer(kernel_size) -> Tuple.duplicate(kernel_size, rank) + is_tuple(kernel_size) -> kernel_size + true -> raise ArgumentError, "kernel_size must be an integer or tuple" + end + + # Convert `nil` noise to `0.0` so it's always a valid tensor + noise_t = if is_nil(noise), do: Nx.tensor(0.0), else: Nx.tensor(noise) + + # Compute filter window size + size = Tuple.to_list(kernel_size) |> Enum.reduce(1, &*/2) + + # Ensure the kernel is the same size as the filter window + kernel = Nx.broadcast(1.0, kernel_size) + + t + |> Nx.as_type(:f64) + |> wiener_n(kernel, noise_t, calculate_noise: is_nil(noise), size: size) + |> Nx.as_type(Nx.type(t)) + end + + defnp wiener_n(t, kernel, noise, opts) do + size = opts[:size] + + # Compute local mean using "same" mode in correlation + l_mean = correlate(t, kernel, mode: :same) / size + + # Compute local variance + l_var = + correlate(t ** 2, kernel, mode: :same) + |> Nx.divide(size) + |> Nx.subtract(l_mean ** 2) + + # Ensure `noise` is a tensor to avoid `nil` issues in `defnp` + noise = + case opts[:calculate_noise] do + true -> Nx.mean(l_var) + false -> noise + end + + # Apply Wiener filter formula + res = (t - l_mean) * (1 - noise / l_var) + Nx.select(l_var < noise, l_mean, res + l_mean) + end end diff --git a/test/nx_signal/filters_test.exs b/test/nx_signal/filters_test.exs index a439305..167a3b9 100644 --- a/test/nx_signal/filters_test.exs +++ b/test/nx_signal/filters_test.exs @@ -116,4 +116,129 @@ defmodule NxSignal.FiltersTest do ) end end + + describe "wiener/2" do + test "performs n-dim wiener filter with calculated noise" do + im = + Nx.tensor( + [ + [1.0, 2.0, 3.0, 4.0, 5.0], + [6.0, 7.0, 8.0, 9.0, 10.0], + [11.0, 12.0, 13.0, 14.0, 15.0] + ], + type: :f64 + ) + + kernel_size = {3, 3} + + expected = + Nx.tensor( + [ + [ + 1.7777777777777777, + 3.0, + 3.6666666666666665, + 4.333333333333333, + 3.111111111111111 + ], + [4.3366520642506305, 7.0, 8.0, 9.0, 7.58637597408283], + [ + 4.692197051420351, + 7.261706150595039, + 8.748939779474131, + 10.157992415073023, + 9.813815742524799 + ] + ], + type: :f64 + ) + + assert NxSignal.Filters.wiener(im, kernel_size: kernel_size) == expected + assert NxSignal.Filters.wiener(im, kernel_size: 3) == expected + + assert NxSignal.Filters.wiener(Nx.as_type(im, :f32), kernel_size: kernel_size) == + Nx.tensor([ + [ + 1.7777777910232544, + 3.0, + 3.6666667461395264, + 4.333333492279053, + 3.1111111640930176 + ], + [4.3366522789001465, 7.0, 8.0, 9.0, 7.586376190185547], + [ + 4.692196846008301, + 7.261706352233887, + 8.748939514160156, + 10.157992362976074, + 9.81381607055664 + ] + ]) + end + + test "performs n-dim wiener filter with parameterized noise" do + im = + Nx.tensor( + [ + [1.0, 2.0, 3.0, 4.0, 5.0], + [6.0, 7.0, 8.0, 9.0, 10.0], + [11.0, 12.0, 13.0, 14.0, 15.0] + ], + type: :f64 + ) + + kernel_size = {3, 3} + + assert NxSignal.Filters.wiener(im, kernel_size: kernel_size, noise: 10) == + Nx.tensor( + [ + [ + 1.7777777777777777, + 3.0, + 3.5882352941176467, + 4.238095238095238, + 3.7397034596375622 + ], + [5.193548387096774, 7.0, 8.0, 9.0, 8.829787234042554], + [ + 7.941747572815534, + 9.702702702702702, + 10.938931297709924, + 12.137254901960784, + 12.485549132947977 + ] + ], + type: :f64 + ) + + assert NxSignal.Filters.wiener(Nx.as_type(im, :f32), kernel_size: kernel_size, noise: 10) == + Nx.tensor([ + [ + 1.7777777910232544, + 3.0, + 3.588235378265381, + 4.238095283508301, + 3.739703416824341 + ], + [5.193548202514648, 7.0, 8.0, 9.0, 8.829787254333496], + [ + 7.941747665405273, + 9.702702522277832, + 10.938931465148926, + 12.13725471496582, + 12.485548973083496 + ] + ]) + + assert NxSignal.Filters.wiener(im, kernel_size: kernel_size, noise: 0) == + Nx.tensor( + [ + [1.0, 2.0, 3.0, 4.0, 5.0], + [6.0, 7.0, 8.0, 9.0, 10.0], + [11.0, 12.0, 13.0, 14.0, 15.0] + ], + type: :f64 + ) + end + end end