263 lines
7.1 KiB
Python
263 lines
7.1 KiB
Python
"""
|
|
RAG-based context selection using local embeddings.
|
|
Optimized with in-memory embedding cache to avoid recomputation.
|
|
"""
|
|
|
|
import numpy as np
|
|
from sentence_transformers import SentenceTransformer
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from typing import List, Dict, Optional
|
|
import asyncio
|
|
from threading import Lock
|
|
|
|
# Small, fast model - ~100MB, runs on CPU
|
|
MODEL_NAME = "all-MiniLM-L6-v2"
|
|
_model: Optional[SentenceTransformer] = None
|
|
|
|
# In-memory embedding cache
|
|
_embedding_cache: Dict[str, Dict] = {
|
|
"skills": {},
|
|
"snippets": {},
|
|
"initialized": False
|
|
}
|
|
_cache_lock = Lock()
|
|
|
|
|
|
def get_model() -> SentenceTransformer:
|
|
"""Lazy-load the embedding model"""
|
|
global _model
|
|
if _model is None:
|
|
_model = SentenceTransformer(MODEL_NAME)
|
|
return _model
|
|
|
|
|
|
def embed_text(text: str) -> np.ndarray:
|
|
"""Generate embedding for text"""
|
|
model = get_model()
|
|
return model.encode(text, normalize_embeddings=True)
|
|
|
|
|
|
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
|
|
"""Compute cosine similarity between two vectors"""
|
|
return float(np.dot(a, b))
|
|
|
|
|
|
def _get_skill_text(skill) -> str:
|
|
"""Generate searchable text for a skill"""
|
|
return f"{skill.name} {skill.description or ''} {skill.content[:500]}"
|
|
|
|
|
|
def _get_snippet_text(snippet) -> str:
|
|
"""Generate searchable text for a snippet"""
|
|
return f"{snippet.name} {snippet.content}"
|
|
|
|
|
|
async def ensure_cache_initialized(db: AsyncSession):
|
|
"""Load all skills and snippets into memory with their embeddings"""
|
|
global _embedding_cache
|
|
|
|
with _cache_lock:
|
|
if _embedding_cache["initialized"]:
|
|
return
|
|
|
|
# Load skills
|
|
result = await db.execute(select(Skill))
|
|
skills = result.scalars().all()
|
|
|
|
for skill in skills:
|
|
text = _get_skill_text(skill)
|
|
embedding = embed_text(text)
|
|
_embedding_cache["skills"][skill.id] = {
|
|
"embedding": embedding,
|
|
"skill": skill
|
|
}
|
|
|
|
# Load snippets
|
|
result = await db.execute(select(Snippet))
|
|
snippets = result.scalars().all()
|
|
|
|
for snippet in snippets:
|
|
text = _get_snippet_text(snippet)
|
|
embedding = embed_text(text)
|
|
_embedding_cache["snippets"][snippet.id] = {
|
|
"embedding": embedding,
|
|
"snippet": snippet
|
|
}
|
|
|
|
_embedding_cache["initialized"] = True
|
|
|
|
|
|
async def select_relevant_skills(
|
|
query: str,
|
|
db: AsyncSession,
|
|
top_k: int = 3,
|
|
min_score: float = 0.3
|
|
) -> List[Dict]:
|
|
"""
|
|
Find most relevant skills using semantic search.
|
|
Uses cached embeddings for O(1) retrieval after initial load.
|
|
"""
|
|
from models import Skill
|
|
|
|
await ensure_cache_initialized(db)
|
|
|
|
if not _embedding_cache["skills"]:
|
|
return []
|
|
|
|
# Generate query embedding
|
|
query_embedding = embed_text(query)
|
|
|
|
# Score each cached skill
|
|
scored = []
|
|
for skill_id, cached in _embedding_cache["skills"].items():
|
|
skill = cached["skill"]
|
|
embedding = cached["embedding"]
|
|
score = cosine_similarity(query_embedding, embedding)
|
|
|
|
if score >= min_score:
|
|
scored.append((score, skill))
|
|
|
|
# Sort by relevance, take top K
|
|
scored.sort(key=lambda x: x[0], reverse=True)
|
|
top_skills = scored[:top_k]
|
|
|
|
return [
|
|
{
|
|
"id": skill.id,
|
|
"name": skill.name,
|
|
"content": skill.content,
|
|
"relevance_score": score
|
|
}
|
|
for score, skill in top_skills
|
|
]
|
|
|
|
|
|
async def select_relevant_conventions(
|
|
project_path: str,
|
|
db: AsyncSession,
|
|
top_k: int = 2
|
|
) -> List[Dict]:
|
|
"""
|
|
Get conventions for a project.
|
|
Exact match on project_path, plus fuzzy match on parent paths.
|
|
(No embeddings - small dataset, exact matching is fine)
|
|
"""
|
|
from models import Convention
|
|
|
|
result = await db.execute(
|
|
select(Convention)
|
|
.where(Convention.project_path == project_path)
|
|
.order_by(Convention.auto_inject.desc())
|
|
)
|
|
exact_matches = result.scalars().all()
|
|
|
|
if exact_matches:
|
|
return [
|
|
{"id": c.id, "name": c.name, "content": c.content}
|
|
for c in exact_matches[:top_k]
|
|
]
|
|
|
|
# Try parent path match
|
|
parent_path = "/".join(project_path.split("/")[:-1])
|
|
if parent_path:
|
|
result = await db.execute(
|
|
select(Convention)
|
|
.where(Convention.project_path == parent_path)
|
|
)
|
|
parent_matches = result.scalars().all()
|
|
return [
|
|
{"id": c.id, "name": c.name, "content": c.content}
|
|
for c in parent_matches[:top_k]
|
|
]
|
|
|
|
return []
|
|
|
|
|
|
async def select_relevant_snippets(
|
|
query: str,
|
|
db: AsyncSession,
|
|
top_k: int = 2,
|
|
language: Optional[str] = None
|
|
) -> List[Dict]:
|
|
"""Find relevant code snippets using cached embeddings"""
|
|
from models import Snippet
|
|
|
|
await ensure_cache_initialized(db)
|
|
|
|
if not _embedding_cache["snippets"]:
|
|
return []
|
|
|
|
query_embedding = embed_text(query)
|
|
|
|
scored = []
|
|
for snippet_id, cached in _embedding_cache["snippets"].items():
|
|
snippet = cached["snippet"]
|
|
if language and snippet.language != language:
|
|
continue
|
|
|
|
embedding = cached["embedding"]
|
|
score = cosine_similarity(query_embedding, embedding)
|
|
|
|
if score >= 0.25: # Lower threshold for snippets
|
|
scored.append((score, snippet))
|
|
|
|
scored.sort(key=lambda x: x[0], reverse=True)
|
|
|
|
return [
|
|
{
|
|
"id": s.id,
|
|
"name": s.name,
|
|
"language": s.language,
|
|
"content": s.content,
|
|
"relevance_score": score
|
|
}
|
|
for score, s in scored[:top_k]
|
|
]
|
|
|
|
|
|
async def build_context_bundle(
|
|
query: str,
|
|
db: AsyncSession,
|
|
project: Optional[str] = None,
|
|
max_skills: int = 3,
|
|
max_conventions: int = 2,
|
|
max_snippets: int = 2
|
|
) -> Dict:
|
|
"""
|
|
Build optimized context bundle with only relevant items.
|
|
This is the main RAG entry point.
|
|
"""
|
|
skills, conventions, snippets = await asyncio.gather(
|
|
select_relevant_skills(query, db, top_k=max_skills),
|
|
select_relevant_conventions(project, db, top_k=max_conventions) if project else asyncio.coroutine(lambda: [])(),
|
|
select_relevant_snippets(query, db, top_k=max_snippets)
|
|
)
|
|
|
|
# Calculate total tokens
|
|
total_content = "\n".join(
|
|
[s["content"] for s in skills] +
|
|
[c["content"] for c in conventions] +
|
|
[s["content"] for s in snippets]
|
|
)
|
|
|
|
from compression import count_tokens
|
|
token_count = count_tokens(total_content)
|
|
|
|
return {
|
|
"skills": skills,
|
|
"conventions": conventions,
|
|
"snippets": snippets,
|
|
"estimated_tokens": token_count,
|
|
"items_included": len(skills) + len(conventions) + len(snippets)
|
|
}
|
|
|
|
|
|
def clear_cache():
|
|
"""Clear embedding cache (useful for testing)"""
|
|
global _embedding_cache
|
|
_embedding_cache = {
|
|
"skills": {},
|
|
"snippets": {},
|
|
"initialized": False
|
|
}
|