Skip to content

Middleware API

pydantic_ai_summarization.middleware

Context manager middleware for real-time conversation context management.

This module requires the hybrid extra::

Text Only
pip install summarization-pydantic-ai[hybrid]

It provides :class:ContextManagerMiddleware, a dual-protocol class that:

  1. Acts as a pydantic-ai history processor (__call__): tracks token usage and auto-compresses when approaching the token limit.
  2. Acts as a pydantic-ai-middleware AgentMiddleware (after_tool_call): optionally truncates large tool outputs inline.

Example::

Text Only
from pydantic_ai import Agent
from pydantic_ai_middleware import MiddlewareAgent
from pydantic_ai_summarization import (
    ContextManagerMiddleware,
    create_context_manager_middleware,
)

middleware = create_context_manager_middleware(
    max_tokens=200_000,
    compress_threshold=0.9,
    on_usage_update=lambda pct, cur, mx: print(f"{pct:.0%} used"),
)

agent = Agent("openai:gpt-4.1", history_processors=[middleware])
wrapped = MiddlewareAgent(agent, middleware=[middleware])

UsageCallback = Callable[[float, int, int], Any] module-attribute

Callback type for usage updates: (percentage, current_tokens, max_tokens).

Supports both sync and async callables. If the callable returns an awaitable, it will be awaited automatically.

ContextManagerMiddleware dataclass

Bases: AgentMiddleware[Any]

Real-time context management middleware.

Combines token tracking, auto-compression, and optional tool output truncation. Registered both as a pydantic-ai history_processor (for per-model-call context management) and as an AgentMiddleware (for tool output interception).

Attributes:

Name Type Description
max_tokens int

Maximum token budget for the conversation.

compress_threshold float

Fraction of max_tokens at which auto-compression triggers.

keep ContextSize

How much context to retain after compression.

summarization_model ModelType

Model used for generating summaries.

token_counter TokenCounter

Function to count tokens in messages.

summary_prompt str

Prompt template for summary generation.

trim_tokens_to_summarize int

Max tokens to include when generating the summary.

max_input_tokens int | None

Model max input tokens (for fraction-based keep).

max_tool_output_tokens int | None

Per-tool-output token limit before truncation.

tool_output_head_lines int

Lines from the beginning of truncated output.

tool_output_tail_lines int

Lines from the end of truncated output.

on_usage_update UsageCallback | None

Callback invoked with usage stats before each model call.

