Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 118 additions & 24 deletions cspuz_rs/src/solver/ndarray.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use super::traits::{ArrayShape, Item, Operand, PropagateBinary, PropagateTernary
use crate::items::Arrow;
use crate::solver::traits::BoolArrayLike;
use crate::solver::traits::IntArrayLike;
use crate::solver::BoolExprArray1D;
use std::ops::{Bound, Not, RangeBounds};

use cspuz_core::csp::BoolExpr as CSPBoolExpr;
Expand Down Expand Up @@ -599,46 +600,67 @@ where
Self: Operand<Shape = (usize, usize), Value = CSPBoolExpr>,
{
pub fn conv2d_and(&self, filter: (usize, usize)) -> NdArray<(usize, usize), CSPBoolExpr> {
self.conv2d_impl(filter, CSPBoolExpr::And)
conv2d_impl(self, filter, |parts| {
CSPBoolExpr::And(parts.into_iter().map(Box::new).collect())
})
}

pub fn conv2d_or(&self, filter: (usize, usize)) -> NdArray<(usize, usize), CSPBoolExpr> {
self.conv2d_impl(filter, CSPBoolExpr::Or)
conv2d_impl(self, filter, |parts| {
CSPBoolExpr::Or(parts.into_iter().map(Box::new).collect())
})
}

fn conv2d_impl<F>(&self, filter: (usize, usize), op: F) -> NdArray<(usize, usize), CSPBoolExpr>
where
F: Fn(Vec<Box<CSPBoolExpr>>) -> CSPBoolExpr,
{
let orig = self.as_ndarray();
let (h, w) = orig.shape;
let (fh, fw) = filter;
assert!(h >= fh);
assert!(w >= fw);
pub fn conv2d_count_true(&self, filter: (usize, usize)) -> NdArray<(usize, usize), CSPIntExpr> {
conv2d_impl(self, filter, |parts| {
let array = BoolExprArray1D::from_raw(parts);
array.count_true().data.0
})
}
}

let mut data = vec![];
for y in 0..=(h - fh) {
for x in 0..=(w - fw) {
let mut part = vec![];
for dy in 0..fh {
for dx in 0..fw {
part.push(Box::new(orig.data[(y + dy) * w + (x + dx)].clone()));
}
fn conv2d_impl<A, I, O, F>(
array: &NdArray<(usize, usize), A>,
filter: (usize, usize),
op: F,
) -> NdArray<(usize, usize), O>
where
A: Clone,
I: Clone,
O: Clone,
NdArray<(usize, usize), A>: Operand<Shape = (usize, usize), Value = I>,
F: Fn(Vec<I>) -> O,
{
let orig = array.as_ndarray();
let (h, w) = orig.shape;
let (fh, fw) = filter;
assert!(h >= fh);
assert!(w >= fw);

let mut data = vec![];
for y in 0..=(h - fh) {
for x in 0..=(w - fw) {
let mut part = vec![];
for dy in 0..fh {
for dx in 0..fw {
part.push(orig.data[(y + dy) * w + (x + dx)].clone());
}
data.push(op(part));
}
data.push(op(part));
}
}

NdArray {
shape: (h - fh + 1, w - fw + 1),
data,
}
NdArray {
shape: (h - fh + 1, w - fw + 1),
data,
}
}

#[cfg(test)]
mod tests {
use super::super::Solver;
use cspuz_core::csp::BoolExpr as CSPBoolExpr;
use cspuz_core::csp::IntExpr as CSPIntExpr;

#[test]
fn test_ndarray_add_0d_0d() {
Expand Down Expand Up @@ -938,4 +960,76 @@ mod tests {
assert_eq!(model.get(b), -3);
}
}

#[test]
fn test_ndarray_conv2d_and() {
let mut solver = Solver::new();
let a = &solver.bool_var_2d((4, 5));
let b = a.conv2d_and((2, 2));

assert_eq!(b.shape(), (3, 4));
for y in 0..3 {
for x in 0..4 {
let expected = CSPBoolExpr::And(vec![
Box::new(a.at((y, x)).data.0.expr()),
Box::new(a.at((y, x + 1)).data.0.expr()),
Box::new(a.at((y + 1, x)).data.0.expr()),
Box::new(a.at((y + 1, x + 1)).data.0.expr()),
]);
assert_eq!(&expected, &b.at((y, x)).data.0);
}
}
}

#[test]
fn test_ndarray_conv2d_or() {
let mut solver = Solver::new();
let a = &solver.bool_var_2d((4, 5));
let b = a.conv2d_or((2, 2));

assert_eq!(b.shape(), (3, 4));
for y in 0..3 {
for x in 0..4 {
let expected = CSPBoolExpr::Or(vec![
Box::new(a.at((y, x)).data.0.expr()),
Box::new(a.at((y, x + 1)).data.0.expr()),
Box::new(a.at((y + 1, x)).data.0.expr()),
Box::new(a.at((y + 1, x + 1)).data.0.expr()),
]);
assert_eq!(&expected, &b.at((y, x)).data.0);
}
}
}

#[test]
fn test_ndarray_conv2d_count_true() {
let mut solver = Solver::new();
let a = &solver.bool_var_2d((4, 5));
let b = a.conv2d_count_true((2, 2));

assert_eq!(b.shape(), (3, 4));
for y in 0..3 {
for x in 0..4 {
let expected = {
let mut terms = vec![];
for dy in 0..2 {
for dx in 0..2 {
terms.push((
Box::new(
a.at((y + dy, x + dx))
.data
.0
.expr()
.ite(CSPIntExpr::Const(1), CSPIntExpr::Const(0)),
),
1,
));
}
}
CSPIntExpr::Linear(terms)
};
assert_eq!(&expected, &b.at((y, x)).data.0);
}
}
}
}