""" RAG-based context selection using local embeddings. No external API calls - runs entirely on your home server. """ import numpy as np from sentence_transformers import SentenceTransformer from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from typing import List, Dict, Optional import os # Small, fast model - ~100MB, runs on CPU MODEL_NAME = "all-MiniLM-L6-v2" _model: Optional[SentenceTransformer] = None def get_model() -> SentenceTransformer: """Lazy-load the embedding model""" global _model if _model is None: _model = SentenceTransformer(MODEL_NAME) return _model def embed_text(text: str) -> np.ndarray: """Generate embedding for text""" model = get_model() return model.encode(text, normalize_embeddings=True) def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float: """Compute cosine similarity between two vectors""" return float(np.dot(a, b)) async def select_relevant_skills( query: str, db: AsyncSession, top_k: int = 3, min_score: float = 0.3 ) -> List[Dict]: """ Find most relevant skills using semantic search. Only returns skills above minimum similarity threshold. """ from models import Skill # Get all skills (for small datasets, load all - fine for <1000 items) result = await db.execute(select(Skill)) skills = result.scalars().all() if not skills: return [] # Generate query embedding query_embedding = embed_text(query) # Score each skill scored = [] for skill in skills: # Use cached embedding if available, else compute skill_text = f"{skill.name} {skill.description or ''} {skill.content[:500]}" skill_embedding = embed_text(skill_text) score = cosine_similarity(query_embedding, skill_embedding) if score >= min_score: scored.append((score, skill)) # Sort by relevance, take top K scored.sort(key=lambda x: x[0], reverse=True) top_skills = scored[:top_k] return [ { "id": skill.id, "name": skill.name, "content": skill.content, "relevance_score": score } for score, skill in top_skills ] async def select_relevant_conventions( project_path: str, db: AsyncSession, top_k: int = 2 ) -> List[Dict]: """ Get conventions for a project. Exact match on project_path, plus fuzzy match on parent paths. """ from models import Convention result = await db.execute( select(Convention) .where(Convention.project_path == project_path) .order_by(Convention.auto_inject.desc()) ) exact_matches = result.scalars().all() if exact_matches: return [ {"id": c.id, "name": c.name, "content": c.content} for c in exact_matches[:top_k] ] # Try parent path match parent_path = "/".join(project_path.split("/")[:-1]) if parent_path: result = await db.execute( select(Convention) .where(Convention.project_path == parent_path) ) parent_matches = result.scalars().all() return [ {"id": c.id, "name": c.name, "content": c.content} for c in parent_matches[:top_k] ] return [] async def select_relevant_snippets( query: str, db: AsyncSession, top_k: int = 2, language: Optional[str] = None ) -> List[Dict]: """Find relevant code snippets""" from models import Snippet result = await db.execute(select(Snippet)) snippets = result.scalars().all() if not snippets: return [] query_embedding = embed_text(query) scored = [] for snippet in snippets: if language and snippet.language != language: continue snippet_text = f"{snippet.name} {snippet.content}" snippet_embedding = embed_text(snippet_text) score = cosine_similarity(query_embedding, snippet_embedding) if score >= 0.25: # Lower threshold for snippets scored.append((score, snippet)) scored.sort(key=lambda x: x[0], reverse=True) return [ { "id": s.id, "name": s.name, "language": s.language, "content": s.content, "relevance_score": score } for score, s in scored[:top_k] ] async def build_context_bundle( query: str, db: AsyncSession, project: Optional[str] = None, max_skills: int = 3, max_conventions: int = 2, max_snippets: int = 2 ) -> Dict: """ Build optimized context bundle with only relevant items. This is the main RAG entry point. """ skills, conventions, snippets = await asyncio.gather( select_relevant_skills(query, db, top_k=max_skills), select_relevant_conventions(project, db, top_k=max_conventions) if project else asyncio.coroutine(lambda: [])(), select_relevant_snippets(query, db, top_k=max_snippets) ) # Calculate total tokens total_content = "\n".join( [s["content"] for s in skills] + [c["content"] for c in conventions] + [s["content"] for s in snippets] ) from compression import count_tokens token_count = count_tokens(total_content) return { "skills": skills, "conventions": conventions, "snippets": snippets, "estimated_tokens": token_count, "items_included": len(skills) + len(conventions) + len(snippets) } import asyncio