Skip to content

Commit 6d1f127

Browse files
Trying to deal with possible errors during the test step of the segmentation task by invoking tiled inference
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
1 parent efec27b commit 6d1f127

File tree

1 file changed

+22
-1
lines changed

1 file changed

+22
-1
lines changed

terratorch/tasks/segmentation_tasks.py

+22-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from terratorch.tasks.optimizer_factory import optimizer_factory
1919
from terratorch.tasks.tiled_inference import TiledInferenceParameters, tiled_inference
2020
from terratorch.tasks.base_task import TerraTorchTask
21+
from terratorch.models.model import ModelOutput
2122

2223
BATCH_IDX_FOR_VALIDATION_PLOTTING = 10
2324

@@ -266,7 +267,27 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None
266267

267268
rest = {k: batch[k] for k in other_keys}
268269

269-
model_output: ModelOutput = self(x, **rest)
270+
271+
def model_forward(x, **kwargs):
272+
return self(x, **kwargs).output
273+
274+
# When the input sample cannot be fit on memory for some reason
275+
# the tiled inference is automatically invoked.
276+
try:
277+
model_output: ModelOutput = self(x, **rest)
278+
except Exception:
279+
if self.tiled_inference_parameters:
280+
y_hat: Tensor = tiled_inference(
281+
model_forward,
282+
x,
283+
self.hparams["model_args"]["num_classes"],
284+
self.tiled_inference_parameters,
285+
**rest,
286+
)
287+
model_output = ModelOutput(mask=y_hat)
288+
else:
289+
raise Exception("You need to define a configuration for the tiled inference.")
290+
270291
if dataloader_idx >= len(self.test_loss_handler):
271292
msg = "You are returning more than one test dataloader but not defining enough test_dataloaders_names."
272293
raise ValueError(msg)

0 commit comments

Comments
 (0)