from __future__ import annotations
import asyncio
import inspect
import json
import logging
import signal
import uuid
import warnings
from asyncio import Future, Task
from collections import defaultdict
from typing import (
    TYPE_CHECKING,
    Any,
    AsyncIterable,
    AsyncIterator,
    Awaitable,
    Callable,
    ClassVar,
    DefaultDict,
    Dict,
    List,
    Literal,
    Mapping,
    ParamSpec,
    Sequence,
    Set,
    Tuple,
    Type,
    TypeVar,
    cast,
)
from autogen_core import (
    JSON_DATA_CONTENT_TYPE,
    PROTOBUF_DATA_CONTENT_TYPE,
    Agent,
    AgentId,
    AgentInstantiationContext,
    AgentMetadata,
    AgentRuntime,
    AgentType,
    CancellationToken,
    MessageContext,
    MessageHandlerContext,
    MessageSerializer,
    Subscription,
    TopicId,
)
from autogen_core._runtime_impl_helpers import SubscriptionManager, get_impl
from autogen_core._serialization import (
    SerializationRegistry,
)
from autogen_core._telemetry import MessageRuntimeTracingConfig, TraceHelper, get_telemetry_grpc_metadata
from google.protobuf import any_pb2
from opentelemetry.trace import TracerProvider
from typing_extensions import Self
from autogen_ext.runtimes.grpc._utils import subscription_to_proto
from . import _constants
from ._constants import GRPC_IMPORT_ERROR_STR
from ._type_helpers import ChannelArgumentType
from .protos import agent_worker_pb2, agent_worker_pb2_grpc, cloudevent_pb2
try:
    import grpc.aio
except ImportError as e:
    raise ImportError(GRPC_IMPORT_ERROR_STR) from e
if TYPE_CHECKING:
    from .protos.agent_worker_pb2_grpc import AgentRpcAsyncStub
logger = logging.getLogger("autogen_core")
event_logger = logging.getLogger("autogen_core.events")
P = ParamSpec("P")
T = TypeVar("T", bound=Agent)
type_func_alias = type
class QueueAsyncIterable(AsyncIterator[Any], AsyncIterable[Any]):
    def __init__(self, queue: asyncio.Queue[Any]) -> None:
        self._queue = queue
    async def __anext__(self) -> Any:
        return await self._queue.get()
    def __aiter__(self) -> AsyncIterator[Any]:
        return self
class HostConnection:
    DEFAULT_GRPC_CONFIG: ClassVar[ChannelArgumentType] = [
        (
            "grpc.service_config",
            json.dumps(
                {
                    "methodConfig": [
                        {
                            "name": [{}],
                            "retryPolicy": {
                                "maxAttempts": 3,
                                "initialBackoff": "0.01s",
                                "maxBackoff": "5s",
                                "backoffMultiplier": 2,
                                "retryableStatusCodes": ["UNAVAILABLE"],
                            },
                        }
                    ],
                }
            ),
        )
    ]
    def __init__(self, channel: grpc.aio.Channel, stub: Any) -> None:  # type: ignore
        self._channel = channel
        self._send_queue = asyncio.Queue[agent_worker_pb2.Message]()
        self._recv_queue = asyncio.Queue[agent_worker_pb2.Message]()
        self._connection_task: Task[None] | None = None
        self._stub: AgentRpcAsyncStub = stub
        self._client_id = str(uuid.uuid4())
    @property
    def stub(self) -> Any:
        return self._stub
    @property
    def metadata(self) -> Sequence[Tuple[str, str]]:
        return [("client-id", self._client_id)]
    @classmethod
    def from_host_address(cls, host_address: str, extra_grpc_config: ChannelArgumentType = DEFAULT_GRPC_CONFIG) -> Self:
        logger.info("Connecting to %s", host_address)
        #  Always use DEFAULT_GRPC_CONFIG and override it with provided grpc_config
        merged_options = [
            (k, v) for k, v in {**dict(HostConnection.DEFAULT_GRPC_CONFIG), **dict(extra_grpc_config)}.items()
        ]
        channel = grpc.aio.insecure_channel(
            host_address,
            options=merged_options,
        )
        stub: AgentRpcAsyncStub = agent_worker_pb2_grpc.AgentRpcStub(channel)  # type: ignore
        instance = cls(channel, stub)
        instance._connection_task = asyncio.create_task(
            instance._connect(stub, instance._send_queue, instance._recv_queue, instance._client_id)
        )
        return instance
    async def close(self) -> None:
        if self._connection_task is None:
            raise RuntimeError("Connection is not open.")
        await self._channel.close()
        await self._connection_task
    @staticmethod
    async def _connect(
        stub: Any,  # AgentRpcAsyncStub
        send_queue: asyncio.Queue[agent_worker_pb2.Message],
        receive_queue: asyncio.Queue[agent_worker_pb2.Message],
        client_id: str,
    ) -> None:
        from grpc.aio import StreamStreamCall
        # TODO: where do exceptions from reading the iterable go? How do we recover from those?
        recv_stream: StreamStreamCall[agent_worker_pb2.Message, agent_worker_pb2.Message] = stub.OpenChannel(  # type: ignore
            QueueAsyncIterable(send_queue), metadata=[("client-id", client_id)]
        )
        while True:
            logger.info("Waiting for message from host")
            message = cast(agent_worker_pb2.Message, await recv_stream.read())  # type: ignore
            if message == grpc.aio.EOF:  # type: ignore
                logger.info("EOF")
                break
            logger.info(f"Received a message from host: {message}")
            await receive_queue.put(message)
            logger.info("Put message in receive queue")
    async def send(self, message: agent_worker_pb2.Message) -> None:
        logger.info(f"Send message to host: {message}")
        await self._send_queue.put(message)
        logger.info("Put message in send queue")
    async def recv(self) -> agent_worker_pb2.Message:
        logger.info("Getting message from queue")
        return await self._recv_queue.get()
