Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WiP] allow users to run iterate using terratorch command #483

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
6 changes: 6 additions & 0 deletions docs/quick_start.md
Original file line number Diff line number Diff line change
Expand Up @@ -226,3 +226,9 @@ For inference, execute:
```sh
terratorch predict -c <path_to_config_file> --ckpt_path<path_to_checkpoint> --predict_output_dir <path_to_output_dir> --data.init_args.predict_data_root <path_to_input_dir> --data.init_args.predict_dataset_bands <all bands in the predicted dataset, e.g. [BLUE,GREEN,RED,NIR_NARROW,SWIR_1,SWIR_2,0]>
```

Users that want to optmize hyperparameters or repeat best experiment might be interest in in terratorch-iterate, a terratorch's plugin. For instance, to run terratorch-iterate to optimize hyperparameters, one can run:
```sh
terratorch iterate --hpo --config <path_to_config_file>
```
Please go to https://github.com/IBM/terratorch-iterate for further info.
91 changes: 89 additions & 2 deletions terratorch/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,98 @@

"""Command-line interface to TerraTorch."""

from jsonargparse import ArgumentParser
from terratorch.cli_tools import build_lightning_cli

import sys
from benchmark.benchmark_types import Defaults, Task
from benchmark.backbone_benchmark import benchmark_backbone, rerun_best_from_backbone
from typing import List, Any

def main():
_ = build_lightning_cli()

if sys.argv[1] in ["iterate"]:
print("Running terratorch-iterate...")
del sys.argv[1]
parser = ArgumentParser()

parser.add_argument('--defaults', type=Defaults) # to ignore model
parser.add_argument('--optimization_space', type=dict) # to ignore model
parser.add_argument('--experiment_name', type=str) # to ignore model
parser.add_argument('--run_name', type=str) # to ignore model
parser.add_argument('--save_models', type=bool) # to ignore model
parser.add_argument('--storage_uri', type=str) # to ignore model
parser.add_argument('--ray_storage_path', type=str) # to ignore model
parser.add_argument('--n_trials', type=int) # to ignore model
parser.add_argument('--run_repetitions', type=int) # to ignore model
parser.add_argument('--tasks', type=list[Task])
parser.add_argument("--config", action="config")
parser.add_argument("--hpo", help="optimize hyperparameters", action="store_true")
parser.add_argument("--repeat", help="repeat best experiments", action="store_true")

args = parser.parse_args()
print(args)
paths: List[Any] = args.config
path = paths[0]
print(f"args={args} path={path}")
repeat = args.repeat
hpo = args.hpo
config = parser.parse_path(path)

config_init = parser.instantiate_classes(config)
# validate the objects
experiment_name = config_init.experiment_name
assert isinstance(experiment_name, str), f"Error! {experiment_name=} is not a str"
run_name = config_init.run_name
if run_name is not None:
assert isinstance(run_name, str), f"Error! {run_name=} is not a str"
tasks = config_init.tasks
assert isinstance(tasks, list), f"Error! {tasks=} is not a list"
for t in tasks:
assert isinstance(t, Task), f"Error! {t=} is not a Task"
defaults = config_init.defaults
assert isinstance(defaults, Defaults), f"Error! {defaults=} is not a Defaults"
# defaults.trainer_args["max_epochs"] = 5
storage_uri = config_init.storage_uri
assert isinstance(storage_uri, str), f"Error! {storage_uri=} is not a str"

optimization_space = config_init.optimization_space
assert isinstance(
optimization_space, dict
), f"Error! {optimization_space=} is not a dict"
ray_storage_path = config_init.ray_storage_path
assert isinstance(ray_storage_path, str), f"Error! {ray_storage_path=} is not a str"

n_trials = config_init.n_trials
assert isinstance(n_trials, int) and n_trials > 0, f"Error! {n_trials=} is invalid"
run_repetitions = config_init.run_repetitions
if repeat and not hpo:
rerun_best_from_backbone(
defaults=defaults,
tasks=tasks,
experiment_name=experiment_name,
storage_uri=storage_uri,
optimization_space=optimization_space,
run_repetitions=run_repetitions,
)
else:
if not repeat and hpo:
run_repetitions = 0

# run_repetions is an optional parameter
benchmark_backbone(
defaults=defaults,
tasks=tasks,
experiment_name=experiment_name,
storage_uri=storage_uri,
ray_storage_path=ray_storage_path,
run_name=run_name,
optimization_space=optimization_space,
n_trials=n_trials,
run_repetitions=run_repetitions,
)
else:
print("run terratorch")
_ = build_lightning_cli()


if __name__ == "__main__":
Expand Down