66
66
SAVE_DIR = os .getenv ("SPARSEZOO_MODELS_PATH" , CACHE_DIR )
67
67
COMPRESSED_FILE_NAME = "model.onnx.tar.gz"
68
68
69
+ STUB_V1_REGEX_EXPR = (
70
+ r"^(zoo:)?"
71
+ r"(?P<domain>[\.A-z0-9_]+)"
72
+ r"/(?P<sub_domain>[\.A-z0-9_]+)"
73
+ r"/(?P<architecture>[\.A-z0-9_]+)(-(?P<sub_architecture>[\.A-z0-9_]+))?"
74
+ r"/(?P<framework>[\.A-z0-9_]+)"
75
+ r"/(?P<repo>[\.A-z0-9_]+)"
76
+ r"/(?P<dataset>[\.A-z0-9_]+)(-(?P<training_scheme>[\.A-z0-9_]+))?"
77
+ r"/(?P<sparse_tag>[\.A-z0-9_-]+)"
78
+ )
79
+
80
+ STUB_V2_REGEX_EXPR = (
81
+ r"^(zoo:)?"
82
+ r"(?P<architecture>[\.A-z0-9_]+)"
83
+ r"(-(?P<sub_architecture>[\.A-z0-9_]+))?"
84
+ r"-(?P<source_dataset>[\.A-z0-9_]+)"
85
+ r"(-(?P<training_dataset>[\.A-z0-9_]+))?"
86
+ r"-(?P<sparse_tag>[\.A-z0-9_]+)"
87
+ )
88
+
69
89
70
90
def load_files_from_directory (directory_path : str ) -> List [Dict [str , Any ]]:
71
91
"""
@@ -118,33 +138,44 @@ def load_files_from_stub(
118
138
models = api .fetch (
119
139
operation_body = "models" ,
120
140
arguments = arguments ,
121
- fields = ["modelId" , "modelOnnxSizeCompressedBytes" ],
141
+ fields = [
142
+ "model_id" ,
143
+ "model_onnx_size_compressed_bytes" ,
144
+ "files" ,
145
+ "benchmark_results" ,
146
+ "training_results" ,
147
+ ],
122
148
)
123
149
124
- if len (models ):
125
- model_id = models [0 ]["model_id" ]
126
-
127
- files = api .fetch (
128
- operation_body = "files" ,
129
- arguments = {"model_id" : model_id },
150
+ matching_models = len (models )
151
+ if matching_models == 0 :
152
+ raise ValueError (
153
+ f"No matching models found with stub: { stub } ." "Please try another stub"
130
154
)
155
+ if matching_models > 1 :
156
+ logging .warning (
157
+ f"{ len (models )} found from the stub: { stub } "
158
+ "Using the first model to obtain metadata."
159
+ "Proceed with caution"
160
+ )
161
+
162
+ if matching_models :
163
+ model = models [0 ]
164
+
165
+ model_id = model ["model_id" ]
166
+
167
+ files = model .get ("files" )
131
168
include_file_download_url (files )
132
169
files = restructure_request_json (request_json = files )
133
170
134
171
if params is not None :
135
172
files = filter_files (files = files , params = params )
136
173
137
- training_results = api .fetch (
138
- operation_body = "training_results" ,
139
- arguments = {"model_id" : model_id },
140
- )
174
+ training_results = model .get ("training_results" )
141
175
142
- benchmark_results = api .fetch (
143
- operation_body = "benchmark_results" ,
144
- arguments = {"model_id" : model_id },
145
- )
176
+ benchmark_results = model .get ("benchmark_results" )
146
177
147
- model_onnx_size_compressed_bytes = models [ 0 ][ "model_onnx_size_compressed_bytes" ]
178
+ model_onnx_size_compressed_bytes = model . get ( "model_onnx_size_compressed_bytes" )
148
179
149
180
throughput_results = [
150
181
ThroughputResults (** benchmark_result )
@@ -553,6 +584,38 @@ def include_file_download_url(files: List[Dict]):
553
584
)
554
585
555
586
587
+ def get_model_metadata_from_stub (stub : str ) -> Dict [str , str ]:
588
+ """Return a dictionary of the model metadata from stub"""
589
+
590
+ matches = re .match (STUB_V1_REGEX_EXPR , stub ) or re .match (STUB_V2_REGEX_EXPR , stub )
591
+ if not matches :
592
+ return {}
593
+
594
+ if "source_dataset" in matches .groupdict ():
595
+ return {"repo_name" : stub }
596
+
597
+ if "dataset" in matches .groupdict ():
598
+ return {
599
+ "domain" : matches .group ("domain" ),
600
+ "sub_domain" : matches .group ("sub_domain" ),
601
+ "architecture" : matches .group ("architecture" ),
602
+ "sub_architecture" : matches .group ("sub_architecture" ),
603
+ "framework" : matches .group ("framework" ),
604
+ "repo" : matches .group ("repo" ),
605
+ "dataset" : matches .group ("dataset" ),
606
+ "sparse_tag" : matches .group ("sparse_tag" ),
607
+ }
608
+
609
+ return {}
610
+
611
+
612
+ def is_stub (candidate : str ) -> bool :
613
+ return bool (
614
+ re .match (STUB_V1_REGEX_EXPR , candidate )
615
+ or re .match (STUB_V2_REGEX_EXPR , candidate )
616
+ )
617
+
618
+
556
619
def get_file_download_url (
557
620
model_id : str ,
558
621
file_name : str ,
@@ -566,32 +629,3 @@ def get_file_download_url(
566
629
download_url += "?increment_download=False"
567
630
568
631
return download_url
569
-
570
-
571
- def get_model_metadata_from_stub (stub : str ) -> Dict [str , str ]:
572
- """
573
- Return a dictionary of the model metadata from stub
574
- """
575
-
576
- stub_regex_expr = (
577
- r"^(zoo:)?"
578
- r"(?P<domain>[\.A-z0-9_]+)"
579
- r"/(?P<sub_domain>[\.A-z0-9_]+)"
580
- r"/(?P<architecture>[\.A-z0-9_]+)(-(?P<sub_architecture>[\.A-z0-9_]+))?"
581
- r"/(?P<framework>[\.A-z0-9_]+)"
582
- r"/(?P<repo>[\.A-z0-9_]+)"
583
- r"/(?P<dataset>[\.A-z0-9_]+)"
584
- r"/(?P<sparse_tag>[\.A-z0-9_-]+)"
585
- )
586
- matches = re .match (stub_regex_expr , stub )
587
-
588
- return {
589
- "domain" : matches .group ("domain" ),
590
- "sub_domain" : matches .group ("sub_domain" ),
591
- "architecture" : matches .group ("architecture" ),
592
- "sub_architecture" : matches .group ("sub_architecture" ),
593
- "framework" : matches .group ("framework" ),
594
- "repo" : matches .group ("repo" ),
595
- "dataset" : matches .group ("dataset" ),
596
- "sparse_tag" : matches .group ("sparse_tag" ),
597
- }
0 commit comments