# TODO: Lots of types need to have protobuf equivalents:
# Core:
#   - FunctionCall, CodeResult, possibly CodeBlock
#   - All the types in https://github.com/microsoft/autogen/blob/main/python/packages/autogen-core/src/autogen_core/models/_types.py
#
# Agentchat:
#   - All the types in https://github.com/microsoft/autogen/blob/main/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py to protobufs.
#
# Ext --
#   CodeExecutor:
#       - CommandLineCodeResult
[docs]
class GrpcWorkerAgentRuntime(AgentRuntime):
    """An agent runtime for running remote or cross-language agents.
    Agent messaging uses protobufs from `agent_worker.proto`_ and ``CloudEvent`` from `cloudevent.proto`_.
    Cross-language agents will additionally require all agents use shared protobuf schemas for any message types that are sent between agents.
    .. _agent_worker.proto: https://github.com/microsoft/autogen/blob/main/protos/agent_worker.proto
    .. _cloudevent.proto: https://github.com/microsoft/autogen/blob/main/protos/cloudevent.proto
    """
    # TODO: Needs to handle agent close() call
    def __init__(
        self,
        host_address: str,
        tracer_provider: TracerProvider | None = None,
        extra_grpc_config: ChannelArgumentType | None = None,
        payload_serialization_format: str = JSON_DATA_CONTENT_TYPE,
    ) -> None:
        self._host_address = host_address
        self._trace_helper = TraceHelper(tracer_provider, MessageRuntimeTracingConfig("Worker Runtime"))
        self._per_type_subscribers: DefaultDict[tuple[str, str], Set[AgentId]] = defaultdict(set)
        self._agent_factories: Dict[
            str, Callable[[], Agent | Awaitable[Agent]] | Callable[[AgentRuntime, AgentId], Agent | Awaitable[Agent]]
        ] = {}
        self._instantiated_agents: Dict[AgentId, Agent] = {}
        self._known_namespaces: set[str] = set()
        self._read_task: None | Task[None] = None
        self._running = False
        self._pending_requests: Dict[str, Future[Any]] = {}
        self._pending_requests_lock = asyncio.Lock()
        self._next_request_id = 0
        self._host_connection: HostConnection | None = None
        self._background_tasks: Set[Task[Any]] = set()
        self._subscription_manager = SubscriptionManager()
        self._serialization_registry = SerializationRegistry()
        self._extra_grpc_config = extra_grpc_config or []
        if payload_serialization_format not in {JSON_DATA_CONTENT_TYPE, PROTOBUF_DATA_CONTENT_TYPE}:
            raise ValueError(f"Unsupported payload serialization format: {payload_serialization_format}")
        self._payload_serialization_format = payload_serialization_format
