-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathonnx-chunk-external-data.py
134 lines (108 loc) · 4.41 KB
/
onnx-chunk-external-data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
##
## tool to chunk onnx external data files
##
## for example: python onnx-chunk-external-data.py -i ../gemma-2-2b-web.0/onnx/model_q4f16.onnx -o onnx/model_q4f16.onnx --threshhold 1 --maxchunks 1
##
##
import argparse
import onnx
import onnx.external_data_helper as ext_data
import os
import itertools
MB = 1024 * 1024
DEFAULT_MAX_SIZE = 2000 # 2048 fails to fetch in chrome
def recursive_attribute_processor(attribute, func):
if attribute.type == onnx.AttributeProto.GRAPH:
yield from func(attribute.g)
if attribute.type == onnx.AttributeProto.GRAPHS:
for graph in attribute.graphs:
yield from func(graph)
def get_attribute_tensors_from_graph(graph_or_function):
for node in graph_or_function.node:
for attribute in node.attribute:
if attribute.HasField("t"):
yield attribute.t
yield from attribute.tensors
yield from recursive_attribute_processor(attribute, get_attribute_tensors_from_graph)
def get_attribute_tensors(model_proto):
yield from get_attribute_tensors_from_graph(model_proto.graph)
for function in model_proto.functions:
yield from get_attribute_tensors_from_graph(function)
def get_initializer_tensors_from_graph(graph_or_function):
if isinstance(graph_or_function, onnx.GraphProto):
yield from graph_or_function.initializer
for node in graph_or_function.node:
for attribute in node.attribute:
yield from recursive_attribute_processor(attribute, get_initializer_tensors_from_graph)
def get_initializer_tensors(model_proto):
yield from get_initializer_tensors_from_graph(model_proto.graph)
for function in model_proto.functions:
yield from get_attribute_tensors_from_graph(function)
def get_all_tensors(model_proto):
return itertools.chain(get_initializer_tensors(model_proto), get_attribute_tensors(model_proto))
def save_external(model_proto, external_data_name, max_size, threshhold, maxchunks):
idx = 0
file_name = os.path.basename(external_data_name)
def open_segment():
if idx >= maxchunks:
# no more datafiles, stick initializers into onnx file
return None, "model"
if idx == 0:
name = f"{external_data_name}_data"
else:
name = f"{external_data_name}_data_{idx}"
return open(name, "wb"), os.path.basename(name)
offset = 0
f, name = open_segment()
for tensor in get_all_tensors(model_proto):
tensor_data = tensor.raw_data
tensor_size = len(tensor_data)
assert tensor_data and tensor_size > 0
if tensor_size < threshhold:
# small tensors are stored in the model file
continue
if not f:
# no more datafiles, stick initializers into onnx file
continue
if offset + tensor_size > max_size:
# start new chunk
f.close()
print(f"{name}, size: {offset // MB} MB)")
idx += 1
offset = 0
f, name = open_segment()
if not f:
continue
f.write(tensor_data)
tensor.ClearField("raw_data")
del tensor.external_data[:]
tensor.data_location = onnx.TensorProto.EXTERNAL
for k, v in {
"location": name,
"offset": offset,
"length": tensor_size,
}.items():
entry = tensor.external_data.add()
entry.key = k
entry.value = str(v)
offset += tensor_size
if f:
f.close()
print(f"{name}, size: {offset // MB} MB)")
onnx.save_model(model_proto, external_data_name)
def get_args():
parser = argparse.ArgumentParser(description='tool to chunk onnx external data files')
parser.add_argument("--input", "-i", required=True, help='input')
parser.add_argument("--output", "-o", help='output')
parser.add_argument("--size", default=DEFAULT_MAX_SIZE, type=int, help='max weight size')
parser.add_argument("--threshhold", default=0, type=int, help='threshhold in MB to be external data')
parser.add_argument("--maxchunks", default=99, type=int, help='maximum number of datafiles')
args = parser.parse_args()
return args
def main():
args = get_args()
model = onnx.load_model(args.input)
if args.output:
save_external(model, args.output, args.size * MB, args.threshhold * MB, args.maxchunks)
if __name__ == '__main__':
main()