Skip to content

Commit 5b29974

Browse files
authored
implement transform_bounding_boxes for random_shear (#20704)
1 parent 476a664 commit 5b29974

File tree

2 files changed

+266
-2
lines changed

2 files changed

+266
-2
lines changed

keras/src/layers/preprocessing/image_preprocessing/random_shear.py

+140-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,14 @@
22
from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501
33
BaseImagePreprocessingLayer,
44
)
5+
from keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import ( # noqa: E501
6+
clip_to_image_size,
7+
)
8+
from keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import ( # noqa: E501
9+
convert_format,
10+
)
511
from keras.src.random.seed_generator import SeedGenerator
12+
from keras.src.utils import backend_utils
613

714

815
@keras_export("keras.layers.RandomShear")
@@ -175,7 +182,7 @@ def get_random_transformation(self, data, training=True, seed=None):
175182
)
176183
* invert
177184
)
178-
return {"shear_factor": shear_factor}
185+
return {"shear_factor": shear_factor, "input_shape": images_shape}
179186

180187
def transform_images(self, images, transformation, training=True):
181188
images = self.backend.cast(images, self.compute_dtype)
@@ -231,13 +238,144 @@ def _get_shear_matrix(self, shear_factors):
231238
def transform_labels(self, labels, transformation, training=True):
232239
return labels
233240

