""" RAG-based context selection using local embeddings. Optimized with in-memory embedding cache to avoid recomputation. """ import numpy as np from sentence_transformers import SentenceTransformer from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from typing import List, Dict, Optional import asyncio from threading import Lock # Small, fast model - ~100MB, runs on CPU MODEL_NAME = "all-MiniLM-L6-v2" _model: Optional[SentenceTransformer] = None # In-memory embedding cache _embedding_cache: Dict[str, Dict] = { "skills": {}, "snippets": {}, "initialized": False } _cache_lock = Lock() def get_model() -> SentenceTransformer: """Lazy-load the embedding model""" global _model if _model is None: _model = SentenceTransformer(MODEL_NAME) return _model def embed_text(text: str) -> np.ndarray: """Generate embedding for text""" model = get_model() return model.encode(text, normalize_embeddings=True) def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float: """Compute cosine similarity between two vectors""" return float(np.dot(a, b)) def _get_skill_text(skill) -> str: """Generate searchable text for a skill""" return f"{skill.name} {skill.description or ''} {skill.content[:500]}" def _get_snippet_text(snippet) -> str: """Generate searchable text for a snippet""" return f"{snippet.name} {snippet.content}" async def ensure_cache_initialized(db: AsyncSession): """Load all skills and snippets into memory with their embeddings""" global _embedding_cache with _cache_lock: if _embedding_cache["initialized"]: return # Load skills result = await db.execute(select(Skill)) skills = result.scalars().all() for skill in skills: text = _get_skill_text(skill) embedding = embed_text(text) _embedding_cache["skills"][skill.id] = { "embedding": embedding, "skill": skill } # Load snippets result = await db.execute(select(Snippet)) snippets = result.scalars().all() for snippet in snippets: text = _get_snippet_text(snippet) embedding = embed_text(text) _embedding_cache["snippets"][snippet.id] = { "embedding": embedding, "snippet": snippet } _embedding_cache["initialized"] = True async def select_relevant_skills( query: str, db: AsyncSession, top_k: int = 3, min_score: float = 0.3 ) -> List[Dict]: """ Find most relevant skills using semantic search. Uses cached embeddings for O(1) retrieval after initial load. """ from models import Skill await ensure_cache_initialized(db) if not _embedding_cache["skills"]: return [] # Generate query embedding query_embedding = embed_text(query) # Score each cached skill scored = [] for skill_id, cached in _embedding_cache["skills"].items(): skill = cached["skill"] embedding = cached["embedding"] score = cosine_similarity(query_embedding, embedding) if score >= min_score: scored.append((score, skill)) # Sort by relevance, take top K scored.sort(key=lambda x: x[0], reverse=True) top_skills = scored[:top_k] return [ { "id": skill.id, "name": skill.name, "content": skill.content, "relevance_score": score } for score, skill in top_skills ] async def select_relevant_conventions( project_path: str, db: AsyncSession, top_k: int = 2 ) -> List[Dict]: """ Get conventions for a project. Exact match on project_path, plus fuzzy match on parent paths. (No embeddings - small dataset, exact matching is fine) """ from models import Convention result = await db.execute( select(Convention) .where(Convention.project_path == project_path) .order_by(Convention.auto_inject.desc()) ) exact_matches = result.scalars().all() if exact_matches: return [ {"id": c.id, "name": c.name, "content": c.content} for c in exact_matches[:top_k] ] # Try parent path match parent_path = "/".join(project_path.split("/")[:-1]) if parent_path: result = await db.execute( select(Convention) .where(Convention.project_path == parent_path) ) parent_matches = result.scalars().all() return [ {"id": c.id, "name": c.name, "content": c.content} for c in parent_matches[:top_k] ] return [] async def select_relevant_snippets( query: str, db: AsyncSession, top_k: int = 2, language: Optional[str] = None ) -> List[Dict]: """Find relevant code snippets using cached embeddings""" from models import Snippet await ensure_cache_initialized(db) if not _embedding_cache["snippets"]: return [] query_embedding = embed_text(query) scored = [] for snippet_id, cached in _embedding_cache["snippets"].items(): snippet = cached["snippet"] if language and snippet.language != language: continue embedding = cached["embedding"] score = cosine_similarity(query_embedding, embedding) if score >= 0.25: # Lower threshold for snippets scored.append((score, snippet)) scored.sort(key=lambda x: x[0], reverse=True) return [ { "id": s.id, "name": s.name, "language": s.language, "content": s.content, "relevance_score": score } for score, s in scored[:top_k] ] async def build_context_bundle( query: str, db: AsyncSession, project: Optional[str] = None, max_skills: int = 3, max_conventions: int = 2, max_snippets: int = 2 ) -> Dict: """ Build optimized context bundle with only relevant items. This is the main RAG entry point. """ skills, conventions, snippets = await asyncio.gather( select_relevant_skills(query, db, top_k=max_skills), select_relevant_conventions(project, db, top_k=max_conventions) if project else asyncio.coroutine(lambda: [])(), select_relevant_snippets(query, db, top_k=max_snippets) ) # Calculate total tokens total_content = "\n".join( [s["content"] for s in skills] + [c["content"] for c in conventions] + [s["content"] for s in snippets] ) from compression import count_tokens token_count = count_tokens(total_content) return { "skills": skills, "conventions": conventions, "snippets": snippets, "estimated_tokens": token_count, "items_included": len(skills) + len(conventions) + len(snippets) } def clear_cache(): """Clear embedding cache (useful for testing)""" global _embedding_cache _embedding_cache = { "skills": {}, "snippets": {}, "initialized": False }