[docs]
    def start(self) -> None:
        """Start the runtime in a background task."""
        if self._running:
            raise ValueError("Runtime is already running.")
        logger.info(f"Connecting to host: {self._host_address}")
        self._host_connection = HostConnection.from_host_address(
            self._host_address, extra_grpc_config=self._extra_grpc_config
        )
        logger.info("Connection established")
        if self._read_task is None:
            self._read_task = asyncio.create_task(self._run_read_loop())
        self._running = True 
    def _raise_on_exception(self, task: Task[Any]) -> None:
        exception = task.exception()
        if exception is not None:
            raise exception
    async def _run_read_loop(self) -> None:
        logger.info("Starting read loop")
        assert self._host_connection is not None
        # TODO: catch exceptions and reconnect
        while self._running:
            try:
                message = await self._host_connection.recv()
                oneofcase = agent_worker_pb2.Message.WhichOneof(message, "message")
                match oneofcase:
                    case "request":
                        task = asyncio.create_task(self._process_request(message.request))
                        self._background_tasks.add(task)
                        task.add_done_callback(self._raise_on_exception)
                        task.add_done_callback(self._background_tasks.discard)
                    case "response":
                        task = asyncio.create_task(self._process_response(message.response))
                        self._background_tasks.add(task)
                        task.add_done_callback(self._raise_on_exception)
                        task.add_done_callback(self._background_tasks.discard)
                    case "cloudEvent":
                        task = asyncio.create_task(self._process_event(message.cloudEvent))
                        self._background_tasks.add(task)
                        task.add_done_callback(self._raise_on_exception)
                        task.add_done_callback(self._background_tasks.discard)
                    case None:
                        logger.warning("No message")
            except Exception as e:
                logger.error("Error in read loop", exc_info=e)
[docs]
    async def stop(self) -> None:
        """Stop the runtime immediately."""
        if not self._running:
            raise RuntimeError("Runtime is not running.")
        self._running = False
        # Wait for all background tasks to finish.
        final_tasks_results = await asyncio.gather(*self._background_tasks, return_exceptions=True)
        for task_result in final_tasks_results:
            if isinstance(task_result, Exception):
                logger.error("Error in background task", exc_info=task_result)
        # Close the host connection.
        if self._host_connection is not None:
            try:
                await self._host_connection.close()
            except asyncio.CancelledError:
                pass
        # Cancel the read task.
        if self._read_task is not None:
            self._read_task.cancel()
            try:
                await self._read_task
            except asyncio.CancelledError:
                pass 
[docs]
    async def stop_when_signal(self, signals: Sequence[signal.Signals] = (signal.SIGTERM, signal.SIGINT)) -> None:
        """Stop the runtime when a signal is received."""
        loop = asyncio.get_running_loop()
        shutdown_event = asyncio.Event()
        def signal_handler() -> None:
            logger.info("Received exit signal, shutting down gracefully...")
            shutdown_event.set()
        for sig in signals:
            loop.add_signal_handler(sig, signal_handler)
        # Wait for the signal to trigger the shutdown event.
        await shutdown_event.wait()
        # Stop the runtime.
        await self.stop() 
    @property
    def _known_agent_names(self) -> Set[str]:
        return set(self._agent_factories.keys())
    async def _send_message(
        self,
        runtime_message: agent_worker_pb2.Message,
        send_type: Literal["send", "publish"],
        recipient: AgentId | TopicId,
        telemetry_metadata: Mapping[str, str],
    ) -> None:
        if self._host_connection is None:
            raise RuntimeError("Host connection is not set.")
        with self._trace_helper.trace_block(send_type, recipient, parent=telemetry_metadata):
            await self._host_connection.send(runtime_message)
