Skip to content

Commit 52f6faf

Browse files
committed
Cleanup top-level modules
- All chunk modules become submodules of `chunk`. - Modules for reading foreign formats become submodules of `compat`.
1 parent a7fec83 commit 52f6faf

File tree

18 files changed

+338
-338
lines changed

18 files changed

+338
-338
lines changed

src/chunks/io.rs

Lines changed: 263 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,263 @@
1+
use std::fmt::{self, Display};
2+
use std::fs::File;
3+
use std::io::{BufReader, Read, Seek, Write};
4+
5+
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
6+
7+
use crate::io::{Error, ErrorKind, Result};
8+
9+
const MODEL_VERSION: u32 = 0;
10+
11+
const MAGIC: [u8; 4] = [b'F', b'i', b'F', b'u'];
12+
13+
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
14+
#[repr(u32)]
15+
pub enum ChunkIdentifier {
16+
Header = 0,
17+
SimpleVocab = 1,
18+
NdArray = 2,
19+
FinalfusionSubwordVocab = 3,
20+
QuantizedArray = 4,
21+
Metadata = 5,
22+
NdNorms = 6,
23+
FastTextSubwordVocab = 7,
24+
}
25+
26+
impl ChunkIdentifier {
27+
pub fn try_from(identifier: u32) -> Option<Self> {
28+
use self::ChunkIdentifier::*;
29+
30+
match identifier {
31+
1 => Some(SimpleVocab),
32+
2 => Some(NdArray),
33+
3 => Some(FinalfusionSubwordVocab),
34+
4 => Some(QuantizedArray),
35+
5 => Some(Metadata),
36+
6 => Some(NdNorms),
37+
7 => Some(FastTextSubwordVocab),
38+
_ => None,
39+
}
40+
}
41+
42+
/// Read and ensure that the chunk has the given identifier.
43+
pub fn ensure_chunk_type<R>(read: &mut R, identifier: ChunkIdentifier) -> Result<()>
44+
where
45+
R: Read,
46+
{
47+
let chunk_id = read
48+
.read_u32::<LittleEndian>()
49+
.map_err(|e| ErrorKind::io_error("Cannot read chunk identifier", e))?;
50+
let chunk_id = ChunkIdentifier::try_from(chunk_id)
51+
.ok_or_else(|| ErrorKind::Format(format!("Unknown chunk identifier: {}", chunk_id)))
52+
.map_err(Error::from)?;
53+
if chunk_id != identifier {
54+
return Err(ErrorKind::Format(format!(
55+
"Invalid chunk identifier, expected: {}, got: {}",
56+
identifier, chunk_id
57+
))
58+
.into());
59+
}
60+
61+
Ok(())
62+
}
63+
}
64+
65+
impl Display for ChunkIdentifier {
66+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
67+
use self::ChunkIdentifier::*;
68+
69+
match self {
70+
Header => write!(f, "Header"),
71+
SimpleVocab => write!(f, "SimpleVocab"),
72+
NdArray => write!(f, "NdArray"),
73+
FastTextSubwordVocab => write!(f, "FastTextSubwordVocab"),
74+
FinalfusionSubwordVocab => write!(f, "FinalfusionSubwordVocab"),
75+
QuantizedArray => write!(f, "QuantizedArray"),
76+
Metadata => write!(f, "Metadata"),
77+
NdNorms => write!(f, "NdNorms"),
78+
}
79+
}
80+
}
81+
82+
pub trait TypeId {
83+
/// Read and ensure that the data type is equal to `Self`.
84+
fn ensure_data_type<R>(read: &mut R) -> Result<()>
85+
where
86+
R: Read;
87+
88+
fn type_id() -> u32;
89+
}
90+
91+
macro_rules! typeid_impl {
92+
($type:ty, $id:expr) => {
93+
impl TypeId for $type {
94+
fn ensure_data_type<R>(read: &mut R) -> Result<()>
95+
where
96+
R: Read,
97+
{
98+
let type_id = read
99+
.read_u32::<LittleEndian>()
100+
.map_err(|e| ErrorKind::io_error("Cannot read type identifier", e))?;
101+
if type_id != Self::type_id() {
102+
return Err(ErrorKind::Format(format!(
103+
"Invalid type, expected: {}, got: {}",
104+
Self::type_id(),
105+
type_id
106+
))
107+
.into());
108+
}
109+
110+
Ok(())
111+
}
112+
113+
fn type_id() -> u32 {
114+
$id
115+
}
116+
}
117+
};
118+
}
119+
120+
typeid_impl!(f32, 10);
121+
typeid_impl!(u8, 1);
122+
123+
pub trait ReadChunk
124+
where
125+
Self: Sized,
126+
{
127+
fn read_chunk<R>(read: &mut R) -> Result<Self>
128+
where
129+
R: Read + Seek;
130+
}
131+
132+
/// Memory-mappable chunks.
133+
pub trait MmapChunk
134+
where
135+
Self: Sized,
136+
{
137+
/// Memory map a chunk.
138+
///
139+
/// The given `File` object should be positioned at the start of the chunk.
140+
fn mmap_chunk(read: &mut BufReader<File>) -> Result<Self>;
141+
}
142+
143+
pub trait WriteChunk {
144+
/// Get the identifier of a chunk.
145+
fn chunk_identifier(&self) -> ChunkIdentifier;
146+
147+
fn write_chunk<W>(&self, write: &mut W) -> Result<()>
148+
where
149+
W: Write + Seek;
150+
}
151+
152+
#[derive(Debug, Eq, PartialEq)]
153+
pub(crate) struct Header {
154+
chunk_identifiers: Vec<ChunkIdentifier>,
155+
}
156+
157+
impl Header {
158+
pub fn new(chunk_identifiers: impl Into<Vec<ChunkIdentifier>>) -> Self {
159+
Header {
160+
chunk_identifiers: chunk_identifiers.into(),
161+
}
162+
}
163+
164+
pub fn chunk_identifiers(&self) -> &[ChunkIdentifier] {
165+
&self.chunk_identifiers
166+
}
167+
}
168+
169+
impl WriteChunk for Header {
170+
fn chunk_identifier(&self) -> ChunkIdentifier {
171+
ChunkIdentifier::Header
172+
}
173+
174+
fn write_chunk<W>(&self, write: &mut W) -> Result<()>
175+
where
176+
W: Write + Seek,
177+
{
178+
write
179+
.write_all(&MAGIC)
180+
.map_err(|e| ErrorKind::io_error("Cannot write magic", e))?;
181+
write
182+
.write_u32::<LittleEndian>(MODEL_VERSION)
183+
.map_err(|e| ErrorKind::io_error("Cannot write model version", e))?;
184+
write
185+
.write_u32::<LittleEndian>(self.chunk_identifiers.len() as u32)
186+
.map_err(|e| ErrorKind::io_error("Cannot write chunk identifiers length", e))?;
187+
188+
for &identifier in &self.chunk_identifiers {
189+
write
190+
.write_u32::<LittleEndian>(identifier as u32)
191+
.map_err(|e| ErrorKind::io_error("Cannot write chunk identifier", e))?;
192+
}
193+
194+
Ok(())
195+
}
196+
}
197+
198+
impl ReadChunk for Header {
199+
fn read_chunk<R>(read: &mut R) -> Result<Self>
200+
where
201+
R: Read + Seek,
202+
{
203+
// Magic and version ceremony.
204+
let mut magic = [0u8; 4];
205+
read.read_exact(&mut magic)
206+
.map_err(|e| ErrorKind::io_error("Cannot read magic", e))?;
207+
208+
if magic != MAGIC {
209+
return Err(ErrorKind::Format(format!(
210+
"Expected 'FiFu' as magic, got: {}",
211+
String::from_utf8_lossy(&magic).into_owned()
212+
))
213+
.into());
214+
}
215+
216+
let version = read
217+
.read_u32::<LittleEndian>()
218+
.map_err(|e| ErrorKind::io_error("Cannot read model version", e))?;
219+
if version != MODEL_VERSION {
220+
return Err(
221+
ErrorKind::Format(format!("Unknown finalfusion version: {}", version)).into(),
222+
);
223+
}
224+
225+
// Read chunk identifiers.
226+
let chunk_identifiers_len = read
227+
.read_u32::<LittleEndian>()
228+
.map_err(|e| ErrorKind::io_error("Cannot read chunk identifiers length", e))?
229+
as usize;
230+
let mut chunk_identifiers = Vec::with_capacity(chunk_identifiers_len);
231+
for _ in 0..chunk_identifiers_len {
232+
let identifier = read
233+
.read_u32::<LittleEndian>()
234+
.map_err(|e| ErrorKind::io_error("Cannot read chunk identifier", e))?;
235+
let chunk_identifier = ChunkIdentifier::try_from(identifier)
236+
.ok_or_else(|| {
237+
ErrorKind::Format(format!("Unknown chunk identifier: {}", identifier))
238+
})
239+
.map_err(Error::from)?;
240+
chunk_identifiers.push(chunk_identifier);
241+
}
242+
243+
Ok(Header { chunk_identifiers })
244+
}
245+
}
246+
247+
#[cfg(test)]
248+
mod tests {
249+
use std::io::{Cursor, Seek, SeekFrom};
250+
251+
use super::{ChunkIdentifier, Header, ReadChunk, WriteChunk};
252+
253+
#[test]
254+
fn header_write_read_roundtrip() {
255+
let check_header =
256+
Header::new(vec![ChunkIdentifier::SimpleVocab, ChunkIdentifier::NdArray]);
257+
let mut cursor = Cursor::new(Vec::new());
258+
check_header.write_chunk(&mut cursor).unwrap();
259+
cursor.seek(SeekFrom::Start(0)).unwrap();
260+
let header = Header::read_chunk(&mut cursor).unwrap();
261+
assert_eq!(header, check_header);
262+
}
263+
}

