diff --git a/.env.example b/.env.example index 8502ad5..33713cb 100644 --- a/.env.example +++ b/.env.example @@ -1,13 +1,14 @@ # Adjust this file for storing private and user specific environment variables, like keys or system paths. # rename it to ".env" (excluded from version control by default) -PROJECT_ROOT="path/to/aether" +PROJECT_ROOT="path/to/aether/" # path to your local aether repo TRAINER_PROFILE="gpu" # cpu/gpu/mps/ddp -HF_HOME="/path/to/huggingface/cache" # set or will default to './.cache/huggingface/' -DATA_DIR="../data/" # set orwill default to './data/' + #---------------------------- # OPTIONALS #---------------------------- +HF_HOME="${PROJECT_ROOT}/.cache/huggingface/" # set or will default to './.cache/huggingface/' +DATA_DIR="${PROJECT_ROOT}/data/" # set to your local data folder (for aether), or will default to '${PROJECT_ROOT}/data/' # Working directories # STORAGE_MODE=# or "shared" diff --git a/.gitignore b/.gitignore index 7f9682a..0bec17c 100644 --- a/.gitignore +++ b/.gitignore @@ -227,3 +227,4 @@ uv.lock notebooks/01-TvdP-tmp.ipynb */source/* *.tif # for now +..env.swp diff --git a/README.md b/README.md index 180bd29..c75c605 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,7 @@ This project develops an EO embedding/language model that can be used for explai ### Virtual environment -First, install dependencies in a venv using [uv](https://docs.astral.sh/uv/getting-started/installation/) +To install the dependencies in a venv using [uv](https://docs.astral.sh/uv/getting-started/installation/), first, clone the repo: ```bash # clone project @@ -31,12 +31,14 @@ git clone https://github.com/WUR-AI/aether cd aether ``` +Then, create a virtual environment (or alternatively via conda): ```bash # Create venv python3 -m venv .venv source .venv/bin/activate ``` +Then, install `uv` and use this to install all packages. ```bash # install uv manager pip install uv @@ -52,9 +54,16 @@ Note, running `uv sync` in the venv will always update the package to the most u ### Set paths -Next, create a file in your local repo parent folder `aether/` called `.env`. Copy the contents of `aether/env.example` and adjust the paths to your local system. **Important**: `DATA_DIR` should either point to `aether/data/` OR if it points to another folder (e.g., `my/local/data/`) then copy the contents of `aether/data/` to `my/local/data/` to ensure the butterfly use case runs using the provided example data. Other data will automatically be downloaded and organised by `pooch` if possible, or should be copied manually. +Next, create a file in your local repo parent folder `aether/` called `.env` and copy the contents of `aether/.env.example`: -Data folders should follow the following directory structure: +```bash +cp .env.example .env +``` +Adjust the paths in `.env` to your local system. **At a minimum, you should set PROJECT_ROOT!**. + +**Important**: `DATA_DIR` should either point to `aether/data/` (default setting) OR if it points to another folder (e.g., `my/local/data/`) then copy the contents of the `aether/data/` folder to `my/local/data/` to ensure the butterfly use case runs using the provided example data. Other data will automatically be downloaded and organised by `pooch` if possible into `DATA_DIR`, or should be copied manually. + +Data folders should follow the following directory structure within `DATA_DIR`: ``` ├── registry.txt <- Pooch config file, don't change. @@ -73,7 +82,18 @@ Data folders should follow the following directory structure: ├── other_dataset/ ``` -### Training +### Verify installation: + +To verify whether the installation was successful, run the tests in `aether/` using: +```bash +pytest --use-mock -m "not slow" +``` +which should pass all tests. + + +## Training + +Currently, we have implemented 2 models: a prediction model (that predicts target variables from EO data) and an alignment model (that aligns EO embeddings with text embeddings). Experiment configurations (such as choosing data, encoders, hyperparameters etc.) are managed through [Hydra](https://hydra.cc/) configurations. Define your experiment configurations in `configs/experiments/experiment_name.yaml`, for example to train predictive model with GeoCLIP coordinate encoder for the Butterfly data using `configs/experiments/prediction.yaml` (copied below) @@ -112,6 +132,8 @@ To execute this experiment run (inside your venv): python train.py experiment=prediction ``` +Please see the [Hydra](https://hydra.cc/) and [Hydra-Lightning template](https://github.com/ashleve/lightning-hydra-template) documentation for further examples of how to configure training runs. + ## Directory structure We follow the directory structure from the [Hydra-Lightning template](https://github.com/ashleve/lightning-hydra-template), which looks like: @@ -136,7 +158,7 @@ We follow the directory structure from the [Hydra-Lightning template](https://gi │ ├── eval.yaml <- Main config for evaluation │ └── train.yaml <- Main config for training │ -├── data <- Project data +├── data <- Project data (for aether, this can also be elsewhere, see environment paths). │ ├── logs <- Logs generated by hydra and lightning loggers │ diff --git a/data/registry.txt b/data/registry.txt index 84cb8d7..093244a 100644 --- a/data/registry.txt +++ b/data/registry.txt @@ -1,5 +1,5 @@ # S2BMS dataset (butterfly, ecology UC) -S2BMS.zip md5:af98bf3d1d0c4645c3c5787d49f59a70 doi:10.5281/zenodo.15198883 +S2BMS.zip md5:af98bf3d1d0c4645c3c5787d49f59a70 https://zenodo.org/records/15198884/files/S2BMS.zip?download=1 # Satbird (birds, ecology UC) Kenya.zip None https://drive.google.com/uc?id=19PSNaKQn1papoT-jN5FkzTp7Xf4M1juD diff --git a/data/s2bms/caption_templates/v1.json b/data/s2bms/caption_templates/v1.json new file mode 100644 index 0000000..128ae6f --- /dev/null +++ b/data/s2bms/caption_templates/v1.json @@ -0,0 +1,15 @@ +[ + "Location with , and .", + "Area with and .", + "Site with and , with and .", + "Location with and , with and .", + "Area with and , with and .", + "Site with and , with and .", + "Location with and , with , and .", + "Area with , with and .", + "Site with , with and .", + "Location with , with and .", + "Area with , with and .", + "Site with , with , and ." + +] diff --git a/pyproject.toml b/pyproject.toml index 5268d97..9cca0a5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ dependencies = [ "pre-commit>=4.5.1", "pooch>=1.8.2", "torchinfo>=1.8.0", + "transformers==4.57", "gdown>=5.2.1", ] diff --git a/src/data/base_dataset.py b/src/data/base_dataset.py index 17efdee..b49c014 100644 --- a/src/data/base_dataset.py +++ b/src/data/base_dataset.py @@ -78,6 +78,7 @@ def __init__( self.use_target_data: bool = use_target_data self.use_aux_data: bool = use_aux_data self.records: dict[str, Any] = self.get_records() + self.pooch_cli = None @final def get_records(self) -> dict[str, Any]: @@ -93,7 +94,12 @@ def get_records(self) -> dict[str, Any]: columns.extend(["lat", "lon"]) else: # Add paths - self.add_modality_paths_to_df(modality, params["format"]) + self.add_modality_paths_to_df( + modality, + params.get( + "format", KeyError(f"{modality} modality is missing format parameter") + ), + ) columns.append(f"{modality}_path") # Include targets @@ -218,7 +224,7 @@ def pooch_setup(self) -> None: # Initialise pooch client self.pooch_cli = pooch.create( - path=os.path.join(self.cache_dir, self.data_dir), + path=self.cache_dir, base_url="", registry=None, ) diff --git a/src/data/butterfly_dataset.py b/src/data/butterfly_dataset.py index 900d598..d6b704c 100644 --- a/src/data/butterfly_dataset.py +++ b/src/data/butterfly_dataset.py @@ -2,6 +2,7 @@ from typing import Any, Dict, override import numpy as np +import pooch import torch import src.data_preprocessing.data_utils as du @@ -53,7 +54,7 @@ def setup(self): return elif mod == "s2": self.setup_s2bms() - if self.modalities["s2"].get("preprocessing", "") == "zcored": + if self.modalities["s2"].get("preprocessing", "") == "zscored": self.init_norm_stats() elif mod == "tessera": self.setup_tessera() @@ -69,7 +70,8 @@ def setup_s2bms(self) -> None: # If data does not exist or is empty → full download if not os.path.exists(dst_dir) or len(os.listdir(dst_dir)) == 0: - import pooch + if self.pooch_cli is None: + self.pooch_setup() os.makedirs(dst_dir, exist_ok=True) fnames = self.pooch_cli.fetch("S2BMS.zip", processor=pooch.Unzip()) @@ -81,7 +83,7 @@ def setup_s2bms(self) -> None: # Move files to data dir rename_s2bms(dst_dir, fnames) - with open(os.path.join(dst_dir, "meta.tx"), "w") as f: + with open(os.path.join(dst_dir, "meta.txt"), "w") as f: f.writelines("Data from S2BMS study\n") f.writelines("Containing 4 channel S2 256x256px imagery.\n") # TODO: add more diff --git a/src/data/satbird_dataset.py b/src/data/satbird_dataset.py index 67d70e2..1364691 100644 --- a/src/data/satbird_dataset.py +++ b/src/data/satbird_dataset.py @@ -32,8 +32,9 @@ def __init__( :param study_site: study site name [Kenya, USA_summer, USA_winter] :param mock: whether to mock csv file """ - # assert study_site in ["Kenya", "USA_summer", "USA_winter"] - assert study_site in ["Kenya"] + assert study_site in ["Kenya", "USA-summer", "USA-winter"] + # assert study_site in ["Kenya"] + self.study_site = study_site super().__init__( data_dir=data_dir, @@ -47,14 +48,11 @@ def __init__( mock=mock, ) - self.study_site = study_site - @override def setup(self): """Setups the whole dataset, makes available data of requested modalities.""" # Set up each requested modality - for mod in self.modalities.keys(): if mod == "coords" and len(self.modalities.keys()) == 1: return @@ -95,8 +93,10 @@ def __getitem__(self, idx): if modality in ["coords"]: formatted_row["eo"][modality] = torch.tensor([row["lat"], row["lon"]]) elif modality in ["s2", "s2rgb"]: - formatted_row["eo"][modality] = self.load_s2(row[f"{modality}_path"]) + s2 = self.load_s2(row[f"{modality}_path"]) # TODO: augmentations + s2 = v2.CenterCrop(self.modalities[modality].get("size", 256))(s2) + formatted_row["eo"][modality] = s2 elif modality == "tessera": formatted_row["eo"][modality] = self.load_npy(row["tessera_path"]) # TODO any normalisation needed diff --git a/src/data_preprocessing/pooch_helpers.py b/src/data_preprocessing/pooch_helpers.py index 88683e3..d8e9137 100644 --- a/src/data_preprocessing/pooch_helpers.py +++ b/src/data_preprocessing/pooch_helpers.py @@ -1,10 +1,18 @@ -import os +import sys import gdown def drive_downloader(url, output_file, pooch_obj): - if os.path.exists(output_file): - print(f"{output_file} already exists, skipping.") - return - gdown.download(url, str(output_file), quiet=False) + """Downloader callback for pooch that uses gdown to fetch files from Google Drive. + + Uses fuzzy=True to handle Google Drive's virus scanning page and use_cookies=True to handle + access restrictions. + """ + gdown.download( + url, + str(output_file), + quiet=False, + fuzzy=True, + use_cookies=True, + ) diff --git a/src/data_preprocessing/satbird.py b/src/data_preprocessing/satbird.py index a19f89d..6c5dc2f 100644 --- a/src/data_preprocessing/satbird.py +++ b/src/data_preprocessing/satbird.py @@ -73,8 +73,8 @@ def pooch_satbird_downloader( conf = { "Kenya": ("Kenya.zip", pooch.Unzip), - "USA_summer": ("USA_summer.tar.gz", pooch.Untar), - "USA_winter": ("USA_winter.tar.gz", pooch.Untar), + "USA-summer": ("USA_summer.tar.gz", pooch.Untar), + "USA-winter": ("USA_winter.tar.gz", pooch.Untar), } fnames = pooch_cli.fetch( @@ -86,7 +86,7 @@ def pooch_satbird_downloader( extract_satbird_data(data_dir, fnames, study_site) # Delete the unzipped dir at the end - if False: + if True: unzip_dir = os.path.join(cache_dir, f"{study_site}.zip.unzip") for name in os.listdir(unzip_dir): path = os.path.join(unzip_dir, name) @@ -123,7 +123,6 @@ def extract_satbird_data(data_dir: str, fnames: list[str], study_site: str) -> N # Iterate through all file names from pooch for fname in fnames: - # get the base name base = os.path.basename(fname) dst = None @@ -137,7 +136,7 @@ def extract_satbird_data(data_dir: str, fnames: list[str], study_site: str) -> N elif "environmental" in fname: dst = os.path.join(env_dir, f"environmental_{base}") elif "images_visual" in fname: - base = base.replace("_visual", " ") + base = base.replace("_visual", "") dst = os.path.join(s2rgb_dir, f"s2rgb_{base}") elif "images" in fname: dst = os.path.join(s2_dir, f"s2_{base}") @@ -145,7 +144,7 @@ def extract_satbird_data(data_dir: str, fnames: list[str], study_site: str) -> N splits_file.append(fname) if dst is not None and not os.path.exists(dst): - shutil.copy(fname, dst) + shutil.move(fname, dst) print(f"Moving {base} to {dst}") # Compile model ready csv and split file @@ -232,3 +231,15 @@ def make_model_ready_csv( df_joined.rename(columns=rename_col, inplace=True) df_joined.to_csv(model_ready_csv_path, index=False) print(f"Model ready csv saved {model_ready_csv_path}") + + +if __name__ == "__main__": + print(os.getcwd()) + study_site = "USA-winter" + + setup_satbird_from_pooch( + f"data/satbird-{study_site}/", + cache_dir="data/cache", + study_site=study_site, + registry_file="data/registry.txt", + ) diff --git a/src/models/components/eo_encoders/average_encoder.py b/src/models/components/eo_encoders/average_encoder.py new file mode 100644 index 0000000..5e5d7dc --- /dev/null +++ b/src/models/components/eo_encoders/average_encoder.py @@ -0,0 +1,75 @@ +from typing import Dict, override + +import torch +import torch.nn.functional as F +from torch import nn + +from src.models.components.eo_encoders.base_eo_encoder import BaseEOEncoder + + +class AverageEncoder(BaseEOEncoder): + def __init__( + self, + output_dim: int | None = None, + eo_data_name="aef", + output_normalization="l2", + ) -> None: + super().__init__() + + dict_n_bands_default = {"s2": 4, "aef": 64, "tessera": 128} + assert ( + eo_data_name in dict_n_bands_default + ), f"eo_data_name must be one of {list(dict_n_bands_default.keys())}, got {eo_data_name}" + self.eo_data_name = eo_data_name + self.output_normalization = output_normalization + + if output_dim is None or output_dim == dict_n_bands_default[eo_data_name]: + self.output_dim = dict_n_bands_default[eo_data_name] + self.extra_projector = None + self.eo_encoder = self._average + else: + assert ( + type(output_dim) is int and output_dim > 0 + ), f"output_dim must be positive int, got {output_dim}" + self.output_dim = output_dim + self.extra_projector = nn.Linear(dict_n_bands_default[eo_data_name], output_dim) + self.eo_encoder = self._average_and_project + + def _average(self, x: torch.Tensor) -> torch.Tensor: + """Averages the input tensor over spatial dimensions. + + :param x: input tensor of shape (B, C, H, W) + :return: averaged tensor of shape (B, C) + """ + return x.mean(dim=(-2, -1)) + + def _average_and_project(self, x: torch.Tensor) -> torch.Tensor: + """Averages the input tensor over spatial dimensions and projects to output_dim. + + :param x: input tensor of shape (B, C, H, W) + :return: projected tensor of shape (B, output_dim) + """ + x_avg = x.mean(dim=(-2, -1)) + x_proj = self.extra_projector(x_avg) + return x_proj + + @override + def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: + eo_data = batch.get("eo", {}) + feats = self.eo_encoder(eo_data[self.eo_data_name]) + # n_nans = torch.sum(torch.isnan(feats)).item() + # assert ( + # n_nans == 0 + # ), f"AverageEncoder output contains {n_nans}/{feats.numel()} NaNs PRIOR to normalization with data min {eo_data[self.eo_data_name].min()} and max {eo_data[self.eo_data_name].max()}." + if self.output_normalization == "l2": + feats = F.normalize(feats, p=2, dim=1) # L2 normalization (per feature vector) + elif self.output_normalization == "none": + pass + else: + raise ValueError(f"Unsupported output_normalization: {self.output_normalization}") + + return feats + + +if __name__ == "__main__": + _ = AverageEncoder(None, None) diff --git a/src/models/components/eo_encoders/cnn_encoder.py b/src/models/components/eo_encoders/cnn_encoder.py index 9017b14..bb194e7 100644 --- a/src/models/components/eo_encoders/cnn_encoder.py +++ b/src/models/components/eo_encoders/cnn_encoder.py @@ -142,13 +142,11 @@ def forward( :return: extracted features """ eo_data = batch.get("eo", {}) - assert self.eo_data_name in eo_data, f"eo['{self.eo_data_name}'] not found in batch" - # assert not torch.any(torch.isnan(eo_data[self.eo_data_name])), f"EO data for modality {self.eo_data_name} contains NaNs in the batch." feats = self.eo_encoder(eo_data[self.eo_data_name]) - n_nans = torch.sum(torch.isnan(feats)).item() - assert ( - n_nans == 0 - ), f"CNNEncoder output contains {n_nans}/{feats.numel()} NaNs PRIOR to normalization with data min {eo_data[self.eo_data_name].min()} and max {eo_data[self.eo_data_name].max()}." + # n_nans = torch.sum(torch.isnan(feats)).item() + # assert ( + # n_nans == 0 + # ), f"CNNEncoder output contains {n_nans}/{feats.numel()} NaNs PRIOR to normalization with data min {eo_data[self.eo_data_name].min()} and max {eo_data[self.eo_data_name].max()}." if self.output_normalization == "l2": feats = F.normalize(feats, p=2, dim=1) # L2 normalization (per feature vector) elif self.output_normalization == "none": diff --git a/src/models/components/loss_fns/top_k_accuracy.py b/src/models/components/loss_fns/top_k_accuracy.py index 3720eb4..0a53e7a 100644 --- a/src/models/components/loss_fns/top_k_accuracy.py +++ b/src/models/components/loss_fns/top_k_accuracy.py @@ -29,14 +29,8 @@ def forward(self, pred: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: tmp_pred_greater_th[row, inds_sorted_preds[row, :k]] = 1 tmp_label_greater_th[row, inds_sorted_target[row, :k]] = 1 - assert tmp_pred_greater_th.sum() <= k * len_batch, tmp_pred_greater_th.sum() - assert tmp_label_greater_th.sum() <= k * len_batch, tmp_label_greater_th.sum() tmp_joint = tmp_pred_greater_th * tmp_label_greater_th n_present = torch.sum(tmp_joint, dim=1) # sum per batch sample - - for n in n_present: - assert n <= k, n_present - top_k_acc = n_present.float() / k # accuracy per batch sample accs[k] = top_k_acc.mean() diff --git a/tests/test_eo_encoders.py b/tests/test_eo_encoders.py index 4d2ef27..e91d9da 100644 --- a/tests/test_eo_encoders.py +++ b/tests/test_eo_encoders.py @@ -5,6 +5,7 @@ import pytest import torch +from src.models.components.eo_encoders.average_encoder import AverageEncoder from src.models.components.eo_encoders.base_eo_encoder import BaseEOEncoder from src.models.components.eo_encoders.cnn_encoder import CNNEncoder from src.models.components.eo_encoders.geoclip import GeoClipCoordinateEncoder @@ -13,7 +14,11 @@ # @pytest.mark.slow def test_eo_encoder_generic_properties(create_butterfly_dataset): """This test checks that all EO encoders implement the basic properties and methods.""" - dict_eo_encoders = {"geoclip_coords": GeoClipCoordinateEncoder, "cnn": CNNEncoder} + dict_eo_encoders = { + "geoclip_coords": GeoClipCoordinateEncoder, + "cnn": CNNEncoder, + "average": AverageEncoder, + } ds, dm = create_butterfly_dataset batch = next(iter(dm.train_dataloader())) @@ -34,6 +39,7 @@ def test_eo_encoder_generic_properties(create_butterfly_dataset): ), f"'forward' is not callable in {eo_encoder_class.__name__}." if eo_encoder_name == "geoclip_coords": + # TODO: try more EO encoders when (mock) test data also includes images. feats = eo_encoder.forward(batch) assert isinstance( feats, torch.Tensor