Skip to content

Commit 0209427

Browse files
authored
Add disk_embedding parameter to support put Embedding layer on CPU (#11617)
1 parent 2478e2c commit 0209427

File tree

4 files changed

+86
-66
lines changed

4 files changed

+86
-66
lines changed

python/llm/src/ipex_llm/transformers/convert.py

+16-45
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,9 @@ def use_scale_search(model_config, qtype):
309309

310310
def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
311311
convert_shape_only=False,
312-
cpu_embedding=False, prefix_name='',
312+
cpu_embedding=False,
313+
disk_embedding=False,
314+
prefix_name='',
313315
imatrix_data=None, embedding_qtype=None,
314316
model_config=None, torch_dtype=torch.float32,
315317
enable_xetla=False,
@@ -319,7 +321,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
319321
):
320322
from ipex_llm.transformers.low_bit_linear import LowBitLinear, FP4Params, \
321323
FP16Linear, BF16Linear
322-
from ipex_llm.transformers.embedding import LLMEmbedding, LowBitEmbedding
324+
from ipex_llm.transformers.embedding import CPUEmbedding, DiskEmbedding, LowBitEmbedding
323325
has_been_replaced = False
324326

325327
for name, module in model.named_children():
@@ -467,48 +469,15 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
467469
model._modules[name].requires_grad_(False)
468470

