|
2 | 2 | from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501
|
3 | 3 | BaseImagePreprocessingLayer,
|
4 | 4 | )
|
| 5 | +from keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import ( # noqa: E501 |
| 6 | + clip_to_image_size, |
| 7 | +) |
| 8 | +from keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import ( # noqa: E501 |
| 9 | + convert_format, |
| 10 | +) |
5 | 11 | from keras.src.random.seed_generator import SeedGenerator
|
| 12 | +from keras.src.utils import backend_utils |
6 | 13 |
|
7 | 14 |
|
8 | 15 | @keras_export("keras.layers.RandomShear")
|
@@ -175,7 +182,7 @@ def get_random_transformation(self, data, training=True, seed=None):
|
175 | 182 | )
|
176 | 183 | * invert
|
177 | 184 | )
|
178 |
| - return {"shear_factor": shear_factor} |
| 185 | + return {"shear_factor": shear_factor, "input_shape": images_shape} |
179 | 186 |
|
180 | 187 | def transform_images(self, images, transformation, training=True):
|
181 | 188 | images = self.backend.cast(images, self.compute_dtype)
|
@@ -231,13 +238,144 @@ def _get_shear_matrix(self, shear_factors):
|
231 | 238 | def transform_labels(self, labels, transformation, training=True):
|
232 | 239 | return labels
|
233 | 240 |
|
| 241 | + def get_transformed_x_y(self, x, y, transform): |
| 242 | + a0, a1, a2, b0, b1, b2, c0, c1 = self.backend.numpy.split( |
| 243 | + transform, 8, axis=-1 |
| 244 | + ) |
| 245 | + |
| 246 | + k = c0 * x + c1 * y + 1 |
| 247 | + x_transformed = (a0 * x + a1 * y + a2) / k |
| 248 | + y_transformed = (b0 * x + b1 * y + b2) / k |
| 249 | + return x_transformed, y_transformed |
| 250 | + |
| 251 | + def get_shifted_bbox(self, bounding_boxes, w_shift_factor, h_shift_factor): |
| 252 | + bboxes = bounding_boxes["boxes"] |
| 253 | + x1, x2, x3, x4 = self.backend.numpy.split(bboxes, 4, axis=-1) |
| 254 | + |
| 255 | + w_shift_factor = self.backend.convert_to_tensor( |
| 256 | + w_shift_factor, dtype=x1.dtype |
| 257 | + ) |
| 258 | + h_shift_factor = self.backend.convert_to_tensor( |
| 259 | + h_shift_factor, dtype=x1.dtype |
| 260 | + ) |
| 261 | + |
| 262 | + if len(bboxes.shape) == 3: |
| 263 | + w_shift_factor = self.backend.numpy.expand_dims(w_shift_factor, -1) |
| 264 | + h_shift_factor = self.backend.numpy.expand_dims(h_shift_factor, -1) |
| 265 | + |
| 266 | + bounding_boxes["boxes"] = self.backend.numpy.concatenate( |
| 267 | + [ |
| 268 | + x1 - w_shift_factor, |
| 269 | + x2 - h_shift_factor, |
| 270 | + x3 - w_shift_factor, |
| 271 | + x4 - h_shift_factor, |
| 272 | + ], |
| 273 | + axis=-1, |
| 274 | + ) |
| 275 | + return bounding_boxes |
| 276 | + |
234 | 277 | def transform_bounding_boxes(
|
235 | 278 | self,
|
236 | 279 | bounding_boxes,
|
237 | 280 | transformation,
|
238 | 281 | training=True,
|
239 | 282 | ):
|
240 |
| - raise NotImplementedError |
| 283 | + def _get_height_width(transformation): |
| 284 | + if self.data_format == "channels_first": |
| 285 | + height_axis = -2 |
| 286 | + width_axis = -1 |
| 287 | + else: |
| 288 | + height_axis = -3 |
| 289 | + width_axis = -2 |
| 290 | + input_height, input_width = ( |
| 291 | + transformation["input_shape"][height_axis], |
| 292 | + transformation["input_shape"][width_axis], |
| 293 | + ) |
| 294 | + return input_height, input_width |
| 295 | + |
| 296 | + if training: |
| 297 | + if backend_utils.in_tf_graph(): |
| 298 | + self.backend.set_backend("tensorflow") |
| 299 | + |
| 300 | + input_height, input_width = _get_height_width(transformation) |
| 301 | + |
| 302 | + bounding_boxes = convert_format( |
| 303 | + bounding_boxes, |
| 304 | + source=self.bounding_box_format, |
| 305 | + target="rel_xyxy", |
| 306 | + height=input_height, |
| 307 | + width=input_width, |
| 308 | + dtype=self.compute_dtype, |
| 309 | + ) |
| 310 | + |
| 311 | + bounding_boxes = self._shear_bboxes(bounding_boxes, transformation) |
| 312 | + |
| 313 | + bounding_boxes = clip_to_image_size( |
| 314 | + bounding_boxes=bounding_boxes, |
| 315 | + height=input_height, |
| 316 | + width=input_width, |
| 317 | + bounding_box_format="rel_xyxy", |
| 318 | + ) |
| 319 | + |
| 320 | + bounding_boxes = convert_format( |
| 321 | + bounding_boxes, |
| 322 | + source="rel_xyxy", |
| 323 | + target=self.bounding_box_format, |
| 324 | + height=input_height, |
| 325 | + width=input_width, |
| 326 | + dtype=self.compute_dtype, |
| 327 | + ) |
| 328 | + |
| 329 | + self.backend.reset() |
| 330 | + |
| 331 | + return bounding_boxes |
| 332 | + |
| 333 | + def _shear_bboxes(self, bounding_boxes, transformation): |
| 334 | + shear_factor = self.backend.cast( |
| 335 | + transformation["shear_factor"], dtype=self.compute_dtype |
| 336 | + ) |
| 337 | + shear_x_amount, shear_y_amount = self.backend.numpy.split( |
| 338 | + shear_factor, 2, axis=-1 |
| 339 | + ) |
| 340 | + |
| 341 | + x1, y1, x2, y2 = self.backend.numpy.split( |
| 342 | + bounding_boxes["boxes"], 4, axis=-1 |
| 343 | + ) |
| 344 | + x1 = self.backend.numpy.squeeze(x1, axis=-1) |
| 345 | + y1 = self.backend.numpy.squeeze(y1, axis=-1) |
| 346 | + x2 = self.backend.numpy.squeeze(x2, axis=-1) |
| 347 | + y2 = self.backend.numpy.squeeze(y2, axis=-1) |
| 348 | + |
| 349 | + if shear_x_amount is not None: |
| 350 | + x1_top = x1 - (shear_x_amount * y1) |
| 351 | + x1_bottom = x1 - (shear_x_amount * y2) |
| 352 | + x1 = self.backend.numpy.where(shear_x_amount < 0, x1_top, x1_bottom) |
| 353 | + |
| 354 | + x2_top = x2 - (shear_x_amount * y1) |
| 355 | + x2_bottom = x2 - (shear_x_amount * y2) |
| 356 | + x2 = self.backend.numpy.where(shear_x_amount < 0, x2_bottom, x2_top) |
| 357 | + |
| 358 | + if shear_y_amount is not None: |
| 359 | + y1_left = y1 - (shear_y_amount * x1) |
| 360 | + y1_right = y1 - (shear_y_amount * x2) |
| 361 | + y1 = self.backend.numpy.where(shear_y_amount > 0, y1_right, y1_left) |
| 362 | + |
| 363 | + y2_left = y2 - (shear_y_amount * x1) |
| 364 | + y2_right = y2 - (shear_y_amount * x2) |
| 365 | + y2 = self.backend.numpy.where(shear_y_amount > 0, y2_left, y2_right) |
| 366 | + |
| 367 | + boxes = self.backend.numpy.concatenate( |
| 368 | + [ |
| 369 | + self.backend.numpy.expand_dims(x1, axis=-1), |
| 370 | + self.backend.numpy.expand_dims(y1, axis=-1), |
| 371 | + self.backend.numpy.expand_dims(x2, axis=-1), |
| 372 | + self.backend.numpy.expand_dims(y2, axis=-1), |
| 373 | + ], |
| 374 | + axis=-1, |
| 375 | + ) |
| 376 | + bounding_boxes["boxes"] = boxes |
| 377 | + |
| 378 | + return bounding_boxes |
241 | 379 |
|
242 | 380 | def transform_segmentation_masks(
|
243 | 381 | self, segmentation_masks, transformation, training=True
|
|
0 commit comments