Source code for autogen_ext.runtimes.grpc._worker_runtime
from__future__importannotationsimportasyncioimportinspectimportjsonimportloggingimportsignalimportuuidimportwarningsfromasyncioimportFuture,Taskfromcollectionsimportdefaultdictfromtypingimport(TYPE_CHECKING,Any,AsyncIterable,AsyncIterator,Awaitable,Callable,ClassVar,DefaultDict,Dict,List,Literal,Mapping,ParamSpec,Sequence,Set,Tuple,Type,TypeVar,cast,)fromautogen_coreimport(JSON_DATA_CONTENT_TYPE,PROTOBUF_DATA_CONTENT_TYPE,Agent,AgentId,AgentInstantiationContext,AgentMetadata,AgentRuntime,AgentType,CancellationToken,MessageContext,MessageHandlerContext,MessageSerializer,Subscription,TopicId,)fromautogen_core._runtime_impl_helpersimportSubscriptionManager,get_implfromautogen_core._serializationimport(SerializationRegistry,)fromautogen_core._telemetryimportMessageRuntimeTracingConfig,TraceHelper,get_telemetry_grpc_metadatafromgoogle.protobufimportany_pb2fromopentelemetry.traceimportTracerProviderfromtyping_extensionsimportSelffromautogen_ext.runtimes.grpc._utilsimportsubscription_to_protofrom.import_constantsfrom._constantsimportGRPC_IMPORT_ERROR_STRfrom._type_helpersimportChannelArgumentTypefrom.protosimportagent_worker_pb2,agent_worker_pb2_grpc,cloudevent_pb2try:importgrpc.aioexceptImportErrorase:raiseImportError(GRPC_IMPORT_ERROR_STR)fromeifTYPE_CHECKING:from.protos.agent_worker_pb2_grpcimportAgentRpcAsyncStublogger=logging.getLogger("autogen_core")event_logger=logging.getLogger("autogen_core.events")P=ParamSpec("P")T=TypeVar("T",bound=Agent)type_func_alias=typeclassQueueAsyncIterable(AsyncIterator[Any],AsyncIterable[Any]):def__init__(self,queue:asyncio.Queue[Any])->None:self._queue=queueasyncdef__anext__(self)->Any:returnawaitself._queue.get()def__aiter__(self)->AsyncIterator[Any]:returnselfclassHostConnection:DEFAULT_GRPC_CONFIG:ClassVar[ChannelArgumentType]=[("grpc.service_config",json.dumps({"methodConfig":[{"name":[{}],"retryPolicy":{"maxAttempts":3,"initialBackoff":"0.01s","maxBackoff":"5s","backoffMultiplier":2,"retryableStatusCodes":["UNAVAILABLE"],},}],}),)]def__init__(self,channel:grpc.aio.Channel,stub:Any)->None:# type: ignoreself._channel=channelself._send_queue=asyncio.Queue[agent_worker_pb2.Message]()self._recv_queue=asyncio.Queue[agent_worker_pb2.Message]()self._connection_task:Task[None]|None=Noneself._stub:AgentRpcAsyncStub=stubself._client_id=str(uuid.uuid4())@propertydefstub(self)->Any:returnself._stub@propertydefmetadata(self)->Sequence[Tuple[str,str]]:return[("client-id",self._client_id)]@classmethodasyncdeffrom_host_address(cls,host_address:str,extra_grpc_config:ChannelArgumentType=DEFAULT_GRPC_CONFIG)->Self:logger.info("Connecting to %s",host_address)# Always use DEFAULT_GRPC_CONFIG and override it with provided grpc_configmerged_options=[(k,v)fork,vin{**dict(HostConnection.DEFAULT_GRPC_CONFIG),**dict(extra_grpc_config)}.items()]channel=grpc.aio.insecure_channel(host_address,options=merged_options,)stub:AgentRpcAsyncStub=agent_worker_pb2_grpc.AgentRpcStub(channel)# type: ignoreinstance=cls(channel,stub)instance._connection_task=awaitinstance._connect(stub,instance._send_queue,instance._recv_queue,instance._client_id)returninstanceasyncdefclose(self)->None:ifself._connection_taskisNone:raiseRuntimeError("Connection is not open.")awaitself._channel.close()awaitself._connection_task@staticmethodasyncdef_connect(stub:Any,# AgentRpcAsyncStubsend_queue:asyncio.Queue[agent_worker_pb2.Message],receive_queue:asyncio.Queue[agent_worker_pb2.Message],client_id:str,)->Task[None]:fromgrpc.aioimportStreamStreamCall# TODO: where do exceptions from reading the iterable go? How do we recover from those?stream:StreamStreamCall[agent_worker_pb2.Message,agent_worker_pb2.Message]=stub.OpenChannel(# type: ignoreQueueAsyncIterable(send_queue),metadata=[("client-id",client_id)])awaitstream.wait_for_connection()asyncdefread_loop()->None:whileTrue:logger.info("Waiting for message from host")message=cast(agent_worker_pb2.Message,awaitstream.read())# type: ignoreifmessage==grpc.aio.EOF:# type: ignorelogger.info("EOF")breaklogger.info(f"Received a message from host: {message}")awaitreceive_queue.put(message)logger.info("Put message in receive queue")returnasyncio.create_task(read_loop())asyncdefsend(self,message:agent_worker_pb2.Message)->None:logger.info(f"Send message to host: {message}")awaitself._send_queue.put(message)logger.info("Put message in send queue")asyncdefrecv(self)->agent_worker_pb2.Message:logger.info("Getting message from queue")returnawaitself._recv_queue.get()# TODO: Lots of types need to have protobuf equivalents:# Core:# - FunctionCall, CodeResult, possibly CodeBlock# - All the types in https://github.com/microsoft/autogen/blob/main/python/packages/autogen-core/src/autogen_core/models/_types.py## Agentchat:# - All the types in https://github.com/microsoft/autogen/blob/main/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py to protobufs.## Ext --# CodeExecutor:# - CommandLineCodeResult
[docs]classGrpcWorkerAgentRuntime(AgentRuntime):"""An agent runtime for running remote or cross-language agents. Agent messaging uses protobufs from `agent_worker.proto`_ and ``CloudEvent`` from `cloudevent.proto`_. Cross-language agents will additionally require all agents use shared protobuf schemas for any message types that are sent between agents. .. _agent_worker.proto: https://github.com/microsoft/autogen/blob/main/protos/agent_worker.proto .. _cloudevent.proto: https://github.com/microsoft/autogen/blob/main/protos/cloudevent.proto """# TODO: Needs to handle agent close() calldef__init__(self,host_address:str,tracer_provider:TracerProvider|None=None,extra_grpc_config:ChannelArgumentType|None=None,payload_serialization_format:str=JSON_DATA_CONTENT_TYPE,)->None:self._host_address=host_addressself._trace_helper=TraceHelper(tracer_provider,MessageRuntimeTracingConfig("Worker Runtime"))self._per_type_subscribers:DefaultDict[tuple[str,str],Set[AgentId]]=defaultdict(set)self._agent_factories:Dict[str,Callable[[],Agent|Awaitable[Agent]]|Callable[[AgentRuntime,AgentId],Agent|Awaitable[Agent]]]={}self._instantiated_agents:Dict[AgentId,Agent]={}self._known_namespaces:set[str]=set()self._read_task:None|Task[None]=Noneself._running=Falseself._pending_requests:Dict[str,Future[Any]]={}self._pending_requests_lock=asyncio.Lock()self._next_request_id=0self._host_connection:HostConnection|None=Noneself._background_tasks:Set[Task[Any]]=set()self._subscription_manager=SubscriptionManager()self._serialization_registry=SerializationRegistry()self._extra_grpc_config=extra_grpc_configor[]ifpayload_serialization_formatnotin{JSON_DATA_CONTENT_TYPE,PROTOBUF_DATA_CONTENT_TYPE}:raiseValueError(f"Unsupported payload serialization format: {payload_serialization_format}")self._payload_serialization_format=payload_serialization_format
[docs]asyncdefstart(self)->None:"""Start the runtime in a background task."""ifself._running:raiseValueError("Runtime is already running.")logger.info(f"Connecting to host: {self._host_address}")self._host_connection=awaitHostConnection.from_host_address(self._host_address,extra_grpc_config=self._extra_grpc_config)logger.info("Connection established")ifself._read_taskisNone:self._read_task=asyncio.create_task(self._run_read_loop())self._running=True
def_raise_on_exception(self,task:Task[Any])->None:exception=task.exception()ifexceptionisnotNone:raiseexceptionasyncdef_run_read_loop(self)->None:logger.info("Starting read loop")assertself._host_connectionisnotNone# TODO: catch exceptions and reconnectwhileself._running:try:message=awaitself._host_connection.recv()oneofcase=agent_worker_pb2.Message.WhichOneof(message,"message")matchoneofcase:case"request":task=asyncio.create_task(self._process_request(message.request))self._background_tasks.add(task)task.add_done_callback(self._raise_on_exception)task.add_done_callback(self._background_tasks.discard)case"response":task=asyncio.create_task(self._process_response(message.response))self._background_tasks.add(task)task.add_done_callback(self._raise_on_exception)task.add_done_callback(self._background_tasks.discard)case"cloudEvent":task=asyncio.create_task(self._process_event(message.cloudEvent))self._background_tasks.add(task)task.add_done_callback(self._raise_on_exception)task.add_done_callback(self._background_tasks.discard)caseNone:logger.warning("No message")exceptExceptionase:logger.error("Error in read loop",exc_info=e)
[docs]asyncdefstop(self)->None:"""Stop the runtime immediately."""ifnotself._running:raiseRuntimeError("Runtime is not running.")self._running=False# Wait for all background tasks to finish.final_tasks_results=awaitasyncio.gather(*self._background_tasks,return_exceptions=True)fortask_resultinfinal_tasks_results:ifisinstance(task_result,Exception):logger.error("Error in background task",exc_info=task_result)# Close the host connection.ifself._host_connectionisnotNone:try:awaitself._host_connection.close()exceptasyncio.CancelledError:pass# Cancel the read task.ifself._read_taskisnotNone:self._read_task.cancel()try:awaitself._read_taskexceptasyncio.CancelledError:pass
[docs]asyncdefstop_when_signal(self,signals:Sequence[signal.Signals]=(signal.SIGTERM,signal.SIGINT))->None:"""Stop the runtime when a signal is received."""loop=asyncio.get_running_loop()shutdown_event=asyncio.Event()defsignal_handler()->None:logger.info("Received exit signal, shutting down gracefully...")shutdown_event.set()forsiginsignals:loop.add_signal_handler(sig,signal_handler)# Wait for the signal to trigger the shutdown event.awaitshutdown_event.wait()# Stop the runtime.awaitself.stop()
@propertydef_known_agent_names(self)->Set[str]:returnset(self._agent_factories.keys())asyncdef_send_message(self,runtime_message:agent_worker_pb2.Message,send_type:Literal["send","publish"],recipient:AgentId|TopicId,telemetry_metadata:Mapping[str,str],)->None:ifself._host_connectionisNone:raiseRuntimeError("Host connection is not set.")withself._trace_helper.trace_block(send_type,recipient,parent=telemetry_metadata):awaitself._host_connection.send(runtime_message)
[docs]asyncdefsend_message(self,message:Any,recipient:AgentId,*,sender:AgentId|None=None,cancellation_token:CancellationToken|None=None,message_id:str|None=None,)->Any:# TODO: use message_idifnotself._running:raiseValueError("Runtime must be running when sending message.")ifself._host_connectionisNone:raiseRuntimeError("Host connection is not set.")data_type=self._serialization_registry.type_name(message)withself._trace_helper.trace_block("create",recipient,parent=None,extraAttributes={"message_type":data_type}):# create a new future for the resultfuture=asyncio.get_event_loop().create_future()request_id=awaitself._get_new_request_id()self._pending_requests[request_id]=futureserialized_message=self._serialization_registry.serialize(message,type_name=data_type,data_content_type=JSON_DATA_CONTENT_TYPE)telemetry_metadata=get_telemetry_grpc_metadata()runtime_message=agent_worker_pb2.Message(request=agent_worker_pb2.RpcRequest(request_id=request_id,target=agent_worker_pb2.AgentId(type=recipient.type,key=recipient.key),source=agent_worker_pb2.AgentId(type=sender.type,key=sender.key)ifsenderisnotNoneelseNone,metadata=telemetry_metadata,payload=agent_worker_pb2.Payload(data_type=data_type,data=serialized_message,data_content_type=JSON_DATA_CONTENT_TYPE,),))# TODO: Find a way to handle timeouts/errorstask=asyncio.create_task(self._send_message(runtime_message,"send",recipient,telemetry_metadata))self._background_tasks.add(task)task.add_done_callback(self._raise_on_exception)task.add_done_callback(self._background_tasks.discard)returnawaitfuture
[docs]asyncdefpublish_message(self,message:Any,topic_id:TopicId,*,sender:AgentId|None=None,cancellation_token:CancellationToken|None=None,message_id:str|None=None,)->None:ifnotself._running:raiseValueError("Runtime must be running when publishing message.")ifself._host_connectionisNone:raiseRuntimeError("Host connection is not set.")ifmessage_idisNone:message_id=str(uuid.uuid4())message_type=self._serialization_registry.type_name(message)withself._trace_helper.trace_block("create",topic_id,parent=None,extraAttributes={"message_type":message_type}):serialized_message=self._serialization_registry.serialize(message,type_name=message_type,data_content_type=self._payload_serialization_format)sender_id=senderorAgentId("unknown","unknown")attributes={_constants.DATA_CONTENT_TYPE_ATTR:cloudevent_pb2.CloudEvent.CloudEventAttributeValue(ce_string=self._payload_serialization_format),_constants.DATA_SCHEMA_ATTR:cloudevent_pb2.CloudEvent.CloudEventAttributeValue(ce_string=message_type),_constants.AGENT_SENDER_TYPE_ATTR:cloudevent_pb2.CloudEvent.CloudEventAttributeValue(ce_string=sender_id.type),_constants.AGENT_SENDER_KEY_ATTR:cloudevent_pb2.CloudEvent.CloudEventAttributeValue(ce_string=sender_id.key),_constants.MESSAGE_KIND_ATTR:cloudevent_pb2.CloudEvent.CloudEventAttributeValue(ce_string=_constants.MESSAGE_KIND_VALUE_PUBLISH),}# If sending JSON we fill text_data with the serialized message# If sending Protobuf we fill proto_data with the serialized message# TODO: add an encoding field for serializerifself._payload_serialization_format==JSON_DATA_CONTENT_TYPE:runtime_message=agent_worker_pb2.Message(cloudEvent=cloudevent_pb2.CloudEvent(id=message_id,spec_version="1.0",type=topic_id.type,source=topic_id.source,attributes=attributes,# TODO: use text, or proto fields appropriatelybinary_data=serialized_message,))else:# We need to unpack the serialized proto back into an Any# TODO: find a way to prevent the roundtrip serializationany_proto=any_pb2.Any()any_proto.ParseFromString(serialized_message)runtime_message=agent_worker_pb2.Message(cloudEvent=cloudevent_pb2.CloudEvent(id=message_id,spec_version="1.0",type=topic_id.type,source=topic_id.source,attributes=attributes,proto_data=any_proto,))telemetry_metadata=get_telemetry_grpc_metadata()task=asyncio.create_task(self._send_message(runtime_message,"publish",topic_id,telemetry_metadata))self._background_tasks.add(task)task.add_done_callback(self._raise_on_exception)task.add_done_callback(self._background_tasks.discard)
[docs]asyncdefsave_state(self)->Mapping[str,Any]:raiseNotImplementedError("Saving state is not yet implemented.")
[docs]asyncdefload_state(self,state:Mapping[str,Any])->None:raiseNotImplementedError("Loading state is not yet implemented.")
[docs]asyncdefagent_metadata(self,agent:AgentId)->AgentMetadata:raiseNotImplementedError("Agent metadata is not yet implemented.")
[docs]asyncdefagent_save_state(self,agent:AgentId)->Mapping[str,Any]:raiseNotImplementedError("Agent save_state is not yet implemented.")
[docs]asyncdefagent_load_state(self,agent:AgentId,state:Mapping[str,Any])->None:raiseNotImplementedError("Agent load_state is not yet implemented.")
asyncdef_get_new_request_id(self)->str:asyncwithself._pending_requests_lock:self._next_request_id+=1returnstr(self._next_request_id)asyncdef_process_request(self,request:agent_worker_pb2.RpcRequest)->None:assertself._host_connectionisnotNonerecipient=AgentId(request.target.type,request.target.key)sender:AgentId|None=Noneifrequest.HasField("source"):sender=AgentId(request.source.type,request.source.key)logging.info(f"Processing request from {sender} to {recipient}")else:logging.info(f"Processing request from unknown source to {recipient}")# Deserialize the message.message=self._serialization_registry.deserialize(request.payload.data,type_name=request.payload.data_type,data_content_type=request.payload.data_content_type,)# Get the receiving agent and prepare the message context.rec_agent=awaitself._get_agent(recipient)message_context=MessageContext(sender=sender,topic_id=None,is_rpc=True,cancellation_token=CancellationToken(),message_id=request.request_id,)# Call the receiving agent.try:withMessageHandlerContext.populate_context(rec_agent.id):withself._trace_helper.trace_block("process",rec_agent.id,parent=request.metadata,attributes={"request_id":request.request_id},extraAttributes={"message_type":request.payload.data_type},):result=awaitrec_agent.on_message(message,ctx=message_context)exceptBaseExceptionase:response_message=agent_worker_pb2.Message(response=agent_worker_pb2.RpcResponse(request_id=request.request_id,error=str(e),metadata=get_telemetry_grpc_metadata(),),)# Send the error response.awaitself._host_connection.send(response_message)return# Serialize the result.result_type=self._serialization_registry.type_name(result)serialized_result=self._serialization_registry.serialize(result,type_name=result_type,data_content_type=JSON_DATA_CONTENT_TYPE)# Create the response message.response_message=agent_worker_pb2.Message(response=agent_worker_pb2.RpcResponse(request_id=request.request_id,payload=agent_worker_pb2.Payload(data_type=result_type,data=serialized_result,data_content_type=JSON_DATA_CONTENT_TYPE,),metadata=get_telemetry_grpc_metadata(),))# Send the response.awaitself._host_connection.send(response_message)asyncdef_process_response(self,response:agent_worker_pb2.RpcResponse)->None:withself._trace_helper.trace_block("ack",None,parent=response.metadata,attributes={"request_id":response.request_id},extraAttributes={"message_type":response.payload.data_type},):# Deserialize the result.result=self._serialization_registry.deserialize(response.payload.data,type_name=response.payload.data_type,data_content_type=response.payload.data_content_type,)# Get the future and set the result.future=self._pending_requests.pop(response.request_id)iflen(response.error)>0:future.set_exception(Exception(response.error))else:future.set_result(result)asyncdef_process_event(self,event:cloudevent_pb2.CloudEvent)->None:event_attributes=event.attributessender:AgentId|None=Noneif(_constants.AGENT_SENDER_TYPE_ATTRinevent_attributesand_constants.AGENT_SENDER_KEY_ATTRinevent_attributes):sender=AgentId(event_attributes[_constants.AGENT_SENDER_TYPE_ATTR].ce_string,event_attributes[_constants.AGENT_SENDER_KEY_ATTR].ce_string,)topic_id=TopicId(event.type,event.source)# Get the recipients for the topic.recipients=awaitself._subscription_manager.get_subscribed_recipients(topic_id)message_content_type=event_attributes[_constants.DATA_CONTENT_TYPE_ATTR].ce_stringmessage_type=event_attributes[_constants.DATA_SCHEMA_ATTR].ce_stringifmessage_content_type==JSON_DATA_CONTENT_TYPE:message=self._serialization_registry.deserialize(event.binary_data,type_name=message_type,data_content_type=message_content_type)elifmessage_content_type==PROTOBUF_DATA_CONTENT_TYPE:# TODO: find a way to prevent the roundtrip serializationproto_binary_data=event.proto_data.SerializeToString()message=self._serialization_registry.deserialize(proto_binary_data,type_name=message_type,data_content_type=message_content_type)else:raiseValueError(f"Unsupported message content type: {message_content_type}")# TODO: dont read these values in the runtimetopic_type_suffix=topic_id.type.split(":",maxsplit=1)[1]if":"intopic_id.typeelse""is_rpc=topic_type_suffix==_constants.MESSAGE_KIND_VALUE_RPC_REQUESTis_marked_rpc_type=(_constants.MESSAGE_KIND_ATTRinevent_attributesandevent_attributes[_constants.MESSAGE_KIND_ATTR].ce_string==_constants.MESSAGE_KIND_VALUE_RPC_REQUEST)ifis_rpcandnotis_marked_rpc_type:warnings.warn("Received RPC request with topic type suffix but not marked as RPC request.",stacklevel=2)# Send the message to each recipient.responses:List[Awaitable[Any]]=[]foragent_idinrecipients:ifagent_id==sender:continuemessage_context=MessageContext(sender=sender,topic_id=topic_id,is_rpc=is_rpc,cancellation_token=CancellationToken(),message_id=event.id,)agent=awaitself._get_agent(agent_id)withMessageHandlerContext.populate_context(agent.id):defstringify_attributes(attributes:Mapping[str,cloudevent_pb2.CloudEvent.CloudEventAttributeValue],)->Mapping[str,str]:result:Dict[str,str]={}forkey,valueinattributes.items():item=Nonematchvalue.WhichOneof("attr"):case"ce_boolean":item=str(value.ce_boolean)case"ce_integer":item=str(value.ce_integer)case"ce_string":item=value.ce_stringcase"ce_bytes":item=str(value.ce_bytes)case"ce_uri":item=value.ce_uricase"ce_uri_ref":item=value.ce_uri_refcase"ce_timestamp":item=str(value.ce_timestamp)case_:raiseValueError("Unknown attribute kind")result[key]=itemreturnresultasyncdefsend_message(agent:Agent,message_context:MessageContext)->Any:withself._trace_helper.trace_block("process",agent.id,parent=stringify_attributes(event.attributes),extraAttributes={"message_type":message_type},):awaitagent.on_message(message,ctx=message_context)future=send_message(agent,message_context)responses.append(future)# Wait for all responses.try:awaitasyncio.gather(*responses)exceptBaseExceptionase:logger.error("Error handling event",exc_info=e)
[docs]asyncdefregister_factory(self,type:str|AgentType,agent_factory:Callable[[],T|Awaitable[T]],*,expected_class:type[T]|None=None,)->AgentType:ifisinstance(type,str):type=AgentType(type)iftype.typeinself._agent_factories:raiseValueError(f"Agent with type {type} already exists.")ifself._host_connectionisNone:raiseRuntimeError("Host connection is not set.")asyncdeffactory_wrapper()->T:maybe_agent_instance=agent_factory()ifinspect.isawaitable(maybe_agent_instance):agent_instance=awaitmaybe_agent_instanceelse:agent_instance=maybe_agent_instanceifexpected_classisnotNoneandtype_func_alias(agent_instance)!=expected_class:raiseValueError("Factory registered using the wrong type.")returnagent_instanceself._agent_factories[type.type]=factory_wrapper# Send the registration request message to the host.message=agent_worker_pb2.RegisterAgentTypeRequest(type=type.type)_response:agent_worker_pb2.RegisterAgentTypeResponse=awaitself._host_connection.stub.RegisterAgent(message,metadata=self._host_connection.metadata)returntype
asyncdef_invoke_agent_factory(self,agent_factory:Callable[[],T|Awaitable[T]]|Callable[[AgentRuntime,AgentId],T|Awaitable[T]],agent_id:AgentId,)->T:withAgentInstantiationContext.populate_context((self,agent_id)):iflen(inspect.signature(agent_factory).parameters)==0:factory_one=cast(Callable[[],T],agent_factory)agent=factory_one()eliflen(inspect.signature(agent_factory).parameters)==2:warnings.warn("Agent factories that take two arguments are deprecated. Use AgentInstantiationContext instead. Two arg factories will be removed in a future version.",stacklevel=2,)factory_two=cast(Callable[[AgentRuntime,AgentId],T],agent_factory)agent=factory_two(self,agent_id)else:raiseValueError("Agent factory must take 0 or 2 arguments.")ifinspect.isawaitable(agent):returncast(T,awaitagent)returnagentasyncdef_get_agent(self,agent_id:AgentId)->Agent:ifagent_idinself._instantiated_agents:returnself._instantiated_agents[agent_id]ifagent_id.typenotinself._agent_factories:raiseValueError(f"Agent with name {agent_id.type} not found.")agent_factory=self._agent_factories[agent_id.type]agent=awaitself._invoke_agent_factory(agent_factory,agent_id)self._instantiated_agents[agent_id]=agentreturnagent# TODO: uncomment out the following type ignore when this is fixed in mypy: https://github.com/python/mypy/issues/3737
[docs]asyncdeftry_get_underlying_agent_instance(self,id:AgentId,type:Type[T]=Agent)->T:# type: ignore[assignment]ifid.typenotinself._agent_factories:raiseLookupError(f"Agent with name {id.type} not found.")# TODO: check if remoteagent_instance=awaitself._get_agent(id)ifnotisinstance(agent_instance,type):raiseTypeError(f"Agent with name {id.type} is not of type {type.__name__}")returnagent_instance
[docs]asyncdefadd_subscription(self,subscription:Subscription)->None:ifself._host_connectionisNone:raiseRuntimeError("Host connection is not set.")message=agent_worker_pb2.AddSubscriptionRequest(subscription=subscription_to_proto(subscription))_response:agent_worker_pb2.AddSubscriptionResponse=awaitself._host_connection.stub.AddSubscription(message,metadata=self._host_connection.metadata)# Add to local subscription manager.awaitself._subscription_manager.add_subscription(subscription)
[docs]asyncdefremove_subscription(self,id:str)->None:ifself._host_connectionisNone:raiseRuntimeError("Host connection is not set.")message=agent_worker_pb2.RemoveSubscriptionRequest(id=id)_response:agent_worker_pb2.RemoveSubscriptionResponse=awaitself._host_connection.stub.RemoveSubscription(message,metadata=self._host_connection.metadata)awaitself._subscription_manager.remove_subscription(id)