33It will link the current setup (i.e medcat version) into account and
44subsequently identify and download the medcat-scripts based on the most
55recent applicable tag. So if you've got medcat==2.2.0, it might grab
6- medcat-scripts /v2.2.3 for instance.
6+ medcat/v2.2.3 for instance.
77"""
88import importlib .metadata
99import tempfile
1010import zipfile
1111from pathlib import Path
1212import requests
1313import logging
14+ import argparse
1415
1516
1617logger = logging .getLogger (__name__ )
1718
1819
20+ EXPECTED_TAG_PREFIX = 'medcat/v'
1921GITHUB_REPO = "CogStack/cogstack-nlp"
2022SCRIPTS_PATH = "medcat-scripts/"
2123DOWNLOAD_URL_TEMPLATE = (
@@ -27,7 +29,9 @@ def _get_medcat_version() -> str:
2729 """Return the installed MedCAT version as 'major.minor'."""
2830 version = importlib .metadata .version ("medcat" )
2931 major , minor , * _ = version .split ("." )
30- return f"{ major } .{ minor } "
32+ minor_version = f"{ major } .{ minor } "
33+ logger .debug ("Using medcat minor version of %s" , minor_version )
34+ return minor_version
3135
3236
3337def _find_latest_scripts_tag (major_minor : str ) -> str :
@@ -38,9 +42,10 @@ def _find_latest_scripts_tag(major_minor: str) -> str:
3842 matching = [
3943 t ["name" ]
4044 for t in tags
41- if t ["name" ].startswith (f"medcat-scripts/v{ major_minor } ." )
42- or t ["name" ].startswith (f"v{ major_minor } ." )
45+ if t ["name" ].startswith (f"{ EXPECTED_TAG_PREFIX } { major_minor } ." )
4346 ]
47+ logger .debug ("Found %d matching (out of a total of %d): %s" ,
48+ len (matching ), len (tags ), matching )
4449 if not matching :
4550 raise RuntimeError (
4651 f"No medcat-scripts tags found for MedCAT { major_minor } .x" )
@@ -49,36 +54,42 @@ def _find_latest_scripts_tag(major_minor: str) -> str:
4954 return matching [0 ]
5055
5156
52- def fetch_scripts (destination : str | Path = "." ) -> Path :
53- """Download the latest compatible medcat-scripts folder into.
54-
55- Args:
56- destination (str | Path): The destination path. Defaults to ".".
57+ def _determine_url (overwrite_url : str | None ,
58+ overwrite_tag : str | None ) -> str :
59+ if overwrite_url :
60+ logger .info ("Using the overwrite URL instead: %s" , overwrite_url )
61+ zip_url = overwrite_url
62+ else :
63+ version = _get_medcat_version ()
64+ if overwrite_tag :
65+ tag = overwrite_tag
66+ logger .info ("Using overwritten tag '%s'" , tag )
67+ else :
68+ tag = _find_latest_scripts_tag (version )
5769
58- Returns:
59- Path: The path of the scripts.
60- """
61- dest = Path (destination ).expanduser ().resolve ()
62- dest .mkdir (parents = True , exist_ok = True )
70+ logger .info ("Fetching scripts for MedCAT %s → tag %s}" ,
71+ version , tag )
6372
64- version = _get_medcat_version ()
65- tag = _find_latest_scripts_tag (version )
73+ # Download the GitHub auto-generated zipball
74+ zip_url = DOWNLOAD_URL_TEMPLATE .format (tag = tag )
75+ return zip_url
6676
67- logger .info ("Fetching scripts for MedCAT %s → tag %s}" ,
68- version , tag )
6977
70- # Download the GitHub auto-generated zipball
71- zip_url = DOWNLOAD_URL_TEMPLATE .format (tag = tag )
78+ def _download_zip (zip_url : str , tmp : tempfile ._TemporaryFileWrapper ):
7279 with requests .get (zip_url , stream = True , timeout = 30 ) as r :
7380 r .raise_for_status ()
74- with tempfile .NamedTemporaryFile (delete = False ) as tmp :
75- for chunk in r .iter_content (chunk_size = 8192 ):
76- tmp .write (chunk )
77- zip_path = Path (tmp .name )
81+ for chunk in r .iter_content (chunk_size = 8192 ):
82+ tmp .write (chunk )
83+ tmp .flush ()
7884
85+
86+ def _extract_zip (dest : Path , zip_path : Path ):
7987 # Extract only medcat-scripts/ from the archive
88+ wrote_files_num = 0
89+ total_files = 0
8090 with zipfile .ZipFile (zip_path ) as zf :
8191 for m in zf .namelist ():
92+ total_files += 1
8293 if f"/{ SCRIPTS_PATH } " not in m :
8394 continue
8495 # skip repo-hash prefix
@@ -88,14 +99,60 @@ def fetch_scripts(destination: str | Path = ".") -> Path:
8899 else :
89100 with open (target , "wb" ) as f :
90101 f .write (zf .read (m ))
91-
102+ wrote_files_num += 1
103+
104+ logger .debug ("Wrote %d / %d files" , wrote_files_num , total_files )
105+ if not wrote_files_num :
106+ logger .warning (
107+ "Was unable to extract any files from '%s' folder in the zip. "
108+ "The folder doesn't seem to exist in the provided archive." ,
109+ SCRIPTS_PATH )
92110 logger .info ("Scripts extracted to: %s" , dest )
111+
112+
113+ def fetch_scripts (destination : str | Path = "." ,
114+ overwrite_url : str | None = None ,
115+ overwrite_tag : str | None = None ) -> Path :
116+ """Download the latest compatible medcat-scripts folder into.
117+
118+ Args:
119+ destination (str | Path): The destination path. Defaults to ".".
120+ overwrite_url (str | None): The overwrite URL. Defaults to None.
121+ overwrite_tag (str | None): The overwrite tag. Defaults to None.
122+
123+ Returns:
124+ Path: The path of the scripts.
125+ """
126+ dest = Path (destination ).expanduser ().resolve ()
127+ dest .mkdir (parents = True , exist_ok = True )
128+
129+ zip_url = _determine_url (overwrite_url , overwrite_tag )
130+ with tempfile .NamedTemporaryFile () as tmp :
131+ _download_zip (zip_url , tmp )
132+ _extract_zip (dest , Path (tmp .name ))
93133 return dest
94134
95135
96- def main (destination : str = "." ,
97- log_level : int | str = logging .INFO ):
136+ def main (* in_args : str ):
137+ parser = argparse .ArgumentParser (
138+ prog = "python -m medcat download-scripts" ,
139+ description = "Download medcat-scripts"
140+ )
141+ parser .add_argument ("destination" , type = str , default = "." , nargs = '?' ,
142+ help = "The destination folder for the scripts" )
143+ parser .add_argument ("--overwrite-url" , type = str , default = None ,
144+ help = "The URL to download and extract from. "
145+ "This is expected to refer to a .zip file "
146+ "that has a `medcat-scripts` folder." )
147+ parser .add_argument ("--overwrite-tag" , '-t' , type = str , default = None ,
148+ help = "The tag to use from GitHub" )
149+ parser .add_argument ("--log-level" , type = str , default = 'INFO' ,
150+ choices = ["DEBUG" , "INFO" , "WARNING" , "ERROR" ],
151+ help = "The log level for fetching" )
152+ args = parser .parse_args (in_args )
153+ log_level = args .log_level
98154 logger .setLevel (log_level )
99155 if not logger .handlers :
100156 logger .addHandler (logging .StreamHandler ())
101- fetch_scripts (destination )
157+ fetch_scripts (args .destination , args .overwrite_url ,
158+ args .overwrite_tag )
0 commit comments