ai-skills-api/compression.py

157 lines
5.3 KiB
Python

"""
Conversation compression - summarizes old turns to save tokens.
Supports multiple strategies: extractive summarization or Ollama LLM.
"""
from typing import List, Dict
import logging
import tiktoken
import httpx
from sumy.parsers.plaintext import PlaintextParser
from sumy.nlp.tokenizers import Tokenizer
from sumy.summarizers.lsa import LsaSummarizer
import asyncio
logger = logging.getLogger(__name__)
ENCODING = tiktoken.get_encoding("cl100k_base")
def count_tokens(text: str) -> int:
"""Count tokens in text"""
return len(ENCODING.encode(text))
def truncate_tool_output(output: str, max_tokens: int = 200) -> str:
"""Truncate tool outputs to save tokens"""
tokens = ENCODING.encode(output)
if len(tokens) <= max_tokens:
return output
truncated = ENCODING.decode(tokens[:max_tokens])
return f"{truncated}... [truncated, {len(tokens) - max_tokens} tokens omitted]"
def extractive_summarize(text: str, sentences_count: int = 3) -> str:
"""
Simple extractive summarization using LSA algorithm.
Picks the most important sentences from the text.
No external API calls, fast and deterministic.
"""
try:
parser = PlaintextParser.from_string(text, Tokenizer("english"))
summarizer = LsaSummarizer()
summary_sentences = summarizer(parser.document, sentences_count)
return " ".join(str(sentence) for sentence in summary_sentences)
except Exception as e:
# Fallback: truncate to first few sentences
sentences = text.split('. ')[:3]
return '. '.join(sentences) + '.'
async def ollama_summarize(text: str, model: str = "phi3:mini", url: str = "http://localhost:11434") -> str:
"""
Summarize using Ollama API.
Requires Ollama running with the specified model pulled.
"""
try:
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.post(
f"{url}/api/generate",
json={
"model": model,
"prompt": f"Summarize the following conversation in 2-3 sentences, focusing on key decisions and conclusions:\n\n{text}",
"stream": False,
"options": {
"num_predict": 200
}
}
)
response.raise_for_status()
result = response.json()
return result.get("response", "").strip()
except Exception as e:
# Fallback to extractive on any error
return extractive_summarize(text, sentences_count=3)
async def compress_conversation(
messages: List[Dict],
max_tokens: int = 2000,
keep_last_n: int = 3,
strategy: str = "extractive",
ollama_model: str = "phi3:mini",
ollama_url: str = "http://localhost:11434"
) -> List[Dict]:
"""
Compress conversation history:
- Keep last N exchanges in full
- Summarize everything before using the configured strategy
Args:
messages: Full conversation history
max_tokens: Target token budget
keep_last_n: Number of recent exchanges to keep uncompressed
strategy: "extractive", "ollama", or "none"
ollama_model: Model to use if strategy is "ollama"
ollama_url: Ollama API endpoint
Returns compressed message list.
"""
if strategy == "none":
return messages
if len(messages) <= keep_last_n * 2: # *2 for user/assistant pairs
return messages
# Keep system message if present
system_msg = None
convo_messages = messages[:]
if messages[0].get("role") == "system":
system_msg = messages[0]
convo_messages = messages[1:]
# Split into old (to compress) and recent (keep full)
recent = convo_messages[-keep_last_n * 2:]
old = convo_messages[:-keep_last_n * 2]
# Create text to summarize from old turns
old_text = "\n".join([f"{m['role']}: {m['content']}" for m in old])
# Summarize using selected strategy
summary = None
if strategy == "ollama":
try:
summary = await ollama_summarize(old_text, ollama_model, ollama_url)
except Exception as e:
logger.warning(f"Ollama summarization failed: {e}, falling back to extractive")
summary = extractive_summarize(old_text, sentences_count=3)
else:
# Extractive is synchronous but fast; run in thread pool to avoid blocking
loop = asyncio.get_event_loop()
summary = await loop.run_in_executor(None, lambda: extractive_summarize(old_text, 3))
# Build compressed messages
compressed = []
if system_msg:
compressed.append(system_msg)
# Add summary as a user message with clear demarcation
compressed.append({
"role": "user",
"content": f"[CONVERSATION SUMMARY]\n{summary}\n[/CONVERSATION SUMMARY]\n\n---\n\nRecent conversation (most relevant):"
})
compressed.extend(recent)
# Verify we're under limit, if not, drop old more aggressively
total_tokens = sum(count_tokens(m.get("content", "")) for m in compressed)
if total_tokens > max_tokens and len(compressed) > 2:
# Keep only system + last exchange
if system_msg:
compressed = [system_msg, recent[-2]]
else:
compressed = recent[-2:]
return compressed