Skip to content

Commit 3b2648c

Browse files
committed
Fix #38 - support rust_decimal.
1 parent fb899e3 commit 3b2648c

File tree

7 files changed

+102
-12
lines changed

7 files changed

+102
-12
lines changed

Cargo.toml

+4-3
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ arrayvec = { version = "0.7", default-features = false, optional = true }
1818
bitcode_derive = { version = "0.6.3", path = "./bitcode_derive", optional = true }
1919
bytemuck = { version = "1.14", features = [ "min_const_generics", "must_cast" ] }
2020
glam = { version = ">=0.21", default-features = false, optional = true }
21+
rust_decimal = { version = "1.36.0", default-features = false, optional = true }
2122
serde = { version = "1.0", default-features = false, features = [ "alloc" ], optional = true }
2223

2324
[dev-dependencies]
@@ -37,8 +38,8 @@ zstd = "0.13.0"
3738

3839
[features]
3940
derive = [ "dep:bitcode_derive" ]
40-
std = [ "serde?/std", "glam?/std", "arrayvec?/std" ]
41-
default = [ "derive", "std" ]
41+
std = [ "serde?/std", "glam?/std", "arrayvec?/std", "rust_decimal?/std" ]
42+
default = [ "derive", "std", "rust_decimal" ]
4243

4344
[package.metadata.docs.rs]
4445
features = [ "derive", "serde", "std" ]
@@ -48,4 +49,4 @@ features = [ "derive", "serde", "std" ]
4849
#lto = true
4950

5051
[lints.rust]
51-
unexpected_cfgs = { level = "warn", check-cfg = ['cfg(fuzzing)'] }
52+
unexpected_cfgs = { level = "warn", check-cfg = ['cfg(fuzzing)'] }

fuzz/Cargo.toml

+3-2
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@ cargo-fuzz = true
1010

1111
[dependencies]
1212
arrayvec = { version = "0.7", features = ["serde"] }
13-
bitcode = { path = "..", features = [ "arrayvec", "serde" ] }
13+
bitcode = { path = "..", features = [ "arrayvec", "serde", "rust_decimal" ] }
1414
libfuzzer-sys = "0.4"
15+
rust_decimal = "1.36.0"
1516
serde = { version ="1.0", features = [ "derive" ] }
1617

1718
# Prevent this from interfering with workspaces
@@ -22,4 +23,4 @@ members = ["."]
2223
name = "fuzz"
2324
path = "fuzz_targets/fuzz.rs"
2425
test = false
25-
doc = false
26+
doc = false

fuzz/fuzz_targets/fuzz.rs

+2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ use std::collections::{BTreeMap, HashMap};
99
use std::fmt::Debug;
1010
use std::num::NonZeroU32;
1111
use std::time::Duration;
12+
use rust_decimal::Decimal;
1213
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddr, SocketAddrV6};
1314

