1
- use candle_core:: quantized:: { gguf_file, k_quants , QTensor } ;
2
- use candle_core:: { Device , Result , Tensor } ;
1
+ use candle_core:: quantized:: { gguf_file, GgmlDType , QTensor } ;
2
+ use candle_core:: { Device , Result } ;
3
3
use clap:: { Parser , Subcommand , ValueEnum } ;
4
4
use rayon:: prelude:: * ;
5
5
@@ -11,22 +11,17 @@ enum QuantizationMode {
11
11
}
12
12
13
13
impl QuantizationMode {
14
- fn quantize (
15
- & self ,
16
- name : & str ,
17
- tensor : QTensor ,
18
- default : fn ( & Tensor ) -> Result < QTensor > ,
19
- ) -> Result < QTensor > {
14
+ fn quantize ( & self , name : & str , tensor : QTensor , dtype : GgmlDType ) -> Result < QTensor > {
20
15
match self {
21
16
Self :: Llama => {
22
17
// Same behavior as the llama.cpp quantization.
23
18
let should_quantize = name. ends_with ( ".weight" ) && tensor. rank ( ) == 2 ;
24
19
if should_quantize {
25
20
let tensor = tensor. dequantize ( & Device :: Cpu ) ?;
26
21
if name == "output.weight" {
27
- QTensor :: quantize :: < k_quants :: BlockQ6K > ( & tensor)
22
+ QTensor :: quantize ( & tensor, GgmlDType :: Q6K )
28
23
} else {
29
- default ( & tensor)
24
+ QTensor :: quantize ( & tensor, dtype )
30
25
}
31
26
} else {
32
27
Ok ( tensor)
@@ -60,6 +55,27 @@ enum Quantization {
60
55
F32 ,
61
56
}
62
57
58
+ impl Quantization {
59
+ fn dtype ( & self ) -> GgmlDType {
60
+ match self {
61
+ Quantization :: Q4_0 => GgmlDType :: Q4_0 ,
62
+ Quantization :: Q4_1 => GgmlDType :: Q4_1 ,
63
+ Quantization :: Q5_0 => GgmlDType :: Q5_0 ,
64
+ Quantization :: Q5_1 => GgmlDType :: Q5_1 ,
65
+ Quantization :: Q8_0 => GgmlDType :: Q8_0 ,
66
+ Quantization :: Q8_1 => GgmlDType :: Q8_1 ,
67
+ Quantization :: Q2k => GgmlDType :: Q2K ,
68
+ Quantization :: Q3k => GgmlDType :: Q3K ,
69
+ Quantization :: Q4k => GgmlDType :: Q4K ,
70
+ Quantization :: Q5k => GgmlDType :: Q5K ,
71
+ Quantization :: Q6k => GgmlDType :: Q6K ,
72
+ Quantization :: Q8k => GgmlDType :: Q8K ,
73
+ Quantization :: F16 => GgmlDType :: F16 ,
74
+ Quantization :: F32 => GgmlDType :: F32 ,
75
+ }
76
+ }
77
+ }
78
+
63
79
#[ derive( ValueEnum , Debug , Clone ) ]
64
80
enum Format {
65
81
Safetensors ,
@@ -134,7 +150,12 @@ struct Args {
134
150
command : Command ,
135
151
}
136
152
137
- fn run_ls ( file : & std:: path:: PathBuf , format : Option < Format > , verbose : bool ) -> Result < ( ) > {
153
+ fn run_ls (
154
+ file : & std:: path:: PathBuf ,
155
+ format : Option < Format > ,
156
+ verbose : bool ,
157
+ device : & Device ,
158
+ ) -> Result < ( ) > {
138
159
let format = match format {
139
160
Some ( format) => format,
140
161
None => match Format :: infer ( file) {
@@ -200,7 +221,7 @@ fn run_ls(file: &std::path::PathBuf, format: Option<Format>, verbose: bool) -> R
200
221
}
201
222
Format :: Ggml => {
202
223
let mut file = std:: fs:: File :: open ( file) ?;
203
- let content = candle_core:: quantized:: ggml_file:: Content :: read ( & mut file) ?;
224
+ let content = candle_core:: quantized:: ggml_file:: Content :: read ( & mut file, device ) ?;
204
225
let mut tensors = content. tensors . into_iter ( ) . collect :: < Vec < _ > > ( ) ;
205
226
tensors. sort_by ( |a, b| a. 0 . cmp ( & b. 0 ) ) ;
206
227
for ( name, qtensor) in tensors. iter ( ) {
@@ -241,47 +262,18 @@ fn run_quantize_safetensors(
241
262
}
242
263
println ! ( "tensors: {}" , tensors. len( ) ) ;
243
264
244
- let quantize_fn = match q {
245
- Quantization :: Q4_0 => QTensor :: quantize :: < k_quants:: BlockQ4_0 > ,
246
- Quantization :: Q4_1 => QTensor :: quantize :: < k_quants:: BlockQ4_1 > ,
247
- Quantization :: Q5_0 => QTensor :: quantize :: < k_quants:: BlockQ5_0 > ,
248
- Quantization :: Q5_1 => QTensor :: quantize :: < k_quants:: BlockQ5_1 > ,
249
- Quantization :: Q8_0 => QTensor :: quantize :: < k_quants:: BlockQ8_0 > ,
250
- Quantization :: Q8_1 => QTensor :: quantize :: < k_quants:: BlockQ8_1 > ,
251
- Quantization :: Q2k => QTensor :: quantize :: < k_quants:: BlockQ2K > ,
252
- Quantization :: Q3k => QTensor :: quantize :: < k_quants:: BlockQ3K > ,
253
- Quantization :: Q4k => QTensor :: quantize :: < k_quants:: BlockQ4K > ,
254
- Quantization :: Q5k => QTensor :: quantize :: < k_quants:: BlockQ5K > ,
255
- Quantization :: Q6k => QTensor :: quantize :: < k_quants:: BlockQ6K > ,
256
- Quantization :: Q8k => QTensor :: quantize :: < k_quants:: BlockQ8K > ,
257
- Quantization :: F16 => QTensor :: quantize :: < half:: f16 > ,
258
- Quantization :: F32 => QTensor :: quantize :: < f32 > ,
259
- } ;
260
- let block_size = match q {
261
- Quantization :: Q4_0 => k_quants:: QK4_0 ,
262
- Quantization :: Q4_1 => k_quants:: QK4_1 ,
263
- Quantization :: Q5_0 => k_quants:: QK5_0 ,
264
- Quantization :: Q5_1 => k_quants:: QK5_1 ,
265
- Quantization :: Q8_0 => k_quants:: QK8_0 ,
266
- Quantization :: Q8_1 => k_quants:: QK8_1 ,
267
- Quantization :: Q2k
268
- | Quantization :: Q3k
269
- | Quantization :: Q4k
270
- | Quantization :: Q5k
271
- | Quantization :: Q6k
272
- | Quantization :: Q8k => k_quants:: QK_K ,
273
- Quantization :: F16 | Quantization :: F32 => 1 ,
274
- } ;
265
+ let dtype = q. dtype ( ) ;
266
+ let block_size = dtype. block_size ( ) ;
275
267
276
268
let qtensors = tensors
277
269
. into_par_iter ( )
278
270
. map ( |( name, tensor) | {
279
271
let should_quantize = tensor. rank ( ) == 2 && tensor. dim ( 1 ) ? % block_size == 0 ;
280
272
println ! ( " quantizing {name} {tensor:?} {should_quantize}" ) ;
281
273
let tensor = if should_quantize {
282
- quantize_fn ( & tensor) ?
274
+ QTensor :: quantize ( & tensor, dtype ) ?
283
275
} else {
284
- QTensor :: quantize :: < f32 > ( & tensor) ?
276
+ QTensor :: quantize ( & tensor, GgmlDType :: F32 ) ?
285
277
} ;
286
278
Ok ( ( name, tensor) )
287
279
} )
@@ -294,13 +286,17 @@ fn run_quantize_safetensors(
294
286
Ok ( ( ) )
295
287
}
296
288
297
- fn run_dequantize ( in_file : std:: path:: PathBuf , out_file : std:: path:: PathBuf ) -> Result < ( ) > {
289
+ fn run_dequantize (
290
+ in_file : std:: path:: PathBuf ,
291
+ out_file : std:: path:: PathBuf ,
292
+ device : & Device ,
293
+ ) -> Result < ( ) > {
298
294
let mut in_file = std:: fs:: File :: open ( in_file) ?;
299
295
let content = gguf_file:: Content :: read ( & mut in_file) ?;
300
296
let mut tensors = std:: collections:: HashMap :: new ( ) ;
301
297
for ( tensor_name, _) in content. tensor_infos . iter ( ) {
302
- let tensor = content. tensor ( & mut in_file, tensor_name) ?;
303
- let tensor = tensor. dequantize ( & Device :: Cpu ) ?;
298
+ let tensor = content. tensor ( & mut in_file, tensor_name, device ) ?;
299
+ let tensor = tensor. dequantize ( device ) ?;
304
300
tensors. insert ( tensor_name. to_string ( ) , tensor) ;
305
301
}
306
302
candle_core:: safetensors:: save ( & tensors, out_file) ?;
@@ -312,6 +308,7 @@ fn run_quantize(
312
308
out_file : std:: path:: PathBuf ,
313
309
q : Quantization ,
314
310
qmode : QuantizationMode ,
311
+ device : & Device ,
315
312
) -> Result < ( ) > {
316
313
if in_files. is_empty ( ) {
317
314
candle_core:: bail!( "no specified input files" )
@@ -337,31 +334,15 @@ fn run_quantize(
337
334
let content = gguf_file:: Content :: read ( & mut in_) ?;
338
335
println ! ( "tensors: {}" , content. tensor_infos. len( ) ) ;
339
336
340
- let quantize_fn = match q {
341
- Quantization :: Q4_0 => QTensor :: quantize :: < k_quants:: BlockQ4_0 > ,
342
- Quantization :: Q4_1 => QTensor :: quantize :: < k_quants:: BlockQ4_1 > ,
343
- Quantization :: Q5_0 => QTensor :: quantize :: < k_quants:: BlockQ5_0 > ,
344
- Quantization :: Q5_1 => QTensor :: quantize :: < k_quants:: BlockQ5_1 > ,
345
- Quantization :: Q8_0 => QTensor :: quantize :: < k_quants:: BlockQ8_0 > ,
346
- Quantization :: Q8_1 => QTensor :: quantize :: < k_quants:: BlockQ8_1 > ,
347
- Quantization :: Q2k => QTensor :: quantize :: < k_quants:: BlockQ2K > ,
348
- Quantization :: Q3k => QTensor :: quantize :: < k_quants:: BlockQ3K > ,
349
- Quantization :: Q4k => QTensor :: quantize :: < k_quants:: BlockQ4K > ,
350
- Quantization :: Q5k => QTensor :: quantize :: < k_quants:: BlockQ5K > ,
351
- Quantization :: Q6k => QTensor :: quantize :: < k_quants:: BlockQ6K > ,
352
- Quantization :: Q8k => QTensor :: quantize :: < k_quants:: BlockQ8K > ,
353
- Quantization :: F16 => QTensor :: quantize :: < half:: f16 > ,
354
- Quantization :: F32 => QTensor :: quantize :: < f32 > ,
355
- } ;
356
-
337
+ let dtype = q. dtype ( ) ;
357
338
let qtensors = content
358
339
. tensor_infos
359
340
. par_iter ( )
360
341
. map ( |( name, _) | {
361
342
println ! ( " quantizing {name}" ) ;
362
343
let mut in_file = std:: fs:: File :: open ( & in_files[ 0 ] ) ?;
363
- let tensor = content. tensor ( & mut in_file, name) ?;
364
- let tensor = qmode. quantize ( name, tensor, quantize_fn ) ?;
344
+ let tensor = content. tensor ( & mut in_file, name, device ) ?;
345
+ let tensor = qmode. quantize ( name, tensor, dtype ) ?;
365
346
Ok ( ( name, tensor) )
366
347
} )
367
348
. collect :: < Result < Vec < _ > > > ( ) ?;
@@ -381,6 +362,7 @@ fn run_quantize(
381
362
382
363
fn main ( ) -> anyhow:: Result < ( ) > {
383
364
let args = Args :: parse ( ) ;
365
+ let device = Device :: Cpu ;
384
366
match args. command {
385
367
Command :: Ls {
386
368
files,
@@ -392,16 +374,16 @@ fn main() -> anyhow::Result<()> {
392
374
if multiple_files {
393
375
println ! ( "--- {file:?} ---" ) ;
394
376
}
395
- run_ls ( file, format. clone ( ) , verbose) ?
377
+ run_ls ( file, format. clone ( ) , verbose, & device ) ?
396
378
}
397
379
}
398
380
Command :: Quantize {
399
381
in_file,
400
382
out_file,
401
383
quantization,
402
384
mode,
403
- } => run_quantize ( & in_file, out_file, quantization, mode) ?,
404
- Command :: Dequantize { in_file, out_file } => run_dequantize ( in_file, out_file) ?,
385
+ } => run_quantize ( & in_file, out_file, quantization, mode, & device ) ?,
386
+ Command :: Dequantize { in_file, out_file } => run_dequantize ( in_file, out_file, & device ) ?,
405
387
}
406
388
Ok ( ( ) )
407
389
}
0 commit comments