157 lines
5.3 KiB
Python
157 lines
5.3 KiB
Python
"""
|
|
Conversation compression - summarizes old turns to save tokens.
|
|
Supports multiple strategies: extractive summarization or Ollama LLM.
|
|
"""
|
|
|
|
from typing import List, Dict
|
|
import logging
|
|
import tiktoken
|
|
import httpx
|
|
from sumy.parsers.plaintext import PlaintextParser
|
|
from sumy.nlp.tokenizers import Tokenizer
|
|
from sumy.summarizers.lsa import LsaSummarizer
|
|
import asyncio
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
ENCODING = tiktoken.get_encoding("cl100k_base")
|
|
|
|
|
|
def count_tokens(text: str) -> int:
|
|
"""Count tokens in text"""
|
|
return len(ENCODING.encode(text))
|
|
|
|
|
|
def truncate_tool_output(output: str, max_tokens: int = 200) -> str:
|
|
"""Truncate tool outputs to save tokens"""
|
|
tokens = ENCODING.encode(output)
|
|
if len(tokens) <= max_tokens:
|
|
return output
|
|
|
|
truncated = ENCODING.decode(tokens[:max_tokens])
|
|
return f"{truncated}... [truncated, {len(tokens) - max_tokens} tokens omitted]"
|
|
|
|
|
|
def extractive_summarize(text: str, sentences_count: int = 3) -> str:
|
|
"""
|
|
Simple extractive summarization using LSA algorithm.
|
|
Picks the most important sentences from the text.
|
|
No external API calls, fast and deterministic.
|
|
"""
|
|
try:
|
|
parser = PlaintextParser.from_string(text, Tokenizer("english"))
|
|
summarizer = LsaSummarizer()
|
|
summary_sentences = summarizer(parser.document, sentences_count)
|
|
return " ".join(str(sentence) for sentence in summary_sentences)
|
|
except Exception as e:
|
|
# Fallback: truncate to first few sentences
|
|
sentences = text.split('. ')[:3]
|
|
return '. '.join(sentences) + '.'
|
|
|
|
|
|
async def ollama_summarize(text: str, model: str = "phi3:mini", url: str = "http://localhost:11434") -> str:
|
|
"""
|
|
Summarize using Ollama API.
|
|
Requires Ollama running with the specified model pulled.
|
|
"""
|
|
try:
|
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
|
response = await client.post(
|
|
f"{url}/api/generate",
|
|
json={
|
|
"model": model,
|
|
"prompt": f"Summarize the following conversation in 2-3 sentences, focusing on key decisions and conclusions:\n\n{text}",
|
|
"stream": False,
|
|
"options": {
|
|
"num_predict": 200
|
|
}
|
|
}
|
|
)
|
|
response.raise_for_status()
|
|
result = response.json()
|
|
return result.get("response", "").strip()
|
|
except Exception as e:
|
|
# Fallback to extractive on any error
|
|
return extractive_summarize(text, sentences_count=3)
|
|
|
|
|
|
async def compress_conversation(
|
|
messages: List[Dict],
|
|
max_tokens: int = 2000,
|
|
keep_last_n: int = 3,
|
|
strategy: str = "extractive",
|
|
ollama_model: str = "phi3:mini",
|
|
ollama_url: str = "http://localhost:11434"
|
|
) -> List[Dict]:
|
|
"""
|
|
Compress conversation history:
|
|
- Keep last N exchanges in full
|
|
- Summarize everything before using the configured strategy
|
|
|
|
Args:
|
|
messages: Full conversation history
|
|
max_tokens: Target token budget
|
|
keep_last_n: Number of recent exchanges to keep uncompressed
|
|
strategy: "extractive", "ollama", or "none"
|
|
ollama_model: Model to use if strategy is "ollama"
|
|
ollama_url: Ollama API endpoint
|
|
|
|
Returns compressed message list.
|
|
"""
|
|
if strategy == "none":
|
|
return messages
|
|
|
|
if len(messages) <= keep_last_n * 2: # *2 for user/assistant pairs
|
|
return messages
|
|
|
|
# Keep system message if present
|
|
system_msg = None
|
|
convo_messages = messages[:]
|
|
|
|
if messages[0].get("role") == "system":
|
|
system_msg = messages[0]
|
|
convo_messages = messages[1:]
|
|
|
|
# Split into old (to compress) and recent (keep full)
|
|
recent = convo_messages[-keep_last_n * 2:]
|
|
old = convo_messages[:-keep_last_n * 2]
|
|
|
|
# Create text to summarize from old turns
|
|
old_text = "\n".join([f"{m['role']}: {m['content']}" for m in old])
|
|
|
|
# Summarize using selected strategy
|
|
summary = None
|
|
if strategy == "ollama":
|
|
try:
|
|
summary = await ollama_summarize(old_text, ollama_model, ollama_url)
|
|
except Exception as e:
|
|
logger.warning(f"Ollama summarization failed: {e}, falling back to extractive")
|
|
summary = extractive_summarize(old_text, sentences_count=3)
|
|
else:
|
|
# Extractive is synchronous but fast; run in thread pool to avoid blocking
|
|
loop = asyncio.get_event_loop()
|
|
summary = await loop.run_in_executor(None, lambda: extractive_summarize(old_text, 3))
|
|
|
|
# Build compressed messages
|
|
compressed = []
|
|
if system_msg:
|
|
compressed.append(system_msg)
|
|
|
|
# Add summary as a user message with clear demarcation
|
|
compressed.append({
|
|
"role": "user",
|
|
"content": f"[CONVERSATION SUMMARY]\n{summary}\n[/CONVERSATION SUMMARY]\n\n---\n\nRecent conversation (most relevant):"
|
|
})
|
|
|
|
compressed.extend(recent)
|
|
|
|
# Verify we're under limit, if not, drop old more aggressively
|
|
total_tokens = sum(count_tokens(m.get("content", "")) for m in compressed)
|
|
if total_tokens > max_tokens and len(compressed) > 2:
|
|
# Keep only system + last exchange
|
|
if system_msg:
|
|
compressed = [system_msg, recent[-2]]
|
|
else:
|
|
compressed = recent[-2:]
|
|
|
|
return compressed
|