Skip to content

Commit 240c7ed

Browse files
authored
Merge pull request #145 from greyblake/generics-trait-boundaries
WIP: Start working on trait boundaries for generics
2 parents 34f571b + 35ff667 commit 240c7ed

File tree

6 files changed

+356
-76
lines changed

6 files changed

+356
-76
lines changed

README.md

-6
Original file line numberDiff line numberDiff line change
@@ -376,12 +376,6 @@ assert_eq!(name.into_inner(), " boo ");
376376
* IDEs may not be very helpful at giving you hints about proc macros.
377377
* Design of nutype may enforce you to run unnecessary validation (e.g. on loading data from DB), which may have a negative impact if you aim for extreme performance.
378378

379-
## A note about #[derive(...)]
380-
381-
You've got to know that the `#[nutype]` macro intercepts `#[derive(...)]` macro.
382-
It's done on purpose to ensure that anything like `DerefMut` or `BorrowMut`, that can lead to a violation of the validation rules is excluded.
383-
The library takes a conservative approach and it has its downside: deriving traits that are not known to the library is not possible.
384-
385379
## Support Ukrainian military forces
386380

387381
Today I live in Berlin, I have the luxury to live a physically safe life.

dummy/src/main.rs

+19-5
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,24 @@
11
use nutype::nutype;
2-
use std::borrow::Cow;
2+
use std::cmp::Ord;
33

4-
#[nutype(derive(Into))]
5-
struct Clarabelle<'a>(Cow<'a, str>);
4+
#[nutype(
5+
sanitize(with = |mut v| { v.sort(); v }),
6+
validate(predicate = |vec| !vec.is_empty()),
7+
derive(Debug, Deserialize, Serialize),
8+
)]
9+
struct SortedNotEmptyVec<T: Ord>(Vec<T>);
610

711
fn main() {
8-
// let clarabelle = Clarabelle::new(Cow::Borrowed("Clarabelle"));
9-
// assert_eq!(clarabelle.to_string(), "Clarabelle");
12+
{
13+
// Not empty vec is fine
14+
let json = "[3, 1, 5, 2]";
15+
let sv = serde_json::from_str::<SortedNotEmptyVec<i32>>(json).unwrap();
16+
assert_eq!(sv.into_inner(), vec![1, 2, 3, 5]);
17+
}
18+
{
19+
// Empty vec is not allowed
20+
let json = "[]";
21+
let result = serde_json::from_str::<SortedNotEmptyVec<i32>>(json);
22+
assert!(result.is_err());
23+
}
1024
}

examples/any_generics/src/main.rs

+54-2
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,29 @@
11
use nutype::nutype;
22
use std::borrow::Cow;
3+
use std::cmp::Ord;
34

5+
/// A wrapper around a vector that is guaranteed to be sorted.
6+
#[nutype(
7+
sanitize(with = |mut v| { v.sort(); v }),
8+
derive(Debug, Deserialize, Serialize)
9+
)]
10+
struct SortedVec<T: Ord>(Vec<T>);
11+
12+
/// A wrapper around a vector that is guaranteed to be non-empty.
413
#[nutype(
514
validate(predicate = |vec| !vec.is_empty()),
615
derive(Debug),
716
)]
817
struct NotEmpty<T>(Vec<T>);
918

19+
#[nutype(
20+
sanitize(with = |mut v| { v.sort(); v }),
21+
validate(predicate = |vec| !vec.is_empty()),
22+
derive(Debug, Deserialize, Serialize),
23+
)]
24+
struct SortedNotEmptyVec<T: Ord>(Vec<T>);
25+
26+
/// An example with lifetimes
1027
#[nutype(derive(
1128
Debug,
1229
Display,
@@ -32,15 +49,50 @@ struct NotEmpty<T>(Vec<T>);
3249
struct Clarabelle<'a>(Cow<'a, str>);
3350

