Source code for autogen_core.components._closure_agent

import inspect
from typing import Any, Awaitable, Callable, List, Mapping, Sequence, TypeVar, get_type_hints

from ..base import (
    Agent,
    AgentId,
    AgentInstantiationContext,
    AgentMetadata,
    AgentRuntime,
    AgentType,
    MessageContext,
    Subscription,
    SubscriptionInstantiationContext,
    try_get_known_serializers_for_type,
)
from ..base._type_helpers import get_types
from ..base.exceptions import CantHandleException

T = TypeVar("T")


def get_handled_types_from_closure(
    closure: Callable[[AgentRuntime, AgentId, T, MessageContext], Awaitable[Any]],
) -> Sequence[type]:
    args = inspect.getfullargspec(closure)[0]
    if len(args) != 4:
        raise AssertionError("Closure must have 4 arguments")

    message_arg_name = args[2]

    type_hints = get_type_hints(closure)

    if "return" not in type_hints:
        raise AssertionError("return not found in function signature")

    # Get the type of the message parameter
    target_types = get_types(type_hints[message_arg_name])
    if target_types is None:
        raise AssertionError("Message type not found")

    # print(type_hints)
    return_types = get_types(type_hints["return"])

    if return_types is None:
        raise AssertionError("Return type not found")

    return target_types


[docs] class ClosureAgent(Agent): def __init__( self, description: str, closure: Callable[[AgentRuntime, AgentId, T, MessageContext], Awaitable[Any]] ) -> None: try: runtime = AgentInstantiationContext.current_runtime() id = AgentInstantiationContext.current_agent_id() except Exception as e: raise RuntimeError( "ClosureAgent must be instantiated within the context of an AgentRuntime. It cannot be directly instantiated." ) from e self._runtime: AgentRuntime = runtime self._id: AgentId = id self._description = description handled_types = get_handled_types_from_closure(closure) self._expected_types = handled_types self._closure = closure @property def metadata(self) -> AgentMetadata: assert self._id is not None return AgentMetadata( key=self._id.key, type=self._id.type, description=self._description, ) @property def id(self) -> AgentId: return self._id @property def runtime(self) -> AgentRuntime: return self._runtime
[docs] async def on_message(self, message: Any, ctx: MessageContext) -> Any: if type(message) not in self._expected_types: raise CantHandleException( f"Message type {type(message)} not in target types {self._expected_types} of {self.id}" ) return await self._closure(self._runtime, self._id, message, ctx)
[docs] def save_state(self) -> Mapping[str, Any]: raise ValueError("save_state not implemented for ClosureAgent")
[docs] def load_state(self, state: Mapping[str, Any]) -> None: raise ValueError("load_state not implemented for ClosureAgent")
[docs] @classmethod async def register( cls, runtime: AgentRuntime, type: str, closure: Callable[[AgentRuntime, AgentId, T, MessageContext], Awaitable[Any]], *, description: str = "", subscriptions: Callable[[], list[Subscription] | Awaitable[list[Subscription]]] | None = None, ) -> AgentType: agent_type = AgentType(type) subscriptions_list: List[Subscription] = [] if subscriptions is not None: with SubscriptionInstantiationContext.populate_context(agent_type): subscriptions_list_result = subscriptions() if inspect.isawaitable(subscriptions_list_result): subscriptions_list.extend(await subscriptions_list_result) else: # just ignore mypy here subscriptions_list.extend(subscriptions_list_result) # type: ignore agent_type = await runtime.register_factory( type=agent_type, agent_factory=lambda: ClosureAgent(description=description, closure=closure), expected_class=cls, ) for subscription in subscriptions_list: await runtime.add_subscription(subscription) handled_types = get_handled_types_from_closure(closure) for message_type in handled_types: # TODO: support custom serializers serializer = try_get_known_serializers_for_type(message_type) runtime.add_message_serializer(serializer) return agent_type