import asyncio
import logging
from _collections_abc import AsyncIterator
from asyncio import Future, Task
from typing import Any, Dict, Sequence, Set, Tuple
from autogen_core import TopicId
from autogen_core._runtime_impl_helpers import SubscriptionManager
from ._constants import GRPC_IMPORT_ERROR_STR
from ._utils import subscription_from_proto
try:
import grpc
except ImportError as e:
raise ImportError(GRPC_IMPORT_ERROR_STR) from e
from .protos import agent_worker_pb2, agent_worker_pb2_grpc, cloudevent_pb2
logger = logging.getLogger("autogen_core")
event_logger = logging.getLogger("autogen_core.events")
ClientConnectionId = str
def metadata_to_dict(metadata: Sequence[Tuple[str, str]] | None) -> Dict[str, str]:
if metadata is None:
return {}
return {key: value for key, value in metadata}
async def get_client_id_or_abort(context: grpc.aio.ServicerContext[Any, Any]) -> str: # type: ignore
# The type hint on context.invocation_metadata() is incorrect.
metadata = metadata_to_dict(context.invocation_metadata()) # type: ignore
if (client_id := metadata.get("client-id")) is None:
await context.abort(grpc.StatusCode.INVALID_ARGUMENT, "client-id metadata not found.")
return client_id # type: ignore
[docs]
class GrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer):
"""A gRPC servicer that hosts message delivery service for agents."""
def __init__(self) -> None:
self._send_queues: Dict[ClientConnectionId, asyncio.Queue[agent_worker_pb2.Message]] = {}
self._agent_type_to_client_id_lock = asyncio.Lock()
self._agent_type_to_client_id: Dict[str, ClientConnectionId] = {}
self._pending_responses: Dict[ClientConnectionId, Dict[str, Future[Any]]] = {}
self._background_tasks: Set[Task[Any]] = set()
self._subscription_manager = SubscriptionManager()
self._client_id_to_subscription_id_mapping: Dict[ClientConnectionId, set[str]] = {}
[docs]
async def OpenChannel( # type: ignore
self,
request_iterator: AsyncIterator[agent_worker_pb2.Message],
context: grpc.aio.ServicerContext[agent_worker_pb2.Message, agent_worker_pb2.Message],
) -> AsyncIterator[agent_worker_pb2.Message]:
client_id = await get_client_id_or_abort(context)
# Register the client with the server and create a send queue for the client.
send_queue: asyncio.Queue[agent_worker_pb2.Message] = asyncio.Queue()
self._send_queues[client_id] = send_queue
logger.info(f"Client {client_id} connected.")
try:
# Concurrently handle receiving messages from the client and sending messages to the client.
# This task will receive messages from the client.
receiving_task = asyncio.create_task(self._receive_messages(client_id, request_iterator))
# Return an async generator that will yield messages from the send queue to the client.
while True:
message = await send_queue.get()
# Yield the message to the client.
try:
yield message
except Exception as e:
logger.error(f"Failed to send message to client {client_id}: {e}", exc_info=True)
break
logger.info(f"Sent message to client {client_id}: {message}")
# Wait for the receiving task to finish.
await receiving_task
finally:
# Clean up the client connection.
del self._send_queues[client_id]
# Cancel pending requests sent to this client.
for future in self._pending_responses.pop(client_id, {}).values():
future.cancel()
# Remove the client id from the agent type to client id mapping.
await self._on_client_disconnect(client_id)
async def _on_client_disconnect(self, client_id: ClientConnectionId) -> None:
async with self._agent_type_to_client_id_lock:
agent_types = [agent_type for agent_type, id_ in self._agent_type_to_client_id.items() if id_ == client_id]
for agent_type in agent_types:
logger.info(f"Removing agent type {agent_type} from agent type to client id mapping")
del self._agent_type_to_client_id[agent_type]
for sub_id in self._client_id_to_subscription_id_mapping.get(client_id, set()):
logger.info(f"Client id {client_id} disconnected. Removing corresponding subscription with id {id}")
await self._subscription_manager.remove_subscription(sub_id)
logger.info(f"Client {client_id} disconnected successfully")
def _raise_on_exception(self, task: Task[Any]) -> None:
exception = task.exception()
if exception is not None:
raise exception
async def _receive_messages(
self, client_id: ClientConnectionId, request_iterator: AsyncIterator[agent_worker_pb2.Message]
) -> None:
# Receive messages from the client and process them.
async for message in request_iterator:
logger.info(f"Received message from client {client_id}: {message}")
oneofcase = message.WhichOneof("message")
match oneofcase:
case "request":
request: agent_worker_pb2.RpcRequest = message.request
task = asyncio.create_task(self._process_request(request, client_id))
self._background_tasks.add(task)
task.add_done_callback(self._raise_on_exception)
task.add_done_callback(self._background_tasks.discard)
case "response":
response: agent_worker_pb2.RpcResponse = message.response
task = asyncio.create_task(self._process_response(response, client_id))
self._background_tasks.add(task)
task.add_done_callback(self._raise_on_exception)
task.add_done_callback(self._background_tasks.discard)
case "cloudEvent":
task = asyncio.create_task(self._process_event(message.cloudEvent))
self._background_tasks.add(task)
task.add_done_callback(self._raise_on_exception)
task.add_done_callback(self._background_tasks.discard)
case None:
logger.warning("Received empty message")
async def _process_request(self, request: agent_worker_pb2.RpcRequest, client_id: ClientConnectionId) -> None:
# Deliver the message to a client given the target agent type.
async with self._agent_type_to_client_id_lock:
target_client_id = self._agent_type_to_client_id.get(request.target.type)
if target_client_id is None:
logger.error(f"Agent {request.target.type} not found, failed to deliver message.")
return
target_send_queue = self._send_queues.get(target_client_id)
if target_send_queue is None:
logger.error(f"Client {target_client_id} not found, failed to deliver message.")
return
await target_send_queue.put(agent_worker_pb2.Message(request=request))
# Create a future to wait for the response from the target.
future = asyncio.get_event_loop().create_future()
self._pending_responses.setdefault(target_client_id, {})[request.request_id] = future
# Create a task to wait for the response and send it back to the client.
send_response_task = asyncio.create_task(self._wait_and_send_response(future, client_id))
self._background_tasks.add(send_response_task)
send_response_task.add_done_callback(self._raise_on_exception)
send_response_task.add_done_callback(self._background_tasks.discard)
async def _wait_and_send_response(
self, future: Future[agent_worker_pb2.RpcResponse], client_id: ClientConnectionId
) -> None:
response = await future
message = agent_worker_pb2.Message(response=response)
send_queue = self._send_queues.get(client_id)
if send_queue is None:
logger.error(f"Client {client_id} not found, failed to send response message.")
return
await send_queue.put(message)
async def _process_response(self, response: agent_worker_pb2.RpcResponse, client_id: ClientConnectionId) -> None:
# Setting the result of the future will send the response back to the original sender.
future = self._pending_responses[client_id].pop(response.request_id)
future.set_result(response)
async def _process_event(self, event: cloudevent_pb2.CloudEvent) -> None:
topic_id = TopicId(type=event.type, source=event.source)
recipients = await self._subscription_manager.get_subscribed_recipients(topic_id)
# Get the client ids of the recipients.
async with self._agent_type_to_client_id_lock:
client_ids: Set[ClientConnectionId] = set()
for recipient in recipients:
client_id = self._agent_type_to_client_id.get(recipient.type)
if client_id is not None:
client_ids.add(client_id)
else:
logger.error(f"Agent {recipient.type} and its client not found for topic {topic_id}.")
# Deliver the event to clients.
for client_id in client_ids:
await self._send_queues[client_id].put(agent_worker_pb2.Message(cloudEvent=event))
[docs]
async def RegisterAgent( # type: ignore
self,
request: agent_worker_pb2.RegisterAgentTypeRequest,
context: grpc.aio.ServicerContext[
agent_worker_pb2.RegisterAgentTypeRequest, agent_worker_pb2.RegisterAgentTypeResponse
],
) -> agent_worker_pb2.RegisterAgentTypeResponse:
client_id = await get_client_id_or_abort(context)
async with self._agent_type_to_client_id_lock:
if request.type in self._agent_type_to_client_id:
existing_client_id = self._agent_type_to_client_id[request.type]
await context.abort(
grpc.StatusCode.INVALID_ARGUMENT,
f"Agent type {request.type} already registered with client {existing_client_id}.",
)
else:
self._agent_type_to_client_id[request.type] = client_id
return agent_worker_pb2.RegisterAgentTypeResponse()
[docs]
async def AddSubscription( # type: ignore
self,
request: agent_worker_pb2.AddSubscriptionRequest,
context: grpc.aio.ServicerContext[
agent_worker_pb2.AddSubscriptionRequest, agent_worker_pb2.AddSubscriptionResponse
],
) -> agent_worker_pb2.AddSubscriptionResponse:
client_id = await get_client_id_or_abort(context)
subscription = subscription_from_proto(request.subscription)
try:
await self._subscription_manager.add_subscription(subscription)
subscription_ids = self._client_id_to_subscription_id_mapping.setdefault(client_id, set())
subscription_ids.add(subscription.id)
except ValueError as e:
await context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e))
return agent_worker_pb2.AddSubscriptionResponse()
[docs]
async def RemoveSubscription( # type: ignore
self,
request: agent_worker_pb2.RemoveSubscriptionRequest,
context: grpc.aio.ServicerContext[
agent_worker_pb2.RemoveSubscriptionRequest, agent_worker_pb2.RemoveSubscriptionResponse
],
) -> agent_worker_pb2.RemoveSubscriptionResponse:
_client_id = await get_client_id_or_abort(context)
raise NotImplementedError("Method not implemented.")
[docs]
async def GetSubscriptions( # type: ignore
self,
request: agent_worker_pb2.GetSubscriptionsRequest,
context: grpc.aio.ServicerContext[
agent_worker_pb2.GetSubscriptionsRequest, agent_worker_pb2.GetSubscriptionsResponse
],
) -> agent_worker_pb2.GetSubscriptionsResponse:
_client_id = await get_client_id_or_abort(context)
raise NotImplementedError("Method not implemented.")
[docs]
async def GetState( # type: ignore
self,
request: agent_worker_pb2.AgentId,
context: grpc.aio.ServicerContext[agent_worker_pb2.AgentId, agent_worker_pb2.GetStateResponse],
) -> agent_worker_pb2.GetStateResponse:
raise NotImplementedError("Method not implemented!")
[docs]
async def SaveState( # type: ignore
self,
request: agent_worker_pb2.AgentState,
context: grpc.aio.ServicerContext[agent_worker_pb2.AgentId, agent_worker_pb2.SaveStateResponse],
) -> agent_worker_pb2.SaveStateResponse:
raise NotImplementedError("Method not implemented!")