[docs]
    async def send_message(
        self,
        message: Any,
        recipient: AgentId,
        *,
        sender: AgentId | None = None,
        cancellation_token: CancellationToken | None = None,
        message_id: str | None = None,
    ) -> Any:
        # TODO: use message_id
        if not self._running:
            raise ValueError("Runtime must be running when sending message.")
        if self._host_connection is None:
            raise RuntimeError("Host connection is not set.")
        data_type = self._serialization_registry.type_name(message)
        with self._trace_helper.trace_block(
            "create", recipient, parent=None, extraAttributes={"message_type": data_type}
        ):
            # create a new future for the result
            future = asyncio.get_event_loop().create_future()
            request_id = await self._get_new_request_id()
            self._pending_requests[request_id] = future
            serialized_message = self._serialization_registry.serialize(
                message, type_name=data_type, data_content_type=JSON_DATA_CONTENT_TYPE
            )
            telemetry_metadata = get_telemetry_grpc_metadata()
            runtime_message = agent_worker_pb2.Message(
                request=agent_worker_pb2.RpcRequest(
                    request_id=request_id,
                    target=agent_worker_pb2.AgentId(type=recipient.type, key=recipient.key),
                    source=agent_worker_pb2.AgentId(type=sender.type, key=sender.key) if sender is not None else None,
                    metadata=telemetry_metadata,
                    payload=agent_worker_pb2.Payload(
                        data_type=data_type,
                        data=serialized_message,
                        data_content_type=JSON_DATA_CONTENT_TYPE,
                    ),
                )
            )
            # TODO: Find a way to handle timeouts/errors
            task = asyncio.create_task(self._send_message(runtime_message, "send", recipient, telemetry_metadata))
            self._background_tasks.add(task)
            task.add_done_callback(self._raise_on_exception)
            task.add_done_callback(self._background_tasks.discard)
            return await future 
[docs]
    async def publish_message(
        self,
        message: Any,
        topic_id: TopicId,
        *,
        sender: AgentId | None = None,
        cancellation_token: CancellationToken | None = None,
        message_id: str | None = None,
    ) -> None:
        if not self._running:
            raise ValueError("Runtime must be running when publishing message.")
        if self._host_connection is None:
            raise RuntimeError("Host connection is not set.")
        if message_id is None:
            message_id = str(uuid.uuid4())
        message_type = self._serialization_registry.type_name(message)
        with self._trace_helper.trace_block(
            "create", topic_id, parent=None, extraAttributes={"message_type": message_type}
        ):
            serialized_message = self._serialization_registry.serialize(
                message, type_name=message_type, data_content_type=self._payload_serialization_format
            )
            sender_id = sender or AgentId("unknown", "unknown")
            attributes = {
                _constants.DATA_CONTENT_TYPE_ATTR: cloudevent_pb2.CloudEvent.CloudEventAttributeValue(
                    ce_string=self._payload_serialization_format
                ),
                _constants.DATA_SCHEMA_ATTR: cloudevent_pb2.CloudEvent.CloudEventAttributeValue(ce_string=message_type),
                _constants.AGENT_SENDER_TYPE_ATTR: cloudevent_pb2.CloudEvent.CloudEventAttributeValue(
                    ce_string=sender_id.type
                ),
                _constants.AGENT_SENDER_KEY_ATTR: cloudevent_pb2.CloudEvent.CloudEventAttributeValue(
                    ce_string=sender_id.key
                ),
                _constants.MESSAGE_KIND_ATTR: cloudevent_pb2.CloudEvent.CloudEventAttributeValue(
                    ce_string=_constants.MESSAGE_KIND_VALUE_PUBLISH
                ),
            }
            # If sending JSON we fill text_data with the serialized message
            # If sending Protobuf we fill proto_data with the serialized message
            # TODO: add an encoding field for serializer
            if self._payload_serialization_format == JSON_DATA_CONTENT_TYPE:
                runtime_message = agent_worker_pb2.Message(
                    cloudEvent=cloudevent_pb2.CloudEvent(
                        id=message_id,
                        spec_version="1.0",
                        type=topic_id.type,
                        source=topic_id.source,
                        attributes=attributes,
                        # TODO: use text, or proto fields appropriately
                        binary_data=serialized_message,
                    )
                )
            else:
                # We need to unpack the serialized proto back into an Any
                # TODO: find a way to prevent the roundtrip serialization
                any_proto = any_pb2.Any()
                any_proto.ParseFromString(serialized_message)
                runtime_message = agent_worker_pb2.Message(
                    cloudEvent=cloudevent_pb2.CloudEvent(
                        id=message_id,
                        spec_version="1.0",
                        type=topic_id.type,
                        source=topic_id.source,
                        attributes=attributes,
                        proto_data=any_proto,
                    )
                )
            telemetry_metadata = get_telemetry_grpc_metadata()
            task = asyncio.create_task(self._send_message(runtime_message, "publish", topic_id, telemetry_metadata))
            self._background_tasks.add(task)
            task.add_done_callback(self._raise_on_exception)
            task.add_done_callback(self._background_tasks.discard) 
