Skip to content

Capabilities API

Capabilities for pydantic-ai agents.

Wraps summarization processors and context management as pydantic-ai AbstractCapability instances, removing the need for pydantic-ai-middleware.

Example
Python
from pydantic_ai import Agent
from pydantic_ai_summarization.capability import ContextManagerCapability

agent = Agent(
    "openai:gpt-4.1",
    capabilities=[ContextManagerCapability(max_tokens=100_000)],
)

SummarizationCapability dataclass

Bases: AbstractCapability[Any]

Capability that summarizes conversation history when thresholds are reached.

Wraps SummarizationProcessor as a pydantic-ai capability.

Example
Python
from pydantic_ai import Agent
from pydantic_ai_summarization.capability import SummarizationCapability

agent = Agent(
    "openai:gpt-4.1",
    capabilities=[SummarizationCapability(
        trigger=("messages", 50),
        keep=("messages", 10),
    )],
)
Source code in src/pydantic_ai_summarization/capability.py
Python
@dataclass
class SummarizationCapability(AbstractCapability[Any]):
    """Capability that summarizes conversation history when thresholds are reached.

    Wraps ``SummarizationProcessor`` as a pydantic-ai capability.

    Example:
        ```python
        from pydantic_ai import Agent
        from pydantic_ai_summarization.capability import SummarizationCapability

        agent = Agent(
            "openai:gpt-4.1",
            capabilities=[SummarizationCapability(
                trigger=("messages", 50),
                keep=("messages", 10),
            )],
        )
        ```
    """

    trigger: ContextSize = ("messages", 50)
    keep: ContextSize = ("messages", 10)
    model: ModelType = "openai:gpt-4.1-mini"
    token_counter: TokenCounter = field(default=count_tokens_approximately)
    summary_prompt: str = DEFAULT_SUMMARY_PROMPT
    _processor: SummarizationProcessor | None = field(default=None, init=False, repr=False)

    @classmethod
    def get_serialization_name(cls) -> str:
        return "SummarizationCapability"

    def __post_init__(self) -> None:
        self._processor = SummarizationProcessor(
            trigger=self.trigger,
            keep=self.keep,
            model=self.model,
            token_counter=self.token_counter,
            summary_prompt=self.summary_prompt,
        )

    async def before_model_request(
        self,
        ctx: RunContext[Any],
        request_context: Any,
    ) -> Any:
        assert self._processor is not None
        request_context.messages = await self._processor(request_context.messages)
        return request_context

SlidingWindowCapability dataclass

Bases: AbstractCapability[Any]

Capability that trims old messages using a sliding window.

Zero-cost alternative to summarization — discards oldest messages.

Example
Python
from pydantic_ai import Agent
from pydantic_ai_summarization.capability import SlidingWindowCapability

agent = Agent(
    "openai:gpt-4.1",
    capabilities=[SlidingWindowCapability(
        trigger=("messages", 100),
        keep=("messages", 50),
        keep_head=("messages", 1),  # preserve system prompt
    )],
)
Source code in src/pydantic_ai_summarization/capability.py
Python
@dataclass
class SlidingWindowCapability(AbstractCapability[Any]):
    """Capability that trims old messages using a sliding window.

    Zero-cost alternative to summarization — discards oldest messages.

    Example:
        ```python
        from pydantic_ai import Agent
        from pydantic_ai_summarization.capability import SlidingWindowCapability

        agent = Agent(
            "openai:gpt-4.1",
            capabilities=[SlidingWindowCapability(
                trigger=("messages", 100),
                keep=("messages", 50),
                keep_head=("messages", 1),  # preserve system prompt
            )],
        )
        ```
    """

    trigger: ContextSize = ("messages", 100)
    keep: ContextSize = ("messages", 50)
    keep_head: ContextSize | None = None
    token_counter: TokenCounter = field(default=count_tokens_approximately)
    _processor: SlidingWindowProcessor | None = field(default=None, init=False, repr=False)

    @classmethod
    def get_serialization_name(cls) -> str:
        return "SlidingWindowCapability"

    def __post_init__(self) -> None:
        self._processor = SlidingWindowProcessor(
            trigger=self.trigger,
            keep=self.keep,
            keep_head=self.keep_head,
            token_counter=self.token_counter,
        )

    async def before_model_request(
        self,
        ctx: RunContext[Any],
        request_context: Any,
    ) -> Any:
        assert self._processor is not None
        request_context.messages = await self._processor(request_context.messages)
        return request_context

LimitWarnerCapability dataclass

Bases: AbstractCapability[Any]

Capability that warns the agent when run limits approach.

Injects a warning as a trailing user message when iteration, context window, or total token limits are near.

Example
Python
from pydantic_ai import Agent
from pydantic_ai_summarization.capability import LimitWarnerCapability

agent = Agent(
    "openai:gpt-4.1",
    capabilities=[LimitWarnerCapability(
        max_iterations=40,
        max_context_tokens=100_000,
    )],
)
Source code in src/pydantic_ai_summarization/capability.py
Python
@dataclass
class LimitWarnerCapability(AbstractCapability[Any]):
    """Capability that warns the agent when run limits approach.

    Injects a warning as a trailing user message when iteration, context
    window, or total token limits are near.

    Example:
        ```python
        from pydantic_ai import Agent
        from pydantic_ai_summarization.capability import LimitWarnerCapability

        agent = Agent(
            "openai:gpt-4.1",
            capabilities=[LimitWarnerCapability(
                max_iterations=40,
                max_context_tokens=100_000,
            )],
        )
        ```
    """

    max_iterations: int | None = None
    max_context_tokens: int | None = None
    max_total_tokens: int | None = None
    warn_on: list[WarningOn] | None = None
    warning_threshold: float = 0.7
    critical_remaining_iterations: int = 3
    token_counter: TokenCounter = field(default=count_tokens_approximately)
    _processor: LimitWarnerProcessor | None = field(default=None, init=False, repr=False)

    @classmethod
    def get_serialization_name(cls) -> str:
        return "LimitWarnerCapability"

    def __post_init__(self) -> None:
        self._processor = LimitWarnerProcessor(
            max_iterations=self.max_iterations,
            max_context_tokens=self.max_context_tokens,
            max_total_tokens=self.max_total_tokens,
            warn_on=self.warn_on,
            warning_threshold=self.warning_threshold,
            critical_remaining_iterations=self.critical_remaining_iterations,
            token_counter=self.token_counter,
        )

    async def before_model_request(
        self,
        ctx: RunContext[Any],
        request_context: Any,
    ) -> Any:
        assert self._processor is not None
        request_context.messages = await self._processor(ctx, request_context.messages)
        return request_context

ContextManagerCapability dataclass

Bases: AbstractCapability[Any]

Full context management capability with token tracking, auto-compression, and tool output truncation.

Replaces ContextManagerMiddleware + pydantic-ai-middleware with a native pydantic-ai capability. Uses before_model_request for history processing and after_tool_execute for tool output truncation.

Example
Python
from pydantic_ai import Agent
from pydantic_ai_summarization.capability import ContextManagerCapability

