Source code for autogen_ext.runtimes.grpc._worker_runtime_host_servicer

import asyncio
import logging
from _collections_abc import AsyncIterator, Iterator
from asyncio import Future, Task
from typing import Any, Dict, Set, cast

from autogen_core import Subscription, TopicId, TypePrefixSubscription, TypeSubscription
from autogen_core._runtime_impl_helpers import SubscriptionManager

from ._constants import GRPC_IMPORT_ERROR_STR

try:
    import grpc
except ImportError as e:
    raise ImportError(GRPC_IMPORT_ERROR_STR) from e

from .protos import agent_worker_pb2, agent_worker_pb2_grpc, cloudevent_pb2

logger = logging.getLogger("autogen_core")
event_logger = logging.getLogger("autogen_core.events")


[docs] class GrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer): """A gRPC servicer that hosts message delivery service for agents.""" def __init__(self) -> None: self._client_id = 0 self._client_id_lock = asyncio.Lock() self._send_queues: Dict[int, asyncio.Queue[agent_worker_pb2.Message]] = {} self._agent_type_to_client_id_lock = asyncio.Lock() self._agent_type_to_client_id: Dict[str, int] = {} self._pending_responses: Dict[int, Dict[str, Future[Any]]] = {} self._background_tasks: Set[Task[Any]] = set() self._subscription_manager = SubscriptionManager() self._client_id_to_subscription_id_mapping: Dict[int, set[str]] = {}
[docs] async def OpenChannel( # type: ignore self, request_iterator: AsyncIterator[agent_worker_pb2.Message], context: grpc.aio.ServicerContext[agent_worker_pb2.Message, agent_worker_pb2.Message], ) -> Iterator[agent_worker_pb2.Message] | AsyncIterator[agent_worker_pb2.Message]: # type: ignore # Aquire the lock to get a new client id. async with self._client_id_lock: self._client_id += 1 client_id = self._client_id # Register the client with the server and create a send queue for the client. send_queue: asyncio.Queue[agent_worker_pb2.Message] = asyncio.Queue() self._send_queues[client_id] = send_queue logger.info(f"Client {client_id} connected.") try: # Concurrently handle receiving messages from the client and sending messages to the client. # This task will receive messages from the client. receiving_task = asyncio.create_task(self._receive_messages(client_id, request_iterator)) # Return an async generator that will yield messages from the send queue to the client. while True: message = await send_queue.get() # Yield the message to the client. try: yield message except Exception as e: logger.error(f"Failed to send message to client {client_id}: {e}", exc_info=True) break logger.info(f"Sent message to client {client_id}: {message}") # Wait for the receiving task to finish. await receiving_task finally: # Clean up the client connection. del self._send_queues[client_id] # Cancel pending requests sent to this client. for future in self._pending_responses.pop(client_id, {}).values(): future.cancel() # Remove the client id from the agent type to client id mapping. await self._on_client_disconnect(client_id)
async def _on_client_disconnect(self, client_id: int) -> None: async with self._agent_type_to_client_id_lock: agent_types = [agent_type for agent_type, id_ in self._agent_type_to_client_id.items() if id_ == client_id] for agent_type in agent_types: logger.info(f"Removing agent type {agent_type} from agent type to client id mapping") del self._agent_type_to_client_id[agent_type] for sub_id in self._client_id_to_subscription_id_mapping.get(client_id, set()): logger.info(f"Client id {client_id} disconnected. Removing corresponding subscription with id {id}") await self._subscription_manager.remove_subscription(sub_id) logger.info(f"Client {client_id} disconnected successfully") def _raise_on_exception(self, task: Task[Any]) -> None: exception = task.exception() if exception is not None: raise exception async def _receive_messages( self, client_id: int, request_iterator: AsyncIterator[agent_worker_pb2.Message] ) -> None: # Receive messages from the client and process them. async for message in request_iterator: logger.info(f"Received message from client {client_id}: {message}") oneofcase = message.WhichOneof("message") match oneofcase: case "request": request: agent_worker_pb2.RpcRequest = message.request task = asyncio.create_task(self._process_request(request, client_id)) self._background_tasks.add(task) task.add_done_callback(self._raise_on_exception) task.add_done_callback(self._background_tasks.discard) case "response": response: agent_worker_pb2.RpcResponse = message.response task = asyncio.create_task(self._process_response(response, client_id)) self._background_tasks.add(task) task.add_done_callback(self._raise_on_exception) task.add_done_callback(self._background_tasks.discard) case "cloudEvent": # The proto typing doesnt resolve this one event = cast(cloudevent_pb2.CloudEvent, message.cloudEvent) # type: ignore task = asyncio.create_task(self._process_event(event)) self._background_tasks.add(task) task.add_done_callback(self._raise_on_exception) task.add_done_callback(self._background_tasks.discard) case "registerAgentTypeRequest": register_agent_type: agent_worker_pb2.RegisterAgentTypeRequest = message.registerAgentTypeRequest task = asyncio.create_task( self._process_register_agent_type_request(register_agent_type, client_id) ) self._background_tasks.add(task) task.add_done_callback(self._raise_on_exception) task.add_done_callback(self._background_tasks.discard) case "addSubscriptionRequest": add_subscription: agent_worker_pb2.AddSubscriptionRequest = message.addSubscriptionRequest task = asyncio.create_task(self._process_add_subscription_request(add_subscription, client_id)) self._background_tasks.add(task) task.add_done_callback(self._raise_on_exception) task.add_done_callback(self._background_tasks.discard) case "registerAgentTypeResponse" | "addSubscriptionResponse": logger.warning(f"Received unexpected message type: {oneofcase}") case None: logger.warning("Received empty message") async def _process_request(self, request: agent_worker_pb2.RpcRequest, client_id: int) -> None: # Deliver the message to a client given the target agent type. async with self._agent_type_to_client_id_lock: target_client_id = self._agent_type_to_client_id.get(request.target.type) if target_client_id is None: logger.error(f"Agent {request.target.type} not found, failed to deliver message.") return target_send_queue = self._send_queues.get(target_client_id) if target_send_queue is None: logger.error(f"Client {target_client_id} not found, failed to deliver message.") return await target_send_queue.put(agent_worker_pb2.Message(request=request)) # Create a future to wait for the response from the target. future = asyncio.get_event_loop().create_future() self._pending_responses.setdefault(target_client_id, {})[request.request_id] = future # Create a task to wait for the response and send it back to the client. send_response_task = asyncio.create_task(self._wait_and_send_response(future, client_id)) self._background_tasks.add(send_response_task) send_response_task.add_done_callback(self._raise_on_exception) send_response_task.add_done_callback(self._background_tasks.discard) async def _wait_and_send_response(self, future: Future[agent_worker_pb2.RpcResponse], client_id: int) -> None: response = await future message = agent_worker_pb2.Message(response=response) send_queue = self._send_queues.get(client_id) if send_queue is None: logger.error(f"Client {client_id} not found, failed to send response message.") return await send_queue.put(message) async def _process_response(self, response: agent_worker_pb2.RpcResponse, client_id: int) -> None: # Setting the result of the future will send the response back to the original sender. future = self._pending_responses[client_id].pop(response.request_id) future.set_result(response) async def _process_event(self, event: cloudevent_pb2.CloudEvent) -> None: topic_id = TopicId(type=event.type, source=event.source) recipients = await self._subscription_manager.get_subscribed_recipients(topic_id) # Get the client ids of the recipients. async with self._agent_type_to_client_id_lock: client_ids: Set[int] = set() for recipient in recipients: client_id = self._agent_type_to_client_id.get(recipient.type) if client_id is not None: client_ids.add(client_id) else: logger.error(f"Agent {recipient.type} and its client not found for topic {topic_id}.") # Deliver the event to clients. for client_id in client_ids: await self._send_queues[client_id].put(agent_worker_pb2.Message(cloudEvent=event)) async def _process_register_agent_type_request( self, register_agent_type_req: agent_worker_pb2.RegisterAgentTypeRequest, client_id: int ) -> None: # Register the agent type with the host runtime. async with self._agent_type_to_client_id_lock: if register_agent_type_req.type in self._agent_type_to_client_id: existing_client_id = self._agent_type_to_client_id[register_agent_type_req.type] logger.error( f"Agent type {register_agent_type_req.type} already registered with client {existing_client_id}." ) success = False error = f"Agent type {register_agent_type_req.type} already registered." else: self._agent_type_to_client_id[register_agent_type_req.type] = client_id success = True error = None # Send a response back to the client. await self._send_queues[client_id].put( agent_worker_pb2.Message( registerAgentTypeResponse=agent_worker_pb2.RegisterAgentTypeResponse( request_id=register_agent_type_req.request_id, success=success, error=error ) ) ) async def _process_add_subscription_request( self, add_subscription_req: agent_worker_pb2.AddSubscriptionRequest, client_id: int ) -> None: oneofcase = add_subscription_req.subscription.WhichOneof("subscription") subscription: Subscription | None = None match oneofcase: case "typeSubscription": type_subscription_msg: agent_worker_pb2.TypeSubscription = ( add_subscription_req.subscription.typeSubscription ) subscription = TypeSubscription( topic_type=type_subscription_msg.topic_type, agent_type=type_subscription_msg.agent_type ) case "typePrefixSubscription": type_prefix_subscription_msg: agent_worker_pb2.TypePrefixSubscription = ( add_subscription_req.subscription.typePrefixSubscription ) subscription = TypePrefixSubscription( topic_type_prefix=type_prefix_subscription_msg.topic_type_prefix, agent_type=type_prefix_subscription_msg.agent_type, ) case None: logger.warning("Received empty subscription message") if subscription is not None: try: await self._subscription_manager.add_subscription(subscription) subscription_ids = self._client_id_to_subscription_id_mapping.setdefault(client_id, set()) subscription_ids.add(subscription.id) success = True error = None except ValueError as e: success = False error = str(e) # Send a response back to the client. await self._send_queues[client_id].put( agent_worker_pb2.Message( addSubscriptionResponse=agent_worker_pb2.AddSubscriptionResponse( request_id=add_subscription_req.request_id, success=success, error=error ) ) )
[docs] async def GetState( # type: ignore self, request: agent_worker_pb2.AgentId, context: grpc.aio.ServicerContext[agent_worker_pb2.AgentId, agent_worker_pb2.GetStateResponse], ) -> agent_worker_pb2.GetStateResponse: # type: ignore raise NotImplementedError("Method not implemented!")
[docs] async def SaveState( # type: ignore self, request: agent_worker_pb2.AgentState, context: grpc.aio.ServicerContext[agent_worker_pb2.AgentId, agent_worker_pb2.SaveStateResponse], ) -> agent_worker_pb2.SaveStateResponse: # type: ignore raise NotImplementedError("Method not implemented!")