Skip to content

Commit 403680f

Browse files
NarsilNicolas Patry
andauthored
Quantized GGUF style (#1523)
* Metal quantized modifications proposal. - Add a device param, wherever needed. - Create new QMetal storage thing that implements QuantizedType. - Update everywhere needed. Fix Python. Fixing examples. Fix: fmt + clippy + stub. Moving everything around. Only missing the actual implems. Fixing everything + adding dequantized kernels. More work. Fixing matmul. Fmt + Clippy Some clippy fixes. Working state. Q2K Metal -> Bugged (also present in GGML). Q4K CPU -> Bugged (present previously, new test catch it). Q5K CPU -> Bugged (present previously). Q8_1 Both -> Never really implemented it seems Q8K metal -> Never implemented in metal Fixing Q2K bug (present in ggml). * Cleanup. * Fix the rebase. * Removing the fences speeds everything up and *is* correct this time... * Cleanup the fence. * After rebase. * Bad code removal. * Rebase after phi2 merge + fix replit default to CPU. * Making the CI happy. * More happy tests. --------- Co-authored-by: Nicolas Patry <[email protected]>
1 parent 5270224 commit 403680f

File tree

31 files changed

+6447
-516
lines changed

31 files changed

+6447
-516
lines changed

candle-core/examples/tensor-tools.rs

Lines changed: 52 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
use candle_core::quantized::{gguf_file, k_quants, QTensor};
2-
use candle_core::{Device, Result, Tensor};
1+
use candle_core::quantized::{gguf_file, GgmlDType, QTensor};
2+
use candle_core::{Device, Result};
33
use clap::{Parser, Subcommand, ValueEnum};
44
use rayon::prelude::*;
55

@@ -11,22 +11,17 @@ enum QuantizationMode {
1111
}
1212

1313
impl QuantizationMode {
14-
fn quantize(
15-
&self,
16-
name: &str,
17-
tensor: QTensor,
18-
default: fn(&Tensor) -> Result<QTensor>,
19-
) -> Result<QTensor> {
14+
fn quantize(&self, name: &str, tensor: QTensor, dtype: GgmlDType) -> Result<QTensor> {
2015
match self {
2116
Self::Llama => {
2217
// Same behavior as the llama.cpp quantization.
2318
let should_quantize = name.ends_with(".weight") && tensor.rank() == 2;
2419
if should_quantize {
2520
let tensor = tensor.dequantize(&Device::Cpu)?;
2621
if name == "output.weight" {
27-
QTensor::quantize::<k_quants::BlockQ6K>(&tensor)
22+
QTensor::quantize(&tensor, GgmlDType::Q6K)
2823
} else {
29-
default(&tensor)
24+
QTensor::quantize(&tensor, dtype)
3025
}
3126
} else {
3227
Ok(tensor)
@@ -60,6 +55,27 @@ enum Quantization {
6055
F32,
6156
}
6257

58+
impl Quantization {
59+
fn dtype(&self) -> GgmlDType {
60+
match self {
61+
Quantization::Q4_0 => GgmlDType::Q4_0,
62+
Quantization::Q4_1 => GgmlDType::Q4_1,
63+
Quantization::Q5_0 => GgmlDType::Q5_0,
64+
Quantization::Q5_1 => GgmlDType::Q5_1,
65+
Quantization::Q8_0 => GgmlDType::Q8_0,
66+
Quantization::Q8_1 => GgmlDType::Q8_1,
67+
Quantization::Q2k => GgmlDType::Q2K,
68+
Quantization::Q3k => GgmlDType::Q3K,
69+
Quantization::Q4k => GgmlDType::Q4K,
70+
Quantization::Q5k => GgmlDType::Q5K,
71+
Quantization::Q6k => GgmlDType::Q6K,
72+
Quantization::Q8k => GgmlDType::Q8K,
73+
Quantization::F16 => GgmlDType::F16,
74+
Quantization::F32 => GgmlDType::F32,
75+
}
76+
}
77+
}
78+
6379
#[derive(ValueEnum, Debug, Clone)]
6480
enum Format {
6581
Safetensors,
@@ -134,7 +150,12 @@ struct Args {
134150
command: Command,
135151
}
136152

137-
fn run_ls(file: &std::path::PathBuf, format: Option<Format>, verbose: bool) -> Result<()> {
153+
fn run_ls(
154+
file: &std::path::PathBuf,
155+
format: Option<Format>,
156+
verbose: bool,
157+
device: &Device,
158+
) -> Result<()> {
138159
let format = match format {
139160
Some(format) => format,
140161
None => match Format::infer(file) {
@@ -200,7 +221,7 @@ fn run_ls(file: &std::path::PathBuf, format: Option<Format>, verbose: bool) -> R
200221
}
201222
Format::Ggml => {
202223
let mut file = std::fs::File::open(file)?;
203-
let content = candle_core::quantized::ggml_file::Content::read(&mut file)?;
224+
let content = candle_core::quantized::ggml_file::Content::read(&mut file, device)?;
204225
let mut tensors = content.tensors.into_iter().collect::<Vec<_>>();
205226
tensors.sort_by(|a, b| a.0.cmp(&b.0));
206227
for (name, qtensor) in tensors.iter() {
@@ -241,47 +262,18 @@ fn run_quantize_safetensors(
241262
}
242263
println!("tensors: {}", tensors.len());
243264

244-
let quantize_fn = match q {
245-
Quantization::Q4_0 => QTensor::quantize::<k_quants::BlockQ4_0>,
246-
Quantization::Q4_1 => QTensor::quantize::<k_quants::BlockQ4_1>,
247-
Quantization::Q5_0 => QTensor::quantize::<k_quants::BlockQ5_0>,
248-
Quantization::Q5_1 => QTensor::quantize::<k_quants::BlockQ5_1>,
249-
Quantization::Q8_0 => QTensor::quantize::<k_quants::BlockQ8_0>,
250-
Quantization::Q8_1 => QTensor::quantize::<k_quants::BlockQ8_1>,
251-
Quantization::Q2k => QTensor::quantize::<k_quants::BlockQ2K>,
252-
Quantization::Q3k => QTensor::quantize::<k_quants::BlockQ3K>,
253-
Quantization::Q4k => QTensor::quantize::<k_quants::BlockQ4K>,
254-
Quantization::Q5k => QTensor::quantize::<k_quants::BlockQ5K>,
255-
Quantization::Q6k => QTensor::quantize::<k_quants::BlockQ6K>,
256-
Quantization::Q8k => QTensor::quantize::<k_quants::BlockQ8K>,
257-
Quantization::F16 => QTensor::quantize::<half::f16>,
258-
Quantization::F32 => QTensor::quantize::<f32>,
259-
};
260-
let block_size = match q {
261-
Quantization::Q4_0 => k_quants::QK4_0,
262-
Quantization::Q4_1 => k_quants::QK4_1,
263-
Quantization::Q5_0 => k_quants::QK5_0,
264-
Quantization::Q5_1 => k_quants::QK5_1,
265-
Quantization::Q8_0 => k_quants::QK8_0,
266-
Quantization::Q8_1 => k_quants::QK8_1,
267-
Quantization::Q2k
268-
| Quantization::Q3k
269-
| Quantization::Q4k
270-
| Quantization::Q5k
271-
| Quantization::Q6k
272-
| Quantization::Q8k => k_quants::QK_K,
273-
Quantization::F16 | Quantization::F32 => 1,
274-
};
265+
let dtype = q.dtype();
266+
let block_size = dtype.block_size();
275267

276268
let qtensors = tensors
277269
.into_par_iter()
278270
.map(|(name, tensor)| {
279271
let should_quantize = tensor.rank() == 2 && tensor.dim(1)? % block_size == 0;
280272
println!(" quantizing {name} {tensor:?} {should_quantize}");
281273
let tensor = if should_quantize {
282-
quantize_fn(&tensor)?
274+
QTensor::quantize(&tensor, dtype)?
283275
} else {
284-
QTensor::quantize::<f32>(&tensor)?
276+
QTensor::quantize(&tensor, GgmlDType::F32)?
285277
};
286278
Ok((name, tensor))
287279
})
@@ -294,13 +286,17 @@ fn run_quantize_safetensors(
294286
Ok(())
295287
}
296288

297-
fn run_dequantize(in_file: std::path::PathBuf, out_file: std::path::PathBuf) -> Result<()> {
289+
fn run_dequantize(
290+
in_file: std::path::PathBuf,
291+
out_file: std::path::PathBuf,
292+
device: &Device,
293+
) -> Result<()> {
298294
let mut in_file = std::fs::File::open(in_file)?;
299295
let content = gguf_file::Content::read(&mut in_file)?;
300296
let mut tensors = std::collections::HashMap::new();
301297
for (tensor_name, _) in content.tensor_infos.iter() {
302-
let tensor = content.tensor(&mut in_file, tensor_name)?;
303-
let tensor = tensor.dequantize(&Device::Cpu)?;
298+
let tensor = content.tensor(&mut in_file, tensor_name, device)?;
299+
let tensor = tensor.dequantize(device)?;
304300
tensors.insert(tensor_name.to_string(), tensor);
305301
}
306302
candle_core::safetensors::save(&tensors, out_file)?;
@@ -312,6 +308,7 @@ fn run_quantize(
312308
out_file: std::path::PathBuf,
313309
q: Quantization,
314310
qmode: QuantizationMode,
311+
device: &Device,
315312
) -> Result<()> {
316313
if in_files.is_empty() {
317314
candle_core::bail!("no specified input files")
@@ -337,31 +334,15 @@ fn run_quantize(
337334
let content = gguf_file::Content::read(&mut in_)?;
338335
println!("tensors: {}", content.tensor_infos.len());
339336

340-
let quantize_fn = match q {
341-
Quantization::Q4_0 => QTensor::quantize::<k_quants::BlockQ4_0>,
342-
Quantization::Q4_1 => QTensor::quantize::<k_quants::BlockQ4_1>,
343-
Quantization::Q5_0 => QTensor::quantize::<k_quants::BlockQ5_0>,
344-
Quantization::Q5_1 => QTensor::quantize::<k_quants::BlockQ5_1>,
345-
Quantization::Q8_0 => QTensor::quantize::<k_quants::BlockQ8_0>,
346-
Quantization::Q8_1 => QTensor::quantize::<k_quants::BlockQ8_1>,
347-
Quantization::Q2k => QTensor::quantize::<k_quants::BlockQ2K>,
348-
Quantization::Q3k => QTensor::quantize::<k_quants::BlockQ3K>,
349-
Quantization::Q4k => QTensor::quantize::<k_quants::BlockQ4K>,
350-
Quantization::Q5k => QTensor::quantize::<k_quants::BlockQ5K>,
351-
Quantization::Q6k => QTensor::quantize::<k_quants::BlockQ6K>,
352-
Quantization::Q8k => QTensor::quantize::<k_quants::BlockQ8K>,
353-
Quantization::F16 => QTensor::quantize::<half::f16>,
354-
Quantization::F32 => QTensor::quantize::<f32>,
355-
};
356-
337+
let dtype = q.dtype();
357338
let qtensors = content
358339
.tensor_infos
359340
.par_iter()
360341
.map(|(name, _)| {
361342
println!(" quantizing {name}");
362343
let mut in_file = std::fs::File::open(&in_files[0])?;
363-
let tensor = content.tensor(&mut in_file, name)?;
364-
let tensor = qmode.quantize(name, tensor, quantize_fn)?;
344+
let tensor = content.tensor(&mut in_file, name, device)?;
345+
let tensor = qmode.quantize(name, tensor, dtype)?;
365346
Ok((name, tensor))
366347
})
367348
.collect::<Result<Vec<_>>>()?;
@@ -381,6 +362,7 @@ fn run_quantize(
381362

382363
fn main() -> anyhow::Result<()> {
383364
let args = Args::parse();
365+
let device = Device::Cpu;
384366
match args.command {
385367
Command::Ls {
386368
files,
@@ -392,16 +374,16 @@ fn main() -> anyhow::Result<()> {
392374
if multiple_files {
393375
println!("--- {file:?} ---");
394376
}
395-
run_ls(file, format.clone(), verbose)?
377+
run_ls(file, format.clone(), verbose, &device)?
396378
}
397379
}
398380
Command::Quantize {
399381
in_file,
400382
out_file,
401383
quantization,
402384
mode,
403-
} => run_quantize(&in_file, out_file, quantization, mode)?,
404-
Command::Dequantize { in_file, out_file } => run_dequantize(in_file, out_file)?,
385+
} => run_quantize(&in_file, out_file, quantization, mode, &device)?,
386+
Command::Dequantize { in_file, out_file } => run_dequantize(in_file, out_file, &device)?,
405387
}
406388
Ok(())
407389
}

0 commit comments

Comments
 (0)