241+
def get_transformed_x_y(self, x, y, transform):
242+
a0, a1, a2, b0, b1, b2, c0, c1 = self.backend.numpy.split(
243+
transform, 8, axis=-1
244+
)
245+
246+
k = c0 * x + c1 * y + 1
247+
x_transformed = (a0 * x + a1 * y + a2) / k
248+
y_transformed = (b0 * x + b1 * y + b2) / k
249+
return x_transformed, y_transformed
250+
251+
def get_shifted_bbox(self, bounding_boxes, w_shift_factor, h_shift_factor):
252+
bboxes = bounding_boxes["boxes"]
253+
x1, x2, x3, x4 = self.backend.numpy.split(bboxes, 4, axis=-1)
254+
255+
w_shift_factor = self.backend.convert_to_tensor(
256+
w_shift_factor, dtype=x1.dtype
257+
)
258+
h_shift_factor = self.backend.convert_to_tensor(
259+
h_shift_factor, dtype=x1.dtype
260+
)
261+
262+
if len(bboxes.shape) == 3:
263+
w_shift_factor = self.backend.numpy.expand_dims(w_shift_factor, -1)
264+
h_shift_factor = self.backend.numpy.expand_dims(h_shift_factor, -1)
265+
266+
bounding_boxes["boxes"] = self.backend.numpy.concatenate(
267+
[
268+
x1 - w_shift_factor,
269+
x2 - h_shift_factor,
270+
x3 - w_shift_factor,
271+
x4 - h_shift_factor,
272+
],
273+
axis=-1,
274+
)
275+
return bounding_boxes
276+
234277
def transform_bounding_boxes(
235278
self,
236279
bounding_boxes,
237280
transformation,
238281
training=True,
239282
):
240-
raise NotImplementedError
283+
def _get_height_width(transformation):
284+
if self.data_format == "channels_first":
285+
height_axis = -2
286+
width_axis = -1
287+
else:
288+
height_axis = -3
289+
width_axis = -2
290+
input_height, input_width = (
291+
transformation["input_shape"][height_axis],
292+
transformation["input_shape"][width_axis],
293+
)
294+
return input_height, input_width
295+
296+
if training:
297+
if backend_utils.in_tf_graph():
298+
self.backend.set_backend("tensorflow")
299+
300+
input_height, input_width = _get_height_width(transformation)
301+
302+
bounding_boxes = convert_format(
303+
bounding_boxes,
304+
source=self.bounding_box_format,
305+
target="rel_xyxy",
306+
height=input_height,
307+
width=input_width,
308+
dtype=self.compute_dtype,
309+
)
310+
311+
bounding_boxes = self._shear_bboxes(bounding_boxes, transformation)
312+
313+
bounding_boxes = clip_to_image_size(
314+
bounding_boxes=bounding_boxes,
315+
height=input_height,
316+
width=input_width,
317+
bounding_box_format="rel_xyxy",
318+
)
319+
320+
bounding_boxes = convert_format(
321+
bounding_boxes,
322+
source="rel_xyxy",
323+
target=self.bounding_box_format,
324+
height=input_height,
325+
width=input_width,
326+
dtype=self.compute_dtype,
327+
)
328+
329+
self.backend.reset()
330+
331+
return bounding_boxes
332+
333+
def _shear_bboxes(self, bounding_boxes, transformation):
334+
shear_factor = self.backend.cast(
335+
transformation["shear_factor"], dtype=self.compute_dtype
336+
)
337+
shear_x_amount, shear_y_amount = self.backend.numpy.split(
338+
shear_factor, 2, axis=-1
339+
)
340+
341+
x1, y1, x2, y2 = self.backend.numpy.split(
342+
bounding_boxes["boxes"], 4, axis=-1
343+
)
344+
x1 = self.backend.numpy.squeeze(x1, axis=-1)
345+
y1 = self.backend.numpy.squeeze(y1, axis=-1)
346+
x2 = self.backend.numpy.squeeze(x2, axis=-1)
347+
y2 = self.backend.numpy.squeeze(y2, axis=-1)
348+
349+
if shear_x_amount is not None:
350+
x1_top = x1 - (shear_x_amount * y1)
351+
x1_bottom = x1 - (shear_x_amount * y2)
352+
x1 = self.backend.numpy.where(shear_x_amount < 0, x1_top, x1_bottom)
353+
354+
x2_top = x2 - (shear_x_amount * y1)
355+
x2_bottom = x2 - (shear_x_amount * y2)
356+
x2 = self.backend.numpy.where(shear_x_amount < 0, x2_bottom, x2_top)
357+
358+
if shear_y_amount is not None:
359+
y1_left = y1 - (shear_y_amount * x1)
360+
y1_right = y1 - (shear_y_amount * x2)
361+
y1 = self.backend.numpy.where(shear_y_amount > 0, y1_right, y1_left)
362+
363+
y2_left = y2 - (shear_y_amount * x1)
364+
y2_right = y2 - (shear_y_amount * x2)
365+
y2 = self.backend.numpy.where(shear_y_amount > 0, y2_left, y2_right)
366+
367+
boxes = self.backend.numpy.concatenate(
368+
[
369+
self.backend.numpy.expand_dims(x1, axis=-1),
370+
self.backend.numpy.expand_dims(y1, axis=-1),
371+
self.backend.numpy.expand_dims(x2, axis=-1),
372+
self.backend.numpy.expand_dims(y2, axis=-1),
373+
],
374+
axis=-1,
375+
)
376+
bounding_boxes["boxes"] = boxes
377+
378+
return bounding_boxes
241379

