diff --git a/terratorch/tasks/regression_tasks.py b/terratorch/tasks/regression_tasks.py index 9849b0b2..118298a0 100644 --- a/terratorch/tasks/regression_tasks.py +++ b/terratorch/tasks/regression_tasks.py @@ -384,12 +384,12 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> T other_keys = batch.keys() - {"image", "mask", "filename"} rest = {k: batch[k] for k in other_keys} - def model_forward(x): + def model_forward(x, **kwargs): return self(x).output if self.tiled_inference_parameters: # TODO: tiled inference does not work with additional input data (**rest) - y_hat: Tensor = tiled_inference(model_forward, x, 1, self.tiled_inference_parameters) + y_hat: Tensor = tiled_inference(model_forward, x, 1, self.tiled_inference_parameters, **rest) else: y_hat: Tensor = self(x, **rest).output return y_hat, file_names diff --git a/terratorch/tasks/segmentation_tasks.py b/terratorch/tasks/segmentation_tasks.py index 882f5c86..6064cdf8 100644 --- a/terratorch/tasks/segmentation_tasks.py +++ b/terratorch/tasks/segmentation_tasks.py @@ -345,18 +345,16 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> T rest = {k: batch[k] for k in other_keys} - model_output: ModelOutput = self(x, **rest) - - def model_forward(x): - return self(x).output + def model_forward(x, **kwargs): + return self(x, **kwargs).output if self.tiled_inference_parameters: y_hat: Tensor = tiled_inference( - # TODO: tiled inference does not work with additional input data (**rest) model_forward, x, self.hparams["model_args"]["num_classes"], self.tiled_inference_parameters, + **rest, ) else: y_hat: Tensor = self(x, **rest).output diff --git a/terratorch/tasks/tiled_inference.py b/terratorch/tasks/tiled_inference.py index 28d93140..1e130089 100644 --- a/terratorch/tasks/tiled_inference.py +++ b/terratorch/tasks/tiled_inference.py @@ -46,6 +46,7 @@ def tiled_inference( input_batch: torch.Tensor, out_channels: int, inference_parameters: TiledInferenceParameters, + **kwargs ) -> torch.Tensor: """ Like divide an image into (potentially) overlapping tiles and perform inference on them. @@ -163,7 +164,7 @@ def tiled_inference( end = min(len(coordinates_and_inputs), start + process_batch_size) batch = coordinates_and_inputs[start:end] tensor_input = torch.stack([b.input_data for b in batch], dim=0) - output = model_forward(tensor_input) + output = model_forward(tensor_input, **kwargs) output = [output[i] for i in range(len(batch))] for batch_input, predicted in zip(batch, output, strict=True): if batch_input.output_crop is not None: