diff --git a/main.py b/main.py index 0f12e7a..63d6e4e 100644 --- a/main.py +++ b/main.py @@ -422,7 +422,7 @@ async def get_context_rag( @app.post("/compress", dependencies=[Depends(verify_api_key)]) async def compress_messages_endpoint( - messages: List[dict], + request: CompressionRequest, keep_last_n: Optional[int] = None, max_tokens: Optional[int] = None ): @@ -435,9 +435,9 @@ async def compress_messages_endpoint( max_tokens = max_tokens or config.compression.max_tokens compressed = await compress_conversation( - messages, - max_tokens=max_tokens, - keep_last_n=keep_last_n, + request.messages, + max_tokens=max_tokens or config.compression.max_tokens, + keep_last_n=keep_last_n or config.compression.keep_last_n, strategy=config.compression.strategy, ollama_model=config.compression.ollama_model, ollama_url=config.compression.ollama_url diff --git a/schemas.py b/schemas.py index 418b2b0..f51d5ed 100644 --- a/schemas.py +++ b/schemas.py @@ -67,6 +67,17 @@ class Memory(MemoryBase): from_attributes = True +class Message(BaseModel): + role: str + content: str + + +class CompressionRequest(BaseModel): + messages: List[Message] + keep_last_n: Optional[int] = None + max_tokens: Optional[int] = None + + class ContextBundle(BaseModel): skills: List[Skill] snippets: List[Snippet]