Skip to content

Commit 76b7a8b

Browse files
authored
Merge pull request #75 from itan1/make-flip-and-zoom-nd-compatible
Make Flips ND compatible and fix ND ScaleKeepAspect, ScaleFixed and PinOrigin
2 parents 880b5ef + 65eb33c commit 76b7a8b

File tree

8 files changed

+139
-34
lines changed

8 files changed

+139
-34
lines changed

docs/src/projective/gallery.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -135,14 +135,14 @@ tfms = [
135135
showgrid([apply(tfm, (image, bbox)) for tfm in tfms]; ncol=6, npad=8)
136136
```
137137

138-
## [`FlipX`](@ref), [`FlipY`](@ref), [`Reflect`](@ref)
138+
## [`FlipX`](@ref), [`FlipY`](@ref), [`FlipZ`](@ref), [`Reflect`](@ref)
139139

140140
Flip the data on the horizontally and vertically, respectively. More generally, reflect around an angle from the x-axis.
141141

142142
```@example deps
143143
tfms = [
144-
FlipX(),
145-
FlipY(),
144+
FlipX{2}(),
145+
FlipY{2}(),
146146
Reflect(30),
147147
]
148148
showgrid([apply(tfm, (image, bbox)) for tfm in tfms]; ncol=6, npad=8)

docs/src/projective/intro.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ We can break down most augmentation used in practive into a single (possibly sto
77
As an example, consider an image augmentation pipeline: A random horizontal flip, followed by a random resized crop. The latter resizes and crops (irregularly sized) images to a common size without distorting the aspect ratio.
88

99
```julia
10-
Maybe(FlipX()) |> RandomResizeCrop((h, w))
10+
Maybe(FlipX{2}()) |> RandomResizeCrop((h, w))
1111
```
1212

1313
Let's pull apart the steps involved.

docs/src/ref.md

+2
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@ BoundingBox
66
CenterCrop
77
CenterResizeCrop
88
Crop
9+
FlipDim
910
FlipX
1011
FlipY
12+
FlipZ
1113
Image
1214
Keypoints
1315
MaskBinary

docs/src/transformations.md

+4-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ using DataAugmentation
66
Using transformations is easy. Simply `compose` them:
77

88
```@example tsm
9-
tfm = Rotate(10) |> ScaleRatio((0.7,0.1,1.2)) |> FlipX() |> Crop((128, 128))
9+
tfm = Rotate(10) |> ScaleRatio((0.7,0.1,1.2)) |> FlipX{2}() |> Crop((128, 128))
1010
```
1111

1212
# Projective transformations
@@ -26,8 +26,10 @@ Projective transformations include:
2626
Affine transformations are a subgroup of projective transformations that can be composed very efficiently: composing two affine transformations results in another affine transformation. Affine transformations can represent translation, scaling, reflection and rotation. Available `Transform`s are:
2727

2828
```@docs; canonical=false
29+
FlipDim
2930
FlipX
3031
FlipY
32+
FlipZ
3133
Reflect
3234
Rotate
3335
RotateX
@@ -73,7 +75,7 @@ Let's say we have an image classification dataset. For most datasets, horizontal
7375
```@example
7476
using DataAugmentation, TestImages
7577
item = Image(testimage("lighthouse"))
76-
tfm = Maybe(FlipX())
78+
tfm = Maybe(FlipX{2}())
7779
titems = [apply(tfm, item) for _ in 1:8]
7880
showgrid(titems; ncol = 4, npad = 16)
7981
```

src/DataAugmentation.jl

+2
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,10 @@ export Item,
8484
apply,
8585
Reflect,
8686
WarpAffine,
87+
FlipDim,
8788
FlipX,
8889
FlipY,
90+
FlipZ,
8991
PinOrigin,
9092
AdjustBrightness,
9193
AdjustContrast,

src/projective/affine.jl

+65-18
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,15 @@ struct ScaleKeepAspect{N} <: ProjectiveTransform
4343
end
4444

4545

46-
function getprojection(scale::ScaleKeepAspect{N}, bounds; randstate = nothing) where N
46+
function getprojection(scale::ScaleKeepAspect{N}, bounds::Bounds{N}; randstate = nothing) where N
4747
# If no scaling needs to be done, return a noop transform
4848
scale.minlengths == length.(bounds.rs) && return IdentityTransformation()
4949

5050
# Offset `minlengths` by 1 to avoid black border on one side
5151
ratio = maximum((scale.minlengths .+ 1) ./ length.(bounds.rs))
5252
upperleft = SVector{N, Float32}(minimum.(bounds.rs)) .- 0.5
5353
P = scaleprojection(Tuple(ratio for _ in 1:N))
54-
if upperleft != SVector(0, 0)
54+
if any(upperleft .!= 0)
5555
P = P Translation((Float32.(P(upperleft)) .+ 0.5f0))
5656
end
5757
return P
@@ -79,11 +79,11 @@ struct ScaleFixed{N} <: ProjectiveTransform
7979
end
8080

8181

82-
function getprojection(scale::ScaleFixed, bounds; randstate = nothing)
82+
function getprojection(scale::ScaleFixed, bounds::Bounds{N}; randstate = nothing) where N
8383
ratios = (scale.sizes .+ 1) ./ length.(bounds.rs)
84-
upperleft = SVector{2, Float32}(minimum.(bounds.rs)) .- 1
84+
upperleft = SVector{N, Float32}(minimum.(bounds.rs)) .- 1
8585
P = scaleprojection(ratios)
86-
if upperleft != SVector(0, 0)
86+
if any(upperleft .!= 0)
8787
P = P Translation(-upperleft)
8888
end
8989
return P
@@ -92,7 +92,7 @@ end
9292

9393
function projectionbounds(tfm::ScaleFixed{N}, P, bounds::Bounds{N}; randstate = nothing) where N
9494
bounds_ = transformbounds(bounds, P)
95-
return offsetcropbounds(tfm.sizes, bounds_, (1., 1.))
95+
return offsetcropbounds(tfm.sizes, bounds_, ntuple(_ -> 1., N))
9696
end
9797

9898
"""
@@ -230,7 +230,7 @@ struct Reflect <: ProjectiveTransform
230230
end
231231

232232

233-
function getprojection(tfm::Reflect, bounds; randstate = getrandstate(tfm))
233+
function getprojection(tfm::Reflect, bounds::Bounds{2}; randstate = getrandstate(tfm))
234234
r = tfm.γ / 360 * 2pi
235235
return centered(LinearMap(reflectionmatrix(r)), bounds)
236236
end
@@ -241,26 +241,73 @@ end
241241
Transform `P` so that is applied around the center of `bounds`
242242
instead of the origin
243243
"""
244-
function centered(P, bounds::Bounds{2})
244+
function centered(P, bounds::Bounds{N}) where N
245245
upperleft = minimum.(bounds.rs)
246246
bottomright = maximum.(bounds.rs)
247247

248-
midpoint = SVector{2, Float32}((bottomright .- upperleft) ./ 2) .+ SVector{2, Float32}(.5, .5)
248+
midpoint = SVector{N, Float32}((bottomright .- upperleft) ./ 2) .+ .5f0
249249
return recenter(P, midpoint)
250250
end
251251

252+
253+
function reflectionmatrix(r)
254+
A = SMatrix{2, 2, Float32}(cos(2r), sin(2r), sin(2r), -cos(2r))
255+
return round.(A; digits = 12)
256+
end
257+
258+
259+
"""
260+
FlipDim{N}(dim)
261+
262+
Reflect `N` dimensional data along the axis of dimension `dim`. Must satisfy 1 <= `dim` <= `N`.
263+
264+
## Examples
265+
266+
```julia
267+
tfm = FlipDim{2}(1)
268+
```
269+
"""
270+
struct FlipDim{N} <: ProjectiveTransform
271+
dim::Int
272+
FlipDim{N}(dim) where N = 1 <= dim <= N ? new{N}(dim) : error("invalid dimension")
273+
end
274+
252275
"""
253-
Reflect(180)
276+
FlipX{N}()
277+
278+
Flip `N` dimensional data along the x-axis. 2D images use (r, c) = (y, x)
279+
convention such that x-axis flips occur along the second dimension. For N >= 3,
280+
x-axis flips occur along the first dimension.
254281
"""
255-
FlipX() = Reflect(180)
282+
struct FlipX{N}
283+
FlipX{N}() where N = FlipDim{N}(N==2 ? 2 : 1)
284+
end
285+
256286
"""
257-
Reflect(90)
287+
FlipY{N}()
288+
289+
Flip `N` dimensional data along the y-axis. 2D images use (r, c) = (y, x)
290+
convention such that y-axis flips occur along the first dimension. For N >= 3,
291+
y-axis flips occur along the second dimension.
258292
"""
259-
FlipY() = Reflect(90)
293+
struct FlipY{N}
294+
FlipY{N}() where N = FlipDim{N}(N==2 ? 1 : 2)
295+
end
260296

261-
function reflectionmatrix(r)
262-
A = SMatrix{2, 2, Float32}(cos(2r), sin(2r), sin(2r), -cos(2r))
263-
return round.(A; digits = 12)
297+
"""
298+
FlipZ{N}()
299+
300+
Flip `N` dimensional data along the z-axis.
301+
"""
302+
struct FlipZ{N}
303+
FlipZ{N}() where N = FlipDim{N}(3)
304+
end
305+
306+
function getprojection(tfm::FlipDim{N}, bounds::Bounds{N}; randstate = nothing) where N
307+
arr = 1I(N)
308+
arr[tfm.dim, tfm.dim] = -1
309+
M = SMatrix{N, N, Float32}(arr)
310+
return DataAugmentation.centered(LinearMap(M), bounds)
264311
end
265312

266313

@@ -281,8 +328,8 @@ at one.
281328
"""
282329
struct PinOrigin <: ProjectiveTransform end
283330

284-
function getprojection(::PinOrigin, bounds; randstate = nothing)
285-
p = (-SVector{2, Float32}(minimum.(bounds.rs))) .+ 1
331+
function getprojection(::PinOrigin, bounds::Bounds{N}; randstate = nothing) where N
332+
p = (-SVector{N, Float32}(minimum.(bounds.rs))) .+ 1
286333
P = Translation(p)
287334
return P
288335
end

src/projective/compose.jl

+3
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ compose(composed::ComposedProjectiveTransform, tfm::ProjectiveTransform) =
2626
compose(tfm::ProjectiveTransform, composed::ComposedProjectiveTransform) =
2727
ComposedProjectiveTransform(tfm, composed.tfms...)
2828

29+
compose(composed1::ComposedProjectiveTransform, composed2::ComposedProjectiveTransform) =
30+
ComposedProjectiveTransform(composed1.tfms..., composed2.tfms...)
31+
2932

3033
# The random state is collected from the transformations that make up the
3134
# `ComposedProjectiveTransform`:

test/projective/affine.jl

+59-10
Original file line numberDiff line numberDiff line change
@@ -192,16 +192,60 @@ include("../imports.jl")
192192
@test_nowarn apply!(buffer, tfm, image2)
193193
end
194194

195-
@testset ExtendedTestSet "`RandomCrop` correct indices" begin
196-
# Flipping and cropping should be the same as reverse-indexing
197-
# the flipped dimension
198-
tfm = FlipX() |> RandomCrop((64, 64)) |> PinOrigin()
199-
img = rand(RGB, 64, 64)
195+
196+
@testset ExtendedTestSet "FlipX 2D correct indices" begin
197+
tfm = FlipX{2}() |> RandomCrop((10,10)) |> PinOrigin()
198+
img = rand(RGB, 10, 10)
199+
item = Image(img)
200+
@test_nowarn titem = apply(tfm, item)
201+
titem = apply(tfm, item)
202+
@test itemdata(titem) == img[:, end:-1:1]
203+
end
204+
205+
@testset ExtendedTestSet "FlipY 2D correct indices" begin
206+
tfm = FlipY{2}() |> RandomCrop((10,10)) |> PinOrigin()
207+
img = rand(RGB, 10, 10)
208+
item = Image(img)
209+
@test_nowarn titem = apply(tfm, item)
210+
titem = apply(tfm, item)
211+
@test itemdata(titem) == img[end:-1:1, :]
212+
end
213+
214+
215+
@testset ExtendedTestSet "FlipX 3D correct indices" begin
216+
tfm = FlipX{3}() |> RandomCrop((10,10,10)) |> PinOrigin()
217+
img = rand(RGB, 10, 10, 10)
200218
item = Image(img)
219+
@test_nowarn titem = apply(tfm, item)
201220
titem = apply(tfm, item)
202-
timg = itemdata(titem)
203-
rimg = img[:, end:-1:1]
204-
@test titem.data == rimg
221+
@test itemdata(titem) == img[end:-1:1, :, :]
222+
end
223+
224+
@testset ExtendedTestSet "FlipY 3D correct indices" begin
225+
tfm = FlipY{3}() |> RandomCrop((10,10,10)) |> PinOrigin()
226+
img = rand(RGB, 10, 10, 10)
227+
item = Image(img)
228+
@test_nowarn titem = apply(tfm, item)
229+
titem = apply(tfm, item)
230+
@test itemdata(titem) == img[:, end:-1:1, :]
231+
end
232+
233+
@testset ExtendedTestSet "FlipZ 3D correct indices" begin
234+
tfm = FlipZ{3}() |> RandomCrop((10,10,10)) |> PinOrigin()
235+
img = rand(RGB, 10, 10, 10)
236+
item = Image(img)
237+
@test_nowarn titem = apply(tfm, item)
238+
titem = apply(tfm, item)
239+
@test itemdata(titem) == img[:, :, end:-1:1]
240+
end
241+
242+
@testset ExtendedTestSet "Double flip is identity" begin
243+
tfm = FlipZ{3}() |> FlipZ{3}() |> RandomCrop((10,10,10)) |> PinOrigin()
244+
img = rand(RGB, 10, 10, 10)
245+
item = Image(img)
246+
@test_nowarn titem = apply(tfm, item)
247+
titem = apply(tfm, item)
248+
@test itemdata(titem) == img
205249
end
206250
end
207251

@@ -210,8 +254,8 @@ end
210254
@testset ExtendedTestSet "2D" begin
211255
tfms = compose(
212256
Rotate(10),
213-
FlipX(),
214-
FlipY(),
257+
FlipX{2}(),
258+
FlipY{2}(),
215259
ScaleRatio((.8, .8)),
216260
WarpAffine(0.1),
217261
Zoom((1., 1.2)),
@@ -230,9 +274,14 @@ end
230274
)
231275

232276
tfms = compose(
277+
FlipX{3}(),
278+
FlipY{3}(),
279+
FlipZ{3}(),
280+
ScaleFixed((30, 40, 50)),
233281
Rotate(10, 20, 30),
234282
ScaleRatio((.8, .8, .8)),
235283
ScaleKeepAspect((12, 10, 10)),
284+
Zoom((1., 1.2)),
236285
RandomCrop((10, 10, 10))
237286
)
238287
testprojective(tfms, items)

0 commit comments

Comments
 (0)