- semantic_cache.py: Semantic similarity matching for cache hits - rag.py: RAG-based context selection with local embeddings - compression.py: Conversation history summarization - New endpoints: /cache/semantic-lookup, /cache/semantic-store, /context/rag, /compress - Uses sentence-transformers (all-MiniLM-L6-v2) - no external API calls - No vector DB needed - cosine similarity on small datasets is fast enough - Expected savings: 50-70% token reduction
205 lines
5.5 KiB
Python
205 lines
5.5 KiB
Python
"""
|
|
RAG-based context selection using local embeddings.
|
|
No external API calls - runs entirely on your home server.
|
|
"""
|
|
|
|
import numpy as np
|
|
from sentence_transformers import SentenceTransformer
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from typing import List, Dict, Optional
|
|
import os
|
|
|
|
# Small, fast model - ~100MB, runs on CPU
|
|
MODEL_NAME = "all-MiniLM-L6-v2"
|
|
_model: Optional[SentenceTransformer] = None
|
|
|
|
|
|
def get_model() -> SentenceTransformer:
|
|
"""Lazy-load the embedding model"""
|
|
global _model
|
|
if _model is None:
|
|
_model = SentenceTransformer(MODEL_NAME)
|
|
return _model
|
|
|
|
|
|
def embed_text(text: str) -> np.ndarray:
|
|
"""Generate embedding for text"""
|
|
model = get_model()
|
|
return model.encode(text, normalize_embeddings=True)
|
|
|
|
|
|
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
|
|
"""Compute cosine similarity between two vectors"""
|
|
return float(np.dot(a, b))
|
|
|
|
|
|
async def select_relevant_skills(
|
|
query: str,
|
|
db: AsyncSession,
|
|
top_k: int = 3,
|
|
min_score: float = 0.3
|
|
) -> List[Dict]:
|
|
"""
|
|
Find most relevant skills using semantic search.
|
|
Only returns skills above minimum similarity threshold.
|
|
"""
|
|
from models import Skill
|
|
|
|
# Get all skills (for small datasets, load all - fine for <1000 items)
|
|
result = await db.execute(select(Skill))
|
|
skills = result.scalars().all()
|
|
|
|
if not skills:
|
|
return []
|
|
|
|
# Generate query embedding
|
|
query_embedding = embed_text(query)
|
|
|
|
# Score each skill
|
|
scored = []
|
|
for skill in skills:
|
|
# Use cached embedding if available, else compute
|
|
skill_text = f"{skill.name} {skill.description or ''} {skill.content[:500]}"
|
|
skill_embedding = embed_text(skill_text)
|
|
score = cosine_similarity(query_embedding, skill_embedding)
|
|
|
|
if score >= min_score:
|
|
scored.append((score, skill))
|
|
|
|
# Sort by relevance, take top K
|
|
scored.sort(key=lambda x: x[0], reverse=True)
|
|
top_skills = scored[:top_k]
|
|
|
|
return [
|
|
{
|
|
"id": skill.id,
|
|
"name": skill.name,
|
|
"content": skill.content,
|
|
"relevance_score": score
|
|
}
|
|
for score, skill in top_skills
|
|
]
|
|
|
|
|
|
async def select_relevant_conventions(
|
|
project_path: str,
|
|
db: AsyncSession,
|
|
top_k: int = 2
|
|
) -> List[Dict]:
|
|
"""
|
|
Get conventions for a project.
|
|
Exact match on project_path, plus fuzzy match on parent paths.
|
|
"""
|
|
from models import Convention
|
|
|
|
result = await db.execute(
|
|
select(Convention)
|
|
.where(Convention.project_path == project_path)
|
|
.order_by(Convention.auto_inject.desc())
|
|
)
|
|
exact_matches = result.scalars().all()
|
|
|
|
if exact_matches:
|
|
return [
|
|
{"id": c.id, "name": c.name, "content": c.content}
|
|
for c in exact_matches[:top_k]
|
|
]
|
|
|
|
# Try parent path match
|
|
parent_path = "/".join(project_path.split("/")[:-1])
|
|
if parent_path:
|
|
result = await db.execute(
|
|
select(Convention)
|
|
.where(Convention.project_path == parent_path)
|
|
)
|
|
parent_matches = result.scalars().all()
|
|
return [
|
|
{"id": c.id, "name": c.name, "content": c.content}
|
|
for c in parent_matches[:top_k]
|
|
]
|
|
|
|
return []
|
|
|
|
|
|
async def select_relevant_snippets(
|
|
query: str,
|
|
db: AsyncSession,
|
|
top_k: int = 2,
|
|
language: Optional[str] = None
|
|
) -> List[Dict]:
|
|
"""Find relevant code snippets"""
|
|
from models import Snippet
|
|
|
|
result = await db.execute(select(Snippet))
|
|
snippets = result.scalars().all()
|
|
|
|
if not snippets:
|
|
return []
|
|
|
|
query_embedding = embed_text(query)
|
|
|
|
scored = []
|
|
for snippet in snippets:
|
|
if language and snippet.language != language:
|
|
continue
|
|
|
|
snippet_text = f"{snippet.name} {snippet.content}"
|
|
snippet_embedding = embed_text(snippet_text)
|
|
score = cosine_similarity(query_embedding, snippet_embedding)
|
|
|
|
if score >= 0.25: # Lower threshold for snippets
|
|
scored.append((score, snippet))
|
|
|
|
scored.sort(key=lambda x: x[0], reverse=True)
|
|
|
|
return [
|
|
{
|
|
"id": s.id,
|
|
"name": s.name,
|
|
"language": s.language,
|
|
"content": s.content,
|
|
"relevance_score": score
|
|
}
|
|
for score, s in scored[:top_k]
|
|
]
|
|
|
|
|
|
async def build_context_bundle(
|
|
query: str,
|
|
db: AsyncSession,
|
|
project: Optional[str] = None,
|
|
max_skills: int = 3,
|
|
max_conventions: int = 2,
|
|
max_snippets: int = 2
|
|
) -> Dict:
|
|
"""
|
|
Build optimized context bundle with only relevant items.
|
|
This is the main RAG entry point.
|
|
"""
|
|
skills, conventions, snippets = await asyncio.gather(
|
|
select_relevant_skills(query, db, top_k=max_skills),
|
|
select_relevant_conventions(project, db, top_k=max_conventions) if project else asyncio.coroutine(lambda: [])(),
|
|
select_relevant_snippets(query, db, top_k=max_snippets)
|
|
)
|
|
|
|
# Calculate total tokens
|
|
total_content = "\n".join(
|
|
[s["content"] for s in skills] +
|
|
[c["content"] for c in conventions] +
|
|
[s["content"] for s in snippets]
|
|
)
|
|
|
|
from compression import count_tokens
|
|
token_count = count_tokens(total_content)
|
|
|
|
return {
|
|
"skills": skills,
|
|
"conventions": conventions,
|
|
"snippets": snippets,
|
|
"estimated_tokens": token_count,
|
|
"items_included": len(skills) + len(conventions) + len(snippets)
|
|
}
|
|
|
|
|
|
import asyncio
|