[docs]
    async def save_state(self) -> Mapping[str, Any]:
        raise NotImplementedError("Saving state is not yet implemented.") 
[docs]
    async def load_state(self, state: Mapping[str, Any]) -> None:
        raise NotImplementedError("Loading state is not yet implemented.") 
[docs]
    async def agent_save_state(self, agent: AgentId) -> Mapping[str, Any]:
        raise NotImplementedError("Agent save_state is not yet implemented.") 
[docs]
    async def agent_load_state(self, agent: AgentId, state: Mapping[str, Any]) -> None:
        raise NotImplementedError("Agent load_state is not yet implemented.") 
    async def _get_new_request_id(self) -> str:
        async with self._pending_requests_lock:
            self._next_request_id += 1
            return str(self._next_request_id)
    async def _process_request(self, request: agent_worker_pb2.RpcRequest) -> None:
        assert self._host_connection is not None
        recipient = AgentId(request.target.type, request.target.key)
        sender: AgentId | None = None
        if request.HasField("source"):
            sender = AgentId(request.source.type, request.source.key)
            logging.info(f"Processing request from {sender} to {recipient}")
        else:
            logging.info(f"Processing request from unknown source to {recipient}")
        # Deserialize the message.
        message = self._serialization_registry.deserialize(
            request.payload.data,
            type_name=request.payload.data_type,
            data_content_type=request.payload.data_content_type,
        )
        # Get the receiving agent and prepare the message context.
        rec_agent = await self._get_agent(recipient)
        message_context = MessageContext(
            sender=sender,
            topic_id=None,
            is_rpc=True,
            cancellation_token=CancellationToken(),
            message_id=request.request_id,
        )
        # Call the receiving agent.
        try:
            with MessageHandlerContext.populate_context(rec_agent.id):
                with self._trace_helper.trace_block(
                    "process",
                    rec_agent.id,
                    parent=request.metadata,
                    attributes={"request_id": request.request_id},
                    extraAttributes={"message_type": request.payload.data_type},
                ):
                    result = await rec_agent.on_message(message, ctx=message_context)
        except BaseException as e:
            response_message = agent_worker_pb2.Message(
                response=agent_worker_pb2.RpcResponse(
                    request_id=request.request_id,
                    error=str(e),
                    metadata=get_telemetry_grpc_metadata(),
                ),
            )
            # Send the error response.
            await self._host_connection.send(response_message)
            return
        # Serialize the result.
        result_type = self._serialization_registry.type_name(result)
        serialized_result = self._serialization_registry.serialize(
            result, type_name=result_type, data_content_type=JSON_DATA_CONTENT_TYPE
        )
        # Create the response message.
        response_message = agent_worker_pb2.Message(
            response=agent_worker_pb2.RpcResponse(
                request_id=request.request_id,
                payload=agent_worker_pb2.Payload(
                    data_type=result_type,
                    data=serialized_result,
                    data_content_type=JSON_DATA_CONTENT_TYPE,
                ),
                metadata=get_telemetry_grpc_metadata(),
            )
        )
        # Send the response.
        await self._host_connection.send(response_message)
    async def _process_response(self, response: agent_worker_pb2.RpcResponse) -> None:
        with self._trace_helper.trace_block(
            "ack",
            None,
            parent=response.metadata,
            attributes={"request_id": response.request_id},
            extraAttributes={"message_type": response.payload.data_type},
        ):
            # Deserialize the result.
            result = self._serialization_registry.deserialize(
                response.payload.data,
                type_name=response.payload.data_type,
                data_content_type=response.payload.data_content_type,
            )
            # Get the future and set the result.
            future = self._pending_requests.pop(response.request_id)
            if len(response.error) > 0:
                future.set_exception(Exception(response.error))
            else:
                future.set_result(result)
    async def _process_event(self, event: cloudevent_pb2.CloudEvent) -> None:
        event_attributes = event.attributes
        sender: AgentId | None = None
        if (
            _constants.AGENT_SENDER_TYPE_ATTR in event_attributes
            and _constants.AGENT_SENDER_KEY_ATTR in event_attributes
        ):
            sender = AgentId(
                event_attributes[_constants.AGENT_SENDER_TYPE_ATTR].ce_string,
                event_attributes[_constants.AGENT_SENDER_KEY_ATTR].ce_string,
            )
        topic_id = TopicId(event.type, event.source)
        # Get the recipients for the topic.
        recipients = await self._subscription_manager.get_subscribed_recipients(topic_id)
        message_content_type = event_attributes[_constants.DATA_CONTENT_TYPE_ATTR].ce_string
        message_type = event_attributes[_constants.DATA_SCHEMA_ATTR].ce_string
        if message_content_type == JSON_DATA_CONTENT_TYPE:
            message = self._serialization_registry.deserialize(
                event.binary_data, type_name=message_type, data_content_type=message_content_type
            )
        elif message_content_type == PROTOBUF_DATA_CONTENT_TYPE:
            # TODO: find a way to prevent the roundtrip serialization
            proto_binary_data = event.proto_data.SerializeToString()
            message = self._serialization_registry.deserialize(
                proto_binary_data, type_name=message_type, data_content_type=message_content_type
            )
        else:
            raise ValueError(f"Unsupported message content type: {message_content_type}")
        # TODO: dont read these values in the runtime
        topic_type_suffix = topic_id.type.split(":", maxsplit=1)[1] if ":" in topic_id.type else ""
        is_rpc = topic_type_suffix == _constants.MESSAGE_KIND_VALUE_RPC_REQUEST
        is_marked_rpc_type = (
            _constants.MESSAGE_KIND_ATTR in event_attributes
            and event_attributes[_constants.MESSAGE_KIND_ATTR].ce_string == _constants.MESSAGE_KIND_VALUE_RPC_REQUEST
        )
        if is_rpc and not is_marked_rpc_type:
            warnings.warn("Received RPC request with topic type suffix but not marked as RPC request.", stacklevel=2)
        # Send the message to each recipient.
        responses: List[Awaitable[Any]] = []
        for agent_id in recipients:
            if agent_id == sender:
                continue
            message_context = MessageContext(
                sender=sender,
                topic_id=topic_id,
                is_rpc=is_rpc,
                cancellation_token=CancellationToken(),
                message_id=event.id,
            )
            agent = await self._get_agent(agent_id)
            with MessageHandlerContext.populate_context(agent.id):
                def stringify_attributes(
                    attributes: Mapping[str, cloudevent_pb2.CloudEvent.CloudEventAttributeValue],
                ) -> Mapping[str, str]:
                    result: Dict[str, str] = {}
                    for key, value in attributes.items():
                        item = None
                        match value.WhichOneof("attr"):
                            case "ce_boolean":
                                item = str(value.ce_boolean)
                            case "ce_integer":
                                item = str(value.ce_integer)
                            case "ce_string":
                                item = value.ce_string
                            case "ce_bytes":
                                item = str(value.ce_bytes)
                            case "ce_uri":
                                item = value.ce_uri
                            case "ce_uri_ref":
                                item = value.ce_uri_ref
                            case "ce_timestamp":
                                item = str(value.ce_timestamp)
                            case _:
                                raise ValueError("Unknown attribute kind")
                        result[key] = item
                    return result
                async def send_message(agent: Agent, message_context: MessageContext) -> Any:
                    with self._trace_helper.trace_block(
                        "process",
                        agent.id,
                        parent=stringify_attributes(event.attributes),
                        extraAttributes={"message_type": message_type},
                    ):
                        await agent.on_message(message, ctx=message_context)
                future = send_message(agent, message_context)
            responses.append(future)
        # Wait for all responses.
        try:
            await asyncio.gather(*responses)
        except BaseException as e:
            logger.error("Error handling event", exc_info=e)
