@@ -99,6 +99,7 @@ def __init__(
9999 adapters_dir_path : str ,
100100 hbm_memory_budget : int ,
101101 cpu_memory_budget : int ,
102+ total_slots : int ,
102103 ):
103104 """Initializes the AdapterTensorStore."""
104105 self .engine = engine # Possibly MaxEngine object
@@ -119,8 +120,27 @@ def __init__(
119120 self .running_requests : int = (
120121 0 # Number of async tasks which are in "loading" state
121122 )
123+ self .decoding_adapters_cache : Dict [str , Any ] = {}
124+
125+ # TODO: Make dtype configurable for the scale factor array
126+ self .adapters_scale_factor = jnp .empty (1 , dtype = jnp .bfloat16 )
127+
128+ self .total_slots = total_slots
122129 self .lock = asyncio .Lock () # Use an asyncio Lock for thread safety
123130
131+ def _get_adapter_scale_factor (self , adapter_id : str ):
132+ """
133+ Internal: Get the LoRA scale_factor using the adapter_id.
134+ """
135+ adapter_config = self .adapter_registry [adapter_id ].config
136+ lora_scale_factor = float (1 )
137+
138+ if "r" in adapter_config and "lora_alpha" in adapter_config :
139+ lora_rank = int (adapter_config ["r" ])
140+ lora_scale_factor = float (adapter_config ["lora_alpha" ]) / lora_rank
141+
142+ return lora_scale_factor
143+
124144 # --- Unsafe Internal methods which assumes that lock is held ---
125145 def _unsafe_transfer_to_hbm (self , adapter_id : str ):
126146 """
@@ -207,6 +227,90 @@ def _unsafe_unload_adapter(self, adapter_id: str):
207227 metadata .size_hbm = 0
208228 metadata .size_cpu = 0
209229
230+ def _initialize_decoding_adapters_cache (self , adapter_weights ):
231+ """
232+ Create a new PyTree with zero tensors at the paths corresponding to
233+ non-None leaves in the input PyTree. The zero tensors have an added
234+ dimension of size `self.totol_slots`.
235+ Args:
236+ adatper_weights: The input PyTree, whose structure will be mirrored.
237+ Returns:
238+ A new PyTree with zero Tensors or None values, mirroring the structure
239+ of the input PyTree.
240+ """
241+
242+ def create_zero_leaf (leaf ):
243+ if leaf is not None :
244+ original_shape = leaf .shape
245+ if not original_shape : # handle scalar case
246+ zero_tensor_shape = (self .total_slots ,)
247+ else :
248+ zero_tensor_shape = (
249+ self .total_slots ,
250+ ) + original_shape # Prepend a new dimension
251+
252+ return jnp .zeros (zero_tensor_shape , dtype = leaf .dtype )
253+ else :
254+ return None # Maintain None structure for None leaves
255+
256+ self .adapters_scale_factor = jnp .ones (self .total_slots , dtype = jnp .bfloat16 )
257+ return jax .tree_util .tree_map (create_zero_leaf , adapter_weights )
258+
259+ def insert_adapter_in_cache (self , adapter_id : str , slot_id : int ):
260+ """
261+ Insert the specific adapter tensors into a slot in the
262+ serving_adapters_cache.
263+ Args:
264+ adapter_id: The id of the adapter, whose tensors will be inserted
265+ slot_id: The id of slot, which represents the index in the
266+ serving_adapter_cache where the adapter tensors will be inserted.
267+ """
268+
269+ def insert_leaf (dest_leaf , source_leaf ):
270+ if dest_leaf is not None and source_leaf is not None :
271+ return dest_leaf .at [slot_id ].set (
272+ source_leaf
273+ ) # Insert at the specific index
274+ elif dest_leaf is not None :
275+ return dest_leaf # If source_leaf is None, keep the zero_leaf as is
276+ elif (
277+ source_leaf is not None
278+ ): # In this case the adapters have different target modules
279+ original_shape = source_leaf .shape
280+ if not original_shape : # Handle scalar case
281+ zero_tensor_shape = (self .total_slots ,)
282+ else :
283+ zero_tensor_shape = (self .total_slots ,) + original_shape
284+ new_dest_leaf = jnp .zeros (zero_tensor_shape , dtype = source_leaf .dtype )
285+ return new_dest_leaf .at [slot_id ].set (source_leaf )
286+ else :
287+ return None # If both are None, return None
288+
289+ if adapter_id == "" :
290+ logging .info (
291+ "Empty adapter id. No LoRA tensors added to adapter_tensorstore cache"
292+ )
293+ return
294+
295+ asyncio .run (self .load_adapter (adapter_id , None , True ))
296+
297+ adapter_weights = self .loaded_adapters_hbm [adapter_id ]
298+
299+ if not self .decoding_adapters_cache :
300+ self .decoding_adapters_cache = self ._initialize_decoding_adapters_cache (
301+ adapter_weights
302+ )
303+
304+ adapter_scale_factor = jnp .bfloat16 (
305+ self ._get_adapter_scale_factor (adapter_id )
306+ )
307+ self .adapters_scale_factor = self .adapters_scale_factor .at [slot_id ].set (
308+ adapter_scale_factor
309+ )
310+ self .decoding_adapters_cache = jax .tree_util .tree_map (
311+ insert_leaf , self .decoding_adapters_cache , adapter_weights
312+ )
313+
210314 # --- Public Methods (Acquire lock, then call unsafe methods) ---
211315
212316 async def register_adapter (
0 commit comments