agent = Agent(
    "openai:gpt-4.1",
    capabilities=[ContextManagerCapability(max_tokens=100_000)],
)
Source code in src/pydantic_ai_summarization/capability.py
Python
@dataclass
class ContextManagerCapability(AbstractCapability[Any]):
    """Full context management capability with token tracking, auto-compression,
    and tool output truncation.

    Replaces ``ContextManagerMiddleware`` + ``pydantic-ai-middleware`` with a native
    pydantic-ai capability. Uses ``before_model_request`` for history processing
    and ``after_tool_execute`` for tool output truncation.

    Example:
        ```python
        from pydantic_ai import Agent
        from pydantic_ai_summarization.capability import ContextManagerCapability

        agent = Agent(
            "openai:gpt-4.1",
            capabilities=[ContextManagerCapability(max_tokens=100_000)],
        )
        ```
    """

    max_tokens: int | None = None
    """Max token budget. None = auto-detect from model via genai-prices (default 200K fallback)."""

    compress_threshold: float = 0.9
    keep: ContextSize = ("messages", 0)
    summarization_model: ModelType = "openai:gpt-4.1-mini"
    token_counter: TokenCounter = field(default=count_tokens_approximately)
    summary_prompt: str = DEFAULT_SUMMARY_PROMPT
    max_tool_output_tokens: int | None = None
    tool_output_head_lines: int = 5
    tool_output_tail_lines: int = 5
    on_usage_update: UsageCallback | None = None
    on_before_compress: BeforeCompressCallback | None = None
    on_after_compress: AfterCompressCallback | None = None
    include_compact_tool: bool = False
    """When True, adds a ``compact_conversation`` tool so the agent can trigger compression."""

    _compact_requested: bool = field(default=False, init=False, repr=False)
    _compact_focus: str | None = field(default=None, init=False, repr=False)
    _compression_count: int = field(default=0, init=False, repr=False)
    _summarization_processor: SummarizationProcessor | None = field(
        default=None, init=False, repr=False
    )

    _resolved_max_tokens: int = field(default=0, init=False, repr=False)

    @classmethod
    def get_serialization_name(cls) -> str:
        return "ContextManagerCapability"

    def __post_init__(self) -> None:
        if not 0 < self.compress_threshold <= 1:
            raise ValueError(
                f"compress_threshold must be between 0 and 1, got {self.compress_threshold}."
            )
        self._resolved_max_tokens = self.max_tokens or 200_000
        self._summarization_processor = SummarizationProcessor(
            trigger=("fraction", self.compress_threshold),
            keep=self.keep,
            model=self.summarization_model,
            token_counter=self.token_counter,
            summary_prompt=self.summary_prompt,
            max_input_tokens=self._resolved_max_tokens,
        )

    def get_toolset(self) -> Any:
        """Return a toolset with the ``compact_conversation`` tool, or None."""
        if not self.include_compact_tool:
            return None

        from pydantic_ai.toolsets import FunctionToolset

        toolset: FunctionToolset[Any] = FunctionToolset(id="context-compact")
        cap_ref = self  # capture for closure

        @toolset.tool_plain(
            description=(
                "Compress the conversation history to free up context space. "
                "Use when the conversation is getting long and you want to preserve "
                "important context while reducing token usage. "
                "Optionally provide a focus topic to prioritize in the summary."
            ),
        )
        async def compact_conversation(focus: str | None = None) -> str:
            """Compress conversation history.

            Args:
                focus: Optional topic to prioritize in the summary.
            """
            cap_ref.request_compact(focus=focus)
            msg = "Conversation compaction requested."
            if focus:
                msg += f" Focus: {focus}"
            msg += " It will be applied before the next model request."
            return msg

        return toolset

    async def for_run(self, ctx: RunContext[Any]) -> ContextManagerCapability:
        """Auto-detect max_tokens from model on first run if not set."""
        if self.max_tokens is None and self._resolved_max_tokens == 200_000:  # pragma: no cover
            model_id = getattr(ctx.model, "model_id", None)
            if model_id is not None:
                resolved = _resolve_max_tokens(str(model_id))
                if resolved is not None:
                    self._resolved_max_tokens = resolved
                    # Rebuild processor with updated tokens
                    self._summarization_processor = SummarizationProcessor(
                        trigger=("fraction", self.compress_threshold),
                        keep=self.keep,
                        model=self.summarization_model,
                        token_counter=self.token_counter,
                        summary_prompt=self.summary_prompt,
                        max_input_tokens=self._resolved_max_tokens,
                    )
        return self

    @property
    def compression_count(self) -> int:
        """Number of times compression has been triggered."""
        return self._compression_count

    def request_compact(self, focus: str | None = None) -> None:
        """Request manual compaction on the next model request.

        Args:
            focus: Optional focus instructions for the summary.
        """
        self._compact_requested = True
        self._compact_focus = focus

    async def compact(
        self,
        messages: list[ModelMessage],
        _focus: str | None = None,
    ) -> list[ModelMessage]:
        """Directly compact messages. Callable outside agent.run().

        Args:
            messages: Message history to compress.
            focus: Optional focus instructions for the summary.

        Returns:
            Compressed message list.
        """
        assert self._summarization_processor is not None
        compressed = await self._summarization_processor(messages)
        self._compression_count += 1
        return compressed

    async def before_model_request(
        self,
        ctx: RunContext[Any],
        request_context: Any,
    ) -> Any:
        """Track tokens, auto-compress when threshold reached."""
        messages: list[ModelMessage] = request_context.messages

        max_tok = self._resolved_max_tokens
        total = await async_count_tokens(self.token_counter, messages)
        pct = total / max_tok if max_tok > 0 else 0.0

        if self.on_usage_update is not None:
            self.on_usage_update(pct, total, max_tok)

        should_compress = pct >= self.compress_threshold or self._compact_requested
        if should_compress:  # pragma: no cover — compression requires LLM call
            self._compact_requested = False
            self._compact_focus = None

            if self.on_before_compress is not None:
                self.on_before_compress(messages, 0)

            assert self._summarization_processor is not None
            messages = await self._summarization_processor(messages)
            self._compression_count += 1

            if self.on_after_compress is not None:
                result = self.on_after_compress(messages)
                if isinstance(result, str) and messages:
                    from pydantic_ai.messages import ModelRequest, SystemPromptPart

                    first = messages[0]
                    if isinstance(first, ModelRequest):
                        messages[0] = ModelRequest(
                            parts=[*first.parts, SystemPromptPart(content=result)],
                            instructions=first.instructions,
                        )

            new_total = await async_count_tokens(self.token_counter, messages)
            new_pct = new_total / max_tok if max_tok > 0 else 0.0
            if self.on_usage_update is not None:
                self.on_usage_update(new_pct, new_total, max_tok)

        request_context.messages = messages
        return request_context

    async def after_tool_execute(
        self,
        ctx: RunContext[Any],
        *,
        call: ToolCallPart,
        tool_def: ToolDefinition,
        args: dict[str, Any],
        result: Any,
    ) -> Any:
        """Truncate large tool outputs."""
        if self.max_tool_output_tokens is None:
            return result

        result_str = str(result) if not isinstance(result, str) else result
        char_limit = self.max_tool_output_tokens * 4

        if len(result_str) <= char_limit:
            return result

        return _truncate_tool_output(
            result_str, self.tool_output_head_lines, self.tool_output_tail_lines
        )

