diff --git a/.github/workflows/ci_linux.yaml b/.github/workflows/ci_linux.yaml index 116de865..660d6933 100644 --- a/.github/workflows/ci_linux.yaml +++ b/.github/workflows/ci_linux.yaml @@ -50,5 +50,5 @@ jobs: run: cargo test --features single_threaded_async - name: Build docs - run: cargo doc + run: cargo doc --all-features diff --git a/macros/Cargo.toml b/macros/Cargo.toml index 12fa2bb3..c8627251 100644 --- a/macros/Cargo.toml +++ b/macros/Cargo.toml @@ -16,3 +16,4 @@ proc-macro = true [dependencies] syn = "2.0" quote = "1.0" +proc-macro2 = "1.0.93" diff --git a/macros/src/buffer.rs b/macros/src/buffer.rs new file mode 100644 index 00000000..b231da09 --- /dev/null +++ b/macros/src/buffer.rs @@ -0,0 +1,493 @@ +use proc_macro2::TokenStream; +use quote::{format_ident, quote}; +use syn::{ + parse_quote, Field, Generics, Ident, ImplGenerics, ItemStruct, Type, TypeGenerics, TypePath, + Visibility, WhereClause, +}; + +use crate::Result; + +const JOINED_ATTR_TAG: &'static str = "joined"; +const KEY_ATTR_TAG: &'static str = "key"; + +pub(crate) fn impl_joined_value(input_struct: &ItemStruct) -> Result { + let struct_ident = &input_struct.ident; + let (impl_generics, ty_generics, where_clause) = input_struct.generics.split_for_impl(); + let StructConfig { + buffer_struct_name: buffer_struct_ident, + } = StructConfig::from_data_struct(&input_struct, &JOINED_ATTR_TAG); + let buffer_struct_vis = &input_struct.vis; + + let (field_ident, _, field_config) = + get_fields_map(&input_struct.fields, FieldSettings::for_joined())?; + let buffer: Vec<&Type> = field_config.iter().map(|config| &config.buffer).collect(); + let noncopy = field_config.iter().any(|config| config.noncopy); + + let buffer_struct: ItemStruct = generate_buffer_struct( + &buffer_struct_ident, + buffer_struct_vis, + &impl_generics, + &where_clause, + &field_ident, + &buffer, + ); + + let impl_buffer_clone = impl_buffer_clone( + &buffer_struct_ident, + &impl_generics, + &ty_generics, + &where_clause, + &field_ident, + noncopy, + ); + + let impl_select_buffers = impl_select_buffers( + struct_ident, + &buffer_struct_ident, + buffer_struct_vis, + &impl_generics, + &ty_generics, + &where_clause, + &field_ident, + &buffer, + ); + + let impl_buffer_map_layout = + impl_buffer_map_layout(&buffer_struct, &field_ident, &field_config)?; + let impl_joined = impl_joined(&buffer_struct, &input_struct, &field_ident)?; + + let gen = quote! { + impl #impl_generics ::bevy_impulse::Joined for #struct_ident #ty_generics #where_clause { + type Buffers = #buffer_struct_ident #ty_generics; + } + + #buffer_struct + + #impl_buffer_clone + + #impl_select_buffers + + #impl_buffer_map_layout + + #impl_joined + }; + + Ok(gen.into()) +} + +pub(crate) fn impl_buffer_key_map(input_struct: &ItemStruct) -> Result { + let struct_ident = &input_struct.ident; + let (impl_generics, ty_generics, where_clause) = input_struct.generics.split_for_impl(); + let StructConfig { + buffer_struct_name: buffer_struct_ident, + } = StructConfig::from_data_struct(&input_struct, &KEY_ATTR_TAG); + let buffer_struct_vis = &input_struct.vis; + + let (field_ident, field_type, field_config) = + get_fields_map(&input_struct.fields, FieldSettings::for_key())?; + let buffer: Vec<&Type> = field_config.iter().map(|config| &config.buffer).collect(); + let noncopy = field_config.iter().any(|config| config.noncopy); + + let buffer_struct: ItemStruct = generate_buffer_struct( + &buffer_struct_ident, + buffer_struct_vis, + &impl_generics, + &where_clause, + &field_ident, + &buffer, + ); + + let impl_buffer_clone = impl_buffer_clone( + &buffer_struct_ident, + &impl_generics, + &ty_generics, + &where_clause, + &field_ident, + noncopy, + ); + + let impl_select_buffers = impl_select_buffers( + struct_ident, + &buffer_struct_ident, + buffer_struct_vis, + &impl_generics, + &ty_generics, + &where_clause, + &field_ident, + &buffer, + ); + + let impl_buffer_map_layout = + impl_buffer_map_layout(&buffer_struct, &field_ident, &field_config)?; + let impl_accessed = impl_accessed(&buffer_struct, &input_struct, &field_ident, &field_type)?; + + let gen = quote! { + impl #impl_generics ::bevy_impulse::Accessor for #struct_ident #ty_generics #where_clause { + type Buffers = #buffer_struct_ident #ty_generics; + } + + #buffer_struct + + #impl_buffer_clone + + #impl_select_buffers + + #impl_buffer_map_layout + + #impl_accessed + }; + + Ok(gen.into()) +} + +/// Code that are currently unused but could be used in the future, move them out of this mod if +/// they are ever used. +#[allow(unused)] +mod _unused { + use super::*; + + /// Converts a list of generics to a [`PhantomData`] TypePath. + /// e.g. `::std::marker::PhantomData` + fn to_phantom_data(generics: &Generics) -> TypePath { + let lifetimes: Vec = generics + .lifetimes() + .map(|lt| { + let lt = <.lifetime; + let ty: Type = parse_quote! { & #lt () }; + ty + }) + .collect(); + let ty_params: Vec<&Ident> = generics.type_params().map(|ty| &ty.ident).collect(); + parse_quote! { ::std::marker::PhantomData } + } +} + +struct StructConfig { + buffer_struct_name: Ident, +} + +impl StructConfig { + fn from_data_struct(data_struct: &ItemStruct, attr_tag: &str) -> Self { + let mut config = Self { + buffer_struct_name: format_ident!("__bevy_impulse_{}_Buffers", data_struct.ident), + }; + + let attr = data_struct + .attrs + .iter() + .find(|attr| attr.path().is_ident(attr_tag)); + + if let Some(attr) = attr { + attr.parse_nested_meta(|meta| { + if meta.path.is_ident("buffers_struct_name") { + config.buffer_struct_name = meta.value()?.parse()?; + } + Ok(()) + }) + // panic if attribute is malformed, this will result in a compile error which is intended. + .unwrap(); + } + + config + } +} + +struct FieldSettings { + default_buffer: fn(&Type) -> Type, + attr_tag: &'static str, +} + +impl FieldSettings { + fn for_joined() -> Self { + Self { + default_buffer: Self::default_field_for_joined, + attr_tag: JOINED_ATTR_TAG, + } + } + + fn for_key() -> Self { + Self { + default_buffer: Self::default_field_for_key, + attr_tag: KEY_ATTR_TAG, + } + } + + fn default_field_for_joined(ty: &Type) -> Type { + parse_quote! { ::bevy_impulse::Buffer<#ty> } + } + + fn default_field_for_key(ty: &Type) -> Type { + parse_quote! { <#ty as ::bevy_impulse::BufferKeyLifecycle>::TargetBuffer } + } +} + +struct FieldConfig { + buffer: Type, + noncopy: bool, +} + +impl FieldConfig { + fn from_field(field: &Field, settings: &FieldSettings) -> Self { + let ty = &field.ty; + let mut config = Self { + buffer: (settings.default_buffer)(ty), + noncopy: false, + }; + + for attr in field + .attrs + .iter() + .filter(|attr| attr.path().is_ident(settings.attr_tag)) + { + attr.parse_nested_meta(|meta| { + if meta.path.is_ident("buffer") { + config.buffer = meta.value()?.parse()?; + } + if meta.path.is_ident("noncopy_buffer") { + config.noncopy = true; + } + Ok(()) + }) + // panic if attribute is malformed, this will result in a compile error which is intended. + .unwrap(); + } + + config + } +} + +fn get_fields_map( + fields: &syn::Fields, + settings: FieldSettings, +) -> Result<(Vec<&Ident>, Vec<&Type>, Vec)> { + match fields { + syn::Fields::Named(data) => { + let mut idents = Vec::new(); + let mut types = Vec::new(); + let mut configs = Vec::new(); + for field in &data.named { + let ident = field + .ident + .as_ref() + .ok_or("expected named fields".to_string())?; + idents.push(ident); + types.push(&field.ty); + configs.push(FieldConfig::from_field(field, &settings)); + } + Ok((idents, types, configs)) + } + _ => return Err("expected named fields".to_string()), + } +} + +fn generate_buffer_struct( + buffer_struct_ident: &Ident, + buffer_struct_vis: &Visibility, + impl_generics: &ImplGenerics, + where_clause: &Option<&WhereClause>, + field_ident: &Vec<&Ident>, + buffer: &Vec<&Type>, +) -> ItemStruct { + parse_quote! { + #[allow(non_camel_case_types, unused)] + #buffer_struct_vis struct #buffer_struct_ident #impl_generics #where_clause { + #( + #buffer_struct_vis #field_ident: #buffer, + )* + } + } +} + +fn impl_select_buffers( + struct_ident: &Ident, + buffer_struct_ident: &Ident, + buffer_struct_vis: &Visibility, + impl_generics: &ImplGenerics, + ty_generics: &TypeGenerics, + where_clause: &Option<&WhereClause>, + field_ident: &Vec<&Ident>, + buffer: &Vec<&Type>, +) -> TokenStream { + quote! { + impl #impl_generics #struct_ident #ty_generics #where_clause { + #buffer_struct_vis fn select_buffers( + #( + #field_ident: #buffer, + )* + ) -> #buffer_struct_ident #ty_generics { + #buffer_struct_ident { + #( + #field_ident, + )* + } + } + } + } + .into() +} + +fn impl_buffer_clone( + buffer_struct_ident: &Ident, + impl_generics: &ImplGenerics, + ty_generics: &TypeGenerics, + where_clause: &Option<&WhereClause>, + field_ident: &Vec<&Ident>, + noncopy: bool, +) -> TokenStream { + if noncopy { + // Clone impl for structs with a buffer that is not copyable + quote! { + impl #impl_generics ::std::clone::Clone for #buffer_struct_ident #ty_generics #where_clause { + fn clone(&self) -> Self { + Self { + #( + #field_ident: self.#field_ident.clone(), + )* + } + } + } + } + } else { + // Clone and copy impl for structs with buffers that are all copyable + quote! { + impl #impl_generics ::std::clone::Clone for #buffer_struct_ident #ty_generics #where_clause { + fn clone(&self) -> Self { + *self + } + } + + impl #impl_generics ::std::marker::Copy for #buffer_struct_ident #ty_generics #where_clause {} + } + } +} + +/// Params: +/// buffer_struct: The struct to implement `BufferMapLayout`. +/// item_struct: The struct which `buffer_struct` is derived from. +/// settings: [`FieldSettings`] to use when parsing the field attributes +fn impl_buffer_map_layout( + buffer_struct: &ItemStruct, + field_ident: &Vec<&Ident>, + field_config: &Vec, +) -> Result { + let struct_ident = &buffer_struct.ident; + let (impl_generics, ty_generics, where_clause) = buffer_struct.generics.split_for_impl(); + let buffer: Vec<&Type> = field_config.iter().map(|config| &config.buffer).collect(); + let map_key: Vec = field_ident.iter().map(|v| v.to_string()).collect(); + + Ok(quote! { + impl #impl_generics ::bevy_impulse::BufferMapLayout for #struct_ident #ty_generics #where_clause { + fn try_from_buffer_map(buffers: &::bevy_impulse::BufferMap) -> Result { + let mut compatibility = ::bevy_impulse::IncompatibleLayout::default(); + #( + let #field_ident = compatibility.require_buffer_for_identifier::<#buffer>(#map_key, buffers); + )* + + // Unwrap the Ok after inspecting every field so that the + // IncompatibleLayout error can include all information about + // which fields were incompatible. + #( + let Ok(#field_ident) = #field_ident else { + return Err(compatibility); + }; + )* + + Ok(Self { + #( + #field_ident, + )* + }) + } + } + + impl #impl_generics ::bevy_impulse::BufferMapStruct for #struct_ident #ty_generics #where_clause { + fn buffer_list(&self) -> ::smallvec::SmallVec<[AnyBuffer; 8]> { + use smallvec::smallvec; + smallvec![#( + ::bevy_impulse::AsAnyBuffer::as_any_buffer(&self.#field_ident), + )*] + } + } + } + .into()) +} + +/// Params: +/// joined_struct: The struct to implement `Joining`. +/// item_struct: The associated `Item` type to use for the `Joining` implementation. +fn impl_joined( + joined_struct: &ItemStruct, + item_struct: &ItemStruct, + field_ident: &Vec<&Ident>, +) -> Result { + let struct_ident = &joined_struct.ident; + let item_struct_ident = &item_struct.ident; + let (impl_generics, ty_generics, where_clause) = item_struct.generics.split_for_impl(); + + Ok(quote! { + impl #impl_generics ::bevy_impulse::Joining for #struct_ident #ty_generics #where_clause { + type Item = #item_struct_ident #ty_generics; + + fn pull(&self, session: ::bevy_impulse::re_exports::Entity, world: &mut ::bevy_impulse::re_exports::World) -> Result { + #( + let #field_ident = self.#field_ident.pull(session, world)?; + )* + + Ok(Self::Item {#( + #field_ident, + )*}) + } + } + }.into()) +} + +fn impl_accessed( + accessed_struct: &ItemStruct, + key_struct: &ItemStruct, + field_ident: &Vec<&Ident>, + field_type: &Vec<&Type>, +) -> Result { + let struct_ident = &accessed_struct.ident; + let key_struct_ident = &key_struct.ident; + let (impl_generics, ty_generics, where_clause) = key_struct.generics.split_for_impl(); + + Ok(quote! { + impl #impl_generics ::bevy_impulse::Accessing for #struct_ident #ty_generics #where_clause { + type Key = #key_struct_ident #ty_generics; + + fn add_accessor( + &self, + accessor: ::bevy_impulse::re_exports::Entity, + world: &mut ::bevy_impulse::re_exports::World, + ) -> ::bevy_impulse::OperationResult { + #( + ::bevy_impulse::Accessing::add_accessor(&self.#field_ident, accessor, world)?; + )* + Ok(()) + } + + fn create_key(&self, builder: &::bevy_impulse::BufferKeyBuilder) -> Self::Key { + Self::Key {#( + // TODO(@mxgrey): This currently does not have good support for the user + // substituting in a different key type than what the BufferKeyLifecycle expects. + // We could consider adding a .clone().into() to help support that use case, but + // this would be such a niche use case that I think we can ignore it for now. + #field_ident: <#field_type as ::bevy_impulse::BufferKeyLifecycle>::create_key(&self.#field_ident, builder), + )*} + } + + fn deep_clone_key(key: &Self::Key) -> Self::Key { + Self::Key {#( + #field_ident: ::bevy_impulse::BufferKeyLifecycle::deep_clone(&key.#field_ident), + )*} + } + + fn is_key_in_use(key: &Self::Key) -> bool { + false + #( + || ::bevy_impulse::BufferKeyLifecycle::is_in_use(&key.#field_ident) + )* + } + } + }.into()) +} diff --git a/macros/src/lib.rs b/macros/src/lib.rs index d40c9309..58873049 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -15,9 +15,12 @@ * */ +mod buffer; +use buffer::{impl_buffer_key_map, impl_joined_value}; + use proc_macro::TokenStream; use quote::quote; -use syn::DeriveInput; +use syn::{parse_macro_input, DeriveInput, ItemStruct}; #[proc_macro_derive(Stream)] pub fn simple_stream_macro(item: TokenStream) -> TokenStream { @@ -58,3 +61,30 @@ pub fn delivery_label_macro(item: TokenStream) -> TokenStream { } .into() } + +/// The result error is the compiler error message to be displayed. +type Result = std::result::Result; + +#[proc_macro_derive(Joined, attributes(joined))] +pub fn derive_joined_value(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as ItemStruct); + match impl_joined_value(&input) { + Ok(tokens) => tokens.into(), + Err(msg) => quote! { + compile_error!(#msg); + } + .into(), + } +} + +#[proc_macro_derive(Accessor, attributes(key))] +pub fn derive_buffer_key_map(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as ItemStruct); + match impl_buffer_key_map(&input) { + Ok(tokens) => tokens.into(), + Err(msg) => quote! { + compile_error!(#msg); + } + .into(), + } +} diff --git a/src/buffer.rs b/src/buffer.rs index bb0fe2b8..1d33c28d 100644 --- a/src/buffer.rs +++ b/src/buffer.rs @@ -17,28 +17,37 @@ use bevy_ecs::{ change_detection::Mut, - prelude::{Commands, Entity, Query}, + prelude::{Commands, Entity, Query, World}, query::QueryEntityError, - system::SystemParam, + system::{SystemParam, SystemState}, }; use std::{ops::RangeBounds, sync::Arc}; +use thiserror::Error as ThisError; + use crate::{ Builder, Chain, Gate, GateState, InputSlot, NotifyBufferUpdate, OnNewBufferValue, UnusedTarget, }; +mod any_buffer; +pub use any_buffer::*; + mod buffer_access_lifecycle; +pub use buffer_access_lifecycle::BufferKeyLifecycle; pub(crate) use buffer_access_lifecycle::*; mod buffer_key_builder; -pub(crate) use buffer_key_builder::*; +pub use buffer_key_builder::*; + +mod buffer_map; +pub use buffer_map::*; mod buffer_storage; pub(crate) use buffer_storage::*; -mod buffered; -pub use buffered::*; +mod buffering; +pub use buffering::*; mod bufferable; pub use bufferable::*; @@ -46,12 +55,16 @@ pub use bufferable::*; mod manage_buffer; pub use manage_buffer::*; +#[cfg(feature = "diagram")] +mod json_buffer; +#[cfg(feature = "diagram")] +pub use json_buffer::*; + /// A buffer is a special type of node within a workflow that is able to store /// and release data. When a session is finished, the buffered data from the /// session will be automatically cleared. pub struct Buffer { - pub(crate) scope: Entity, - pub(crate) source: Entity, + pub(crate) location: BufferLocation, pub(crate) _ignore: std::marker::PhantomData, } @@ -61,11 +74,11 @@ impl Buffer { &self, builder: &'b mut Builder<'w, 's, 'a>, ) -> Chain<'w, 's, 'a, 'b, ()> { - assert_eq!(self.scope, builder.scope); + assert_eq!(self.scope(), builder.scope); let target = builder.commands.spawn(UnusedTarget).id(); builder .commands - .add(OnNewBufferValue::new(self.source, target)); + .add(OnNewBufferValue::new(self.id(), target)); Chain::new(target, builder) } @@ -77,24 +90,86 @@ impl Buffer { T: Clone, { CloneFromBuffer { - scope: self.scope, - source: self.source, + location: self.location, _ignore: Default::default(), } } /// Get an input slot for this buffer. pub fn input_slot(self) -> InputSlot { - InputSlot::new(self.scope, self.source) + InputSlot::new(self.scope(), self.id()) + } + + /// Get the entity ID of the buffer. + pub fn id(&self) -> Entity { + self.location.source + } + + /// Get the ID of the workflow that the buffer is associated with. + pub fn scope(&self) -> Entity { + self.location.scope + } + + /// Get general information about the buffer. + pub fn location(&self) -> BufferLocation { + self.location + } +} + +impl Clone for Buffer { + fn clone(&self) -> Self { + *self } } +impl Copy for Buffer {} + +/// The general identifying information for a buffer to locate it within the +/// world. This does not indicate anything about the type of messages that the +/// buffer can contain. +#[derive(Clone, Copy, Debug)] +pub struct BufferLocation { + /// The entity ID of the buffer. + pub scope: Entity, + /// The ID of the workflow that the buffer is associated with. + pub source: Entity, +} + +#[derive(Clone)] pub struct CloneFromBuffer { - pub(crate) scope: Entity, - pub(crate) source: Entity, + pub(crate) location: BufferLocation, pub(crate) _ignore: std::marker::PhantomData, } +// +impl Copy for CloneFromBuffer {} + +impl CloneFromBuffer { + /// Get the entity ID of the buffer. + pub fn id(&self) -> Entity { + self.location.source + } + + /// Get the ID of the workflow that the buffer is associated with. + pub fn scope(&self) -> Entity { + self.location.scope + } + + /// Get general information about the buffer. + pub fn location(&self) -> BufferLocation { + self.location + } +} + +impl From> for Buffer { + fn from(value: CloneFromBuffer) -> Self { + Buffer { + location: value.location, + _ignore: Default::default(), + } + } +} + /// Settings to describe the behavior of a buffer. #[derive(Default, Clone, Copy)] pub struct BufferSettings { @@ -157,44 +232,22 @@ impl Default for RetentionPolicy { } } -impl Clone for Buffer { - fn clone(&self) -> Self { - *self - } -} - -impl Copy for Buffer {} - -impl Clone for CloneFromBuffer { - fn clone(&self) -> Self { - *self - } -} - -impl Copy for CloneFromBuffer {} - /// This key can unlock access to the contents of a buffer by passing it into /// [`BufferAccess`] or [`BufferAccessMut`]. /// /// To obtain a `BufferKey`, use [`Chain::with_access`][1], or [`listen`][2]. /// /// [1]: crate::Chain::with_access -/// [2]: crate::Bufferable::listen +/// [2]: crate::Accessible::listen pub struct BufferKey { - buffer: Entity, - session: Entity, - accessor: Entity, - lifecycle: Option>, + tag: BufferKeyTag, _ignore: std::marker::PhantomData, } impl Clone for BufferKey { fn clone(&self) -> Self { Self { - buffer: self.buffer, - session: self.session, - accessor: self.accessor, - lifecycle: self.lifecycle.as_ref().map(Arc::clone), + tag: self.tag.clone(), _ignore: Default::default(), } } @@ -202,28 +255,67 @@ impl Clone for BufferKey { impl BufferKey { /// The buffer ID of this key. - pub fn id(&self) -> Entity { - self.buffer + pub fn buffer(&self) -> Entity { + self.tag.buffer } /// The session that this key belongs to. pub fn session(&self) -> Entity { - self.session + self.tag.session + } + + pub fn tag(&self) -> &BufferKeyTag { + &self.tag + } +} + +impl BufferKeyLifecycle for BufferKey { + type TargetBuffer = Buffer; + + fn create_key(buffer: &Self::TargetBuffer, builder: &BufferKeyBuilder) -> Self { + BufferKey { + tag: builder.make_tag(buffer.id()), + _ignore: Default::default(), + } + } + + fn is_in_use(&self) -> bool { + self.tag.is_in_use() } - pub(crate) fn is_in_use(&self) -> bool { + fn deep_clone(&self) -> Self { + Self { + tag: self.tag.deep_clone(), + _ignore: Default::default(), + } + } +} + +impl std::fmt::Debug for BufferKey { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("BufferKey") + .field("message_type_name", &std::any::type_name::()) + .field("tag", &self.tag) + .finish() + } +} + +/// The identifying information for a buffer key. This does not indicate +/// anything about the type of messages that the buffer can contain. +#[derive(Clone)] +pub struct BufferKeyTag { + pub buffer: Entity, + pub session: Entity, + pub accessor: Entity, + pub lifecycle: Option>, +} + +impl BufferKeyTag { + pub fn is_in_use(&self) -> bool { self.lifecycle.as_ref().is_some_and(|l| l.is_in_use()) } - // We do a deep clone of the key when distributing it to decouple the - // lifecycle of the keys that we send out from the key that's held by the - // accessor node. - // - // The key instance held by the accessor node will never be dropped until - // the session is cleaned up, so the keys that we send out into the workflow - // need to have their own independent lifecycles or else we won't detect - // when the workflow has dropped them. - pub(crate) fn deep_clone(&self) -> Self { + pub fn deep_clone(&self) -> Self { let mut deep = self.clone(); deep.lifecycle = self .lifecycle @@ -233,6 +325,17 @@ impl BufferKey { } } +impl std::fmt::Debug for BufferKeyTag { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("BufferKeyTag") + .field("buffer", &self.buffer) + .field("session", &self.session) + .field("accessor", &self.accessor) + .field("in_use", &self.is_in_use()) + .finish() + } +} + /// This system parameter lets you get read-only access to a buffer that exists /// within a workflow. Use a [`BufferKey`] to unlock the access. /// @@ -247,9 +350,9 @@ where impl<'w, 's, T: 'static + Send + Sync> BufferAccess<'w, 's, T> { pub fn get<'a>(&'a self, key: &BufferKey) -> Result, QueryEntityError> { - let session = key.session; + let session = key.session(); self.query - .get(key.buffer) + .get(key.buffer()) .map(|(storage, gate)| BufferView { storage, gate, @@ -276,9 +379,9 @@ where T: 'static + Send + Sync, { pub fn get<'a>(&'a self, key: &BufferKey) -> Result, QueryEntityError> { - let session = key.session; + let session = key.session(); self.query - .get(key.buffer) + .get(key.buffer()) .map(|(storage, gate)| BufferView { storage, gate, @@ -290,15 +393,76 @@ where &'a mut self, key: &BufferKey, ) -> Result, QueryEntityError> { - let buffer = key.buffer; - let session = key.session; - let accessor = key.accessor; - self.query.get_mut(key.buffer).map(|(storage, gate)| { + let buffer = key.buffer(); + let session = key.session(); + let accessor = key.tag.accessor; + self.query.get_mut(key.buffer()).map(|(storage, gate)| { BufferMut::new(storage, gate, buffer, session, accessor, &mut self.commands) }) } } +/// This trait allows [`World`] to give you access to any buffer using a [`BufferKey`] +pub trait BufferWorldAccess { + /// Call this to get read-only access to a buffer from a [`World`]. + /// + /// Alternatively you can use [`BufferAccess`] as a regular bevy system parameter, + /// which does not need direct world access. + fn buffer_view(&self, key: &BufferKey) -> Result, BufferError> + where + T: 'static + Send + Sync; + + /// Call this to get mutable access to a buffer. + /// + /// Pass in a callback that will receive [`BufferMut`], allowing it to view + /// and modify the contents of the buffer. + fn buffer_mut( + &mut self, + key: &BufferKey, + f: impl FnOnce(BufferMut) -> U, + ) -> Result + where + T: 'static + Send + Sync; +} + +impl BufferWorldAccess for World { + fn buffer_view(&self, key: &BufferKey) -> Result, BufferError> + where + T: 'static + Send + Sync, + { + let buffer_ref = self + .get_entity(key.tag.buffer) + .ok_or(BufferError::BufferMissing)?; + let storage = buffer_ref + .get::>() + .ok_or(BufferError::BufferMissing)?; + let gate = buffer_ref + .get::() + .ok_or(BufferError::BufferMissing)?; + Ok(BufferView { + storage, + gate, + session: key.tag.session, + }) + } + + fn buffer_mut( + &mut self, + key: &BufferKey, + f: impl FnOnce(BufferMut) -> U, + ) -> Result + where + T: 'static + Send + Sync, + { + let mut state = SystemState::>::new(self); + let mut buffer_access_mut = state.get_mut(self); + let buffer_mut = buffer_access_mut + .get_mut(key) + .map_err(|_| BufferError::BufferMissing)?; + Ok(f(buffer_mut)) + } +} + /// Access to view a buffer that exists inside a workflow. pub struct BufferView<'a, T> where @@ -424,7 +588,7 @@ where self.len() == 0 } - /// Check whether the gate of this buffer is open or closed + /// Check whether the gate of this buffer is open or closed. pub fn gate(&self) -> Gate { self.gate .map @@ -467,7 +631,7 @@ where self.storage.drain(self.session, range) } - /// Pull the oldest item from the buffer + /// Pull the oldest item from the buffer. pub fn pull(&mut self) -> Option { self.modified = true; self.storage.pull(self.session) @@ -500,7 +664,7 @@ where // continuous systems with BufferAccessMut from running at the same time no // matter what the buffer type is. - /// Tell the buffer [`Gate`] to open + /// Tell the buffer [`Gate`] to open. pub fn open_gate(&mut self) { if let Some(gate) = self.gate.map.get_mut(&self.session) { if *gate != Gate::Open { @@ -510,7 +674,7 @@ where } } - /// Tell the buffer [`Gate`] to close + /// Tell the buffer [`Gate`] to close. pub fn close_gate(&mut self) { if let Some(gate) = self.gate.map.get_mut(&self.session) { *gate = Gate::Closed; @@ -519,7 +683,7 @@ where } } - /// Perform an action on the gate of the buffer + /// Perform an action on the gate of the buffer. pub fn gate_action(&mut self, action: Gate) { match action { Gate::Open => self.open_gate(), @@ -569,6 +733,12 @@ where } } +#[derive(ThisError, Debug, Clone)] +pub enum BufferError { + #[error("The key was unable to identify a buffer")] + BufferMissing, +} + #[cfg(test)] mod tests { use crate::{prelude::*, testing::*, Gate}; diff --git a/src/buffer/any_buffer.rs b/src/buffer/any_buffer.rs new file mode 100644 index 00000000..efde9907 --- /dev/null +++ b/src/buffer/any_buffer.rs @@ -0,0 +1,1327 @@ +/* + * Copyright (C) 2025 Open Source Robotics Foundation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * +*/ + +// TODO(@mxgrey): Add module-level documentation describing how to use AnyBuffer + +use std::{ + any::{Any, TypeId}, + collections::{hash_map::Entry, HashMap}, + ops::RangeBounds, + sync::{Mutex, OnceLock}, +}; + +use bevy_ecs::{ + prelude::{Commands, Entity, EntityRef, EntityWorldMut, Mut, World}, + system::SystemState, +}; + +use thiserror::Error as ThisError; + +use smallvec::SmallVec; + +use crate::{ + add_listener_to_source, Accessing, Buffer, BufferAccessMut, BufferAccessors, BufferError, + BufferKey, BufferKeyBuilder, BufferKeyLifecycle, BufferKeyTag, BufferLocation, BufferStorage, + Bufferable, Buffering, Builder, DrainBuffer, Gate, GateState, InspectBuffer, Joining, + ManageBuffer, NotifyBufferUpdate, OperationError, OperationResult, OperationRoster, OrBroken, +}; + +/// A [`Buffer`] whose message type has been anonymized. Joining with this buffer +/// type will yield an [`AnyMessageBox`]. +#[derive(Clone, Copy)] +pub struct AnyBuffer { + pub(crate) location: BufferLocation, + pub(crate) interface: &'static (dyn AnyBufferAccessInterface + Send + Sync), +} + +impl AnyBuffer { + /// The buffer ID for this key. + pub fn id(&self) -> Entity { + self.location.source + } + + /// ID of the workflow that this buffer is associated with. + pub fn scope(&self) -> Entity { + self.location.scope + } + + /// Get the type ID of the messages that this buffer supports. + pub fn message_type_id(&self) -> TypeId { + self.interface.message_type_id() + } + + pub fn message_type_name(&self) -> &'static str { + self.interface.message_type_name() + } + + /// Get the [`AnyBufferAccessInterface`] for this specific instance of [`AnyBuffer`]. + pub fn get_interface(&self) -> &'static (dyn AnyBufferAccessInterface + Send + Sync) { + self.interface + } + + /// Get the [`AnyBufferAccessInterface`] for a concrete message type. + pub fn interface_for( + ) -> &'static (dyn AnyBufferAccessInterface + Send + Sync) { + static INTERFACE_MAP: OnceLock< + Mutex>, + > = OnceLock::new(); + let interfaces = INTERFACE_MAP.get_or_init(|| Mutex::default()); + + // SAFETY: This will leak memory exactly once per type, so the leakage is bounded. + // Leaking this allows the interface to be shared freely across all instances. + let mut interfaces_mut = interfaces.lock().unwrap(); + *interfaces_mut + .entry(TypeId::of::()) + .or_insert_with(|| Box::leak(Box::new(AnyBufferAccessImpl::::new()))) + } +} + +impl std::fmt::Debug for AnyBuffer { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("AnyBuffer") + .field("scope", &self.location.scope) + .field("source", &self.location.source) + .field("message_type_name", &self.interface.message_type_name()) + .finish() + } +} + +impl AnyBuffer { + /// Downcast this into a concrete [`Buffer`] for the specified message type. + /// + /// To downcast into a specialized kind of buffer, use [`Self::downcast_buffer`] instead. + pub fn downcast_for_message(&self) -> Option> { + if TypeId::of::() == self.interface.message_type_id() { + Some(Buffer { + location: self.location, + _ignore: Default::default(), + }) + } else { + None + } + } + + /// Downcast this into a different special buffer representation, such as a + /// `JsonBuffer`. + pub fn downcast_buffer(&self) -> Option { + self.interface.buffer_downcast(TypeId::of::())?(self.location) + .downcast::() + .ok() + .map(|x| *x) + } +} + +impl From> for AnyBuffer { + fn from(value: Buffer) -> Self { + let interface = AnyBuffer::interface_for::(); + AnyBuffer { + location: value.location, + interface, + } + } +} + +/// A trait for turning a buffer into an [`AnyBuffer`]. It is expected that all +/// buffer types implement this trait. +pub trait AsAnyBuffer { + /// Convert this buffer into an [`AnyBuffer`]. + fn as_any_buffer(&self) -> AnyBuffer; +} + +impl AsAnyBuffer for AnyBuffer { + fn as_any_buffer(&self) -> AnyBuffer { + *self + } +} + +impl AsAnyBuffer for Buffer { + fn as_any_buffer(&self) -> AnyBuffer { + (*self).into() + } +} + +/// Similar to a [`BufferKey`] except it can be used for any buffer without +/// knowing the buffer's message type at compile time. +/// +/// This can key be used with a [`World`][1] to directly view or manipulate the +/// contents of a buffer through the [`AnyBufferWorldAccess`] interface. +/// +/// [1]: bevy_ecs::prelude::World +#[derive(Clone)] +pub struct AnyBufferKey { + pub(crate) tag: BufferKeyTag, + pub(crate) interface: &'static (dyn AnyBufferAccessInterface + Send + Sync), +} + +impl AnyBufferKey { + /// Downcast this into a concrete [`BufferKey`] for the specified message type. + /// + /// To downcast to a specialized kind of key, use [`Self::downcast_buffer_key`] instead. + pub fn downcast_for_message(self) -> Option> { + if TypeId::of::() == self.interface.message_type_id() { + Some(BufferKey { + tag: self.tag, + _ignore: Default::default(), + }) + } else { + None + } + } + + /// Downcast this into a different special buffer key representation, such + /// as a `JsonBufferKey`. + pub fn downcast_buffer_key(self) -> Option { + self.interface.key_downcast(TypeId::of::())?(self.tag) + .downcast::() + .ok() + .map(|x| *x) + } + + /// The buffer ID of this key. + pub fn id(&self) -> Entity { + self.tag.buffer + } + + /// The session that this key belongs to. + pub fn session(&self) -> Entity { + self.tag.session + } +} + +impl BufferKeyLifecycle for AnyBufferKey { + type TargetBuffer = AnyBuffer; + + fn create_key(buffer: &AnyBuffer, builder: &BufferKeyBuilder) -> Self { + AnyBufferKey { + tag: builder.make_tag(buffer.id()), + interface: buffer.interface, + } + } + + fn is_in_use(&self) -> bool { + self.tag.is_in_use() + } + + fn deep_clone(&self) -> Self { + Self { + tag: self.tag.deep_clone(), + interface: self.interface, + } + } +} + +impl std::fmt::Debug for AnyBufferKey { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("AnyBufferKey") + .field("message_type_name", &self.interface.message_type_name()) + .field("tag", &self.tag) + .finish() + } +} + +impl From> for AnyBufferKey { + fn from(value: BufferKey) -> Self { + let interface = AnyBuffer::interface_for::(); + AnyBufferKey { + tag: value.tag, + interface, + } + } +} + +/// Similar to [`BufferView`][crate::BufferView], but this can be unlocked with +/// an [`AnyBufferKey`], so it can work for any buffer whose message types +/// support serialization and deserialization. +pub struct AnyBufferView<'a> { + storage: Box, + gate: &'a GateState, + session: Entity, +} + +impl<'a> AnyBufferView<'a> { + /// Look at the oldest message in the buffer. + pub fn oldest(&self) -> Option> { + self.storage.any_oldest(self.session) + } + + /// Look at the newest message in the buffer. + pub fn newest(&self) -> Option> { + self.storage.any_newest(self.session) + } + + /// Borrow a message from the buffer. Index 0 is the oldest message in the buffer + /// while the highest index is the newest message in the buffer. + pub fn get(&self, index: usize) -> Option> { + self.storage.any_get(self.session, index) + } + + /// Get how many messages are in this buffer. + pub fn len(&self) -> usize { + self.storage.any_count(self.session) + } + + /// Check if the buffer is empty. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Check whether the gate of this buffer is open or closed. + pub fn gate(&self) -> Gate { + self.gate + .map + .get(&self.session) + .copied() + .unwrap_or(Gate::Open) + } +} + +/// Similar to [`BufferMut`][crate::BufferMut], but this can be unlocked with an +/// [`AnyBufferKey`], so it can work for any buffer regardless of the data type +/// inside. +pub struct AnyBufferMut<'w, 's, 'a> { + storage: Box, + gate: Mut<'a, GateState>, + buffer: Entity, + session: Entity, + accessor: Option, + commands: &'a mut Commands<'w, 's>, + modified: bool, +} + +impl<'w, 's, 'a> AnyBufferMut<'w, 's, 'a> { + /// Same as [BufferMut::allow_closed_loops][1]. + /// + /// [1]: crate::BufferMut::allow_closed_loops + pub fn allow_closed_loops(mut self) -> Self { + self.accessor = None; + self + } + + /// Look at the oldest message in the buffer. + pub fn oldest(&self) -> Option> { + self.storage.any_oldest(self.session) + } + + /// Look at the newest message in the buffer. + pub fn newest(&self) -> Option> { + self.storage.any_newest(self.session) + } + + /// Borrow a message from the buffer. Index 0 is the oldest message in the buffer + /// while the highest index is the newest message in the buffer. + pub fn get(&self, index: usize) -> Option> { + self.storage.any_get(self.session, index) + } + + /// Get how many messages are in this buffer. + pub fn len(&self) -> usize { + self.storage.any_count(self.session) + } + + /// Check if the buffer is empty. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Check whether the gate of this buffer is open or closed. + pub fn gate(&self) -> Gate { + self.gate + .map + .get(&self.session) + .copied() + .unwrap_or(Gate::Open) + } + + /// Modify the oldest message in the buffer. + pub fn oldest_mut(&mut self) -> Option> { + self.modified = true; + self.storage.any_oldest_mut(self.session) + } + + /// Modify the newest message in the buffer. + pub fn newest_mut(&mut self) -> Option> { + self.modified = true; + self.storage.any_newest_mut(self.session) + } + + /// Modify a message in the buffer. Index 0 is the oldest message in the buffer + /// with the highest index being the newest message in the buffer. + pub fn get_mut(&mut self, index: usize) -> Option> { + self.modified = true; + self.storage.any_get_mut(self.session, index) + } + + /// Drain a range of messages out of the buffer. + pub fn drain>(&mut self, range: R) -> DrainAnyBuffer<'_> { + self.modified = true; + DrainAnyBuffer { + interface: self.storage.any_drain(self.session, AnyRange::new(range)), + } + } + + /// Pull the oldest message from the buffer. + pub fn pull(&mut self) -> Option { + self.modified = true; + self.storage.any_pull(self.session) + } + + /// Pull the message that was most recently put into the buffer (instead of the + /// oldest, which is what [`Self::pull`] gives). + pub fn pull_newest(&mut self) -> Option { + self.modified = true; + self.storage.any_pull_newest(self.session) + } + + /// Attempt to push a new value into the buffer. + /// + /// If the input value matches the message type of the buffer, this will + /// return [`Ok`]. If the buffer is at its limit before a successful push, this + /// will return the value that needed to be removed. + /// + /// If the input value does not match the message type of the buffer, this + /// will return [`Err`] and give back the message that you tried to push. + pub fn push(&mut self, value: T) -> Result, T> { + if TypeId::of::() != self.storage.any_message_type() { + return Err(value); + } + + self.modified = true; + + // SAFETY: We checked that T matches the message type for this buffer, + // so pushing and downcasting should not exhibit any errors. + let removed = self + .storage + .any_push(self.session, Box::new(value)) + .unwrap() + .map(|value| *value.downcast::().unwrap()); + + Ok(removed) + } + + /// Attempt to push a new value of any message type into the buffer. + /// + /// If the input value matches the message type of the buffer, this will + /// return [`Ok`]. If the buffer is at its limit before a successful push, this + /// will return the value that needed to be removed. + /// + /// If the input value does not match the message type of the buffer, this + /// will return [`Err`] and give back an error with the message that you + /// tried to push and the type information for the expected message type. + pub fn push_any( + &mut self, + value: AnyMessageBox, + ) -> Result, AnyMessageError> { + self.storage.any_push(self.session, value) + } + + /// Attempt to push a value into the buffer as if it is the oldest value of + /// the buffer. + /// + /// The result follows the same rules as [`Self::push`]. + pub fn push_as_oldest( + &mut self, + value: T, + ) -> Result, T> { + if TypeId::of::() != self.storage.any_message_type() { + return Err(value); + } + + self.modified = true; + + // SAFETY: We checked that T matches the message type for this buffer, + // so pushing and downcasting should not exhibit any errors. + let removed = self + .storage + .any_push_as_oldest(self.session, Box::new(value)) + .unwrap() + .map(|value| *value.downcast::().unwrap()); + + Ok(removed) + } + + /// Attempt to push a value into the buffer as if it is the oldest value of + /// the buffer. + /// + /// The result follows the same rules as [`Self::push_any`]. + pub fn push_any_as_oldest( + &mut self, + value: AnyMessageBox, + ) -> Result, AnyMessageError> { + self.storage.any_push_as_oldest(self.session, value) + } + + /// Tell the buffer [`Gate`] to open. + pub fn open_gate(&mut self) { + if let Some(gate) = self.gate.map.get_mut(&self.session) { + if *gate != Gate::Open { + *gate = Gate::Open; + self.modified = true; + } + } + } + + /// Tell the buffer [`Gate`] to close. + pub fn close_gate(&mut self) { + if let Some(gate) = self.gate.map.get_mut(&self.session) { + *gate = Gate::Closed; + // There is no need to to indicate that a modification happened + // because listeners do not get notified about gates closing. + } + } + + /// Perform an action on the gate of the buffer. + pub fn gate_action(&mut self, action: Gate) { + match action { + Gate::Open => self.open_gate(), + Gate::Closed => self.close_gate(), + } + } + + /// Trigger the listeners for this buffer to wake up even if nothing in the + /// buffer has changed. This could be used for timers or timeout elements + /// in a workflow. + pub fn pulse(&mut self) { + self.modified = true; + } +} + +impl<'w, 's, 'a> Drop for AnyBufferMut<'w, 's, 'a> { + fn drop(&mut self) { + if self.modified { + self.commands.add(NotifyBufferUpdate::new( + self.buffer, + self.session, + self.accessor, + )); + } + } +} + +/// This trait allows [`World`] to give you access to any buffer using an +/// [`AnyBufferKey`]. +pub trait AnyBufferWorldAccess { + /// Call this to get read-only access to any buffer. + /// + /// For technical reasons this requires direct [`World`] access, but you can + /// do other read-only queries on the world while holding onto the + /// [`AnyBufferView`]. + fn any_buffer_view(&self, key: &AnyBufferKey) -> Result, BufferError>; + + /// Call this to get mutable access to any buffer. + /// + /// Pass in a callback that will receive a [`AnyBufferMut`], allowing it to + /// view and modify the contents of the buffer. + fn any_buffer_mut( + &mut self, + key: &AnyBufferKey, + f: impl FnOnce(AnyBufferMut) -> U, + ) -> Result; +} + +impl AnyBufferWorldAccess for World { + fn any_buffer_view(&self, key: &AnyBufferKey) -> Result, BufferError> { + key.interface.create_any_buffer_view(key, self) + } + + fn any_buffer_mut( + &mut self, + key: &AnyBufferKey, + f: impl FnOnce(AnyBufferMut) -> U, + ) -> Result { + let interface = key.interface; + let mut state = interface.create_any_buffer_access_mut_state(self); + let mut access = state.get_any_buffer_access_mut(self); + let buffer_mut = access.as_any_buffer_mut(key)?; + Ok(f(buffer_mut)) + } +} + +trait AnyBufferViewing { + fn any_count(&self, session: Entity) -> usize; + fn any_oldest<'a>(&'a self, session: Entity) -> Option>; + fn any_newest<'a>(&'a self, session: Entity) -> Option>; + fn any_get<'a>(&'a self, session: Entity, index: usize) -> Option>; + fn any_message_type(&self) -> TypeId; +} + +trait AnyBufferManagement: AnyBufferViewing { + fn any_push(&mut self, session: Entity, value: AnyMessageBox) -> AnyMessagePushResult; + fn any_push_as_oldest(&mut self, session: Entity, value: AnyMessageBox) + -> AnyMessagePushResult; + fn any_pull(&mut self, session: Entity) -> Option; + fn any_pull_newest(&mut self, session: Entity) -> Option; + fn any_oldest_mut<'a>(&'a mut self, session: Entity) -> Option>; + fn any_newest_mut<'a>(&'a mut self, session: Entity) -> Option>; + fn any_get_mut<'a>(&'a mut self, session: Entity, index: usize) -> Option>; + fn any_drain<'a>( + &'a mut self, + session: Entity, + range: AnyRange, + ) -> Box; +} + +pub(crate) struct AnyRange { + start_bound: std::ops::Bound, + end_bound: std::ops::Bound, +} + +impl AnyRange { + pub(crate) fn new>(range: T) -> Self { + AnyRange { + start_bound: deref_bound(range.start_bound()), + end_bound: deref_bound(range.end_bound()), + } + } +} + +fn deref_bound(bound: std::ops::Bound<&usize>) -> std::ops::Bound { + match bound { + std::ops::Bound::Included(v) => std::ops::Bound::Included(*v), + std::ops::Bound::Excluded(v) => std::ops::Bound::Excluded(*v), + std::ops::Bound::Unbounded => std::ops::Bound::Unbounded, + } +} + +impl std::ops::RangeBounds for AnyRange { + fn start_bound(&self) -> std::ops::Bound<&usize> { + self.start_bound.as_ref() + } + + fn end_bound(&self) -> std::ops::Bound<&usize> { + self.end_bound.as_ref() + } + + fn contains(&self, item: &U) -> bool + where + usize: PartialOrd, + U: ?Sized + PartialOrd, + { + match self.start_bound { + std::ops::Bound::Excluded(lower) => { + if *item <= lower { + return false; + } + } + std::ops::Bound::Included(lower) => { + if *item < lower { + return false; + } + } + _ => {} + } + + match self.end_bound { + std::ops::Bound::Excluded(upper) => { + if upper <= *item { + return false; + } + } + std::ops::Bound::Included(upper) => { + if upper < *item { + return false; + } + } + _ => {} + } + + return true; + } +} + +pub type AnyMessageRef<'a> = &'a (dyn Any + 'static + Send + Sync); + +impl AnyBufferViewing for &'_ BufferStorage { + fn any_count(&self, session: Entity) -> usize { + self.count(session) + } + + fn any_oldest<'a>(&'a self, session: Entity) -> Option> { + self.oldest(session).map(to_any_ref) + } + + fn any_newest<'a>(&'a self, session: Entity) -> Option> { + self.newest(session).map(to_any_ref) + } + + fn any_get<'a>(&'a self, session: Entity, index: usize) -> Option> { + self.get(session, index).map(to_any_ref) + } + + fn any_message_type(&self) -> TypeId { + TypeId::of::() + } +} + +impl AnyBufferViewing for Mut<'_, BufferStorage> { + fn any_count(&self, session: Entity) -> usize { + self.count(session) + } + + fn any_oldest<'a>(&'a self, session: Entity) -> Option> { + self.oldest(session).map(to_any_ref) + } + + fn any_newest<'a>(&'a self, session: Entity) -> Option> { + self.newest(session).map(to_any_ref) + } + + fn any_get<'a>(&'a self, session: Entity, index: usize) -> Option> { + self.get(session, index).map(to_any_ref) + } + + fn any_message_type(&self) -> TypeId { + TypeId::of::() + } +} + +pub type AnyMessageMut<'a> = &'a mut (dyn Any + 'static + Send + Sync); + +pub type AnyMessageBox = Box; + +#[derive(ThisError, Debug)] +#[error("failed to convert a message")] +pub struct AnyMessageError { + /// The original value provided + pub value: AnyMessageBox, + /// The ID of the type expected by the buffer + pub type_id: TypeId, + /// The name of the type expected by the buffer + pub type_name: &'static str, +} + +pub type AnyMessagePushResult = Result, AnyMessageError>; + +impl AnyBufferManagement for Mut<'_, BufferStorage> { + fn any_push(&mut self, session: Entity, value: AnyMessageBox) -> AnyMessagePushResult { + let value = from_any_message::(value)?; + Ok(self.push(session, value).map(to_any_message)) + } + + fn any_push_as_oldest( + &mut self, + session: Entity, + value: AnyMessageBox, + ) -> AnyMessagePushResult { + let value = from_any_message::(value)?; + Ok(self.push_as_oldest(session, value).map(to_any_message)) + } + + fn any_pull(&mut self, session: Entity) -> Option { + self.pull(session).map(to_any_message) + } + + fn any_pull_newest(&mut self, session: Entity) -> Option { + self.pull_newest(session).map(to_any_message) + } + + fn any_oldest_mut<'a>(&'a mut self, session: Entity) -> Option> { + self.oldest_mut(session).map(to_any_mut) + } + + fn any_newest_mut<'a>(&'a mut self, session: Entity) -> Option> { + self.newest_mut(session).map(to_any_mut) + } + + fn any_get_mut<'a>(&'a mut self, session: Entity, index: usize) -> Option> { + self.get_mut(session, index).map(to_any_mut) + } + + fn any_drain<'a>( + &'a mut self, + session: Entity, + range: AnyRange, + ) -> Box { + Box::new(self.drain(session, range)) + } +} + +fn to_any_ref<'a, T: 'static + Send + Sync + Any>(x: &'a T) -> AnyMessageRef<'a> { + x +} + +fn to_any_mut<'a, T: 'static + Send + Sync + Any>(x: &'a mut T) -> AnyMessageMut<'a> { + x +} + +fn to_any_message(x: T) -> AnyMessageBox { + Box::new(x) +} + +fn from_any_message( + value: AnyMessageBox, +) -> Result +where + T: 'static, +{ + let value = value.downcast::().map_err(|value| AnyMessageError { + value, + type_id: TypeId::of::(), + type_name: std::any::type_name::(), + })?; + + Ok(*value) +} + +pub trait AnyBufferAccessMutState { + fn get_any_buffer_access_mut<'s, 'w: 's>( + &'s mut self, + world: &'w mut World, + ) -> Box + 's>; +} + +impl AnyBufferAccessMutState + for SystemState> +{ + fn get_any_buffer_access_mut<'s, 'w: 's>( + &'s mut self, + world: &'w mut World, + ) -> Box + 's> { + Box::new(self.get_mut(world)) + } +} + +pub trait AnyBufferAccessMut<'w, 's> { + fn as_any_buffer_mut<'a>( + &'a mut self, + key: &AnyBufferKey, + ) -> Result, BufferError>; +} + +impl<'w, 's, T: 'static + Send + Sync + Any> AnyBufferAccessMut<'w, 's> + for BufferAccessMut<'w, 's, T> +{ + fn as_any_buffer_mut<'a>( + &'a mut self, + key: &AnyBufferKey, + ) -> Result, BufferError> { + let BufferAccessMut { query, commands } = self; + let (storage, gate) = query + .get_mut(key.tag.buffer) + .map_err(|_| BufferError::BufferMissing)?; + Ok(AnyBufferMut { + storage: Box::new(storage), + gate, + buffer: key.tag.buffer, + session: key.tag.session, + accessor: Some(key.tag.accessor), + commands, + modified: false, + }) + } +} + +pub trait AnyBufferAccessInterface { + fn message_type_id(&self) -> TypeId; + + fn message_type_name(&self) -> &'static str; + + fn buffered_count(&self, entity: &EntityRef, session: Entity) -> Result; + + fn ensure_session(&self, entity_mut: &mut EntityWorldMut, session: Entity) -> OperationResult; + + fn register_buffer_downcast(&self, buffer_type: TypeId, f: BufferDowncastBox); + + fn buffer_downcast(&self, buffer_type: TypeId) -> Option; + + fn register_key_downcast(&self, key_type: TypeId, f: KeyDowncastBox); + + fn key_downcast(&self, key_type: TypeId) -> Option; + + fn pull( + &self, + entity_mut: &mut EntityWorldMut, + session: Entity, + ) -> Result; + + fn create_any_buffer_view<'a>( + &self, + key: &AnyBufferKey, + world: &'a World, + ) -> Result, BufferError>; + + fn create_any_buffer_access_mut_state( + &self, + world: &mut World, + ) -> Box; +} + +pub type BufferDowncastBox = Box Box + Send + Sync>; +pub type BufferDowncastRef = &'static (dyn Fn(BufferLocation) -> Box + Send + Sync); +pub type KeyDowncastBox = Box Box + Send + Sync>; +pub type KeyDowncastRef = &'static (dyn Fn(BufferKeyTag) -> Box + Send + Sync); + +struct AnyBufferAccessImpl { + buffer_downcasts: Mutex>, + key_downcasts: Mutex>, + _ignore: std::marker::PhantomData, +} + +impl AnyBufferAccessImpl { + fn new() -> Self { + let mut buffer_downcasts: HashMap<_, BufferDowncastRef> = HashMap::new(); + + // SAFETY: These leaks are okay because we will only ever instantiate + // AnyBufferAccessImpl once per generic argument T, which puts a firm + // ceiling on how many of these callbacks will get leaked. + + // Automatically register a downcast into AnyBuffer + buffer_downcasts.insert( + TypeId::of::(), + Box::leak(Box::new(|location| -> Box { + Box::new(AnyBuffer { + location, + interface: AnyBuffer::interface_for::(), + }) + })), + ); + + // Allow downcasting back to the original Buffer + buffer_downcasts.insert( + TypeId::of::>(), + Box::leak(Box::new(|location| -> Box { + Box::new(Buffer:: { + location, + _ignore: Default::default(), + }) + })), + ); + + let mut key_downcasts: HashMap<_, KeyDowncastRef> = HashMap::new(); + + // Automatically register a downcast to AnyBufferKey + key_downcasts.insert( + TypeId::of::(), + Box::leak(Box::new(|tag| -> Box { + Box::new(AnyBufferKey { + tag, + interface: AnyBuffer::interface_for::(), + }) + })), + ); + + Self { + buffer_downcasts: Mutex::new(buffer_downcasts), + key_downcasts: Mutex::new(key_downcasts), + _ignore: Default::default(), + } + } +} + +impl AnyBufferAccessInterface for AnyBufferAccessImpl { + fn message_type_id(&self) -> TypeId { + TypeId::of::() + } + + fn message_type_name(&self) -> &'static str { + std::any::type_name::() + } + + fn buffered_count(&self, entity: &EntityRef, session: Entity) -> Result { + entity.buffered_count::(session) + } + + fn ensure_session(&self, entity_mut: &mut EntityWorldMut, session: Entity) -> OperationResult { + entity_mut.ensure_session::(session) + } + + fn register_buffer_downcast(&self, buffer_type: TypeId, f: BufferDowncastBox) { + let mut downcasts = self.buffer_downcasts.lock().unwrap(); + + if let Entry::Vacant(entry) = downcasts.entry(buffer_type) { + // SAFETY: We only leak this into the register once per type + entry.insert(Box::leak(f)); + } + } + + fn buffer_downcast(&self, buffer_type: TypeId) -> Option { + self.buffer_downcasts + .lock() + .unwrap() + .get(&buffer_type) + .copied() + } + + fn register_key_downcast(&self, key_type: TypeId, f: KeyDowncastBox) { + let mut downcasts = self.key_downcasts.lock().unwrap(); + + if let Entry::Vacant(entry) = downcasts.entry(key_type) { + // SAFTY: We only leak this in to the register once per type + entry.insert(Box::leak(f)); + } + } + + fn key_downcast(&self, key_type: TypeId) -> Option { + self.key_downcasts.lock().unwrap().get(&key_type).copied() + } + + fn pull( + &self, + entity_mut: &mut EntityWorldMut, + session: Entity, + ) -> Result { + entity_mut + .pull_from_buffer::(session) + .map(to_any_message) + } + + fn create_any_buffer_view<'a>( + &self, + key: &AnyBufferKey, + world: &'a World, + ) -> Result, BufferError> { + let buffer_ref = world + .get_entity(key.tag.buffer) + .ok_or(BufferError::BufferMissing)?; + let storage = buffer_ref + .get::>() + .ok_or(BufferError::BufferMissing)?; + let gate = buffer_ref + .get::() + .ok_or(BufferError::BufferMissing)?; + Ok(AnyBufferView { + storage: Box::new(storage), + gate, + session: key.tag.session, + }) + } + + fn create_any_buffer_access_mut_state( + &self, + world: &mut World, + ) -> Box { + Box::new(SystemState::>::new(world)) + } +} + +pub struct DrainAnyBuffer<'a> { + interface: Box, +} + +impl<'a> Iterator for DrainAnyBuffer<'a> { + type Item = AnyMessageBox; + + fn next(&mut self) -> Option { + self.interface.any_next() + } +} + +trait DrainAnyBufferInterface { + fn any_next(&mut self) -> Option; +} + +impl DrainAnyBufferInterface for DrainBuffer<'_, T> { + fn any_next(&mut self) -> Option { + self.next().map(to_any_message) + } +} + +impl Bufferable for AnyBuffer { + type BufferType = Self; + fn into_buffer(self, builder: &mut Builder) -> Self::BufferType { + assert_eq!(self.scope(), builder.scope()); + self + } +} + +impl Buffering for AnyBuffer { + fn verify_scope(&self, scope: Entity) { + assert_eq!(scope, self.scope()); + } + + fn buffered_count(&self, session: Entity, world: &World) -> Result { + let entity_ref = world.get_entity(self.id()).or_broken()?; + self.interface.buffered_count(&entity_ref, session) + } + + fn add_listener(&self, listener: Entity, world: &mut World) -> OperationResult { + add_listener_to_source(self.id(), listener, world) + } + + fn gate_action( + &self, + session: Entity, + action: Gate, + world: &mut World, + roster: &mut OperationRoster, + ) -> OperationResult { + GateState::apply(self.id(), session, action, world, roster) + } + + fn as_input(&self) -> SmallVec<[Entity; 8]> { + SmallVec::from_iter([self.id()]) + } + + fn ensure_active_session(&self, session: Entity, world: &mut World) -> OperationResult { + let mut entity_mut = world.get_entity_mut(self.id()).or_broken()?; + self.interface.ensure_session(&mut entity_mut, session) + } +} + +impl Joining for AnyBuffer { + type Item = AnyMessageBox; + fn pull(&self, session: Entity, world: &mut World) -> Result { + let mut buffer_mut = world.get_entity_mut(self.id()).or_broken()?; + self.interface.pull(&mut buffer_mut, session) + } +} + +impl Accessing for AnyBuffer { + type Key = AnyBufferKey; + fn add_accessor(&self, accessor: Entity, world: &mut World) -> OperationResult { + world + .get_mut::(self.id()) + .or_broken()? + .add_accessor(accessor); + Ok(()) + } + + fn create_key(&self, builder: &super::BufferKeyBuilder) -> Self::Key { + AnyBufferKey { + tag: builder.make_tag(self.id()), + interface: self.interface, + } + } + + fn deep_clone_key(key: &Self::Key) -> Self::Key { + key.deep_clone() + } + + fn is_key_in_use(key: &Self::Key) -> bool { + key.is_in_use() + } +} + +#[cfg(test)] +mod tests { + use crate::{prelude::*, testing::*}; + use bevy_ecs::prelude::World; + + #[test] + fn test_any_count() { + let mut context = TestingContext::minimal_plugins(); + + let workflow = context.spawn_io_workflow(|scope, builder| { + let buffer = builder.create_buffer(BufferSettings::keep_all()); + let push_multiple_times = builder + .commands() + .spawn_service(push_multiple_times_into_buffer.into_blocking_service()); + let count = builder + .commands() + .spawn_service(get_buffer_count.into_blocking_service()); + + scope + .input + .chain(builder) + .with_access(buffer) + .then(push_multiple_times) + .then(count) + .connect(scope.terminate); + }); + + let mut promise = context.command(|commands| commands.request(1, workflow).take_response()); + + context.run_with_conditions(&mut promise, Duration::from_secs(2)); + let count = promise.take().available().unwrap(); + assert_eq!(count, 5); + assert!(context.no_unhandled_errors()); + } + + fn push_multiple_times_into_buffer( + In((value, key)): In<(usize, BufferKey)>, + mut access: BufferAccessMut, + ) -> AnyBufferKey { + let mut buffer = access.get_mut(&key).unwrap(); + for _ in 0..5 { + buffer.push(value); + } + + key.into() + } + + fn get_buffer_count(In(key): In, world: &mut World) -> usize { + world.any_buffer_view(&key).unwrap().len() + } + + #[test] + fn test_modify_any_message() { + let mut context = TestingContext::minimal_plugins(); + + let workflow = context.spawn_io_workflow(|scope, builder| { + let buffer = builder.create_buffer(BufferSettings::keep_all()); + let push_multiple_times = builder + .commands() + .spawn_service(push_multiple_times_into_buffer.into_blocking_service()); + let modify_content = builder + .commands() + .spawn_service(modify_buffer_content.into_blocking_service()); + let drain_content = builder + .commands() + .spawn_service(pull_each_buffer_item.into_blocking_service()); + + scope + .input + .chain(builder) + .with_access(buffer) + .then(push_multiple_times) + .then(modify_content) + .then(drain_content) + .connect(scope.terminate); + }); + + let mut promise = context.command(|commands| commands.request(3, workflow).take_response()); + + context.run_with_conditions(&mut promise, Duration::from_secs(2)); + let values = promise.take().available().unwrap(); + assert_eq!(values, vec![0, 3, 6, 9, 12]); + assert!(context.no_unhandled_errors()); + } + + fn modify_buffer_content(In(key): In, world: &mut World) -> AnyBufferKey { + world + .any_buffer_mut(&key, |mut access| { + for i in 0..access.len() { + access.get_mut(i).map(|value| { + *value.downcast_mut::().unwrap() *= i; + }); + } + }) + .unwrap(); + + key + } + + fn pull_each_buffer_item(In(key): In, world: &mut World) -> Vec { + world + .any_buffer_mut(&key, |mut access| { + let mut values = Vec::new(); + while let Some(value) = access.pull() { + values.push(*value.downcast::().unwrap()); + } + values + }) + .unwrap() + } + + #[test] + fn test_drain_any_message() { + let mut context = TestingContext::minimal_plugins(); + + let workflow = context.spawn_io_workflow(|scope, builder| { + let buffer = builder.create_buffer(BufferSettings::keep_all()); + let push_multiple_times = builder + .commands() + .spawn_service(push_multiple_times_into_buffer.into_blocking_service()); + let modify_content = builder + .commands() + .spawn_service(modify_buffer_content.into_blocking_service()); + let drain_content = builder + .commands() + .spawn_service(drain_buffer_contents.into_blocking_service()); + + scope + .input + .chain(builder) + .with_access(buffer) + .then(push_multiple_times) + .then(modify_content) + .then(drain_content) + .connect(scope.terminate); + }); + + let mut promise = context.command(|commands| commands.request(3, workflow).take_response()); + + context.run_with_conditions(&mut promise, Duration::from_secs(2)); + let values = promise.take().available().unwrap(); + assert_eq!(values, vec![0, 3, 6, 9, 12]); + assert!(context.no_unhandled_errors()); + } + + fn drain_buffer_contents(In(key): In, world: &mut World) -> Vec { + world + .any_buffer_mut(&key, |mut access| { + access + .drain(..) + .map(|value| *value.downcast::().unwrap()) + .collect() + }) + .unwrap() + } + + #[test] + fn double_any_messages() { + let mut context = TestingContext::minimal_plugins(); + + let workflow = + context.spawn_io_workflow(|scope: Scope<(u32, i32, f32), (u32, i32, f32)>, builder| { + let buffer_u32: AnyBuffer = builder + .create_buffer::(BufferSettings::default()) + .into(); + let buffer_i32: AnyBuffer = builder + .create_buffer::(BufferSettings::default()) + .into(); + let buffer_f32: AnyBuffer = builder + .create_buffer::(BufferSettings::default()) + .into(); + + let (input_u32, input_i32, input_f32) = scope.input.chain(builder).unzip(); + input_u32.chain(builder).map_block(|v| 2 * v).connect( + buffer_u32 + .downcast_for_message::() + .unwrap() + .input_slot(), + ); + + input_i32.chain(builder).map_block(|v| 2 * v).connect( + buffer_i32 + .downcast_for_message::() + .unwrap() + .input_slot(), + ); + + input_f32.chain(builder).map_block(|v| 2.0 * v).connect( + buffer_f32 + .downcast_for_message::() + .unwrap() + .input_slot(), + ); + + (buffer_u32, buffer_i32, buffer_f32) + .join(builder) + .map_block(|(value_u32, value_i32, value_f32)| { + ( + *value_u32.downcast::().unwrap(), + *value_i32.downcast::().unwrap(), + *value_f32.downcast::().unwrap(), + ) + }) + .connect(scope.terminate); + }); + + let mut promise = context.command(|commands| { + commands + .request((1u32, 2i32, 3f32), workflow) + .take_response() + }); + + context.run_with_conditions(&mut promise, Duration::from_secs(2)); + let (v_u32, v_i32, v_f32) = promise.take().available().unwrap(); + assert_eq!(v_u32, 2); + assert_eq!(v_i32, 4); + assert_eq!(v_f32, 6.0); + assert!(context.no_unhandled_errors()); + } +} diff --git a/src/buffer/buffer_access_lifecycle.rs b/src/buffer/buffer_access_lifecycle.rs index d368a484..b7596fac 100644 --- a/src/buffer/buffer_access_lifecycle.rs +++ b/src/buffer/buffer_access_lifecycle.rs @@ -21,7 +21,7 @@ use tokio::sync::mpsc::UnboundedSender as TokioSender; use std::sync::Arc; -use crate::{emit_disposal, ChannelItem, Disposal, OperationRoster}; +use crate::{emit_disposal, BufferKeyBuilder, ChannelItem, Disposal, OperationRoster}; /// This is used as a field inside of [`crate::BufferKey`] which keeps track of /// when a key that was sent out into the world gets fully dropped from use. We @@ -29,7 +29,7 @@ use crate::{emit_disposal, ChannelItem, Disposal, OperationRoster}; /// we would be needlessly doing a reachability check every time the key gets /// cloned. #[derive(Clone)] -pub(crate) struct BufferAccessLifecycle { +pub struct BufferAccessLifecycle { scope: Entity, accessor: Entity, session: Entity, @@ -87,3 +87,29 @@ impl Drop for BufferAccessLifecycle { } } } + +/// This trait is implemented by [`crate::BufferKey`]-like structs so their +/// lifecycles can be managed. +pub trait BufferKeyLifecycle { + /// What kind of buffer this key can unlock. + type TargetBuffer; + + /// Create a new key of this type. + fn create_key(buffer: &Self::TargetBuffer, builder: &BufferKeyBuilder) -> Self; + + /// Check if the key is currently in use. + fn is_in_use(&self) -> bool; + + /// Create a deep clone of the key. The usage tracking of the clone will + /// be unrelated to the usage tracking of the original. + /// + /// We do a deep clone of the key when distributing it to decouple the + /// lifecycle of the keys that we send out from the key that's held by the + /// accessor node. + // + /// The key instance held by the accessor node will never be dropped until + /// the session is cleaned up, so the keys that we send out into the workflow + /// need to have their own independent lifecycles or else we won't detect + /// when the workflow has dropped them. + fn deep_clone(&self) -> Self; +} diff --git a/src/buffer/buffer_key_builder.rs b/src/buffer/buffer_key_builder.rs index 02e4664d..e1866e2e 100644 --- a/src/buffer/buffer_key_builder.rs +++ b/src/buffer/buffer_key_builder.rs @@ -19,7 +19,7 @@ use bevy_ecs::prelude::Entity; use std::sync::Arc; -use crate::{BufferAccessLifecycle, BufferKey, ChannelSender}; +use crate::{BufferAccessLifecycle, BufferKeyTag, ChannelSender}; pub struct BufferKeyBuilder { scope: Entity, @@ -29,8 +29,9 @@ pub struct BufferKeyBuilder { } impl BufferKeyBuilder { - pub(crate) fn build(&self, buffer: Entity) -> BufferKey { - BufferKey { + /// Make a [`BufferKeyTag`] that can be given to a [`crate::BufferKey`]-like struct. + pub fn make_tag(&self, buffer: Entity) -> BufferKeyTag { + BufferKeyTag { buffer, session: self.session, accessor: self.accessor, @@ -44,7 +45,6 @@ impl BufferKeyBuilder { tracker.clone(), )) }), - _ignore: Default::default(), } } diff --git a/src/buffer/buffer_map.rs b/src/buffer/buffer_map.rs new file mode 100644 index 00000000..04fe4ea7 --- /dev/null +++ b/src/buffer/buffer_map.rs @@ -0,0 +1,894 @@ +/* + * Copyright (C) 2025 Open Source Robotics Foundation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * +*/ + +use std::{borrow::Cow, collections::HashMap}; + +use thiserror::Error as ThisError; + +use smallvec::SmallVec; + +use bevy_ecs::prelude::{Entity, World}; + +use crate::{ + add_listener_to_source, Accessing, AnyBuffer, AnyBufferKey, AnyMessageBox, AsAnyBuffer, Buffer, + BufferKeyBuilder, BufferKeyLifecycle, Bufferable, Buffering, Builder, Chain, Gate, GateState, + Joining, Node, OperationError, OperationResult, OperationRoster, +}; + +pub use bevy_impulse_derive::{Accessor, Joined}; + +/// Uniquely identify a buffer within a buffer map, either by name or by an +/// index value. +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum BufferIdentifier<'a> { + /// Identify a buffer by name + Name(Cow<'a, str>), + /// Identify a buffer by an index value + Index(usize), +} + +impl BufferIdentifier<'static> { + /// Clone a name to use as an identifier. + pub fn clone_name(name: &str) -> Self { + BufferIdentifier::Name(Cow::Owned(name.to_owned())) + } + + /// Borrow a string literal name to use as an identifier. + pub fn literal_name(name: &'static str) -> Self { + BufferIdentifier::Name(Cow::Borrowed(name)) + } + + /// Use an index as an identifier. + pub fn index(index: usize) -> Self { + BufferIdentifier::Index(index) + } +} + +impl<'a> From<&'a str> for BufferIdentifier<'a> { + fn from(value: &'a str) -> Self { + BufferIdentifier::Name(Cow::Borrowed(value)) + } +} + +impl<'a> From for BufferIdentifier<'a> { + fn from(value: usize) -> Self { + BufferIdentifier::Index(value) + } +} + +pub type BufferMap = HashMap, AnyBuffer>; + +/// Extension trait that makes it more convenient to insert buffers into a [`BufferMap`]. +pub trait AddBufferToMap { + /// Convenience function for inserting items into a [`BufferMap`]. This + /// automatically takes care of converting the types. + fn insert_buffer>, B: AsAnyBuffer>( + &mut self, + identifier: I, + buffer: B, + ); +} + +impl AddBufferToMap for BufferMap { + fn insert_buffer>, B: AsAnyBuffer>( + &mut self, + identifier: I, + buffer: B, + ) { + self.insert(identifier.into(), buffer.as_any_buffer()); + } +} + +/// This error is used when the buffers provided for an input are not compatible +/// with the layout. +#[derive(ThisError, Debug, Clone, Default)] +#[error("the incoming buffer map is incompatible with the layout")] +pub struct IncompatibleLayout { + /// Identities of buffers that were missing from the incoming buffer map. + pub missing_buffers: Vec>, + /// Identities of buffers in the incoming buffer map which cannot exist in + /// the target layout. + pub forbidden_buffers: Vec>, + /// Buffers whose expected type did not match the received type. + pub incompatible_buffers: Vec, +} + +impl IncompatibleLayout { + /// Convert this into an error if it has any contents inside. + pub fn as_result(self) -> Result<(), Self> { + if !self.missing_buffers.is_empty() { + return Err(self); + } + + if !self.incompatible_buffers.is_empty() { + return Err(self); + } + + Ok(()) + } + + /// Check whether the buffer associated with the identifier is compatible with + /// the required buffer type. You can pass in a `&static str` or a `usize` + /// directly as the identifier. + /// + /// ``` + /// # use bevy_impulse::prelude::*; + /// + /// let buffer_map = BufferMap::default(); + /// let mut compatibility = IncompatibleLayout::default(); + /// let buffer = compatibility.require_buffer_for_identifier::>("some_field", &buffer_map); + /// assert!(buffer.is_err()); + /// assert!(compatibility.as_result().is_err()); + /// + /// let mut compatibility = IncompatibleLayout::default(); + /// let buffer = compatibility.require_buffer_for_identifier::>(10, &buffer_map); + /// assert!(buffer.is_err()); + /// assert!(compatibility.as_result().is_err()); + /// ``` + pub fn require_buffer_for_identifier( + &mut self, + identifier: impl Into>, + buffers: &BufferMap, + ) -> Result { + let identifier = identifier.into(); + if let Some(buffer) = buffers.get(&identifier) { + if let Some(buffer) = buffer.downcast_buffer::() { + return Ok(buffer); + } else { + self.incompatible_buffers.push(BufferIncompatibility { + identifier, + expected: std::any::type_name::(), + received: buffer.message_type_name(), + }); + } + } else { + self.missing_buffers.push(identifier); + } + + Err(()) + } + + /// Same as [`Self::require_buffer_for_identifier`], but can be used with + /// temporary borrows of a string slice. The string slice will be cloned if + /// an error message needs to be produced. + pub fn require_buffer_for_borrowed_name( + &mut self, + expected_name: &str, + buffers: &BufferMap, + ) -> Result { + let identifier = BufferIdentifier::Name(Cow::Borrowed(expected_name)); + if let Some(buffer) = buffers.get(&identifier) { + if let Some(buffer) = buffer.downcast_buffer::() { + return Ok(buffer); + } else { + self.incompatible_buffers.push(BufferIncompatibility { + identifier: BufferIdentifier::Name(Cow::Owned(expected_name.to_owned())), + expected: std::any::type_name::(), + received: buffer.message_type_name(), + }); + } + } else { + self.missing_buffers + .push(BufferIdentifier::Name(Cow::Owned(expected_name.to_owned()))); + } + + Err(()) + } +} + +/// Difference between the expected and received types of a named buffer. +#[derive(Debug, Clone)] +pub struct BufferIncompatibility { + /// Name of the expected buffer + pub identifier: BufferIdentifier<'static>, + /// The type that was expected for this buffer + pub expected: &'static str, + /// The type that was received for this buffer + pub received: &'static str, + // TODO(@mxgrey): Replace TypeId with TypeInfo +} + +/// This trait can be implemented on structs that represent a layout of buffers. +/// You do not normally have to implement this yourself. Instead you should +/// `#[derive(Joined)]` on a struct that you want a join operation to +/// produce. +pub trait BufferMapLayout: Sized + Clone + 'static + Send + Sync { + /// Try to convert a generic [`BufferMap`] into this specific layout. + fn try_from_buffer_map(buffers: &BufferMap) -> Result; +} + +/// This trait helps auto-generated buffer map structs to implement the Buffering +/// trait. +pub trait BufferMapStruct: Sized + Clone + 'static + Send + Sync { + /// Produce a list of the buffers that exist in this layout. Implementing + /// this function alone is sufficient to implement the entire [`Buffering`] trait. + fn buffer_list(&self) -> SmallVec<[AnyBuffer; 8]>; +} + +impl Bufferable for T { + type BufferType = Self; + + fn into_buffer(self, _: &mut Builder) -> Self::BufferType { + self + } +} + +impl Buffering for T { + fn verify_scope(&self, scope: Entity) { + for buffer in self.buffer_list() { + assert_eq!(buffer.scope(), scope); + } + } + + fn buffered_count(&self, session: Entity, world: &World) -> Result { + let mut min_count = None; + + for buffer in self.buffer_list() { + let count = buffer.buffered_count(session, world)?; + min_count = if min_count.is_some_and(|m| m < count) { + min_count + } else { + Some(count) + }; + } + + Ok(min_count.unwrap_or(0)) + } + + fn ensure_active_session(&self, session: Entity, world: &mut World) -> OperationResult { + for buffer in self.buffer_list() { + buffer.ensure_active_session(session, world)?; + } + + Ok(()) + } + + fn add_listener(&self, listener: Entity, world: &mut World) -> OperationResult { + for buffer in self.buffer_list() { + add_listener_to_source(buffer.id(), listener, world)?; + } + Ok(()) + } + + fn gate_action( + &self, + session: Entity, + action: Gate, + world: &mut World, + roster: &mut OperationRoster, + ) -> OperationResult { + for buffer in self.buffer_list() { + GateState::apply(buffer.id(), session, action, world, roster)?; + } + Ok(()) + } + + fn as_input(&self) -> SmallVec<[Entity; 8]> { + let mut inputs = SmallVec::new(); + for buffer in self.buffer_list() { + inputs.push(buffer.id()); + } + inputs + } +} + +/// This trait can be implemented for structs that are created by joining together +/// values from a collection of buffers. This allows [`join`][1] to produce arbitrary +/// structs. Structs with this trait can be produced by [`try_join`][2]. +/// +/// Each field in this struct needs to have the trait bounds `'static + Send + Sync`. +/// +/// This does not generally need to be implemented explicitly. Instead you should +/// use `#[derive(Joined)]`: +/// +/// ``` +/// use bevy_impulse::prelude::*; +/// +/// #[derive(Joined)] +/// struct SomeValues { +/// integer: i64, +/// string: String, +/// } +/// ``` +/// +/// The above example would allow you to join a value from an `i64` buffer with +/// a value from a `String` buffer. You can have as many fields in the struct +/// as you'd like. +/// +/// This macro will generate a struct of buffers to match the fields of the +/// struct that it's applied to. The name of that struct is anonymous by default +/// since you don't generally need to use it directly, but if you want to give +/// it a name you can use #[joined(buffers_struct_name = ...)]`: +/// +/// ``` +/// # use bevy_impulse::prelude::*; +/// +/// #[derive(Joined)] +/// #[joined(buffers_struct_name = SomeBuffers)] +/// struct SomeValues { +/// integer: i64, +/// string: String, +/// } +/// ``` +/// +/// By default each field of the generated buffers struct will have a type of +/// [`Buffer`], but you can override this using `#[joined(buffer = ...)]` +/// to specify a special buffer type. For example if your `Joined` struct +/// contains an [`AnyMessageBox`] then by default the macro will use `Buffer`, +/// but you probably really want it to have an [`AnyBuffer`]: +/// +/// ``` +/// # use bevy_impulse::prelude::*; +/// +/// #[derive(Joined)] +/// struct SomeValues { +/// integer: i64, +/// string: String, +/// #[joined(buffer = AnyBuffer)] +/// any: AnyMessageBox, +/// } +/// ``` +/// +/// The above method also works for joining a `JsonMessage` field from a `JsonBuffer`. +/// +/// [1]: crate::Builder::join +/// [2]: crate::Builder::try_join +pub trait Joined: 'static + Send + Sync + Sized { + /// This associated type must represent a buffer map layout that implements + /// the [`Joining`] trait. The message type yielded by [`Joining`] for this + /// associated type must match the [`Joined`] type. + type Buffers: 'static + BufferMapLayout + Joining + Send + Sync; + + /// Used by [`Builder::try_join`] + fn try_join_from<'w, 's, 'a, 'b>( + buffers: &BufferMap, + builder: &'b mut Builder<'w, 's, 'a>, + ) -> Result, IncompatibleLayout> { + let buffers: Self::Buffers = Self::Buffers::try_from_buffer_map(buffers)?; + Ok(buffers.join(builder)) + } +} + +/// Trait to describe a set of buffer keys. This allows [listen][1] and [access][2] +/// to work for arbitrary structs of buffer keys. Structs with this trait can be +/// produced by [`try_listen`][3] and [`try_create_buffer_access`][4]. +/// +/// Each field in the struct must be some kind of buffer key. +/// +/// This does not generally need to be implemented explicitly. Instead you should +/// define a struct where all fields are buffer keys and then apply +/// `#[derive(Accessor)]` to it, e.g.: +/// +/// ``` +/// use bevy_impulse::prelude::*; +/// +/// #[derive(Clone, Accessor)] +/// struct SomeKeys { +/// integer: BufferKey, +/// string: BufferKey, +/// any: AnyBufferKey, +/// } +/// ``` +/// +/// The macro will generate a struct of buffers to match the keys. The name of +/// that struct is anonymous by default since you don't generally need to use it +/// directly, but if you want to give it a name you can use `#[key(buffers_struct_name = ...)]`: +/// +/// ``` +/// # use bevy_impulse::prelude::*; +/// +/// #[derive(Clone, Accessor)] +/// #[key(buffers_struct_name = SomeBuffers)] +/// struct SomeKeys { +/// integer: BufferKey, +/// string: BufferKey, +/// any: AnyBufferKey, +/// } +/// ``` +/// +/// [1]: crate::Builder::listen +/// [2]: crate::Builder::create_buffer_access +/// [3]: crate::Builder::try_listen +/// [4]: crate::Builder::try_create_buffer_access +pub trait Accessor: 'static + Send + Sync + Sized + Clone { + type Buffers: 'static + BufferMapLayout + Accessing + Send + Sync; + + fn try_listen_from<'w, 's, 'a, 'b>( + buffers: &BufferMap, + builder: &'b mut Builder<'w, 's, 'a>, + ) -> Result, IncompatibleLayout> { + let buffers: Self::Buffers = Self::Buffers::try_from_buffer_map(buffers)?; + Ok(buffers.listen(builder)) + } + + fn try_buffer_access( + buffers: &BufferMap, + builder: &mut Builder, + ) -> Result, IncompatibleLayout> { + let buffers: Self::Buffers = Self::Buffers::try_from_buffer_map(buffers)?; + Ok(buffers.access(builder)) + } +} + +impl BufferMapLayout for BufferMap { + fn try_from_buffer_map(buffers: &BufferMap) -> Result { + Ok(buffers.clone()) + } +} + +impl BufferMapStruct for BufferMap { + fn buffer_list(&self) -> SmallVec<[AnyBuffer; 8]> { + self.values().cloned().collect() + } +} + +impl Joining for BufferMap { + type Item = HashMap, AnyMessageBox>; + + fn pull(&self, session: Entity, world: &mut World) -> Result { + let mut value = HashMap::new(); + for (name, buffer) in self.iter() { + value.insert(name.clone(), buffer.pull(session, world)?); + } + + Ok(value) + } +} + +impl Joined for HashMap, AnyMessageBox> { + type Buffers = BufferMap; +} + +impl Accessing for BufferMap { + type Key = HashMap, AnyBufferKey>; + + fn create_key(&self, builder: &BufferKeyBuilder) -> Self::Key { + let mut keys = HashMap::new(); + for (name, buffer) in self.iter() { + let key = AnyBufferKey { + tag: builder.make_tag(buffer.id()), + interface: buffer.interface, + }; + keys.insert(name.clone(), key); + } + keys + } + + fn add_accessor(&self, accessor: Entity, world: &mut World) -> OperationResult { + for buffer in self.values() { + buffer.add_accessor(accessor, world)?; + } + Ok(()) + } + + fn deep_clone_key(key: &Self::Key) -> Self::Key { + let mut cloned_key = HashMap::new(); + for (name, key) in key.iter() { + cloned_key.insert(name.clone(), key.deep_clone()); + } + cloned_key + } + + fn is_key_in_use(key: &Self::Key) -> bool { + for k in key.values() { + if k.is_in_use() { + return true; + } + } + + return false; + } +} + +impl Joined for Vec { + type Buffers = Vec>; +} + +impl BufferMapLayout for Vec { + fn try_from_buffer_map(buffers: &BufferMap) -> Result { + let mut downcast_buffers = Vec::new(); + let mut compatibility = IncompatibleLayout::default(); + for i in 0..buffers.len() { + if let Ok(downcast) = compatibility.require_buffer_for_identifier::(i, buffers) { + downcast_buffers.push(downcast); + } + } + + compatibility.as_result()?; + Ok(downcast_buffers) + } +} + +impl Joined for SmallVec<[T; N]> { + type Buffers = SmallVec<[Buffer; N]>; +} + +impl BufferMapLayout + for SmallVec<[B; N]> +{ + fn try_from_buffer_map(buffers: &BufferMap) -> Result { + let mut downcast_buffers = SmallVec::new(); + let mut compatibility = IncompatibleLayout::default(); + for i in 0..buffers.len() { + if let Ok(downcast) = compatibility.require_buffer_for_identifier::(i, buffers) { + downcast_buffers.push(downcast); + } + } + + compatibility.as_result()?; + Ok(downcast_buffers) + } +} + +#[cfg(test)] +mod tests { + use crate::{prelude::*, testing::*, AddBufferToMap, BufferMap}; + + #[derive(Joined)] + struct TestJoinedValue { + integer: i64, + float: f64, + string: String, + generic: T, + #[joined(buffer = AnyBuffer)] + any: AnyMessageBox, + } + + #[test] + fn test_try_join() { + let mut context = TestingContext::minimal_plugins(); + + let workflow = context.spawn_io_workflow(|scope, builder| { + let buffer_i64 = builder.create_buffer(BufferSettings::default()); + let buffer_f64 = builder.create_buffer(BufferSettings::default()); + let buffer_string = builder.create_buffer(BufferSettings::default()); + let buffer_generic = builder.create_buffer(BufferSettings::default()); + let buffer_any = builder.create_buffer(BufferSettings::default()); + + let mut buffers = BufferMap::default(); + buffers.insert_buffer("integer", buffer_i64); + buffers.insert_buffer("float", buffer_f64); + buffers.insert_buffer("string", buffer_string); + buffers.insert_buffer("generic", buffer_generic); + buffers.insert_buffer("any", buffer_any); + + scope.input.chain(builder).fork_unzip(( + |chain: Chain<_>| chain.connect(buffer_i64.input_slot()), + |chain: Chain<_>| chain.connect(buffer_f64.input_slot()), + |chain: Chain<_>| chain.connect(buffer_string.input_slot()), + |chain: Chain<_>| chain.connect(buffer_generic.input_slot()), + |chain: Chain<_>| chain.connect(buffer_any.input_slot()), + )); + + builder.try_join(&buffers).unwrap().connect(scope.terminate); + }); + + let mut promise = context.command(|commands| { + commands + .request( + (5_i64, 3.14_f64, "hello".to_string(), "world", 42_i64), + workflow, + ) + .take_response() + }); + + context.run_with_conditions(&mut promise, Duration::from_secs(2)); + let value: TestJoinedValue<&'static str> = promise.take().available().unwrap(); + assert_eq!(value.integer, 5); + assert_eq!(value.float, 3.14); + assert_eq!(value.string, "hello"); + assert_eq!(value.generic, "world"); + assert_eq!(*value.any.downcast::().unwrap(), 42); + assert!(context.no_unhandled_errors()); + } + + #[test] + fn test_joined_value() { + let mut context = TestingContext::minimal_plugins(); + + let workflow = context.spawn_io_workflow(|scope, builder| { + let buffer_i64 = builder.create_buffer(BufferSettings::default()); + let buffer_f64 = builder.create_buffer(BufferSettings::default()); + let buffer_string = builder.create_buffer(BufferSettings::default()); + let buffer_generic = builder.create_buffer(BufferSettings::default()); + let buffer_any = builder.create_buffer::(BufferSettings::default()); + + scope.input.chain(builder).fork_unzip(( + |chain: Chain<_>| chain.connect(buffer_i64.input_slot()), + |chain: Chain<_>| chain.connect(buffer_f64.input_slot()), + |chain: Chain<_>| chain.connect(buffer_string.input_slot()), + |chain: Chain<_>| chain.connect(buffer_generic.input_slot()), + |chain: Chain<_>| chain.connect(buffer_any.input_slot()), + )); + + let buffers = TestJoinedValue::select_buffers( + buffer_i64, + buffer_f64, + buffer_string, + buffer_generic, + buffer_any.into(), + ); + + builder.join(buffers).connect(scope.terminate); + }); + + let mut promise = context.command(|commands| { + commands + .request( + (5_i64, 3.14_f64, "hello".to_string(), "world", 42_i64), + workflow, + ) + .take_response() + }); + + context.run_with_conditions(&mut promise, Duration::from_secs(2)); + let value: TestJoinedValue<&'static str> = promise.take().available().unwrap(); + assert_eq!(value.integer, 5); + assert_eq!(value.float, 3.14); + assert_eq!(value.string, "hello"); + assert_eq!(value.generic, "world"); + assert_eq!(*value.any.downcast::().unwrap(), 42); + assert!(context.no_unhandled_errors()); + } + + #[derive(Clone, Joined)] + #[joined(buffers_struct_name = FooBuffers)] + struct TestDeriveWithConfig {} + + #[test] + fn test_derive_with_config() { + // a compile test to check that the name of the generated struct is correct + fn _check_buffer_struct_name(_: FooBuffers) {} + } + + struct MultiGenericValue { + t: T, + u: U, + } + + #[derive(Joined)] + #[joined(buffers_struct_name = MultiGenericBuffers)] + struct JoinedMultiGenericValue { + #[joined(buffer = Buffer>)] + a: MultiGenericValue, + b: String, + } + + #[test] + fn test_multi_generic_joined_value() { + let mut context = TestingContext::minimal_plugins(); + + let workflow = context.spawn_io_workflow( + |scope: Scope<(i32, String), JoinedMultiGenericValue>, builder| { + let multi_generic_buffers = MultiGenericBuffers:: { + a: builder.create_buffer(BufferSettings::default()), + b: builder.create_buffer(BufferSettings::default()), + }; + + let copy = multi_generic_buffers; + + scope + .input + .chain(builder) + .map_block(|(integer, string)| { + ( + MultiGenericValue { + t: integer, + u: string.clone(), + }, + string, + ) + }) + .fork_unzip(( + |a: Chain<_>| a.connect(multi_generic_buffers.a.input_slot()), + |b: Chain<_>| b.connect(multi_generic_buffers.b.input_slot()), + )); + + multi_generic_buffers.join(builder).connect(scope.terminate); + copy.join(builder).connect(scope.terminate); + }, + ); + + let mut promise = context.command(|commands| { + commands + .request((5, "hello".to_string()), workflow) + .take_response() + }); + + context.run_with_conditions(&mut promise, Duration::from_secs(2)); + let value = promise.take().available().unwrap(); + assert_eq!(value.a.t, 5); + assert_eq!(value.a.u, "hello"); + assert_eq!(value.b, "hello"); + assert!(context.no_unhandled_errors()); + } + + /// We create this struct just to verify that it is able to compile despite + /// NonCopyBuffer not being copyable. + #[derive(Joined)] + #[allow(unused)] + struct JoinedValueForNonCopyBuffer { + #[joined(buffer = NonCopyBuffer, noncopy_buffer)] + _a: String, + _b: u32, + } + + #[derive(Clone, Accessor)] + #[key(buffers_struct_name = TestKeysBuffers)] + struct TestKeys { + integer: BufferKey, + float: BufferKey, + string: BufferKey, + generic: BufferKey, + any: AnyBufferKey, + } + #[test] + fn test_listen() { + let mut context = TestingContext::minimal_plugins(); + + let workflow = context.spawn_io_workflow(|scope, builder| { + let buffer_any = builder.create_buffer::(BufferSettings::default()); + + let buffers = TestKeys::select_buffers( + builder.create_buffer(BufferSettings::default()), + builder.create_buffer(BufferSettings::default()), + builder.create_buffer(BufferSettings::default()), + builder.create_buffer(BufferSettings::default()), + buffer_any.as_any_buffer(), + ); + + scope.input.chain(builder).fork_unzip(( + |chain: Chain<_>| chain.connect(buffers.integer.input_slot()), + |chain: Chain<_>| chain.connect(buffers.float.input_slot()), + |chain: Chain<_>| chain.connect(buffers.string.input_slot()), + |chain: Chain<_>| chain.connect(buffers.generic.input_slot()), + |chain: Chain<_>| chain.connect(buffer_any.input_slot()), + )); + + builder + .listen(buffers) + .then(join_via_listen.into_blocking_callback()) + .dispose_on_none() + .connect(scope.terminate); + }); + + let mut promise = context.command(|commands| { + commands + .request( + (5_i64, 3.14_f64, "hello".to_string(), "world", 42_i64), + workflow, + ) + .take_response() + }); + + context.run_with_conditions(&mut promise, Duration::from_secs(2)); + let value: TestJoinedValue<&'static str> = promise.take().available().unwrap(); + assert_eq!(value.integer, 5); + assert_eq!(value.float, 3.14); + assert_eq!(value.string, "hello"); + assert_eq!(value.generic, "world"); + assert_eq!(*value.any.downcast::().unwrap(), 42); + assert!(context.no_unhandled_errors()); + } + + #[test] + fn test_try_listen() { + let mut context = TestingContext::minimal_plugins(); + + let workflow = context.spawn_io_workflow(|scope, builder| { + let buffer_i64 = builder.create_buffer::(BufferSettings::default()); + let buffer_f64 = builder.create_buffer::(BufferSettings::default()); + let buffer_string = builder.create_buffer::(BufferSettings::default()); + let buffer_generic = builder.create_buffer::<&'static str>(BufferSettings::default()); + let buffer_any = builder.create_buffer::(BufferSettings::default()); + + scope.input.chain(builder).fork_unzip(( + |chain: Chain<_>| chain.connect(buffer_i64.input_slot()), + |chain: Chain<_>| chain.connect(buffer_f64.input_slot()), + |chain: Chain<_>| chain.connect(buffer_string.input_slot()), + |chain: Chain<_>| chain.connect(buffer_generic.input_slot()), + |chain: Chain<_>| chain.connect(buffer_any.input_slot()), + )); + + let mut buffer_map = BufferMap::new(); + buffer_map.insert_buffer("integer", buffer_i64); + buffer_map.insert_buffer("float", buffer_f64); + buffer_map.insert_buffer("string", buffer_string); + buffer_map.insert_buffer("generic", buffer_generic); + buffer_map.insert_buffer("any", buffer_any); + + builder + .try_listen(&buffer_map) + .unwrap() + .then(join_via_listen.into_blocking_callback()) + .dispose_on_none() + .connect(scope.terminate); + }); + + let mut promise = context.command(|commands| { + commands + .request( + (5_i64, 3.14_f64, "hello".to_string(), "world", 42_i64), + workflow, + ) + .take_response() + }); + + context.run_with_conditions(&mut promise, Duration::from_secs(2)); + let value: TestJoinedValue<&'static str> = promise.take().available().unwrap(); + assert_eq!(value.integer, 5); + assert_eq!(value.float, 3.14); + assert_eq!(value.string, "hello"); + assert_eq!(value.generic, "world"); + assert_eq!(*value.any.downcast::().unwrap(), 42); + assert!(context.no_unhandled_errors()); + } + + /// This macro is a manual implementation of the join operation that uses + /// the buffer listening mechanism. There isn't any reason to reimplement + /// join here except so we can test that listening is working correctly for + /// Accessor. + fn join_via_listen( + In(keys): In>, + world: &mut World, + ) -> Option> { + if world.buffer_view(&keys.integer).ok()?.is_empty() { + return None; + } + if world.buffer_view(&keys.float).ok()?.is_empty() { + return None; + } + if world.buffer_view(&keys.string).ok()?.is_empty() { + return None; + } + if world.buffer_view(&keys.generic).ok()?.is_empty() { + return None; + } + if world.any_buffer_view(&keys.any).ok()?.is_empty() { + return None; + } + + let integer = world + .buffer_mut(&keys.integer, |mut buffer| buffer.pull()) + .unwrap() + .unwrap(); + let float = world + .buffer_mut(&keys.float, |mut buffer| buffer.pull()) + .unwrap() + .unwrap(); + let string = world + .buffer_mut(&keys.string, |mut buffer| buffer.pull()) + .unwrap() + .unwrap(); + let generic = world + .buffer_mut(&keys.generic, |mut buffer| buffer.pull()) + .unwrap() + .unwrap(); + let any = world + .any_buffer_mut(&keys.any, |mut buffer| buffer.pull()) + .unwrap() + .unwrap(); + + Some(TestJoinedValue { + integer, + float, + string, + generic, + any, + }) + } +} diff --git a/src/buffer/buffer_storage.rs b/src/buffer/buffer_storage.rs index c7465415..e3d9bc88 100644 --- a/src/buffer/buffer_storage.rs +++ b/src/buffer/buffer_storage.rs @@ -44,23 +44,18 @@ pub(crate) struct BufferStorage { } impl BufferStorage { - pub(crate) fn force_push(&mut self, session: Entity, value: T) -> Option { - Self::impl_push( - self.reverse_queues.entry(session).or_default(), - self.settings.retention(), - value, - ) + pub(crate) fn count(&self, session: Entity) -> usize { + self.reverse_queues + .get(&session) + .map(|q| q.len()) + .unwrap_or(0) } - pub(crate) fn push(&mut self, session: Entity, value: T) -> Option { - let Some(reverse_queue) = self.reverse_queues.get_mut(&session) else { - return Some(value); - }; - - Self::impl_push(reverse_queue, self.settings.retention(), value) + pub(crate) fn active_sessions(&self) -> SmallVec<[Entity; 16]> { + self.reverse_queues.keys().copied().collect() } - pub(crate) fn impl_push( + fn impl_push( reverse_queue: &mut SmallVec<[T; 16]>, retention: RetentionPolicy, value: T, @@ -92,6 +87,22 @@ impl BufferStorage { replaced } + pub(crate) fn force_push(&mut self, session: Entity, value: T) -> Option { + Self::impl_push( + self.reverse_queues.entry(session).or_default(), + self.settings.retention(), + value, + ) + } + + pub(crate) fn push(&mut self, session: Entity, value: T) -> Option { + let Some(reverse_queue) = self.reverse_queues.get_mut(&session) else { + return Some(value); + }; + + Self::impl_push(reverse_queue, self.settings.retention(), value) + } + pub(crate) fn push_as_oldest(&mut self, session: Entity, value: T) -> Option { let Some(reverse_queue) = self.reverse_queues.get_mut(&session) else { return Some(value); @@ -147,20 +158,8 @@ impl BufferStorage { self.reverse_queues.remove(&session); } - pub(crate) fn count(&self, session: Entity) -> usize { - self.reverse_queues - .get(&session) - .map(|q| q.len()) - .unwrap_or(0) - } - - pub(crate) fn iter(&self, session: Entity) -> IterBufferView<'_, T> - where - T: 'static + Send + Sync, - { - IterBufferView { - iter: self.reverse_queues.get(&session).map(|q| q.iter().rev()), - } + pub(crate) fn ensure_session(&mut self, session: Entity) { + self.reverse_queues.entry(session).or_default(); } pub(crate) fn iter_mut(&mut self, session: Entity) -> IterBufferMut<'_, T> @@ -175,6 +174,28 @@ impl BufferStorage { } } + pub(crate) fn oldest_mut(&mut self, session: Entity) -> Option<&mut T> { + self.reverse_queues + .get_mut(&session) + .and_then(|q| q.last_mut()) + } + + pub(crate) fn newest_mut(&mut self, session: Entity) -> Option<&mut T> { + self.reverse_queues + .get_mut(&session) + .and_then(|q| q.first_mut()) + } + + pub(crate) fn get_mut(&mut self, session: Entity, index: usize) -> Option<&mut T> { + let reverse_queue = self.reverse_queues.get_mut(&session)?; + let len = reverse_queue.len(); + if len <= index { + return None; + } + + reverse_queue.get_mut(len - index - 1) + } + pub(crate) fn drain(&mut self, session: Entity, range: R) -> DrainBuffer<'_, T> where T: 'static + Send + Sync, @@ -188,6 +209,15 @@ impl BufferStorage { } } + pub(crate) fn iter(&self, session: Entity) -> IterBufferView<'_, T> + where + T: 'static + Send + Sync, + { + IterBufferView { + iter: self.reverse_queues.get(&session).map(|q| q.iter().rev()), + } + } + pub(crate) fn oldest(&self, session: Entity) -> Option<&T> { self.reverse_queues.get(&session).and_then(|q| q.last()) } @@ -199,43 +229,13 @@ impl BufferStorage { pub(crate) fn get(&self, session: Entity, index: usize) -> Option<&T> { let reverse_queue = self.reverse_queues.get(&session)?; let len = reverse_queue.len(); - if len >= index { + if len <= index { return None; } reverse_queue.get(len - index - 1) } - pub(crate) fn oldest_mut(&mut self, session: Entity) -> Option<&mut T> { - self.reverse_queues - .get_mut(&session) - .and_then(|q| q.last_mut()) - } - - pub(crate) fn newest_mut(&mut self, session: Entity) -> Option<&mut T> { - self.reverse_queues - .get_mut(&session) - .and_then(|q| q.first_mut()) - } - - pub(crate) fn get_mut(&mut self, session: Entity, index: usize) -> Option<&mut T> { - let reverse_queue = self.reverse_queues.get_mut(&session)?; - let len = reverse_queue.len(); - if len >= index { - return None; - } - - reverse_queue.get_mut(len - index - 1) - } - - pub(crate) fn active_sessions(&self) -> SmallVec<[Entity; 16]> { - self.reverse_queues.keys().copied().collect() - } - - pub(crate) fn ensure_session(&mut self, session: Entity) { - self.reverse_queues.entry(session).or_default(); - } - pub(crate) fn new(settings: BufferSettings) -> Self { Self { settings, diff --git a/src/buffer/bufferable.rs b/src/buffer/bufferable.rs index 17f8b367..daf55738 100644 --- a/src/buffer/bufferable.rs +++ b/src/buffer/bufferable.rs @@ -19,143 +19,25 @@ use bevy_utils::all_tuples; use smallvec::SmallVec; use crate::{ - AddOperation, Buffer, BufferSettings, Buffered, Builder, Chain, CleanupWorkflowConditions, - CloneFromBuffer, Join, Listen, Output, Scope, ScopeSettings, UnusedTarget, + Accessing, AddOperation, Buffer, BufferSettings, Buffering, Builder, Chain, CloneFromBuffer, + Join, Joining, Output, UnusedTarget, }; -pub type BufferKeys = <::BufferType as Buffered>::Key; -pub type BufferItem = <::BufferType as Buffered>::Item; +pub type BufferKeys = <::BufferType as Accessing>::Key; +pub type JoinedItem = <::BufferType as Joining>::Item; pub trait Bufferable { - type BufferType: Buffered; + type BufferType: Buffering; /// Convert these bufferable workflow elements into buffers if they are not /// buffers already. fn into_buffer(self, builder: &mut Builder) -> Self::BufferType; - - /// Join these bufferable workflow elements. Each time every buffer contains - /// at least one element, this will pull the oldest element from each buffer - /// and join them into a tuple that gets sent to the target. - /// - /// If you need a more general way to get access to one or more buffers, - /// use [`listen`](Self::listen) instead. - fn join<'w, 's, 'a, 'b>( - self, - builder: &'b mut Builder<'w, 's, 'a>, - ) -> Chain<'w, 's, 'a, 'b, BufferItem> - where - Self: Sized, - Self::BufferType: 'static + Send + Sync, - BufferItem: 'static + Send + Sync, - { - let scope = builder.scope(); - let buffers = self.into_buffer(builder); - buffers.verify_scope(scope); - - let join = builder.commands.spawn(()).id(); - let target = builder.commands.spawn(UnusedTarget).id(); - builder.commands.add(AddOperation::new( - Some(scope), - join, - Join::new(buffers, target), - )); - - Output::new(scope, target).chain(builder) - } - - /// Create an operation that will output buffer access keys each time any - /// one of the buffers is modified. This can be used to create a node in a - /// workflow that wakes up every time one or more buffers change, and then - /// operates on those buffers. - /// - /// For an operation that simply joins the contents of two or more outputs - /// or buffers, use [`join`](Self::join) instead. - fn listen<'w, 's, 'a, 'b>( - self, - builder: &'b mut Builder<'w, 's, 'a>, - ) -> Chain<'w, 's, 'a, 'b, BufferKeys> - where - Self: Sized, - Self::BufferType: 'static + Send + Sync, - BufferKeys: 'static + Send + Sync, - { - let scope = builder.scope(); - let buffers = self.into_buffer(builder); - buffers.verify_scope(scope); - - let listen = builder.commands.spawn(()).id(); - let target = builder.commands.spawn(UnusedTarget).id(); - builder.commands.add(AddOperation::new( - Some(scope), - listen, - Listen::new(buffers, target), - )); - - Output::new(scope, target).chain(builder) - } - - /// Alternative way to call [`Builder::on_cleanup`]. - fn on_cleanup( - self, - builder: &mut Builder, - build: impl FnOnce(Scope, (), ()>, &mut Builder) -> Settings, - ) where - Self: Sized, - Self::BufferType: 'static + Send + Sync, - BufferKeys: 'static + Send + Sync, - Settings: Into, - { - builder.on_cleanup(self, build) - } - - /// Alternative way to call [`Builder::on_cancel`]. - fn on_cancel( - self, - builder: &mut Builder, - build: impl FnOnce(Scope, (), ()>, &mut Builder) -> Settings, - ) where - Self: Sized, - Self::BufferType: 'static + Send + Sync, - BufferKeys: 'static + Send + Sync, - Settings: Into, - { - builder.on_cancel(self, build) - } - - /// Alternative way to call [`Builder::on_terminate`]. - fn on_terminate( - self, - builder: &mut Builder, - build: impl FnOnce(Scope, (), ()>, &mut Builder) -> Settings, - ) where - Self: Sized, - Self::BufferType: 'static + Send + Sync, - BufferKeys: 'static + Send + Sync, - Settings: Into, - { - builder.on_terminate(self, build) - } - - /// Alternative way to call [`Builder::on_cleanup_if`]. - fn on_cleanup_if( - self, - builder: &mut Builder, - conditions: CleanupWorkflowConditions, - build: impl FnOnce(Scope, (), ()>, &mut Builder) -> Settings, - ) where - Self: Sized, - Self::BufferType: 'static + Send + Sync, - BufferKeys: 'static + Send + Sync, - Settings: Into, - { - builder.on_cleanup_if(conditions, self, build) - } } impl Bufferable for Buffer { type BufferType = Self; fn into_buffer(self, builder: &mut Builder) -> Self::BufferType { - assert_eq!(self.scope, builder.scope()); + assert_eq!(self.scope(), builder.scope()); self } } @@ -163,7 +45,7 @@ impl Bufferable for Buffer { impl Bufferable for CloneFromBuffer { type BufferType = Self; fn into_buffer(self, builder: &mut Builder) -> Self::BufferType { - assert_eq!(self.scope, builder.scope()); + assert_eq!(self.scope(), builder.scope()); self } } @@ -178,6 +60,70 @@ impl Bufferable for Output { } } +pub trait Joinable: Bufferable { + type Item: 'static + Send + Sync; + + fn join<'w, 's, 'a, 'b>( + self, + builder: &'b mut Builder<'w, 's, 'a>, + ) -> Chain<'w, 's, 'a, 'b, Self::Item>; +} + +/// This trait is used to create join operations that pull exactly one value +/// from multiple buffers or outputs simultaneously. +impl Joinable for B +where + B: Bufferable, + B::BufferType: Joining, +{ + type Item = JoinedItem; + + /// Join these bufferable workflow elements. Each time every buffer contains + /// at least one element, this will pull the oldest element from each buffer + /// and join them into a tuple that gets sent to the target. + /// + /// If you need a more general way to get access to one or more buffers, + /// use [`listen`](Accessible::listen) instead. + fn join<'w, 's, 'a, 'b>( + self, + builder: &'b mut Builder<'w, 's, 'a>, + ) -> Chain<'w, 's, 'a, 'b, Self::Item> { + self.into_buffer(builder).join(builder) + } +} + +/// This trait is used to create operations that access buffers or outputs. +pub trait Accessible: Bufferable { + type Keys: 'static + Send + Sync; + + /// Create an operation that will output buffer access keys each time any + /// one of the buffers is modified. This can be used to create a node in a + /// workflow that wakes up every time one or more buffers change, and then + /// operates on those buffers. + /// + /// For an operation that simply joins the contents of two or more outputs + /// or buffers, use [`join`](Joinable::join) instead. + fn listen<'w, 's, 'a, 'b>( + self, + builder: &'b mut Builder<'w, 's, 'a>, + ) -> Chain<'w, 's, 'a, 'b, Self::Keys>; +} + +impl Accessible for B +where + B: Bufferable, + B::BufferType: Accessing, +{ + type Keys = BufferKeys; + + fn listen<'w, 's, 'a, 'b>( + self, + builder: &'b mut Builder<'w, 's, 'a>, + ) -> Chain<'w, 's, 'a, 'b, Self::Keys> { + self.into_buffer(builder).listen(builder) + } +} + macro_rules! impl_bufferable_for_tuple { ($($T:ident),*) => { #[allow(non_snake_case)] @@ -206,8 +152,15 @@ impl Bufferable for [T; N] { } } +impl Bufferable for Vec { + type BufferType = Vec; + fn into_buffer(self, builder: &mut Builder) -> Self::BufferType { + self.into_iter().map(|b| b.into_buffer(builder)).collect() + } +} + pub trait IterBufferable { - type BufferElement: Buffered; + type BufferElement: Buffering + Joining; /// Convert an iterable collection of bufferable workflow elements into /// buffers if they are not buffers already. @@ -224,11 +177,11 @@ pub trait IterBufferable { fn join_vec<'w, 's, 'a, 'b, const N: usize>( self, builder: &'b mut Builder<'w, 's, 'a>, - ) -> Chain<'w, 's, 'a, 'b, SmallVec<[::Item; N]>> + ) -> Chain<'w, 's, 'a, 'b, SmallVec<[::Item; N]>> where Self: Sized, Self::BufferElement: 'static + Send + Sync, - ::Item: 'static + Send + Sync, + ::Item: 'static + Send + Sync, { let buffers = self.into_buffer_vec::(builder); let join = builder.commands.spawn(()).id(); @@ -247,6 +200,7 @@ impl IterBufferable for T where T: IntoIterator, T::Item: Bufferable, + ::BufferType: Joining, { type BufferElement = ::BufferType; diff --git a/src/buffer/buffered.rs b/src/buffer/buffering.rs similarity index 54% rename from src/buffer/buffered.rs rename to src/buffer/buffering.rs index 084acb13..81eea5f1 100644 --- a/src/buffer/buffered.rs +++ b/src/buffer/buffering.rs @@ -16,24 +16,24 @@ */ use bevy_ecs::prelude::{Entity, World}; +use bevy_hierarchy::BuildChildren; use bevy_utils::all_tuples; use smallvec::SmallVec; use crate::{ - Buffer, BufferAccessors, BufferKey, BufferKeyBuilder, BufferStorage, CloneFromBuffer, - ForkTargetStorage, Gate, GateState, InspectBuffer, ManageBuffer, OperationError, - OperationResult, OperationRoster, OrBroken, SingleInputStorage, + AddOperation, BeginCleanupWorkflow, Buffer, BufferAccessors, BufferKey, BufferKeyBuilder, + BufferKeyLifecycle, BufferStorage, Builder, Chain, CleanupWorkflowConditions, CloneFromBuffer, + ForkTargetStorage, Gate, GateState, InputSlot, InspectBuffer, Join, Listen, ManageBuffer, Node, + OperateBufferAccess, OperationError, OperationResult, OperationRoster, OrBroken, Output, Scope, + ScopeSettings, SingleInputStorage, UnusedTarget, }; -pub trait Buffered: Clone { +pub trait Buffering: 'static + Send + Sync + Clone { fn verify_scope(&self, scope: Entity); fn buffered_count(&self, session: Entity, world: &World) -> Result; - type Item; - fn pull(&self, session: Entity, world: &mut World) -> Result; - fn add_listener(&self, listener: Entity, world: &mut World) -> OperationResult; fn gate_action( @@ -46,56 +46,177 @@ pub trait Buffered: Clone { fn as_input(&self) -> SmallVec<[Entity; 8]>; - type Key: Clone; - fn add_accessor(&self, accessor: Entity, world: &mut World) -> OperationResult; + fn ensure_active_session(&self, session: Entity, world: &mut World) -> OperationResult; +} - fn create_key(&self, builder: &BufferKeyBuilder) -> Self::Key; +pub trait Joining: Buffering { + type Item: 'static + Send + Sync; + fn pull(&self, session: Entity, world: &mut World) -> Result; - fn ensure_active_session(&self, session: Entity, world: &mut World) -> OperationResult; + /// Join these bufferable workflow elements. Each time every buffer contains + /// at least one element, this will pull the oldest element from each buffer + /// and join them into a tuple that gets sent to the target. + /// + /// If you need a more general way to get access to one or more buffers, + /// use [`listen`](Accessing::listen) instead. + fn join<'w, 's, 'a, 'b>( + self, + builder: &'b mut Builder<'w, 's, 'a>, + ) -> Chain<'w, 's, 'a, 'b, Self::Item> { + let scope = builder.scope(); + self.verify_scope(scope); + + let join = builder.commands.spawn(()).id(); + let target = builder.commands.spawn(UnusedTarget).id(); + builder.commands.add(AddOperation::new( + Some(scope), + join, + Join::new(self, target), + )); + + Output::new(scope, target).chain(builder) + } +} +pub trait Accessing: Buffering { + type Key: 'static + Send + Sync + Clone; + fn add_accessor(&self, accessor: Entity, world: &mut World) -> OperationResult; + fn create_key(&self, builder: &BufferKeyBuilder) -> Self::Key; fn deep_clone_key(key: &Self::Key) -> Self::Key; - fn is_key_in_use(key: &Self::Key) -> bool; + + /// Create an operation that will output buffer access keys each time any + /// one of the buffers is modified. This can be used to create a node in a + /// workflow that wakes up every time one or more buffers change, and then + /// operates on those buffers. + /// + /// For an operation that simply joins the contents of two or more outputs + /// or buffers, use [`join`](Joining::join) instead. + fn listen<'w, 's, 'a, 'b>( + self, + builder: &'b mut Builder<'w, 's, 'a>, + ) -> Chain<'w, 's, 'a, 'b, Self::Key> { + let scope = builder.scope(); + self.verify_scope(scope); + + let listen = builder.commands.spawn(()).id(); + let target = builder.commands.spawn(UnusedTarget).id(); + builder.commands.add(AddOperation::new( + Some(scope), + listen, + Listen::new(self, target), + )); + + Output::new(scope, target).chain(builder) + } + + fn access(self, builder: &mut Builder) -> Node { + let source = builder.commands.spawn(()).id(); + let target = builder.commands.spawn(UnusedTarget).id(); + builder.commands.add(AddOperation::new( + Some(builder.scope), + source, + OperateBufferAccess::::new(self, target), + )); + + Node { + input: InputSlot::new(builder.scope, source), + output: Output::new(builder.scope, target), + streams: (), + } + } + + /// Alternative way to call [`Builder::on_cleanup`]. + fn on_cleanup( + self, + builder: &mut Builder, + build: impl FnOnce(Scope, &mut Builder) -> Settings, + ) where + Settings: Into, + { + self.on_cleanup_if( + builder, + CleanupWorkflowConditions::always_if(true, true), + build, + ) + } + + /// Alternative way to call [`Builder::on_cancel`]. + fn on_cancel( + self, + builder: &mut Builder, + build: impl FnOnce(Scope, &mut Builder) -> Settings, + ) where + Settings: Into, + { + self.on_cleanup_if( + builder, + CleanupWorkflowConditions::always_if(false, true), + build, + ) + } + + /// Alternative way to call [`Builder::on_terminate`]. + fn on_terminate( + self, + builder: &mut Builder, + build: impl FnOnce(Scope, &mut Builder) -> Settings, + ) where + Settings: Into, + { + self.on_cleanup_if( + builder, + CleanupWorkflowConditions::always_if(true, false), + build, + ) + } + + /// Alternative way to call [`Builder::on_cleanup_if`]. + fn on_cleanup_if( + self, + builder: &mut Builder, + conditions: CleanupWorkflowConditions, + build: impl FnOnce(Scope, &mut Builder) -> Settings, + ) where + Settings: Into, + { + let cancelling_scope_id = builder.commands.spawn(()).id(); + let _ = builder.create_scope_impl::( + cancelling_scope_id, + builder.finish_scope_cancel, + build, + ); + + let begin_cancel = builder.commands.spawn(()).set_parent(builder.scope).id(); + self.verify_scope(builder.scope); + builder.commands.add(AddOperation::new( + None, + begin_cancel, + BeginCleanupWorkflow::::new( + builder.scope, + self, + cancelling_scope_id, + conditions.run_on_terminate, + conditions.run_on_cancel, + ), + )); + } } -impl Buffered for Buffer { +impl Buffering for Buffer { fn verify_scope(&self, scope: Entity) { - assert_eq!(scope, self.scope); + assert_eq!(scope, self.scope()); } fn buffered_count(&self, session: Entity, world: &World) -> Result { world - .get_entity(self.source) + .get_entity(self.id()) .or_broken()? .buffered_count::(session) } - type Item = T; - fn pull(&self, session: Entity, world: &mut World) -> Result { - world - .get_entity_mut(self.source) - .or_broken()? - .pull_from_buffer::(session) - } - fn add_listener(&self, listener: Entity, world: &mut World) -> OperationResult { - let mut targets = world - .get_mut::(self.source) - .or_broken()?; - if !targets.0.contains(&listener) { - targets.0.push(listener); - } - - if let Some(mut input_storage) = world.get_mut::(listener) { - input_storage.add(self.source); - } else { - world - .get_entity_mut(listener) - .or_broken()? - .insert(SingleInputStorage::new(self.source)); - } - - Ok(()) + add_listener_to_source(self.id(), listener, world) } fn gate_action( @@ -105,35 +226,46 @@ impl Buffered for Buffer { world: &mut World, roster: &mut OperationRoster, ) -> OperationResult { - GateState::apply(self.source, session, action, world, roster) + GateState::apply(self.id(), session, action, world, roster) } fn as_input(&self) -> SmallVec<[Entity; 8]> { - SmallVec::from_iter([self.source]) + SmallVec::from_iter([self.id()]) } - type Key = BufferKey; - fn add_accessor(&self, accessor: Entity, world: &mut World) -> OperationResult { - let mut accessors = world.get_mut::(self.source).or_broken()?; - - accessors.0.push(accessor); - accessors.0.sort(); - accessors.0.dedup(); + fn ensure_active_session(&self, session: Entity, world: &mut World) -> OperationResult { + world + .get_mut::>(self.id()) + .or_broken()? + .ensure_session(session); Ok(()) } +} - fn create_key(&self, builder: &BufferKeyBuilder) -> Self::Key { - builder.build(self.source) +impl Joining for Buffer { + type Item = T; + fn pull(&self, session: Entity, world: &mut World) -> Result { + world + .get_entity_mut(self.id()) + .or_broken()? + .pull_from_buffer::(session) } +} - fn ensure_active_session(&self, session: Entity, world: &mut World) -> OperationResult { +impl Accessing for Buffer { + type Key = BufferKey; + fn add_accessor(&self, accessor: Entity, world: &mut World) -> OperationResult { world - .get_mut::>(self.source) + .get_mut::(self.id()) .or_broken()? - .ensure_session(session); + .add_accessor(accessor); Ok(()) } + fn create_key(&self, builder: &BufferKeyBuilder) -> Self::Key { + Self::Key::create_key(&self, builder) + } + fn deep_clone_key(key: &Self::Key) -> Self::Key { key.deep_clone() } @@ -143,44 +275,20 @@ impl Buffered for Buffer { } } -impl Buffered for CloneFromBuffer { +impl Buffering for CloneFromBuffer { fn verify_scope(&self, scope: Entity) { - assert_eq!(scope, self.scope); + assert_eq!(scope, self.scope()); } fn buffered_count(&self, session: Entity, world: &World) -> Result { world - .get_entity(self.source) + .get_entity(self.id()) .or_broken()? .buffered_count::(session) } - type Item = T; - fn pull(&self, session: Entity, world: &mut World) -> Result { - world - .get_entity(self.source) - .or_broken()? - .try_clone_from_buffer(session) - .and_then(|r| r.or_broken()) - } - fn add_listener(&self, listener: Entity, world: &mut World) -> OperationResult { - let mut targets = world - .get_mut::(self.source) - .or_broken()?; - if !targets.0.contains(&listener) { - targets.0.push(listener); - } - - if let Some(mut input_storage) = world.get_mut::(listener) { - input_storage.add(self.source); - } else { - world - .get_entity_mut(listener) - .or_broken()? - .insert(SingleInputStorage::new(self.source)); - } - Ok(()) + add_listener_to_source(self.id(), listener, world) } fn gate_action( @@ -190,35 +298,46 @@ impl Buffered for CloneFromBuffer { world: &mut World, roster: &mut OperationRoster, ) -> OperationResult { - GateState::apply(self.source, session, action, world, roster) + GateState::apply(self.id(), session, action, world, roster) } fn as_input(&self) -> SmallVec<[Entity; 8]> { - SmallVec::from_iter([self.source]) + SmallVec::from_iter([self.id()]) } - type Key = BufferKey; - fn add_accessor(&self, accessor: Entity, world: &mut World) -> OperationResult { - let mut accessors = world.get_mut::(self.source).or_broken()?; - - accessors.0.push(accessor); - accessors.0.sort(); - accessors.0.dedup(); - Ok(()) + fn ensure_active_session(&self, session: Entity, world: &mut World) -> OperationResult { + world + .get_entity_mut(self.id()) + .or_broken()? + .ensure_session::(session) } +} - fn create_key(&self, builder: &BufferKeyBuilder) -> Self::Key { - builder.build(self.source) +impl Joining for CloneFromBuffer { + type Item = T; + fn pull(&self, session: Entity, world: &mut World) -> Result { + world + .get_entity(self.id()) + .or_broken()? + .try_clone_from_buffer(session) + .and_then(|r| r.or_broken()) } +} - fn ensure_active_session(&self, session: Entity, world: &mut World) -> OperationResult { +impl Accessing for CloneFromBuffer { + type Key = BufferKey; + fn add_accessor(&self, accessor: Entity, world: &mut World) -> OperationResult { world - .get_mut::>(self.source) + .get_mut::(self.id()) .or_broken()? - .ensure_session(session); + .add_accessor(accessor); Ok(()) } + fn create_key(&self, builder: &BufferKeyBuilder) -> Self::Key { + Self::Key::create_key(&(*self).into(), builder) + } + fn deep_clone_key(key: &Self::Key) -> Self::Key { key.deep_clone() } @@ -231,7 +350,7 @@ impl Buffered for CloneFromBuffer { macro_rules! impl_buffered_for_tuple { ($(($T:ident, $K:ident)),*) => { #[allow(non_snake_case)] - impl<$($T: Buffered),*> Buffered for ($($T,)*) + impl<$($T: Buffering),*> Buffering for ($($T,)*) { fn verify_scope(&self, scope: Entity) { let ($($T,)*) = self; @@ -253,18 +372,6 @@ macro_rules! impl_buffered_for_tuple { ].iter().copied().min().unwrap_or(0)) } - type Item = ($($T::Item),*); - fn pull( - &self, - session: Entity, - world: &mut World, - ) -> Result { - let ($($T,)*) = self; - Ok(($( - $T.pull(session, world)?, - )*)) - } - fn add_listener( &self, listener: Entity, @@ -300,6 +407,38 @@ macro_rules! impl_buffered_for_tuple { inputs } + fn ensure_active_session( + &self, + session: Entity, + world: &mut World, + ) -> OperationResult { + let ($($T,)*) = self; + $( + $T.ensure_active_session(session, world)?; + )* + Ok(()) + } + } + + #[allow(non_snake_case)] + impl<$($T: Joining),*> Joining for ($($T,)*) + { + type Item = ($($T::Item),*); + fn pull( + &self, + session: Entity, + world: &mut World, + ) -> Result { + let ($($T,)*) = self; + Ok(($( + $T.pull(session, world)?, + )*)) + } + } + + #[allow(non_snake_case)] + impl<$($T: Accessing),*> Accessing for ($($T,)*) + { type Key = ($($T::Key), *); fn add_accessor( &self, @@ -323,18 +462,6 @@ macro_rules! impl_buffered_for_tuple { )*) } - fn ensure_active_session( - &self, - session: Entity, - world: &mut World, - ) -> OperationResult { - let ($($T,)*) = self; - $( - $T.ensure_active_session(session, world)?; - )* - Ok(()) - } - fn deep_clone_key(key: &Self::Key) -> Self::Key { let ($($K,)*) = key; ($( @@ -352,11 +479,11 @@ macro_rules! impl_buffered_for_tuple { } } -// Implements the `Buffered` trait for all tuples between size 2 and 12 -// (inclusive) made of types that implement `Buffered` +// Implements the `Buffering` trait for all tuples between size 2 and 12 +// (inclusive) made of types that implement `Buffering` all_tuples!(impl_buffered_for_tuple, 2, 12, T, K); -impl Buffered for [T; N] { +impl Buffering for [T; N] { fn verify_scope(&self, scope: Entity) { for buffer in self.iter() { buffer.verify_scope(scope); @@ -375,15 +502,6 @@ impl Buffered for [T; N] { Ok(min_count.unwrap_or(0)) } - // TODO(@mxgrey) We may be able to use [T::Item; N] here instead of SmallVec - // when try_map is stabilized: https://github.com/rust-lang/rust/issues/79711 - type Item = SmallVec<[T::Item; N]>; - fn pull(&self, session: Entity, world: &mut World) -> Result { - self.iter() - .map(|buffer| buffer.pull(session, world)) - .collect() - } - fn add_listener(&self, listener: Entity, world: &mut World) -> OperationResult { for buffer in self { buffer.add_listener(listener, world)?; @@ -408,6 +526,27 @@ impl Buffered for [T; N] { self.iter().flat_map(|buffer| buffer.as_input()).collect() } + fn ensure_active_session(&self, session: Entity, world: &mut World) -> OperationResult { + for buffer in self { + buffer.ensure_active_session(session, world)?; + } + + Ok(()) + } +} + +impl Joining for [T; N] { + // TODO(@mxgrey) We may be able to use [T::Item; N] here instead of SmallVec + // when try_map is stabilized: https://github.com/rust-lang/rust/issues/79711 + type Item = SmallVec<[T::Item; N]>; + fn pull(&self, session: Entity, world: &mut World) -> Result { + self.iter() + .map(|buffer| buffer.pull(session, world)) + .collect() + } +} + +impl Accessing for [T; N] { type Key = SmallVec<[T::Key; N]>; fn add_accessor(&self, accessor: Entity, world: &mut World) -> OperationResult { for buffer in self { @@ -424,14 +563,6 @@ impl Buffered for [T; N] { keys } - fn ensure_active_session(&self, session: Entity, world: &mut World) -> OperationResult { - for buffer in self { - buffer.ensure_active_session(session, world)?; - } - - Ok(()) - } - fn deep_clone_key(key: &Self::Key) -> Self::Key { let mut keys = SmallVec::new(); for k in key { @@ -451,7 +582,7 @@ impl Buffered for [T; N] { } } -impl Buffered for SmallVec<[T; N]> { +impl Buffering for SmallVec<[T; N]> { fn verify_scope(&self, scope: Entity) { for buffer in self.iter() { buffer.verify_scope(scope); @@ -470,13 +601,6 @@ impl Buffered for SmallVec<[T; N]> { Ok(min_count.unwrap_or(0)) } - type Item = SmallVec<[T::Item; N]>; - fn pull(&self, session: Entity, world: &mut World) -> Result { - self.iter() - .map(|buffer| buffer.pull(session, world)) - .collect() - } - fn add_listener(&self, listener: Entity, world: &mut World) -> OperationResult { for buffer in self { buffer.add_listener(listener, world)?; @@ -501,6 +625,25 @@ impl Buffered for SmallVec<[T; N]> { self.iter().flat_map(|buffer| buffer.as_input()).collect() } + fn ensure_active_session(&self, session: Entity, world: &mut World) -> OperationResult { + for buffer in self { + buffer.ensure_active_session(session, world)?; + } + + Ok(()) + } +} + +impl Joining for SmallVec<[T; N]> { + type Item = SmallVec<[T::Item; N]>; + fn pull(&self, session: Entity, world: &mut World) -> Result { + self.iter() + .map(|buffer| buffer.pull(session, world)) + .collect() + } +} + +impl Accessing for SmallVec<[T; N]> { type Key = SmallVec<[T::Key; N]>; fn add_accessor(&self, accessor: Entity, world: &mut World) -> OperationResult { for buffer in self { @@ -517,6 +660,68 @@ impl Buffered for SmallVec<[T; N]> { keys } + fn deep_clone_key(key: &Self::Key) -> Self::Key { + let mut keys = SmallVec::new(); + for k in key { + keys.push(T::deep_clone_key(k)); + } + keys + } + + fn is_key_in_use(key: &Self::Key) -> bool { + for k in key { + if T::is_key_in_use(k) { + return true; + } + } + + false + } +} + +impl Buffering for Vec { + fn verify_scope(&self, scope: Entity) { + for buffer in self { + buffer.verify_scope(scope); + } + } + + fn buffered_count(&self, session: Entity, world: &World) -> Result { + let mut min_count = None; + for buffer in self { + let count = buffer.buffered_count(session, world)?; + if !min_count.is_some_and(|min| min < count) { + min_count = Some(count); + } + } + + Ok(min_count.unwrap_or(0)) + } + + fn add_listener(&self, listener: Entity, world: &mut World) -> OperationResult { + for buffer in self { + buffer.add_listener(listener, world)?; + } + Ok(()) + } + + fn gate_action( + &self, + session: Entity, + action: Gate, + world: &mut World, + roster: &mut OperationRoster, + ) -> OperationResult { + for buffer in self { + buffer.gate_action(session, action, world, roster)?; + } + Ok(()) + } + + fn as_input(&self) -> SmallVec<[Entity; 8]> { + self.iter().flat_map(|buffer| buffer.as_input()).collect() + } + fn ensure_active_session(&self, session: Entity, world: &mut World) -> OperationResult { for buffer in self { buffer.ensure_active_session(session, world)?; @@ -524,18 +729,45 @@ impl Buffered for SmallVec<[T; N]> { Ok(()) } +} + +impl Joining for Vec { + type Item = Vec; + fn pull(&self, session: Entity, world: &mut World) -> Result { + self.iter() + .map(|buffer| buffer.pull(session, world)) + .collect() + } +} + +impl Accessing for Vec { + type Key = Vec; + fn add_accessor(&self, accessor: Entity, world: &mut World) -> OperationResult { + for buffer in self { + buffer.add_accessor(accessor, world)?; + } + Ok(()) + } + + fn create_key(&self, builder: &BufferKeyBuilder) -> Self::Key { + let mut keys = Vec::new(); + for buffer in self { + keys.push(buffer.create_key(builder)); + } + keys + } fn deep_clone_key(key: &Self::Key) -> Self::Key { - let mut keys = SmallVec::new(); + let mut keys = Vec::new(); for k in key { - keys.push(T::deep_clone_key(k)); + keys.push(B::deep_clone_key(k)); } keys } fn is_key_in_use(key: &Self::Key) -> bool { for k in key { - if T::is_key_in_use(k) { + if B::is_key_in_use(k) { return true; } } @@ -543,3 +775,25 @@ impl Buffered for SmallVec<[T; N]> { false } } + +pub(crate) fn add_listener_to_source( + source: Entity, + listener: Entity, + world: &mut World, +) -> OperationResult { + let mut targets = world.get_mut::(source).or_broken()?; + if !targets.0.contains(&listener) { + targets.0.push(listener); + } + + if let Some(mut input_storage) = world.get_mut::(listener) { + input_storage.add(source); + } else { + world + .get_entity_mut(listener) + .or_broken()? + .insert(SingleInputStorage::new(source)); + } + + Ok(()) +} diff --git a/src/buffer/json_buffer.rs b/src/buffer/json_buffer.rs new file mode 100644 index 00000000..a3eba1ff --- /dev/null +++ b/src/buffer/json_buffer.rs @@ -0,0 +1,1598 @@ +/* + * Copyright (C) 2025 Open Source Robotics Foundation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * +*/ + +// TODO(@mxgrey): Add module-level documentation describing how to use JsonBuffer + +use std::{ + any::TypeId, + collections::HashMap, + ops::RangeBounds, + sync::{Mutex, OnceLock}, +}; + +use bevy_ecs::{ + prelude::{Commands, Entity, EntityRef, EntityWorldMut, Mut, World}, + system::SystemState, +}; + +use serde::{de::DeserializeOwned, Serialize}; + +pub use serde_json::Value as JsonMessage; + +use smallvec::SmallVec; + +use crate::{ + add_listener_to_source, Accessing, AnyBuffer, AnyBufferAccessInterface, AnyBufferKey, AnyRange, + AsAnyBuffer, Buffer, BufferAccessMut, BufferAccessors, BufferError, BufferIdentifier, + BufferKey, BufferKeyBuilder, BufferKeyLifecycle, BufferKeyTag, BufferLocation, BufferMap, + BufferMapLayout, BufferMapStruct, BufferStorage, Bufferable, Buffering, Builder, DrainBuffer, + Gate, GateState, IncompatibleLayout, InspectBuffer, Joined, Joining, ManageBuffer, + NotifyBufferUpdate, OperationError, OperationResult, OrBroken, +}; + +/// A [`Buffer`] whose message type has been anonymized, but which is known to +/// support serialization and deserialization. Joining this buffer type will +/// yield a [`JsonMessage`]. +#[derive(Clone, Copy, Debug)] +pub struct JsonBuffer { + location: BufferLocation, + interface: &'static (dyn JsonBufferAccessInterface + Send + Sync), +} + +impl JsonBuffer { + /// Downcast this into a concerete [`Buffer`] for the specific message type. + /// + /// To downcast this into a specialized kind of buffer, use [`Self::downcast_buffer`] instead. + pub fn downcast_for_message(&self) -> Option> { + if TypeId::of::() == self.interface.any_access_interface().message_type_id() { + Some(Buffer { + location: self.location, + _ignore: Default::default(), + }) + } else { + None + } + } + + /// Downcast this into a different specialized buffer representation. + pub fn downcast_buffer(&self) -> Option { + self.as_any_buffer().downcast_buffer::() + } + + /// Register the ability to cast into [`JsonBuffer`] and [`JsonBufferKey`] + /// for buffers containing messages of type `T`. This only needs to be done + /// once in the entire lifespan of a program. + /// + /// Note that this will take effect automatically any time you create an + /// instance of [`JsonBuffer`] or [`JsonBufferKey`] for a buffer with + /// messages of type `T`. + pub fn register_for() + where + T: 'static + Serialize + DeserializeOwned + Send + Sync, + { + // We just need to ensure that this function gets called so that the + // downcast callback gets registered. Nothing more needs to be done. + JsonBufferAccessImpl::::get_interface(); + } + + /// Get the entity ID of the buffer. + pub fn id(&self) -> Entity { + self.location.source + } + + /// Get the ID of the workflow that the buffer is associated with. + pub fn scope(&self) -> Entity { + self.location.scope + } + + /// Get general information about the buffer. + pub fn location(&self) -> BufferLocation { + self.location + } +} + +impl From> for JsonBuffer { + fn from(value: Buffer) -> Self { + Self { + location: value.location, + interface: JsonBufferAccessImpl::::get_interface(), + } + } +} + +impl From for AnyBuffer { + fn from(value: JsonBuffer) -> Self { + Self { + location: value.location, + interface: value.interface.any_access_interface(), + } + } +} + +impl AsAnyBuffer for JsonBuffer { + fn as_any_buffer(&self) -> AnyBuffer { + (*self).into() + } +} + +/// Similar to a [`BufferKey`] except it can be used for any buffer that supports +/// serialization and deserialization without knowing the buffer's specific +/// message type at compile time. +/// +/// This can key be used with a [`World`][1] to directly view or manipulate the +/// contents of a buffer through the [`JsonBufferWorldAccess`] interface. +/// +/// [1]: bevy_ecs::prelude::World +#[derive(Clone)] +pub struct JsonBufferKey { + tag: BufferKeyTag, + interface: &'static (dyn JsonBufferAccessInterface + Send + Sync), +} + +impl JsonBufferKey { + /// Downcast this into a concrete [`BufferKey`] for the specified message type. + /// + /// To downcast to a specialized kind of key, use [`Self::downcast_buffer_key`] instead. + pub fn downcast_for_message(self) -> Option> { + self.as_any_buffer_key().downcast_for_message() + } + + pub fn downcast_buffer_key(self) -> Option { + self.as_any_buffer_key().downcast_buffer_key() + } + + /// Cast this into an [`AnyBufferKey`] + pub fn as_any_buffer_key(self) -> AnyBufferKey { + self.into() + } +} + +impl BufferKeyLifecycle for JsonBufferKey { + type TargetBuffer = JsonBuffer; + + fn create_key(buffer: &Self::TargetBuffer, builder: &BufferKeyBuilder) -> Self { + Self { + tag: builder.make_tag(buffer.id()), + interface: buffer.interface, + } + } + + fn is_in_use(&self) -> bool { + self.tag.is_in_use() + } + + fn deep_clone(&self) -> Self { + Self { + tag: self.tag.deep_clone(), + interface: self.interface, + } + } +} + +impl std::fmt::Debug for JsonBufferKey { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("JsonBufferKey") + .field( + "message_type_name", + &self.interface.any_access_interface().message_type_name(), + ) + .field("tag", &self.tag) + .finish() + } +} + +impl From> for JsonBufferKey { + fn from(value: BufferKey) -> Self { + let interface = JsonBufferAccessImpl::::get_interface(); + JsonBufferKey { + tag: value.tag, + interface, + } + } +} + +impl From for AnyBufferKey { + fn from(value: JsonBufferKey) -> Self { + AnyBufferKey { + tag: value.tag, + interface: value.interface.any_access_interface(), + } + } +} + +/// Similar to [`BufferView`][crate::BufferView], but this can be unlocked with +/// a [`JsonBufferKey`], so it can work for any buffer whose message types +/// support serialization and deserialization. +pub struct JsonBufferView<'a> { + storage: Box, + gate: &'a GateState, + session: Entity, +} + +impl<'a> JsonBufferView<'a> { + /// Get a serialized copy of the oldest message in the buffer. + pub fn oldest(&self) -> JsonMessageViewResult { + self.storage.json_oldest(self.session) + } + + /// Get a serialized copy of the newest message in the buffer. + pub fn newest(&self) -> JsonMessageViewResult { + self.storage.json_newest(self.session) + } + + /// Get a serialized copy of a message in the buffer. + pub fn get(&self, index: usize) -> JsonMessageViewResult { + self.storage.json_get(self.session, index) + } + + /// Get how many messages are in this buffer. + pub fn len(&self) -> usize { + self.storage.json_count(self.session) + } + + /// Check if the buffer is empty. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Check whether the gate of this buffer is open or closed. + pub fn gate(&self) -> Gate { + self.gate + .map + .get(&self.session) + .copied() + .unwrap_or(Gate::Open) + } +} + +/// Similar to [`BufferMut`][crate::BufferMut], but this can be unlocked with a +/// [`JsonBufferKey`], so it can work for any buffer whose message types support +/// serialization and deserialization. +pub struct JsonBufferMut<'w, 's, 'a> { + storage: Box, + gate: Mut<'a, GateState>, + buffer: Entity, + session: Entity, + accessor: Option, + commands: &'a mut Commands<'w, 's>, + modified: bool, +} + +impl<'w, 's, 'a> JsonBufferMut<'w, 's, 'a> { + /// Same as [BufferMut::allow_closed_loops][1]. + /// + /// [1]: crate::BufferMut::allow_closed_loops + pub fn allow_closed_loops(mut self) -> Self { + self.accessor = None; + self + } + + /// Get a serialized copy of the oldest message in the buffer. + pub fn oldest(&self) -> JsonMessageViewResult { + self.storage.json_oldest(self.session) + } + + /// Get a serialized copy of the newest message in the buffer. + pub fn newest(&self) -> JsonMessageViewResult { + self.storage.json_newest(self.session) + } + + /// Get a serialized copy of a message in the buffer. + pub fn get(&self, index: usize) -> JsonMessageViewResult { + self.storage.json_get(self.session, index) + } + + /// Get how many messages are in this buffer. + pub fn len(&self) -> usize { + self.storage.json_count(self.session) + } + + /// Check if the buffer is empty. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Check whether the gate of this buffer is open or closed. + pub fn gate(&self) -> Gate { + self.gate + .map + .get(&self.session) + .copied() + .unwrap_or(Gate::Open) + } + + /// Modify the oldest message in the buffer. + pub fn oldest_mut(&mut self) -> Option> { + self.storage + .json_oldest_mut(self.session, &mut self.modified) + } + + /// Modify the newest message in the buffer. + pub fn newest_mut(&mut self) -> Option> { + self.storage + .json_newest_mut(self.session, &mut self.modified) + } + + /// Modify a message in the buffer. + pub fn get_mut(&mut self, index: usize) -> Option> { + self.storage + .json_get_mut(self.session, index, &mut self.modified) + } + + /// Drain a range of messages out of the buffer. + pub fn drain>(&mut self, range: R) -> DrainJsonBuffer<'_> { + self.modified = true; + DrainJsonBuffer { + interface: self.storage.json_drain(self.session, AnyRange::new(range)), + } + } + + /// Pull the oldest message from the buffer as a JSON value. Unlike + /// [`Self::oldest`] this will remove the message from the buffer. + pub fn pull(&mut self) -> JsonMessageViewResult { + self.modified = true; + self.storage.json_pull(self.session) + } + + /// Pull the oldest message from the buffer and attempt to deserialize it + /// into the target type. + pub fn pull_as(&mut self) -> Result, serde_json::Error> { + self.pull()?.map(|m| serde_json::from_value(m)).transpose() + } + + /// Pull the newest message from the buffer as a JSON value. Unlike + /// [`Self::newest`] this will remove the message from the buffer. + pub fn pull_newest(&mut self) -> JsonMessageViewResult { + self.modified = true; + self.storage.json_pull_newest(self.session) + } + + /// Pull the newest message from the buffer and attempt to deserialize it + /// into the target type. + pub fn pull_newest_as(&mut self) -> Result, serde_json::Error> { + self.pull_newest()? + .map(|m| serde_json::from_value(m)) + .transpose() + } + + /// Attempt to push a new value into the buffer. + /// + /// If the input value is compatible with the message type of the buffer, + /// this will return [`Ok`]. If the buffer is at its limit before a successful + /// push, this will return the value that needed to be removed. + /// + /// If the input value does not match the message type of the buffer, this + /// will return [`Err`]. This may also return [`Err`] if the message coming + /// out of the buffer failed to serialize. + // TODO(@mxgrey): Consider having an error type that differentiates the + // various possible error modes. + pub fn push( + &mut self, + value: T, + ) -> Result, serde_json::Error> { + let message = serde_json::to_value(&value)?; + self.modified = true; + self.storage.json_push(self.session, message) + } + + /// Same as [`Self::push`] but no serialization step is needed for the incoming + /// message. + pub fn push_json( + &mut self, + message: JsonMessage, + ) -> Result, serde_json::Error> { + self.modified = true; + self.storage.json_push(self.session, message) + } + + /// Same as [`Self::push`] but the message will be interpreted as the oldest + /// message in the buffer. + pub fn push_as_oldest( + &mut self, + value: T, + ) -> Result, serde_json::Error> { + let message = serde_json::to_value(&value)?; + self.modified = true; + self.storage.json_push_as_oldest(self.session, message) + } + + /// Same as [`Self::push_as_oldest`] but no serialization step is needed for + /// the incoming message. + pub fn push_json_as_oldest( + &mut self, + message: JsonMessage, + ) -> Result, serde_json::Error> { + self.modified = true; + self.storage.json_push_as_oldest(self.session, message) + } + + /// Tell the buffer [`Gate`] to open. + pub fn open_gate(&mut self) { + if let Some(gate) = self.gate.map.get_mut(&self.session) { + if *gate != Gate::Open { + *gate = Gate::Open; + self.modified = true; + } + } + } + + /// Tell the buffer [`Gate`] to close. + pub fn close_gate(&mut self) { + if let Some(gate) = self.gate.map.get_mut(&self.session) { + *gate = Gate::Closed; + // There is no need to to indicate that a modification happened + // because listeners do not get notified about gates closing. + } + } + + /// Perform an action on the gate of the buffer. + pub fn gate_action(&mut self, action: Gate) { + match action { + Gate::Open => self.open_gate(), + Gate::Closed => self.close_gate(), + } + } + + /// Trigger the listeners for this buffer to wake up even if nothing in the + /// buffer has changed. This could be used for timers or timeout elements + /// in a workflow. + pub fn pulse(&mut self) { + self.modified = true; + } +} + +impl<'w, 's, 'a> Drop for JsonBufferMut<'w, 's, 'a> { + fn drop(&mut self) { + if self.modified { + self.commands.add(NotifyBufferUpdate::new( + self.buffer, + self.session, + self.accessor, + )); + } + } +} + +pub trait JsonBufferWorldAccess { + /// Call this to get read-only access to any buffer whose message type is + /// serializable and deserializable. + /// + /// For technical reasons this requires direct [`World`] access, but you can + /// do other read-only queries on the world while holding onto the + /// [`JsonBufferView`]. + fn json_buffer_view(&self, key: &JsonBufferKey) -> Result, BufferError>; + + /// Call this to get mutable access to any buffer whose message type is + /// serializable and deserializable. + /// + /// Pass in a callback that will receive a [`JsonBufferMut`], allowing it to + /// view and modify the contents of the buffer. + fn json_buffer_mut( + &mut self, + key: &JsonBufferKey, + f: impl FnOnce(JsonBufferMut) -> U, + ) -> Result; +} + +impl JsonBufferWorldAccess for World { + fn json_buffer_view(&self, key: &JsonBufferKey) -> Result, BufferError> { + key.interface.create_json_buffer_view(key, self) + } + + fn json_buffer_mut( + &mut self, + key: &JsonBufferKey, + f: impl FnOnce(JsonBufferMut) -> U, + ) -> Result { + let interface = key.interface; + let mut state = interface.create_json_buffer_access_mut_state(self); + let mut access = state.get_json_buffer_access_mut(self); + let buffer_mut = access.as_json_buffer_mut(key)?; + Ok(f(buffer_mut)) + } +} + +/// View or modify a buffer message in terms of JSON values. +pub struct JsonMut<'a> { + interface: &'a mut dyn JsonMutInterface, + modified: &'a mut bool, +} + +impl<'a> JsonMut<'a> { + /// Serialize the message within the buffer into JSON. + /// + /// This new [`JsonMessage`] will be a duplicate of the data of the message + /// inside the buffer, effectively meaning this function clones the data. + pub fn serialize(&self) -> Result { + self.interface.serialize() + } + + /// This will first serialize the message within the buffer into JSON and + /// then attempt to deserialize it into the target type. + /// + /// The target type does not need to match the message type inside the buffer, + /// as long as the target type can be deserialized from a serialized value + /// of the buffer's message type. + /// + /// The returned value will duplicate the data of the message inside the + /// buffer, effectively meaning this function clones the data. + pub fn deserialize_into(&self) -> Result { + serde_json::from_value::(self.serialize()?) + } + + /// Replace the underlying message with new data, and receive its original + /// data as JSON. + #[must_use = "if you are going to discard the returned message, use insert instead"] + pub fn replace(&mut self, message: JsonMessage) -> JsonMessageReplaceResult { + *self.modified = true; + self.interface.replace(message) + } + + /// Insert new data into the underyling message. This is the same as replace + /// except it is more efficient if you don't care about the original data, + /// because it will discard the original data instead of serializing it. + pub fn insert(&mut self, message: JsonMessage) -> Result<(), serde_json::Error> { + *self.modified = true; + self.interface.insert(message) + } + + /// Modify the data of the underlying message. This is equivalent to calling + /// [`Self::serialize`], modifying the value, and then calling [`Self::insert`]. + /// The benefit of this function is that you do not need to remember to + /// insert after you have finished your modifications. + pub fn modify(&mut self, f: impl FnOnce(&mut JsonMessage)) -> Result<(), serde_json::Error> { + let mut message = self.serialize()?; + f(&mut message); + self.insert(message) + } +} + +/// The return type for functions that give a JSON view of a message in a buffer. +/// If an error occurs while attempting to serialize the message, this will return +/// [`Err`]. +/// +/// If this returns [`Ok`] then [`None`] means there was no message available at +/// the requested location while [`Some`] will contain a serialized copy of the +/// message. +pub type JsonMessageViewResult = Result, serde_json::Error>; + +/// The return type for functions that push a new message into a buffer. If an +/// error occurs while deserializing the message into the buffer's message type +/// then this will return [`Err`]. +/// +/// If this returns [`Ok`] then [`None`] means the new message was added and all +/// prior messages have been retained in the buffer. [`Some`] will contain an +/// old message which has now been removed from the buffer. +pub type JsonMessagePushResult = Result, serde_json::Error>; + +/// The return type for functions that replace (swap out) one message with +/// another. If an error occurs while serializing or deserializing either +/// message to/from the buffer's message type then this will return [`Err`]. +/// +/// If this returns [`Ok`] then the message was successfully replaced, and the +/// value inside [`Ok`] is the message that was previously in the buffer. +pub type JsonMessageReplaceResult = Result; + +trait JsonBufferViewing { + fn json_count(&self, session: Entity) -> usize; + fn json_oldest<'a>(&'a self, session: Entity) -> JsonMessageViewResult; + fn json_newest<'a>(&'a self, session: Entity) -> JsonMessageViewResult; + fn json_get<'a>(&'a self, session: Entity, index: usize) -> JsonMessageViewResult; +} + +trait JsonBufferManagement: JsonBufferViewing { + fn json_push(&mut self, session: Entity, value: JsonMessage) -> JsonMessagePushResult; + fn json_push_as_oldest(&mut self, session: Entity, value: JsonMessage) + -> JsonMessagePushResult; + fn json_pull(&mut self, session: Entity) -> JsonMessageViewResult; + fn json_pull_newest(&mut self, session: Entity) -> JsonMessageViewResult; + fn json_oldest_mut<'a>( + &'a mut self, + session: Entity, + modified: &'a mut bool, + ) -> Option>; + fn json_newest_mut<'a>( + &'a mut self, + session: Entity, + modified: &'a mut bool, + ) -> Option>; + fn json_get_mut<'a>( + &'a mut self, + session: Entity, + index: usize, + modified: &'a mut bool, + ) -> Option>; + fn json_drain<'a>( + &'a mut self, + session: Entity, + range: AnyRange, + ) -> Box; +} + +impl JsonBufferViewing for &'_ BufferStorage +where + T: 'static + Send + Sync + Serialize + DeserializeOwned, +{ + fn json_count(&self, session: Entity) -> usize { + self.count(session) + } + + fn json_oldest<'a>(&'a self, session: Entity) -> JsonMessageViewResult { + self.oldest(session).map(serde_json::to_value).transpose() + } + + fn json_newest<'a>(&'a self, session: Entity) -> JsonMessageViewResult { + self.newest(session).map(serde_json::to_value).transpose() + } + + fn json_get<'a>(&'a self, session: Entity, index: usize) -> JsonMessageViewResult { + self.get(session, index) + .map(serde_json::to_value) + .transpose() + } +} + +impl JsonBufferViewing for Mut<'_, BufferStorage> +where + T: 'static + Send + Sync + Serialize + DeserializeOwned, +{ + fn json_count(&self, session: Entity) -> usize { + self.count(session) + } + + fn json_oldest<'a>(&'a self, session: Entity) -> JsonMessageViewResult { + self.oldest(session).map(serde_json::to_value).transpose() + } + + fn json_newest<'a>(&'a self, session: Entity) -> JsonMessageViewResult { + self.newest(session).map(serde_json::to_value).transpose() + } + + fn json_get<'a>(&'a self, session: Entity, index: usize) -> JsonMessageViewResult { + self.get(session, index) + .map(serde_json::to_value) + .transpose() + } +} + +impl JsonBufferManagement for Mut<'_, BufferStorage> +where + T: 'static + Send + Sync + Serialize + DeserializeOwned, +{ + fn json_push(&mut self, session: Entity, value: JsonMessage) -> JsonMessagePushResult { + let value: T = serde_json::from_value(value)?; + self.push(session, value) + .map(serde_json::to_value) + .transpose() + } + + fn json_push_as_oldest( + &mut self, + session: Entity, + value: JsonMessage, + ) -> JsonMessagePushResult { + let value: T = serde_json::from_value(value)?; + self.push(session, value) + .map(serde_json::to_value) + .transpose() + } + + fn json_pull(&mut self, session: Entity) -> JsonMessageViewResult { + self.pull(session).map(serde_json::to_value).transpose() + } + + fn json_pull_newest(&mut self, session: Entity) -> JsonMessageViewResult { + self.pull_newest(session) + .map(serde_json::to_value) + .transpose() + } + + fn json_oldest_mut<'a>( + &'a mut self, + session: Entity, + modified: &'a mut bool, + ) -> Option> { + self.oldest_mut(session).map(|interface| JsonMut { + interface, + modified, + }) + } + + fn json_newest_mut<'a>( + &'a mut self, + session: Entity, + modified: &'a mut bool, + ) -> Option> { + self.newest_mut(session).map(|interface| JsonMut { + interface, + modified, + }) + } + + fn json_get_mut<'a>( + &'a mut self, + session: Entity, + index: usize, + modified: &'a mut bool, + ) -> Option> { + self.get_mut(session, index).map(|interface| JsonMut { + interface, + modified, + }) + } + + fn json_drain<'a>( + &'a mut self, + session: Entity, + range: AnyRange, + ) -> Box { + Box::new(self.drain(session, range)) + } +} + +trait JsonMutInterface { + /// Serialize the underlying message into JSON + fn serialize(&self) -> Result; + /// Replace the underlying message with new data, and receive its original + /// data as JSON + fn replace(&mut self, message: JsonMessage) -> JsonMessageReplaceResult; + /// Insert new data into the underyling message. This is the same as replace + /// except it is more efficient if you don't care about the original data, + /// because it will discard the original data instead of serializing it. + fn insert(&mut self, message: JsonMessage) -> Result<(), serde_json::Error>; +} + +impl JsonMutInterface for T { + fn serialize(&self) -> Result { + serde_json::to_value(self) + } + + fn replace(&mut self, message: JsonMessage) -> JsonMessageReplaceResult { + let new_message: T = serde_json::from_value(message)?; + let old_message = serde_json::to_value(&self)?; + *self = new_message; + Ok(old_message) + } + + fn insert(&mut self, message: JsonMessage) -> Result<(), serde_json::Error> { + let new_message: T = serde_json::from_value(message)?; + *self = new_message; + Ok(()) + } +} + +trait JsonBufferAccessInterface { + fn any_access_interface(&self) -> &'static (dyn AnyBufferAccessInterface + Send + Sync); + + fn buffered_count( + &self, + buffer_ref: &EntityRef, + session: Entity, + ) -> Result; + + fn ensure_session(&self, buffer_mut: &mut EntityWorldMut, session: Entity) -> OperationResult; + + fn pull( + &self, + buffer_mut: &mut EntityWorldMut, + session: Entity, + ) -> Result; + + fn create_json_buffer_view<'a>( + &self, + key: &JsonBufferKey, + world: &'a World, + ) -> Result, BufferError>; + + fn create_json_buffer_access_mut_state( + &self, + world: &mut World, + ) -> Box; +} + +impl<'a> std::fmt::Debug for &'a (dyn JsonBufferAccessInterface + Send + Sync) { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Message Properties") + .field("type", &self.any_access_interface().message_type_name()) + .finish() + } +} + +struct JsonBufferAccessImpl(std::marker::PhantomData); + +impl JsonBufferAccessImpl { + pub(crate) fn get_interface() -> &'static (dyn JsonBufferAccessInterface + Send + Sync) { + // Create and cache the json buffer access interface + static INTERFACE_MAP: OnceLock< + Mutex>, + > = OnceLock::new(); + let interfaces = INTERFACE_MAP.get_or_init(|| Mutex::default()); + + let mut interfaces_mut = interfaces.lock().unwrap(); + *interfaces_mut.entry(TypeId::of::()).or_insert_with(|| { + // Register downcasting for JsonBuffer and JsonBufferKey the + // first time that we retrieve an interface for this type. + let any_interface = AnyBuffer::interface_for::(); + any_interface.register_buffer_downcast( + TypeId::of::(), + Box::new(|location| { + Box::new(JsonBuffer { + location, + interface: Self::get_interface(), + }) + }), + ); + + any_interface.register_key_downcast( + TypeId::of::(), + Box::new(|tag| { + Box::new(JsonBufferKey { + tag, + interface: Self::get_interface(), + }) + }), + ); + + // SAFETY: This will leak memory exactly once per type, so the leakage is bounded. + // Leaking this allows the interface to be shared freely across all instances. + Box::leak(Box::new(JsonBufferAccessImpl::(Default::default()))) + }) + } +} + +impl JsonBufferAccessInterface + for JsonBufferAccessImpl +{ + fn any_access_interface(&self) -> &'static (dyn AnyBufferAccessInterface + Send + Sync) { + AnyBuffer::interface_for::() + } + + fn buffered_count( + &self, + buffer_ref: &EntityRef, + session: Entity, + ) -> Result { + buffer_ref.buffered_count::(session) + } + + fn ensure_session(&self, buffer_mut: &mut EntityWorldMut, session: Entity) -> OperationResult { + buffer_mut.ensure_session::(session) + } + + fn pull( + &self, + buffer_mut: &mut EntityWorldMut, + session: Entity, + ) -> Result { + let value = buffer_mut.pull_from_buffer::(session)?; + serde_json::to_value(value).or_broken() + } + + fn create_json_buffer_view<'a>( + &self, + key: &JsonBufferKey, + world: &'a World, + ) -> Result, BufferError> { + let buffer_ref = world + .get_entity(key.tag.buffer) + .ok_or(BufferError::BufferMissing)?; + let storage = buffer_ref + .get::>() + .ok_or(BufferError::BufferMissing)?; + let gate = buffer_ref + .get::() + .ok_or(BufferError::BufferMissing)?; + Ok(JsonBufferView { + storage: Box::new(storage), + gate, + session: key.tag.session, + }) + } + + fn create_json_buffer_access_mut_state( + &self, + world: &mut World, + ) -> Box { + Box::new(SystemState::>::new(world)) + } +} + +trait JsonBufferAccessMutState { + fn get_json_buffer_access_mut<'s, 'w: 's>( + &'s mut self, + world: &'w mut World, + ) -> Box + 's>; +} + +impl JsonBufferAccessMutState for SystemState> +where + T: 'static + Send + Sync + Serialize + DeserializeOwned, +{ + fn get_json_buffer_access_mut<'s, 'w: 's>( + &'s mut self, + world: &'w mut World, + ) -> Box + 's> { + Box::new(self.get_mut(world)) + } +} + +trait JsonBufferAccessMut<'w, 's> { + fn as_json_buffer_mut<'a>( + &'a mut self, + key: &JsonBufferKey, + ) -> Result, BufferError>; +} + +impl<'w, 's, T> JsonBufferAccessMut<'w, 's> for BufferAccessMut<'w, 's, T> +where + T: 'static + Send + Sync + Serialize + DeserializeOwned, +{ + fn as_json_buffer_mut<'a>( + &'a mut self, + key: &JsonBufferKey, + ) -> Result, BufferError> { + let BufferAccessMut { query, commands } = self; + let (storage, gate) = query + .get_mut(key.tag.buffer) + .map_err(|_| BufferError::BufferMissing)?; + Ok(JsonBufferMut { + storage: Box::new(storage), + gate, + buffer: key.tag.buffer, + session: key.tag.session, + accessor: Some(key.tag.accessor), + commands, + modified: false, + }) + } +} + +pub struct DrainJsonBuffer<'a> { + interface: Box, +} + +impl<'a> Iterator for DrainJsonBuffer<'a> { + type Item = Result; + + fn next(&mut self) -> Option { + self.interface.json_next() + } +} + +trait DrainJsonBufferInterface { + fn json_next(&mut self) -> Option>; +} + +impl DrainJsonBufferInterface for DrainBuffer<'_, T> { + fn json_next(&mut self) -> Option> { + self.next().map(serde_json::to_value) + } +} + +impl Bufferable for JsonBuffer { + type BufferType = Self; + fn into_buffer(self, builder: &mut Builder) -> Self::BufferType { + assert_eq!(self.scope(), builder.scope()); + self + } +} + +impl Buffering for JsonBuffer { + fn verify_scope(&self, scope: Entity) { + assert_eq!(scope, self.scope()); + } + + fn buffered_count(&self, session: Entity, world: &World) -> Result { + let buffer_ref = world.get_entity(self.id()).or_broken()?; + self.interface.buffered_count(&buffer_ref, session) + } + + fn add_listener(&self, listener: Entity, world: &mut World) -> OperationResult { + add_listener_to_source(self.id(), listener, world) + } + + fn gate_action( + &self, + session: Entity, + action: Gate, + world: &mut World, + roster: &mut crate::OperationRoster, + ) -> OperationResult { + GateState::apply(self.id(), session, action, world, roster) + } + + fn as_input(&self) -> smallvec::SmallVec<[Entity; 8]> { + SmallVec::from_iter([self.id()]) + } + + fn ensure_active_session(&self, session: Entity, world: &mut World) -> OperationResult { + let mut buffer_mut = world.get_entity_mut(self.id()).or_broken()?; + self.interface.ensure_session(&mut buffer_mut, session) + } +} + +impl Joining for JsonBuffer { + type Item = JsonMessage; + fn pull(&self, session: Entity, world: &mut World) -> Result { + let mut buffer_mut = world.get_entity_mut(self.id()).or_broken()?; + self.interface.pull(&mut buffer_mut, session) + } +} + +impl Accessing for JsonBuffer { + type Key = JsonBufferKey; + fn add_accessor(&self, accessor: Entity, world: &mut World) -> OperationResult { + world + .get_mut::(self.id()) + .or_broken()? + .add_accessor(accessor); + Ok(()) + } + + fn create_key(&self, builder: &BufferKeyBuilder) -> Self::Key { + JsonBufferKey { + tag: builder.make_tag(self.id()), + interface: self.interface, + } + } + + fn deep_clone_key(key: &Self::Key) -> Self::Key { + key.deep_clone() + } + + fn is_key_in_use(key: &Self::Key) -> bool { + key.is_in_use() + } +} + +impl Joined for serde_json::Map { + type Buffers = HashMap; +} + +impl BufferMapLayout for HashMap { + fn try_from_buffer_map(buffers: &BufferMap) -> Result { + let mut downcast_buffers = HashMap::new(); + let mut compatibility = IncompatibleLayout::default(); + for name in buffers.keys() { + match name { + BufferIdentifier::Name(name) => { + if let Ok(downcast) = + compatibility.require_buffer_for_borrowed_name::(&name, buffers) + { + downcast_buffers.insert(name.clone().into_owned(), downcast); + } + } + BufferIdentifier::Index(index) => { + compatibility + .forbidden_buffers + .push(BufferIdentifier::Index(*index)); + } + } + } + + compatibility.as_result()?; + Ok(downcast_buffers) + } +} + +impl BufferMapStruct for HashMap { + fn buffer_list(&self) -> SmallVec<[AnyBuffer; 8]> { + self.values().map(|b| b.as_any_buffer()).collect() + } +} + +impl Joining for HashMap { + type Item = serde_json::Map; + fn pull(&self, session: Entity, world: &mut World) -> Result { + self.iter() + .map(|(key, value)| value.pull(session, world).map(|v| (key.clone(), v))) + .collect() + } +} + +#[cfg(test)] +mod tests { + use crate::{prelude::*, testing::*, AddBufferToMap}; + use bevy_ecs::prelude::World; + use serde::{Deserialize, Serialize}; + + #[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)] + struct TestMessage { + v_i32: i32, + v_u32: u32, + v_string: String, + } + + impl TestMessage { + fn new() -> Self { + Self { + v_i32: 1, + v_u32: 2, + v_string: "hello".to_string(), + } + } + } + + #[test] + fn test_json_count() { + let mut context = TestingContext::minimal_plugins(); + + let workflow = context.spawn_io_workflow(|scope, builder| { + let buffer = builder.create_buffer(BufferSettings::keep_all()); + let push_multiple_times = builder + .commands() + .spawn_service(push_multiple_times_into_buffer.into_blocking_service()); + let count = builder + .commands() + .spawn_service(get_buffer_count.into_blocking_service()); + + scope + .input + .chain(builder) + .with_access(buffer) + .then(push_multiple_times) + .then(count) + .connect(scope.terminate); + }); + + let msg = TestMessage::new(); + let mut promise = + context.command(|commands| commands.request(msg, workflow).take_response()); + + context.run_with_conditions(&mut promise, Duration::from_secs(2)); + let count = promise.take().available().unwrap(); + assert_eq!(count, 5); + assert!(context.no_unhandled_errors()); + } + + fn push_multiple_times_into_buffer( + In((value, key)): In<(TestMessage, BufferKey)>, + mut access: BufferAccessMut, + ) -> JsonBufferKey { + let mut buffer = access.get_mut(&key).unwrap(); + for _ in 0..5 { + buffer.push(value.clone()); + } + + key.into() + } + + fn get_buffer_count(In(key): In, world: &mut World) -> usize { + world.json_buffer_view(&key).unwrap().len() + } + + #[test] + fn test_modify_json_message() { + let mut context = TestingContext::minimal_plugins(); + + let workflow = context.spawn_io_workflow(|scope, builder| { + let buffer = builder.create_buffer(BufferSettings::keep_all()); + let push_multiple_times = builder + .commands() + .spawn_service(push_multiple_times_into_buffer.into_blocking_service()); + let modify_content = builder + .commands() + .spawn_service(modify_buffer_content.into_blocking_service()); + let drain_content = builder + .commands() + .spawn_service(pull_each_buffer_item.into_blocking_service()); + + scope + .input + .chain(builder) + .with_access(buffer) + .then(push_multiple_times) + .then(modify_content) + .then(drain_content) + .connect(scope.terminate); + }); + + let msg = TestMessage::new(); + let mut promise = + context.command(|commands| commands.request(msg, workflow).take_response()); + + context.run_with_conditions(&mut promise, Duration::from_secs(2)); + let values = promise.take().available().unwrap(); + assert_eq!(values.len(), 5); + for i in 0..values.len() { + let v_i32 = values[i].get("v_i32").unwrap().as_i64().unwrap(); + assert_eq!(v_i32, i as i64); + } + assert!(context.no_unhandled_errors()); + } + + fn modify_buffer_content(In(key): In, world: &mut World) -> JsonBufferKey { + world + .json_buffer_mut(&key, |mut access| { + for i in 0..access.len() { + access + .get_mut(i) + .unwrap() + .modify(|value| { + let v_i32 = value.get_mut("v_i32").unwrap(); + let modified_v_i32 = i as i64 * v_i32.as_i64().unwrap(); + *v_i32 = modified_v_i32.into(); + }) + .unwrap(); + } + }) + .unwrap(); + + key + } + + fn pull_each_buffer_item(In(key): In, world: &mut World) -> Vec { + world + .json_buffer_mut(&key, |mut access| { + let mut values = Vec::new(); + while let Ok(Some(value)) = access.pull() { + values.push(value); + } + values + }) + .unwrap() + } + + #[test] + fn test_drain_json_message() { + let mut context = TestingContext::minimal_plugins(); + + let workflow = context.spawn_io_workflow(|scope, builder| { + let buffer = builder.create_buffer(BufferSettings::keep_all()); + let push_multiple_times = builder + .commands() + .spawn_service(push_multiple_times_into_buffer.into_blocking_service()); + let modify_content = builder + .commands() + .spawn_service(modify_buffer_content.into_blocking_service()); + let drain_content = builder + .commands() + .spawn_service(drain_buffer_contents.into_blocking_service()); + + scope + .input + .chain(builder) + .with_access(buffer) + .then(push_multiple_times) + .then(modify_content) + .then(drain_content) + .connect(scope.terminate); + }); + + let msg = TestMessage::new(); + let mut promise = + context.command(|commands| commands.request(msg, workflow).take_response()); + + context.run_with_conditions(&mut promise, Duration::from_secs(2)); + let values = promise.take().available().unwrap(); + assert_eq!(values.len(), 5); + for i in 0..values.len() { + let v_i32 = values[i].get("v_i32").unwrap().as_i64().unwrap(); + assert_eq!(v_i32, i as i64); + } + assert!(context.no_unhandled_errors()); + } + + fn drain_buffer_contents(In(key): In, world: &mut World) -> Vec { + world + .json_buffer_mut(&key, |mut access| { + access.drain(..).collect::, _>>() + }) + .unwrap() + .unwrap() + } + + #[test] + fn double_json_messages() { + let mut context = TestingContext::minimal_plugins(); + + let workflow = context.spawn_io_workflow(|scope, builder| { + let buffer_double_u32: JsonBuffer = builder + .create_buffer::(BufferSettings::default()) + .into(); + let buffer_double_i32: JsonBuffer = builder + .create_buffer::(BufferSettings::default()) + .into(); + let buffer_double_string: JsonBuffer = builder + .create_buffer::(BufferSettings::default()) + .into(); + + scope.input.chain(builder).fork_clone(( + |chain: Chain<_>| { + chain + .map_block(|mut msg: TestMessage| { + msg.v_u32 *= 2; + msg + }) + .connect( + buffer_double_u32 + .downcast_for_message::() + .unwrap() + .input_slot(), + ) + }, + |chain: Chain<_>| { + chain + .map_block(|mut msg: TestMessage| { + msg.v_i32 *= 2; + msg + }) + .connect( + buffer_double_i32 + .downcast_for_message::() + .unwrap() + .input_slot(), + ) + }, + |chain: Chain<_>| { + chain + .map_block(|mut msg: TestMessage| { + msg.v_string = msg.v_string.clone() + &msg.v_string; + msg + }) + .connect( + buffer_double_string + .downcast_for_message::() + .unwrap() + .input_slot(), + ) + }, + )); + + (buffer_double_u32, buffer_double_i32, buffer_double_string) + .join(builder) + .connect(scope.terminate); + }); + + let msg = TestMessage::new(); + let mut promise = + context.command(|commands| commands.request(msg, workflow).take_response()); + + context.run_with_conditions(&mut promise, Duration::from_secs(2)); + let (double_u32, double_i32, double_string) = promise.take().available().unwrap(); + assert_eq!(4, double_u32.get("v_u32").unwrap().as_i64().unwrap()); + assert_eq!(2, double_i32.get("v_i32").unwrap().as_i64().unwrap()); + assert_eq!( + "hellohello", + double_string.get("v_string").unwrap().as_str().unwrap() + ); + assert!(context.no_unhandled_errors()); + } + + #[test] + fn test_buffer_downcast() { + let mut context = TestingContext::minimal_plugins(); + + let workflow = context.spawn_io_workflow(|scope, builder| { + // We just need to test that these buffers can be downcast without + // a panic occurring. + JsonBuffer::register_for::(); + let buffer = builder.create_buffer::(BufferSettings::keep_all()); + let any_buffer: AnyBuffer = buffer.into(); + let json_buffer: JsonBuffer = any_buffer.downcast_buffer().unwrap(); + let _original_from_any: Buffer = + any_buffer.downcast_for_message().unwrap(); + let _original_from_json: Buffer = + json_buffer.downcast_for_message().unwrap(); + + scope + .input + .chain(builder) + .with_access(buffer) + .map_block(|(data, key)| { + let any_key: AnyBufferKey = key.clone().into(); + let json_key: JsonBufferKey = any_key.clone().downcast_buffer_key().unwrap(); + let _original_from_any: BufferKey = + any_key.downcast_for_message().unwrap(); + let _original_from_json: BufferKey = + json_key.downcast_for_message().unwrap(); + + data + }) + .connect(scope.terminate); + }); + + let mut promise = context.command(|commands| commands.request(1, workflow).take_response()); + + context.run_with_conditions(&mut promise, Duration::from_secs(2)); + let response = promise.take().available().unwrap(); + assert_eq!(1, response); + assert!(context.no_unhandled_errors()); + } + + #[derive(Clone, Joined)] + #[joined(buffers_struct_name = TestJoinedValueJsonBuffers)] + struct TestJoinedValueJson { + integer: i64, + float: f64, + #[joined(buffer = JsonBuffer)] + json: JsonMessage, + } + + #[test] + fn test_try_join_json() { + let mut context = TestingContext::minimal_plugins(); + + let workflow = context.spawn_io_workflow(|scope, builder| { + JsonBuffer::register_for::(); + + let buffer_i64 = builder.create_buffer(BufferSettings::default()); + let buffer_f64 = builder.create_buffer(BufferSettings::default()); + let buffer_json = builder.create_buffer(BufferSettings::default()); + + let mut buffers = BufferMap::default(); + buffers.insert_buffer("integer", buffer_i64); + buffers.insert_buffer("float", buffer_f64); + buffers.insert_buffer("json", buffer_json); + + scope.input.chain(builder).fork_unzip(( + |chain: Chain<_>| chain.connect(buffer_i64.input_slot()), + |chain: Chain<_>| chain.connect(buffer_f64.input_slot()), + |chain: Chain<_>| chain.connect(buffer_json.input_slot()), + )); + + builder.try_join(&buffers).unwrap().connect(scope.terminate); + }); + + let mut promise = context.command(|commands| { + commands + .request((5_i64, 3.14_f64, TestMessage::new()), workflow) + .take_response() + }); + + context.run_with_conditions(&mut promise, Duration::from_secs(2)); + let value: TestJoinedValueJson = promise.take().available().unwrap(); + assert_eq!(value.integer, 5); + assert_eq!(value.float, 3.14); + let deserialized_json: TestMessage = serde_json::from_value(value.json).unwrap(); + let expected_json = TestMessage::new(); + assert_eq!(deserialized_json, expected_json); + } + + #[test] + fn test_joined_value_json() { + let mut context = TestingContext::minimal_plugins(); + + let workflow = context.spawn_io_workflow(|scope, builder| { + JsonBuffer::register_for::(); + + let json_buffer = builder.create_buffer::(BufferSettings::default()); + let buffers = TestJoinedValueJsonBuffers { + integer: builder.create_buffer(BufferSettings::default()), + float: builder.create_buffer(BufferSettings::default()), + json: json_buffer.into(), + }; + + scope.input.chain(builder).fork_unzip(( + |chain: Chain<_>| chain.connect(buffers.integer.input_slot()), + |chain: Chain<_>| chain.connect(buffers.float.input_slot()), + |chain: Chain<_>| chain.connect(json_buffer.input_slot()), + )); + + builder.join(buffers).connect(scope.terminate); + }); + + let mut promise = context.command(|commands| { + commands + .request((5_i64, 3.14_f64, TestMessage::new()), workflow) + .take_response() + }); + + context.run_with_conditions(&mut promise, Duration::from_secs(2)); + let value: TestJoinedValueJson = promise.take().available().unwrap(); + assert_eq!(value.integer, 5); + assert_eq!(value.float, 3.14); + let deserialized_json: TestMessage = serde_json::from_value(value.json).unwrap(); + let expected_json = TestMessage::new(); + assert_eq!(deserialized_json, expected_json); + } + + #[test] + fn test_select_buffers_json() { + let mut context = TestingContext::minimal_plugins(); + + let workflow = context.spawn_io_workflow(|scope, builder| { + let buffer_integer = builder.create_buffer::(BufferSettings::default()); + let buffer_float = builder.create_buffer::(BufferSettings::default()); + let buffer_json = + JsonBuffer::from(builder.create_buffer::(BufferSettings::default())); + + let buffers = + TestJoinedValueJson::select_buffers(buffer_integer, buffer_float, buffer_json); + + scope.input.chain(builder).fork_unzip(( + |chain: Chain<_>| chain.connect(buffers.integer.input_slot()), + |chain: Chain<_>| chain.connect(buffers.float.input_slot()), + |chain: Chain<_>| { + chain.connect(buffers.json.downcast_for_message().unwrap().input_slot()) + }, + )); + + builder.join(buffers).connect(scope.terminate); + }); + + let mut promise = context.command(|commands| { + commands + .request((5_i64, 3.14_f64, TestMessage::new()), workflow) + .take_response() + }); + + context.run_with_conditions(&mut promise, Duration::from_secs(2)); + let value: TestJoinedValueJson = promise.take().available().unwrap(); + assert_eq!(value.integer, 5); + assert_eq!(value.float, 3.14); + let deserialized_json: TestMessage = serde_json::from_value(value.json).unwrap(); + let expected_json = TestMessage::new(); + assert_eq!(deserialized_json, expected_json); + } + + #[test] + fn test_join_json_buffer_vec() { + let mut context = TestingContext::minimal_plugins(); + + let workflow = context.spawn_io_workflow(|scope, builder| { + let buffer_u32 = builder.create_buffer::(BufferSettings::default()); + let buffer_i32 = builder.create_buffer::(BufferSettings::default()); + let buffer_string = builder.create_buffer::(BufferSettings::default()); + let buffer_msg = builder.create_buffer::(BufferSettings::default()); + let buffers: Vec = vec![ + buffer_i32.into(), + buffer_u32.into(), + buffer_string.into(), + buffer_msg.into(), + ]; + + scope + .input + .chain(builder) + .map_block(|msg: TestMessage| (msg.v_u32, msg.v_i32, msg.v_string.clone(), msg)) + .fork_unzip(( + |chain: Chain<_>| chain.connect(buffer_u32.input_slot()), + |chain: Chain<_>| chain.connect(buffer_i32.input_slot()), + |chain: Chain<_>| chain.connect(buffer_string.input_slot()), + |chain: Chain<_>| chain.connect(buffer_msg.input_slot()), + )); + + builder.join(buffers).connect(scope.terminate); + }); + + let mut promise = context.command(|commands| { + commands + .request(TestMessage::new(), workflow) + .take_response() + }); + + context.run_with_conditions(&mut promise, Duration::from_secs(2)); + let values = promise.take().available().unwrap(); + assert_eq!(values.len(), 4); + assert_eq!(values[0], serde_json::Value::Number(1.into())); + assert_eq!(values[1], serde_json::Value::Number(2.into())); + assert_eq!(values[2], serde_json::Value::String("hello".to_string())); + assert_eq!(values[3], serde_json::to_value(TestMessage::new()).unwrap()); + } + + // We define this struct just to make sure the Accessor macro successfully + // compiles with JsonBufferKey. + #[derive(Clone, Accessor)] + #[allow(unused)] + struct TestJsonKeyMap { + integer: BufferKey, + string: BufferKey, + json: JsonBufferKey, + any: AnyBufferKey, + } +} diff --git a/src/buffer/manage_buffer.rs b/src/buffer/manage_buffer.rs index f0a6705a..af9cf65b 100644 --- a/src/buffer/manage_buffer.rs +++ b/src/buffer/manage_buffer.rs @@ -89,6 +89,8 @@ pub trait ManageBuffer { ) -> Result, OperationError>; fn clear_buffer(&mut self, session: Entity) -> OperationResult; + + fn ensure_session(&mut self, session: Entity) -> OperationResult; } impl<'w> ManageBuffer for EntityWorldMut<'w> { @@ -114,4 +116,11 @@ impl<'w> ManageBuffer for EntityWorldMut<'w> { .clear_session(session); Ok(()) } + + fn ensure_session(&mut self, session: Entity) -> OperationResult { + self.get_mut::>() + .or_broken()? + .ensure_session(session); + Ok(()) + } } diff --git a/src/builder.rs b/src/builder.rs index 4476d060..45727868 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -16,20 +16,19 @@ */ use bevy_ecs::prelude::{Commands, Entity}; -use bevy_hierarchy::prelude::BuildChildren; use std::future::Future; use smallvec::SmallVec; use crate::{ - AddOperation, AsMap, BeginCleanupWorkflow, Buffer, BufferItem, BufferKeys, BufferSettings, - Bufferable, Buffered, Chain, Collect, ForkClone, ForkCloneOutput, ForkTargetStorage, Gate, - GateRequest, Injection, InputSlot, IntoAsyncMap, IntoBlockingMap, Node, OperateBuffer, - OperateBufferAccess, OperateDynamicGate, OperateScope, OperateSplit, OperateStaticGate, Output, - Provider, RequestOfMap, ResponseOfMap, Scope, ScopeEndpoints, ScopeSettings, - ScopeSettingsStorage, Sendish, Service, SplitOutputs, Splittable, StreamPack, StreamTargetMap, - StreamsOfMap, Trim, TrimBranch, UnusedTarget, + Accessible, Accessing, Accessor, AddOperation, AsMap, Buffer, BufferKeys, BufferLocation, + BufferMap, BufferSettings, Bufferable, Buffering, Chain, Collect, ForkClone, ForkCloneOutput, + ForkTargetStorage, Gate, GateRequest, IncompatibleLayout, Injection, InputSlot, IntoAsyncMap, + IntoBlockingMap, Joinable, Joined, Node, OperateBuffer, OperateDynamicGate, OperateScope, + OperateSplit, OperateStaticGate, Output, Provider, RequestOfMap, ResponseOfMap, Scope, + ScopeEndpoints, ScopeSettings, ScopeSettingsStorage, Sendish, Service, SplitOutputs, + Splittable, StreamPack, StreamTargetMap, StreamsOfMap, Trim, TrimBranch, UnusedTarget, }; pub(crate) mod connect; @@ -165,8 +164,10 @@ impl<'w, 's, 'a> Builder<'w, 's, 'a> { )); Buffer { - scope: self.scope, - source, + location: BufferLocation { + scope: self.scope(), + source, + }, _ignore: Default::default(), } } @@ -230,27 +231,32 @@ impl<'w, 's, 'a> Builder<'w, 's, 'a> { ) } - /// Alternative way of calling [`Bufferable::join`] - pub fn join<'b, B: Bufferable>(&'b mut self, buffers: B) -> Chain<'w, 's, 'a, 'b, BufferItem> - where - B::BufferType: 'static + Send + Sync, - BufferItem: 'static + Send + Sync, - { + /// Alternative way of calling [`Joinable::join`] + pub fn join<'b, B: Joinable>(&'b mut self, buffers: B) -> Chain<'w, 's, 'a, 'b, B::Item> { buffers.join(self) } - /// Alternative way of calling [`Bufferable::listen`]. - pub fn listen<'b, B: Bufferable>( + /// Try joining a map of buffers into a single value. + pub fn try_join<'b, J: Joined>( &'b mut self, - buffers: B, - ) -> Chain<'w, 's, 'a, 'b, BufferKeys> - where - B::BufferType: 'static + Send + Sync, - BufferKeys: 'static + Send + Sync, - { + buffers: &BufferMap, + ) -> Result, IncompatibleLayout> { + J::try_join_from(buffers, self) + } + + /// Alternative way of calling [`Accessible::listen`]. + pub fn listen<'b, B: Accessible>(&'b mut self, buffers: B) -> Chain<'w, 's, 'a, 'b, B::Keys> { buffers.listen(self) } + /// Try listening to a map of buffers. + pub fn try_listen<'b, Keys: Accessor>( + &'b mut self, + buffers: &BufferMap, + ) -> Result, IncompatibleLayout> { + Keys::try_listen_from(buffers, self) + } + /// Create a node that combines its inputs with access to some buffers. You /// must specify one ore more buffers to access. FOr multiple buffers, /// combine then into a tuple or an [`Iterator`]. Tuples of buffers can be @@ -258,27 +264,29 @@ impl<'w, 's, 'a> Builder<'w, 's, 'a> { /// /// Other [outputs](Output) can also be passed in as buffers. These outputs /// will be transformed into a buffer with default buffer settings. - pub fn create_buffer_access(&mut self, buffers: B) -> Node)> + pub fn create_buffer_access( + &mut self, + buffers: B, + ) -> Node)> where + B::BufferType: Accessing, T: 'static + Send + Sync, - B: Bufferable, - B::BufferType: 'static + Send + Sync, - BufferKeys: 'static + Send + Sync, { let buffers = buffers.into_buffer(self); - let source = self.commands.spawn(()).id(); - let target = self.commands.spawn(UnusedTarget).id(); - self.commands.add(AddOperation::new( - Some(self.scope), - source, - OperateBufferAccess::::new(buffers, target), - )); + buffers.access(self) + } - Node { - input: InputSlot::new(self.scope, source), - output: Output::new(self.scope, target), - streams: (), - } + /// Try to create access to some buffers. Same as [`Self::create_buffer_access`] + /// except it will return an error if the buffers in the [`BufferMap`] are not + /// compatible with the keys that are being asked for. + pub fn try_create_buffer_access( + &mut self, + buffers: &BufferMap, + ) -> Result, IncompatibleLayout> + where + T: 'static + Send + Sync, + { + Keys::try_buffer_access(buffers, self) } /// Collect incoming workflow threads into a container. @@ -385,15 +393,10 @@ impl<'w, 's, 'a> Builder<'w, 's, 'a> { build: impl FnOnce(Scope, (), ()>, &mut Builder) -> Settings, ) where B: Bufferable, - B::BufferType: 'static + Send + Sync, - BufferKeys: 'static + Send + Sync, + B::BufferType: Accessing, Settings: Into, { - self.on_cleanup_if( - CleanupWorkflowConditions::always_if(true, true), - from_buffers, - build, - ) + from_buffers.into_buffer(self).on_cleanup(self, build); } /// Define a cleanup workflow that only gets run if the scope was cancelled. @@ -415,15 +418,10 @@ impl<'w, 's, 'a> Builder<'w, 's, 'a> { build: impl FnOnce(Scope, (), ()>, &mut Builder) -> Settings, ) where B: Bufferable, - B::BufferType: 'static + Send + Sync, - BufferKeys: 'static + Send + Sync, + B::BufferType: Accessing, Settings: Into, { - self.on_cleanup_if( - CleanupWorkflowConditions::always_if(false, true), - from_buffers, - build, - ) + from_buffers.into_buffer(self).on_cancel(self, build); } /// Define a cleanup workflow that only gets run if the scope was successfully @@ -439,15 +437,10 @@ impl<'w, 's, 'a> Builder<'w, 's, 'a> { build: impl FnOnce(Scope, (), ()>, &mut Builder) -> Settings, ) where B: Bufferable, - B::BufferType: 'static + Send + Sync, - BufferKeys: 'static + Send + Sync, + B::BufferType: Accessing, Settings: Into, { - self.on_cleanup_if( - CleanupWorkflowConditions::always_if(true, false), - from_buffers, - build, - ) + from_buffers.into_buffer(self).on_terminate(self, build); } /// Define a sub-workflow that will be run when this workflow is being cleaned @@ -460,31 +453,12 @@ impl<'w, 's, 'a> Builder<'w, 's, 'a> { build: impl FnOnce(Scope, (), ()>, &mut Builder) -> Settings, ) where B: Bufferable, - B::BufferType: 'static + Send + Sync, - BufferKeys: 'static + Send + Sync, + B::BufferType: Accessing, Settings: Into, { - let cancelling_scope_id = self.commands.spawn(()).id(); - let _ = self.create_scope_impl::, (), (), Settings>( - cancelling_scope_id, - self.finish_scope_cancel, - build, - ); - - let begin_cancel = self.commands.spawn(()).set_parent(self.scope).id(); - let buffers = from_buffers.into_buffer(self); - buffers.verify_scope(self.scope); - self.commands.add(AddOperation::new( - None, - begin_cancel, - BeginCleanupWorkflow::::new( - self.scope, - buffers, - cancelling_scope_id, - conditions.run_on_terminate, - conditions.run_on_cancel, - ), - )); + from_buffers + .into_buffer(self) + .on_cleanup_if(self, conditions, build); } /// Create a node that trims (cancels) other nodes in the workflow when it @@ -525,7 +499,6 @@ impl<'w, 's, 'a> Builder<'w, 's, 'a> { pub fn create_gate(&mut self, buffers: B) -> Node, T> where B: Bufferable, - B::BufferType: 'static + Send + Sync, T: 'static + Send + Sync, { let buffers = buffers.into_buffer(self); @@ -555,7 +528,6 @@ impl<'w, 's, 'a> Builder<'w, 's, 'a> { pub fn create_gate_action(&mut self, action: Gate, buffers: B) -> Node where B: Bufferable, - B::BufferType: 'static + Send + Sync, T: 'static + Send + Sync, { let buffers = buffers.into_buffer(self); @@ -582,7 +554,6 @@ impl<'w, 's, 'a> Builder<'w, 's, 'a> { pub fn create_gate_open(&mut self, buffers: B) -> Node where B: Bufferable, - B::BufferType: 'static + Send + Sync, T: 'static + Send + Sync, { self.create_gate_action(Gate::Open, buffers) @@ -594,7 +565,6 @@ impl<'w, 's, 'a> Builder<'w, 's, 'a> { pub fn create_gate_close(&mut self, buffers: B) -> Node where B: Bufferable, - B::BufferType: 'static + Send + Sync, T: 'static + Send + Sync, { self.create_gate_action(Gate::Closed, buffers) @@ -696,8 +666,8 @@ impl<'w, 's, 'a> Builder<'w, 's, 'a> { /// later without breaking API. #[derive(Clone)] pub struct CleanupWorkflowConditions { - run_on_terminate: bool, - run_on_cancel: bool, + pub(crate) run_on_terminate: bool, + pub(crate) run_on_cancel: bool, } impl CleanupWorkflowConditions { diff --git a/src/chain.rs b/src/chain.rs index 7fa5ffa5..fbdee9f6 100644 --- a/src/chain.rs +++ b/src/chain.rs @@ -24,12 +24,12 @@ use smallvec::SmallVec; use std::error::Error; use crate::{ - make_option_branching, make_result_branching, AddOperation, AsMap, Buffer, BufferKey, - BufferKeys, Bufferable, Buffered, Builder, Collect, CreateCancelFilter, CreateDisposalFilter, - ForkTargetStorage, Gate, GateRequest, InputSlot, IntoAsyncMap, IntoBlockingCallback, - IntoBlockingMap, Node, Noop, OperateBufferAccess, OperateDynamicGate, OperateSplit, - OperateStaticGate, Output, ProvideOnce, Provider, Scope, ScopeSettings, Sendish, Service, - Spread, StreamOf, StreamPack, StreamTargetMap, Trim, TrimBranch, UnusedTarget, + make_option_branching, make_result_branching, Accessing, AddOperation, AsMap, Buffer, + BufferKey, BufferKeys, Bufferable, Buffering, Builder, Collect, CreateCancelFilter, + CreateDisposalFilter, ForkTargetStorage, Gate, GateRequest, InputSlot, IntoAsyncMap, + IntoBlockingCallback, IntoBlockingMap, Node, Noop, OperateBufferAccess, OperateDynamicGate, + OperateSplit, OperateStaticGate, Output, ProvideOnce, Provider, Scope, ScopeSettings, Sendish, + Service, Spread, StreamOf, StreamPack, StreamTargetMap, Trim, TrimBranch, UnusedTarget, }; pub mod fork_clone_builder; @@ -298,12 +298,11 @@ impl<'w, 's, 'a, 'b, T: 'static + Send + Sync> Chain<'w, 's, 'a, 'b, T> { /// will be transformed into a buffer with default buffer settings. /// /// To obtain a set of buffer keys each time a buffer is modified, use - /// [`listen`](crate::Bufferable::listen). + /// [`listen`](crate::Accessible::listen). pub fn with_access(self, buffers: B) -> Chain<'w, 's, 'a, 'b, (T, BufferKeys)> where B: Bufferable, - B::BufferType: 'static + Send + Sync, - BufferKeys: 'static + Send + Sync, + B::BufferType: Accessing, { let buffers = buffers.into_buffer(self.builder); buffers.verify_scope(self.builder.scope); @@ -324,8 +323,7 @@ impl<'w, 's, 'a, 'b, T: 'static + Send + Sync> Chain<'w, 's, 'a, 'b, T> { pub fn then_access(self, buffers: B) -> Chain<'w, 's, 'a, 'b, BufferKeys> where B: Bufferable, - B::BufferType: 'static + Send + Sync, - BufferKeys: 'static + Send + Sync, + B::BufferType: Accessing, { self.with_access(buffers).map_block(|(_, key)| key) } @@ -393,7 +391,7 @@ impl<'w, 's, 'a, 'b, T: 'static + Send + Sync> Chain<'w, 's, 'a, 'b, T> { /// The return values of the individual chain builders will be zipped into /// one tuple return value by this function. If all of the builders return /// [`Output`] then you can easily continue chaining more operations using - /// [`join`](crate::Bufferable::join), or destructure them into individual + /// [`join`](crate::Joinable::join), or destructure them into individual /// outputs that you can continue to build with. pub fn fork_clone>(self, build: Build) -> Build::Outputs where @@ -546,7 +544,7 @@ impl<'w, 's, 'a, 'b, T: 'static + Send + Sync> Chain<'w, 's, 'a, 'b, T> { /// If the buffer is broken (e.g. its operation has been despawned) the /// workflow will be cancelled. pub fn then_push(self, buffer: Buffer) -> Chain<'w, 's, 'a, 'b, ()> { - assert_eq!(self.scope(), buffer.scope); + assert_eq!(self.scope(), buffer.scope()); self.with_access(buffer) .then(push_into_buffer.into_blocking_callback()) .cancel_on_err() @@ -557,7 +555,6 @@ impl<'w, 's, 'a, 'b, T: 'static + Send + Sync> Chain<'w, 's, 'a, 'b, T> { pub fn then_gate_action(self, action: Gate, buffers: B) -> Chain<'w, 's, 'a, 'b, T> where B: Bufferable, - B::BufferType: 'static + Send + Sync, { let buffers = buffers.into_buffer(self.builder); buffers.verify_scope(self.builder.scope); @@ -578,7 +575,6 @@ impl<'w, 's, 'a, 'b, T: 'static + Send + Sync> Chain<'w, 's, 'a, 'b, T> { pub fn then_gate_open(self, buffers: B) -> Chain<'w, 's, 'a, 'b, T> where B: Bufferable, - B::BufferType: 'static + Send + Sync, { self.then_gate_action(Gate::Open, buffers) } @@ -588,7 +584,6 @@ impl<'w, 's, 'a, 'b, T: 'static + Send + Sync> Chain<'w, 's, 'a, 'b, T> { pub fn then_gate_close(self, buffers: B) -> Chain<'w, 's, 'a, 'b, T> where B: Bufferable, - B::BufferType: 'static + Send + Sync, { self.then_gate_action(Gate::Closed, buffers) } diff --git a/src/diagram.rs b/src/diagram.rs index f6b70203..b4822a36 100644 --- a/src/diagram.rs +++ b/src/diagram.rs @@ -375,37 +375,7 @@ pub struct Diagram { } impl Diagram { - /// Spawns a workflow from this diagram. - /// - /// # Examples - /// - /// ``` - /// use bevy_impulse::{Diagram, DiagramError, NodeBuilderOptions, DiagramElementRegistry, RunCommandsOnWorldExt}; - /// - /// let mut app = bevy_app::App::new(); - /// let mut registry = DiagramElementRegistry::new(); - /// registry.register_node_builder(NodeBuilderOptions::new("echo".to_string()), |builder, _config: ()| { - /// builder.create_map_block(|msg: String| msg) - /// }); - /// - /// let json_str = r#" - /// { - /// "version": "0.1.0", - /// "start": "echo", - /// "ops": { - /// "echo": { - /// "type": "node", - /// "builder": "echo", - /// "next": { "builtin": "terminate" } - /// } - /// } - /// } - /// "#; - /// - /// let diagram = Diagram::from_json_str(json_str)?; - /// let workflow = app.world.command(|cmds| diagram.spawn_io_workflow(cmds, ®istry))?; - /// # Ok::<_, DiagramError>(()) - /// ``` + /// Implementation for [Self::spawn_io_workflow]. // TODO(koonpeng): Support streams other than `()` #43. /* pub */ fn spawn_workflow( @@ -447,7 +417,37 @@ impl Diagram { Ok(w) } - /// Wrapper to [spawn_workflow::<()>](Self::spawn_workflow). + /// Spawns a workflow from this diagram. + /// + /// # Examples + /// + /// ``` + /// use bevy_impulse::{Diagram, DiagramError, NodeBuilderOptions, DiagramElementRegistry, RunCommandsOnWorldExt}; + /// + /// let mut app = bevy_app::App::new(); + /// let mut registry = DiagramElementRegistry::new(); + /// registry.register_node_builder(NodeBuilderOptions::new("echo".to_string()), |builder, _config: ()| { + /// builder.create_map_block(|msg: String| msg) + /// }); + /// + /// let json_str = r#" + /// { + /// "version": "0.1.0", + /// "start": "echo", + /// "ops": { + /// "echo": { + /// "type": "node", + /// "builder": "echo", + /// "next": { "builtin": "terminate" } + /// } + /// } + /// } + /// "#; + /// + /// let diagram = Diagram::from_json_str(json_str)?; + /// let workflow = app.world.command(|cmds| diagram.spawn_io_workflow(cmds, ®istry))?; + /// # Ok::<_, DiagramError>(()) + /// ``` pub fn spawn_io_workflow( &self, cmds: &mut Commands, diff --git a/src/diagram/registration.rs b/src/diagram/registration.rs index 66afb59f..d0ce66e9 100644 --- a/src/diagram/registration.rs +++ b/src/diagram/registration.rs @@ -985,7 +985,7 @@ impl DiagramElementRegistry { /// Register a node builder with all the common operations (deserialize the /// request, serialize the response, and clone the response) enabled. /// - /// You will receive a [`RegistrationBuilder`] which you can then use to + /// You will receive a [`NodeRegistrationBuilder`] which you can then use to /// enable more operations around your node, such as fork result, split, /// or unzip. The data types of your node need to be suitable for those /// operations or else the compiler will not allow you to enable them. diff --git a/src/gate.rs b/src/gate.rs index 8d34c757..03797337 100644 --- a/src/gate.rs +++ b/src/gate.rs @@ -23,14 +23,14 @@ pub enum Gate { /// receive a wakeup immediately when a gate switches from closed to open, /// even if none of the data inside the buffer has changed. /// - /// [1]: crate::Bufferable::join + /// [1]: crate::Joinable::join Open, /// Close the buffer gate so that listeners (including [join][1] operations) /// will not be woken up when the data in the buffer gets modified. This /// effectively blocks the workflow nodes that are downstream of the buffer. /// Data will build up in the buffer according to its [`BufferSettings`][2]. /// - /// [1]: crate::Bufferable::join + /// [1]: crate::Joinable::join /// [2]: crate::BufferSettings Closed, } diff --git a/src/lib.rs b/src/lib.rs index d7e27ab8..ceab07f2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -72,6 +72,8 @@ pub use async_execution::Sendish; pub mod buffer; pub use buffer::*; +pub mod re_exports; + pub mod builder; pub use builder::*; @@ -148,6 +150,8 @@ pub use trim::*; use bevy_app::prelude::{App, Plugin, Update}; use bevy_ecs::prelude::{Entity, In}; +extern crate self as bevy_impulse; + /// Use `BlockingService` to indicate that your system is a blocking [`Service`]. /// /// A blocking service will have exclusive world access while it runs, which @@ -336,8 +340,10 @@ impl Plugin for ImpulsePlugin { pub mod prelude { pub use crate::{ buffer::{ - Buffer, BufferAccess, BufferAccessMut, BufferKey, BufferSettings, Bufferable, Buffered, - IterBufferable, RetentionPolicy, + Accessible, Accessor, AnyBuffer, AnyBufferKey, AnyBufferMut, AnyBufferWorldAccess, + AnyMessageBox, AsAnyBuffer, Buffer, BufferAccess, BufferAccessMut, BufferKey, + BufferMap, BufferMapLayout, BufferSettings, BufferWorldAccess, Bufferable, Buffering, + IncompatibleLayout, IterBufferable, Joinable, Joined, RetentionPolicy, }, builder::Builder, callback::{AsCallback, Callback, IntoAsyncCallback, IntoBlockingCallback}, @@ -362,4 +368,9 @@ pub mod prelude { BlockingCallback, BlockingCallbackInput, BlockingMap, BlockingService, BlockingServiceInput, ContinuousQuery, ContinuousService, ContinuousServiceInput, }; + + #[cfg(feature = "diagram")] + pub use crate::buffer::{ + JsonBuffer, JsonBufferKey, JsonBufferMut, JsonBufferWorldAccess, JsonMessage, + }; } diff --git a/src/operation/cleanup.rs b/src/operation/cleanup.rs index f4231825..8bac864e 100644 --- a/src/operation/cleanup.rs +++ b/src/operation/cleanup.rs @@ -16,7 +16,7 @@ */ use crate::{ - BufferAccessStorage, Buffered, ManageDisposal, ManageInput, MiscellaneousFailure, + Accessing, BufferAccessStorage, ManageDisposal, ManageInput, MiscellaneousFailure, OperationError, OperationResult, OperationRoster, OrBroken, ScopeStorage, UnhandledErrors, }; @@ -98,7 +98,7 @@ impl<'a> OperationCleanup<'a> { pub fn cleanup_buffer_access(&mut self) -> OperationResult where - B: Buffered + 'static + Send + Sync, + B: Accessing + 'static + Send + Sync, B::Key: 'static + Send + Sync, { let scope = self diff --git a/src/operation/join.rs b/src/operation/join.rs index 91d5f3e9..314e64b8 100644 --- a/src/operation/join.rs +++ b/src/operation/join.rs @@ -18,7 +18,7 @@ use bevy_ecs::prelude::{Component, Entity}; use crate::{ - Buffered, FunnelInputStorage, Input, InputBundle, ManageInput, Operation, OperationCleanup, + FunnelInputStorage, Input, InputBundle, Joining, ManageInput, Operation, OperationCleanup, OperationError, OperationReachability, OperationRequest, OperationResult, OperationSetup, OrBroken, ReachabilityResult, SingleInputStorage, SingleTargetStorage, }; @@ -37,7 +37,7 @@ impl Join { #[derive(Component)] struct BufferStorage(Buffers); -impl Operation for Join +impl Operation for Join where Buffers::Item: 'static + Send + Sync, { diff --git a/src/operation/listen.rs b/src/operation/listen.rs index 3a15ca82..378fe9f5 100644 --- a/src/operation/listen.rs +++ b/src/operation/listen.rs @@ -18,7 +18,7 @@ use bevy_ecs::prelude::Entity; use crate::{ - buffer_key_usage, get_access_keys, BufferAccessStorage, BufferKeyUsage, Buffered, + buffer_key_usage, get_access_keys, Accessing, BufferAccessStorage, BufferKeyUsage, FunnelInputStorage, Input, InputBundle, ManageInput, Operation, OperationCleanup, OperationReachability, OperationRequest, OperationResult, OperationSetup, OrBroken, ReachabilityResult, SingleInputStorage, SingleTargetStorage, @@ -37,7 +37,7 @@ impl Listen { impl Operation for Listen where - B: Buffered + 'static + Send + Sync, + B: Accessing + 'static + Send + Sync, B::Key: 'static + Send + Sync, { fn setup(self, OperationSetup { source, world }: OperationSetup) -> OperationResult { diff --git a/src/operation/operate_buffer_access.rs b/src/operation/operate_buffer_access.rs index b2936f01..07208410 100644 --- a/src/operation/operate_buffer_access.rs +++ b/src/operation/operate_buffer_access.rs @@ -25,7 +25,7 @@ use std::{ use smallvec::SmallVec; use crate::{ - BufferKeyBuilder, Buffered, ChannelQueue, Input, InputBundle, ManageInput, Operation, + Accessing, BufferKeyBuilder, ChannelQueue, Input, InputBundle, ManageInput, Operation, OperationCleanup, OperationError, OperationReachability, OperationRequest, OperationResult, OperationSetup, OrBroken, ReachabilityResult, ScopeStorage, SingleInputStorage, SingleTargetStorage, @@ -34,7 +34,7 @@ use crate::{ pub(crate) struct OperateBufferAccess where T: 'static + Send + Sync, - B: Buffered, + B: Accessing, { buffers: B, target: Entity, @@ -44,7 +44,7 @@ where impl OperateBufferAccess where T: 'static + Send + Sync, - B: Buffered, + B: Accessing, { pub(crate) fn new(buffers: B, target: Entity) -> Self { Self { @@ -59,12 +59,12 @@ where pub struct BufferKeyUsage(pub(crate) fn(Entity, Entity, &World) -> ReachabilityResult); #[derive(Component)] -pub(crate) struct BufferAccessStorage { +pub(crate) struct BufferAccessStorage { pub(crate) buffers: B, pub(crate) keys: HashMap, } -impl BufferAccessStorage { +impl BufferAccessStorage { pub(crate) fn new(buffers: B) -> Self { Self { buffers, @@ -76,7 +76,7 @@ impl BufferAccessStorage { impl Operation for OperateBufferAccess where T: 'static + Send + Sync, - B: Buffered + 'static + Send + Sync, + B: Accessing + 'static + Send + Sync, B::Key: 'static + Send + Sync, { fn setup(self, OperationSetup { source, world }: OperationSetup) -> OperationResult { @@ -138,7 +138,7 @@ pub(crate) fn get_access_keys( world: &mut World, ) -> Result where - B: Buffered + 'static + Send + Sync, + B: Accessing + 'static + Send + Sync, B::Key: 'static + Send + Sync, { let scope = world.get::(source).or_broken()?.get(); @@ -180,7 +180,7 @@ pub(crate) fn buffer_key_usage( world: &World, ) -> ReachabilityResult where - B: Buffered + 'static + Send + Sync, + B: Accessing + 'static + Send + Sync, B::Key: 'static + Send + Sync, { let key = world @@ -206,6 +206,12 @@ where pub(crate) struct BufferAccessors(pub(crate) SmallVec<[Entity; 8]>); impl BufferAccessors { + pub(crate) fn add_accessor(&mut self, accessor: Entity) { + self.0.push(accessor); + self.0.sort(); + self.0.dedup(); + } + pub(crate) fn is_reachable(r: &mut OperationReachability) -> ReachabilityResult { let Some(accessors) = r.world.get::(r.source) else { return Ok(false); diff --git a/src/operation/operate_gate.rs b/src/operation/operate_gate.rs index 9a677e37..0a45d449 100644 --- a/src/operation/operate_gate.rs +++ b/src/operation/operate_gate.rs @@ -18,7 +18,7 @@ use bevy_ecs::prelude::{Component, Entity}; use crate::{ - emit_disposal, Buffered, Disposal, Gate, GateRequest, Input, InputBundle, ManageInput, + emit_disposal, Buffering, Disposal, Gate, GateRequest, Input, InputBundle, ManageInput, Operation, OperationCleanup, OperationReachability, OperationRequest, OperationResult, OperationSetup, OrBroken, ReachabilityResult, SingleInputStorage, SingleTargetStorage, }; @@ -48,7 +48,7 @@ impl OperateDynamicGate { impl Operation for OperateDynamicGate where T: 'static + Send + Sync, - B: Buffered + 'static + Send + Sync, + B: Buffering + 'static + Send + Sync, { fn setup(self, OperationSetup { source, world }: OperationSetup) -> OperationResult { world @@ -144,7 +144,7 @@ impl OperateStaticGate { impl Operation for OperateStaticGate where - B: Buffered + 'static + Send + Sync, + B: Buffering + 'static + Send + Sync, T: 'static + Send + Sync, { fn setup(self, OperationSetup { source, world }: OperationSetup) -> OperationResult { diff --git a/src/operation/scope.rs b/src/operation/scope.rs index fe18821f..f757653b 100644 --- a/src/operation/scope.rs +++ b/src/operation/scope.rs @@ -16,15 +16,14 @@ */ use crate::{ - check_reachability, execute_operation, is_downstream_of, AddOperation, Blocker, - BufferKeyBuilder, Buffered, Cancel, Cancellable, Cancellation, Cleanup, CleanupContents, - ClearBufferFn, CollectMarker, DisposalListener, DisposalUpdate, FinalizeCleanup, - FinalizeCleanupRequest, Input, InputBundle, InspectDisposals, ManageCancellation, ManageInput, - Operation, OperationCancel, OperationCleanup, OperationError, OperationReachability, - OperationRequest, OperationResult, OperationRoster, OperationSetup, OrBroken, - ReachabilityResult, ScopeSettings, SingleInputStorage, SingleTargetStorage, Stream, StreamPack, - StreamRequest, StreamTargetMap, StreamTargetStorage, UnhandledErrors, Unreachability, - UnusedTarget, + check_reachability, execute_operation, is_downstream_of, Accessing, AddOperation, Blocker, + BufferKeyBuilder, Cancel, Cancellable, Cancellation, Cleanup, CleanupContents, ClearBufferFn, + CollectMarker, DisposalListener, DisposalUpdate, FinalizeCleanup, FinalizeCleanupRequest, + Input, InputBundle, InspectDisposals, ManageCancellation, ManageInput, Operation, + OperationCancel, OperationCleanup, OperationError, OperationReachability, OperationRequest, + OperationResult, OperationRoster, OperationSetup, OrBroken, ReachabilityResult, ScopeSettings, + SingleInputStorage, SingleTargetStorage, Stream, StreamPack, StreamRequest, StreamTargetMap, + StreamTargetStorage, UnhandledErrors, Unreachability, UnusedTarget, }; use backtrace::Backtrace; @@ -1125,7 +1124,7 @@ impl BeginCleanupWorkflow { impl Operation for BeginCleanupWorkflow where - B: Buffered + 'static + Send + Sync, + B: Accessing + 'static + Send + Sync, B::Key: 'static + Send + Sync, { fn setup(self, OperationSetup { source, world }: OperationSetup) -> OperationResult { diff --git a/src/re_exports.rs b/src/re_exports.rs new file mode 100644 index 00000000..84f22076 --- /dev/null +++ b/src/re_exports.rs @@ -0,0 +1,21 @@ +/* + * Copyright (C) 2025 Open Source Robotics Foundation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * +*/ + +//! This module contains symbols that are being re-exported so they can be used +//! by bevy_impulse_derive. + +pub use bevy_ecs::prelude::{Entity, World}; diff --git a/src/testing.rs b/src/testing.rs index 83d21dd9..f7ba4664 100644 --- a/src/testing.rs +++ b/src/testing.rs @@ -19,7 +19,7 @@ use bevy_app::ScheduleRunnerPlugin; pub use bevy_app::{App, Update}; use bevy_core::{FrameCountPlugin, TaskPoolPlugin, TypeRegistrationPlugin}; pub use bevy_ecs::{ - prelude::{Commands, Component, Entity, In, Local, Query, ResMut, Resource}, + prelude::{Commands, Component, Entity, In, Local, Query, ResMut, Resource, World}, system::{CommandQueue, IntoSystem}, }; use bevy_time::TimePlugin; @@ -32,10 +32,12 @@ pub use std::time::{Duration, Instant}; use smallvec::SmallVec; use crate::{ - flush_impulses, AddContinuousServicesExt, AsyncServiceInput, BlockingMap, BlockingServiceInput, - Builder, ContinuousQuery, ContinuousQueueView, ContinuousService, FlushParameters, - GetBufferedSessionsFn, Promise, RunCommandsOnWorldExt, Scope, Service, SpawnWorkflowExt, - StreamOf, StreamPack, UnhandledErrors, WorkflowSettings, + flush_impulses, Accessing, AddContinuousServicesExt, AnyBuffer, AsAnyBuffer, AsyncServiceInput, + BlockingMap, BlockingServiceInput, Buffer, BufferKey, BufferKeyLifecycle, Bufferable, + Buffering, Builder, ContinuousQuery, ContinuousQueueView, ContinuousService, FlushParameters, + GetBufferedSessionsFn, Joining, OperationError, OperationResult, OperationRoster, Promise, + RunCommandsOnWorldExt, Scope, Service, SpawnWorkflowExt, StreamOf, StreamPack, UnhandledErrors, + WorkflowSettings, }; pub struct TestingContext { @@ -478,3 +480,104 @@ pub struct TestComponent; pub struct Integer { pub value: i32, } + +/// This is an ordinary buffer newtype whose only purpose is to test the +/// #[joined(noncopy_buffer)] feature. We intentionally do not implement +/// the Copy trait for it. +pub struct NonCopyBuffer { + inner: Buffer, +} + +impl NonCopyBuffer { + pub fn register_downcast() { + let any_interface = AnyBuffer::interface_for::(); + any_interface.register_buffer_downcast( + std::any::TypeId::of::>(), + Box::new(|location| { + Box::new(NonCopyBuffer:: { + inner: Buffer { + location, + _ignore: Default::default(), + }, + }) + }), + ); + } +} + +impl Clone for NonCopyBuffer { + fn clone(&self) -> Self { + Self { inner: self.inner } + } +} + +impl AsAnyBuffer for NonCopyBuffer { + fn as_any_buffer(&self) -> AnyBuffer { + self.inner.as_any_buffer() + } +} + +impl Bufferable for NonCopyBuffer { + type BufferType = Self; + fn into_buffer(self, _builder: &mut Builder) -> Self::BufferType { + self + } +} + +impl Buffering for NonCopyBuffer { + fn add_listener(&self, listener: Entity, world: &mut World) -> OperationResult { + self.inner.add_listener(listener, world) + } + + fn as_input(&self) -> smallvec::SmallVec<[Entity; 8]> { + self.inner.as_input() + } + + fn buffered_count(&self, session: Entity, world: &World) -> Result { + self.inner.buffered_count(session, world) + } + + fn ensure_active_session(&self, session: Entity, world: &mut World) -> OperationResult { + self.inner.ensure_active_session(session, world) + } + + fn gate_action( + &self, + session: Entity, + action: crate::Gate, + world: &mut World, + roster: &mut OperationRoster, + ) -> OperationResult { + self.inner.gate_action(session, action, world, roster) + } + + fn verify_scope(&self, scope: Entity) { + self.inner.verify_scope(scope); + } +} + +impl Joining for NonCopyBuffer { + type Item = T; + fn pull(&self, session: Entity, world: &mut World) -> Result { + self.inner.pull(session, world) + } +} + +impl Accessing for NonCopyBuffer { + type Key = BufferKey; + fn add_accessor(&self, accessor: Entity, world: &mut World) -> OperationResult { + self.inner.add_accessor(accessor, world) + } + + fn create_key(&self, builder: &crate::BufferKeyBuilder) -> Self::Key { + self.inner.create_key(builder) + } + + fn deep_clone_key(key: &Self::Key) -> Self::Key { + key.deep_clone() + } + + fn is_key_in_use(key: &Self::Key) -> bool { + key.is_in_use() + } +}