Skip to content

Commit 1fa1db5

Browse files
authored
feat(native-composite-key): support multiple key parts in KTable (#921)
* feat(native-composite-key): support multiple key parts in KTable * fix: postgres source key schema * fix: count for multiple key correctly * chore: adjust enum style, forward compatibility * docs: revise docs for multiple key support
1 parent e0a2088 commit 1fa1db5

File tree

28 files changed

+565
-381
lines changed

28 files changed

+565
-381
lines changed

docs/docs/core/basics.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ Each piece of data has a **data type**, falling into one of the following catego
2323

2424
* *Basic type*.
2525
* *Struct type*: a collection of **fields**, each with a name and a type.
26-
* *Table type*: a collection of **rows**, each of which is a struct with specified schema. A table type can be a *KTable* (which has a key field) or a *LTable* (ordered but without key field).
26+
* *Table type*: a collection of **rows**, each of which is a struct with specified schema. A table type can be a *KTable* (with key columns that uniquely identify each row) or a *LTable* (rows are ordered but without keys).
2727

2828
An indexing flow always has a top-level struct, containing all data within and managed by the flow.
2929

docs/docs/core/data_types.mdx

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -148,21 +148,27 @@ We have two specific types of *Table* types: *KTable* and *LTable*.
148148

149149
#### KTable
150150

151-
*KTable* is a *Table* type whose first column serves as the key.
151+
*KTable* is a *Table* type whose one or more columns together serve as the key.
152152
The row order of a *KTable* is not preserved.
153-
Type of the first column (key column) must be a [key type](#key-types).
153+
Each key column must be a [key type](#key-types). When multiple key columns are present, they form a composite key.
154154

155-
In Python, a *KTable* type is represented by `dict[K, V]`.
156-
The `K` should be the type binding to a key type,
157-
and the `V` should be the type binding to a *Struct* type representing the value fields of each row.
158-
When the specific type annotation is not provided,
159-
the key type is bound to a tuple with its key parts when it's a *Struct* type, the value type is bound to `dict[str, Any]`.
155+
In Python, a *KTable* type is represented by `dict[K, V]`.
156+
`K` represents the key and `V` represents the value for each row:
157+
158+
- `K` can be a Struct type (either a frozen dataclass or a `NamedTuple`) that contains all key parts as fields. This is the general way to model multi-part keys.
159+
- When there is only a single key part and it is a basic type (e.g. `str`, `int`), you may use that basic type directly as the dictionary key instead of wrapping it in a Struct.
160+
- `V` should be the type bound to a *Struct* representing the non-key value fields of each row.
161+
162+
When a specific type annotation is not provided:
163+
- For composite keys (multiple key parts), the key binds to a Python tuple of the key parts, e.g. `tuple[str, str]`.
164+
- For a single basic key part, the key binds to that basic Python type.
165+
- The value binds to `dict[str, Any]`.
160166

161167

162168
For example, you can use `dict[str, Person]` or `dict[str, PersonTuple]` to represent a *KTable*, with 4 columns: key (*Str*), `first_name` (*Str*), `last_name` (*Str*), `dob` (*Date*).
163169
It's bound to `dict[str, dict[str, Any]]` if you don't annotate the function argument with a specific type.
164170

165-
Note that if you want to use a *Struct* as the key, you need to ensure its value in Python is immutable. For `dataclass`, annotate it with `@dataclass(frozen=True)`. For `NamedTuple`, immutability is built-in. For example:
171+
Note that when using a Struct as the key, it must be immutable in Python. For a dataclass, annotate it with `@dataclass(frozen=True)`. For `NamedTuple`, immutability is built-in. For example:
166172

167173
```python
168174
@dataclass(frozen=True)
@@ -175,8 +181,8 @@ class PersonKeyTuple(NamedTuple):
175181
id: str
176182
```
177183

178-
Then you can use `dict[PersonKey, Person]` or `dict[PersonKeyTuple, PersonTuple]` to represent a KTable keyed by `PersonKey` or `PersonKeyTuple`.
179-
It's bound to `dict[(str, str), dict[str, Any]]` if you don't annotate the function argument with a specific type.
184+
Then you can use `dict[PersonKey, Person]` or `dict[PersonKeyTuple, PersonTuple]` to represent a KTable keyed by both `id_kind` and `id`.
185+
If you don't annotate the function argument with a specific type, it's bound to `dict[tuple[str, str], dict[str, Any]]`.
180186

181187

182188
#### LTable

docs/docs/getting_started/quickstart.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ Notes:
105105
* `chunk`, representing each row of `chunks`.
106106

107107
3. A *data source* extracts data from an external source.
108-
In this example, the `LocalFile` data source imports local files as a KTable (table with a key field, see [KTable](../core/data_types#ktable) for details), each row has `"filename"` and `"content"` fields.
108+
In this example, the `LocalFile` data source imports local files as a KTable (table with key columns, see [KTable](../core/data_types#ktable) for details), each row has `"filename"` and `"content"` fields.
109109

110110
4. After defining the KTable, we extend a new field `"chunks"` to each row by *transforming* the `"content"` field using `SplitRecursively`. The output of the `SplitRecursively` is also a KTable representing each chunk of the document, with `"location"` and `"text"` fields.
111111

examples/postgres_source/main.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,8 @@ def postgres_product_indexing_flow(
9797
with data_scope["products"].row() as product:
9898
product["full_description"] = flow_builder.transform(
9999
make_full_description,
100-
product["_key"]["product_category"],
101-
product["_key"]["product_name"],
100+
product["product_category"],
101+
product["product_name"],
102102
product["description"],
103103
)
104104
product["total_value"] = flow_builder.transform(
@@ -112,8 +112,8 @@ def postgres_product_indexing_flow(
112112
)
113113
)
114114
indexed_product.collect(
115-
product_category=product["_key"]["product_category"],
116-
product_name=product["_key"]["product_name"],
115+
product_category=product["product_category"],
116+
product_name=product["product_name"],
117117
description=product["description"],
118118
price=product["price"],
119119
amount=product["amount"],

python/cocoindex/convert.py

Lines changed: 46 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import numpy as np
1515

1616
from .typing import (
17-
KEY_FIELD_NAME,
1817
TABLE_TYPES,
1918
AnalyzedAnyType,
2019
AnalyzedBasicType,
@@ -96,14 +95,24 @@ def encode_struct_list(value: Any) -> Any:
9695
f"Value type for dict is required to be a struct (e.g. dataclass or NamedTuple), got {variant.value_type}. "
9796
f"If you want a free-formed dict, use `cocoindex.Json` instead."
9897
)
98+
value_encoder = make_engine_value_encoder(value_type_info)
9999

100-
key_encoder = make_engine_value_encoder(analyze_type_info(variant.key_type))
101-
value_encoder = make_engine_value_encoder(analyze_type_info(variant.value_type))
100+
key_type_info = analyze_type_info(variant.key_type)
101+
key_encoder = make_engine_value_encoder(key_type_info)
102+
if isinstance(key_type_info.variant, AnalyzedBasicType):
103+
104+
def encode_row(k: Any, v: Any) -> Any:
105+
return [key_encoder(k)] + value_encoder(v)
106+
107+
else:
108+
109+
def encode_row(k: Any, v: Any) -> Any:
110+
return key_encoder(k) + value_encoder(v)
102111

103112
def encode_struct_dict(value: Any) -> Any:
104113
if not value:
105114
return []
106-
return [[key_encoder(k)] + value_encoder(v) for k, v in value.items()]
115+
return [encode_row(k, v) for k, v in value.items()]
107116

108117
return encode_struct_dict
109118

@@ -234,25 +243,47 @@ def decode(value: Any) -> Any | None:
234243
f"declared `{dst_type_info.core_type}`, a dict type expected"
235244
)
236245

237-
key_field_schema = engine_fields_schema[0]
238-
field_path.append(f".{key_field_schema.get('name', KEY_FIELD_NAME)}")
239-
key_decoder = make_engine_value_decoder(
240-
field_path,
241-
key_field_schema["type"],
242-
analyze_type_info(key_type),
243-
for_key=True,
244-
)
245-
field_path.pop()
246+
num_key_parts = src_type.get("num_key_parts", 1)
247+
key_type_info = analyze_type_info(key_type)
248+
key_decoder: Callable[..., Any] | None = None
249+
if (
250+
isinstance(
251+
key_type_info.variant, (AnalyzedBasicType, AnalyzedAnyType)
252+
)
253+
and num_key_parts == 1
254+
):
255+
single_key_decoder = make_engine_value_decoder(
256+
field_path,
257+
engine_fields_schema[0]["type"],
258+
key_type_info,
259+
for_key=True,
260+
)
261+
262+
def key_decoder(value: list[Any]) -> Any:
263+
return single_key_decoder(value[0])
264+
265+
else:
266+
key_decoder = make_engine_struct_decoder(
267+
field_path,
268+
engine_fields_schema[0:num_key_parts],
269+
key_type_info,
270+
for_key=True,
271+
)
246272
value_decoder = make_engine_struct_decoder(
247273
field_path,
248-
engine_fields_schema[1:],
274+
engine_fields_schema[num_key_parts:],
249275
analyze_type_info(value_type),
250276
)
251277

252278
def decode(value: Any) -> Any | None:
253279
if value is None:
254280
return None
255-
return {key_decoder(v[0]): value_decoder(v[1:]) for v in value}
281+
return {
282+
key_decoder(v[0:num_key_parts]): value_decoder(
283+
v[num_key_parts:]
284+
)
285+
for v in value
286+
}
256287

257288
return decode
258289

python/cocoindex/typing.py

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -330,35 +330,50 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo:
330330

331331
def _encode_struct_schema(
332332
struct_type: type, key_type: type | None = None
333-
) -> dict[str, Any]:
333+
) -> tuple[dict[str, Any], int | None]:
334334
fields = []
335335

336-
def add_field(name: str, t: Any) -> None:
336+
def add_field(name: str, analyzed_type: AnalyzedTypeInfo) -> None:
337337
try:
338-
type_info = encode_enriched_type_info(analyze_type_info(t))
338+
type_info = encode_enriched_type_info(analyzed_type)
339339
except ValueError as e:
340340
e.add_note(
341341
f"Failed to encode annotation for field - "
342-
f"{struct_type.__name__}.{name}: {t}"
342+
f"{struct_type.__name__}.{name}: {analyzed_type.core_type}"
343343
)
344344
raise
345345
type_info["name"] = name
346346
fields.append(type_info)
347347

348+
def add_fields_from_struct(struct_type: type) -> None:
349+
if dataclasses.is_dataclass(struct_type):
350+
for field in dataclasses.fields(struct_type):
351+
add_field(field.name, analyze_type_info(field.type))
352+
elif is_namedtuple_type(struct_type):
353+
for name, field_type in struct_type.__annotations__.items():
354+
add_field(name, analyze_type_info(field_type))
355+
else:
356+
raise ValueError(f"Unsupported struct type: {struct_type}")
357+
358+
result: dict[str, Any] = {}
359+
num_key_parts = None
348360
if key_type is not None:
349-
add_field(KEY_FIELD_NAME, key_type)
361+
key_type_info = analyze_type_info(key_type)
362+
if isinstance(key_type_info.variant, AnalyzedBasicType):
363+
add_field(KEY_FIELD_NAME, key_type_info)
364+
num_key_parts = 1
365+
elif isinstance(key_type_info.variant, AnalyzedStructType):
366+
add_fields_from_struct(key_type_info.variant.struct_type)
367+
num_key_parts = len(fields)
368+
else:
369+
raise ValueError(f"Unsupported key type: {key_type}")
350370

351-
if dataclasses.is_dataclass(struct_type):
352-
for field in dataclasses.fields(struct_type):
353-
add_field(field.name, field.type)
354-
elif is_namedtuple_type(struct_type):
355-
for name, field_type in struct_type.__annotations__.items():
356-
add_field(name, field_type)
371+
add_fields_from_struct(struct_type)
357372

358-
result: dict[str, Any] = {"fields": fields}
373+
result["fields"] = fields
359374
if doc := inspect.getdoc(struct_type):
360375
result["description"] = doc
361-
return result
376+
return result, num_key_parts
362377

363378

364379
def _encode_type(type_info: AnalyzedTypeInfo) -> dict[str, Any]:
@@ -374,7 +389,7 @@ def _encode_type(type_info: AnalyzedTypeInfo) -> dict[str, Any]:
374389
return {"kind": variant.kind}
375390

376391
if isinstance(variant, AnalyzedStructType):
377-
encoded_type = _encode_struct_schema(variant.struct_type)
392+
encoded_type, _ = _encode_struct_schema(variant.struct_type)
378393
encoded_type["kind"] = "Struct"
379394
return encoded_type
380395

@@ -384,10 +399,8 @@ def _encode_type(type_info: AnalyzedTypeInfo) -> dict[str, Any]:
384399
if isinstance(elem_type_info.variant, AnalyzedStructType):
385400
if variant.vector_info is not None:
386401
raise ValueError("LTable type must not have a vector info")
387-
return {
388-
"kind": "LTable",
389-
"row": _encode_struct_schema(elem_type_info.variant.struct_type),
390-
}
402+
row_type, _ = _encode_struct_schema(elem_type_info.variant.struct_type)
403+
return {"kind": "LTable", "row": row_type}
391404
else:
392405
vector_info = variant.vector_info
393406
return {
@@ -402,12 +415,14 @@ def _encode_type(type_info: AnalyzedTypeInfo) -> dict[str, Any]:
402415
raise ValueError(
403416
f"KTable value must have a Struct type, got {value_type_info.core_type}"
404417
)
418+
row_type, num_key_parts = _encode_struct_schema(
419+
value_type_info.variant.struct_type,
420+
variant.key_type,
421+
)
405422
return {
406423
"kind": "KTable",
407-
"row": _encode_struct_schema(
408-
value_type_info.variant.struct_type,
409-
variant.key_type,
410-
),
424+
"row": row_type,
425+
"num_key_parts": num_key_parts,
411426
}
412427

413428
if isinstance(variant, AnalyzedUnionType):

0 commit comments

Comments
 (0)