max_tokens = None class-attribute instance-attribute

Max token budget. None = auto-detect from model via genai-prices (default 200K fallback).

include_compact_tool = False class-attribute instance-attribute

When True, adds a compact_conversation tool so the agent can trigger compression.

compression_count property

Number of times compression has been triggered.

get_toolset()

Return a toolset with the compact_conversation tool, or None.

Source code in src/pydantic_ai_summarization/capability.py
Python
def get_toolset(self) -> Any:
    """Return a toolset with the ``compact_conversation`` tool, or None."""
    if not self.include_compact_tool:
        return None

    from pydantic_ai.toolsets import FunctionToolset

    toolset: FunctionToolset[Any] = FunctionToolset(id="context-compact")
    cap_ref = self  # capture for closure

    @toolset.tool_plain(
        description=(
            "Compress the conversation history to free up context space. "
            "Use when the conversation is getting long and you want to preserve "
            "important context while reducing token usage. "
            "Optionally provide a focus topic to prioritize in the summary."
        ),
    )
    async def compact_conversation(focus: str | None = None) -> str:
        """Compress conversation history.

        Args:
            focus: Optional topic to prioritize in the summary.
        """
        cap_ref.request_compact(focus=focus)
        msg = "Conversation compaction requested."
        if focus:
            msg += f" Focus: {focus}"
        msg += " It will be applied before the next model request."
        return msg

    return toolset

for_run(ctx) async

Auto-detect max_tokens from model on first run if not set.

