Skip to content

Commit 59f692f

Browse files
authored
Update SplitRecursively to take language and chunk sizes dynamically. (#124)
1 parent 0cc1da3 commit 59f692f

File tree

8 files changed

+115
-86
lines changed

8 files changed

+115
-86
lines changed

examples/code_embedding/code_embedding.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ def code_embedding_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoind
2121

2222
with data_scope["files"].row() as file:
2323
file["chunks"] = file["content"].transform(
24-
cocoindex.functions.SplitRecursively(
25-
language="javascript", chunk_size=300, chunk_overlap=100))
24+
cocoindex.functions.SplitRecursively(),
25+
language="javascript", chunk_size=300, chunk_overlap=100)
2626
with file["chunks"].row() as chunk:
2727
chunk["embedding"] = chunk["text"].call(code_to_embedding)
2828
code_embeddings.collect(filename=file["filename"], location=chunk["location"],

examples/pdf_embedding/pdf_embedding.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ def pdf_embedding_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoinde
5050
with data_scope["documents"].row() as doc:
5151
doc["markdown"] = doc["content"].transform(PdfToMarkdown())
5252
doc["chunks"] = doc["markdown"].transform(
53-
cocoindex.functions.SplitRecursively(
54-
language="markdown", chunk_size=300, chunk_overlap=100))
53+
cocoindex.functions.SplitRecursively(),
54+
language="markdown", chunk_size=300, chunk_overlap=100)
5555

5656
with doc["chunks"].row() as chunk:
5757
chunk["embedding"] = chunk["text"].call(text_to_embedding)

examples/text_embedding/text_embedding.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,8 @@ def text_embedding_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoind
2323

2424
with data_scope["documents"].row() as doc:
2525
doc["chunks"] = doc["content"].transform(
26-
cocoindex.functions.SplitRecursively(
27-
language="markdown", chunk_size=300, chunk_overlap=100))
28-
29-
doc["chunks"] = flow_builder.call(
30-
cocoindex.functions.SplitRecursively(),
31-
doc["content"], language="markdown", chunk_size=300, chunk_overlap=100);
26+
cocoindex.functions.SplitRecursively(),
27+
language="markdown", chunk_size=300, chunk_overlap=100)
3228

3329
with doc["chunks"].row() as chunk:
3430
chunk["embedding"] = text_to_embedding(chunk["text"])

python/cocoindex/convert.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
"""
2+
Utilities to convert between Python and engine values.
3+
"""
4+
import dataclasses
5+
from typing import Any
6+
7+
def to_engine_value(value: Any) -> Any:
8+
"""Convert a Python value to an engine value."""
9+
if dataclasses.is_dataclass(value):
10+
return [to_engine_value(getattr(value, f.name)) for f in dataclasses.fields(value)]
11+
if isinstance(value, (list, tuple)):
12+
return [to_engine_value(v) for v in value]
13+
return value

python/cocoindex/flow.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -162,24 +162,27 @@ def for_each(self, f: Callable[[DataScope], None]) -> None:
162162
with self.row() as scope:
163163
f(scope)
164164

165-
def transform(self, fn_spec: op.FunctionSpec, /, name: str | None = None) -> DataSlice:
165+
def transform(self, fn_spec: op.FunctionSpec, *args, **kwargs) -> DataSlice:
166166
"""
167167
Apply a function to the data slice.
168168
"""
169-
args = [(self._state.engine_data_slice, None)]
169+
transform_args = [(self._state.engine_data_slice, None)]
170+
transform_args += [(self._state.flow_builder_state.get_data_slice(v), None) for v in args]
171+
transform_args += [(self._state.flow_builder_state.get_data_slice(v), k)
172+
for (k, v) in kwargs.items()]
173+
170174
flow_builder_state = self._state.flow_builder_state
171175
return _create_data_slice(
172176
flow_builder_state,
173177
lambda target_scope, name:
174178
flow_builder_state.engine_flow_builder.transform(
175179
_spec_kind(fn_spec),
176180
_spec_value_dump(fn_spec),
177-
args,
181+
transform_args,
178182
target_scope,
179183
flow_builder_state.field_name_builder.build_name(
180184
name, prefix=_to_snake_case(_spec_kind(fn_spec))+'_'),
181-
),
182-
name)
185+
))
183186

