Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ repos:
args:
[
--verbose,
--fail-under=30,
--fail-under=10,
--ignore-init-module,
--ignore-init-method,
--ignore-module,
Expand Down
10 changes: 7 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,15 @@ cd aether
```

Then, create a virtual environment (or alternatively via conda):

```bash
# Create venv
python3 -m venv .venv
source .venv/bin/activate
```

Then, install `uv` and use this to install all packages.

```bash
# install uv manager
pip install uv
Expand All @@ -59,7 +61,8 @@ Next, create a file in your local repo parent folder `aether/` called `.env` and
```bash
cp .env.example .env
```
Adjust the paths in `.env` to your local system. **At a minimum, you should set PROJECT_ROOT!**.

Adjust the paths in `.env` to your local system. **At a minimum, you should set PROJECT_ROOT!**.

**Important**: `DATA_DIR` should either point to `aether/data/` (default setting) OR if it points to another folder (e.g., `my/local/data/`) then copy the contents of the `aether/data/` folder to `my/local/data/` to ensure the butterfly use case runs using the provided example data. Other data will automatically be downloaded and organised by `pooch` if possible into `DATA_DIR`, or should be copied manually.

Expand All @@ -85,11 +88,12 @@ Data folders should follow the following directory structure within `DATA_DIR`:
### Verify installation:

To verify whether the installation was successful, run the tests in `aether/` using:

```bash
pytest --use-mock -m "not slow"
```
which should pass all tests.

which should pass all tests.

## Training

Expand Down Expand Up @@ -129,7 +133,7 @@ logger:
To execute this experiment run (inside your venv):

```bash
python train.py experiment=prediction
python src/train.py experiment=prediction
```

Please see the [Hydra](https://hydra.cc/) and [Hydra-Lightning template](https://github.com/ashleve/lightning-hydra-template) documentation for further examples of how to configure training runs.
Expand Down
29 changes: 29 additions & 0 deletions configs/experiment/alignment_llm2clip.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# @package _global_

# to execute this experiment run:
# python train.py experiment=example

defaults:
- override /model: geoclip_llm2clip_alignment
- override /data: butterfly_coords_text

# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters

tags: ["alignment", "geoclip_coords", "llm2clip_text"]

seed: 12345

trainer:
min_epochs: 10
max_epochs: 100

data:
batch_size: 64

logger:
wandb:
tags: ${tags}
group: "alignment"
aim:
experiment: "alignment"
28 changes: 28 additions & 0 deletions configs/model/geoclip_llm2clip_alignment.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
_target_: src.models.text_alignment_model.TextAlignmentModel

eo_encoder:
_target_: src.models.components.eo_encoders.geoclip.GeoClipCoordinateEncoder

text_encoder:
_target_: src.models.components.text_encoders.llm2clip_text_encoder.LLM2CLIPTextEncoder
hf_cache_dir: ${paths.huggingface_cache}
output_normalization: l2

trainable_modules: [text_encoder.projector, loss_fn.log_temp]

optimizer:
_target_: torch.optim.Adam
_partial_: true
lr: 0.001
weight_decay: 0.0

scheduler:
_target_: torch.optim.lr_scheduler.ReduceLROnPlateau
_partial_: true
mode: min
factor: 0.1
patience: 10

loss_fn:
_target_: src.models.components.loss_fns.clip_loss.ClipLoss
temperature: 0.07
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ dependencies = [
"torchinfo>=1.8.0",
"transformers==4.57",
"gdown>=5.2.1",
"peft>=0.18.1",
"llm2vec",
]

[project.optional-dependencies]
Expand Down Expand Up @@ -66,3 +68,6 @@ exclude_lines = [
"raise NotImplementedError()",
"if __name__ == .__main__.:",
]

[tool.uv.sources]
llm2vec = { git = "https://github.com/gabrieletijunaityte/llm2vec.git", rev = "445a831479748460eddc1e537ab31031cfd1a1e1" }
11 changes: 6 additions & 5 deletions src/models/components/eo_encoders/average_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ def __init__(
), f"eo_data_name must be one of {list(dict_n_bands_default.keys())}, got {eo_data_name}"
self.eo_data_name = eo_data_name
self.output_normalization = output_normalization
if self.output_normalization not in ["l2", "none"]:
raise ValueError(f"Unsupported output_normalization: {self.output_normalization}")

if output_dim is None or output_dim == dict_n_bands_default[eo_data_name]:
self.output_dim = dict_n_bands_default[eo_data_name]
Expand Down Expand Up @@ -56,19 +58,18 @@ def _average_and_project(self, x: torch.Tensor) -> torch.Tensor:
@override
def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
eo_data = batch.get("eo", {})
dtype = self.dtype
if eo_data.dtype != dtype:
eo_data = eo_data.to(dtype)
feats = self.eo_encoder(eo_data[self.eo_data_name])
# n_nans = torch.sum(torch.isnan(feats)).item()
# assert (
# n_nans == 0
# ), f"AverageEncoder output contains {n_nans}/{feats.numel()} NaNs PRIOR to normalization with data min {eo_data[self.eo_data_name].min()} and max {eo_data[self.eo_data_name].max()}."
if self.output_normalization == "l2":
feats = F.normalize(feats, p=2, dim=1) # L2 normalization (per feature vector)
elif self.output_normalization == "none":
pass
else:
raise ValueError(f"Unsupported output_normalization: {self.output_normalization}")

return feats
return feats.to(dtype)


if __name__ == "__main__":
Expand Down
14 changes: 14 additions & 0 deletions src/models/components/eo_encoders/base_eo_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,20 @@ def __init__(self) -> None:
def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
pass

@property
def device(self) -> torch.device:
devices = {p.device for p in self.parameters()}
if len(devices) != 1:
raise RuntimeError("EO encoder is on multiple devices")
return devices.pop()

@property
def dtype(self) -> torch.dtype:
dtypes = {p.dtype for p in self.parameters()}
if len(dtypes) != 1:
raise RuntimeError("EO encoder has multiple dtypes")
return dtypes.pop()


if __name__ == "__main__":
_ = BaseEOEncoder(None)
12 changes: 7 additions & 5 deletions src/models/components/eo_encoders/cnn_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ def __init__(
), f"input_n_bands must be int >=3, got {self.input_n_bands}"
self.output_dim = output_dim
self.output_normalization = output_normalization
if self.output_normalization not in ["l2", "none"]:
raise ValueError(f"Unsupported output_normalization: {self.output_normalization}")

self.eo_encoder = self.get_backbone()

Expand Down Expand Up @@ -142,19 +144,19 @@ def forward(
:return: extracted features
"""
eo_data = batch.get("eo", {})

dtype = self.dtype
if eo_data.dtype != dtype:
eo_data = eo_data.to(dtype)
feats = self.eo_encoder(eo_data[self.eo_data_name])
# n_nans = torch.sum(torch.isnan(feats)).item()
# assert (
# n_nans == 0
# ), f"CNNEncoder output contains {n_nans}/{feats.numel()} NaNs PRIOR to normalization with data min {eo_data[self.eo_data_name].min()} and max {eo_data[self.eo_data_name].max()}."
if self.output_normalization == "l2":
feats = F.normalize(feats, p=2, dim=1) # L2 normalization (per feature vector)
elif self.output_normalization == "none":
pass
else:
raise ValueError(f"Unsupported output_normalization: {self.output_normalization}")

return feats
return feats.to(dtype)


if __name__ == "__main__":
Expand Down
16 changes: 11 additions & 5 deletions src/models/components/eo_encoders/geoclip.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,28 @@ def __init__(self, output_normalization="l2") -> None:
super().__init__()
self.eo_encoder = LocationEncoder()
self.output_dim = self.eo_encoder.LocEnc0.head[0].out_features

self.output_normalization = output_normalization
if self.output_normalization not in ["l2", "none"]:
raise ValueError(f"Unsupported output_normalization: {self.output_normalization}")

@override
def forward(
self,
batch: Dict[str, torch.Tensor],
) -> torch.Tensor:

coords = batch.get("eo", {}).get("coords")

dtype = self.dtype
if coords.dtype != dtype:
coords = coords.to(dtype)
feats = self.eo_encoder(coords)

if self.output_normalization == "l2":
feats = F.normalize(feats, p=2, dim=-1) # L2 normalization (per feature vector)
elif self.output_normalization == "none":
pass
else:
raise ValueError(f"Unsupported output_normalization: {self.output_normalization}")
return feats

return feats.to(dtype)


if __name__ == "__main__":
Expand Down
19 changes: 18 additions & 1 deletion src/models/components/text_encoders/base_text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,22 @@ def add_projector(self, projected_dim: int) -> None:

NB: is not used by default, needs to be called explicitly in forward().
"""
self.extra_projector = nn.Linear(self.output_dim, projected_dim)
self.extra_projector = nn.Linear(self.output_dim, projected_dim, dtype=self.dtype)
print(
f"Extra linear projection layer added with mapping dimension {self.output_dim} to {projected_dim}"
)
self.output_dim = projected_dim

@property
def device(self) -> torch.device:
devices = {p.device for p in self.parameters()}
if len(devices) != 1:
raise RuntimeError("Text encoder is on multiple devices")
return devices.pop()

@property
def dtype(self) -> torch.dtype:
dtypes = {p.dtype for p in self.parameters()}
if len(dtypes) != 1:
raise RuntimeError("Text encoder has multiple dtypes")
return dtypes.pop()
7 changes: 3 additions & 4 deletions src/models/components/text_encoders/clip_text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ def __init__(self, hf_cache_dir: str = "../.cache", output_normalization="l2") -

self.projector = GeoCLIP().image_encoder.mlp
self.output_normalization = output_normalization
if self.output_normalization not in ["l2", "none"]:
raise ValueError(f"Unsupported output_normalization: {self.output_normalization}")

self.output_dim = 512

@override
Expand Down Expand Up @@ -61,9 +64,5 @@ def forward(self, batch: Dict[str, torch.Tensor], mode: str) -> torch.Tensor:
text_embeds = F.normalize(
text_embeds, p=2, dim=-1
) # L2 normalization (per feature vector)
elif self.output_normalization == "none":
pass
else:
raise ValueError(f"Unsupported output_normalization: {self.output_normalization}")

return text_embeds
Empty file.
71 changes: 71 additions & 0 deletions src/models/components/text_encoders/llm2clip/llama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import importlib.metadata

from packaging import version
from torch import nn
from transformers import LlamaConfig, LlamaModel, LlamaPreTrainedModel
from transformers.modeling_layers import GradientCheckpointingLayer
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaDecoderLayer,
LlamaMLP,
LlamaRMSNorm,
LlamaRotaryEmbedding,
)
from transformers.utils import logging
from transformers.utils.import_utils import _is_package_available

logger = logging.get_logger(__name__)


def is_transformers_attn_greater_or_equal_4_56_2():
if not _is_package_available("transformers"):
return False

return version.parse(importlib.metadata.version("transformers")) >= version.parse("4.56.2")


class ModifiedLlamaAttention(LlamaAttention):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.is_causal = False


class ModifiedLlamaDecoderLayer(LlamaDecoderLayer):
def __init__(self, config: LlamaConfig, layer_idx: int):
GradientCheckpointingLayer.__init__(self)
self.hidden_size = config.hidden_size

self.self_attn = ModifiedLlamaAttention(config=config, layer_idx=layer_idx)

self.mlp = LlamaMLP(config)
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)


class LlamaEncoderModel(LlamaModel):
def __init__(self, config):
if not is_transformers_attn_greater_or_equal_4_56_2():
raise ValueError(
"The current implementation of LlamaEncoderModel follows modeling_llama.py of transformers version >= 4.56.2"
)
LlamaPreTrainedModel.__init__(self, config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size

self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
[
ModifiedLlamaDecoderLayer(config, layer_idx)
for layer_idx in range(config.num_hidden_layers)
]
)
self._use_sdpa = config._attn_implementation == "sdpa"
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"

self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = LlamaRotaryEmbedding(config=config)
self.gradient_checkpointing = False

# Initialize weights and apply final processing
self.post_init()
print("Model loaded successfully without Flash Attention!")
Loading