1415
#[inline(never)]
@@ -209,6 +210,7 @@ fuzz_target!(|data: &[u8]| {
209210
ArrayString<70>,
210211
ArrayVec<u8, 5>,
211212
ArrayVec<u8, 70>,
213+
Decimal,
212214
Duration,
213215
Ipv4Addr,
214216
Ipv6Addr,

src/derive/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use alloc::vec::Vec;
55
use core::num::NonZeroUsize;
66

77
mod array;
8-
mod convert;
8+
pub(crate) mod convert;
99
mod duration;
1010
mod empty;
1111
mod impls;

src/ext/mod.rs

+2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ mod arrayvec;
33
#[cfg(feature = "glam")]
44
#[rustfmt::skip] // Makes impl_struct! calls way longer.
55
mod glam;
6+
#[cfg(feature = "rust_decimal")]
7+
mod rust_decimal;
68

79
#[allow(unused)]
810
macro_rules! impl_struct {

src/ext/rust_decimal.rs

+84
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
use crate::{
2+
convert::{self, ConvertFrom},
3+
Decode, Encode,
4+
};
5+
use bytemuck::CheckedBitPattern;
6+
use rust_decimal::Decimal;
7+
8+
type DecimalConversion = (u32, u32, u32, Flags);
9+
10+
impl ConvertFrom<&Decimal> for DecimalConversion {
11+
fn convert_from(value: &Decimal) -> Self {
12+
let unpacked = value.unpack();
13+
(
14+
unpacked.lo,
15+
unpacked.mid,
16+
unpacked.hi,
17+
Flags::new(unpacked.scale, unpacked.negative),
18+
)
19+
}
20+
}
21+
22+
impl ConvertFrom<DecimalConversion> for Decimal {
23+
fn convert_from(value: DecimalConversion) -> Self {
24+
Self::from_parts(
25+
value.0,
26+
value.1,
27+
value.2,
28+
value.3.negative(),
29+
value.3.scale(),
30+
)
31+
}
32+
}
33+
34+
impl Encode for Decimal {
35+
type Encoder = convert::ConvertIntoEncoder<DecimalConversion>;
36+
}
37+
impl<'a> Decode<'a> for Decimal {
38+
type Decoder = convert::ConvertFromDecoder<'a, DecimalConversion>;
39+
}
40+
41+
impl ConvertFrom<&Flags> for u8 {
42+
fn convert_from(flags: &Flags) -> Self {
43+
flags.0
44+
}
45+
}
46+
47+
impl Encode for Flags {
48+
type Encoder = convert::ConvertIntoEncoder<u8>;
49+
}
50+
51+
/// A u8 guaranteed to satisfy (flags >> 1) <= 28. Prevents Decimal::from_parts from misbehaving.
52+
#[derive(Copy, Clone)]
53+
#[repr(transparent)]
54+
pub struct Flags(u8);
55+
56+
impl Flags {
57+
#[inline(always)]
58+
fn new(scale: u32, negative: bool) -> Self {
59+
Self((scale as u8) << 1 | negative as u8)
60+
}
61+
62+
#[inline(always)]
63+
fn scale(&self) -> u32 {
64+
(self.0 >> 1) as u32
65+
}
66+
67+
#[inline(always)]
68+
fn negative(&self) -> bool {
69+
self.0 & 1 == 1
70+
}
71+
}
72+
73+
// Safety: u8 and Flags have the same layout since Flags is #[repr(transparent)].
74+
unsafe impl CheckedBitPattern for Flags {
75+
type Bits = u8;
76+
#[inline(always)]
77+
fn is_valid_bit_pattern(bits: &Self::Bits) -> bool {
78+
(*bits >> 1) <= 28
79+
}
80+
}
81+
82+
impl<'a> Decode<'a> for Flags {
83+
type Decoder = crate::int::CheckedIntDecoder<'a, Flags, u8>;
84+
}

src/serde/de.rs

+6-6
Original file line numberDiff line numberDiff line change
@@ -119,19 +119,19 @@ macro_rules! specify {
119119
#[cold]
120120
fn cold<'de>(decoder: &mut SerdeDecoder<'de>, input: &mut &'de [u8]) -> Result<()> {
121121
let &mut SerdeDecoder::Unspecified { length } = decoder else {
122-
type_changed!()
123-
};
122+
type_changed!()
123+
};
124124
*decoder = SerdeDecoder::$variant(Default::default());
125125
decoder.populate(input, length)
126126
}
127127
cold(&mut *$self.decoder, &mut *$self.input)?;
128128
}
129129
}
130130
let SerdeDecoder::$variant(d) = &mut *$self.decoder else {
131-
// Safety: `cold` gets called when decoder isn't the correct decoder. `cold` either
132-
// errors or sets lazy to the correct decoder.
133-
unsafe { core::hint::unreachable_unchecked() };
134-
};
131+
// Safety: `cold` gets called when decoder isn't the correct decoder. `cold` either
132+
// errors or sets lazy to the correct decoder.
133+
unsafe { core::hint::unreachable_unchecked() };
134+
};
135135
d
136136
}};
137137
}

0 commit comments

Comments
 (0)