Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add masked_fill under Tensor #2346

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions candle-core/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2454,6 +2454,13 @@ impl Tensor {
pub fn broadcast_pow(&self, rhs: &Tensor) -> Result<Self> {
rhs.broadcast_mul(&self.log()?)?.exp()
}

pub fn masked_fill(&self, rhs: &Tensor, value: f32) -> Result<Self> {
rhs.where_cond(
&Tensor::new(value, self.device())?.broadcast_as(rhs.shape().dims())?,
self,
)
}
}

macro_rules! bin_trait {
Expand Down
18 changes: 18 additions & 0 deletions candle-core/tests/tensor_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1345,3 +1345,21 @@ fn pow() -> Result<()> {
);
Ok(())
}

#[test]
fn masked_fill() -> Result<()> {
let lhs = Tensor::zeros((5, 5), DType::F32, &Device::Cpu)?;
let rhs = Tensor::eye(5, DType::I64, &Device::Cpu)?;
let res = lhs.masked_fill(&rhs, f32::NEG_INFINITY)?;
assert_eq!(
res.to_vec2::<f32>()?,
[
[f32::NEG_INFINITY, 0.0, 0.0, 0.0, 0.0],
[0.0, f32::NEG_INFINITY, 0.0, 0.0, 0.0],
[0.0, 0.0, f32::NEG_INFINITY, 0.0, 0.0],
[0.0, 0.0, 0.0, f32::NEG_INFINITY, 0.0],
[0.0, 0.0, 0.0, 0.0, f32::NEG_INFINITY],
]
);
Ok(())
}
10 changes: 1 addition & 9 deletions candle-transformers/src/models/chatglm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,6 @@ struct CoreAttention {
norm_factor: f64,
}

fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
let shape = mask.shape();
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
let m = mask.where_cond(&on_true, on_false)?;
Ok(m)
}

