Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement RAG for code understanding (#450) #669

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,5 @@ orjson
gevent
gevent-websocket
curl_cffi
chromadb>=0.4.22
sentence-transformers>=2.2.2
22 changes: 16 additions & 6 deletions src/filesystem/read_code.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
from typing import Dict, List

from src.config import Config
from src.memory.rag import CodeRAG

"""
TODO: Replace this with `code2prompt` - https://github.com/mufeedvh/code2prompt
Expand All @@ -9,27 +11,35 @@
class ReadCode:
def __init__(self, project_name: str):
config = Config()
project_path = config.get_projects_dir()
self.directory_path = os.path.join(project_path, project_name.lower().replace(" ", "-"))
self.project_name = project_name.lower().replace(" ", "-")
self.directory_path = os.path.join(config.get_projects_dir(), self.project_name)
self.rag = CodeRAG(project_name)

def read_directory(self):
def read_directory(self) -> List[Dict[str, str]]:
files_list = []
for root, _dirs, files in os.walk(self.directory_path):
for file in files:
try:
file_path = os.path.join(root, file)
with open(file_path, 'r') as file_content:
files_list.append({"filename": file_path, "code": file_content.read()})
code = file_content.read()
files_list.append({"filename": file_path, "code": code})
self.rag.add_code(file_path, code)
except:
pass

return files_list

def code_set_to_markdown(self):
def code_set_to_markdown(self) -> str:
code_set = self.read_directory()
markdown = ""
for code in code_set:
markdown += f"### {code['filename']}:\n\n"
summary = self.rag.summarize_code(code['code'])
if summary:
markdown += f"Summary: {summary}\n\n"
markdown += f"```\n{code['code']}\n```\n\n"
markdown += "---\n\n"
return markdown

def get_code_context(self, query: str, n_results: int = 5) -> Dict:
return self.rag.get_context(query, n_results)
109 changes: 108 additions & 1 deletion src/memory/rag.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,110 @@
"""
Vector Search for Code Docs + Docs Loading
"""
Implements RAG (Retrieval Augmented Generation) for code understanding
"""

import os
from typing import List, Dict, Optional
import chromadb
from chromadb.config import Settings
from chromadb.utils import embedding_functions

from src.bert.sentence import SentenceBERT
from src.config import Config

class CodeRAG:
def __init__(self, project_name: str):
config = Config()
self.project_name = project_name.lower().replace(" ", "-")
self.db_path = os.path.join(config.get_projects_dir(), ".vector_db")
os.makedirs(self.db_path, exist_ok=True)

# Initialize ChromaDB with persistence
self.client = chromadb.PersistentClient(path=self.db_path)
self.sentence_transformer = embedding_functions.SentenceTransformerEmbeddingFunction(
model_name="all-MiniLM-L6-v2"
)

# Get or create collection for this project
self.collection = self.client.get_or_create_collection(
name=f"code_{self.project_name}",
embedding_function=self.sentence_transformer
)

def chunk_code(self, code: str, chunk_size: int = 1000) -> List[str]:
"""Split code into smaller chunks while preserving context."""
chunks = []
lines = code.split('\n')
current_chunk = []
current_size = 0

for line in lines:
line_size = len(line)
if current_size + line_size > chunk_size and current_chunk:
chunks.append('\n'.join(current_chunk))
current_chunk = []
current_size = 0
current_chunk.append(line)
current_size += line_size

if current_chunk:
chunks.append('\n'.join(current_chunk))
return chunks

def add_code(self, filename: str, code: str):
"""Add code to the vector database with chunking."""
chunks = self.chunk_code(code)

# Generate unique IDs for chunks
chunk_ids = [f"{filename}_{i}" for i in range(len(chunks))]

# Add chunks to collection
self.collection.add(
documents=chunks,
ids=chunk_ids,
metadatas=[{"filename": filename, "chunk": i} for i in range(len(chunks))]
)

def query_similar(self, query: str, n_results: int = 5) -> List[Dict]:
"""Query the vector database for similar code chunks."""
results = self.collection.query(
query_texts=[query],
n_results=n_results
)

return [{
"text": doc,
"metadata": meta,
"distance": dist
} for doc, meta, dist in zip(
results["documents"][0],
results["metadatas"][0],
results["distances"][0]
)]

def summarize_code(self, code: str) -> str:
"""Extract key information from code using SentenceBERT."""
bert = SentenceBERT(code)
keywords = bert.extract_keywords(top_n=10)
return ", ".join([kw[0] for kw in keywords])

def get_context(self, query: str, n_results: int = 5) -> Dict:
"""Get relevant code context for a query."""
similar_chunks = self.query_similar(query, n_results)

context = {
"relevant_code": [],
"summary": [],
"files": set()
}

for chunk in similar_chunks:
context["relevant_code"].append({
"code": chunk["text"],
"file": chunk["metadata"]["filename"],
"relevance": 1 - chunk["distance"] # Convert distance to similarity score
})
context["files"].add(chunk["metadata"]["filename"])
context["summary"].append(self.summarize_code(chunk["text"]))

return context
30 changes: 30 additions & 0 deletions tests/test_code_rag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import unittest
from src.memory.rag import CodeRAG

class TestCodeRAG(unittest.TestCase):
def setUp(self):
self.test_project = "test_project"
self.rag = CodeRAG(self.test_project)

# Test code sample
self.test_code = '''
def calculate_sum(a: int, b: int) -> int:
"""Calculate the sum of two integers."""
return a + b

def multiply_numbers(x: int, y: int) -> int:
"""Multiply two numbers together."""
return x * y
'''

def test_chunk_code(self):
"""Test code chunking functionality."""
chunks = self.rag.chunk_code(self.test_code, chunk_size=100)
self.assertTrue(len(chunks) > 0)
self.assertTrue(all(len(chunk) <= 100 for chunk in chunks))
# Verify function boundaries are preserved
self.assertTrue(any("calculate_sum" in chunk for chunk in chunks))
self.assertTrue(any("multiply_numbers" in chunk for chunk in chunks))

if __name__ == '__main__':
unittest.main()
59 changes: 59 additions & 0 deletions tests/test_read_code.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import os
import shutil
import unittest
from src.filesystem.read_code import ReadCode
from src.config import Config

class TestReadCode(unittest.TestCase):
def setUp(self):
self.config = Config()
self.test_project = "test_project"
self.project_dir = os.path.join(self.config.get_projects_dir(), self.test_project)

# Create test project directory and files
os.makedirs(self.project_dir, exist_ok=True)
self.test_code = '''
def calculate_sum(a: int, b: int) -> int:
"""Calculate the sum of two integers."""
return a + b

def multiply_numbers(x: int, y: int) -> int:
"""Multiply two numbers together."""
return x * y
'''
with open(os.path.join(self.project_dir, "test.py"), "w") as f:
f.write(self.test_code)

self.reader = ReadCode(self.test_project)

def test_read_directory(self):
files = self.reader.read_directory()
self.assertEqual(len(files), 1)
self.assertTrue(any(f["filename"].endswith("test.py") for f in files))
self.assertTrue(any("calculate_sum" in f["code"] for f in files))

def test_code_set_to_markdown(self):
markdown = self.reader.code_set_to_markdown()
self.assertIn("test.py", markdown)
self.assertIn("```", markdown)
self.assertIn("calculate_sum", markdown)
self.assertIn("Summary:", markdown)

def test_get_code_context(self):
context = self.reader.get_code_context("How to multiply numbers?")
self.assertTrue("relevant_code" in context)
self.assertTrue("summary" in context)
self.assertTrue(any("multiply" in code["code"].lower()
for code in context["relevant_code"]))

def tearDown(self):
# Clean up test files
if os.path.exists(self.project_dir):
shutil.rmtree(self.project_dir)
# Clean up vector database
vector_db_path = os.path.join(self.config.get_projects_dir(), ".vector_db")
if os.path.exists(vector_db_path):
shutil.rmtree(vector_db_path)

if __name__ == '__main__':
unittest.main()