src/metadata.rs renamed to src/chunks/metadata.rs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,8 @@ use std::io::{Read, Seek, Write};
55
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
66
use toml::Value;
77

8-
use crate::io::{
9-
private::{ChunkIdentifier, Header, ReadChunk, WriteChunk},
10-
Error, ErrorKind, ReadMetadata, Result,
11-
};
8+
use super::io::{ChunkIdentifier, Header, ReadChunk, WriteChunk};
9+
use crate::io::{Error, ErrorKind, ReadMetadata, Result};
1210

1311
/// Embeddings metadata.
1412
///
@@ -101,7 +99,7 @@ mod tests {
10199
use toml::toml;
102100

103101
use super::Metadata;
104-
use crate::io::private::{ReadChunk, WriteChunk};
102+
use crate::chunks::io::{ReadChunk, WriteChunk};
105103

106104
fn read_chunk_size(read: &mut impl Read) -> u64 {
107105
// Skip identifier.

src/chunks/mod.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
//! finalfusion chunks
2+
3+
pub(crate) mod io;
4+
5+
pub mod metadata;
6+
7+
pub mod norms;
8+
9+
pub mod storage;
10+
11+
pub mod vocab;

src/norms.rs renamed to src/chunks/norms.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use std::mem::size_of;
66
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
77
use ndarray::Array1;
88

9-
use crate::io::private::{ChunkIdentifier, ReadChunk, TypeId, WriteChunk};
9+
use super::io::{ChunkIdentifier, ReadChunk, TypeId, WriteChunk};
1010
use crate::io::{ErrorKind, Result};
1111
use crate::util::padding;
1212

@@ -117,8 +117,8 @@ mod tests {
117117
use byteorder::{LittleEndian, ReadBytesExt};
118118
use ndarray::Array1;
119119

120-
use crate::io::private::{ReadChunk, WriteChunk};
121-
use crate::norms::NdNorms;
120+
use super::NdNorms;
121+
use crate::chunks::io::{ReadChunk, WriteChunk};
122122

123123
const LEN: usize = 100;
124124

src/storage.rs renamed to src/chunks/storage.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use rand::{FromEntropy, Rng};
1111
use rand_xorshift::XorShiftRng;
1212
use reductive::pq::{QuantizeVector, ReconstructVector, TrainPQ, PQ};
1313

14-
use crate::io::private::{ChunkIdentifier, MmapChunk, ReadChunk, TypeId, WriteChunk};
14+
use super::io::{ChunkIdentifier, MmapChunk, ReadChunk, TypeId, WriteChunk};
1515
use crate::io::{Error, ErrorKind, Result};
1616
use crate::util::padding;
1717

@@ -936,8 +936,8 @@ mod tests {
936936
use ndarray::Array2;
937937
use reductive::pq::PQ;
938938

939-
use crate::io::private::{ReadChunk, WriteChunk};
940-
use crate::storage::{NdArray, Quantize, QuantizedArray, StorageView};
939+
use crate::chunks::io::{ReadChunk, WriteChunk};
940+
use crate::chunks::storage::{NdArray, Quantize, QuantizedArray, StorageView};
941941

942942
const N_ROWS: usize = 100;
943943
const N_COLS: usize = 100;

src/vocab.rs renamed to src/chunks/vocab.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ use std::mem::size_of;
66

77
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
88

9-
use crate::fasttext::FastTextIndexer;
10-
use crate::io::private::{ChunkIdentifier, ReadChunk, WriteChunk};
9+
use super::io::{ChunkIdentifier, ReadChunk, WriteChunk};
10+
use crate::compat::fasttext::FastTextIndexer;
1111
use crate::io::{Error, ErrorKind, Result};
1212
use crate::subword::{BucketIndexer, FinalfusionHashIndexer, Indexer, SubwordIndices};
1313

@@ -521,8 +521,8 @@ mod tests {
521521
use byteorder::{LittleEndian, ReadBytesExt};
522522

523523
use super::{FastTextSubwordVocab, FinalfusionSubwordVocab, SimpleVocab, SubwordVocab};
524-
use crate::fasttext::FastTextIndexer;
525-
use crate::io::private::{ReadChunk, WriteChunk};
524+
use crate::chunks::io::{ReadChunk, WriteChunk};
525+
use crate::compat::fasttext::FastTextIndexer;
526526
use crate::subword::{BucketIndexer, FinalfusionHashIndexer};
527527

528528
fn test_fasttext_subword_vocab() -> FastTextSubwordVocab {
File renamed without changes.

src/fasttext/io.rs renamed to src/compat/fasttext/io.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,14 @@ use ndarray::{s, Array2, ErrorKind as ShapeErrorKind, ShapeError};
55
use serde::Serialize;
66
use toml::Value;
77

8+
use crate::chunks::metadata::Metadata;
9+
use crate::chunks::norms::NdNorms;
10+
use crate::chunks::storage::{NdArray, Storage, StorageViewMut};
11+
use crate::chunks::vocab::{FastTextSubwordVocab, Vocab};
812
use crate::embeddings::Embeddings;
913
use crate::io::{Error, ErrorKind, Result};
10-
use crate::metadata::Metadata;
11-
use crate::norms::NdNorms;
12-
use crate::storage::{NdArray, Storage, StorageViewMut};
1314
use crate::subword::BucketIndexer;
1415
use crate::util::{l2_normalize_array, read_string};
15-
use crate::vocab::{FastTextSubwordVocab, Vocab};
1616

1717
use super::FastTextIndexer;
1818

File renamed without changes.

src/compat/mod.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
//! Readers/writers for other embedding formats.
2+
3+
pub mod fasttext;
4+
5+
pub mod text;
6+
7+
pub mod word2vec;

0 commit comments

Comments
 (0)