1+ use std:: convert:: TryFrom ;
12use std:: fmt:: { self , Display } ;
23use std:: fs:: File ;
34use std:: io:: { BufReader , Read , Seek , Write } ;
45
56use byteorder:: { LittleEndian , ReadBytesExt , WriteBytesExt } ;
67
7- use crate :: io :: { Error , ErrorKind , Result } ;
8+ use crate :: error :: { Error , Result } ;
89
910const MODEL_VERSION : u32 = 0 ;
1011
@@ -25,39 +26,20 @@ pub enum ChunkIdentifier {
2526}
2627
2728impl ChunkIdentifier {
28- pub fn try_from ( identifier : u32 ) -> Option < Self > {
29- use self :: ChunkIdentifier :: * ;
30-
31- match identifier {
32- 1 => Some ( SimpleVocab ) ,
33- 2 => Some ( NdArray ) ,
34- 3 => Some ( BucketSubwordVocab ) ,
35- 4 => Some ( QuantizedArray ) ,
36- 5 => Some ( Metadata ) ,
37- 6 => Some ( NdNorms ) ,
38- 7 => Some ( FastTextSubwordVocab ) ,
39- 8 => Some ( ExplicitSubwordVocab ) ,
40- _ => None ,
41- }
42- }
43-
4429 /// Read and ensure that the chunk has the given identifier.
4530 pub fn ensure_chunk_type < R > ( read : & mut R , identifier : ChunkIdentifier ) -> Result < ( ) >
4631 where
4732 R : Read ,
4833 {
4934 let chunk_id = read
5035 . read_u32 :: < LittleEndian > ( )
51- . map_err ( |e| ErrorKind :: io_error ( "Cannot read chunk identifier" , e) ) ?;
52- let chunk_id = ChunkIdentifier :: try_from ( chunk_id)
53- . ok_or_else ( || ErrorKind :: Format ( format ! ( "Unknown chunk identifier: {}" , chunk_id) ) )
54- . map_err ( Error :: from) ?;
36+ . map_err ( |e| Error :: io_error ( "Cannot read chunk identifier" , e) ) ?;
37+ let chunk_id = ChunkIdentifier :: try_from ( chunk_id) ?;
5538 if chunk_id != identifier {
56- return Err ( ErrorKind :: Format ( format ! (
39+ return Err ( Error :: Format ( format ! (
5740 "Invalid chunk identifier, expected: {}, got: {}" ,
5841 identifier, chunk_id
59- ) )
60- . into ( ) ) ;
42+ ) ) ) ;
6143 }
6244
6345 Ok ( ( ) )
@@ -82,6 +64,26 @@ impl Display for ChunkIdentifier {
8264 }
8365}
8466
67+ impl TryFrom < u32 > for ChunkIdentifier {
68+ type Error = Error ;
69+
70+ fn try_from ( identifier : u32 ) -> Result < Self > {
71+ use self :: ChunkIdentifier :: * ;
72+
73+ match identifier {
74+ 1 => Ok ( SimpleVocab ) ,
75+ 2 => Ok ( NdArray ) ,
76+ 3 => Ok ( BucketSubwordVocab ) ,
77+ 4 => Ok ( QuantizedArray ) ,
78+ 5 => Ok ( Metadata ) ,
79+ 6 => Ok ( NdNorms ) ,
80+ 7 => Ok ( FastTextSubwordVocab ) ,
81+ 8 => Ok ( ExplicitSubwordVocab ) ,
82+ unknown => Err ( Error :: UnknownChunkIdentifier ( unknown) ) ,
83+ }
84+ }
85+ }
86+
8587/// Trait defining identifiers for data types.
8688pub trait TypeId {
8789 /// Read and ensure that the data type is equal to `Self`.
@@ -102,14 +104,13 @@ macro_rules! typeid_impl {
102104 {
103105 let type_id = read
104106 . read_u32:: <LittleEndian >( )
105- . map_err( |e| ErrorKind :: io_error( "Cannot read type identifier" , e) ) ?;
107+ . map_err( |e| Error :: io_error( "Cannot read type identifier" , e) ) ?;
106108 if type_id != Self :: type_id( ) {
107- return Err ( ErrorKind :: Format ( format!(
109+ return Err ( Error :: Format ( format!(
108110 "Invalid type, expected: {}, got: {}" ,
109111 Self :: type_id( ) ,
110112 type_id
111- ) )
112- . into( ) ) ;
113+ ) ) ) ;
113114 }
114115
115116 Ok ( ( ) )
@@ -183,18 +184,18 @@ impl WriteChunk for Header {
183184 {
184185 write
185186 . write_all ( & MAGIC )
186- . map_err ( |e| ErrorKind :: io_error ( "Cannot write magic" , e) ) ?;
187+ . map_err ( |e| Error :: io_error ( "Cannot write magic" , e) ) ?;
187188 write
188189 . write_u32 :: < LittleEndian > ( MODEL_VERSION )
189- . map_err ( |e| ErrorKind :: io_error ( "Cannot write model version" , e) ) ?;
190+ . map_err ( |e| Error :: io_error ( "Cannot write model version" , e) ) ?;
190191 write
191192 . write_u32 :: < LittleEndian > ( self . chunk_identifiers . len ( ) as u32 )
192- . map_err ( |e| ErrorKind :: io_error ( "Cannot write chunk identifiers length" , e) ) ?;
193+ . map_err ( |e| Error :: io_error ( "Cannot write chunk identifiers length" , e) ) ?;
193194
194195 for & identifier in & self . chunk_identifiers {
195196 write
196197 . write_u32 :: < LittleEndian > ( identifier as u32 )
197- . map_err ( |e| ErrorKind :: io_error ( "Cannot write chunk identifier" , e) ) ?;
198+ . map_err ( |e| Error :: io_error ( "Cannot write chunk identifier" , e) ) ?;
198199 }
199200
200201 Ok ( ( ) )
@@ -209,40 +210,36 @@ impl ReadChunk for Header {
209210 // Magic and version ceremony.
210211 let mut magic = [ 0u8 ; 4 ] ;
211212 read. read_exact ( & mut magic)
212- . map_err ( |e| ErrorKind :: io_error ( "Cannot read magic" , e) ) ?;
213+ . map_err ( |e| Error :: io_error ( "Cannot read magic" , e) ) ?;
213214
214215 if magic != MAGIC {
215- return Err ( ErrorKind :: Format ( format ! (
216+ return Err ( Error :: Format ( format ! (
216217 "Expected 'FiFu' as magic, got: {}" ,
217218 String :: from_utf8_lossy( & magic) . into_owned( )
218- ) )
219- . into ( ) ) ;
219+ ) ) ) ;
220220 }
221221
222222 let version = read
223223 . read_u32 :: < LittleEndian > ( )
224- . map_err ( |e| ErrorKind :: io_error ( "Cannot read model version" , e) ) ?;
224+ . map_err ( |e| Error :: io_error ( "Cannot read model version" , e) ) ?;
225225 if version != MODEL_VERSION {
226- return Err (
227- ErrorKind :: Format ( format ! ( "Unknown finalfusion version: {}" , version) ) . into ( ) ,
228- ) ;
226+ return Err ( Error :: Format ( format ! (
227+ "Unknown finalfusion version: {}" ,
228+ version
229+ ) ) ) ;
229230 }
230231
231232 // Read chunk identifiers.
232233 let chunk_identifiers_len = read
233234 . read_u32 :: < LittleEndian > ( )
234- . map_err ( |e| ErrorKind :: io_error ( "Cannot read chunk identifiers length" , e) ) ?
235+ . map_err ( |e| Error :: io_error ( "Cannot read chunk identifiers length" , e) ) ?
235236 as usize ;
236237 let mut chunk_identifiers = Vec :: with_capacity ( chunk_identifiers_len) ;
237238 for _ in 0 ..chunk_identifiers_len {
238239 let identifier = read
239240 . read_u32 :: < LittleEndian > ( )
240- . map_err ( |e| ErrorKind :: io_error ( "Cannot read chunk identifier" , e) ) ?;
241- let chunk_identifier = ChunkIdentifier :: try_from ( identifier)
242- . ok_or_else ( || {
243- ErrorKind :: Format ( format ! ( "Unknown chunk identifier: {}" , identifier) )
244- } )
245- . map_err ( Error :: from) ?;
241+ . map_err ( |e| Error :: io_error ( "Cannot read chunk identifier" , e) ) ?;
242+ let chunk_identifier = ChunkIdentifier :: try_from ( identifier) ?;
246243 chunk_identifiers. push ( chunk_identifier) ;
247244 }
248245
0 commit comments