Skip to content

Commit 5acddc8

Browse files
authored
feat: Support MedCAT v2 (#25)
* feat: support MedCAT v2 * feat: make MedCAT V2 ontology mappings configurable * fix: add workaround for metrics collection for supervsised training * feat: add concept ids to evaluation results and deprecate py38 support * chore: update individual local dev containers * chore: improve type hints * feat: support python 3.12
1 parent 64c8bcc commit 5acddc8

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+4812
-7928
lines changed

.github/workflows/main.yaml

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,25 @@ jobs:
1212
runs-on: ubuntu-latest
1313
strategy:
1414
matrix:
15-
python-version: [ '3.8', '3.9', '3.10', '3.11' ]
15+
python-version: [
16+
'3.9',
17+
'3.10',
18+
'3.11',
19+
'3.12',
20+
]
1621
max-parallel: 4
1722

1823
steps:
1924
- uses: actions/checkout@v4
2025
- name: Install uv and set Python to ${{ matrix.python-version }}
2126
uses: astral-sh/setup-uv@v6
2227
with:
23-
version: "0.7.20"
28+
version: "0.8.10"
2429
python-version: ${{ matrix.python-version }}
2530
- name: Install dependencies
2631
run: |
27-
uv sync --group dev --group docs
32+
uv sync --extra dev --extra docs --extra vllm
33+
uv run python -m ensurepip
2834
- name: Check types
2935
run: |
3036
uv run mypy app

.github/workflows/release.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@ jobs:
1717
- name: Install uv
1818
uses: astral-sh/setup-uv@v5
1919
with:
20-
version: "0.6.10"
20+
version: "0.8.10"
2121
python-version: "3.10"
2222
- name: Install dependencies
2323
run: |
24-
uv sync --group dev --group docs --group vllm
24+
uv sync --extra dev --extra docs --extra vllm
2525
- name: Run unit tests
2626
run: |
2727
uv run pytest -v tests/app --cov --cov-report=html:coverage_reports #--random-order

app/api/api.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -211,9 +211,9 @@ def _get_app(
211211
tags = TagsStreamable
212212
else:
213213
tags = Tags
214-
tags_metadata = [{ # type: ignore
215-
"name": tag.name, # type: ignore
216-
"description": tag.value # type: ignore
214+
tags_metadata = [{
215+
"name": tag.name,
216+
"description": tag.value
217217
} for tag in tags]
218218
app = FastAPI(
219219
title="CogStack ModelServe",

app/api/auth/db.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,5 +52,4 @@ async def get_user_db(session: AsyncSession = Depends(_get_async_session)) -> As
5252
SQLAlchemyUserDatabase: A database instance initialised with the given session and the User model.
5353
"""
5454

55-
# TODO: fix this type checking error
5655
yield SQLAlchemyUserDatabase(session, User)

app/api/routers/generative.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,8 @@ async def _stream(prompt: str, max_tokens: int, temperature: float) -> AsyncGene
217217
yield f"data: {json.dumps(data)}\n\n"
218218
yield "data: [DONE]\n\n"
219219

220-
prompt = get_prompt_from_messages(model_service.tokenizer, messages) # type: ignore
220+
assert hasattr(model_service, "tokenizer"), "Model service doesn't have a tokenizer"
221+
prompt = get_prompt_from_messages(model_service.tokenizer, messages)
221222
if stream:
222223
return StreamingResponse(
223224
_stream(prompt, max_tokens, temperature),

app/api/routers/stream.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def __init__(
137137
self,
138138
content: Any,
139139
status_code: int = 200,
140-
max_chunk_size: Optional[int] = 1024,
140+
max_chunk_size: int = 1024,
141141
headers: Optional[Mapping[str, str]] = None,
142142
media_type: Optional[str] = None,
143143
background: Optional[BackgroundTask] = None,
@@ -161,8 +161,8 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
161161
})
162162
response_started = True
163163
line_bytes = line.encode("utf-8")
164-
for i in range(0, len(line_bytes), self.max_chunk_size): # type: ignore
165-
chunk = line_bytes[i:i + self.max_chunk_size] # type: ignore
164+
for i in range(0, len(line_bytes), self.max_chunk_size):
165+
chunk = line_bytes[i:i + self.max_chunk_size]
166166
await send({
167167
"type": "http.response.body",
168168
"body": chunk,

app/api/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ async def generate_text(
367367
params = SamplingParams(max_tokens=max_tokens)
368368

369369
conversation, _ = parse_chat_messages(messages, model_config, tokenizer, content_format="string") # type: ignore
370-
prompt_tokens = apply_hf_chat_template( # type: ignore
370+
prompt_tokens = apply_hf_chat_template( # type: ignore
371371
tokenizer,
372372
conversation=conversation,
373373
tools=None,

app/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ class Settings(BaseSettings): # type: ignore
3636
TRAINING_CACHE_DIR: str = os.path.join(os.path.abspath(os.path.dirname(__file__)), "cms_cache") # the directory to cache the intermediate files created during training
3737
HF_PIPELINE_AGGREGATION_STRATEGY: str = "simple" # the strategy used for aggregating the predictions of the Hugging Face NER model
3838
LOG_PER_CONCEPT_ACCURACIES: str = "false" # if "true", per-concept accuracies will be exposed to the metrics scrapper. Switch this on with caution due to the potentially high number of concepts
39+
MEDCAT2_MAPPED_ONTOLOGIES: str = "" # the comma-separated names of ontologies for MedCAT2 to map to
3940
DEBUG: str = "false" # if "true", the debug mode is switched on
4041

4142
class Config:

app/envs/.env

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,5 +73,8 @@ TRAINING_SAFE_MODEL_SERIALISATION=false
7373
# The strategy used for aggregating the predictions of the Hugging Face NER model
7474
HF_PIPELINE_AGGREGATION_STRATEGY=simple
7575

76+
# The comma-separated names of ontologies for MedCAT2 to map to
77+
MEDCAT2_MAPPED_ONTOLOGIES=opcs4,icd10
78+
7679
# If "true", the debug mode is switched on
7780
DEBUG=false

app/management/tracker_client.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import ast
12
import os
23
import socket
34
import mlflow
@@ -114,7 +115,7 @@ def send_model_stats(stats: Dict, step: int) -> None:
114115
step (int): The current step in the training or evaluation process.
115116
"""
116117

117-
metrics = {key.replace(" ", "_").lower(): val for key, val in stats.items()}
118+
metrics = {key.replace(" ", "_").lower(): val for key, val in stats.items() if isinstance(val, (int, float))}
118119
mlflow.log_metrics(metrics, step)
119120

120121
@staticmethod
@@ -563,6 +564,7 @@ def get_metrics_by_job_id(self, job_id: str) -> List[Dict[str, Any]]:
563564
metrics_history = {}
564565
for metric in run.data.metrics.keys():
565566
metrics_history[metric] = [m.value for m in self.mlflow_client.get_metric_history(run_id=run.info.run_id, key=metric)]
567+
metrics_history["concepts"] = ast.literal_eval(run.data.tags.get("training.entity.classes", "[]"))
566568
metrics.append(metrics_history)
567569
return metrics
568570
except MlflowException as e:

0 commit comments

Comments
 (0)