ai-skills-api/rag.py

263 lines
7.1 KiB
Python

"""
RAG-based context selection using local embeddings.
Optimized with in-memory embedding cache to avoid recomputation.
"""
import numpy as np
from sentence_transformers import SentenceTransformer
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from typing import List, Dict, Optional
import asyncio
from threading import Lock
# Small, fast model - ~100MB, runs on CPU
MODEL_NAME = "all-MiniLM-L6-v2"
_model: Optional[SentenceTransformer] = None
# In-memory embedding cache
_embedding_cache: Dict[str, Dict] = {
"skills": {},
"snippets": {},
"initialized": False
}
_cache_lock = Lock()
def get_model() -> SentenceTransformer:
"""Lazy-load the embedding model"""
global _model
if _model is None:
_model = SentenceTransformer(MODEL_NAME)
return _model
def embed_text(text: str) -> np.ndarray:
"""Generate embedding for text"""
model = get_model()
return model.encode(text, normalize_embeddings=True)
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
"""Compute cosine similarity between two vectors"""
return float(np.dot(a, b))
def _get_skill_text(skill) -> str:
"""Generate searchable text for a skill"""
return f"{skill.name} {skill.description or ''} {skill.content[:500]}"
def _get_snippet_text(snippet) -> str:
"""Generate searchable text for a snippet"""
return f"{snippet.name} {snippet.content}"
async def ensure_cache_initialized(db: AsyncSession):
"""Load all skills and snippets into memory with their embeddings"""
global _embedding_cache
with _cache_lock:
if _embedding_cache["initialized"]:
return
# Load skills
result = await db.execute(select(Skill))
skills = result.scalars().all()
for skill in skills:
text = _get_skill_text(skill)
embedding = embed_text(text)
_embedding_cache["skills"][skill.id] = {
"embedding": embedding,
"skill": skill
}
# Load snippets
result = await db.execute(select(Snippet))
snippets = result.scalars().all()
for snippet in snippets:
text = _get_snippet_text(snippet)
embedding = embed_text(text)
_embedding_cache["snippets"][snippet.id] = {
"embedding": embedding,
"snippet": snippet
}
_embedding_cache["initialized"] = True
async def select_relevant_skills(
query: str,
db: AsyncSession,
top_k: int = 3,
min_score: float = 0.3
) -> List[Dict]:
"""
Find most relevant skills using semantic search.
Uses cached embeddings for O(1) retrieval after initial load.
"""
from models import Skill
await ensure_cache_initialized(db)
if not _embedding_cache["skills"]:
return []
# Generate query embedding
query_embedding = embed_text(query)
# Score each cached skill
scored = []
for skill_id, cached in _embedding_cache["skills"].items():
skill = cached["skill"]
embedding = cached["embedding"]
score = cosine_similarity(query_embedding, embedding)
if score >= min_score:
scored.append((score, skill))
# Sort by relevance, take top K
scored.sort(key=lambda x: x[0], reverse=True)
top_skills = scored[:top_k]
return [
{
"id": skill.id,
"name": skill.name,
"content": skill.content,
"relevance_score": score
}
for score, skill in top_skills
]
async def select_relevant_conventions(
project_path: str,
db: AsyncSession,
top_k: int = 2
) -> List[Dict]:
"""
Get conventions for a project.
Exact match on project_path, plus fuzzy match on parent paths.
(No embeddings - small dataset, exact matching is fine)
"""
from models import Convention
result = await db.execute(
select(Convention)
.where(Convention.project_path == project_path)
.order_by(Convention.auto_inject.desc())
)
exact_matches = result.scalars().all()
if exact_matches:
return [
{"id": c.id, "name": c.name, "content": c.content}
for c in exact_matches[:top_k]
]
# Try parent path match
parent_path = "/".join(project_path.split("/")[:-1])
if parent_path:
result = await db.execute(
select(Convention)
.where(Convention.project_path == parent_path)
)
parent_matches = result.scalars().all()
return [
{"id": c.id, "name": c.name, "content": c.content}
for c in parent_matches[:top_k]
]
return []
async def select_relevant_snippets(
query: str,
db: AsyncSession,
top_k: int = 2,
language: Optional[str] = None
) -> List[Dict]:
"""Find relevant code snippets using cached embeddings"""
from models import Snippet
await ensure_cache_initialized(db)
if not _embedding_cache["snippets"]:
return []
query_embedding = embed_text(query)
scored = []
for snippet_id, cached in _embedding_cache["snippets"].items():
snippet = cached["snippet"]
if language and snippet.language != language:
continue
embedding = cached["embedding"]
score = cosine_similarity(query_embedding, embedding)
if score >= 0.25: # Lower threshold for snippets
scored.append((score, snippet))
scored.sort(key=lambda x: x[0], reverse=True)
return [
{
"id": s.id,
"name": s.name,
"language": s.language,
"content": s.content,
"relevance_score": score
}
for score, s in scored[:top_k]
]
async def build_context_bundle(
query: str,
db: AsyncSession,
project: Optional[str] = None,
max_skills: int = 3,
max_conventions: int = 2,
max_snippets: int = 2
) -> Dict:
"""
Build optimized context bundle with only relevant items.
This is the main RAG entry point.
"""
skills, conventions, snippets = await asyncio.gather(
select_relevant_skills(query, db, top_k=max_skills),
select_relevant_conventions(project, db, top_k=max_conventions) if project else asyncio.coroutine(lambda: [])(),
select_relevant_snippets(query, db, top_k=max_snippets)
)
# Calculate total tokens
total_content = "\n".join(
[s["content"] for s in skills] +
[c["content"] for c in conventions] +
[s["content"] for s in snippets]
)
from compression import count_tokens
token_count = count_tokens(total_content)
return {
"skills": skills,
"conventions": conventions,
"snippets": snippets,
"estimated_tokens": token_count,
"items_included": len(skills) + len(conventions) + len(snippets)
}
def clear_cache():
"""Clear embedding cache (useful for testing)"""
global _embedding_cache
_embedding_cache = {
"skills": {},
"snippets": {},
"initialized": False
}