Skip to content

Commit d970e20

Browse files
committed
Add prototype of QLoRA
1 parent 0b2ee91 commit d970e20

File tree

7 files changed

+112
-0
lines changed

7 files changed

+112
-0
lines changed

paddleslim/lc/__init__.py

Whitespace-only changes.

paddleslim/lc/layers/__init__.py

Whitespace-only changes.

paddleslim/lc/layers/linear.py

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import paddle
2+
import paddle.nn as nn
3+
4+
5+
class WeightQuantizationLinear(nn.Layer):
6+
def __init__(
7+
self,
8+
linear: paddle.nn.Linear, ):
9+
super().__init__()
10+
self.in_features = linear.weight.shape[0]
11+
self.out_features = linear.weight.shape[1]
12+
self.dtype = linear.dtype
13+
self.weight_name = linear.weight.name
14+
self.quant_weight_name = ".".join([self.weight_name, "quant_weight"])
15+
16+
def forward(self, x):
17+
raise NotImplementedError()
18+
19+
def quantize(self, weight) -> paddle.Tensor:
20+
raise NotImplementedError()

paddleslim/lc/layers/nf4_linear.py

+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import paddle
2+
import paddle.nn as nn
3+
from paddleslim.lc.quantizers import NF4Quantizer
4+
from .linear import WeightQuantizationLinear
5+
6+
7+
class NF4Linear(WeightQuantizationLinear):
8+
quant_dtype = "int4"
9+
weight_dtype = "int8"
10+
11+
def __init__(
12+
self,
13+
linear: nn.Linear,
14+
block_size=64,
15+
double_quant=False, ):
16+
super(NF4Linear, self).__init__(linear)
17+
self.block_size = block_size
18+
self.double_quant = double_quant
19+
self.quantizer = NF4Quantizer(block_size, double_quant)
20+
# PaddlePaddle dosen't support Int4 data type, one Int8 data represents two Int4 data.
21+
self.quant_weight = self.create_parameter(
22+
shape=[self.out_features // 2, self.in_features],
23+
attr=paddle.ParamAttr(self.quant_weight_name),
24+
dtype=NF4Linear.weight_dtype,
25+
is_bias=False, )
26+
27+
self.quant_scale_name = ".".join([self.weight_name, "quant_scale"])
28+
self.quant_scale = self.create_parameter(
29+
shape=[self.out_features],
30+
attr=paddle.ParamAttr(self.quant_scale_name),
31+
dtype="float32", # to be fixed
32+
is_bias=False, )
33+
if self.double_quant:
34+
self.double_quant_scale_name = ".".join(
35+
[self.weight_name, "double_quant_scale"])
36+
self.double_quant_scale = self.create_parameter(
37+
shape=[self.out_features],
38+
attr=paddle.ParamAttr(self.double_quant_scale_name),
39+
dtype="float32",
40+
is_bias=False, )
41+
42+
def quantize(self, weight):
43+
quantized_weight = self.quantizer.quantize(weight)
44+
#self.set_state_dict({self.quant_weight_name: quantized_weight})
45+
self.quant_weight.set_value(quantized_weight)
46+
#self.set_state_dict({self.quant_scale_name: self.quantizer.quant_scale})
47+
self.quant_scale.set_value(self.quantizer.quant_scale)
48+
if self.double_quant:
49+
#self.set_state_dict({self.double_quant_scale_name: self.quantizer.double_quant_scale})
50+
self.double_quant_scale.set_value(self.quantizer.double_quant_scale)
51+
return quantized_weight
52+
53+
def forward(self, x):
54+
self.quantizer.quant_scale = self.state_dict[self.quant_scale_name]
55+
self.quantizer.double_quant_scale = self.state_dict[
56+
self.double_quant_scale_name]
57+
return self.quantizer.matmul(x, self.quant_weight)

paddleslim/lc/quantizers/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .nf4 import NF4Quantizer
+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import paddle
2+
3+
4+
class BaseQuantizer():
5+
def quantize(self, x: paddle.Tensor):
6+
raise NotImplementedError()
7+
8+
def dequantize(self, x: paddle.Tensor):
9+
raise NotImplementedError()
10+
11+
def matmul(self, x: paddle.Tensor, y: paddle.Tensor, bias: paddle.Tensor):
12+
raise NotImplementedError()

paddleslim/lc/quantizers/nf4.py

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import paddle
2+
from .base_quantizer import BaseQuantizer
3+
4+
5+
class NF4Quantizer(BaseQuantizer):
6+
dtype = "int4"
7+
8+
def __init__(self, block_size=64, double_quant=False):
9+
super(BaseQuantizer, self).__init__()
10+
self.block_size = block_size
11+
self.double_quant = double_quant
12+
self.quant_scale = None
13+
self.double_quant_scale = None
14+
15+
def quantize(self, x: paddle.Tensor):
16+
return x
17+
18+
def dequantize(self, x: paddle.Tensor):
19+
return x
20+
21+
def matmul(self, x: paddle.Tensor, y: paddle.Tensor, bias: paddle.Tensor):
22+
return x @ self.dequantize(y) + bias

0 commit comments

Comments
 (0)