forked from EricLBuehler/candle-lora
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathblip.rs
127 lines (109 loc) · 4.28 KB
/
blip.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::Error as E;
use candle_lora::LoraConfig;
use clap::Parser;
use candle_core::{DType, Device, Result, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_lora_transformers::{blip, varbuilder_utils::from_mmaped_safetensors};
use tokenizers::Tokenizer;
// TODO: Maybe add support for the conditional prompt.
#[derive(Parser)]
struct Args {
#[arg(long)]
model: Option<String>,
#[arg(long)]
tokenizer: Option<String>,
#[arg(long)]
image: String,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Use the quantized version of the model.
#[arg(long)]
quantized: bool,
}
const SEP_TOKEN_ID: u32 = 102;
/// Loads an image from disk using the image crate, this returns a tensor with shape
/// (3, 384, 384). OpenAI normalization is applied.
pub fn load_image<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> {
let img = image::io::Reader::open(p)?
.decode()
.map_err(candle_core::Error::wrap)?
.resize_to_fill(384, 384, image::imageops::FilterType::Triangle);
let img = img.to_rgb8();
let data = img.into_raw();
let data = Tensor::from_vec(data, (384, 384, 3), &Device::Cpu)?.permute((2, 0, 1))?;
let mean =
Tensor::new(&[0.48145466f32, 0.4578275, 0.40821073], &Device::Cpu)?.reshape((3, 1, 1))?;
let std = Tensor::new(&[0.26862954f32, 0.261_302_6, 0.275_777_1], &Device::Cpu)?
.reshape((3, 1, 1))?;
(data.to_dtype(candle_core::DType::F32)? / 255.)?
.broadcast_sub(&mean)?
.broadcast_div(&std)
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let model_file = match args.model {
None => {
let api = hf_hub::api::sync::Api::new()?;
if args.quantized {
let api = api.model("lmz/candle-blip".to_string());
api.get("blip-image-captioning-large-q4k.gguf")?
} else {
let api = api.repo(hf_hub::Repo::with_revision(
"Salesforce/blip-image-captioning-large".to_string(),
hf_hub::RepoType::Model,
"refs/pr/18".to_string(),
));
api.get("model.safetensors")?
}
}
Some(model) => model.into(),
};
let tokenizer = match args.tokenizer {
None => {
let api = hf_hub::api::sync::Api::new()?;
let api = api.model("Salesforce/blip-image-captioning-large".to_string());
api.get("tokenizer.json")?
}
Some(file) => file.into(),
};
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
let mut tokenizer = TokenOutputStream::new(tokenizer);
let mut logits_processor =
candle_transformers::generation::LogitsProcessor::new(1337, None, None);
let config = blip::Config::image_captioning_large();
let device = candle_examples::device(args.cpu)?;
let image = load_image(args.image)?.to_device(&device)?;
println!("loaded image {image:?}");
let vb = from_mmaped_safetensors(&[model_file], DType::F32, &device, false)?;
let loraconfig = LoraConfig::new(1, 1., None);
let mut model = blip::BlipForConditionalGeneration::new(&config, vb, true, loraconfig)?;
let image_embeds = image.unsqueeze(0)?.apply(model.vision_model())?;
let mut token_ids = vec![30522u32];
for index in 0..1000 {
let context_size = if index > 0 { 1 } else { token_ids.len() };
let start_pos = token_ids.len().saturating_sub(context_size);
let input_ids = Tensor::new(&token_ids[start_pos..], &device)?.unsqueeze(0)?;
let logits = model.text_decoder().forward(&input_ids, &image_embeds)?;
let logits = logits.squeeze(0)?;
let logits = logits.get(logits.dim(0)? - 1)?;
let token = logits_processor.sample(&logits)?;
if token == SEP_TOKEN_ID {
break;
}
token_ids.push(token);
if let Some(t) = tokenizer.next_token(token)? {
use std::io::Write;
print!("{t}");
std::io::stdout().flush()?;
}
}
if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
Ok(())
}