@@ -309,7 +309,9 @@ def use_scale_search(model_config, qtype):
309
309
310
310
def _replace_with_low_bit_linear (model , qtype , modules_to_not_convert = None ,
311
311
convert_shape_only = False ,
312
- cpu_embedding = False , prefix_name = '' ,
312
+ cpu_embedding = False ,
313
+ disk_embedding = False ,
314
+ prefix_name = '' ,
313
315
imatrix_data = None , embedding_qtype = None ,
314
316
model_config = None , torch_dtype = torch .float32 ,
315
317
enable_xetla = False ,
@@ -319,7 +321,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
319
321
):
320
322
from ipex_llm .transformers .low_bit_linear import LowBitLinear , FP4Params , \
321
323
FP16Linear , BF16Linear
322
- from ipex_llm .transformers .embedding import LLMEmbedding , LowBitEmbedding
324
+ from ipex_llm .transformers .embedding import CPUEmbedding , DiskEmbedding , LowBitEmbedding
323
325
has_been_replaced = False
324
326
325
327
for name , module in model .named_children ():
@@ -467,48 +469,15 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
467
469
model ._modules [name ].requires_grad_ (False )
468
470
469
471
module .weight = None
472
+ # skip user-defined Embedding layer
470
473
elif cpu_embedding and type (module ) == nn .Embedding :
471
- # skip user-defined Embedding layer
472
- model ._modules [name ] = LLMEmbedding (
473
- num_embeddings = module .num_embeddings ,
474
- embedding_dim = module .embedding_dim ,
475
- padding_idx = module .padding_idx ,
476
- max_norm = module .max_norm ,
477
- norm_type = module .norm_type ,
478
- scale_grad_by_freq = module .scale_grad_by_freq ,
479
- sparse = module .sparse ,
480
- _weight = module .weight .data ,
481
- )
482
- elif type (module ) == nn .Embedding and embedding_qtype is not None :
483
- if torch_dtype == "auto" :
484
- torch_dtype = torch .float32
485
- q_embedding = LowBitEmbedding (
486
- num_embeddings = module .num_embeddings ,
487
- embedding_dim = module .embedding_dim ,
488
- padding_idx = module .padding_idx ,
489
- max_norm = module .max_norm ,
490
- norm_type = module .norm_type ,
491
- scale_grad_by_freq = module .scale_grad_by_freq ,
492
- sparse = module .sparse ,
493
- _weight = module .weight .data ,
494
- qtype = embedding_qtype ,
495
- torch_dtype = torch_dtype
496
- )
497
- device = module .weight .data .device
498
- # Copy the weights
499
- paramsLowBit = FP4Params (data = module .weight .data ,
500
- requires_grad = False ,
501
- quantized = False ,
502
- _shape = None ,
503
- convert_shape_only = convert_shape_only ,
504
- qtype = embedding_qtype ,
505
- in_features = module .embedding_dim ).to (device )
506
- q_embedding ._parameters ['weight' ] = paramsLowBit
507
- model ._modules [name ] = q_embedding
508
- # Force requires grad to False to avoid unexpected errors
509
- model ._modules [name ].requires_grad_ (False )
510
- module .weight = None
511
-
474
+ model ._modules [name ] = CPUEmbedding .from_embedding (module )
475
+ elif disk_embedding and type (module ) == nn .Embedding :
476
+ model ._modules [name ] = DiskEmbedding .from_embedding (module )
477
+ elif embedding_qtype is not None and type (module ) == nn .Embedding :
478
+ model ._modules [name ] = LowBitEmbedding .from_embedding (module ,
479
+ convert_shape_only ,
480
+ embedding_qtype )
512
481
# Remove the last key for recursion
513
482
if len (list (module .children ())) > 0 :
514
483
_ , _flag = _replace_with_low_bit_linear (
@@ -517,6 +486,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
517
486
modules_to_not_convert ,
518
487
convert_shape_only ,
519
488
cpu_embedding ,
489
+ disk_embedding ,
520
490
prefix_name = prefix_name + '.' + name if prefix_name != '' else name ,
521
491
imatrix_data = imatrix_data ,
522
492
embedding_qtype = embedding_qtype ,
@@ -775,7 +745,8 @@ def _optimize_pre(model, qtype=None):
775
745
776
746
def ggml_convert_low_bit (model , qtype , optimize_model = True ,
777
747
convert_shape_only = False , device = "cpu" ,
778
- modules_to_not_convert = None , cpu_embedding = False ,
748
+ modules_to_not_convert = None ,
749
+ cpu_embedding = False , disk_embedding = False ,
779
750
lightweight_bmm = False , torch_dtype = "auto" ,
780
751
imatrix_data = None ,
781
752
embedding_qtype = None ,
@@ -817,7 +788,7 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
817
788
# mixed quantization needs model_config to choose custom quantization strategy
818
789
model , has_been_replaced = _replace_with_low_bit_linear (
819
790
model , qtype , modules_to_not_convert ,
820
- convert_shape_only , cpu_embedding ,
791
+ convert_shape_only , cpu_embedding , disk_embedding ,
821
792
imatrix_data = imatrix_data ,
822
793
embedding_qtype = embedding_qtype ,
823
794
model_config = model_config ,
0 commit comments