184187
def call(self, func: Callable[[DataSlice], T]) -> T:
185188
"""
@@ -282,6 +285,14 @@ def __init__(self, /, name: str | None = None):
282285
self.engine_flow_builder = _engine.FlowBuilder(flow_name)
283286
self.field_name_builder = _NameBuilder()
284287

288+
def get_data_slice(self, v: Any) -> _engine.DataSlice:
289+
"""
290+
Return a data slice that represents the given value.
291+
"""
292+
if isinstance(v, DataSlice):
293+
return v._state.engine_data_slice
294+
return self.engine_flow_builder.constant(encode_enriched_type(type(v)), v)
295+
285296
class FlowBuilder:
286297
"""
287298
A flow builder is used to build a flow.
@@ -313,7 +324,6 @@ def add_source(self, spec: op.SourceSpec, /, name: str | None = None) -> DataSli
313324
name
314325
)
315326

316-
317327
class Flow:
318328
"""
319329
A flow describes an indexing pipeline.

python/cocoindex/functions.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,6 @@
77

88
class SplitRecursively(op.FunctionSpec):
99
"""Split a document (in string) recursively."""
10-
chunk_size: int
11-
chunk_overlap: int
12-
language: str | None = None
1310

1411
class ExtractByLlm(op.FunctionSpec):
1512
"""Extract information from a text using a LLM."""

python/cocoindex/op.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from threading import Lock
1010

1111
from .typing import encode_enriched_type, analyze_type_info, COLLECTION_TYPES
12+
from .convert import to_engine_value
1213
from . import _engine
1314

1415

@@ -59,14 +60,6 @@ def __call__(self, spec: dict[str, Any], *args, **kwargs):
5960
result_type = executor.analyze(*args, **kwargs)
6061
return (encode_enriched_type(result_type), executor)
6162

62-
def _to_engine_value(value: Any) -> Any:
63-
"""Convert a Python value to an engine value."""
64-
if dataclasses.is_dataclass(value):
65-
return [_to_engine_value(getattr(value, f.name)) for f in dataclasses.fields(value)]
66-
if isinstance(value, (list, tuple)):
67-
return [_to_engine_value(v) for v in value]
68-
return value
69-
7063
def _make_engine_struct_value_converter(
7164
field_path: list[str],
7265
src_fields: list[dict[str, Any]],
@@ -251,7 +244,7 @@ def __call__(self, *args, **kwargs):
251244
output = super().__call__(*converted_args, **converted_kwargs)
252245
else:
253246
output = super().__call__(*converted_args, **converted_kwargs)
254-
return _to_engine_value(output)
247+
return to_engine_value(output)
255248

256249
_WrappedClass.__name__ = cls.__name__
257250

src/ops/functions/split_recursively.rs

Lines changed: 78 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,13 @@ use std::{collections::HashMap, sync::Arc};
55
use crate::base::field_attrs;
66
use crate::{fields_value, ops::sdk::*};
77

8-
#[derive(Debug, Deserialize)]
9-
pub struct Spec {
10-
#[serde(default)]
11-
language: Option<String>,
12-
13-
chunk_size: usize,
14-
15-
#[serde(default)]
16-
chunk_overlap: usize,
17-
}
8+
type Spec = EmptySpec;
189

1910
pub struct Args {
2011
text: ResolvedOpArg,
12+
chunk_size: ResolvedOpArg,
13+
chunk_overlap: Option<ResolvedOpArg>,
14+
language: Option<ResolvedOpArg>,
2115
}
2216

2317
static DEFAULT_SEPARATORS: LazyLock<Vec<Regex>> = LazyLock::new(|| {
@@ -97,36 +91,13 @@ static SEPARATORS_BY_LANG: LazyLock<HashMap<&'static str, Vec<Regex>>> = LazyLoc
9791
.collect()
9892
});
9993

100-
struct Executor {
101-
spec: Spec,
102-
args: Args,
94+
struct SplitTask {
10395
separators: &'static [Regex],
96+
chunk_size: usize,
97+
chunk_overlap: usize,
10498
}
10599

106-
impl Executor {
107-
fn new(spec: Spec, args: Args) -> Result<Self> {
108-
let separators = spec
109-
.language
110-
.as_ref()
111-
.and_then(|lang| {
112-
SEPARATORS_BY_LANG
113-
.get(lang.to_lowercase().as_str())
114-
.map(|v| v.as_slice())
115-
})
116-
.unwrap_or(DEFAULT_SEPARATORS.as_slice());
117-
Ok(Self {
118-
spec,
119-
args,
120-
separators,
121-
})
122-
}
123-
124-
fn add_output<'s>(pos: usize, text: &'s str, output: &mut Vec<(RangeValue, &'s str)>) {
125-
if !text.trim().is_empty() {
126-
output.push((RangeValue::new(pos, pos + text.len()), text));
127-
}
128-
}
129-
100+
impl SplitTask {
130101
fn split_substring<'s>(
131102
&self,
132103
s: &'s str,
@@ -135,7 +106,7 @@ impl Executor {
135106
output: &mut Vec<(RangeValue, &'s str)>,
136107
) {
137108
if next_sep_id >= self.separators.len() {
138-
Self::add_output(base_pos, s, output);
109+
self.add_output(base_pos, s, output);
139110
return;
140111
}
141112

@@ -147,17 +118,17 @@ impl Executor {
147118
let mut start_pos = chunks[0].start;
148119
for i in 1..chunks.len() - 1 {
149120
let chunk = &chunks[i];
150-
if chunk.end - start_pos > self.spec.chunk_size {
151-
Self::add_output(base_pos + start_pos, &s[start_pos..chunk.end], output);
121+
if chunk.end - start_pos > self.chunk_size {
122+
self.add_output(base_pos + start_pos, &s[start_pos..chunk.end], output);
152123

153124
// Find the new start position, allowing overlap within the threshold.
154125
let mut new_start_idx = i + 1;
155126
let next_chunk = &chunks[i + 1];
156127
while new_start_idx > 0 {
157128
let prev_pos = chunks[new_start_idx - 1].start;
158129
if prev_pos <= start_pos
159-
|| chunk.end - prev_pos > self.spec.chunk_overlap
160-
|| next_chunk.end - prev_pos > self.spec.chunk_size
130+
|| chunk.end - prev_pos > self.chunk_overlap
131+
|| next_chunk.end - prev_pos > self.chunk_size
161132
{
162133
break;
163134
}
@@ -168,32 +139,49 @@ impl Executor {
168139
}
169140

170141
let last_chunk = &chunks[chunks.len() - 1];
171-
Self::add_output(base_pos + start_pos, &s[start_pos..last_chunk.end], output);
142+
self.add_output(base_pos + start_pos, &s[start_pos..last_chunk.end], output);
172143
};
173144

174145
let mut small_chunks = Vec::new();
175-
let mut process_chunk = |start: usize, end: usize| {
176-
let chunk = &s[start..end];
177-
if chunk.len() <= self.spec.chunk_size {
178-
small_chunks.push(RangeValue::new(start, start + chunk.len()));
179-
} else {
180-
flush_small_chunks(&small_chunks, output);
181-
small_chunks.clear();
182-
self.split_substring(chunk, base_pos + start, next_sep_id + 1, output);
183-
}
184-
};
146+
let mut process_chunk =
147+
|start: usize, end: usize, output: &mut Vec<(RangeValue, &'s str)>| {
148+
let chunk = &s[start..end];
149+
if chunk.len() <= self.chunk_size {
150+
small_chunks.push(RangeValue::new(start, start + chunk.len()));
151+
} else {
152+
flush_small_chunks(&small_chunks, output);
153+
small_chunks.clear();
154+
self.split_substring(chunk, base_pos + start, next_sep_id + 1, output);
155+
}
156+
};
185157

186158
let mut next_start_pos = 0;
187159
for cap in self.separators[next_sep_id].find_iter(s) {
188-
process_chunk(next_start_pos, cap.start());
160+
process_chunk(next_start_pos, cap.start(), output);
189161
next_start_pos = cap.end();
190162
}
191163
if next_start_pos < s.len() {
192-
process_chunk(next_start_pos, s.len());
164+
process_chunk(next_start_pos, s.len(), output);
193165
}
194166

195167
flush_small_chunks(&small_chunks, output);
196168
}
169+
170+
fn add_output<'s>(&self, pos: usize, text: &'s str, output: &mut Vec<(RangeValue, &'s str)>) {
171+
if !text.trim().is_empty() {
172+
output.push((RangeValue::new(pos, pos + text.len()), text));
173+
}
174+
}
175+
}
176+
177+
struct Executor {
178+
args: Args,
179+
}
180+
181+
impl Executor {
182+
fn new(args: Args) -> Result<Self> {
183+
Ok(Self { args })
184+
}
197185
}
198186

199187
fn translate_bytes_to_chars<'a>(text: &str, offsets: impl Iterator<Item = &'a mut usize>) {
@@ -229,9 +217,32 @@ fn translate_bytes_to_chars<'a>(text: &str, offsets: impl Iterator<Item = &'a mu
229217
#[async_trait]
230218
impl SimpleFunctionExecutor for Executor {
231219
async fn evaluate(&self, input: Vec<Value>) -> Result<Value> {
220+
let task = SplitTask {
221+
separators: self
222+
.args
223+
.language
224+
.value(&input)?
225+
.map(|v| v.as_str())
226+
.transpose()?
227+
.and_then(|lang| {
228+
SEPARATORS_BY_LANG
229+
.get(lang.to_lowercase().as_str())
230+
.map(|v| v.as_slice())
231+
})
232+
.unwrap_or(DEFAULT_SEPARATORS.as_slice()),
233+
chunk_size: self.args.chunk_size.value(&input)?.as_int64()? as usize,
234+
chunk_overlap: self
235+
.args
236+
.chunk_overlap
237+
.value(&input)?
238+
.map(|v| v.as_int64())
239+
.transpose()?
240+
.unwrap_or(0) as usize,
241+
};
242+
232243
let text = self.args.text.value(&input)?.as_str()?;
233244
let mut output = Vec::new();
234-
self.split_substring(text, 0, 0, &mut output);
245+
task.split_substring(text, 0, 0, &mut output);
235246

236247
translate_bytes_to_chars(
237248
text,
@@ -271,6 +282,15 @@ impl SimpleFunctionFactoryBase for Factory {
271282
text: args_resolver
272283
.next_arg("text")?
273284
.expect_type(&ValueType::Basic(BasicValueType::Str))?,
285+
chunk_size: args_resolver
286+
.next_arg("chunk_size")?
287+
.expect_type(&ValueType::Basic(BasicValueType::Int64))?,
288+
chunk_overlap: args_resolver
289+
.next_optional_arg("chunk_overlap")?
290+
.expect_type(&ValueType::Basic(BasicValueType::Int64))?,
291+
language: args_resolver
292+
.next_optional_arg("language")?
293+
.expect_type(&ValueType::Basic(BasicValueType::Str))?,
274294
};
275295
let output_schema = make_output_type(CollectionSchema::new(
276296
CollectionKind::Table,
@@ -288,10 +308,10 @@ impl SimpleFunctionFactoryBase for Factory {
288308

289309
async fn build_executor(
290310
self: Arc<Self>,
291-
spec: Spec,
311+
_spec: Spec,
292312
args: Args,
293313
_context: Arc<FlowInstanceContext>,
294314
) -> Result<Box<dyn SimpleFunctionExecutor>> {
295-
Ok(Box::new(Executor::new(spec, args)?))
315+
Ok(Box::new(Executor::new(args)?))
296316
}
297317
}

0 commit comments

Comments
 (0)