[docs]
    async def register_factory(
        self,
        type: str | AgentType,
        agent_factory: Callable[[], T | Awaitable[T]],
        *,
        expected_class: type[T] | None = None,
    ) -> AgentType:
        if isinstance(type, str):
            type = AgentType(type)
        if type.type in self._agent_factories:
            raise ValueError(f"Agent with type {type} already exists.")
        if self._host_connection is None:
            raise RuntimeError("Host connection is not set.")
        async def factory_wrapper() -> T:
            maybe_agent_instance = agent_factory()
            if inspect.isawaitable(maybe_agent_instance):
                agent_instance = await maybe_agent_instance
            else:
                agent_instance = maybe_agent_instance
            if type_func_alias(agent_instance) != expected_class:
                raise ValueError("Factory registered using the wrong type.")
            return agent_instance
        self._agent_factories[type.type] = factory_wrapper
        # Send the registration request message to the host.
        message = agent_worker_pb2.RegisterAgentTypeRequest(type=type.type)
        _response: agent_worker_pb2.RegisterAgentTypeResponse = await self._host_connection.stub.RegisterAgent(
            message, metadata=self._host_connection.metadata
        )
        return type 
    async def _invoke_agent_factory(
        self,
        agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
        agent_id: AgentId,
    ) -> T:
        with AgentInstantiationContext.populate_context((self, agent_id)):
            if len(inspect.signature(agent_factory).parameters) == 0:
                factory_one = cast(Callable[[], T], agent_factory)
                agent = factory_one()
            elif len(inspect.signature(agent_factory).parameters) == 2:
                warnings.warn(
                    "Agent factories that take two arguments are deprecated. Use AgentInstantiationContext instead. Two arg factories will be removed in a future version.",
                    stacklevel=2,
                )
                factory_two = cast(Callable[[AgentRuntime, AgentId], T], agent_factory)
                agent = factory_two(self, agent_id)
            else:
                raise ValueError("Agent factory must take 0 or 2 arguments.")
            if inspect.isawaitable(agent):
                return cast(T, await agent)
        return agent
    async def _get_agent(self, agent_id: AgentId) -> Agent:
        if agent_id in self._instantiated_agents:
            return self._instantiated_agents[agent_id]
        if agent_id.type not in self._agent_factories:
            raise ValueError(f"Agent with name {agent_id.type} not found.")
        agent_factory = self._agent_factories[agent_id.type]
        agent = await self._invoke_agent_factory(agent_factory, agent_id)
        self._instantiated_agents[agent_id] = agent
        return agent
    # TODO: uncomment out the following type ignore when this is fixed in mypy: https://github.com/python/mypy/issues/3737
