diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b8c47b4..5f5a48f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -60,7 +60,7 @@ repos: args: [ --verbose, - --fail-under=30, + --fail-under=10, --ignore-init-module, --ignore-init-method, --ignore-module, diff --git a/README.md b/README.md index c75c605..6ef1978 100644 --- a/README.md +++ b/README.md @@ -32,6 +32,7 @@ cd aether ``` Then, create a virtual environment (or alternatively via conda): + ```bash # Create venv python3 -m venv .venv @@ -39,6 +40,7 @@ source .venv/bin/activate ``` Then, install `uv` and use this to install all packages. + ```bash # install uv manager pip install uv @@ -59,7 +61,8 @@ Next, create a file in your local repo parent folder `aether/` called `.env` and ```bash cp .env.example .env ``` -Adjust the paths in `.env` to your local system. **At a minimum, you should set PROJECT_ROOT!**. + +Adjust the paths in `.env` to your local system. **At a minimum, you should set PROJECT_ROOT!**. **Important**: `DATA_DIR` should either point to `aether/data/` (default setting) OR if it points to another folder (e.g., `my/local/data/`) then copy the contents of the `aether/data/` folder to `my/local/data/` to ensure the butterfly use case runs using the provided example data. Other data will automatically be downloaded and organised by `pooch` if possible into `DATA_DIR`, or should be copied manually. @@ -85,11 +88,12 @@ Data folders should follow the following directory structure within `DATA_DIR`: ### Verify installation: To verify whether the installation was successful, run the tests in `aether/` using: + ```bash pytest --use-mock -m "not slow" ``` -which should pass all tests. +which should pass all tests. ## Training @@ -129,7 +133,7 @@ logger: To execute this experiment run (inside your venv): ```bash -python train.py experiment=prediction +python src/train.py experiment=prediction ``` Please see the [Hydra](https://hydra.cc/) and [Hydra-Lightning template](https://github.com/ashleve/lightning-hydra-template) documentation for further examples of how to configure training runs. diff --git a/configs/experiment/alignment_llm2clip.yaml b/configs/experiment/alignment_llm2clip.yaml new file mode 100644 index 0000000..3311990 --- /dev/null +++ b/configs/experiment/alignment_llm2clip.yaml @@ -0,0 +1,29 @@ +# @package _global_ + +# to execute this experiment run: +# python train.py experiment=example + +defaults: + - override /model: geoclip_llm2clip_alignment + - override /data: butterfly_coords_text + +# all parameters below will be merged with parameters from default configurations set above +# this allows you to overwrite only specified parameters + +tags: ["alignment", "geoclip_coords", "llm2clip_text"] + +seed: 12345 + +trainer: + min_epochs: 10 + max_epochs: 100 + +data: + batch_size: 64 + +logger: + wandb: + tags: ${tags} + group: "alignment" + aim: + experiment: "alignment" diff --git a/configs/model/geoclip_llm2clip_alignment.yaml b/configs/model/geoclip_llm2clip_alignment.yaml new file mode 100644 index 0000000..425aa2f --- /dev/null +++ b/configs/model/geoclip_llm2clip_alignment.yaml @@ -0,0 +1,28 @@ +_target_: src.models.text_alignment_model.TextAlignmentModel + +eo_encoder: + _target_: src.models.components.eo_encoders.geoclip.GeoClipCoordinateEncoder + +text_encoder: + _target_: src.models.components.text_encoders.llm2clip_text_encoder.LLM2CLIPTextEncoder + hf_cache_dir: ${paths.huggingface_cache} + output_normalization: l2 + +trainable_modules: [text_encoder.projector, loss_fn.log_temp] + +optimizer: + _target_: torch.optim.Adam + _partial_: true + lr: 0.001 + weight_decay: 0.0 + +scheduler: + _target_: torch.optim.lr_scheduler.ReduceLROnPlateau + _partial_: true + mode: min + factor: 0.1 + patience: 10 + +loss_fn: + _target_: src.models.components.loss_fns.clip_loss.ClipLoss + temperature: 0.07 diff --git a/pyproject.toml b/pyproject.toml index 9cca0a5..3666d7a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,8 @@ dependencies = [ "torchinfo>=1.8.0", "transformers==4.57", "gdown>=5.2.1", + "peft>=0.18.1", + "llm2vec", ] [project.optional-dependencies] @@ -66,3 +68,6 @@ exclude_lines = [ "raise NotImplementedError()", "if __name__ == .__main__.:", ] + +[tool.uv.sources] +llm2vec = { git = "https://github.com/gabrieletijunaityte/llm2vec.git", rev = "445a831479748460eddc1e537ab31031cfd1a1e1" } diff --git a/src/models/components/eo_encoders/average_encoder.py b/src/models/components/eo_encoders/average_encoder.py index 5e5d7dc..6fe2e67 100644 --- a/src/models/components/eo_encoders/average_encoder.py +++ b/src/models/components/eo_encoders/average_encoder.py @@ -22,6 +22,8 @@ def __init__( ), f"eo_data_name must be one of {list(dict_n_bands_default.keys())}, got {eo_data_name}" self.eo_data_name = eo_data_name self.output_normalization = output_normalization + if self.output_normalization not in ["l2", "none"]: + raise ValueError(f"Unsupported output_normalization: {self.output_normalization}") if output_dim is None or output_dim == dict_n_bands_default[eo_data_name]: self.output_dim = dict_n_bands_default[eo_data_name] @@ -56,6 +58,9 @@ def _average_and_project(self, x: torch.Tensor) -> torch.Tensor: @override def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: eo_data = batch.get("eo", {}) + dtype = self.dtype + if eo_data.dtype != dtype: + eo_data = eo_data.to(dtype) feats = self.eo_encoder(eo_data[self.eo_data_name]) # n_nans = torch.sum(torch.isnan(feats)).item() # assert ( @@ -63,12 +68,8 @@ def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: # ), f"AverageEncoder output contains {n_nans}/{feats.numel()} NaNs PRIOR to normalization with data min {eo_data[self.eo_data_name].min()} and max {eo_data[self.eo_data_name].max()}." if self.output_normalization == "l2": feats = F.normalize(feats, p=2, dim=1) # L2 normalization (per feature vector) - elif self.output_normalization == "none": - pass - else: - raise ValueError(f"Unsupported output_normalization: {self.output_normalization}") - return feats + return feats.to(dtype) if __name__ == "__main__": diff --git a/src/models/components/eo_encoders/base_eo_encoder.py b/src/models/components/eo_encoders/base_eo_encoder.py index ee07432..4982652 100644 --- a/src/models/components/eo_encoders/base_eo_encoder.py +++ b/src/models/components/eo_encoders/base_eo_encoder.py @@ -15,6 +15,20 @@ def __init__(self) -> None: def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: pass + @property + def device(self) -> torch.device: + devices = {p.device for p in self.parameters()} + if len(devices) != 1: + raise RuntimeError("EO encoder is on multiple devices") + return devices.pop() + + @property + def dtype(self) -> torch.dtype: + dtypes = {p.dtype for p in self.parameters()} + if len(dtypes) != 1: + raise RuntimeError("EO encoder has multiple dtypes") + return dtypes.pop() + if __name__ == "__main__": _ = BaseEOEncoder(None) diff --git a/src/models/components/eo_encoders/cnn_encoder.py b/src/models/components/eo_encoders/cnn_encoder.py index bb194e7..a61c4c1 100644 --- a/src/models/components/eo_encoders/cnn_encoder.py +++ b/src/models/components/eo_encoders/cnn_encoder.py @@ -44,6 +44,8 @@ def __init__( ), f"input_n_bands must be int >=3, got {self.input_n_bands}" self.output_dim = output_dim self.output_normalization = output_normalization + if self.output_normalization not in ["l2", "none"]: + raise ValueError(f"Unsupported output_normalization: {self.output_normalization}") self.eo_encoder = self.get_backbone() @@ -142,6 +144,10 @@ def forward( :return: extracted features """ eo_data = batch.get("eo", {}) + + dtype = self.dtype + if eo_data.dtype != dtype: + eo_data = eo_data.to(dtype) feats = self.eo_encoder(eo_data[self.eo_data_name]) # n_nans = torch.sum(torch.isnan(feats)).item() # assert ( @@ -149,12 +155,8 @@ def forward( # ), f"CNNEncoder output contains {n_nans}/{feats.numel()} NaNs PRIOR to normalization with data min {eo_data[self.eo_data_name].min()} and max {eo_data[self.eo_data_name].max()}." if self.output_normalization == "l2": feats = F.normalize(feats, p=2, dim=1) # L2 normalization (per feature vector) - elif self.output_normalization == "none": - pass - else: - raise ValueError(f"Unsupported output_normalization: {self.output_normalization}") - return feats + return feats.to(dtype) if __name__ == "__main__": diff --git a/src/models/components/eo_encoders/geoclip.py b/src/models/components/eo_encoders/geoclip.py index 5f8df8c..16936d1 100644 --- a/src/models/components/eo_encoders/geoclip.py +++ b/src/models/components/eo_encoders/geoclip.py @@ -12,22 +12,28 @@ def __init__(self, output_normalization="l2") -> None: super().__init__() self.eo_encoder = LocationEncoder() self.output_dim = self.eo_encoder.LocEnc0.head[0].out_features + self.output_normalization = output_normalization + if self.output_normalization not in ["l2", "none"]: + raise ValueError(f"Unsupported output_normalization: {self.output_normalization}") @override def forward( self, batch: Dict[str, torch.Tensor], ) -> torch.Tensor: + coords = batch.get("eo", {}).get("coords") + + dtype = self.dtype + if coords.dtype != dtype: + coords = coords.to(dtype) feats = self.eo_encoder(coords) + if self.output_normalization == "l2": feats = F.normalize(feats, p=2, dim=-1) # L2 normalization (per feature vector) - elif self.output_normalization == "none": - pass - else: - raise ValueError(f"Unsupported output_normalization: {self.output_normalization}") - return feats + + return feats.to(dtype) if __name__ == "__main__": diff --git a/src/models/components/text_encoders/base_text_encoder.py b/src/models/components/text_encoders/base_text_encoder.py index 9ac5b0d..d934aab 100644 --- a/src/models/components/text_encoders/base_text_encoder.py +++ b/src/models/components/text_encoders/base_text_encoder.py @@ -23,5 +23,22 @@ def add_projector(self, projected_dim: int) -> None: NB: is not used by default, needs to be called explicitly in forward(). """ - self.extra_projector = nn.Linear(self.output_dim, projected_dim) + self.extra_projector = nn.Linear(self.output_dim, projected_dim, dtype=self.dtype) + print( + f"Extra linear projection layer added with mapping dimension {self.output_dim} to {projected_dim}" + ) self.output_dim = projected_dim + + @property + def device(self) -> torch.device: + devices = {p.device for p in self.parameters()} + if len(devices) != 1: + raise RuntimeError("Text encoder is on multiple devices") + return devices.pop() + + @property + def dtype(self) -> torch.dtype: + dtypes = {p.dtype for p in self.parameters()} + if len(dtypes) != 1: + raise RuntimeError("Text encoder has multiple dtypes") + return dtypes.pop() diff --git a/src/models/components/text_encoders/clip_text_encoder.py b/src/models/components/text_encoders/clip_text_encoder.py index 5a0ac5b..46801b8 100644 --- a/src/models/components/text_encoders/clip_text_encoder.py +++ b/src/models/components/text_encoders/clip_text_encoder.py @@ -25,6 +25,9 @@ def __init__(self, hf_cache_dir: str = "../.cache", output_normalization="l2") - self.projector = GeoCLIP().image_encoder.mlp self.output_normalization = output_normalization + if self.output_normalization not in ["l2", "none"]: + raise ValueError(f"Unsupported output_normalization: {self.output_normalization}") + self.output_dim = 512 @override @@ -61,9 +64,5 @@ def forward(self, batch: Dict[str, torch.Tensor], mode: str) -> torch.Tensor: text_embeds = F.normalize( text_embeds, p=2, dim=-1 ) # L2 normalization (per feature vector) - elif self.output_normalization == "none": - pass - else: - raise ValueError(f"Unsupported output_normalization: {self.output_normalization}") return text_embeds diff --git a/src/models/components/text_encoders/llm2clip/__init__.py b/src/models/components/text_encoders/llm2clip/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/models/components/text_encoders/llm2clip/llama.py b/src/models/components/text_encoders/llm2clip/llama.py new file mode 100644 index 0000000..0f98b1d --- /dev/null +++ b/src/models/components/text_encoders/llm2clip/llama.py @@ -0,0 +1,71 @@ +import importlib.metadata + +from packaging import version +from torch import nn +from transformers import LlamaConfig, LlamaModel, LlamaPreTrainedModel +from transformers.modeling_layers import GradientCheckpointingLayer +from transformers.models.llama.modeling_llama import ( + LlamaAttention, + LlamaDecoderLayer, + LlamaMLP, + LlamaRMSNorm, + LlamaRotaryEmbedding, +) +from transformers.utils import logging +from transformers.utils.import_utils import _is_package_available + +logger = logging.get_logger(__name__) + + +def is_transformers_attn_greater_or_equal_4_56_2(): + if not _is_package_available("transformers"): + return False + + return version.parse(importlib.metadata.version("transformers")) >= version.parse("4.56.2") + + +class ModifiedLlamaAttention(LlamaAttention): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.is_causal = False + + +class ModifiedLlamaDecoderLayer(LlamaDecoderLayer): + def __init__(self, config: LlamaConfig, layer_idx: int): + GradientCheckpointingLayer.__init__(self) + self.hidden_size = config.hidden_size + + self.self_attn = ModifiedLlamaAttention(config=config, layer_idx=layer_idx) + + self.mlp = LlamaMLP(config) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + +class LlamaEncoderModel(LlamaModel): + def __init__(self, config): + if not is_transformers_attn_greater_or_equal_4_56_2(): + raise ValueError( + "The current implementation of LlamaEncoderModel follows modeling_llama.py of transformers version >= 4.56.2" + ) + LlamaPreTrainedModel.__init__(self, config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [ + ModifiedLlamaDecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self._use_sdpa = config._attn_implementation == "sdpa" + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = LlamaRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + print("Model loaded successfully without Flash Attention!") diff --git a/src/models/components/text_encoders/llm2clip_text_encoder.py b/src/models/components/text_encoders/llm2clip_text_encoder.py new file mode 100644 index 0000000..dc06401 --- /dev/null +++ b/src/models/components/text_encoders/llm2clip_text_encoder.py @@ -0,0 +1,107 @@ +from typing import Dict, override + +import torch +from llm2vec import LLM2Vec +from torch.nn import functional as F +from transformers import AutoConfig, AutoModel, AutoTokenizer + +from src.models.components.text_encoders.base_text_encoder import BaseTextEncoder +from src.models.components.text_encoders.llm2clip.llama import LlamaEncoderModel + + +class LLM2CLIPTextEncoder(BaseTextEncoder): + def __init__(self, hf_cache_dir: str = "../.cache", output_normalization="l2") -> None: + """LLM2CLIP text encoder implementation. Uses LLM2CLIP-Llama-3-8B-Instruct-CC-Finetuned as + LLM and LLM2CLIP trained adapter. + + :param hf_cache_dir: huggingface cache directory + :param output_normalization: output normalization type + """ + super().__init__() + + # Adapter and image encoder + self.projector = AutoModel.from_pretrained( + "microsoft/LLM2CLIP-Openai-L-14-224", + trust_remote_code=True, + dtype=torch.bfloat16, + revision="50ed31c5248d8ff124893719e37829d59376be81", # pin revision for full reproducibility + cache_dir=hf_cache_dir, + ).eval() + + # TODO: If we want to reuse the vision part this is the fix place + self.projector.vision_model = None + self.projector.visual_projection = None + + # The LLM sentence encoder + llm_model_name = "microsoft/LLM2CLIP-Llama-3-8B-Instruct-CC-Finetuned" + config = AutoConfig.from_pretrained( + llm_model_name, + trust_remote_code=True, + cache_dir=hf_cache_dir, + ) + config._attn_implementation = "eager" + + llm_model = LlamaEncoderModel.from_pretrained( + llm_model_name, + config=config, + dtype=torch.bfloat16, + trust_remote_code=False, # local code + cache_dir=hf_cache_dir, + ) + llm_model.config._name_or_path = ( + "meta-llama/Meta-Llama-3-8B-Instruct" # Workaround for LLM2VEC + ) + self.processor = AutoTokenizer.from_pretrained(llm_model_name) + + # Caption to vector with the llama LLM + self.model = LLM2Vec( + llm_model, self.processor, pooling_mode="mean", max_length=512, doc_max_length=512 + ) + + self.output_normalization = output_normalization + if self.output_normalization not in ["l2", "none"]: + raise ValueError(f"Unsupported output_normalization: {self.output_normalization}") + + self.output_dim = 1280 + + @override + def forward(self, batch: Dict[str, torch.Tensor], mode: str) -> torch.Tensor: + """Forward pass through text encoder.""" + # Get text inputs + text_input = batch.get("text") + + if mode == "train": + text_input = [text_input] + # Embed text and if not train average all templates + avr_embeds = [] + for captions_per_row in text_input: + # LLM is frozen, no gradients needed + with torch.no_grad(): + # Embed + text_embeds = self.model.encode( + captions_per_row, convert_to_tensor=True, device=self.device + ) + + # Change dtype + text_embeds = text_embeds.to( + dtype=self.projector.dtype, device=self.projector.device + ) + + # Project to align with ViT in LLM2CLIP + text_embeds = self.projector.get_text_features(text_embeds) + + if self.extra_projector is not None: + text_embeds = self.extra_projector(text_embeds) + + if mode != "train": + avr_embeds.append(text_embeds.mean(dim=0)) + + if mode != "train": + text_embeds = torch.stack(avr_embeds, dim=0) + + if self.output_normalization == "l2": + text_embeds = F.normalize( + text_embeds, p=2, dim=-1 + ) # L2 normalization (per feature vector) + + return text_embeds diff --git a/src/models/text_alignment_model.py b/src/models/text_alignment_model.py index d1882b1..baef47e 100644 --- a/src/models/text_alignment_model.py +++ b/src/models/text_alignment_model.py @@ -66,15 +66,33 @@ def __init__( ) self.prediction_head.configure_nn() + # Unify dtypes + if self.eo_encoder.dtype != self.text_encoder.dtype: + self.eo_encoder = self.eo_encoder.to(self.text_encoder.dtype) + print(f"Eo encoder dtype changed to {self.eo_encoder.dtype}") + # Freezing requested parts self.freezer() + # Normalisation status for cosine similarity + if ( + self.text_encoder.output_normalization + == "l2" + != self.eo_encoder.output_normalization + == "l2" + ): + # TODO think of how to make this consistent + raise ValueError("Only one modality is normalised") + + self.normalised = self.text_encoder.output_normalization == "l2" + @override def forward( self, batch: Dict[str, torch.Tensor], mode: str = "train", ) -> Tuple[torch.Tensor, torch.Tensor]: + """Model forward logic.""" # Embed modalities eo_feats = self.eo_encoder(batch) @@ -83,6 +101,8 @@ def forward( @override def _step(self, batch: Dict[str, torch.Tensor], mode: str = "train") -> torch.Tensor: + """Model step logic.""" + # Embed eo_feats, text_feats = self.forward(batch, mode) local_batch_size = eo_feats.size(0) @@ -125,7 +145,13 @@ def _step(self, batch: Dict[str, torch.Tensor], mode: str = "train") -> torch.Te def _cos_sim_calc(self, eo_feats, text_feats, mode, log=True): """Calculate cosine similarity between eo and text embeddings and logs it.""" # Similarity matrix - cos_sim_matrix = F.cosine_similarity(eo_feats[:, None, :], text_feats[None, :, :], dim=-1) + if self.normalised: + cos_sim_matrix = eo_feats @ text_feats.T + else: + cos_sim_matrix = F.cosine_similarity( + eo_feats[:, None, :], text_feats[None, :, :], dim=-1 + ) + local_batch_size = eo_feats.size(0) # Average for positive and negative pairs diff --git a/tests/test_text_encoders.py b/tests/test_text_encoders.py index fb098ff..0d09714 100644 --- a/tests/test_text_encoders.py +++ b/tests/test_text_encoders.py @@ -7,12 +7,16 @@ from src.models.components.text_encoders.base_text_encoder import BaseTextEncoder from src.models.components.text_encoders.clip_text_encoder import ClipTextEncoder +from src.models.components.text_encoders.llm2clip_text_encoder import ( + LLM2CLIPTextEncoder, +) -# @pytest.mark.slow +# Initialisation of text encoders involve downloading the large models +@pytest.mark.slow def test_text_encoder_generic_properties(create_butterfly_dataset): """This test checks that all text encoders implement the basic properties and methods.""" - list_text_encoders = [ClipTextEncoder] + list_text_encoders = [ClipTextEncoder, LLM2CLIPTextEncoder] ds, dm = create_butterfly_dataset batch = next(iter(dm.train_dataloader())) text_input = batch.get("text")