Source code in src/pydantic_ai_summarization/middleware.py
Python
@dataclass
class ContextManagerMiddleware(AgentMiddleware[Any]):  # type: ignore[misc]
    """Real-time context management middleware.

    Combines token tracking, auto-compression, and optional tool output
    truncation. Registered both as a pydantic-ai ``history_processor``
    (for per-model-call context management) and as an ``AgentMiddleware``
    (for tool output interception).

    Attributes:
        max_tokens: Maximum token budget for the conversation.
        compress_threshold: Fraction of max_tokens at which auto-compression triggers.
        keep: How much context to retain after compression.
        summarization_model: Model used for generating summaries.
        token_counter: Function to count tokens in messages.
        summary_prompt: Prompt template for summary generation.
        trim_tokens_to_summarize: Max tokens to include when generating the summary.
        max_input_tokens: Model max input tokens (for fraction-based keep).
        max_tool_output_tokens: Per-tool-output token limit before truncation.
        tool_output_head_lines: Lines from the beginning of truncated output.
        tool_output_tail_lines: Lines from the end of truncated output.
        on_usage_update: Callback invoked with usage stats before each model call.
    """

    max_tokens: int = 200_000
    """Maximum token budget for the conversation."""

    compress_threshold: float = 0.9
    """Fraction of max_tokens at which auto-compression triggers (0.0, 1.0]."""

    keep: ContextSize = ("messages", 20)
    """How much context to retain after compression."""

    summarization_model: ModelType = "openai:gpt-4.1-mini"
    """Model used for generating summaries.

    Accepts a string model name, a pydantic-ai Model instance, or a KnownModelName literal.
    """

    token_counter: TokenCounter = field(default=count_tokens_approximately)
    """Function to count tokens in messages."""

    summary_prompt: str = DEFAULT_SUMMARY_PROMPT
    """Prompt template for summary generation."""

    trim_tokens_to_summarize: int = 4000
    """Max tokens to include when generating the summary."""

    max_input_tokens: int | None = None
    """Model max input tokens (required for fraction-based keep)."""

    max_tool_output_tokens: int | None = None
    """Per-tool-output token limit before truncation. None disables truncation."""

    tool_output_head_lines: int = 5
    """Lines to show from the beginning of truncated tool output."""

    tool_output_tail_lines: int = 5
    """Lines to show from the end of truncated tool output."""

    on_usage_update: UsageCallback | None = None
    """Callback invoked with ``(percentage, current_tokens, max_tokens)``."""

    _summarization_agent: Agent[None, str] | None = field(default=None, init=False, repr=False)
    _compression_count: int = field(default=0, init=False, repr=False)

    def __post_init__(self) -> None:
        """Validate configuration."""
        if not 0 < self.compress_threshold <= 1:
            raise ValueError(
                f"compress_threshold must be between 0 and 1, got {self.compress_threshold}."
            )
        self.keep = validate_context_size(self.keep, "keep")
        if self.keep[0] == "fraction" and self.max_input_tokens is None:
            raise ValueError("max_input_tokens is required when using fraction-based keep.")

    # -- History processor protocol (pydantic-ai) --

    async def __call__(self, messages: list[ModelMessage]) -> list[ModelMessage]:
        """History processor: track usage and auto-compress.

        Called by pydantic-ai before every model request within a run.

        Args:
            messages: Current message history.

        Returns:
            Potentially compressed message history.
        """
        total = self.token_counter(messages)
        pct = total / self.max_tokens if self.max_tokens > 0 else 0.0

        await self._notify_usage(pct, total, self.max_tokens)

        if pct >= self.compress_threshold:
            messages = await self._compress(messages)
            self._compression_count += 1
            new_total = self.token_counter(messages)
            new_pct = new_total / self.max_tokens if self.max_tokens > 0 else 0.0
            await self._notify_usage(new_pct, new_total, self.max_tokens)

        return messages

    # -- Middleware protocol (pydantic-ai-middleware) --

    async def after_tool_call(
        self,
        tool_name: str,
        tool_args: dict[str, Any],
        result: Any,
        deps: Any | None,
        ctx: Any | None = None,
    ) -> Any:
        """Middleware hook: optionally truncate large tool outputs.

        Args:
            tool_name: Name of the tool that was called.
            tool_args: Arguments passed to the tool.
            result: The tool's return value.
            deps: Agent dependencies.
            ctx: Middleware scoped context.

        Returns:
            Original or truncated result.
        """
        if self.max_tool_output_tokens is None:
            return result

        result_str = str(result) if not isinstance(result, str) else result
        char_limit = self.max_tool_output_tokens * 4  # ~4 chars per token

        if len(result_str) <= char_limit:
            return result

        return _truncate_tool_output(
            result_str, self.tool_output_head_lines, self.tool_output_tail_lines
        )

    # -- Internal methods --

    async def _notify_usage(self, pct: float, current: int, maximum: int) -> None:
        """Call the usage callback if set, handling sync and async."""
        if self.on_usage_update is None:
            return
        result = self.on_usage_update(pct, current, maximum)
        if inspect.isawaitable(result):
            await result

    async def _compress(self, messages: list[ModelMessage]) -> list[ModelMessage]:
        """Compress messages via LLM summarization."""
        cutoff_index = determine_cutoff_index(
            messages,
            self.keep,
            self.token_counter,
            self.max_input_tokens,
            default_keep=20,
        )

        if cutoff_index <= 0:
            return messages

        messages_to_summarize = messages[:cutoff_index]  # pragma: no cover
        preserved_messages = messages[cutoff_index:]  # pragma: no cover

        summary = await self._create_summary(messages_to_summarize)  # pragma: no cover

        summary_message = ModelRequest(  # pragma: no cover
            parts=[
                SystemPromptPart(content=f"Summary of previous conversation:\n\n{summary}"),
            ]
        )

        return [summary_message, *preserved_messages]  # pragma: no cover

    def _get_summarization_agent(self) -> Agent[None, str]:  # pragma: no cover
        """Get or create the summarization agent."""
        if self._summarization_agent is None:
            self._summarization_agent = Agent(
                self.summarization_model,
                instructions=(
                    "You are a context summarization assistant. "
                    "Extract the most important information from conversations."
                ),
            )
        return self._summarization_agent

    async def _create_summary(
        self, messages_to_summarize: list[ModelMessage]
    ) -> str:  # pragma: no cover
        """Generate summary for the given messages."""
        if not messages_to_summarize:
            return "No previous conversation history."

        formatted = format_messages_for_summary(messages_to_summarize)

        if self.trim_tokens_to_summarize and len(formatted) > self.trim_tokens_to_summarize * 4:
            formatted = formatted[-(self.trim_tokens_to_summarize * 4) :]

        prompt = self.summary_prompt.format(messages=formatted)

        try:
            agent = self._get_summarization_agent()
            result = await agent.run(prompt)
            return result.output.strip()
        except Exception as e:
            return f"Error generating summary: {e!s}"

    @property
    def compression_count(self) -> int:
        """Number of times compression has been triggered."""
        return self._compression_count

