from__future__importannotationsimportinspectimportwarningsfromtypingimportAny,Awaitable,Callable,List,Literal,Mapping,Protocol,Sequence,TypeVar,get_type_hintsfrom._agent_idimportAgentIdfrom._agent_instantiationimportAgentInstantiationContextfrom._agent_metadataimportAgentMetadatafrom._agent_runtimeimportAgentRuntimefrom._agent_typeimportAgentTypefrom._base_agentimportBaseAgentfrom._cancellation_tokenimportCancellationTokenfrom._message_contextimportMessageContextfrom._serializationimporttry_get_known_serializers_for_typefrom._subscriptionimportSubscriptionfrom._subscription_contextimportSubscriptionInstantiationContextfrom._topicimportTopicIdfrom._type_helpersimportget_typesfrom.exceptionsimportCantHandleExceptionT=TypeVar("T")ClosureAgentType=TypeVar("ClosureAgentType",bound="ClosureAgent")defget_handled_types_from_closure(closure:Callable[[ClosureAgent,T,MessageContext],Awaitable[Any]],)->Sequence[type]:args=inspect.getfullargspec(closure)[0]iflen(args)!=3:raiseAssertionError("Closure must have 4 arguments")message_arg_name=args[1]type_hints=get_type_hints(closure)if"return"notintype_hints:raiseAssertionError("return not found in function signature")# Get the type of the message parametertarget_types=get_types(type_hints[message_arg_name])iftarget_typesisNone:raiseAssertionError("Message type not found")# print(type_hints)return_types=get_types(type_hints["return"])ifreturn_typesisNone:raiseAssertionError("Return type not found")returntarget_types
[docs]classClosureAgent(BaseAgent,ClosureContext):def__init__(self,description:str,closure:Callable[[ClosureContext,T,MessageContext],Awaitable[Any]],*,unknown_type_policy:Literal["error","warn","ignore"]="warn",)->None:try:runtime=AgentInstantiationContext.current_runtime()id=AgentInstantiationContext.current_agent_id()exceptExceptionase:raiseRuntimeError("ClosureAgent must be instantiated within the context of an AgentRuntime. It cannot be directly instantiated.")fromeself._runtime:AgentRuntime=runtimeself._id:AgentId=idself._description=descriptionhandled_types=get_handled_types_from_closure(closure)self._expected_types=handled_typesself._closure=closureself._unknown_type_policy=unknown_type_policysuper().__init__(description)@propertydefmetadata(self)->AgentMetadata:assertself._idisnotNonereturnAgentMetadata(key=self._id.key,type=self._id.type,description=self._description,)@propertydefid(self)->AgentId:returnself._id@propertydefruntime(self)->AgentRuntime:returnself._runtime
[docs]asyncdefon_message_impl(self,message:Any,ctx:MessageContext)->Any:iftype(message)notinself._expected_types:ifself._unknown_type_policy=="warn":warnings.warn(f"Message type {type(message)} not in target types {self._expected_types} of {self.id}. Set unknown_type_policy to 'error' to raise an exception, or 'ignore' to suppress this warning.",stacklevel=1,)returnNoneelifself._unknown_type_policy=="error":raiseCantHandleException(f"Message type {type(message)} not in target types {self._expected_types} of {self.id}. Set unknown_type_policy to 'warn' to suppress this exception, or 'ignore' to suppress this warning.")returnawaitself._closure(self,message,ctx)
[docs]asyncdefsave_state(self)->Mapping[str,Any]:"""Closure agents do not have state. So this method always returns an empty dictionary."""return{}
[docs]asyncdefload_state(self,state:Mapping[str,Any])->None:"""Closure agents do not have state. So this method does nothing."""pass
[docs]@classmethodasyncdefregister_closure(cls,runtime:AgentRuntime,type:str,closure:Callable[[ClosureContext,T,MessageContext],Awaitable[Any]],*,unknown_type_policy:Literal["error","warn","ignore"]="warn",skip_direct_message_subscription:bool=False,description:str="",subscriptions:Callable[[],list[Subscription]|Awaitable[list[Subscription]]]|None=None,)->AgentType:"""The closure agent allows you to define an agent using a closure, or function without needing to define a class. It allows values to be extracted out of the runtime. The closure can define the type of message which is expected, or `Any` can be used to accept any type of message. Example: .. code-block:: python import asyncio from autogen_core import SingleThreadedAgentRuntime, MessageContext, ClosureAgent, ClosureContext from dataclasses import dataclass from autogen_core._default_subscription import DefaultSubscription from autogen_core._default_topic import DefaultTopicId @dataclass class MyMessage: content: str async def main(): queue = asyncio.Queue[MyMessage]() async def output_result(_ctx: ClosureContext, message: MyMessage, ctx: MessageContext) -> None: await queue.put(message) runtime = SingleThreadedAgentRuntime() await ClosureAgent.register_closure( runtime, "output_result", output_result, subscriptions=lambda: [DefaultSubscription()] ) runtime.start() await runtime.publish_message(MyMessage("Hello, world!"), DefaultTopicId()) await runtime.stop_when_idle() result = await queue.get() print(result) asyncio.run(main()) Args: runtime (AgentRuntime): Runtime to register the agent to type (str): Agent type of registered agent closure (Callable[[ClosureContext, T, MessageContext], Awaitable[Any]]): Closure to handle messages unknown_type_policy (Literal["error", "warn", "ignore"], optional): What to do if a type is encountered that does not match the closure type. Defaults to "warn". skip_direct_message_subscription (bool, optional): Do not add direct message subscription for this agent. Defaults to False. description (str, optional): Description of what agent does. Defaults to "". subscriptions (Callable[[], list[Subscription] | Awaitable[list[Subscription]]] | None, optional): List of subscriptions for this closure agent. Defaults to None. Returns: AgentType: Type of the agent that was registered """deffactory()->ClosureAgent:returnClosureAgent(description=description,closure=closure,unknown_type_policy=unknown_type_policy)assertlen(cls._unbound_subscriptions())==0,"Closure agents are expected to have no class subscriptions"agent_type=awaitcls.register(runtime=runtime,type=type,factory=factory,# type: ignore# There should be no need to process class subscriptions, as the closure agent does not have any subscriptions.sskip_class_subscriptions=True,skip_direct_message_subscription=skip_direct_message_subscription,)subscriptions_list:List[Subscription]=[]ifsubscriptionsisnotNone:withSubscriptionInstantiationContext.populate_context(agent_type):subscriptions_list_result=subscriptions()ifinspect.isawaitable(subscriptions_list_result):subscriptions_list.extend(awaitsubscriptions_list_result)else:# just ignore mypy heresubscriptions_list.extend(subscriptions_list_result)# type: ignoreforsubscriptioninsubscriptions_list:awaitruntime.add_subscription(subscription)handled_types=get_handled_types_from_closure(closure)formessage_typeinhandled_types:# TODO: support custom serializersserializer=try_get_known_serializers_for_type(message_type)runtime.add_message_serializer(serializer)returnagent_type