Skip to content

Commit 00ccfac

Browse files
horheynmDanny Guintherjeanniefinksdhuangnmdhuang
authored
Stub v2 (#271)
* Dummy graphql requests module * graphql api request * return Model instances * Update NOTICE (#242) license name change * bump main to 1.4.0 (#246) Co-authored-by: dhuang <[email protected]> * Pin numpy version to <=1.21.6 (#247) search search, download draft draft, successful search and download draft Update: `ModelAnalysis.from_onnx(...)` to additionally work with loaded `ModelProto` (#253) refactor search, download * lint * pass tests * init files * lint * Add dummy test using test-specific subclass * tests * add incremeent_downloads=False * allow empty arguments * comments * query parser, allow dict as input, add tests for extra functionality * restore models.utils * restore models.utils * v2 stub * comments * change stubs to ones on prod * lint * Update src/sparsezoo/model/utils.py Co-authored-by: Danny Guinther <[email protected]> * Update src/sparsezoo/model/utils.py Co-authored-by: Danny Guinther <[email protected]> * Update src/sparsezoo/api/utils.py Co-authored-by: Danny Guinther <[email protected]> --------- Co-authored-by: Danny Guinther <[email protected]> Co-authored-by: Jeannie Finks <[email protected]> Co-authored-by: dhuangnm <[email protected]> Co-authored-by: dhuang <[email protected]> Co-authored-by: Rahul Tuli <[email protected]> Co-authored-by: Danny Guinther <[email protected]>
1 parent a9718cc commit 00ccfac

File tree

4 files changed

+140
-53
lines changed

4 files changed

+140
-53
lines changed

src/sparsezoo/api/utils.py

+19-4
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Callable, Dict
15+
from typing import Any, Callable, Dict, List
1616

1717

1818
def to_camel_case(string: str):
@@ -32,7 +32,22 @@ def to_snake_case(string: str):
3232

3333

3434
def map_keys(
35-
dictionary: Dict[str, str], mapper: Callable[[str], str]
35+
dictionary: Dict[str, Any], mapper: Callable[[str], str]
3636
) -> Dict[str, str]:
37-
"""Given a dictionary, update its key to a given mapper callable"""
38-
return {mapper(key): value for key, value in dictionary.items()}
37+
"""
38+
Given a dictionary, update its keys to a given mapper callable.
39+
40+
If the value of the dict is a List of Dict or Dict of Dict, recursively map
41+
its keys
42+
"""
43+
mapped_dict = {}
44+
for key, value in dictionary.items():
45+
if isinstance(value, List) or isinstance(value, Dict):
46+
value_type = type(value)
47+
mapped_dict[mapper(key)] = value_type(
48+
map_keys(dictionary=sub_dict, mapper=mapper) for sub_dict in value
49+
)
50+
else:
51+
mapped_dict[mapper(key)] = value
52+
53+
return mapped_dict

src/sparsezoo/model/model.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from sparsezoo.model.utils import (
2626
SAVE_DIR,
2727
ZOO_STUB_PREFIX,
28+
is_stub,
2829
load_files_from_directory,
2930
load_files_from_stub,
3031
save_outputs_to_tar,
@@ -78,7 +79,7 @@ def __init__(self, source: str, download_path: Optional[str] = None):
7879
self.source = source
7980
self._stub_params = {}
8081

81-
if self.source.startswith(ZOO_STUB_PREFIX):
82+
if is_stub(self.source):
8283
# initializing the files and params from the stub
8384
_setup_args = self.initialize_model_from_stub(stub=self.source)
8485
files, path, url, validation_results, compressed_size = _setup_args

src/sparsezoo/model/utils.py

+79-45
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,26 @@
6666
SAVE_DIR = os.getenv("SPARSEZOO_MODELS_PATH", CACHE_DIR)
6767
COMPRESSED_FILE_NAME = "model.onnx.tar.gz"
6868

69+
STUB_V1_REGEX_EXPR = (
70+
r"^(zoo:)?"
71+
r"(?P<domain>[\.A-z0-9_]+)"
72+
r"/(?P<sub_domain>[\.A-z0-9_]+)"
73+
r"/(?P<architecture>[\.A-z0-9_]+)(-(?P<sub_architecture>[\.A-z0-9_]+))?"
74+
r"/(?P<framework>[\.A-z0-9_]+)"
75+
r"/(?P<repo>[\.A-z0-9_]+)"
76+
r"/(?P<dataset>[\.A-z0-9_]+)(-(?P<training_scheme>[\.A-z0-9_]+))?"
77+
r"/(?P<sparse_tag>[\.A-z0-9_-]+)"
78+
)
79+
80+
STUB_V2_REGEX_EXPR = (
81+
r"^(zoo:)?"
82+
r"(?P<architecture>[\.A-z0-9_]+)"
83+
r"(-(?P<sub_architecture>[\.A-z0-9_]+))?"
84+
r"-(?P<source_dataset>[\.A-z0-9_]+)"
85+
r"(-(?P<training_dataset>[\.A-z0-9_]+))?"
86+
r"-(?P<sparse_tag>[\.A-z0-9_]+)"
87+
)
88+
6989

7090
def load_files_from_directory(directory_path: str) -> List[Dict[str, Any]]:
7191
"""
@@ -118,33 +138,44 @@ def load_files_from_stub(
118138
models = api.fetch(
119139
operation_body="models",
120140
arguments=arguments,
121-
fields=["modelId", "modelOnnxSizeCompressedBytes"],
141+
fields=[
142+
"model_id",
143+
"model_onnx_size_compressed_bytes",
144+
"files",
145+
"benchmark_results",
146+
"training_results",
147+
],
122148
)
123149

124-
if len(models):
125-
model_id = models[0]["model_id"]
126-
127-
files = api.fetch(
128-
operation_body="files",
129-
arguments={"model_id": model_id},
150+
matching_models = len(models)
151+
if matching_models == 0:
152+
raise ValueError(
153+
f"No matching models found with stub: {stub}." "Please try another stub"
130154
)
155+
if matching_models > 1:
156+
logging.warning(
157+
f"{len(models)} found from the stub: {stub}"
158+
"Using the first model to obtain metadata."
159+
"Proceed with caution"
160+
)
161+
162+
if matching_models:
163+
model = models[0]
164+
165+
model_id = model["model_id"]
166+
167+
files = model.get("files")
131168
include_file_download_url(files)
132169
files = restructure_request_json(request_json=files)
133170

134171
if params is not None:
135172
files = filter_files(files=files, params=params)
136173

137-
training_results = api.fetch(
138-
operation_body="training_results",
139-
arguments={"model_id": model_id},
140-
)
174+
training_results = model.get("training_results")
141175

142-
benchmark_results = api.fetch(
143-
operation_body="benchmark_results",
144-
arguments={"model_id": model_id},
145-
)
176+
benchmark_results = model.get("benchmark_results")
146177

147-
model_onnx_size_compressed_bytes = models[0]["model_onnx_size_compressed_bytes"]
178+
model_onnx_size_compressed_bytes = model.get("model_onnx_size_compressed_bytes")
148179

149180
throughput_results = [
150181
ThroughputResults(**benchmark_result)
@@ -553,6 +584,38 @@ def include_file_download_url(files: List[Dict]):
553584
)
554585

555586

587+
def get_model_metadata_from_stub(stub: str) -> Dict[str, str]:
588+
"""Return a dictionary of the model metadata from stub"""
589+
590+
matches = re.match(STUB_V1_REGEX_EXPR, stub) or re.match(STUB_V2_REGEX_EXPR, stub)
591+
if not matches:
592+
return {}
593+
594+
if "source_dataset" in matches.groupdict():
595+
return {"repo_name": stub}
596+
597+
if "dataset" in matches.groupdict():
598+
return {
599+
"domain": matches.group("domain"),
600+
"sub_domain": matches.group("sub_domain"),
601+
"architecture": matches.group("architecture"),
602+
"sub_architecture": matches.group("sub_architecture"),
603+
"framework": matches.group("framework"),
604+
"repo": matches.group("repo"),
605+
"dataset": matches.group("dataset"),
606+
"sparse_tag": matches.group("sparse_tag"),
607+
}
608+
609+
return {}
610+
611+
612+
def is_stub(candidate: str) -> bool:
613+
return bool(
614+
re.match(STUB_V1_REGEX_EXPR, candidate)
615+
or re.match(STUB_V2_REGEX_EXPR, candidate)
616+
)
617+
618+
556619
def get_file_download_url(
557620
model_id: str,
558621
file_name: str,
@@ -566,32 +629,3 @@ def get_file_download_url(
566629
download_url += "?increment_download=False"
567630

568631
return download_url
569-
570-
571-
def get_model_metadata_from_stub(stub: str) -> Dict[str, str]:
572-
"""
573-
Return a dictionary of the model metadata from stub
574-
"""
575-
576-
stub_regex_expr = (
577-
r"^(zoo:)?"
578-
r"(?P<domain>[\.A-z0-9_]+)"
579-
r"/(?P<sub_domain>[\.A-z0-9_]+)"
580-
r"/(?P<architecture>[\.A-z0-9_]+)(-(?P<sub_architecture>[\.A-z0-9_]+))?"
581-
r"/(?P<framework>[\.A-z0-9_]+)"
582-
r"/(?P<repo>[\.A-z0-9_]+)"
583-
r"/(?P<dataset>[\.A-z0-9_]+)"
584-
r"/(?P<sparse_tag>[\.A-z0-9_-]+)"
585-
)
586-
matches = re.match(stub_regex_expr, stub)
587-
588-
return {
589-
"domain": matches.group("domain"),
590-
"sub_domain": matches.group("sub_domain"),
591-
"architecture": matches.group("architecture"),
592-
"sub_architecture": matches.group("sub_architecture"),
593-
"framework": matches.group("framework"),
594-
"repo": matches.group("repo"),
595-
"dataset": matches.group("dataset"),
596-
"sparse_tag": matches.group("sparse_tag"),
597-
}

tests/sparsezoo/model/test_model.py

+40-3
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,16 @@
8282
("checkpoint", "postqat"),
8383
True,
8484
),
85+
(
86+
"biobert-base_cased-jnlpba_pubmed-pruned80.4block_quantized",
87+
("deployment", "default"),
88+
True,
89+
),
90+
(
91+
"resnet_v1-50-imagenet-pruned95",
92+
("checkpoint", "preqat"),
93+
True,
94+
),
8595
],
8696
scope="function",
8797
)
@@ -127,20 +137,47 @@ def _assert_validation_results_exist(model):
127137
"stub, clone_sample_outputs, expected_files",
128138
[
129139
(
130-
"zoo:cv/classification/mobilenet_v1-1.0/pytorch/sparseml/imagenet/pruned-moderate", # noqa E501
140+
(
141+
"zoo:"
142+
"cv/classification/mobilenet_v1-1.0/"
143+
"pytorch/sparseml/imagenet/pruned-moderate"
144+
),
131145
True,
132146
files_ic,
133147
),
134148
(
135-
"zoo:nlp/question_answering/distilbert-none/pytorch/huggingface/squad/pruned80_quant-none-vnni", # noqa E501
149+
(
150+
"zoo:"
151+
"nlp/question_answering/distilbert-none/"
152+
"pytorch/huggingface/squad/pruned80_quant-none-vnni"
153+
),
136154
False,
137155
files_nlp,
138156
),
139157
(
140-
"zoo:cv/detection/yolov5-s/pytorch/ultralytics/coco/pruned_quant-aggressive_94", # noqa E501
158+
(
159+
"zoo:"
160+
"cv/detection/yolov5-s/"
161+
"pytorch/ultralytics/coco/pruned_quant-aggressive_94"
162+
),
141163
True,
142164
files_yolo,
143165
),
166+
(
167+
"yolov5-x-coco-pruned70.4block_quantized",
168+
False,
169+
files_yolo,
170+
),
171+
(
172+
"yolov5-n6-voc_coco-pruned55",
173+
False,
174+
files_yolo,
175+
),
176+
(
177+
"resnet_v1-50-imagenet-channel30_pruned90_quantized",
178+
False,
179+
files_yolo,
180+
),
144181
],
145182
scope="function",
146183
)

0 commit comments

Comments
 (0)