Source code for autogen_ext.runtimes.grpc._worker_runtime_host_servicer
importasyncioimportloggingfrom_collections_abcimportAsyncIteratorfromasyncioimportFuture,TaskfromtypingimportAny,Dict,Sequence,Set,Tuplefromautogen_coreimportTopicIdfromautogen_core._runtime_impl_helpersimportSubscriptionManagerfrom._constantsimportGRPC_IMPORT_ERROR_STRfrom._utilsimportsubscription_from_prototry:importgrpcexceptImportErrorase:raiseImportError(GRPC_IMPORT_ERROR_STR)fromefrom.protosimportagent_worker_pb2,agent_worker_pb2_grpc,cloudevent_pb2logger=logging.getLogger("autogen_core")event_logger=logging.getLogger("autogen_core.events")ClientConnectionId=strdefmetadata_to_dict(metadata:Sequence[Tuple[str,str]]|None)->Dict[str,str]:ifmetadataisNone:return{}return{key:valueforkey,valueinmetadata}asyncdefget_client_id_or_abort(context:grpc.aio.ServicerContext[Any,Any])->str:# type: ignore# The type hint on context.invocation_metadata() is incorrect.metadata=metadata_to_dict(context.invocation_metadata())# type: ignoreif(client_id:=metadata.get("client-id"))isNone:awaitcontext.abort(grpc.StatusCode.INVALID_ARGUMENT,"client-id metadata not found.")returnclient_id# type: ignore
[docs]classGrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer):"""A gRPC servicer that hosts message delivery service for agents."""def__init__(self)->None:self._send_queues:Dict[ClientConnectionId,asyncio.Queue[agent_worker_pb2.Message]]={}self._agent_type_to_client_id_lock=asyncio.Lock()self._agent_type_to_client_id:Dict[str,ClientConnectionId]={}self._pending_responses:Dict[ClientConnectionId,Dict[str,Future[Any]]]={}self._background_tasks:Set[Task[Any]]=set()self._subscription_manager=SubscriptionManager()self._client_id_to_subscription_id_mapping:Dict[ClientConnectionId,set[str]]={}
[docs]asyncdefOpenChannel(# type: ignoreself,request_iterator:AsyncIterator[agent_worker_pb2.Message],context:grpc.aio.ServicerContext[agent_worker_pb2.Message,agent_worker_pb2.Message],)->AsyncIterator[agent_worker_pb2.Message]:client_id=awaitget_client_id_or_abort(context)# Register the client with the server and create a send queue for the client.send_queue:asyncio.Queue[agent_worker_pb2.Message]=asyncio.Queue()self._send_queues[client_id]=send_queuelogger.info(f"Client {client_id} connected.")try:# Concurrently handle receiving messages from the client and sending messages to the client.# This task will receive messages from the client.receiving_task=asyncio.create_task(self._receive_messages(client_id,request_iterator))# Return an async generator that will yield messages from the send queue to the client.whileTrue:message=awaitsend_queue.get()# Yield the message to the client.try:yieldmessageexceptExceptionase:logger.error(f"Failed to send message to client {client_id}: {e}",exc_info=True)breaklogger.info(f"Sent message to client {client_id}: {message}")# Wait for the receiving task to finish.awaitreceiving_taskfinally:# Clean up the client connection.delself._send_queues[client_id]# Cancel pending requests sent to this client.forfutureinself._pending_responses.pop(client_id,{}).values():future.cancel()# Remove the client id from the agent type to client id mapping.awaitself._on_client_disconnect(client_id)
asyncdef_on_client_disconnect(self,client_id:ClientConnectionId)->None:asyncwithself._agent_type_to_client_id_lock:agent_types=[agent_typeforagent_type,id_inself._agent_type_to_client_id.items()ifid_==client_id]foragent_typeinagent_types:logger.info(f"Removing agent type {agent_type} from agent type to client id mapping")delself._agent_type_to_client_id[agent_type]forsub_idinself._client_id_to_subscription_id_mapping.get(client_id,set()):logger.info(f"Client id {client_id} disconnected. Removing corresponding subscription with id {id}")awaitself._subscription_manager.remove_subscription(sub_id)logger.info(f"Client {client_id} disconnected successfully")def_raise_on_exception(self,task:Task[Any])->None:exception=task.exception()ifexceptionisnotNone:raiseexceptionasyncdef_receive_messages(self,client_id:ClientConnectionId,request_iterator:AsyncIterator[agent_worker_pb2.Message])->None:# Receive messages from the client and process them.asyncformessageinrequest_iterator:logger.info(f"Received message from client {client_id}: {message}")oneofcase=message.WhichOneof("message")matchoneofcase:case"request":request:agent_worker_pb2.RpcRequest=message.requesttask=asyncio.create_task(self._process_request(request,client_id))self._background_tasks.add(task)task.add_done_callback(self._raise_on_exception)task.add_done_callback(self._background_tasks.discard)case"response":response:agent_worker_pb2.RpcResponse=message.responsetask=asyncio.create_task(self._process_response(response,client_id))self._background_tasks.add(task)task.add_done_callback(self._raise_on_exception)task.add_done_callback(self._background_tasks.discard)case"cloudEvent":task=asyncio.create_task(self._process_event(message.cloudEvent))self._background_tasks.add(task)task.add_done_callback(self._raise_on_exception)task.add_done_callback(self._background_tasks.discard)caseNone:logger.warning("Received empty message")asyncdef_process_request(self,request:agent_worker_pb2.RpcRequest,client_id:ClientConnectionId)->None:# Deliver the message to a client given the target agent type.asyncwithself._agent_type_to_client_id_lock:target_client_id=self._agent_type_to_client_id.get(request.target.type)iftarget_client_idisNone:logger.error(f"Agent {request.target.type} not found, failed to deliver message.")returntarget_send_queue=self._send_queues.get(target_client_id)iftarget_send_queueisNone:logger.error(f"Client {target_client_id} not found, failed to deliver message.")returnawaittarget_send_queue.put(agent_worker_pb2.Message(request=request))# Create a future to wait for the response from the target.future=asyncio.get_event_loop().create_future()self._pending_responses.setdefault(target_client_id,{})[request.request_id]=future# Create a task to wait for the response and send it back to the client.send_response_task=asyncio.create_task(self._wait_and_send_response(future,client_id))self._background_tasks.add(send_response_task)send_response_task.add_done_callback(self._raise_on_exception)send_response_task.add_done_callback(self._background_tasks.discard)asyncdef_wait_and_send_response(self,future:Future[agent_worker_pb2.RpcResponse],client_id:ClientConnectionId)->None:response=awaitfuturemessage=agent_worker_pb2.Message(response=response)send_queue=self._send_queues.get(client_id)ifsend_queueisNone:logger.error(f"Client {client_id} not found, failed to send response message.")returnawaitsend_queue.put(message)asyncdef_process_response(self,response:agent_worker_pb2.RpcResponse,client_id:ClientConnectionId)->None:# Setting the result of the future will send the response back to the original sender.future=self._pending_responses[client_id].pop(response.request_id)future.set_result(response)asyncdef_process_event(self,event:cloudevent_pb2.CloudEvent)->None:topic_id=TopicId(type=event.type,source=event.source)recipients=awaitself._subscription_manager.get_subscribed_recipients(topic_id)# Get the client ids of the recipients.asyncwithself._agent_type_to_client_id_lock:client_ids:Set[ClientConnectionId]=set()forrecipientinrecipients:client_id=self._agent_type_to_client_id.get(recipient.type)ifclient_idisnotNone:client_ids.add(client_id)else:logger.error(f"Agent {recipient.type} and its client not found for topic {topic_id}.")# Deliver the event to clients.forclient_idinclient_ids:awaitself._send_queues[client_id].put(agent_worker_pb2.Message(cloudEvent=event))
[docs]asyncdefRegisterAgent(# type: ignoreself,request:agent_worker_pb2.RegisterAgentTypeRequest,context:grpc.aio.ServicerContext[agent_worker_pb2.RegisterAgentTypeRequest,agent_worker_pb2.RegisterAgentTypeResponse],)->agent_worker_pb2.RegisterAgentTypeResponse:client_id=awaitget_client_id_or_abort(context)asyncwithself._agent_type_to_client_id_lock:ifrequest.typeinself._agent_type_to_client_id:existing_client_id=self._agent_type_to_client_id[request.type]awaitcontext.abort(grpc.StatusCode.INVALID_ARGUMENT,f"Agent type {request.type} already registered with client {existing_client_id}.",)else:self._agent_type_to_client_id[request.type]=client_idreturnagent_worker_pb2.RegisterAgentTypeResponse()
[docs]asyncdefRemoveSubscription(# type: ignoreself,request:agent_worker_pb2.RemoveSubscriptionRequest,context:grpc.aio.ServicerContext[agent_worker_pb2.RemoveSubscriptionRequest,agent_worker_pb2.RemoveSubscriptionResponse],)->agent_worker_pb2.RemoveSubscriptionResponse:_client_id=awaitget_client_id_or_abort(context)raiseNotImplementedError("Method not implemented.")
[docs]asyncdefGetSubscriptions(# type: ignoreself,request:agent_worker_pb2.GetSubscriptionsRequest,context:grpc.aio.ServicerContext[agent_worker_pb2.GetSubscriptionsRequest,agent_worker_pb2.GetSubscriptionsResponse],)->agent_worker_pb2.GetSubscriptionsResponse:_client_id=awaitget_client_id_or_abort(context)raiseNotImplementedError("Method not implemented.")
[docs]asyncdefGetState(# type: ignoreself,request:agent_worker_pb2.AgentId,context:grpc.aio.ServicerContext[agent_worker_pb2.AgentId,agent_worker_pb2.GetStateResponse],)->agent_worker_pb2.GetStateResponse:raiseNotImplementedError("Method not implemented!")
[docs]asyncdefSaveState(# type: ignoreself,request:agent_worker_pb2.AgentState,context:grpc.aio.ServicerContext[agent_worker_pb2.AgentId,agent_worker_pb2.SaveStateResponse],)->agent_worker_pb2.SaveStateResponse:raiseNotImplementedError("Method not implemented!")