55// http://opensource.org/licenses/MIT>, at your option. You may not use this file except in
66// accordance with one or both of these licenses.
77
8- use std:: collections:: { hash_map , HashMap } ;
8+ use std:: collections:: HashMap ;
99use std:: ops:: Deref ;
1010use std:: sync:: { Arc , Mutex } ;
1111
@@ -83,34 +83,38 @@ where
8383
8484 pub ( crate ) async fn insert_or_update ( & self , object : SO ) -> Result < bool , Error > {
8585 let _guard = self . mutation_lock . lock ( ) . await ;
86- let ( updated, data_to_persist) = {
87- let mut locked_objects = self . objects . lock ( ) . expect ( "lock" ) ;
88- match locked_objects. entry ( object. id ( ) ) {
89- hash_map:: Entry :: Occupied ( mut e) => {
90- let update = object. to_update ( ) ;
91- let updated = e. get_mut ( ) . update ( update) ;
92- let data_to_persist =
93- if updated { Some ( Self :: encode_object ( e. get ( ) ) ) } else { None } ;
94- ( updated, data_to_persist)
95- } ,
96- hash_map:: Entry :: Vacant ( e) => {
97- let data_to_persist = Self :: encode_object ( & object) ;
98- e. insert ( object) ;
99- ( true , Some ( data_to_persist) )
100- } ,
86+
87+ let id = object. id ( ) ;
88+ let data_to_persist = {
89+ let locked_objects = self . objects . lock ( ) . expect ( "lock" ) ;
90+ if let Some ( existing_object) = locked_objects. get ( & id) {
91+ let mut updated_object = existing_object. clone ( ) ;
92+ let updated = updated_object. update ( object. to_update ( ) ) ;
93+ if updated {
94+ Some ( updated_object)
95+ } else {
96+ None
97+ }
98+ } else {
99+ Some ( object)
101100 }
102101 } ;
103102
104- if let Some ( ( store_key, data) ) = data_to_persist {
105- self . persist_encoded ( store_key, data) . await ?;
103+ match data_to_persist {
104+ Some ( updated_object) => {
105+ self . persist ( & updated_object) . await ?;
106+ let mut locked_objects = self . objects . lock ( ) . expect ( "lock" ) ;
107+ locked_objects. insert ( id, updated_object) ;
108+ Ok ( true )
109+ } ,
110+ None => Ok ( false ) ,
106111 }
107- Ok ( updated)
108112 }
109113
110114 pub ( crate ) async fn remove ( & self , id : & SO :: Id ) -> Result < ( ) , Error > {
111115 let _guard = self . mutation_lock . lock ( ) . await ;
112- let removed = { self . objects . lock ( ) . expect ( "lock" ) . remove ( id) . is_some ( ) } ;
113- if removed {
116+ let should_remove = { self . objects . lock ( ) . expect ( "lock" ) . contains_key ( id) } ;
117+ if should_remove {
114118 let store_key = id. encode_to_hex_str ( ) ;
115119 KVStore :: remove (
116120 & * self . kv_store ,
@@ -131,45 +135,46 @@ where
131135 ) ;
132136 Error :: PersistenceFailed
133137 } ) ?;
138+ self . objects . lock ( ) . expect ( "lock" ) . remove ( id) ;
134139 }
135140 Ok ( ( ) )
136141 }
137142
138143 /// Returns the current in-memory object for `id`.
139144 ///
140145 /// The async mutation lock serializes writers, but this synchronous reader cannot wait on it.
141- /// Until store reads are async, callers may temporarily see in-memory state that is either
142- /// still being persisted or has not yet caught up to a write in progress.
146+ /// Until store reads are async, callers may temporarily see in-memory state that has not yet
147+ /// caught up to a write in progress.
143148 pub ( crate ) fn get ( & self , id : & SO :: Id ) -> Option < SO > {
144149 self . objects . lock ( ) . expect ( "lock" ) . get ( id) . cloned ( )
145150 }
146151
147152 pub ( crate ) async fn update ( & self , update : SO :: Update ) -> Result < DataStoreUpdateResult , Error > {
148153 let _guard = self . mutation_lock . lock ( ) . await ;
149- let ( res, data_to_persist) = {
150- let mut locked_objects = self . objects . lock ( ) . expect ( "lock" ) ;
151- if let Some ( object) = locked_objects. get_mut ( & update. id ( ) ) {
152- let updated = object. update ( update) ;
153- if updated {
154- ( DataStoreUpdateResult :: Updated , Some ( Self :: encode_object ( object) ) )
155- } else {
156- ( DataStoreUpdateResult :: Unchanged , None )
157- }
158- } else {
159- ( DataStoreUpdateResult :: NotFound , None )
154+ let id = update. id ( ) ;
155+ let updated_object = {
156+ let locked_objects = self . objects . lock ( ) . expect ( "lock" ) ;
157+ let Some ( object) = locked_objects. get ( & id) else {
158+ return Ok ( DataStoreUpdateResult :: NotFound ) ;
159+ } ;
160+ let mut updated_object = object. clone ( ) ;
161+ if !updated_object. update ( update) {
162+ return Ok ( DataStoreUpdateResult :: Unchanged ) ;
160163 }
164+ updated_object
161165 } ;
162- if let Some ( ( store_key, data) ) = data_to_persist {
163- self . persist_encoded ( store_key, data) . await ?;
164- }
165- Ok ( res)
166+
167+ self . persist ( & updated_object) . await ?;
168+ let mut locked_objects = self . objects . lock ( ) . expect ( "lock" ) ;
169+ locked_objects. insert ( id, updated_object) ;
170+ Ok ( DataStoreUpdateResult :: Updated )
166171 }
167172
168173 /// Returns in-memory objects matching `f`.
169174 ///
170175 /// The async mutation lock serializes writers, but this synchronous reader cannot wait on it.
171- /// Until store reads are async, callers may temporarily see in-memory state that is either
172- /// still being persisted or has not yet caught up to a write in progress.
176+ /// Until store reads are async, callers may temporarily see in-memory state that has not yet
177+ /// caught up to a write in progress.
173178 pub ( crate ) fn list_filter < F : FnMut ( & & SO ) -> bool > ( & self , f : F ) -> Vec < SO > {
174179 self . objects . lock ( ) . expect ( "lock" ) . values ( ) . filter ( f) . cloned ( ) . collect :: < Vec < SO > > ( )
175180 }
@@ -209,8 +214,8 @@ where
209214 /// Returns whether the in-memory store contains `id`.
210215 ///
211216 /// The async mutation lock serializes writers, but this synchronous reader cannot wait on it.
212- /// Until store reads are async, callers may temporarily see in-memory state that is either
213- /// still being persisted or has not yet caught up to a write in progress.
217+ /// Until store reads are async, callers may temporarily see in-memory state that has not yet
218+ /// caught up to a write in progress.
214219 pub ( crate ) fn contains_key ( & self , id : & SO :: Id ) -> bool {
215220 self . objects . lock ( ) . expect ( "lock" ) . contains_key ( id)
216221 }
@@ -219,6 +224,7 @@ where
219224#[ cfg( test) ]
220225mod tests {
221226 use lightning:: impl_writeable_tlv_based;
227+ use lightning:: io;
222228 use lightning:: util:: test_utils:: TestLogger ;
223229
224230 use super :: * ;
@@ -281,6 +287,46 @@ mod tests {
281287 ( 2 , data, required) ,
282288 } ) ;
283289
290+ struct FailingStore ;
291+
292+ impl KVStore for FailingStore {
293+ fn read (
294+ & self , _primary_namespace : & str , _secondary_namespace : & str , _key : & str ,
295+ ) -> impl std:: future:: Future < Output = Result < Vec < u8 > , io:: Error > > + ' static + Send {
296+ async { Err ( io:: Error :: new ( io:: ErrorKind :: Other , "read failed" ) ) }
297+ }
298+
299+ fn write (
300+ & self , _primary_namespace : & str , _secondary_namespace : & str , _key : & str , _buf : Vec < u8 > ,
301+ ) -> impl std:: future:: Future < Output = Result < ( ) , io:: Error > > + ' static + Send {
302+ async { Err ( io:: Error :: new ( io:: ErrorKind :: Other , "write failed" ) ) }
303+ }
304+
305+ fn remove (
306+ & self , _primary_namespace : & str , _secondary_namespace : & str , _key : & str , _lazy : bool ,
307+ ) -> impl std:: future:: Future < Output = Result < ( ) , io:: Error > > + ' static + Send {
308+ async { Err ( io:: Error :: new ( io:: ErrorKind :: Other , "remove failed" ) ) }
309+ }
310+
311+ fn list (
312+ & self , _primary_namespace : & str , _secondary_namespace : & str ,
313+ ) -> impl std:: future:: Future < Output = Result < Vec < String > , io:: Error > > + ' static + Send {
314+ async { Err ( io:: Error :: new ( io:: ErrorKind :: Other , "list failed" ) ) }
315+ }
316+ }
317+
318+ fn new_failing_data_store ( objects : Vec < TestObject > ) -> DataStore < TestObject , Arc < TestLogger > > {
319+ let store: Arc < DynStore > = Arc :: new ( DynStoreWrapper ( FailingStore ) ) ;
320+ let logger = Arc :: new ( TestLogger :: new ( ) ) ;
321+ DataStore :: new (
322+ objects,
323+ "datastore_test_primary" . to_string ( ) ,
324+ "datastore_test_secondary" . to_string ( ) ,
325+ store,
326+ logger,
327+ )
328+ }
329+
284330 #[ tokio:: test]
285331 async fn data_is_persisted ( ) {
286332 let store: Arc < DynStore > = Arc :: new ( DynStoreWrapper ( InMemoryStore :: new ( ) ) ) ;
@@ -346,4 +392,54 @@ mod tests {
346392 new_iou_object. data [ 0 ] += 1 ;
347393 assert_eq ! ( Ok ( true ) , data_store. insert_or_update( new_iou_object) . await ) ;
348394 }
395+
396+ #[ tokio:: test]
397+ async fn insert_or_update_does_not_mutate_memory_if_persist_fails ( ) {
398+ let existing_id = TestObjectId { id : [ 42u8 ; 4 ] } ;
399+ let existing_object = TestObject { id : existing_id, data : [ 23u8 ; 3 ] } ;
400+ let data_store = new_failing_data_store ( vec ! [ existing_object] ) ;
401+
402+ let updated_object = TestObject { id : existing_id, data : [ 24u8 ; 3 ] } ;
403+ assert_eq ! (
404+ Err ( Error :: PersistenceFailed ) ,
405+ data_store. insert_or_update( updated_object) . await
406+ ) ;
407+ assert_eq ! ( Some ( existing_object) , data_store. get( & existing_id) ) ;
408+
409+ let new_id = TestObjectId { id : [ 55u8 ; 4 ] } ;
410+ let new_object = TestObject { id : new_id, data : [ 34u8 ; 3 ] } ;
411+ assert_eq ! ( Err ( Error :: PersistenceFailed ) , data_store. insert_or_update( new_object) . await ) ;
412+ assert ! ( data_store. get( & new_id) . is_none( ) ) ;
413+ }
414+
415+ #[ tokio:: test]
416+ async fn insert_does_not_mutate_memory_if_persist_fails ( ) {
417+ let id = TestObjectId { id : [ 42u8 ; 4 ] } ;
418+ let object = TestObject { id, data : [ 23u8 ; 3 ] } ;
419+ let data_store = new_failing_data_store ( vec ! [ ] ) ;
420+
421+ assert_eq ! ( Err ( Error :: PersistenceFailed ) , data_store. insert( object) . await ) ;
422+ assert ! ( data_store. get( & id) . is_none( ) ) ;
423+ }
424+
425+ #[ tokio:: test]
426+ async fn update_does_not_mutate_memory_if_persist_fails ( ) {
427+ let id = TestObjectId { id : [ 42u8 ; 4 ] } ;
428+ let object = TestObject { id, data : [ 23u8 ; 3 ] } ;
429+ let data_store = new_failing_data_store ( vec ! [ object] ) ;
430+
431+ let update = TestObjectUpdate { id, data : [ 24u8 ; 3 ] } ;
432+ assert_eq ! ( Err ( Error :: PersistenceFailed ) , data_store. update( update) . await ) ;
433+ assert_eq ! ( Some ( object) , data_store. get( & id) ) ;
434+ }
435+
436+ #[ tokio:: test]
437+ async fn remove_does_not_mutate_memory_if_persist_fails ( ) {
438+ let id = TestObjectId { id : [ 42u8 ; 4 ] } ;
439+ let object = TestObject { id, data : [ 23u8 ; 3 ] } ;
440+ let data_store = new_failing_data_store ( vec ! [ object] ) ;
441+
442+ assert_eq ! ( Err ( Error :: PersistenceFailed ) , data_store. remove( & id) . await ) ;
443+ assert_eq ! ( Some ( object) , data_store. get( & id) ) ;
444+ }
349445}
0 commit comments