-
Notifications
You must be signed in to change notification settings - Fork 7
Convolutions #22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Convolutions #22
Changes from 61 commits
Commits
Show all changes
76 commits
Select commit
Hold shift + click to select a range
f220c47
add first convolution test
hunterboerner e30062b
stub implementation
hunterboerner c4de7cb
add assert_all_close helper
hunterboerner b8599f5
Attempt at implementing convolve
hunterboerner 597d2ce
Add another test
hunterboerner b32169c
remove nil
hunterboerner db61498
try changing order
hunterboerner 4388184
shape to output
hunterboerner 8de0a00
change reversing
hunterboerner 67ae6f3
full mode
hunterboerner c0cf107
proper reshaping
hunterboerner c317373
Add another test
hunterboerner 8ab92ab
Complex test
hunterboerner 219de78
Zero rank
hunterboerner 60864fa
Add a numpy test
hunterboerner e5e1856
fairly unnecessary unwrite test
hunterboerner d4028c6
Add comment for where tests came from
hunterboerner e502373
Add auto mode
hunterboerner 61cf708
Start implementing fftconvolve
hunterboerner 5cef02b
refactor: simplify reshapes and fix paddings
polvalente 12ac3c7
Merge branch 'convolutions' of github.com:hunterboerner/nx_signal int…
polvalente de67b3a
No auto method
hunterboerner 1d849ff
We have FFT support but the tensors are not ordered correctly
hunterboerner 1119e14
Change ignore to explicit axes
hunterboerner fade865
Try getting FFTConvolve working
hunterboerner 13fb3b6
Fix typo
hunterboerner 6c18faa
Fix the operator (again)
hunterboerner bfeec5a
checkpoint
hunterboerner 45bc1da
fft_nd not working correctly
hunterboerner c060f82
value should be negative
hunterboerner 1066fa2
Add test for fft_nd with padding
hunterboerner ebc61f4
test: explicit assertions
polvalente 23b205e
refactor: split code paths
polvalente 13e70cc
fix: broadcasting rules
polvalente 69dcf20
More tests for convolution
hunterboerner 9226016
New (broken) test
hunterboerner 6c2fc54
Get tests passing modulo valid
hunterboerner 16dcb5a
Invalid arg tests
hunterboerner 41286c4
Try implementing "valid" convolution
hunterboerner d8dc84e
Swap kernel and volume for valid mode
hunterboerner 4875ae8
More convolution tests
hunterboerner afe2c97
Don't upcast to complex numbers test
hunterboerner 5d9620b
mismatched dims tests
hunterboerner 51f23aa
2d test
hunterboerner d1aef2b
More FFT convolution tests
hunterboerner 9ef7165
More FFT convolvution tests
hunterboerner ae7b05a
FFT same mode
hunterboerner 77f815b
Valid mode for FFT
hunterboerner 70ba95b
add brief doc
hunterboerner 22fa94c
First correlation tests and implementation
hunterboerner 6e0cc30
Rank 1 same correlation test
hunterboerner b6de107
Rank 1 full correlation test
hunterboerner bba9819
Complex correlation
hunterboerner 1e74d60
fix: allow doc metadata to not be set
polvalente a9d0b32
Implement test for oaconvolve
hunterboerner ae4b85e
Revert "Implement test for oaconvolve"
hunterboerner 53bca88
refactor: overall improvements
polvalente b9f0c95
Docs. Latex not rendering correctly
hunterboerner 45fc931
Fix LaTeX
hunterboerner 465afde
Correlation docs
hunterboerner c133d41
Doctests for convolution and correlation
hunterboerner 4596b30
Update lib/nx_signal/convolution.ex
hunterboerner bfbb807
Update lib/nx_signal/convolution.ex
hunterboerner 910c4ba
Update lib/nx_signal/convolution.ex
hunterboerner 09a6817
Update lib/nx_signal/convolution.ex
hunterboerner 3592ecd
Don't upcast to complex for correlate
hunterboerner f58f9bb
Add tests for fft_nd and ifft_nd
hunterboerner c68a413
docs: document fftconvolve
polvalente 363e13a
docs: improve docs for fftconvolve
polvalente 33a4372
docs: more doc improvemnets
polvalente c074683
docs: improve docs
polvalente 2ed0736
chore: remove helpers module
polvalente 03eea5a
chore: remove imports
polvalente 0135211
chore: update ci
polvalente b4c4c87
fix: use runner os in cache key
polvalente 12c3505
fix: hash suffix
polvalente File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,329 @@ | ||
| defmodule NxSignal.Convolution do | ||
| @moduledoc """ | ||
| Convolution functions through various methods. | ||
|
|
||
| Follows the `scipy.signal` conventions. | ||
| """ | ||
|
|
||
| import Nx.Defn | ||
| import NxSignal.Transforms | ||
|
|
||
| @doc """ | ||
| Computes the convolution of two tensors. | ||
|
|
||
| Given $f[n] \\in \\mathbb{C}^N$ and $k[n] \\in \\mathbb{C}^{K}$, we define the convolution $f * k$ by | ||
hunterboerner marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| $$ | ||
| g(m) = (f * k)[m] = \\sum_{m=0}^{K-1} \\tilde{f}[n-m]\\tilde{k}[m], | ||
hunterboerner marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| $$ | ||
|
|
||
| where | ||
|
|
||
| $$ | ||
| \\tilde{f}[n] = | ||
| \\begin{cases} | ||
| 0 & n < 0 \\\\\\\\ | ||
| 0 & n \\geq N \\\\\\\\ | ||
| f[n] & \\text{otherwise} | ||
| \\end{cases} | ||
| $$ | ||
|
|
||
| and $\\tilde{k}$ is defined similarly. | ||
hunterboerner marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| ## Options | ||
|
|
||
| * `:method` - One of `:fft` or `:direct`. Defaults to `:direct`. | ||
| * `:mode` - One of `:full`, `:valid`, or `:same`. Defaults to `:full`. | ||
|
|
||
| ## Examples | ||
|
|
||
| iex> NxSignal.Convolution.convolve(Nx.tensor([1,2,3]), Nx.tensor([3,4,5])) | ||
| #Nx.Tensor< | ||
| f32[5] | ||
| [3.0, 10.0, 22.0, 22.0, 15.0] | ||
| > | ||
| """ | ||
| deftransform convolve(in1, in2, opts \\ []) do | ||
| opts = Keyword.validate!(opts, mode: :full, method: :direct) | ||
|
|
||
| if opts[:mode] not in [:full, :same, :valid] do | ||
| raise ArgumentError, | ||
| "expected mode to be one of [:full, :same, :valid], got: #{inspect(opts[:mode])}" | ||
| end | ||
|
|
||
| if opts[:method] not in [:direct, :fft] do | ||
| raise ArgumentError, | ||
| "expected method to be one of [:direct, :fft], got: #{inspect(opts[:method])}" | ||
| end | ||
|
|
||
| case opts[:method] do | ||
| :direct -> | ||
| direct_convolve(in1, in2, opts) | ||
|
|
||
| :fft -> | ||
| fftconvolve(in1, in2, opts) | ||
| end | ||
| end | ||
|
|
||
| @doc """ | ||
| Given $f[n] \\in \\mathbb{C}^N$ and $k[n] \\in \\mathbb{C}^{K}$, we define the correlation $f \\star k$ by | ||
|
|
||
| $$ | ||
| g(m) = (f * k)[m] = \\sum_{m=0}^{K-1} \\overline{\\tilde{f}[m-n]}\\tilde{k}[m], | ||
| $$ | ||
|
|
||
| where | ||
|
|
||
| $$ | ||
| \\tilde{f}[n] = | ||
| \\begin{cases} | ||
| 0 & n < 0 \\\\\\\\ | ||
| 0 & n \\geq N \\\\\\\\ | ||
| f[n] & \\text{otherwise} | ||
| \\end{cases} | ||
| $$ | ||
hunterboerner marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| and $\\tilde{k}$ is defined similarly. | ||
|
|
||
| ## Options | ||
|
|
||
| * `:method` - One of `:fft` or `:direct`. Defaults to `:direct`. | ||
| * `:mode` - One of `:full`, `:valid`, or `:same`. Defaults to `:full`. | ||
|
|
||
| ## Examples | ||
|
|
||
| iex> NxSignal.Convolution.correlate(Nx.tensor([1,2,3]), Nx.tensor([3,4,5])) | ||
| #Nx.Tensor< | ||
| c64[5] | ||
| [5.0-0.0i, 14.0-0.0i, 26.0-0.0i, 18.0-0.0i, 9.0-0.0i] | ||
hunterboerner marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| > | ||
| """ | ||
| defn correlate(in1, in2, opts \\ []) do | ||
polvalente marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| convolve(in1, Nx.conjugate(Nx.reverse(in2)), opts) | ||
| end | ||
|
|
||
| deftransformp direct_convolve(in1, in2, opts) do | ||
| input_rank = | ||
| case {Nx.rank(in1), Nx.rank(in2)} do | ||
| {0, 0} -> | ||
| 0 | ||
|
|
||
| {0, r} -> | ||
| raise ArgumentError, message: "Incompatible ranks: {0, #{r}}" | ||
|
|
||
| {r, 0} -> | ||
| raise ArgumentError, message: "Incompatible ranks: {#{r}, 0}" | ||
|
|
||
| {r, r} -> | ||
| r | ||
|
|
||
| {r1, r2} -> | ||
| raise ArgumentError, | ||
| "NxSignal.convolve/3 requires both inputs to have the same rank or one of them to be a scalar, got #{r1} and #{r2}" | ||
| end | ||
|
|
||
| zipped = Enum.zip(Tuple.to_list(Nx.shape(in1)), Tuple.to_list(Nx.shape(in2))) | ||
|
|
||
| ok1 = Enum.all?(for {i, j} <- zipped, do: i >= j) | ||
| ok2 = Enum.all?(for {i, j} <- zipped, do: i <= j) | ||
|
|
||
| {in1, in2} = | ||
| cond do | ||
| opts[:mode] != :valid -> | ||
| {in1, in2} | ||
|
|
||
| ok1 -> | ||
| {in1, in2} | ||
|
|
||
| ok2 -> | ||
| {in2, in1} | ||
|
|
||
| true -> | ||
| raise ArgumentError, | ||
| message: | ||
| "For :valid mode, one must be at least as large as the other in every dimension" | ||
| end | ||
|
|
||
| kernel = Nx.reverse(in2) | ||
|
|
||
| kernel_shape = | ||
| case Nx.shape(kernel) do | ||
| {} -> {1, 1, 1, 1} | ||
| {n} -> {1, 1, 1, n} | ||
| shape -> List.to_tuple([1, 1 | Tuple.to_list(shape)]) | ||
| end | ||
|
|
||
| kernel = Nx.reshape(kernel, kernel_shape) | ||
|
|
||
| volume_shape = | ||
| case Nx.shape(in1) do | ||
| {} -> {1, 1, 1, 1} | ||
| {n} -> {1, 1, 1, n} | ||
| shape -> List.to_tuple([1, 1 | Tuple.to_list(shape)]) | ||
| end | ||
|
|
||
| volume = Nx.reshape(in1, volume_shape) | ||
|
|
||
| opts = | ||
| case opts[:mode] do | ||
| :same -> | ||
| kernel_spatial_shape = | ||
| Nx.shape(kernel) | ||
| |> Tuple.to_list() | ||
| |> Enum.drop(2) | ||
|
|
||
| padding = | ||
| Enum.map(kernel_spatial_shape, fn k -> | ||
| pad_total = k - 1 | ||
| # integer division for right side | ||
| pad_right = div(pad_total, 2) | ||
| # put the extra padding on the left | ||
| pad_left = pad_total - pad_right | ||
| {pad_left, pad_right} | ||
| end) | ||
|
|
||
| [padding: padding] | ||
|
|
||
| :full -> | ||
| kernel_spatial_shape = | ||
| Nx.shape(kernel) | ||
| |> Tuple.to_list() | ||
| |> Enum.drop(2) | ||
|
|
||
| padding = | ||
| Enum.map(kernel_spatial_shape, fn k -> | ||
| {k - 1, k - 1} | ||
| end) | ||
|
|
||
| [padding: padding] | ||
|
|
||
| :valid -> | ||
| [padding: :valid] | ||
| end | ||
|
|
||
| out = Nx.conv(volume, kernel, opts) | ||
|
|
||
| squeeze_axes = | ||
| case input_rank do | ||
| 0 -> | ||
| [0, 1, 2, 3] | ||
|
|
||
| 1 -> | ||
| [0, 1, 2] | ||
|
|
||
| _ -> | ||
| [0, 1] | ||
| end | ||
|
|
||
| out | ||
| |> Nx.squeeze(axes: squeeze_axes) | ||
| |> clip_valid(Nx.shape(volume), Nx.shape(kernel), opts[:mode]) | ||
| end | ||
|
|
||
| deftransformp clip_valid(out, in1_shape, in2_shape, :valid) do | ||
| select = | ||
| [in1_shape, in2_shape] | ||
| |> Enum.zip_with(fn [i, j] -> | ||
| 0..(i - j) | ||
| end) | ||
|
|
||
| out[select] | ||
| end | ||
|
|
||
| deftransformp clip_valid(out, _, _, _), do: out | ||
|
|
||
| deftransform fftconvolve(in1, in2, opts \\ []) do | ||
| case {Nx.rank(in1), Nx.rank(in2)} do | ||
| {a, b} when a == b -> | ||
| s1 = Nx.shape(in1) |> Tuple.to_list() | ||
| s2 = Nx.shape(in2) |> Tuple.to_list() | ||
|
|
||
| lengths = | ||
| Enum.zip_with(s1, s2, fn ax1, ax2 -> | ||
| ax1 + ax2 - 1 | ||
| end) | ||
|
|
||
| axes = | ||
| [s1, s2, Nx.axes(in1)] | ||
| |> Enum.zip_with(fn [ax1, ax2, axis] -> | ||
| if ax1 != 1 and ax2 != 1 do | ||
| axis | ||
| end | ||
| end) | ||
| |> Enum.filter(& &1) | ||
|
|
||
| lengths = Enum.map(axes, &Enum.fetch!(lengths, &1)) | ||
|
|
||
| sp1 = | ||
| fft_nd(in1, axes: axes, lengths: lengths) | ||
|
|
||
| sp2 = | ||
| fft_nd(in2, axes: axes, lengths: lengths) | ||
|
|
||
| c = Nx.multiply(sp1, sp2) | ||
|
|
||
| out = ifft_nd(c, axes: axes) | ||
|
|
||
| out = | ||
| if Nx.Type.merge(Nx.type(in1), Nx.type(in2)) |> Nx.Type.complex?() do | ||
| out | ||
| else | ||
| Nx.real(out) | ||
| end | ||
|
|
||
| apply_mode(out, s1, s2, opts[:mode]) | ||
|
|
||
| _ -> | ||
| raise ArgumentError, message: "Rank of in1 and in2 must be equal." | ||
| end | ||
| end | ||
|
|
||
| deftransform apply_mode(out, _s1, _s2, :full) do | ||
| out | ||
| end | ||
|
|
||
| deftransform apply_mode(out, s1, _s2, :same) do | ||
| centered(out, s1) | ||
| end | ||
|
|
||
| deftransform apply_mode(out, s1, s2, :valid) do | ||
| {s1, s2} = swap_axes(s1, s2) | ||
|
|
||
| shape_valid = | ||
| for {a, b} <- Enum.zip(s1, s2) do | ||
| a - b + 1 | ||
| end | ||
|
|
||
| centered(out, shape_valid) | ||
| end | ||
|
|
||
| deftransformp centered(out, new_shape) do | ||
| start_indices = | ||
| out | ||
| |> Nx.shape() | ||
| |> Tuple.to_list() | ||
| |> Enum.zip_with(new_shape, fn current, new -> | ||
| div(current - new, 2) | ||
| end) | ||
|
|
||
| Nx.slice(out, start_indices, new_shape) | ||
| end | ||
|
|
||
| defp swap_axes(s1, s2) do | ||
| ok1 = Enum.zip_reduce(s1, s2, true, fn a, b, acc -> acc and a >= b end) | ||
| ok2 = Enum.zip_reduce(s2, s1, true, fn a, b, acc -> acc and a >= b end) | ||
|
|
||
| cond do | ||
| ok1 -> | ||
| {s1, s2} | ||
|
|
||
| ok2 -> | ||
| {s2, s1} | ||
|
|
||
| true -> | ||
| raise ArgumentError, | ||
| message: | ||
| "For 'valid' mode, one must be at least as large as the other in every dimension." | ||
| end | ||
| end | ||
| end | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,21 @@ | ||
| defmodule NxSignal.Transforms do | ||
| import Nx.Defn | ||
|
|
||
| deftransform fft_nd(tensor, opts \\ []) do | ||
hunterboerner marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| axes = Keyword.get(opts, :axes, [-1]) | ||
| lengths = Keyword.get(opts, :lengths) || List.duplicate(nil, length(axes)) | ||
|
|
||
| Enum.zip_reduce(axes, lengths, tensor, fn axis, len, acc -> | ||
| Nx.fft(acc, axis: axis, length: len) | ||
| end) | ||
| end | ||
|
|
||
| deftransform ifft_nd(tensor, opts \\ []) do | ||
| axes = Keyword.get(opts, :axes, [-1]) | ||
| lengths = Keyword.get(opts, :lengths) || List.duplicate(nil, length(axes)) | ||
|
|
||
| Enum.zip_reduce(axes, lengths, tensor, fn axis, len, acc -> | ||
| Nx.ifft(acc, axis: axis, length: len) | ||
| end) | ||
| end | ||
| end | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.