Source code for autogen_ext.runtimes.grpc._worker_runtime_host_servicer
from__future__importannotationsimportasyncioimportloggingfromabcimportABC,abstractmethodfromasyncioimportFuture,TaskfromtypingimportAny,AsyncIterator,Awaitable,Callable,Dict,Generic,Sequence,Set,Tuple,TypeVarfromautogen_coreimportTopicIdfromautogen_core._agent_idimportAgentIdfromautogen_core._runtime_impl_helpersimportSubscriptionManagerfrom._constantsimportGRPC_IMPORT_ERROR_STRfrom._utilsimportsubscription_from_proto,subscription_to_prototry:importgrpcexceptImportErrorase:raiseImportError(GRPC_IMPORT_ERROR_STR)fromefrom.protosimportagent_worker_pb2,agent_worker_pb2_grpc,cloudevent_pb2logger=logging.getLogger("autogen_core")event_logger=logging.getLogger("autogen_core.events")ClientConnectionId=strdefmetadata_to_dict(metadata:Sequence[Tuple[str,str]]|None)->Dict[str,str]:ifmetadataisNone:return{}return{key:valueforkey,valueinmetadata}asyncdefget_client_id_or_abort(context:grpc.aio.ServicerContext[Any,Any])->str:# type: ignore# The type hint on context.invocation_metadata() is incorrect.metadata=metadata_to_dict(context.invocation_metadata())# type: ignoreif(client_id:=metadata.get("client-id"))isNone:awaitcontext.abort(grpc.StatusCode.INVALID_ARGUMENT,"client-id metadata not found.")returnclient_id# type: ignoreSendT=TypeVar("SendT")ReceiveT=TypeVar("ReceiveT")classChannelConnection(ABC,Generic[SendT,ReceiveT]):def__init__(self,request_iterator:AsyncIterator[ReceiveT],client_id:str)->None:self._request_iterator=request_iteratorself._client_id=client_idself._send_queue:asyncio.Queue[SendT]=asyncio.Queue()self._receiving_task=asyncio.create_task(self._receive_messages(client_id,request_iterator))asyncdef_receive_messages(self,client_id:ClientConnectionId,request_iterator:AsyncIterator[ReceiveT])->None:# Receive messages from the client and process them.asyncformessageinrequest_iterator:logger.info(f"Received message from client {client_id}: {message}")awaitself._handle_message(message)def__aiter__(self)->AsyncIterator[SendT]:returnselfasyncdef__anext__(self)->SendT:try:returnawaitself._send_queue.get()exceptStopAsyncIteration:awaitself._receiving_taskraiseexceptExceptionase:logger.error(f"Failed to get message from send queue: {e}",exc_info=True)awaitself._receiving_taskraise@abstractmethodasyncdef_handle_message(self,message:ReceiveT)->None:passasyncdefsend(self,message:SendT)->None:awaitself._send_queue.put(message)classCallbackChannelConnection(ChannelConnection[SendT,ReceiveT]):def__init__(self,request_iterator:AsyncIterator[ReceiveT],client_id:str,handle_callback:Callable[[ReceiveT],Awaitable[None]],)->None:self._handle_callback=handle_callbacksuper().__init__(request_iterator,client_id)asyncdef_handle_message(self,message:ReceiveT)->None:awaitself._handle_callback(message)
[docs]classGrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer):"""A gRPC servicer that hosts message delivery service for agents."""def__init__(self)->None:self._data_connections:Dict[ClientConnectionId,ChannelConnection[agent_worker_pb2.Message,agent_worker_pb2.Message]]={}self._control_connections:Dict[ClientConnectionId,ChannelConnection[agent_worker_pb2.ControlMessage,agent_worker_pb2.ControlMessage]]={}self._agent_type_to_client_id_lock=asyncio.Lock()self._agent_type_to_client_id:Dict[str,ClientConnectionId]={}self._pending_responses:Dict[ClientConnectionId,Dict[str,Future[Any]]]={}self._background_tasks:Set[Task[Any]]=set()self._subscription_manager=SubscriptionManager()self._client_id_to_subscription_id_mapping:Dict[ClientConnectionId,set[str]]={}
[docs]asyncdefOpenChannel(# type: ignoreself,request_iterator:AsyncIterator[agent_worker_pb2.Message],context:grpc.aio.ServicerContext[agent_worker_pb2.Message,agent_worker_pb2.Message],)->AsyncIterator[agent_worker_pb2.Message]:client_id=awaitget_client_id_or_abort(context)asyncdefhandle_callback(message:agent_worker_pb2.Message)->None:awaitself._receive_message(client_id,message)connection=CallbackChannelConnection[agent_worker_pb2.Message,agent_worker_pb2.Message](request_iterator,client_id,handle_callback=handle_callback)self._data_connections[client_id]=connectionlogger.info(f"Client {client_id} connected.")try:asyncformessageinconnection:yieldmessagefinally:# Clean up the client connection.delself._data_connections[client_id]# Cancel pending requests sent to this client.forfutureinself._pending_responses.pop(client_id,{}).values():future.cancel()# Remove the client id from the agent type to client id mapping.awaitself._on_client_disconnect(client_id)
[docs]asyncdefOpenControlChannel(# type: ignoreself,request_iterator:AsyncIterator[agent_worker_pb2.ControlMessage],context:grpc.aio.ServicerContext[agent_worker_pb2.ControlMessage,agent_worker_pb2.ControlMessage],)->AsyncIterator[agent_worker_pb2.ControlMessage]:client_id=awaitget_client_id_or_abort(context)asyncdefhandle_callback(message:agent_worker_pb2.ControlMessage)->None:awaitself._receive_control_message(client_id,message)connection=CallbackChannelConnection[agent_worker_pb2.ControlMessage,agent_worker_pb2.ControlMessage](request_iterator,client_id,handle_callback=handle_callback)self._control_connections[client_id]=connectionlogger.info(f"Client {client_id} connected.")try:asyncformessageinconnection:yieldmessagefinally:# Clean up the client connection.delself._control_connections[client_id]
asyncdef_on_client_disconnect(self,client_id:ClientConnectionId)->None:asyncwithself._agent_type_to_client_id_lock:agent_types=[agent_typeforagent_type,id_inself._agent_type_to_client_id.items()ifid_==client_id]foragent_typeinagent_types:logger.info(f"Removing agent type {agent_type} from agent type to client id mapping")delself._agent_type_to_client_id[agent_type]forsub_idinself._client_id_to_subscription_id_mapping.get(client_id,set()):logger.info(f"Client id {client_id} disconnected. Removing corresponding subscription with id {id}")try:awaitself._subscription_manager.remove_subscription(sub_id)# Catch and ignore if the subscription does not exist.exceptValueError:continuelogger.info(f"Client {client_id} disconnected successfully")def_raise_on_exception(self,task:Task[Any])->None:exception=task.exception()ifexceptionisnotNone:raiseexceptionasyncdef_receive_message(self,client_id:ClientConnectionId,message:agent_worker_pb2.Message)->None:logger.info(f"Received message from client {client_id}: {message}")oneofcase=message.WhichOneof("message")matchoneofcase:case"request":request:agent_worker_pb2.RpcRequest=message.requesttask=asyncio.create_task(self._process_request(request,client_id))self._background_tasks.add(task)task.add_done_callback(self._raise_on_exception)task.add_done_callback(self._background_tasks.discard)case"response":response:agent_worker_pb2.RpcResponse=message.responsetask=asyncio.create_task(self._process_response(response,client_id))self._background_tasks.add(task)task.add_done_callback(self._raise_on_exception)task.add_done_callback(self._background_tasks.discard)case"cloudEvent":task=asyncio.create_task(self._process_event(message.cloudEvent))self._background_tasks.add(task)task.add_done_callback(self._raise_on_exception)task.add_done_callback(self._background_tasks.discard)caseNone:logger.warning("Received empty message")asyncdef_receive_control_message(self,client_id:ClientConnectionId,message:agent_worker_pb2.ControlMessage)->None:logger.info(f"Received message from client {client_id}: {message}")destination=message.destinationifdestination.startswith("agentid="):agent_id=AgentId.from_str(destination[len("agentid="):])target_client_id=self._agent_type_to_client_id.get(agent_id.type)iftarget_client_idisNone:logger.error(f"Agent client id not found for agent type {agent_id.type}.")returnelifdestination.startswith("clientid="):target_client_id=destination[len("clientid="):]else:logger.error(f"Invalid destination {destination}")returntarget_send_queue=self._control_connections.get(target_client_id)iftarget_send_queueisNone:logger.error(f"Client {target_client_id} not found, failed to deliver message.")returnawaittarget_send_queue.send(message)asyncdef_process_request(self,request:agent_worker_pb2.RpcRequest,client_id:ClientConnectionId)->None:# Deliver the message to a client given the target agent type.asyncwithself._agent_type_to_client_id_lock:target_client_id=self._agent_type_to_client_id.get(request.target.type)iftarget_client_idisNone:logger.error(f"Agent {request.target.type} not found, failed to deliver message.")returntarget_send_queue=self._data_connections.get(target_client_id)iftarget_send_queueisNone:logger.error(f"Client {target_client_id} not found, failed to deliver message.")returnawaittarget_send_queue.send(agent_worker_pb2.Message(request=request))# Create a future to wait for the response from the target.future=asyncio.get_event_loop().create_future()self._pending_responses.setdefault(target_client_id,{})[request.request_id]=future# Create a task to wait for the response and send it back to the client.send_response_task=asyncio.create_task(self._wait_and_send_response(future,client_id))self._background_tasks.add(send_response_task)send_response_task.add_done_callback(self._raise_on_exception)send_response_task.add_done_callback(self._background_tasks.discard)asyncdef_wait_and_send_response(self,future:Future[agent_worker_pb2.RpcResponse],client_id:ClientConnectionId)->None:response=awaitfuturemessage=agent_worker_pb2.Message(response=response)send_queue=self._data_connections.get(client_id)ifsend_queueisNone:logger.error(f"Client {client_id} not found, failed to send response message.")returnawaitsend_queue.send(message)asyncdef_process_response(self,response:agent_worker_pb2.RpcResponse,client_id:ClientConnectionId)->None:# Setting the result of the future will send the response back to the original sender.future=self._pending_responses[client_id].pop(response.request_id)future.set_result(response)asyncdef_process_event(self,event:cloudevent_pb2.CloudEvent)->None:topic_id=TopicId(type=event.type,source=event.source)recipients=awaitself._subscription_manager.get_subscribed_recipients(topic_id)# Get the client ids of the recipients.asyncwithself._agent_type_to_client_id_lock:client_ids:Set[ClientConnectionId]=set()forrecipientinrecipients:client_id=self._agent_type_to_client_id.get(recipient.type)ifclient_idisnotNone:client_ids.add(client_id)else:logger.error(f"Agent {recipient.type} and its client not found for topic {topic_id}.")# Deliver the event to clients.forclient_idinclient_ids:awaitself._data_connections[client_id].send(agent_worker_pb2.Message(cloudEvent=event))
[docs]asyncdefRegisterAgent(# type: ignoreself,request:agent_worker_pb2.RegisterAgentTypeRequest,context:grpc.aio.ServicerContext[agent_worker_pb2.RegisterAgentTypeRequest,agent_worker_pb2.RegisterAgentTypeResponse],)->agent_worker_pb2.RegisterAgentTypeResponse:client_id=awaitget_client_id_or_abort(context)asyncwithself._agent_type_to_client_id_lock:ifrequest.typeinself._agent_type_to_client_id:existing_client_id=self._agent_type_to_client_id[request.type]awaitcontext.abort(grpc.StatusCode.INVALID_ARGUMENT,f"Agent type {request.type} already registered with client {existing_client_id}.",)else:self._agent_type_to_client_id[request.type]=client_idreturnagent_worker_pb2.RegisterAgentTypeResponse()