max_tokens = 200000 class-attribute instance-attribute

Maximum token budget for the conversation.

compress_threshold = 0.9 class-attribute instance-attribute

Fraction of max_tokens at which auto-compression triggers (0.0, 1.0].

keep = ('messages', 20) class-attribute instance-attribute

How much context to retain after compression.

summarization_model = 'openai:gpt-4.1-mini' class-attribute instance-attribute

Model used for generating summaries.

Accepts a string model name, a pydantic-ai Model instance, or a KnownModelName literal.

token_counter = field(default=count_tokens_approximately) class-attribute instance-attribute

Function to count tokens in messages.

summary_prompt = DEFAULT_SUMMARY_PROMPT class-attribute instance-attribute

Prompt template for summary generation.

trim_tokens_to_summarize = 4000 class-attribute instance-attribute

Max tokens to include when generating the summary.

max_input_tokens = None class-attribute instance-attribute

Model max input tokens (required for fraction-based keep).

max_tool_output_tokens = None class-attribute instance-attribute

Per-tool-output token limit before truncation. None disables truncation.

tool_output_head_lines = 5 class-attribute instance-attribute

Lines to show from the beginning of truncated tool output.

tool_output_tail_lines = 5 class-attribute instance-attribute

Lines to show from the end of truncated tool output.

on_usage_update = None class-attribute instance-attribute

Callback invoked with (percentage, current_tokens, max_tokens).

compression_count property

Number of times compression has been triggered.

__post_init__()

Validate configuration.

Source code in src/pydantic_ai_summarization/middleware.py
Python
def __post_init__(self) -> None:
    """Validate configuration."""
    if not 0 < self.compress_threshold <= 1:
        raise ValueError(
            f"compress_threshold must be between 0 and 1, got {self.compress_threshold}."
        )
    self.keep = validate_context_size(self.keep, "keep")
    if self.keep[0] == "fraction" and self.max_input_tokens is None:
        raise ValueError("max_input_tokens is required when using fraction-based keep.")

__call__(messages) async

History processor: track usage and auto-compress.

Called by pydantic-ai before every model request within a run.

Parameters:

Name Type Description Default
messages list[ModelMessage]

Current message history.

required

Returns:

Type Description
list[ModelMessage]

Potentially compressed message history.

Source code in src/pydantic_ai_summarization/middleware.py
Python
async def __call__(self, messages: list[ModelMessage]) -> list[ModelMessage]:
    """History processor: track usage and auto-compress.

    Called by pydantic-ai before every model request within a run.

    Args:
        messages: Current message history.

    Returns:
        Potentially compressed message history.
    """
    total = self.token_counter(messages)
    pct = total / self.max_tokens if self.max_tokens > 0 else 0.0

    await self._notify_usage(pct, total, self.max_tokens)

    if pct >= self.compress_threshold:
        messages = await self._compress(messages)
        self._compression_count += 1
        new_total = self.token_counter(messages)
        new_pct = new_total / self.max_tokens if self.max_tokens > 0 else 0.0
        await self._notify_usage(new_pct, new_total, self.max_tokens)

    return messages

after_tool_call(tool_name, tool_args, result, deps, ctx=None) async

Middleware hook: optionally truncate large tool outputs.

Parameters:

Name Type Description Default
tool_name str

Name of the tool that was called.

required
tool_args dict[str, Any]

Arguments passed to the tool.

required
result Any

The tool's return value.

required
deps Any | None

Agent dependencies.

required
ctx Any | None

Middleware scoped context.

None

Returns:

Type Description
Any

Original or truncated result.