impl CoreAttention {
fn new(layer_number: usize, cfg: &Config) -> Result<Self> {
let norm_factor = (cfg.kv_channels as f64).sqrt();
Expand Down Expand Up @@ -152,8 +145,7 @@ impl CoreAttention {
Some(coeff) => (matmul_result * coeff)?,
};
let attention_scores = match attention_mask {
Some(mask) => masked_fill(
&matmul_result,
Some(mask) => matmul_result.masked_fill(
&mask.broadcast_left((matmul_result.dim(0)?, matmul_result.dim(1)?))?,
f32::NEG_INFINITY,
)?,
Expand Down
11 changes: 3 additions & 8 deletions candle-transformers/src/models/distilbert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,6 @@ use serde::Deserialize;

pub const DTYPE: DType = DType::F32;

fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
let shape = mask.shape();
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
let m = mask.where_cond(&on_true, on_false)?;
Ok(m)
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
#[serde(rename_all = "lowercase")]
enum HiddenAct {
Expand Down Expand Up @@ -180,7 +173,9 @@ impl MultiHeadSelfAttention {
let scores = q.matmul(&k.transpose(2, 3)?.contiguous()?)?;
let mask = attention_mask.broadcast_as(scores.shape())?;

let scores = masked_fill(&scores.to_dtype(DType::F32)?, &mask, f32::NEG_INFINITY)?;
let scores = scores
.to_dtype(DType::F32)?
.masked_fill(&mask, f32::NEG_INFINITY)?;
let weights = candle_nn::ops::softmax(&scores, candle::D::Minus1)?;

let context = weights.matmul(&v.contiguous()?)?;
Expand Down
13 changes: 3 additions & 10 deletions candle-transformers/src/models/falcon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,15 +177,6 @@ impl FalconRotaryEmbedding {
}
}

fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
let shape = mask.shape();
let on_true = Tensor::new(on_true, on_false.device())?
.to_dtype(on_false.dtype())?
.broadcast_as(shape.dims())?;
let m = mask.where_cond(&on_true, on_false)?;
Ok(m)
}

#[derive(Debug, Clone)]
struct FalconAttention {
query_key_value: Linear,
Expand Down Expand Up @@ -298,7 +289,9 @@ impl FalconAttention {
let attention_scores = match mask {
None => attention_scores,
Some(mask) => {
let mask = masked_fill(&mask.to_dtype(DType::F32)?, mask, -1e9)?
let mask = mask
.to_dtype(DType::F32)?
.masked_fill(mask, -1e9)?
.to_dtype(query.dtype())?;
attention_scores.broadcast_add(&mask.squeeze(1)?)?
}
Expand Down
9 changes: 1 addition & 8 deletions candle-transformers/src/models/llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ impl CausalSelfAttention {
att
} else {
let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?;
masked_fill(&att, &mask, f32::NEG_INFINITY)?
att.masked_fill(&mask, f32::NEG_INFINITY)?
};
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
Expand Down Expand Up @@ -295,13 +295,6 @@ impl CausalSelfAttention {
}
}

fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
let shape = mask.shape();
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
let m = mask.where_cond(&on_true, on_false)?;
Ok(m)
}

#[derive(Debug, Clone)]
struct Mlp {
c_fc1: Linear,
Expand Down
9 changes: 1 addition & 8 deletions candle-transformers/src/models/llama2_c.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ impl CausalSelfAttention {
att
} else {
let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?;
masked_fill(&att, &mask, f32::NEG_INFINITY)?
att.masked_fill(&mask, f32::NEG_INFINITY)?
};
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
Expand Down Expand Up @@ -242,13 +242,6 @@ impl CausalSelfAttention {
}
}

fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
let shape = mask.shape();
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
let m = mask.where_cond(&on_true, on_false)?;
Ok(m)
}

#[derive(Debug, Clone)]
struct Mlp {
c_fc1: Linear,
Expand Down
14 changes: 2 additions & 12 deletions candle-transformers/src/models/mpt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,8 @@ impl GroupedQueryAttention {
let attn_weights = attn_weights.broadcast_add(&attn_bias)?;
let attn_weights = match mask {
None => attn_weights,
Some(mask) => masked_fill(
&attn_weights,
&mask.broadcast_as(attn_weights.shape())?,
f32::NEG_INFINITY,
)?,
Some(mask) => attn_weights
.masked_fill(&mask.broadcast_as(attn_weights.shape())?, f32::NEG_INFINITY)?,
};
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
let attn_output = attn_weights
Expand Down Expand Up @@ -281,10 +278,3 @@ pub(crate) fn get_mask(size: usize, device: &Device) -> Result<Tensor> {
.collect();
Tensor::from_slice(&mask, (size, size), device)
}

pub(crate) fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
let shape = mask.shape();
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
let m = mask.where_cond(&on_true, on_false)?;
Ok(m)
}
10 changes: 1 addition & 9 deletions candle-transformers/src/models/phi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,6 @@ fn get_mask(size: usize, device: &Device) -> Result<Tensor> {
Tensor::from_slice(&mask, (size, size), device)
}

fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
let shape = mask.shape();
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
let m = mask.where_cond(&on_true, on_false)?;
Ok(m)
}

impl Attention {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let num_heads = cfg.num_attention_heads;
Expand Down Expand Up @@ -233,8 +226,7 @@ impl Attention {
* self.softmax_scale)?;
let attn_weights = match mask {
None => attn_weights,
Some(mask) => masked_fill(
&attn_weights,
Some(mask) => attn_weights.masked_fill(
&mask.broadcast_left((b_size, self.num_heads))?,
f32::NEG_INFINITY,
)?,
Expand Down
13 changes: 1 addition & 12 deletions candle-transformers/src/models/quantized_llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,19 +138,12 @@ struct LayerWeights {
head_dim: usize,
cos: Tensor,
sin: Tensor,
neg_inf: Tensor,
kv_cache: Option<(Tensor, Tensor)>,
span_attn: tracing::Span,
span_rot: tracing::Span,
span_mlp: tracing::Span,
}

fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: &Tensor) -> Result<Tensor> {
let shape = mask.shape();
let m = mask.where_cond(&on_true.broadcast_as(shape.dims())?, on_false)?;
Ok(m)
}

impl LayerWeights {
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
let _enter = self.span_rot.enter();
Expand Down Expand Up @@ -214,7 +207,7 @@ impl LayerWeights {
None => att,
Some(mask) => {
let mask = mask.broadcast_as(att.shape())?;
masked_fill(&att, &mask, &self.neg_inf)?
att.masked_fill(&mask, f32::NEG_INFINITY)?
}
};
let att = candle_nn::ops::softmax_last_dim(&att)?;
Expand Down Expand Up @@ -260,7 +253,6 @@ impl ModelWeights {
pub fn from_ggml(mut ct: ggml_file::Content, gqa: usize) -> Result<Self> {
let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize;
let (cos, sin) = precomput_freqs_cis(head_dim, 10000., &ct.device)?;
let neg_inf = Tensor::new(f32::NEG_INFINITY, &ct.device)?;
let tok_embeddings = ct.remove("tok_embeddings.weight")?;
let tok_embeddings = tok_embeddings.dequantize(&ct.device)?;
let norm = RmsNorm::from_qtensor(ct.remove("norm.weight")?, 1e-5)?;
Expand Down Expand Up @@ -300,7 +292,6 @@ impl ModelWeights {
head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize,
cos: cos.clone(),
sin: sin.clone(),
neg_inf: neg_inf.clone(),
kv_cache: None,
span_attn,
span_rot,
Expand Down Expand Up @@ -349,7 +340,6 @@ impl ModelWeights {
.and_then(|m| m.to_f32())
.unwrap_or(10000f32);
let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base, device)?;
let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?;

let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
let tok_embeddings = tok_embeddings.dequantize(device)?;
Expand Down Expand Up @@ -420,7 +410,6 @@ impl ModelWeights {
head_dim: embedding_length / head_count,
cos: cos.clone(),
sin: sin.clone(),
neg_inf: neg_inf.clone(),
kv_cache: None,
span_attn,
span_rot,
Expand Down
9 changes: 1 addition & 8 deletions candle-transformers/src/models/quantized_llama2_c.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ impl CausalSelfAttention {
att
} else {
let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?;
masked_fill(&att, &mask, f32::NEG_INFINITY)?
att.masked_fill(&mask, f32::NEG_INFINITY)?
};
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
Expand Down Expand Up @@ -119,13 +119,6 @@ impl CausalSelfAttention {
}
}

fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
let shape = mask.shape();
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
let m = mask.where_cond(&on_true, on_false)?;
Ok(m)
}

#[derive(Debug, Clone)]
struct Mlp {
c_fc1: Linear,
Expand Down
10 changes: 1 addition & 9 deletions candle-transformers/src/models/quantized_mixformer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,6 @@ fn get_mask(size: usize, device: &Device) -> Result<Tensor> {
Tensor::from_slice(&mask, (size, size), device)
}

fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
let shape = mask.shape();
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
let m = mask.where_cond(&on_true, on_false)?;
Ok(m)
}

#[derive(Debug, Clone)]
struct RotaryEmbedding {
sin: Tensor,
Expand Down Expand Up @@ -219,8 +212,7 @@ impl MHA {
// scores = scores + causal_mask.to(dtype=scores.dtype)
let attn_weights = match mask {
None => attn_weights,
Some(mask) => masked_fill(
&attn_weights,
Some(mask) => attn_weights.masked_fill(
&mask.broadcast_left(b_size * self.n_head)?,
f32::NEG_INFINITY,
)?,
Expand Down
7 changes: 2 additions & 5 deletions candle-transformers/src/models/quantized_mpt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,8 @@ impl GroupedQueryAttention {
let attn_weights = attn_weights.broadcast_add(&attn_bias)?;
let attn_weights = match mask {
None => attn_weights,
Some(mask) => super::mpt::masked_fill(
&attn_weights,
&mask.broadcast_as(attn_weights.shape())?,
f32::NEG_INFINITY,
)?,
Some(mask) => attn_weights
.masked_fill(&mask.broadcast_as(attn_weights.shape())?, f32::NEG_INFINITY)?,
};
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
let attn_output = attn_weights
Expand Down
11 changes: 1 addition & 10 deletions candle-transformers/src/models/quantized_phi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,18 +61,11 @@ struct LayerWeights {
cos: Tensor,
sin: Tensor,
rope_dim: usize,
neg_inf: Tensor,
kv_cache: Option<(Tensor, Tensor)>,
span_attn: tracing::Span,
span_rot: tracing::Span,
}

fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: &Tensor) -> Result<Tensor> {
let shape = mask.shape();
let m = mask.where_cond(&on_true.broadcast_as(shape.dims())?, on_false)?;
Ok(m)
}

impl LayerWeights {
fn apply_rotary_emb(&self, xs: &Tensor, index_pos: usize) -> Result<Tensor> {
let _enter = self.span_rot.enter();
Expand Down Expand Up @@ -131,7 +124,7 @@ impl LayerWeights {
None => att,
Some(mask) => {
let mask = mask.broadcast_as(att.shape())?;
masked_fill(&att, &mask, &self.neg_inf)?
att.masked_fill(&mask, f32::NEG_INFINITY)?
}
};
let att = candle_nn::ops::softmax_last_dim(&att)?;
Expand Down Expand Up @@ -199,7 +192,6 @@ impl ModelWeights {
let rope_dim = md_get("phi2.rope.dimension_count")?.to_u32()? as usize;
let ln_eps = md_get("phi2.attention.layer_norm_epsilon")?.to_f32()? as f64;
let (cos, sin) = precomput_freqs_cis(rope_dim, 10_000., device)?;
let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?;

let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
let tok_embeddings = tok_embeddings.dequantize(device)?;
Expand Down Expand Up @@ -233,7 +225,6 @@ impl ModelWeights {
cos: cos.clone(),
sin: sin.clone(),
rope_dim,
neg_inf: neg_inf.clone(),
kv_cache: None,
span_attn,
span_rot,
Expand Down
Loading