469471
module.weight = None
472+
# skip user-defined Embedding layer
470473
elif cpu_embedding and type(module) == nn.Embedding:
471-
# skip user-defined Embedding layer
472-
model._modules[name] = LLMEmbedding(
473-
num_embeddings=module.num_embeddings,
474-
embedding_dim=module.embedding_dim,
475-
padding_idx=module.padding_idx,
476-
max_norm=module.max_norm,
477-
norm_type=module.norm_type,
478-
scale_grad_by_freq=module.scale_grad_by_freq,
479-
sparse=module.sparse,
480-
_weight=module.weight.data,
481-
)
482-
elif type(module) == nn.Embedding and embedding_qtype is not None:
483-
if torch_dtype == "auto":
484-
torch_dtype = torch.float32
485-
q_embedding = LowBitEmbedding(
486-
num_embeddings=module.num_embeddings,
487-
embedding_dim=module.embedding_dim,
488-
padding_idx=module.padding_idx,
489-
max_norm=module.max_norm,
490-
norm_type=module.norm_type,
491-
scale_grad_by_freq=module.scale_grad_by_freq,
492-
sparse=module.sparse,
493-
_weight=module.weight.data,
494-
qtype=embedding_qtype,
495-
torch_dtype=torch_dtype
496-
)
497-
device = module.weight.data.device
498-
# Copy the weights
499-
paramsLowBit = FP4Params(data=module.weight.data,
500-
requires_grad=False,
501-
quantized=False,
502-
_shape=None,
503-
convert_shape_only=convert_shape_only,
504-
qtype=embedding_qtype,
505-
in_features=module.embedding_dim).to(device)
506-
q_embedding._parameters['weight'] = paramsLowBit
507-
model._modules[name] = q_embedding
508-
# Force requires grad to False to avoid unexpected errors
509-
model._modules[name].requires_grad_(False)
510-
module.weight = None
511-
474+
model._modules[name] = CPUEmbedding.from_embedding(module)
475+
elif disk_embedding and type(module) == nn.Embedding:
476+
model._modules[name] = DiskEmbedding.from_embedding(module)
477+
elif embedding_qtype is not None and type(module) == nn.Embedding:
478+
model._modules[name] = LowBitEmbedding.from_embedding(module,
479+
convert_shape_only,
480+
embedding_qtype)
512481
# Remove the last key for recursion
513482
if len(list(module.children())) > 0:
514483
_, _flag = _replace_with_low_bit_linear(
@@ -517,6 +486,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
517486
modules_to_not_convert,
518487
convert_shape_only,
519488
cpu_embedding,
489+
disk_embedding,
520490
prefix_name=prefix_name + '.' + name if prefix_name != '' else name,
521491
imatrix_data=imatrix_data,
522492
embedding_qtype=embedding_qtype,
@@ -775,7 +745,8 @@ def _optimize_pre(model, qtype=None):
775745

776746
def ggml_convert_low_bit(model, qtype, optimize_model=True,
777747
convert_shape_only=False, device="cpu",
778-
modules_to_not_convert=None, cpu_embedding=False,
748+
modules_to_not_convert=None,
749+
cpu_embedding=False, disk_embedding=False,
779750
lightweight_bmm=False, torch_dtype="auto",
780751
imatrix_data=None,
781752
embedding_qtype=None,
@@ -817,7 +788,7 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
817788
# mixed quantization needs model_config to choose custom quantization strategy
818789
model, has_been_replaced = _replace_with_low_bit_linear(
819790
model, qtype, modules_to_not_convert,
820-
convert_shape_only, cpu_embedding,
791+
convert_shape_only, cpu_embedding, disk_embedding,
821792
imatrix_data=imatrix_data,
822793
embedding_qtype=embedding_qtype,
823794
model_config=model_config,

python/llm/src/ipex_llm/transformers/embedding.py

+59-18
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import numpy
1919
import torch
2020
from torch import Tensor
21-
from torch.nn import functional as F
2221
from torch.nn import Parameter
2322
from typing import Optional
2423
from ipex_llm.transformers.low_bit_linear import FP4Params
@@ -56,7 +55,7 @@ def to(self, *args, **kwargs):
5655
return super().to(*args, **kwargs)
5756

5857

59-
class LLMEmbedding(torch.nn.Embedding):
58+
class CPUEmbedding(torch.nn.Embedding):
6059
def __init__(self,
6160
num_embeddings: int,
6261
embedding_dim: int,
@@ -67,15 +66,32 @@ def __init__(self,
6766
sparse: bool = False,
6867
_weight: Optional[Tensor] = None,
6968
_freeze: bool = False,
70-
device=None, dtype=None) -> None:
69+
device=None,
70+
dtype=None) -> None:
7171
super().__init__(num_embeddings, embedding_dim, padding_idx,
7272
max_norm, norm_type, scale_grad_by_freq,
73-
sparse, _weight, _freeze, device, dtype)
74-
self.weight = CPUPinnedParam(self.weight.data, requires_grad=not _freeze)
73+
sparse, _weight, True, device, dtype)
74+
self.weight = CPUPinnedParam(self.weight.data, requires_grad=False)
7575

7676
def forward(self, x: Tensor):
7777
return super().forward(x.to('cpu')).to(x.device)
7878

79+
@classmethod
80+
def from_embedding(cls, embedding: torch.nn.Embedding):
81+
return cls(
82+
embedding.num_embeddings,
83+
embedding.embedding_dim,
84+
embedding.padding_idx,
85+
embedding.max_norm,
86+
embedding.norm_type,
87+
embedding.scale_grad_by_freq,
88+
embedding.sparse,
89+
embedding.weight.data,
90+
True,
91+
embedding.weight.device,
92+
embedding.weight.dtype,
93+
)
94+
7995

8096
class DiskEmbedding(torch.nn.Embedding):
8197
def __init__(self,
@@ -89,7 +105,7 @@ def __init__(self,
89105
_weight: Optional[Tensor] = None,
90106
_freeze: bool = False,
91107
device=None,
92-
dtype=None):
108+
dtype=None) -> None:
93109
super().__init__(num_embeddings, embedding_dim, padding_idx,
94110
max_norm, norm_type, scale_grad_by_freq,
95111
sparse, _weight, True, device, dtype)
@@ -147,30 +163,55 @@ def __init__(self,
147163
sparse: bool = False,
148164
_weight: Optional[Tensor] = None,
149165
_freeze: bool = False,
150-
device=None, dtype=None,
151-
qtype=None,
152-
torch_dtype=torch.float32) -> None:
166+
device=None,
167+
dtype=None,
168+
convert_shape_only=None,
169+
qtype=None) -> None:
153170
super().__init__(num_embeddings, embedding_dim, padding_idx,
154171
max_norm, norm_type, scale_grad_by_freq, sparse,
155172
_weight, device, dtype)
156-
self.weight = FP4Params(self.weight.data,
157-
requires_grad=False,
158-
quantized=False, _shape=None, qtype=qtype)
173+
self.qweight = FP4Params(self.weight.data,
174+
requires_grad=False,
175+
quantized=False,
176+
_shape=None,
177+
convert_shape_only=convert_shape_only,
178+
qtype=qtype,
179+
in_features=embedding_dim)
180+
# this dummy_weight is used to record model's dtype and device
181+
dummy_weight = torch.empty(0, 0, dtype=self.weight.dtype, device=self.weight.device)
182+
self.weight = torch.nn.Parameter(dummy_weight, requires_grad=False)
183+
159184
self.embedding_dim = embedding_dim
160185
self.num_embeddings = num_embeddings
161-
self.torch_dtype = torch_dtype
162186

163187
def forward(self, x: Tensor):
164188
invalidInputError(x.device.type == "xpu",
165189
"`LowBitEmbedding` only supports GPU now.")
166190
try:
167-
import intel_extension_for_pytorch
168191
import xe_linear
169192
except ModuleNotFoundError:
170193
invalidInputError(False,
171-
"Please `pip install bigdl_core_xe` first.")
194+
"Please `pip install bigdl_core_xe_21` first.")
172195

