@@ -31,6 +31,71 @@ defmodule NxSignal.FiltersTest do
3131 assert NxSignal.Filters . median ( t , opts ) == expected
3232 end
3333
34+ test "performs n-dim median filter" do
35+ t =
36+ Nx . tensor ( [
37+ [
38+ [ 31 , 11 , 17 , 13 , 1 ] ,
39+ [ 1 , 3 , 19 , 23 , 29 ] ,
40+ [ 19 , 5 , 7 , 37 , 2 ]
41+ ] ,
42+ [
43+ [ 19 , 5 , 7 , 37 , 2 ] ,
44+ [ 1 , 3 , 19 , 23 , 29 ] ,
45+ [ 31 , 11 , 17 , 13 , 1 ]
46+ ] ,
47+ [
48+ [ 1 , 3 , 19 , 23 , 29 ] ,
49+ [ 31 , 11 , 17 , 13 , 1 ] ,
50+ [ 19 , 5 , 7 , 37 , 2 ]
51+ ]
52+ ] )
53+
54+ k1 = { 3 , 3 , 1 }
55+ k2 = { 3 , 3 , 3 }
56+
57+ expected1 =
58+ Nx . tensor ( [
59+ [
60+ [ 19.0 , 5.0 , 17.0 , 23.0 , 2.0 ] ,
61+ [ 19.0 , 5.0 , 17.0 , 23.0 , 2.0 ] ,
62+ [ 19.0 , 5.0 , 17.0 , 23.0 , 2.0 ]
63+ ] ,
64+ [
65+ [ 19.0 , 5.0 , 17.0 , 23.0 , 2.0 ] ,
66+ [ 19.0 , 5.0 , 17.0 , 23.0 , 2.0 ] ,
67+ [ 19.0 , 5.0 , 17.0 , 23.0 , 2.0 ]
68+ ] ,
69+ [
70+ [ 19.0 , 5.0 , 17.0 , 23.0 , 2.0 ] ,
71+ [ 19.0 , 5.0 , 17.0 , 23.0 , 2.0 ] ,
72+ [ 19.0 , 5.0 , 17.0 , 23.0 , 2.0 ]
73+ ]
74+ ] )
75+
76+ expected2 =
77+ Nx . tensor ( [
78+ [
79+ [ 11.0 , 13.0 , 17.0 , 17.0 , 17.0 ] ,
80+ [ 11.0 , 13.0 , 17.0 , 17.0 , 17.0 ] ,
81+ [ 11.0 , 13.0 , 17.0 , 17.0 , 17.0 ]
82+ ] ,
83+ [
84+ [ 11.0 , 13.0 , 17.0 , 17.0 , 17.0 ] ,
85+ [ 11.0 , 13.0 , 17.0 , 17.0 , 17.0 ] ,
86+ [ 11.0 , 13.0 , 17.0 , 17.0 , 17.0 ]
87+ ] ,
88+ [
89+ [ 11.0 , 13.0 , 17.0 , 17.0 , 17.0 ] ,
90+ [ 11.0 , 13.0 , 17.0 , 17.0 , 17.0 ] ,
91+ [ 11.0 , 13.0 , 17.0 , 17.0 , 17.0 ]
92+ ]
93+ ] )
94+
95+ assert NxSignal.Filters . median ( t , kernel_shape: k1 ) == expected1
96+ assert NxSignal.Filters . median ( t , kernel_shape: k2 ) == expected2
97+ end
98+
3499 test "raises if kernel_shape is not compatible" do
35100 t1 = Nx . iota ( { 10 } )
36101 opts1 = [ kernel_shape: { 5 , 5 } ]
@@ -50,25 +115,5 @@ defmodule NxSignal.FiltersTest do
50115 fn -> NxSignal.Filters . median ( t2 , opts2 ) end
51116 )
52117 end
53-
54- test "raises if tensor rank is not 1 or 2" do
55- t1 = Nx . tensor ( 1 )
56- opts1 = [ kernel_shape: { 1 } ]
57-
58- assert_raise (
59- ArgumentError ,
60- "tensor must be of rank 1 or 2" ,
61- fn -> NxSignal.Filters . median ( t1 , opts1 ) end
62- )
63-
64- t2 = Nx . iota ( { 5 , 5 , 5 } )
65- opts2 = [ kernel_shape: { 3 , 3 , 3 } ]
66-
67- assert_raise (
68- ArgumentError ,
69- "tensor must be of rank 1 or 2" ,
70- fn -> NxSignal.Filters . median ( t2 , opts2 ) end
71- )
72- end
73118 end
74119end
0 commit comments