Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix tutorial notebook not using strings for optimizer #8

Merged
merged 5 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ The library provides:
### Pip
In order to use th file `pyproject.toml` it is necessary to guarantee `pip>=21.8`. If necessary upgrade `pip` using `python -m pip install --upgrade pip`.

Install the library with `pip install git+https://github.com/IBM/terratorch.git`
Install the library with `pip install git+https://github.com/IBM/terratorch.git`.

TerraTorch requires gdal to be installed, which can be quite a complex process. If you don't have GDAL set up on your system, we reccomend using a conda environment and installing it with `conda install -c conda-forge gdal`.

To install as a developer (e.g. to extend the library) clone this repo, install dependencies using `pip install -r requirements.txt` and run `pip install -e .`

Expand Down
4 changes: 3 additions & 1 deletion docs/quick_start.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Quick start
We suggest using Python==3.10.
To get started, make sure to have `PyTorch >= 2` [installed](https://pytorch.org/get-started/locally/).
To get started, make sure to have [PyTorch](https://pytorch.org/get-started/locally/) >= 2.0.0 and [GDAL](https://gdal.org/index.html) installed.

Installing GDAL can be quite a complex process. If you don't have GDAL set up on your system, we reccomend using a conda environment and installing it with `conda install -c conda-forge gdal`.

Install TerraTorch with `pip install git+https://github.com/IBM/terratorch.git`

Expand Down
208 changes: 29 additions & 179 deletions examples/notebooks/Tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,7 @@
"execution_count": 1,
"id": "5d049232-f4b1-473d-aac3-0b3539905b03",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/cpi/opt/miniconda3/envs/terratorch_os/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"outputs": [],
"source": [
"import torch\n",
"import timm\n",
Expand All @@ -39,7 +30,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 3,
"id": "8dcdfa85-8e43-4db0-9ddf-cb11c5544942",
"metadata": {},
"outputs": [
Expand All @@ -61,7 +52,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"id": "13134d11-c477-47c2-998d-a9acb084e2e7",
"metadata": {},
"outputs": [],
Expand All @@ -88,7 +79,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 5,
"id": "c1040a09-fa6c-40e2-9a6f-8a3a60c520b6",
"metadata": {},
"outputs": [
Expand All @@ -115,7 +106,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 6,
"id": "97f19710-1bec-4cb1-a5ec-d9a51556dd5e",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -162,7 +153,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 7,
"id": "c48e98bf-c748-47c6-b56f-98c96304ed1b",
"metadata": {},
"outputs": [],
Expand All @@ -175,27 +166,6 @@
")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "5936e883-d6c0-470a-95e0-51fc121274ad",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<HLSBands.RED: 'RED'>, <HLSBands.GREEN: 'GREEN'>, <HLSBands.BLUE: 'BLUE'>, 14]"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.model_bands"
]
},
{
"cell_type": "code",
"execution_count": 8,
Expand Down Expand Up @@ -314,12 +284,14 @@
"\n",
"Alternatively, leverage one of our generic data modules.\n",
"\n",
"Datamodules package train, test and validation datasets as well as any transforms done.\n"
"Datamodules package train, test and validation datasets as well as any transforms done.\n",
"\n",
"Below is an example for a datamodule that can be used for pixel-wise regression tasks.\n"
]
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 11,
"id": "51a2e286-e5b2-4cf4-8aa7-44e874929117",
"metadata": {},
"outputs": [],
Expand All @@ -334,21 +306,27 @@
"]\n",
"\n",
"train_val_test_labels = {\n",
" \"train_label_data_root\": \"/dccstor/hhr-weather/latest_filters_all_agb_patches_tts_clipped_0_500/train_labels\",\n",
" \"val_label_data_root\": \"/dccstor/hhr-weather/latest_filters_all_agb_patches_tts_clipped_0_500/val_labels\",\n",
" \"test_label_data_root\": \"/dccstor/hhr-weather/latest_filters_all_agb_patches_tts_clipped_0_500/test_labels\",\n",
" \"train_label_data_root\": \"<path>/train_labels\",\n",
" \"val_label_data_root\": \"<path>/val_labels\",\n",
" \"test_label_data_root\": \"<path>/test_labels\",\n",
"}\n",
"\n",
"means = [385.88501817, 714.60615207, 658.96267376, 3314.57774238, 2238.71812558, 1250.00982518]\n",
"stds = [264.62872, 355.62848, 504.54855, 898.4953, 947.22894, 828.1297]\n",
"means = [] # float array of means\n",
"stds = [] # float array of stds\n",
"datamodule = GenericNonGeoPixelwiseRegressionDataModule(\n",
" batch_size,\n",
" num_workers,\n",
" *train_val_test,\n",
" means,\n",
" stds,\n",
" **train_val_test_labels,\n",
"\n",
" # if transforms are defined with Albumentations, you can pass them here\n",
" # train_transform=train_transform,\n",
" # val_transform=val_transform,\n",
" # test_transform=test_transform,\n",
"\n",
" # edit the below for your usecase\n",
" dataset_bands=[\n",
" -1,\n",
" HLSBands.BLUE,\n",
Expand Down Expand Up @@ -386,14 +364,14 @@
"\n",
"They build on the model factory we introduced previously and are able to take any. To use a task with a model not supported by a currently existing model factory, simply create your own model factory!\n",
"\n",
"Let's create a Trainer for PixelWise Regression\n",
"Let's create a Trainer for Pixel-Wise Regression\n",
"\n",
"We also show how to use the popular CosineLrDecay scheduler into training"
]
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": null,
"id": "1fd3f466-5ee4-4897-8e1e-1cfafe6c4b04",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -430,9 +408,9 @@
" aux_loss={\"fcn_aux_head\": 0.4},\n",
" lr=lr,\n",
" ignore_index=-1,\n",
" optimizer=torch.optim.AdamW,\n",
" optimizer=\"AdamW\",\n",
" optimizer_hparams={\"weight_decay\": 0.05},\n",
" scheduler=OneCycleLR,\n",
" scheduler=\"OneCycleLR\",\n",
" scheduler_hparams={\n",
" \"max_lr\": lr,\n",
" \"epochs\": epochs,\n",
Expand Down Expand Up @@ -460,55 +438,10 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": null,
"id": "d4db99b3-72bb-46ba-b149-49493529d714",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">Epoch 0/0 </span> <span style=\"color: #6206e0; text-decoration-color: #6206e0\">━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━</span> <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">1502/1502</span> <span style=\"color: #8a8a8a; text-decoration-color: #8a8a8a\">0:01:45 • 0:00:00</span> <span style=\"color: #b2b2b2; text-decoration-color: #b2b2b2\">14.31it/s</span> <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">v_num: 1.000</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[37mEpoch 0/0 \u001b[0m \u001b[38;2;98;6;224m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[37m1502/1502\u001b[0m \u001b[38;5;245m0:01:45 • 0:00:00\u001b[0m \u001b[38;5;249m14.31it/s\u001b[0m \u001b[37mv_num: 1.000\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO: `Trainer.fit` stopped: `max_epochs=1` reached.\n",
"INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=1` reached.\n"
]
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
],
"text/plain": []
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"outputs": [],
"source": [
"import os\n",
"from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping, RichProgressBar, LearningRateMonitor\n",
Expand Down Expand Up @@ -539,93 +472,10 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": null,
"id": "e977b153-e9e8-4ef0-a3ff-714ff117a94e",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">Testing</span> <span style=\"color: #6206e0; text-decoration-color: #6206e0\">━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━</span> <span style=\"color: #c0c0c0; text-decoration-color: #c0c0c0\">187/187</span> <span style=\"color: #8a8a8a; text-decoration-color: #8a8a8a\">0:00:14 • 0:00:00</span> <span style=\"color: #b2b2b2; text-decoration-color: #b2b2b2\">12.77it/s</span> \n",
"</pre>\n"
],
"text/plain": [
"\u001b[37mTesting\u001b[0m \u001b[38;2;98;6;224m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[37m187/187\u001b[0m \u001b[38;5;245m0:00:14 • 0:00:00\u001b[0m \u001b[38;5;249m12.77it/s\u001b[0m \n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
"┃<span style=\"font-weight: bold\"> Test metric </span>┃<span style=\"font-weight: bold\"> DataLoader 0 </span>┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
"│<span style=\"color: #008080; text-decoration-color: #008080\"> test/MAE </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> 74.75745391845703 </span>│\n",
"│<span style=\"color: #008080; text-decoration-color: #008080\"> test/MSE </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> 13011.7548828125 </span>│\n",
"│<span style=\"color: #008080; text-decoration-color: #008080\"> test/RMSE </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> 114.06907653808594 </span>│\n",
"│<span style=\"color: #008080; text-decoration-color: #008080\"> test/decode_head_epoch </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> 95.1190185546875 </span>│\n",
"│<span style=\"color: #008080; text-decoration-color: #008080\"> test/fcn_aux_head_epoch </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> 87.39353942871094 </span>│\n",
"│<span style=\"color: #008080; text-decoration-color: #008080\"> test/loss </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> 130.076416015625 </span>│\n",
"└───────────────────────────┴───────────────────────────┘\n",
"</pre>\n"
],
"text/plain": [
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
"┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
"│\u001b[36m \u001b[0m\u001b[36m test/MAE \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 74.75745391845703 \u001b[0m\u001b[35m \u001b[0m│\n",
"│\u001b[36m \u001b[0m\u001b[36m test/MSE \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 13011.7548828125 \u001b[0m\u001b[35m \u001b[0m│\n",
"│\u001b[36m \u001b[0m\u001b[36m test/RMSE \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 114.06907653808594 \u001b[0m\u001b[35m \u001b[0m│\n",
"│\u001b[36m \u001b[0m\u001b[36m test/decode_head_epoch \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 95.1190185546875 \u001b[0m\u001b[35m \u001b[0m│\n",
"│\u001b[36m \u001b[0m\u001b[36m test/fcn_aux_head_epoch \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 87.39353942871094 \u001b[0m\u001b[35m \u001b[0m│\n",
"│\u001b[36m \u001b[0m\u001b[36m test/loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 130.076416015625 \u001b[0m\u001b[35m \u001b[0m│\n",
"└───────────────────────────┴───────────────────────────┘\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
],
"text/plain": []
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"[{'test/loss': 130.076416015625,\n",
" 'test/decode_head_epoch': 95.1190185546875,\n",
" 'test/fcn_aux_head_epoch': 87.39353942871094,\n",
" 'test/MAE': 74.75745391845703,\n",
" 'test/MSE': 13011.7548828125,\n",
" 'test/RMSE': 114.06907653808594}]"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"trainer.test(model=task, datamodule=datamodule)"
]
Expand Down
Loading