Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions .github/workflows/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,24 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [ '3.8', '3.9', '3.10', '3.11' ]
python-version: [
'3.9',
'3.10',
'3.11',
]
max-parallel: 4

steps:
- uses: actions/checkout@v4
- name: Install uv and set Python to ${{ matrix.python-version }}
uses: astral-sh/setup-uv@v6
with:
version: "0.7.20"
version: "0.8.10"
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
uv sync --group dev --group docs
uv sync --extra dev --extra docs --extra vllm
uv run python -m ensurepip
- name: Check types
run: |
uv run mypy app
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/release.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ jobs:
- name: Install uv
uses: astral-sh/setup-uv@v5
with:
version: "0.6.10"
version: "0.8.10"
python-version: "3.10"
- name: Install dependencies
run: |
uv sync --group dev --group docs --group vllm
uv sync --extra dev --extra docs --extra vllm
- name: Run unit tests
run: |
uv run pytest -v tests/app --cov --cov-report=html:coverage_reports #--random-order
Expand Down
1 change: 0 additions & 1 deletion app/api/auth/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,5 +52,4 @@ async def get_user_db(session: AsyncSession = Depends(_get_async_session)) -> As
SQLAlchemyUserDatabase: A database instance initialised with the given session and the User model.
"""

# TODO: fix this type checking error
yield SQLAlchemyUserDatabase(session, User)
1 change: 1 addition & 0 deletions app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class Settings(BaseSettings): # type: ignore
TRAINING_CACHE_DIR: str = os.path.join(os.path.abspath(os.path.dirname(__file__)), "cms_cache") # the directory to cache the intermediate files created during training
HF_PIPELINE_AGGREGATION_STRATEGY: str = "simple" # the strategy used for aggregating the predictions of the Hugging Face NER model
LOG_PER_CONCEPT_ACCURACIES: str = "false" # if "true", per-concept accuracies will be exposed to the metrics scrapper. Switch this on with caution due to the potentially high number of concepts
MEDCAT2_MAPPED_ONTOLOGIES: str = "" # the comma-separated names of ontologies for MedCAT2 to map to
DEBUG: str = "false" # if "true", the debug mode is switched on

class Config:
Expand Down
3 changes: 3 additions & 0 deletions app/envs/.env
Original file line number Diff line number Diff line change
Expand Up @@ -73,5 +73,8 @@ TRAINING_SAFE_MODEL_SERIALISATION=false
# The strategy used for aggregating the predictions of the Hugging Face NER model
HF_PIPELINE_AGGREGATION_STRATEGY=simple

# The comma-separated names of ontologies for MedCAT2 to map to
MEDCAT2_MAPPED_ONTOLOGIES=opcs4,icd10

# If "true", the debug mode is switched on
DEBUG=false
4 changes: 3 additions & 1 deletion app/management/tracker_client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import ast
import os
import socket
import mlflow
Expand Down Expand Up @@ -114,7 +115,7 @@ def send_model_stats(stats: Dict, step: int) -> None:
step (int): The current step in the training or evaluation process.
"""

metrics = {key.replace(" ", "_").lower(): val for key, val in stats.items()}
metrics = {key.replace(" ", "_").lower(): val for key, val in stats.items() if isinstance(val, (int, float))}
mlflow.log_metrics(metrics, step)

@staticmethod
Expand Down Expand Up @@ -563,6 +564,7 @@ def get_metrics_by_job_id(self, job_id: str) -> List[Dict[str, Any]]:
metrics_history = {}
for metric in run.data.metrics.keys():
metrics_history[metric] = [m.value for m in self.mlflow_client.get_metric_history(run_id=run.info.run_id, key=metric)]
metrics_history["concepts"] = ast.literal_eval(run.data.tags.get("training.entity.classes", "[]"))
metrics.append(metrics_history)
return metrics
except MlflowException as e:
Expand Down
58 changes: 31 additions & 27 deletions app/model_services/medcat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pandas as pd