[docs]
    async def try_get_underlying_agent_instance(self, id: AgentId, type: Type[T] = Agent) -> T:  # type: ignore[assignment]
        if id.type not in self._agent_factories:
            raise LookupError(f"Agent with name {id.type} not found.")
        # TODO: check if remote
        agent_instance = await self._get_agent(id)
        if not isinstance(agent_instance, type):
            raise TypeError(f"Agent with name {id.type} is not of type {type.__name__}")
        return agent_instance 
[docs]
    async def add_subscription(self, subscription: Subscription) -> None:
        if self._host_connection is None:
            raise RuntimeError("Host connection is not set.")
        message = agent_worker_pb2.AddSubscriptionRequest(subscription=subscription_to_proto(subscription))
        _response: agent_worker_pb2.AddSubscriptionResponse = await self._host_connection.stub.AddSubscription(
            message, metadata=self._host_connection.metadata
        )
        # Add to local subscription manager.
        await self._subscription_manager.add_subscription(subscription) 
[docs]
    async def remove_subscription(self, id: str) -> None:
        raise NotImplementedError("Subscriptions cannot be removed while using distributed runtime currently.") 
[docs]
    async def get(
        self, id_or_type: AgentId | AgentType | str, /, key: str = "default", *, lazy: bool = True
    ) -> AgentId:
        return await get_impl(
            id_or_type=id_or_type,
            key=key,
            lazy=lazy,
            instance_getter=self._get_agent,
        ) 
[docs]
    def add_message_serializer(self, serializer: MessageSerializer[Any] | Sequence[MessageSerializer[Any]]) -> None:
        self._serialization_registry.add_serializer(serializer)