@@ -65,7 +65,14 @@ def __init__(
6565 self .ctx = ctx
6666 self .artifact = artifact
6767 self .uri = get_cloud_uri (uri , namespace ) if namespace else uri
68- self ._file_properties = self ._get_file_properties ()
68+ self ._file_properties = {
69+ ModelFileProperties .TILEDB_ML_MODEL_ML_FRAMEWORK .value : self .Name ,
70+ ModelFileProperties .TILEDB_ML_MODEL_ML_FRAMEWORK_VERSION .value : self .Version ,
71+ ModelFileProperties .TILEDB_ML_MODEL_STAGE .value : "STAGING" ,
72+ ModelFileProperties .TILEDB_ML_MODEL_PYTHON_VERSION .value : platform .python_version (),
73+ ModelFileProperties .TILEDB_ML_MODEL_PREVIEW .value : self .preview (),
74+ ModelFileProperties .TILEDB_ML_MODEL_VERSION .value : __version__ ,
75+ }
6976
7077 @abstractmethod
7178 def save (self , * , update : bool = False , meta : Optional [Meta ] = None ) -> None :
@@ -88,34 +95,23 @@ def get_weights(self, timestamp: Optional[Timestamp] = None) -> Weights:
8895 """
8996 Returns model's weights. Works for Tensorflow Keras and PyTorch
9097 """
91- return cast (Weights , self ._get_model_param ("model" , timestamp ))
98+ with tiledb .open (self .uri , ctx = self .ctx , timestamp = timestamp ) as model_array :
99+ return cast (Weights , self ._get_model_param (model_array , "model" ))
92100
93101 def get_optimizer_weights (self , timestamp : Optional [Timestamp ] = None ) -> Weights :
94102 """
95103 Returns optimizer's weights. Works for Tensorflow Keras and PyTorch
96104 """
97- return cast (Weights , self ._get_model_param ("optimizer" , timestamp ))
105+ with tiledb .open (self .uri , ctx = self .ctx , timestamp = timestamp ) as model_array :
106+ return cast (Weights , self ._get_model_param (model_array , "optimizer" ))
98107
99108 @abstractmethod
100109 def preview (self ) -> str :
101110 """
102111 Creates a string representation of a machine learning model.
103112 """
104113
105- def _get_file_properties (self ) -> Mapping [str , str ]:
106- return {
107- ModelFileProperties .TILEDB_ML_MODEL_ML_FRAMEWORK .value : self .Name ,
108- ModelFileProperties .TILEDB_ML_MODEL_ML_FRAMEWORK_VERSION .value : self .Version ,
109- ModelFileProperties .TILEDB_ML_MODEL_STAGE .value : "STAGING" ,
110- ModelFileProperties .TILEDB_ML_MODEL_PYTHON_VERSION .value : platform .python_version (),
111- ModelFileProperties .TILEDB_ML_MODEL_PREVIEW .value : self .preview (),
112- ModelFileProperties .TILEDB_ML_MODEL_VERSION .value : __version__ ,
113- }
114-
115- def _create_array (
116- self ,
117- fields : Sequence [str ],
118- ) -> None :
114+ def _create_array (self , fields : Sequence [str ]) -> None :
119115 """Internal method that creates a TileDB array based on the model's spec."""
120116
121117 # The array will be be 1 dimensional with domain of 0 to max uint64. We use a tile extent of 1024 bytes
@@ -152,101 +148,78 @@ def _create_array(
152148 if self .namespace :
153149 update_file_properties (self .uri , self ._file_properties )
154150
155- def _write_array (self , model_params : Mapping [str , bytes ]) -> None :
156- """
157- Writes machine learning model related data, i.e., model weights, optimizer weights and Tensorboard files, to
158- a dense TileDB array.
159- """
151+ def _write_array (
152+ self ,
153+ model_params : Mapping [str , bytes ],
154+ tensorboard_log_dir : Optional [str ] = None ,
155+ meta : Optional [Meta ] = None ,
156+ ) -> None :
157+ if tensorboard_log_dir :
158+ tensorboard = self ._serialize_tensorboard (tensorboard_log_dir )
159+ else :
160+ tensorboard = b""
161+ model_params = dict (tensorboard = tensorboard , ** model_params )
162+
163+ if meta is None :
164+ meta = {}
165+ if not meta .keys ().isdisjoint (self ._file_properties .keys ()):
166+ raise ValueError (
167+ "Please avoid using file property key names as metadata keys!"
168+ )
160169
161170 with tiledb .open (self .uri , "w" , ctx = self .ctx ) as model_array :
162171 one_d_buffers = {}
163172 max_len = 0
164-
165173 for key , value in model_params .items ():
166174 one_d_buffer = np .frombuffer (value , dtype = np .uint8 )
167175 one_d_buffer_len = len (one_d_buffer )
168176 one_d_buffers [key ] = one_d_buffer
169-
170177 # Write size only in case is greater than 0.
171178 if one_d_buffer_len :
172179 model_array .meta [key + "_size" ] = one_d_buffer_len
173-
174180 if one_d_buffer_len > max_len :
175181 max_len = one_d_buffer_len
176182
177183 model_array [0 :max_len ] = {
178184 key : np .pad (value , (0 , max_len - len (value )))
179185 for key , value in one_d_buffers .items ()
180186 }
181-
182- def _write_model_metadata (self , meta : Meta ) -> None :
183- """
184- Update the metadata in a TileDB model array. File properties also go in the metadata section.
185- :param meta: A mapping with the <key, value> pairs to be inserted in array's metadata.
186- """
187- with tiledb .open (self .uri , "w" , ctx = self .ctx ) as model_array :
188- # Raise ValueError in case users provide metadata with the same keys as file properties.
189- if not meta .keys ().isdisjoint (self ._file_properties .keys ()):
190- raise ValueError (
191- "Please avoid using file property key names as metadata keys!"
192- )
193-
194- for key , value in meta .items ():
195- model_array .meta [key ] = value
196-
197- for key , value in self ._file_properties .items ():
198- model_array .meta [key ] = value
187+ for mapping in meta , self ._file_properties :
188+ for key , value in mapping .items ():
189+ model_array .meta [key ] = value
190+
191+ def _get_model_param (self , model_array : tiledb .Array , key : str ) -> Any :
192+ size_key = key + "_size"
193+ try :
194+ size = model_array .meta [size_key ]
195+ except KeyError :
196+ raise Exception (
197+ f"{ size_key } metadata entry not present in { self .uri } "
198+ f" (existing keys: { set (model_array .meta .keys ())} )"
199+ )
200+ return pickle .loads (model_array .query (attrs = (key ,))[0 :size ][key ].tobytes ())
199201
200202 @staticmethod
201- def _serialize_tensorboard_files (log_dir : str ) -> bytes :
203+ def _serialize_tensorboard (log_dir : str ) -> bytes :
202204 """Serialize all Tensorboard files."""
203-
204205 if not os .path .exists (log_dir ):
205206 raise ValueError (f"{ log_dir } does not exist" )
206-
207- event_files = {}
207+ tensorboard_files = {}
208208 for path in glob .glob (f"{ log_dir } /*tfevents*" ):
209209 with open (path , "rb" ) as f :
210- event_files [path ] = f .read ()
210+ tensorboard_files [path ] = f .read ()
211+ return pickle .dumps (tensorboard_files , protocol = 4 )
211212
212- return pickle .dumps (event_files , protocol = 4 )
213-
214- def _get_model_param (self , key : str , timestamp : Optional [Timestamp ]) -> Any :
215- with tiledb .open (self .uri , ctx = self .ctx , timestamp = timestamp ) as model_array :
216- size_key = key + "_size"
217- try :
218- size = model_array .meta [size_key ]
219- except KeyError :
220- raise Exception (
221- f"{ size_key } metadata entry not present in { self .uri } "
222- f" (existing keys: { set (model_array .meta .keys ())} )"
223- )
224- return pickle .loads (model_array [0 :size ][key ].tobytes ())
225-
226- def _load_tensorboard (self , timestamp : Optional [Timestamp ] = None ) -> None :
213+ def _load_tensorboard (self , model_array : tiledb .Array ) -> None :
227214 """
228- Writes Tensorboard files to directory. Works for Tensorflow-Keras and PyTorch.
215+ Write Tensorboard files to directory. Works for Tensorflow-Keras and PyTorch.
229216 """
230- with tiledb .open (self .uri , ctx = self .ctx , timestamp = timestamp ) as model_array :
231- try :
232- tensorboard_size = model_array .meta ["tensorboard_size" ]
233- except KeyError :
234- raise Exception (
235- f"tensorboard_size metadata entry not present in"
236- f" (existing keys: { set (model_array .meta .keys ())} )"
237- )
238-
239- tb_contents = model_array [0 :tensorboard_size ]["tensorboard" ]
240- tensorboard_files = pickle .loads (tb_contents .tobytes ())
241-
242- for path , file_bytes in tensorboard_files .items ():
243- log_dir = os .path .dirname (path )
244- if not os .path .exists (log_dir ):
245- os .mkdir (log_dir )
246- with open (os .path .join (log_dir , os .path .basename (path )), "wb" ) as f :
247- f .write (file_bytes )
248-
249- def _use_legacy_schema (self , timestamp : Optional [Timestamp ]) -> bool :
217+ tensorboard_files = self ._get_model_param (model_array , "tensorboard" )
218+ for path , file_bytes in tensorboard_files .items ():
219+ os .makedirs (os .path .dirname (path ), exist_ok = True )
220+ with open (path , "wb" ) as f :
221+ f .write (file_bytes )
222+
223+ def _use_legacy_schema (self , model_array : tiledb .Array ) -> bool :
250224 # TODO: Decide based on tiledb-ml version and not on schema characteristics, like "offset".
251- with tiledb .open (self .uri , ctx = self .ctx , timestamp = timestamp ) as model_array :
252- return str (model_array .schema .domain .dim (0 ).name ) != "offset"
225+ return str (model_array .schema .domain .dim (0 ).name ) != "offset"
0 commit comments