Skip to content

Commit 024dbf4

Browse files
fix: improve contextual chunk headings (#118)
1 parent d6e8422 commit 024dbf4

7 files changed

Lines changed: 102 additions & 23 deletions

File tree

Dockerfile

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,8 @@ RUN mkdir ~/.history/ && \
2626
echo 'HISTFILE=~/.history/.bash_history' >> ~/.bashrc && \
2727
echo 'bind "\"\e[A\": history-search-backward"' >> ~/.bashrc && \
2828
echo 'bind "\"\e[B\": history-search-forward"' >> ~/.bashrc && \
29-
echo 'eval "$(starship init bash)"' >> ~/.bashrc
29+
echo 'eval "$(starship init bash)"' >> ~/.bashrc
30+
31+
# Explicitly configure compilers for llama-cpp-python.
32+
ENV CMAKE_C_COMPILER=/usr/bin/gcc
33+
ENV CMAKE_CXX_COMPILER=/usr/bin/g++

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ dependencies = [
4444
# CLI:
4545
"typer (>=0.15.1)",
4646
# Model Context Protocol:
47-
"fastmcp (>=0.4.1)",
47+
"fastmcp (>=2.0.0)",
4848
# Utilities:
4949
"packaging (>=23.0)",
5050
]
@@ -82,7 +82,7 @@ dev = [
8282
# Frontend:
8383
chainlit = ["chainlit (>=2.0.0)"]
8484
# Large Language Models:
85-
llama-cpp-python = ["llama-cpp-python (>=0.3.3)"]
85+
llama-cpp-python = ["llama-cpp-python (>=0.3.4)"]
8686
# Markdown conversion:
8787
pandoc = ["pypandoc-binary (>=1.13)"]
8888
# Evaluation:

src/raglite/_chatml_function_calling.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ def chatml_function_calling_with_streaming(
290290
# Assistant message
291291
"{% if message.role == 'assistant' %}"
292292
## Regular message
293-
"{% if message.content and message.content | length > 0 %}"
293+
"{% if 'content' in message and message.content %}"
294294
"{% if tool_calls %}"
295295
"message:\n"
296296
"{% endif %}"
@@ -310,7 +310,6 @@ def chatml_function_calling_with_streaming(
310310
"{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
311311
)
312312
template_renderer = ImmutableSandboxedEnvironment(
313-
autoescape=jinja2.select_autoescape(["html", "xml"]),
314313
undefined=jinja2.StrictUndefined,
315314
).from_string(function_calling_template)
316315

src/raglite/_database.py

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -119,26 +119,44 @@ def from_body(
119119
id=hash_bytes(f"{document_id}-{index}".encode()),
120120
document_id=document_id,
121121
index=index,
122-
headings=headings,
122+
headings=Chunk.truncate_headings(headings, body),
123123
body=body,
124124
metadata_=kwargs,
125125
)
126126

127-
def extract_headings(self) -> str:
128-
"""Extract Markdown headings from the chunk, starting from the current Markdown headings."""
127+
@staticmethod
128+
def extract_heading_lines(doc: str, leading_only: bool = False) -> list[str]: # noqa: FBT001,FBT002
129+
"""Extract the leading or final state of the Markdown headings of a document."""
129130
md = MarkdownIt()
130-
heading_lines = [""] * 10
131+
heading_lines = [""] * 6
131132
level = None
132-
for doc in (self.headings, self.body):
133-
for token in md.parse(doc):
134-
if token.type == "heading_open":
135-
level = int(token.tag[1])
136-
elif token.type == "heading_close":
137-
level = None
138-
elif level is not None:
139-
heading_content = token.content.strip().replace("\n", " ")
140-
heading_lines[level] = ("#" * level) + " " + heading_content
141-
heading_lines[level + 1 :] = [""] * len(heading_lines[level + 1 :])
133+
for token in md.parse(doc):
134+
if token.type == "heading_open":
135+
level = int(token.tag[1]) if 1 <= int(token.tag[1]) <= 6 else None # noqa: PLR2004
136+
elif token.type == "heading_close":
137+
level = None
138+
elif level is not None:
139+
heading_content = token.content.strip().replace("\n", " ")
140+
heading_lines[level - 1] = ("#" * level) + " " + heading_content
141+
heading_lines[level:] = [""] * len(heading_lines[level + 1 :])
142+
elif leading_only and level is None and token.content and not token.content.isspace():
143+
break
144+
return heading_lines
145+
146+
@staticmethod
147+
def truncate_headings(headings: str, body: str) -> str:
148+
"""Truncate the contextual headings given the chunk's leading headings (if present)."""
149+
heading_lines = Chunk.extract_heading_lines(headings)
150+
leading_body_heading_lines = Chunk.extract_heading_lines(body, leading_only=True)
151+
level = next((i + 1 for i, line in enumerate(leading_body_heading_lines) if line), None)
152+
if level:
153+
heading_lines[level - 1 :] = [""] * len(heading_lines[level - 1 :])
154+
headings = "\n".join([heading for heading in heading_lines if heading])
155+
return headings
156+
157+
def extract_headings(self) -> str:
158+
"""Extract Markdown headings from the chunk, starting from the contextual headings."""
159+
heading_lines = self.extract_heading_lines(self.headings + "\n\n" + self.body)
142160
headings = "\n".join([heading for heading in heading_lines if heading])
143161
return headings
144162

src/raglite/_mcp.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""MCP server for RAGLite."""
22

3-
from typing import Annotated
3+
from typing import Annotated, Any
44

55
from fastmcp import FastMCP
66
from pydantic import Field
@@ -20,9 +20,9 @@
2020
]
2121

2222

23-
def create_mcp_server(server_name: str, *, config: RAGLiteConfig) -> FastMCP:
23+
def create_mcp_server(server_name: str, *, config: RAGLiteConfig) -> FastMCP[Any]:
2424
"""Create a RAGLite MCP server."""
25-
mcp = FastMCP(server_name)
25+
mcp: FastMCP[Any] = FastMCP(server_name)
2626

2727
@mcp.prompt()
2828
def kb(query: Query) -> str:

tests/test_insert.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
"""Tests for the _insert module."""
2+
3+
from pathlib import Path
4+
5+
from sqlmodel import Session, select
6+
from tqdm import tqdm
7+
8+
from raglite._config import RAGLiteConfig
9+
from raglite._database import Chunk, Document, create_database_engine
10+
from raglite._markdown import document_to_markdown
11+
12+
13+
def test_insert(raglite_test_config: RAGLiteConfig) -> None:
14+
"""Test the insert function by testing logic on chunks in raglite_test_config database."""
15+
# Get access to the database from the raglite_test_config
16+
engine = create_database_engine(raglite_test_config)
17+
18+
# Open a session to extract document and chunks from the existing database
19+
with Session(engine) as session:
20+
# Get the first document from the database (already inserted by the fixture)
21+
document = session.exec(select(Document)).first()
22+
assert document is not None, "No document found in the database"
23+
24+
# Get the existing chunks for this document
25+
chunks = session.exec(
26+
select(Chunk).where(Chunk.document_id == document.id).order_by(Chunk.index) # type: ignore[arg-type]
27+
).all()
28+
assert len(chunks) > 0, "No chunks found for the document"
29+
restored_document = ""
30+
for chunk in tqdm(chunks, desc="Processing chunks"):
31+
# body should not contain the heading string (except if heading is empty)
32+
if chunk.headings.strip() != "":
33+
assert chunk.headings.strip() not in chunk.body.strip(), (
34+
f"Chunk body contains heading: '{chunk.headings.strip()}'\n"
35+
f"Chunk body: '{chunk.body.strip()}'"
36+
)
37+
38+
# Body that starts with a # should not have a heading
39+
if chunk.body.strip().startswith("# "):
40+
assert chunk.headings.strip() == "", (
41+
f"Chunk body starts with a heading: '{chunk.body.strip()}'\n"
42+
f"Chunk headings: '{chunk.headings.strip()}'"
43+
)
44+
45+
restored_document += chunk.body
46+
47+
# combining the chunks should yield the original document
48+
restored_document = "".join(restored_document)
49+
restored_document = restored_document.replace("\n", "").strip()
50+
51+
doc_path = Path(__file__).parent / "specrel.pdf" # Einstein's special relativity paper.
52+
doc = document_to_markdown(doc_path)
53+
doc = doc.replace("\n", "").strip()
54+
55+
assert restored_document == doc, "Restored document does not match the original input."

tests/test_search.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,10 @@ def test_search(raglite_test_config: RAGLiteConfig, search_method: SearchMethod)
4242
chunks = retrieve_chunks(chunk_ids, config=raglite_test_config)
4343
assert all(isinstance(chunk, Chunk) for chunk in chunks)
4444
assert all(chunk_id == chunk.id for chunk_id, chunk in zip(chunk_ids, chunks, strict=True))
45-
assert any("Definition of Simultaneity" in str(chunk) for chunk in chunks)
45+
assert any("Definition of Simultaneity" in str(chunk) for chunk in chunks), (
46+
"Expected 'Definition of Simultaneity' in chunks but got:\n"
47+
+ "\n".join(f"- Chunk {i + 1}:\n{chunk!s}\n{'-' * 80}" for i, chunk in enumerate(chunks))
48+
)
4649
assert all(isinstance(chunk.document, Document) for chunk in chunks)
4750
# Extend the chunks with their neighbours and group them into contiguous segments.
4851
chunk_spans = retrieve_chunk_spans(chunk_ids, neighbors=(-1, 1), config=raglite_test_config)

0 commit comments

Comments
 (0)