import asyncio
import logging
from _collections_abc import AsyncIterator, Iterator
from asyncio import Future, Task
from typing import Any, Dict, Set, cast
from autogen_core import Subscription, TopicId, TypePrefixSubscription, TypeSubscription
from autogen_core._runtime_impl_helpers import SubscriptionManager
from ._constants import GRPC_IMPORT_ERROR_STR
try:
import grpc
except ImportError as e:
raise ImportError(GRPC_IMPORT_ERROR_STR) from e
from .protos import agent_worker_pb2, agent_worker_pb2_grpc, cloudevent_pb2
logger = logging.getLogger("autogen_core")
event_logger = logging.getLogger("autogen_core.events")
[docs]
class GrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer):
"""A gRPC servicer that hosts message delivery service for agents."""
def __init__(self) -> None:
self._client_id = 0
self._client_id_lock = asyncio.Lock()
self._send_queues: Dict[int, asyncio.Queue[agent_worker_pb2.Message]] = {}
self._agent_type_to_client_id_lock = asyncio.Lock()
self._agent_type_to_client_id: Dict[str, int] = {}
self._pending_responses: Dict[int, Dict[str, Future[Any]]] = {}
self._background_tasks: Set[Task[Any]] = set()
self._subscription_manager = SubscriptionManager()
self._client_id_to_subscription_id_mapping: Dict[int, set[str]] = {}
[docs]
async def OpenChannel( # type: ignore
self,
request_iterator: AsyncIterator[agent_worker_pb2.Message],
context: grpc.aio.ServicerContext[agent_worker_pb2.Message, agent_worker_pb2.Message],
) -> Iterator[agent_worker_pb2.Message] | AsyncIterator[agent_worker_pb2.Message]: # type: ignore
# Aquire the lock to get a new client id.
async with self._client_id_lock:
self._client_id += 1
client_id = self._client_id
# Register the client with the server and create a send queue for the client.
send_queue: asyncio.Queue[agent_worker_pb2.Message] = asyncio.Queue()
self._send_queues[client_id] = send_queue
logger.info(f"Client {client_id} connected.")
try:
# Concurrently handle receiving messages from the client and sending messages to the client.
# This task will receive messages from the client.
receiving_task = asyncio.create_task(self._receive_messages(client_id, request_iterator))
# Return an async generator that will yield messages from the send queue to the client.
while True:
message = await send_queue.get()
# Yield the message to the client.
try:
yield message
except Exception as e:
logger.error(f"Failed to send message to client {client_id}: {e}", exc_info=True)
break
logger.info(f"Sent message to client {client_id}: {message}")
# Wait for the receiving task to finish.
await receiving_task
finally:
# Clean up the client connection.
del self._send_queues[client_id]
# Cancel pending requests sent to this client.
for future in self._pending_responses.pop(client_id, {}).values():
future.cancel()
# Remove the client id from the agent type to client id mapping.
await self._on_client_disconnect(client_id)
async def _on_client_disconnect(self, client_id: int) -> None:
async with self._agent_type_to_client_id_lock:
agent_types = [agent_type for agent_type, id_ in self._agent_type_to_client_id.items() if id_ == client_id]
for agent_type in agent_types:
logger.info(f"Removing agent type {agent_type} from agent type to client id mapping")
del self._agent_type_to_client_id[agent_type]
for sub_id in self._client_id_to_subscription_id_mapping.get(client_id, set()):
logger.info(f"Client id {client_id} disconnected. Removing corresponding subscription with id {id}")
await self._subscription_manager.remove_subscription(sub_id)
logger.info(f"Client {client_id} disconnected successfully")
def _raise_on_exception(self, task: Task[Any]) -> None:
exception = task.exception()
if exception is not None:
raise exception
async def _receive_messages(
self, client_id: int, request_iterator: AsyncIterator[agent_worker_pb2.Message]
) -> None:
# Receive messages from the client and process them.
async for message in request_iterator:
logger.info(f"Received message from client {client_id}: {message}")
oneofcase = message.WhichOneof("message")
match oneofcase:
case "request":
request: agent_worker_pb2.RpcRequest = message.request
task = asyncio.create_task(self._process_request(request, client_id))
self._background_tasks.add(task)
task.add_done_callback(self._raise_on_exception)
task.add_done_callback(self._background_tasks.discard)
case "response":
response: agent_worker_pb2.RpcResponse = message.response
task = asyncio.create_task(self._process_response(response, client_id))
self._background_tasks.add(task)
task.add_done_callback(self._raise_on_exception)
task.add_done_callback(self._background_tasks.discard)
case "cloudEvent":
# The proto typing doesnt resolve this one
event = cast(cloudevent_pb2.CloudEvent, message.cloudEvent) # type: ignore
task = asyncio.create_task(self._process_event(event))
self._background_tasks.add(task)
task.add_done_callback(self._raise_on_exception)
task.add_done_callback(self._background_tasks.discard)
case "registerAgentTypeRequest":
register_agent_type: agent_worker_pb2.RegisterAgentTypeRequest = message.registerAgentTypeRequest
task = asyncio.create_task(
self._process_register_agent_type_request(register_agent_type, client_id)
)
self._background_tasks.add(task)
task.add_done_callback(self._raise_on_exception)
task.add_done_callback(self._background_tasks.discard)
case "addSubscriptionRequest":
add_subscription: agent_worker_pb2.AddSubscriptionRequest = message.addSubscriptionRequest
task = asyncio.create_task(self._process_add_subscription_request(add_subscription, client_id))
self._background_tasks.add(task)
task.add_done_callback(self._raise_on_exception)
task.add_done_callback(self._background_tasks.discard)
case "registerAgentTypeResponse" | "addSubscriptionResponse":
logger.warning(f"Received unexpected message type: {oneofcase}")
case None:
logger.warning("Received empty message")
async def _process_request(self, request: agent_worker_pb2.RpcRequest, client_id: int) -> None:
# Deliver the message to a client given the target agent type.
async with self._agent_type_to_client_id_lock:
target_client_id = self._agent_type_to_client_id.get(request.target.type)
if target_client_id is None:
logger.error(f"Agent {request.target.type} not found, failed to deliver message.")
return
target_send_queue = self._send_queues.get(target_client_id)
if target_send_queue is None:
logger.error(f"Client {target_client_id} not found, failed to deliver message.")
return
await target_send_queue.put(agent_worker_pb2.Message(request=request))
# Create a future to wait for the response from the target.
future = asyncio.get_event_loop().create_future()
self._pending_responses.setdefault(target_client_id, {})[request.request_id] = future
# Create a task to wait for the response and send it back to the client.
send_response_task = asyncio.create_task(self._wait_and_send_response(future, client_id))
self._background_tasks.add(send_response_task)
send_response_task.add_done_callback(self._raise_on_exception)
send_response_task.add_done_callback(self._background_tasks.discard)
async def _wait_and_send_response(self, future: Future[agent_worker_pb2.RpcResponse], client_id: int) -> None:
response = await future
message = agent_worker_pb2.Message(response=response)
send_queue = self._send_queues.get(client_id)
if send_queue is None:
logger.error(f"Client {client_id} not found, failed to send response message.")
return
await send_queue.put(message)
async def _process_response(self, response: agent_worker_pb2.RpcResponse, client_id: int) -> None:
# Setting the result of the future will send the response back to the original sender.
future = self._pending_responses[client_id].pop(response.request_id)
future.set_result(response)
async def _process_event(self, event: cloudevent_pb2.CloudEvent) -> None:
topic_id = TopicId(type=event.type, source=event.source)
recipients = await self._subscription_manager.get_subscribed_recipients(topic_id)
# Get the client ids of the recipients.
async with self._agent_type_to_client_id_lock:
client_ids: Set[int] = set()
for recipient in recipients:
client_id = self._agent_type_to_client_id.get(recipient.type)
if client_id is not None:
client_ids.add(client_id)
else:
logger.error(f"Agent {recipient.type} and its client not found for topic {topic_id}.")
# Deliver the event to clients.
for client_id in client_ids:
await self._send_queues[client_id].put(agent_worker_pb2.Message(cloudEvent=event))
async def _process_register_agent_type_request(
self, register_agent_type_req: agent_worker_pb2.RegisterAgentTypeRequest, client_id: int
) -> None:
# Register the agent type with the host runtime.
async with self._agent_type_to_client_id_lock:
if register_agent_type_req.type in self._agent_type_to_client_id:
existing_client_id = self._agent_type_to_client_id[register_agent_type_req.type]
logger.error(
f"Agent type {register_agent_type_req.type} already registered with client {existing_client_id}."
)
success = False
error = f"Agent type {register_agent_type_req.type} already registered."
else:
self._agent_type_to_client_id[register_agent_type_req.type] = client_id
success = True
error = None
# Send a response back to the client.
await self._send_queues[client_id].put(
agent_worker_pb2.Message(
registerAgentTypeResponse=agent_worker_pb2.RegisterAgentTypeResponse(
request_id=register_agent_type_req.request_id, success=success, error=error
)
)
)
async def _process_add_subscription_request(
self, add_subscription_req: agent_worker_pb2.AddSubscriptionRequest, client_id: int
) -> None:
oneofcase = add_subscription_req.subscription.WhichOneof("subscription")
subscription: Subscription | None = None
match oneofcase:
case "typeSubscription":
type_subscription_msg: agent_worker_pb2.TypeSubscription = (
add_subscription_req.subscription.typeSubscription
)
subscription = TypeSubscription(
topic_type=type_subscription_msg.topic_type, agent_type=type_subscription_msg.agent_type
)
case "typePrefixSubscription":
type_prefix_subscription_msg: agent_worker_pb2.TypePrefixSubscription = (
add_subscription_req.subscription.typePrefixSubscription
)
subscription = TypePrefixSubscription(
topic_type_prefix=type_prefix_subscription_msg.topic_type_prefix,
agent_type=type_prefix_subscription_msg.agent_type,
)
case None:
logger.warning("Received empty subscription message")
if subscription is not None:
try:
await self._subscription_manager.add_subscription(subscription)
subscription_ids = self._client_id_to_subscription_id_mapping.setdefault(client_id, set())
subscription_ids.add(subscription.id)
success = True
error = None
except ValueError as e:
success = False
error = str(e)
# Send a response back to the client.
await self._send_queues[client_id].put(
agent_worker_pb2.Message(
addSubscriptionResponse=agent_worker_pb2.AddSubscriptionResponse(
request_id=add_subscription_req.request_id, success=success, error=error
)
)
)
[docs]
async def GetState( # type: ignore
self,
request: agent_worker_pb2.AgentId,
context: grpc.aio.ServicerContext[agent_worker_pb2.AgentId, agent_worker_pb2.GetStateResponse],
) -> agent_worker_pb2.GetStateResponse: # type: ignore
raise NotImplementedError("Method not implemented!")
[docs]
async def SaveState( # type: ignore
self,
request: agent_worker_pb2.AgentState,
context: grpc.aio.ServicerContext[agent_worker_pb2.AgentId, agent_worker_pb2.SaveStateResponse],
) -> agent_worker_pb2.SaveStateResponse: # type: ignore
raise NotImplementedError("Method not implemented!")