Skip to content

Commit bef9fab

Browse files
authored
bug(medcat): CU-869b07hr0 Add optional extra for embedding linker (#209)
* CU-869b07hr0: ADd explicit optional extras for embdedded linker. * CU-869b07hr0: Add check for optional extras to embedded linker * CU-869b07hr0: Fix minor typo
1 parent 889eec1 commit bef9fab

File tree

2 files changed

+17
-4
lines changed

2 files changed

+17
-4
lines changed

medcat-v2/medcat/components/linking/embedding_linker.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,25 @@
55
from medcat.tokenizing.tokenizers import BaseTokenizer
66
from typing import Optional, Iterator, Set
77
from medcat.vocab import Vocab
8-
from torch import Tensor
9-
from transformers import AutoTokenizer, AutoModel
108
from medcat.utils.postprocessing import create_main_ann
119
from tqdm import tqdm
1210
from collections import defaultdict
13-
import torch.nn.functional as F
14-
import torch
1511
import logging
1612
import math
1713

14+
from medcat.utils.import_utils import ensure_optional_extras_installed
15+
import medcat
16+
17+
# NOTE: the below needs to be before torch/transformers imports
18+
_EXTRA_NAME = "embed-linker"
19+
ensure_optional_extras_installed(medcat.__name__, _EXTRA_NAME)
20+
21+
# avoid linting issues due to above check
22+
from torch import Tensor # noqa: E402
23+
from transformers import AutoTokenizer, AutoModel # noqa: E402
24+
import torch.nn.functional as F # noqa: E402
25+
import torch # noqa: E402
26+
1827
logger = logging.getLogger(__name__)
1928

2029

medcat-v2/pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,10 @@ rel_cat = [
113113
"scikit-learn>=1.1.3,<2.0",
114114
"torch>=2.4.0,<3.0",
115115
]
116+
embed_linker = [
117+
"transformers>=4.41.0,<5.0", # avoid major bump
118+
"torch>=2.4.0,<3.0",
119+
]
116120
test = [] # TODO - list
117121

118122
[project.urls]

0 commit comments

Comments
 (0)