Source code in src/pydantic_ai_summarization/middleware.py
Python
async def after_tool_call(
    self,
    tool_name: str,
    tool_args: dict[str, Any],
    result: Any,
    deps: Any | None,
    ctx: Any | None = None,
) -> Any:
    """Middleware hook: optionally truncate large tool outputs.

    Args:
        tool_name: Name of the tool that was called.
        tool_args: Arguments passed to the tool.
        result: The tool's return value.
        deps: Agent dependencies.
        ctx: Middleware scoped context.

    Returns:
        Original or truncated result.
    """
    if self.max_tool_output_tokens is None:
        return result

    result_str = str(result) if not isinstance(result, str) else result
    char_limit = self.max_tool_output_tokens * 4  # ~4 chars per token

    if len(result_str) <= char_limit:
        return result

    return _truncate_tool_output(
        result_str, self.tool_output_head_lines, self.tool_output_tail_lines
    )

create_context_manager_middleware(max_tokens=200000, compress_threshold=0.9, keep=('messages', 20), summarization_model='openai:gpt-4.1-mini', token_counter=None, summary_prompt=None, max_tool_output_tokens=None, tool_output_head_lines=5, tool_output_tail_lines=5, on_usage_update=None, max_input_tokens=None)

Create a :class:ContextManagerMiddleware with sensible defaults.

Parameters:

Name Type Description Default
max_tokens int

Maximum token budget for the conversation.

200000
compress_threshold float

Fraction of max_tokens at which auto-compression triggers.

0.9
keep ContextSize

How much context to retain after compression.

('messages', 20)
summarization_model ModelType

Model used for generating summaries.

'openai:gpt-4.1-mini'
token_counter TokenCounter | None

Custom token counter (default: approximate char-based).

None
summary_prompt str | None

Custom prompt template for summaries.

None
max_tool_output_tokens int | None

Per-tool-output token limit before truncation.

None
tool_output_head_lines int

Lines from start of truncated output.

5
tool_output_tail_lines int

Lines from end of truncated output.

5
on_usage_update UsageCallback | None

Callback for usage updates.

None
max_input_tokens int | None

Model max input tokens (for fraction-based keep).

None

Returns:

Type Description
ContextManagerMiddleware

Configured ContextManagerMiddleware instance.

Source code in src/pydantic_ai_summarization/middleware.py
Python
def create_context_manager_middleware(
    max_tokens: int = 200_000,
    compress_threshold: float = 0.9,
    keep: ContextSize = ("messages", 20),
    summarization_model: ModelType = "openai:gpt-4.1-mini",
    token_counter: TokenCounter | None = None,
    summary_prompt: str | None = None,
    max_tool_output_tokens: int | None = None,
    tool_output_head_lines: int = 5,
    tool_output_tail_lines: int = 5,
    on_usage_update: UsageCallback | None = None,
    max_input_tokens: int | None = None,
) -> ContextManagerMiddleware:
    """Create a :class:`ContextManagerMiddleware` with sensible defaults.

    Args:
        max_tokens: Maximum token budget for the conversation.
        compress_threshold: Fraction of max_tokens at which auto-compression triggers.
        keep: How much context to retain after compression.
        summarization_model: Model used for generating summaries.
        token_counter: Custom token counter (default: approximate char-based).
        summary_prompt: Custom prompt template for summaries.
        max_tool_output_tokens: Per-tool-output token limit before truncation.
        tool_output_head_lines: Lines from start of truncated output.
        tool_output_tail_lines: Lines from end of truncated output.
        on_usage_update: Callback for usage updates.
        max_input_tokens: Model max input tokens (for fraction-based keep).

    Returns:
        Configured ContextManagerMiddleware instance.
    """
    kwargs: dict[str, Any] = {
        "max_tokens": max_tokens,
        "compress_threshold": compress_threshold,
        "keep": keep,
        "summarization_model": summarization_model,
        "max_tool_output_tokens": max_tool_output_tokens,
        "tool_output_head_lines": tool_output_head_lines,
        "tool_output_tail_lines": tool_output_tail_lines,
        "max_input_tokens": max_input_tokens,
    }
    if token_counter is not None:
        kwargs["token_counter"] = token_counter
    if summary_prompt is not None:
        kwargs["summary_prompt"] = summary_prompt
    if on_usage_update is not None:
        kwargs["on_usage_update"] = on_usage_update

    return ContextManagerMiddleware(**kwargs)