242380
def transform_segmentation_masks(
243381
self, segmentation_masks, transformation, training=True

keras/src/layers/preprocessing/image_preprocessing/random_shear_test.py

+126
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import numpy as np
22
import pytest
3+
from absl.testing import parameterized
34
from tensorflow import data as tf_data
45

56
import keras
67
from keras.src import backend
78
from keras.src import layers
89
from keras.src import testing
10+
from keras.src.utils import backend_utils
911

1012

1113
class RandomShearTest(testing.TestCase):
@@ -74,3 +76,127 @@ def test_tf_data_compatibility(self):
7476
ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)
7577
for output in ds.take(1):
7678
output.numpy()
79+
80+
@parameterized.named_parameters(
81+
(
82+
"with_x_shift",
83+
[[1.0, 0.0]],
84+
[[[0.0, 1.0, 3.2, 3.0], [1.2, 4.0, 4.8, 6.0]]],
85+
),
86+
(
87+
"with_y_shift",
88+
[[0.0, 1.0]],
89+
[[[2.0, 0.0, 4.0, 0.5], [6.0, 0.0, 8.0, 0.0]]],
90+
),
91+
(
92+
"with_xy_shift",
93+
[[1.0, 1.0]],
94+
[[[0.0, 0.0, 3.2, 3.5], [1.2, 0.0, 4.8, 4.5]]],
95+
),
96+
)
97+
def test_random_shear_bounding_boxes(self, translation, expected_boxes):
98+
data_format = backend.config.image_data_format()
99+
if data_format == "channels_last":
100+
image_shape = (10, 8, 3)
101+
else:
102+
image_shape = (3, 10, 8)
103+
input_image = np.random.random(image_shape)
104+
bounding_boxes = {
105+
"boxes": np.array(
106+
[
107+
[2, 1, 4, 3],
108+
[6, 4, 8, 6],
109+
]
110+
),
111+
"labels": np.array([[1, 2]]),
112+
}
113+
input_data = {"images": input_image, "bounding_boxes": bounding_boxes}
114+
layer = layers.RandomShear(
115+
x_factor=0.5,
116+
y_factor=0.5,
117+
data_format=data_format,
118+
seed=42,
119+
bounding_box_format="xyxy",
120+
)
121+
122+
transformation = {
123+
"shear_factor": backend_utils.convert_tf_tensor(
124+
np.array(translation)
125+
),
126+
"input_shape": image_shape,
127+
}
128+
output = layer.transform_bounding_boxes(
129+
input_data["bounding_boxes"],
130+
transformation=transformation,
131+
training=True,
132+
)
133+
134+
self.assertAllClose(output["boxes"], expected_boxes)
135+
136+
@parameterized.named_parameters(
137+
(
138+
"with_x_shift",
139+
[[1.0, 0.0]],
140+
[[[0.0, 1.0, 3.2, 3.0], [1.2, 4.0, 4.8, 6.0]]],
141+
),
142+
(
143+
"with_y_shift",
144+
[[0.0, 1.0]],
145+
[[[2.0, 0.0, 4.0, 0.5], [6.0, 0.0, 8.0, 0.0]]],
146+
),
147+
(
148+
"with_xy_shift",
149+
[[1.0, 1.0]],
150+
[[[0.0, 0.0, 3.2, 3.5], [1.2, 0.0, 4.8, 4.5]]],
151+
),
152+
)
153+
def test_random_shear_tf_data_bounding_boxes(
154+
self, translation, expected_boxes
155+
):
156+
data_format = backend.config.image_data_format()
157+
if backend.config.image_data_format() == "channels_last":
158+
image_shape = (1, 10, 8, 3)
159+
else:
160+
image_shape = (1, 3, 10, 8)
161+
input_image = np.random.random(image_shape)
162+
bounding_boxes = {
163+
"boxes": np.array(
164+
[
165+
[
166+
[2, 1, 4, 3],
167+
[6, 4, 8, 6],
168+
]
169+
]
170+
),
171+
"labels": np.array([[1, 2]]),
172+
}
173+
174+
input_data = {"images": input_image, "bounding_boxes": bounding_boxes}
175+
176+
ds = tf_data.Dataset.from_tensor_slices(input_data)
177+
layer = layers.RandomShear(
178+
x_factor=0.5,
179+
y_factor=0.5,
180+
data_format=data_format,
181+
seed=42,
182+
bounding_box_format="xyxy",
183+
)
184+
185+
transformation = {
186+
"shear_factor": backend_utils.convert_tf_tensor(
187+
np.array(translation)
188+
),
189+
"input_shape": image_shape,
190+
}
191+
192+
ds = ds.map(
193+
lambda x: layer.transform_bounding_boxes(
194+
x["bounding_boxes"],
195+
transformation=transformation,
196+
training=True,
197+
)
198+
)
199+
200+
output = next(iter(ds))
201+
expected_boxes = np.array(expected_boxes)
202+
self.assertAllClose(output["boxes"], expected_boxes)

0 commit comments

Comments
 (0)