3451
fn main() {
52+
// SortedVec
53+
//
3554
{
36-
let v = NotEmpty::new(vec![1, 2, 3]).unwrap();
37-
assert_eq!(v.into_inner(), vec![1, 2, 3]);
55+
let v = SortedVec::new(vec![3, 0, 2]);
56+
assert_eq!(v.into_inner(), vec![0, 2, 3]);
57+
}
58+
{
59+
let sv = SortedVec::new(vec![4i32, 2, 8, 5]);
60+
let json = serde_json::to_string(&sv).unwrap();
61+
assert_eq!(json, "[2,4,5,8]");
3862
}
3963
{
64+
let json = "[5,3,7]";
65+
let sv = serde_json::from_str::<SortedVec<i32>>(json).unwrap();
66+
assert_eq!(sv.into_inner(), vec![3, 5, 7]);
67+
}
68+
69+
// NotEmpty
70+
//
71+
{
72+
let v = NotEmpty::new(vec![1, 2, 3]).unwrap();
73+
assert_eq!(v.into_inner(), vec![1, 2, 3]);
74+
4075
let err = NotEmpty::<i32>::new(vec![]).unwrap_err();
4176
assert_eq!(err, NotEmptyError::PredicateViolated);
4277
}
4378

79+
// SortedNotEmptyVec
80+
//
81+
{
82+
// Not empty vec is fine
83+
let json = "[3, 1, 5, 2]";
84+
let snev = serde_json::from_str::<SortedNotEmptyVec<i32>>(json).unwrap();
85+
assert_eq!(snev.into_inner(), vec![1, 2, 3, 5]);
86+
}
87+
{
88+
// Empty vec is not allowed
89+
let json = "[]";
90+
let result = serde_json::from_str::<SortedNotEmptyVec<i32>>(json);
91+
assert!(result.is_err());
92+
}
93+
94+
// Clarabelle (Cow)
95+
//
4496
{
4597
let muu = Clarabelle::new(Cow::Borrowed("Muu"));
4698
assert_eq!(muu.to_string(), "Muu");

nutype_macros/src/common/gen/mod.rs

+43-3
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,9 @@ pub fn gen_impl_into_inner(
138138
generics: &Generics,
139139
inner_type: impl ToTokens,
140140
) -> TokenStream {
141+
let generics_without_bounds = strip_trait_bounds_on_generics(generics);
141142
quote! {
142-
impl #generics #type_name #generics {
143+
impl #generics #type_name #generics_without_bounds {
143144
#[inline]
144145
pub fn into_inner(self) -> #inner_type {
145146
self.0
@@ -148,6 +149,43 @@ pub fn gen_impl_into_inner(
148149
}
149150
}
150151

152+
/// Remove trait bounds from generics.
153+
///
154+
/// Input:
155+
/// <T: Display + Debug, U: Clone>
156+
///
157+
/// Output:
158+
/// <T, U>
159+
fn strip_trait_bounds_on_generics(original: &Generics) -> Generics {
160+
let mut generics = original.clone();
161+
for param in &mut generics.params {
162+
if let syn::GenericParam::Type(syn::TypeParam { bounds, .. }) = param {
163+
*bounds = syn::punctuated::Punctuated::new();
164+
}
165+
}
166+
generics
167+
}
168+
169+
/// Add a bound to all generics types.
170+
///
171+
/// Input:
172+
/// <T, U>
173+
/// Serialize
174+
///
175+
/// Output:
176+
/// <T: Serialize, U: Serialize>
177+
fn add_bound_to_all_type_params(generics: &Generics, bound: TokenStream) -> Generics {
178+
let mut generics = generics.clone();
179+
let parsed_bound: syn::TypeParamBound =
180+
syn::parse2(bound).expect("Failed to parse TypeParamBound");
181+
for param in &mut generics.params {
182+
if let syn::GenericParam::Type(syn::TypeParam { bounds, .. }) = param {
183+
bounds.push(parsed_bound.clone());
184+
}
185+
}
186+
generics
187+
}
188+
151189
pub trait GenerateNewtype {
152190
type Sanitizer;
153191
type Validator;
@@ -197,6 +235,7 @@ pub trait GenerateNewtype {
197235
sanitizers: &[Self::Sanitizer],
198236
validators: &[Self::Validator],
199237
) -> TokenStream {
238+
let generics_without_bounds = strip_trait_bounds_on_generics(generics);
200239
let fn_sanitize = Self::gen_fn_sanitize(inner_type, sanitizers);
201240
let validation_error = Self::gen_validation_error_type(type_name, validators);
202241
let error_type_name = gen_error_type_name(type_name);
@@ -214,7 +253,7 @@ pub trait GenerateNewtype {
214253
quote!(
215254
#validation_error
216255

217-
impl #generics #type_name #generics {
256+
impl #generics #type_name #generics_without_bounds {
218257
pub fn new(raw_value: #input_type) -> ::core::result::Result<Self, #error_type_name> {
219258
#convert_raw_value_if_necessary
220259

@@ -237,6 +276,7 @@ pub trait GenerateNewtype {
237276
inner_type: &Self::InnerType,
238277
sanitizers: &[Self::Sanitizer],
239278
) -> TokenStream {
279+
let generics_without_bounds = strip_trait_bounds_on_generics(generics);
240280
let fn_sanitize = Self::gen_fn_sanitize(inner_type, sanitizers);
241281

242282
let (input_type, convert_raw_value_if_necessary) = if Self::NEW_CONVERT_INTO_INNER_TYPE {
@@ -249,7 +289,7 @@ pub trait GenerateNewtype {
249289
};
250290

251291
quote!(
252-
impl #generics #type_name #generics {
292+
impl #generics #type_name #generics_without_bounds {
253293
pub fn new(raw_value: #input_type) -> Self {
254294
#convert_raw_value_if_necessary
255295
Self(Self::__sanitize__(raw_value))

nutype_macros/src/common/gen/traits.rs

+25-7
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@ use proc_macro2::TokenStream;
44
use quote::{quote, ToTokens};
55
use syn::Generics;
66

7-
use crate::common::models::{ErrorTypeName, InnerType, TypeName};
7+
use crate::common::{
8+
gen::{add_bound_to_all_type_params, strip_trait_bounds_on_generics},
9+
models::{ErrorTypeName, InnerType, TypeName},
10+
};
811

912
use super::parse_error::{gen_def_parse_error, gen_parse_error_name};
1013

@@ -106,8 +109,11 @@ pub fn gen_impl_trait_deref(
106109
}
107110

108111
pub fn gen_impl_trait_display(type_name: &TypeName, generics: &Generics) -> TokenStream {
112+
let generics_without_bounds = strip_trait_bounds_on_generics(generics);
113+
let generics_with_display_bound =
114+
add_bound_to_all_type_params(generics, syn::parse_quote!(::core::fmt::Display));
109115
quote! {
110-
impl #generics ::core::fmt::Display for #type_name #generics {
116+
impl #generics_with_display_bound ::core::fmt::Display for #type_name #generics_without_bounds {
111117
#[inline]
112118
fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
113119
// A tiny wrapper function with trait boundary that improves error reporting.
@@ -239,9 +245,15 @@ pub fn gen_impl_trait_from_str(
239245
}
240246

241247
pub fn gen_impl_trait_serde_serialize(type_name: &TypeName, generics: &Generics) -> TokenStream {
248+
let generics_without_bounds = strip_trait_bounds_on_generics(generics);
249+
250+
// Turn `<T>` into `<T: Serialize>`
251+
let all_generics_with_serialize_bound =
252+
add_bound_to_all_type_params(generics, syn::parse_quote!(::serde::Serialize));
253+
242254
let type_name_str = type_name.to_string();
243255
quote! {
244-
impl #generics ::serde::Serialize for #type_name #generics {
256+
impl #all_generics_with_serialize_bound ::serde::Serialize for #type_name #generics_without_bounds {
245257
fn serialize<S>(&self, serializer: S) -> ::core::result::Result<S::Ok, S::Error>
246258
where
247259
S: ::serde::Serializer
@@ -283,17 +295,23 @@ pub fn gen_impl_trait_serde_deserialize(
283295
all_generics.params.push(syn::parse_quote!('de));
284296
all_generics
285297
};
298+
let all_generics_without_bounds = strip_trait_bounds_on_generics(&all_generics);
299+
let type_generics_without_bounds = strip_trait_bounds_on_generics(type_generics);
300+
301+
// Turn `<'de, T>` into `<'de, T: Deserialize<'de>>`
302+
let all_generics_with_deserialize_bound =
303+
add_bound_to_all_type_params(&all_generics, syn::parse_quote!(::serde::Deserialize<'de>));
286304

287305
quote! {
288-
impl #all_generics ::serde::Deserialize<'de> for #type_name #type_generics {
306+
impl #all_generics_with_deserialize_bound ::serde::Deserialize<'de> for #type_name #type_generics_without_bounds {
289307
fn deserialize<D: ::serde::Deserializer<'de>>(deserializer: D) -> ::core::result::Result<Self, D::Error> {
290308
struct __Visitor #all_generics {
291-
marker: ::std::marker::PhantomData<#type_name #type_generics>,
309+
marker: ::std::marker::PhantomData<#type_name #type_generics_without_bounds>,
292310
lifetime: ::std::marker::PhantomData<&'de ()>,
293311
}
294312

295-
impl #all_generics ::serde::de::Visitor<'de> for __Visitor #all_generics {
296-
type Value = #type_name #type_generics;
313+
impl #all_generics_with_deserialize_bound ::serde::de::Visitor<'de> for __Visitor #all_generics_without_bounds {
314+
type Value = #type_name #type_generics_without_bounds;
297315

298316
fn expecting(&self, formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
299317
write!(formatter, #expecting_str)

0 commit comments

Comments
 (0)