Source code in src/pydantic_ai_summarization/capability.py
Python
async def for_run(self, ctx: RunContext[Any]) -> ContextManagerCapability:
    """Auto-detect max_tokens from model on first run if not set."""
    if self.max_tokens is None and self._resolved_max_tokens == 200_000:  # pragma: no cover
        model_id = getattr(ctx.model, "model_id", None)
        if model_id is not None:
            resolved = _resolve_max_tokens(str(model_id))
            if resolved is not None:
                self._resolved_max_tokens = resolved
                # Rebuild processor with updated tokens
                self._summarization_processor = SummarizationProcessor(
                    trigger=("fraction", self.compress_threshold),
                    keep=self.keep,
                    model=self.summarization_model,
                    token_counter=self.token_counter,
                    summary_prompt=self.summary_prompt,
                    max_input_tokens=self._resolved_max_tokens,
                )
    return self

request_compact(focus=None)

Request manual compaction on the next model request.

Parameters:

Name Type Description Default
focus str | None

Optional focus instructions for the summary.

None
Source code in src/pydantic_ai_summarization/capability.py
Python
def request_compact(self, focus: str | None = None) -> None:
    """Request manual compaction on the next model request.

    Args:
        focus: Optional focus instructions for the summary.
    """
    self._compact_requested = True
    self._compact_focus = focus

compact(messages, _focus=None) async

Directly compact messages. Callable outside agent.run().

Parameters:

Name Type Description Default
messages list[ModelMessage]

Message history to compress.

required
focus

Optional focus instructions for the summary.

required

Returns:

Type Description
list[ModelMessage]

Compressed message list.

Source code in src/pydantic_ai_summarization/capability.py
Python
async def compact(
    self,
    messages: list[ModelMessage],
    _focus: str | None = None,
) -> list[ModelMessage]:
    """Directly compact messages. Callable outside agent.run().

    Args:
        messages: Message history to compress.
        focus: Optional focus instructions for the summary.

    Returns:
        Compressed message list.
    """
    assert self._summarization_processor is not None
    compressed = await self._summarization_processor(messages)
    self._compression_count += 1
    return compressed

before_model_request(ctx, request_context) async

Track tokens, auto-compress when threshold reached.

Source code in src/pydantic_ai_summarization/capability.py
Python
async def before_model_request(
    self,
    ctx: RunContext[Any],
    request_context: Any,
) -> Any:
    """Track tokens, auto-compress when threshold reached."""
    messages: list[ModelMessage] = request_context.messages

    max_tok = self._resolved_max_tokens
    total = await async_count_tokens(self.token_counter, messages)
    pct = total / max_tok if max_tok > 0 else 0.0

    if self.on_usage_update is not None:
        self.on_usage_update(pct, total, max_tok)

    should_compress = pct >= self.compress_threshold or self._compact_requested
    if should_compress:  # pragma: no cover — compression requires LLM call
        self._compact_requested = False
        self._compact_focus = None

        if self.on_before_compress is not None:
            self.on_before_compress(messages, 0)

        assert self._summarization_processor is not None
        messages = await self._summarization_processor(messages)
        self._compression_count += 1

        if self.on_after_compress is not None:
            result = self.on_after_compress(messages)
            if isinstance(result, str) and messages:
                from pydantic_ai.messages import ModelRequest, SystemPromptPart

                first = messages[0]
                if isinstance(first, ModelRequest):
                    messages[0] = ModelRequest(
                        parts=[*first.parts, SystemPromptPart(content=result)],
                        instructions=first.instructions,
                    )

        new_total = await async_count_tokens(self.token_counter, messages)
        new_pct = new_total / max_tok if max_tok > 0 else 0.0
        if self.on_usage_update is not None:
            self.on_usage_update(new_pct, new_total, max_tok)

    request_context.messages = messages
    return request_context

after_tool_execute(ctx, *, call, tool_def, args, result) async

Truncate large tool outputs.

Source code in src/pydantic_ai_summarization/capability.py
Python
async def after_tool_execute(
    self,
    ctx: RunContext[Any],
    *,
    call: ToolCallPart,
    tool_def: ToolDefinition,
    args: dict[str, Any],
    result: Any,
) -> Any:
    """Truncate large tool outputs."""
    if self.max_tool_output_tokens is None:
        return result

    result_str = str(result) if not isinstance(result, str) else result
    char_limit = self.max_tool_output_tokens * 4

    if len(result_str) <= char_limit:
        return result

    return _truncate_tool_output(
        result_str, self.tool_output_head_lines, self.tool_output_tail_lines
    )