173-
result = xe_linear.dequantize_rows(x.contiguous(), self.weight.data,
174-
self.weight.qtype, self.embedding_dim,
196+
result = xe_linear.dequantize_rows(x.contiguous(), self.qweight.data,
197+
self.qweight.qtype, self.embedding_dim,
175198
self.num_embeddings)
176-
return result.to(self.torch_dtype)
199+
return result.to(self.weight.dtype)
200+
201+
@classmethod
202+
def from_embedding(cls, embedding: torch.nn.Embedding, convert_shape_only, qtype):
203+
return cls(
204+
embedding.num_embeddings,
205+
embedding.embedding_dim,
206+
embedding.padding_idx,
207+
embedding.max_norm,
208+
embedding.norm_type,
209+
embedding.scale_grad_by_freq,
210+
embedding.sparse,
211+
embedding.weight.data,
212+
True,
213+
embedding.weight.device,
214+
embedding.weight.dtype,
215+
convert_shape_only,
216+
qtype,
217+
)

python/llm/src/ipex_llm/transformers/low_bit_linear.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -483,7 +483,7 @@ def to(self, *args, **kwargs):
483483
return self.quantize(device.type)
484484
elif (device is not None and device.type == "xpu" and self.data.device.type == "cpu"):
485485
# enter xpu logic, compile linear_int4 extension at first time
486-
self.quantize(device) # tensor is cpu now
486+
self.quantize("cpu") # tensor is cpu now
487487
self.data = ggml_q_format_convet_cpu2xpu(self.data,
488488
reduce(mul, self._shape, 1),
489489
self.qtype)

python/llm/src/ipex_llm/transformers/model.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,8 @@ def from_pretrained(cls,
144144
Default to be ``False``.
145145
:param cpu_embedding: Whether to replace the Embedding layer, may need to set it
146146
to ``True`` when running BigDL-LLM on GPU on Windows. Default to be ``False``.
147+
:param disk_embedding: Whether to put the Embedding layer on disk to save memory.
148+
Default to be ``False``.
147149
:param lightweight_bmm: Whether to replace the torch.bmm ops, may need to set it
148150
to ``True`` when running BigDL-LLM on GPU on Windows. Default to be ``False``.
149151
:param imatrix: str value, represent filename of importance matrix pretrained on
@@ -435,6 +437,7 @@ def load_convert(cls, q_k, optimize_model, *args, **kwargs):
435437
warnings.warn("replace_embedding is deprecated and will be removed in a future version,"
436438
" please use cpu_embedding instead.", FutureWarning)
437439
cpu_embedding = True
440+
disk_embedding = kwargs.pop("disk_embedding", False)
438441
lightweight_bmm = kwargs.pop("lightweight_bmm", False)
439442
quant_config = kwargs.pop("quantization_config", None)
440443
imatrix_data = kwargs.pop("imatrix_data", None)
@@ -507,7 +510,9 @@ def load_convert(cls, q_k, optimize_model, *args, **kwargs):
507510
model = model.to("cpu")
508511
model = ggml_convert_low_bit(model, qtype, optimize_model,
509512
modules_to_not_convert=modules_to_not_convert,
510-
cpu_embedding=cpu_embedding, lightweight_bmm=lightweight_bmm,
513+
cpu_embedding=cpu_embedding,
514+
disk_embedding=disk_embedding,
515+
lightweight_bmm=lightweight_bmm,
511516
torch_dtype=kwargs.get("torch_dtype", 'auto'),
512517
imatrix_data=imatrix_data,
513518
embedding_qtype=embedding_qtype,
@@ -563,6 +568,7 @@ def load_low_bit(cls,
563568
warnings.warn("replace_embedding is deprecated and will be removed in a future version,"
564569
" please use cpu_embedding instead.", FutureWarning)
565570
cpu_embedding = True
571+
disk_embedding = kwargs.pop("disk_embedding", False)
566572
lightweight_bmm = kwargs.pop("lightweight_bmm", False)
567573
# Autofactory
568574
trust_remote_code = kwargs.pop("trust_remote_code", None)
@@ -699,7 +705,9 @@ def load_low_bit(cls,
699705
quant_device = "meta" if bigdl_lcmu_enabled else "cpu"
700706
model = ggml_convert_low_bit(model, qtype, optimize_model, device=quant_device,
701707
modules_to_not_convert=modules_to_not_convert,
702-
cpu_embedding=cpu_embedding, lightweight_bmm=lightweight_bmm,
708+
cpu_embedding=cpu_embedding,
709+
disk_embedding=disk_embedding,
710+
lightweight_bmm=lightweight_bmm,
703711
embedding_qtype=embedding_qtype, torch_dtype=torch_dtype)
704712

705713
if is_sharded:

0 commit comments

Comments
 (0)