@@ -584,6 +584,40 @@ where
584584
585585 Ok ( ( ) )
586586 }
587+
588+ fn write_embeddings_len ( & self , offset : u64 ) -> u64 {
589+ let mut len = 0 ;
590+
591+ let mut chunks = match self . metadata {
592+ Some ( ref metadata) => vec ! [ metadata. chunk_identifier( ) ] ,
593+ None => vec ! [ ] ,
594+ } ;
595+
596+ chunks. extend_from_slice ( & [
597+ self . vocab . chunk_identifier ( ) ,
598+ self . storage . chunk_identifier ( ) ,
599+ ] ) ;
600+
601+ if let Some ( ref norms) = self . norms {
602+ chunks. push ( norms. chunk_identifier ( ) ) ;
603+ }
604+
605+ let header = Header :: new ( chunks) ;
606+ len += header. chunk_len ( offset + len) ;
607+
608+ if let Some ( ref metadata) = self . metadata {
609+ len += metadata. chunk_len ( offset + len) ;
610+ }
611+
612+ len += self . vocab . chunk_len ( offset + len) ;
613+ len += self . storage . chunk_len ( offset + len) ;
614+
615+ if let Some ( ref norms) = self . norms {
616+ len += norms. chunk_len ( offset + len) ;
617+ }
618+
619+ len
620+ }
587621}
588622
589623/// Quantizable embedding matrix.
@@ -736,17 +770,36 @@ mod tests {
736770 use crate :: chunks:: metadata:: Metadata ;
737771 use crate :: chunks:: norms:: NdNorms ;
738772 use crate :: chunks:: storage:: { NdArray , Storage , StorageView } ;
739- use crate :: chunks:: vocab:: { SimpleVocab , Vocab } ;
773+ use crate :: chunks:: vocab:: { FastTextSubwordVocab , SimpleVocab , Vocab } ;
740774 use crate :: compat:: fasttext:: ReadFastText ;
741775 use crate :: compat:: word2vec:: ReadWord2VecRaw ;
742776 use crate :: io:: { ReadEmbeddings , WriteEmbeddings } ;
777+ use crate :: prelude:: StorageWrap ;
778+ use crate :: storage:: QuantizedArray ;
743779 use crate :: subword:: Indexer ;
780+ use crate :: vocab:: VocabWrap ;
744781
745782 fn test_embeddings ( ) -> Embeddings < SimpleVocab , NdArray > {
746783 let mut reader = BufReader :: new ( File :: open ( "testdata/similarity.bin" ) . unwrap ( ) ) ;
747784 Embeddings :: read_word2vec_binary_raw ( & mut reader, false ) . unwrap ( )
748785 }
749786
787+ fn test_embeddings_with_metadata ( ) -> Embeddings < SimpleVocab , NdArray > {
788+ let mut embeds = test_embeddings ( ) ;
789+ embeds. set_metadata ( Some ( test_metadata ( ) ) ) ;
790+ embeds
791+ }
792+
793+ fn test_embeddings_fasttext ( ) -> Embeddings < FastTextSubwordVocab , NdArray > {
794+ let mut reader = BufReader :: new ( File :: open ( "testdata/fasttext.bin" ) . unwrap ( ) ) ;
795+ Embeddings :: read_fasttext ( & mut reader) . unwrap ( )
796+ }
797+
798+ fn test_embeddings_quantized ( ) -> Embeddings < SimpleVocab , QuantizedArray > {
799+ let mut reader = BufReader :: new ( File :: open ( "testdata/quantized.fifu" ) . unwrap ( ) ) ;
800+ Embeddings :: read_embeddings ( & mut reader) . unwrap ( )
801+ }
802+
750803 fn test_metadata ( ) -> Metadata {
751804 Metadata :: new ( toml ! {
752805 [ hyperparameters]
@@ -867,12 +920,15 @@ mod tests {
867920 Embeddings :: read_embeddings ( & mut cursor) . unwrap ( ) ;
868921 assert_eq ! ( embeds. storage( ) . view( ) , check_embeds. storage( ) . view( ) ) ;
869922 assert_eq ! ( embeds. vocab( ) , check_embeds. vocab( ) ) ;
923+ assert_eq ! (
924+ cursor. into_inner( ) . len( ) as u64 ,
925+ check_embeds. write_embeddings_len( 0 )
926+ ) ;
870927 }
871928
872929 #[ test]
873930 fn write_read_simple_metadata_roundtrip ( ) {
874- let mut check_embeds = test_embeddings ( ) ;
875- check_embeds. set_metadata ( Some ( test_metadata ( ) ) ) ;
931+ let check_embeds = test_embeddings_with_metadata ( ) ;
876932
877933 let mut cursor = Cursor :: new ( Vec :: new ( ) ) ;
878934 check_embeds. write_embeddings ( & mut cursor) . unwrap ( ) ;
@@ -881,5 +937,31 @@ mod tests {
881937 Embeddings :: read_embeddings ( & mut cursor) . unwrap ( ) ;
882938 assert_eq ! ( embeds. storage( ) . view( ) , check_embeds. storage( ) . view( ) ) ;
883939 assert_eq ! ( embeds. vocab( ) , check_embeds. vocab( ) ) ;
940+ assert_eq ! (
941+ cursor. into_inner( ) . len( ) as u64 ,
942+ check_embeds. write_embeddings_len( 0 )
943+ ) ;
944+ }
945+
946+ #[ test]
947+ fn embeddings_write_length_different_offsets ( ) {
948+ let embeddings: Vec < Embeddings < VocabWrap , StorageWrap > > = vec ! [
949+ test_embeddings( ) . into( ) ,
950+ test_embeddings_with_metadata( ) . into( ) ,
951+ test_embeddings_fasttext( ) . into( ) ,
952+ test_embeddings_quantized( ) . into( ) ,
953+ ] ;
954+
955+ for check_embeddings in & embeddings {
956+ for offset in 0 ..16u64 {
957+ let mut cursor = Cursor :: new ( Vec :: new ( ) ) ;
958+ cursor. seek ( SeekFrom :: Start ( offset) ) . unwrap ( ) ;
959+ check_embeddings. write_embeddings ( & mut cursor) . unwrap ( ) ;
960+ assert_eq ! (
961+ cursor. into_inner( ) . len( ) as u64 - offset,
962+ check_embeddings. write_embeddings_len( offset)
963+ ) ;
964+ }
965+ }
884966 }
885967}
0 commit comments