from multiprocessing import cpu_count
from typing import Dict, List, Optional, TextIO, Tuple, Any
from typing import Dict, List, Optional, TextIO, Tuple, Any, Set
from medcat.cat import CAT
from app import __version__ as app_version
from app.model_services.base import AbstractModelService
Expand Down Expand Up @@ -46,7 +46,7 @@ def __init__(
base_model_file (Optional[str]): The model package file name. Defaults to None.
"""
super().__init__(config)
self._model: CAT = None
self._model: Optional[CAT] = None
self._config = config
self._model_parent_dir = model_parent_dir or os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "model"))
self._model_pack_path = os.path.join(self._model_parent_dir, base_model_file or config.BASE_MODEL_FILE)
Expand All @@ -55,7 +55,7 @@ def __init__(
self.model_name = model_name or "MedCAT model"

@property
def model(self) -> CAT:
def model(self) -> Optional[CAT]:
"""Getter for the MedCAT model."""

return self._model
Expand Down Expand Up @@ -113,7 +113,7 @@ def load_model(model_file_path: str, *args: Tuple, **kwargs: Dict[str, Any]) ->

model_path = os.path.join(os.path.dirname(model_file_path), get_model_data_package_base_name(model_file_path))
if unpack_model_data_package(model_file_path, model_path):
cat = CAT.load_model_pack(model_file_path.replace(".tar.gz", ".zip"), *args, **kwargs)
cat = CAT.load_model_pack(model_file_path.replace(".tar.gz", ".zip"), **kwargs)
logger.info("Model package loaded from %s", os.path.normpath(model_file_path))
return cat
else:
Expand All @@ -131,18 +131,20 @@ def init_model(self, *args: Any, **kwargs: Any) -> None:
logger.warning("Model service is already initialised and can be initialised only once")
else:
if non_default_device_is_available(get_settings().DEVICE):
self._model = self.load_model(
self._model_pack_path,
meta_cat_config_dict={"general": {"device": get_settings().DEVICE}},
)
self._model.config.general["device"] = get_settings().DEVICE
self._model = self.load_model(self._model_pack_path)
for addon in self._model.get_addons():
addon.config.general.device = get_settings().DEVICE # type: ignore
self._model.config.general.device = get_settings().DEVICE # type: ignore
else:
self._model = self.load_model(self._model_pack_path)
self._set_tuis_filtering()
if self._enable_trainer:
self._supervised_trainer = MedcatSupervisedTrainer(self)
self._unsupervised_trainer = MedcatUnsupervisedTrainer(self)
self._metacat_trainer = MetacatTrainer(self)
self._model.config.general.map_to_other_ontologies = [ # type: ignore # await new MedCAT release
tui.strip() for tui in self._config.MEDCAT2_MAPPED_ONTOLOGIES.split(",")
]

def info(self) -> ModelCard:
"""
Expand All @@ -168,11 +170,9 @@ def annotate(self, text: str) -> List[Annotation]:
List[Annotation]: A list of annotations containing the extracted named entities.
"""

doc = self.model.get_entities(
text,
addl_info=["cui2icd10", "cui2opcs4", "cui2ontologies", "cui2snomed", "cui2athena_ids"],
)
return [load_pydantic_object_from_dict(Annotation, record) for record in self.get_records_from_doc(doc)]
assert self.model is not None, "Model is not initialised"
doc = self.model.get_entities(text)
return [load_pydantic_object_from_dict(Annotation, record) for record in self.get_records_from_doc(doc)] # type: ignore

def batch_annotate(self, texts: List[str]) -> List[List[Annotation]]:
"""
Expand All @@ -187,17 +187,17 @@ def batch_annotate(self, texts: List[str]) -> List[List[Annotation]]:

batch_size_chars = 500000

docs = self.model.multiprocessing_batch_char_size(
self._data_iterator(texts),
assert self.model is not None, "Model is not initialised"
docs = {i: result for i, (_, result) in enumerate(self.model.get_entities_multi_texts(
texts,
batch_size_chars=batch_size_chars,
nproc=max(int(cpu_count() / 2), 1),
addl_info=["cui2icd10", "cui2opcs4", "cui2ontologies", "cui2snomed", "cui2athena_ids"],
)
n_process=max(int(cpu_count() / 2), 1),
))}
docs = dict(sorted(docs.items(), key=lambda x: x[0]))
annotations_list = []
for _, doc in docs.items():
annotations_list.append([
load_pydantic_object_from_dict(Annotation, record) for record in self.get_records_from_doc(doc)
load_pydantic_object_from_dict(Annotation, record) for record in self.get_records_from_doc(doc) # type: ignore
])
return annotations_list

Expand Down Expand Up @@ -362,9 +362,9 @@ def get_records_from_doc(self, doc: Dict) -> List[Dict]:
if "athena_ids" in row and row["athena_ids"]:
df.loc[idx, "athena_ids"] = [athena_id["code"] for athena_id in row["athena_ids"]]
if self._config.INCLUDE_SPAN_TEXT == "true":
df.rename(columns={"pretty_name": "label_name", "cui": "label_id", "source_value": "text", "types": "categories", "acc": "accuracy", "athena_ids": "athena_ids"}, inplace=True)
df.rename(columns={"pretty_name": "label_name", "cui": "label_id", "source_value": "text", "type_ids": "categories", "acc": "accuracy", "athena_ids": "athena_ids"}, inplace=True)
else:
df.rename(columns={"pretty_name": "label_name", "cui": "label_id", "types": "categories", "acc": "accuracy", "athena_ids": "athena_ids"}, inplace=True)
df.rename(columns={"pretty_name": "label_name", "cui": "label_id", "type_ids": "categories", "acc": "accuracy", "athena_ids": "athena_ids"}, inplace=True)
df = self._retrieve_meta_annotations(df)
records = df.to_dict("records")
return records
Expand All @@ -384,15 +384,19 @@ def _retrieve_meta_annotations(df: pd.DataFrame) -> pd.DataFrame:

def _set_tuis_filtering(self) -> None:
# this patching may not be needed after the base 1.4.x model is fixed in the future
assert self._model is not None, "Model is not initialised"
if self._model.cdb.addl_info.get("type_id2name", {}) == {}:
self._model.cdb.addl_info["type_id2name"] = TYPE_ID_TO_NAME_PATCH

tuis2cuis = self._model.cdb.addl_info.get("type_id2cuis")
model_tuis = set(tuis2cuis.keys())
type_id2info = self._model.cdb.type_id2info
model_tuis = set(type_id2info.keys())
if self._whitelisted_tuis == {""}:
return
assert self._whitelisted_tuis.issubset(model_tuis), f"Unrecognisable Type Unique Identifier(s): {self._whitelisted_tuis - model_tuis}"
whitelisted_cuis = set()
whitelisted_cuis: Set = set()
for tui in self._whitelisted_tuis:
whitelisted_cuis.update(tuis2cuis.get(tui, {}))
self._model.cdb.config.linking.filters = {"cuis": whitelisted_cuis}
type_info = type_id2info.get(tui)
if type_info is None:
continue
whitelisted_cuis.update(type_info.cuis)
self._model.config.components.linking.filters.cuis = whitelisted_cuis
45 changes: 25 additions & 20 deletions app/model_services/medcat_model_deid.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from functools import partial
from transformers import pipeline
from medcat.cat import CAT
from medcat.components.types import CoreComponentType
from app import __version__ as app_version
from app.config import Settings
from app.model_services.medcat_model import MedCATModel
Expand Down Expand Up @@ -62,14 +63,15 @@ def info(self) -> ModelCard:
ModelCard: A card containing information about the MedCAT De-Identification (AnonCAT) model.
"""

assert self.model is not None, "Model is not initialised"
model_card = self.model.get_model_card(as_dict=True)
model_card["Basic CDB Stats"]["Average training examples per concept"] = 0
return ModelCard(
model_description=self.model_name,
model_type=ModelType.ANONCAT,
api_version=self.api_version,
model_card=model_card,
labels=self.model.cdb.cui2preferred_name,
model_card=dict(model_card),
labels={cui: info['preferred_name'] for cui, info in self.model.cdb.cui2info.items()},
)

def annotate(self, text: str) -> List[Annotation]:
Expand All @@ -83,12 +85,13 @@ def annotate(self, text: str) -> List[Annotation]:
List[Annotation]: A list of annotations containing the extracted PII entities.
"""

assert self.model is not None, "Model is not initialised"
doc = self.model.get_entities(text)
if doc["entities"]:
for _, entity in doc["entities"].items():
entity["types"] = ["PII"]
entity["type_ids"] = ["PII"]

records = self.get_records_from_doc({"entities": doc["entities"]})
records = self.get_records_from_doc({"entities": doc["entities"]}) # type: ignore
return [load_pydantic_object_from_dict(Annotation, record) for record in records]

def annotate_with_local_chunking(self, text: str) -> List[Annotation]:
Expand All @@ -102,7 +105,8 @@ def annotate_with_local_chunking(self, text: str) -> List[Annotation]:
List[Annotation]: A list of annotation containing the extracted PII entities.
"""

tokenizer = self.model._addl_ner[0].tokenizer.hf_tokenizer
assert self.model is not None, "Model is not initialised"
tokenizer = self.model.pipe.get_component(CoreComponentType.ner)._component.tokenizer.hf_tokenizer # type: ignore
leading_ws_len = len(text) - len(text.lstrip())
text = text.lstrip()
tokenized = self._with_lock(tokenizer, text, return_offsets_mapping=True, add_special_tokens=False)
Expand Down Expand Up @@ -134,7 +138,7 @@ def annotate_with_local_chunking(self, text: str) -> List[Annotation]:
for entity in doc["entities"].values():
entity["start"] += processed_char_len
entity["end"] += processed_char_len
entity["types"] = ["PII"]
entity["type_ids"] = ["PII"]
aggregated_entities[ent_key] = entity
ent_key += 1
processed_char_len = chunk[:window_overlap_start_idx][-1][1][1] + leading_ws_len + 1
Expand All @@ -146,7 +150,7 @@ def annotate_with_local_chunking(self, text: str) -> List[Annotation]:
for entity in doc["entities"].values():
entity["start"] += processed_char_len
entity["end"] += processed_char_len
entity["types"] = ["PII"]
entity["type_ids"] = ["PII"]
aggregated_entities[ent_key] = entity
ent_key += 1
processed_char_len += len(c_text)
Expand All @@ -168,12 +172,13 @@ def batch_annotate(self, texts: List[str]) -> List[List[Annotation]]:
"""

annotations_list = []
assert self.model is not None, "Model is not initialised"
entities_list = self.model.get_entities_multi_texts(texts)
for entities in entities_list:
for _, entities in entities_list:
for _, entity in entities["entities"].items():
entity["types"] = ["PII"]
entity["type_ids"] = ["PII"] # type: ignore
annotations_list.append([
load_pydantic_object_from_dict(Annotation, record) for record in self.get_records_from_doc(entities)
load_pydantic_object_from_dict(Annotation, record) for record in self.get_records_from_doc(entities) # type: ignore
])

return annotations_list
Expand All @@ -190,26 +195,26 @@ def init_model(self, *args: Any, **kwargs: Any) -> None:
logger.warning("Model service is already initialised and can be initialised only once")
else:
self._model = self.load_model(self._model_pack_path)
self._model._addl_ner[0].tokenizer.hf_tokenizer._in_target_context_manager = getattr(self._model._addl_ner[0].tokenizer.hf_tokenizer, "_in_target_context_manager", False)
self._model._addl_ner[0].tokenizer.hf_tokenizer.clean_up_tokenization_spaces = getattr(self._model._addl_ner[0].tokenizer.hf_tokenizer, "clean_up_tokenization_spaces", None)
self._model._addl_ner[0].tokenizer.hf_tokenizer.split_special_tokens = getattr(self._model._addl_ner[0].tokenizer.hf_tokenizer, "split_special_tokens", False)
ner = self._model.pipe.get_component(CoreComponentType.ner)._component # type: ignore
ner.tokenizer.hf_tokenizer._in_target_context_manager = getattr(ner.tokenizer.hf_tokenizer, "_in_target_context_manager", False)
ner.tokenizer.hf_tokenizer.clean_up_tokenization_spaces = getattr(ner.tokenizer.hf_tokenizer, "clean_up_tokenization_spaces", None)
ner.tokenizer.hf_tokenizer.split_special_tokens = getattr(ner.tokenizer.hf_tokenizer, "split_special_tokens", False)
if non_default_device_is_available(self._config.DEVICE):
self._model.config.general["device"] = self._config.DEVICE
self._model._addl_ner[0].model.to(torch.device(self._config.DEVICE))
self._model._addl_ner[0].ner_pipe = pipeline(
model=self._model._addl_ner[0].model,
ner.model.to(torch.device(self._config.DEVICE))
ner.ner_pipe = pipeline(
model=ner.model,
framework="pt",
task="ner",
tokenizer=self._model._addl_ner[0].tokenizer.hf_tokenizer,
tokenizer=ner.tokenizer.hf_tokenizer,
device=get_hf_pipeline_device_id(self._config.DEVICE),
aggregation_strategy=self._config.HF_PIPELINE_AGGREGATION_STRATEGY,
)
else:
if self._config.DEVICE != "default":
logger.warning("DEVICE is set to '%s' but it is not available. Using 'default' instead.", self._config.DEVICE)
_save_pretrained = self._model._addl_ner[0].model.save_pretrained
_save_pretrained = ner.model.save_pretrained
if ("safe_serialization" in inspect.signature(_save_pretrained).parameters):
self._model._addl_ner[0].model.save_pretrained = partial(_save_pretrained, safe_serialization=(self._config.TRAINING_SAFE_MODEL_SERIALISATION == "true"))
ner.model.save_pretrained = partial(_save_pretrained, safe_serialization=(self._config.TRAINING_SAFE_MODEL_SERIALISATION == "true"))
if self._enable_trainer:
self._supervised_trainer = MedcatDeIdentificationSupervisedTrainer(self)

Expand Down
Loading