{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Custom Agents\n",
    "\n",
    "You may have agents with behaviors that do not fall into a preset. \n",
    "In such cases, you can build custom agents.\n",
    "\n",
    "All agents in AgentChat inherit from {py:class}`~autogen_agentchat.agents.BaseChatAgent` \n",
    "class and implement the following abstract methods and attributes:\n",
    "\n",
    "- {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_messages`: The abstract method that defines the behavior of the agent in response to messages. This method is called when the agent is asked to provide a response in {py:meth}`~autogen_agentchat.agents.BaseChatAgent.run`. It returns a {py:class}`~autogen_agentchat.base.Response` object.\n",
    "- {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_reset`: The abstract method that resets the agent to its initial state. This method is called when the agent is asked to reset itself.\n",
    "- {py:attr}`~autogen_agentchat.agents.BaseChatAgent.produced_message_types`: The list of possible {py:class}`~autogen_agentchat.messages.ChatMessage` message types the agent can produce in its response.\n",
    "\n",
    "Optionally, you can implement the the {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_messages_stream` method to stream messages as they are generated by the agent. If this method is not implemented, the agent\n",
    "uses the default implementation of {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_messages_stream`\n",
    "that calls the {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_messages` method and\n",
    "yields all messages in the response."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## CountDownAgent\n",
    "\n",
    "In this example, we create a simple agent that counts down from a given number to zero,\n",
    "and produces a stream of messages with the current count."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3...\n",
      "2...\n",
      "1...\n",
      "Done!\n"
     ]
    }
   ],
   "source": [
    "from typing import AsyncGenerator, List, Sequence\n",
    "\n",
    "from autogen_agentchat.agents import BaseChatAgent\n",
    "from autogen_agentchat.base import Response\n",
    "from autogen_agentchat.messages import AgentEvent, ChatMessage, TextMessage\n",
    "from autogen_core import CancellationToken\n",
    "\n",
    "\n",
    "class CountDownAgent(BaseChatAgent):\n",
    "    def __init__(self, name: str, count: int = 3):\n",
    "        super().__init__(name, \"A simple agent that counts down.\")\n",
    "        self._count = count\n",
    "\n",
    "    @property\n",
    "    def produced_message_types(self) -> Sequence[type[ChatMessage]]:\n",
    "        return (TextMessage,)\n",
    "\n",
    "    async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:\n",
    "        # Calls the on_messages_stream.\n",
    "        response: Response | None = None\n",
    "        async for message in self.on_messages_stream(messages, cancellation_token):\n",
    "            if isinstance(message, Response):\n",
    "                response = message\n",
    "        assert response is not None\n",
    "        return response\n",
    "\n",
    "    async def on_messages_stream(\n",
    "        self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken\n",
    "    ) -> AsyncGenerator[AgentEvent | ChatMessage | Response, None]:\n",
    "        inner_messages: List[AgentEvent | ChatMessage] = []\n",
    "        for i in range(self._count, 0, -1):\n",
    "            msg = TextMessage(content=f\"{i}...\", source=self.name)\n",
    "            inner_messages.append(msg)\n",
    "            yield msg\n",
    "        # The response is returned at the end of the stream.\n",
    "        # It contains the final message and all the inner messages.\n",
    "        yield Response(chat_message=TextMessage(content=\"Done!\", source=self.name), inner_messages=inner_messages)\n",
    "\n",
    "    async def on_reset(self, cancellation_token: CancellationToken) -> None:\n",
    "        pass\n",
    "\n",
    "\n",
    "async def run_countdown_agent() -> None:\n",
    "    # Create a countdown agent.\n",
    "    countdown_agent = CountDownAgent(\"countdown\")\n",
    "\n",
    "    # Run the agent with a given task and stream the response.\n",
    "    async for message in countdown_agent.on_messages_stream([], CancellationToken()):\n",
    "        if isinstance(message, Response):\n",
    "            print(message.chat_message.content)\n",
    "        else:\n",
    "            print(message.content)\n",
    "\n",
    "\n",
    "# Use asyncio.run(run_countdown_agent()) when running in a script.\n",
    "await run_countdown_agent()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## ArithmeticAgent\n",
    "\n",
    "In this example, we create an agent class that can perform simple arithmetic operations\n",
    "on a given integer. Then, we will use different instances of this agent class\n",
    "in a {py:class}`~autogen_agentchat.teams.SelectorGroupChat`\n",
    "to transform a given integer into another integer by applying a sequence of arithmetic operations.\n",
    "\n",
    "The `ArithmeticAgent` class takes an `operator_func` that takes an integer and returns an integer,\n",
    "after applying an arithmetic operation to the integer.\n",
    "In its `on_messages` method, it applies the `operator_func` to the integer in the input message,\n",
    "and returns a response with the result."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Callable, Sequence\n",
    "\n",
    "from autogen_agentchat.agents import BaseChatAgent\n",
    "from autogen_agentchat.base import Response\n",
    "from autogen_agentchat.conditions import MaxMessageTermination\n",
    "from autogen_agentchat.messages import ChatMessage\n",
    "from autogen_agentchat.teams import SelectorGroupChat\n",
    "from autogen_agentchat.ui import Console\n",
    "from autogen_core import CancellationToken\n",
    "from autogen_ext.models.openai import OpenAIChatCompletionClient\n",
    "\n",
    "\n",
    "class ArithmeticAgent(BaseChatAgent):\n",
    "    def __init__(self, name: str, description: str, operator_func: Callable[[int], int]) -> None:\n",
    "        super().__init__(name, description=description)\n",
    "        self._operator_func = operator_func\n",
    "        self._message_history: List[ChatMessage] = []\n",
    "\n",
    "    @property\n",
    "    def produced_message_types(self) -> Sequence[type[ChatMessage]]:\n",
    "        return (TextMessage,)\n",
    "\n",
    "    async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:\n",
    "        # Update the message history.\n",
    "        # NOTE: it is possible the messages is an empty list, which means the agent was selected previously.\n",
    "        self._message_history.extend(messages)\n",
    "        # Parse the number in the last message.\n",
    "        assert isinstance(self._message_history[-1], TextMessage)\n",
    "        number = int(self._message_history[-1].content)\n",
    "        # Apply the operator function to the number.\n",
    "        result = self._operator_func(number)\n",
    "        # Create a new message with the result.\n",
    "        response_message = TextMessage(content=str(result), source=self.name)\n",
    "        # Update the message history.\n",
    "        self._message_history.append(response_message)\n",
    "        # Return the response.\n",
    "        return Response(chat_message=response_message)\n",
    "\n",
    "    async def on_reset(self, cancellation_token: CancellationToken) -> None:\n",
    "        pass"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```{note}\n",
    "The `on_messages` method may be called with an empty list of messages, in which\n",
    "case it means the agent was called previously and is now being called again,\n",
    "without any new messages from the caller. So it is important to keep a history\n",
    "of the previous messages received by the agent, and use that history to generate\n",
    "the response.\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we can create a {py:class}`~autogen_agentchat.teams.SelectorGroupChat` with 5 instances of `ArithmeticAgent`:\n",
    "\n",
    "- one that adds 1 to the input integer,\n",
    "- one that subtracts 1 from the input integer,\n",
    "- one that multiplies the input integer by 2,\n",
    "- one that divides the input integer by 2 and rounds down to the nearest integer, and\n",
    "- one that returns the input integer unchanged.\n",
    "\n",
    "We then create a {py:class}`~autogen_agentchat.teams.SelectorGroupChat` with these agents,\n",
    "and set the appropriate selector settings:\n",
    "\n",
    "- allow the same agent to be selected consecutively to allow for repeated operations, and\n",
    "- customize the selector prompt to tailor the model's response to the specific task."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "---------- user ----------\n",
      "Apply the operations to turn the given number into 25.\n",
      "---------- user ----------\n",
      "10\n",
      "---------- multiply_agent ----------\n",
      "20\n",
      "---------- add_agent ----------\n",
      "21\n",
      "---------- multiply_agent ----------\n",
      "42\n",
      "---------- divide_agent ----------\n",
      "21\n",
      "---------- add_agent ----------\n",
      "22\n",
      "---------- add_agent ----------\n",
      "23\n",
      "---------- add_agent ----------\n",
      "24\n",
      "---------- add_agent ----------\n",
      "25\n",
      "---------- Summary ----------\n",
      "Number of messages: 10\n",
      "Finish reason: Maximum number of messages 10 reached, current message count: 10\n",
      "Total prompt tokens: 0\n",
      "Total completion tokens: 0\n",
      "Duration: 2.40 seconds\n"
     ]
    }
   ],
   "source": [
    "async def run_number_agents() -> None:\n",
    "    # Create agents for number operations.\n",
    "    add_agent = ArithmeticAgent(\"add_agent\", \"Adds 1 to the number.\", lambda x: x + 1)\n",
    "    multiply_agent = ArithmeticAgent(\"multiply_agent\", \"Multiplies the number by 2.\", lambda x: x * 2)\n",
    "    subtract_agent = ArithmeticAgent(\"subtract_agent\", \"Subtracts 1 from the number.\", lambda x: x - 1)\n",
    "    divide_agent = ArithmeticAgent(\"divide_agent\", \"Divides the number by 2 and rounds down.\", lambda x: x // 2)\n",
    "    identity_agent = ArithmeticAgent(\"identity_agent\", \"Returns the number as is.\", lambda x: x)\n",
    "\n",
    "    # The termination condition is to stop after 10 messages.\n",
    "    termination_condition = MaxMessageTermination(10)\n",
    "\n",
    "    # Create a selector group chat.\n",
    "    selector_group_chat = SelectorGroupChat(\n",
    "        [add_agent, multiply_agent, subtract_agent, divide_agent, identity_agent],\n",
    "        model_client=OpenAIChatCompletionClient(model=\"gpt-4o\"),\n",
    "        termination_condition=termination_condition,\n",
    "        allow_repeated_speaker=True,  # Allow the same agent to speak multiple times, necessary for this task.\n",
    "        selector_prompt=(\n",
    "            \"Available roles:\\n{roles}\\nTheir job descriptions:\\n{participants}\\n\"\n",
    "            \"Current conversation history:\\n{history}\\n\"\n",
    "            \"Please select the most appropriate role for the next message, and only return the role name.\"\n",
    "        ),\n",
    "    )\n",
    "\n",
    "    # Run the selector group chat with a given task and stream the response.\n",
    "    task: List[ChatMessage] = [\n",
    "        TextMessage(content=\"Apply the operations to turn the given number into 25.\", source=\"user\"),\n",
    "        TextMessage(content=\"10\", source=\"user\"),\n",
    "    ]\n",
    "    stream = selector_group_chat.run_stream(task=task)\n",
    "    await Console(stream)\n",
    "\n",
    "\n",
    "# Use asyncio.run(run_number_agents()) when running in a script.\n",
    "await run_number_agents()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "From the output, we can see that the agents have successfully transformed the input integer\n",
    "from 10 to 25 by choosing appropriate agents that apply the arithmetic operations in sequence."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Using Custom Model Clients in Custom Agents\n",
    "\n",
    "One of the key features of the {py:class}`~autogen_agentchat.agents.AssistantAgent` preset in AgentChat is that it takes a `model_client` argument and can use it in responding to messages. However, in some cases, you may want your agent to use a custom model client that is not currently supported (see [supported model clients](https://microsoft.github.io/autogen/dev/user-guide/core-user-guide/components/model-clients.html)) or custom model behaviours. \n",
    "\n",
    "You can accomplish this with a custom agent that implements *your custom model client*.\n",
    "\n",
    "In the example below, we will walk through an example of a custom agent that uses the [Google Gemini SDK](https://github.com/googleapis/python-genai) directly to respond to messages.\n",
    "\n",
    "> **Note:** You will need to install the [Google Gemini SDK](https://github.com/googleapis/python-genai) to run this example. You can install it using the following command: \n",
    "\n",
    "```bash\n",
    "pip install google-genai\n",
    "``` "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !pip install google-genai\n",
    "import os\n",
    "from typing import AsyncGenerator, Sequence\n",
    "\n",
    "from autogen_agentchat.agents import BaseChatAgent\n",
    "from autogen_agentchat.base import Response\n",
    "from autogen_agentchat.messages import AgentEvent, ChatMessage\n",
    "from autogen_core import CancellationToken\n",
    "from autogen_core.model_context import UnboundedChatCompletionContext\n",
    "from autogen_core.models import AssistantMessage, RequestUsage, UserMessage\n",
    "from google import genai\n",
    "from google.genai import types\n",
    "\n",
    "\n",
    "class GeminiAssistantAgent(BaseChatAgent):\n",
    "    def __init__(\n",
    "        self,\n",
    "        name: str,\n",
    "        description: str = \"An agent that provides assistance with ability to use tools.\",\n",
    "        model: str = \"gemini-1.5-flash-002\",\n",
    "        api_key: str = os.environ[\"GEMINI_API_KEY\"],\n",
    "        system_message: str\n",
    "        | None = \"You are a helpful assistant that can respond to messages. Reply with TERMINATE when the task has been completed.\",\n",
    "    ):\n",
    "        super().__init__(name=name, description=description)\n",
    "        self._model_context = UnboundedChatCompletionContext()\n",
    "        self._model_client = genai.Client(api_key=api_key)\n",
    "        self._system_message = system_message\n",
    "        self._model = model\n",
    "\n",
    "    @property\n",
    "    def produced_message_types(self) -> Sequence[type[ChatMessage]]:\n",
    "        return (TextMessage,)\n",
    "\n",
    "    async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:\n",
    "        final_response = None\n",
    "        async for message in self.on_messages_stream(messages, cancellation_token):\n",
    "            if isinstance(message, Response):\n",
    "                final_response = message\n",
    "\n",
    "        if final_response is None:\n",
    "            raise AssertionError(\"The stream should have returned the final result.\")\n",
    "\n",
    "        return final_response\n",
    "\n",
    "    async def on_messages_stream(\n",
    "        self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken\n",
    "    ) -> AsyncGenerator[AgentEvent | ChatMessage | Response, None]:\n",
    "        # Add messages to the model context\n",
    "        for msg in messages:\n",
    "            await self._model_context.add_message(UserMessage(content=msg.content, source=msg.source))\n",
    "\n",
    "        # Get conversation history\n",
    "        history = [\n",
    "            (msg.source if hasattr(msg, \"source\") else \"system\")\n",
    "            + \": \"\n",
    "            + (msg.content if isinstance(msg.content, str) else \"\")\n",
    "            + \"\\n\"\n",
    "            for msg in await self._model_context.get_messages()\n",
    "        ]\n",
    "        # Generate response using Gemini\n",
    "        response = self._model_client.models.generate_content(\n",
    "            model=self._model,\n",
    "            contents=f\"History: {history}\\nGiven the history, please provide a response\",\n",
    "            config=types.GenerateContentConfig(\n",
    "                system_instruction=self._system_message,\n",
    "                temperature=0.3,\n",
    "            ),\n",
    "        )\n",
    "\n",
    "        # Create usage metadata\n",
    "        usage = RequestUsage(\n",
    "            prompt_tokens=response.usage_metadata.prompt_token_count,\n",
    "            completion_tokens=response.usage_metadata.candidates_token_count,\n",
    "        )\n",
    "\n",
    "        # Add response to model context\n",
    "        await self._model_context.add_message(AssistantMessage(content=response.text, source=self.name))\n",
    "\n",
    "        # Yield the final response\n",
    "        yield Response(\n",
    "            chat_message=TextMessage(content=response.text, source=self.name, models_usage=usage),\n",
    "            inner_messages=[],\n",
    "        )\n",
    "\n",
    "    async def on_reset(self, cancellation_token: CancellationToken) -> None:\n",
    "        \"\"\"Reset the assistant by clearing the model context.\"\"\"\n",
    "        await self._model_context.clear()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "---------- user ----------\n",
      "What is the capital of New York?\n",
      "---------- gemini_assistant ----------\n",
      "Albany\n",
      "TERMINATE\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "TaskResult(messages=[TextMessage(source='user', models_usage=None, content='What is the capital of New York?', type='TextMessage'), TextMessage(source='gemini_assistant', models_usage=RequestUsage(prompt_tokens=46, completion_tokens=5), content='Albany\\nTERMINATE\\n', type='TextMessage')], stop_reason=None)"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gemini_assistant = GeminiAssistantAgent(\"gemini_assistant\")\n",
    "await Console(gemini_assistant.run_stream(task=\"What is the capital of New York?\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In the example above, we have chosen to provide `model`, `api_key` and `system_message` as arguments - you can choose to provide any other arguments that are required by the model client you are using or fits with your application design. \n",
    "\n",
    "Now, let us explore how to use this custom agent as part of a team in AgentChat."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "---------- user ----------\n",
      "Write a Haiku poem with 4 lines about the fall season.\n",
      "---------- primary ----------\n",
      "Crimson leaves cascade,  \n",
      "Whispering winds sing of change,  \n",
      "Chill wraps the fading,  \n",
      "Nature's quilt, rich and warm.\n",
      "---------- gemini_critic ----------\n",
      "The poem is good, but it has four lines instead of three.  A haiku must have three lines with a 5-7-5 syllable structure.  The content is evocative of autumn, but the form is incorrect.  Please revise to adhere to the haiku's syllable structure.\n",
      "\n",
      "---------- primary ----------\n",
      "Thank you for your feedback! Here’s a revised haiku that follows the 5-7-5 syllable structure:\n",
      "\n",
      "Crimson leaves drift down,  \n",
      "Chill winds whisper through the gold,  \n",
      "Autumn’s breath is near.\n",
      "---------- gemini_critic ----------\n",
      "The revised haiku is much improved.  It correctly follows the 5-7-5 syllable structure and maintains the evocative imagery of autumn.  APPROVE\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "TaskResult(messages=[TextMessage(source='user', models_usage=None, content='Write a Haiku poem with 4 lines about the fall season.', type='TextMessage'), TextMessage(source='primary', models_usage=RequestUsage(prompt_tokens=33, completion_tokens=31), content=\"Crimson leaves cascade,  \\nWhispering winds sing of change,  \\nChill wraps the fading,  \\nNature's quilt, rich and warm.\", type='TextMessage'), TextMessage(source='gemini_critic', models_usage=RequestUsage(prompt_tokens=86, completion_tokens=60), content=\"The poem is good, but it has four lines instead of three.  A haiku must have three lines with a 5-7-5 syllable structure.  The content is evocative of autumn, but the form is incorrect.  Please revise to adhere to the haiku's syllable structure.\\n\", type='TextMessage'), TextMessage(source='primary', models_usage=RequestUsage(prompt_tokens=141, completion_tokens=49), content='Thank you for your feedback! Here’s a revised haiku that follows the 5-7-5 syllable structure:\\n\\nCrimson leaves drift down,  \\nChill winds whisper through the gold,  \\nAutumn’s breath is near.', type='TextMessage'), TextMessage(source='gemini_critic', models_usage=RequestUsage(prompt_tokens=211, completion_tokens=32), content='The revised haiku is much improved.  It correctly follows the 5-7-5 syllable structure and maintains the evocative imagery of autumn.  APPROVE\\n', type='TextMessage')], stop_reason=\"Text 'APPROVE' mentioned\")"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from autogen_agentchat.agents import AssistantAgent\n",
    "from autogen_agentchat.conditions import TextMentionTermination\n",
    "from autogen_agentchat.teams import RoundRobinGroupChat\n",
    "from autogen_agentchat.ui import Console\n",
    "\n",
    "# Create the primary agent.\n",
    "primary_agent = AssistantAgent(\n",
    "    \"primary\",\n",
    "    model_client=OpenAIChatCompletionClient(model=\"gpt-4o-mini\"),\n",
    "    system_message=\"You are a helpful AI assistant.\",\n",
    ")\n",
    "\n",
    "# Create a critic agent based on our new GeminiAssistantAgent.\n",
    "gemini_critic_agent = GeminiAssistantAgent(\n",
    "    \"gemini_critic\",\n",
    "    system_message=\"Provide constructive feedback. Respond with 'APPROVE' to when your feedbacks are addressed.\",\n",
    ")\n",
    "\n",
    "\n",
    "# Define a termination condition that stops the task if the critic approves or after 10 messages.\n",
    "termination = TextMentionTermination(\"APPROVE\") | MaxMessageTermination(10)\n",
    "\n",
    "# Create a team with the primary and critic agents.\n",
    "team = RoundRobinGroupChat([primary_agent, gemini_critic_agent], termination_condition=termination)\n",
    "\n",
    "await Console(team.run_stream(task=\"Write a Haiku poem with 4 lines about the fall season.\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In section above, we show several very important concepts:\n",
    "- We have developed a custom agent that uses the Google Gemini SDK to respond to messages. \n",
    "- We show that this custom agent can be used as part of the broader AgentChat ecosystem - in this case as a participant in a {py:class}`~autogen_agentchat.teams.RoundRobinGroupChat` as long as it inherits from {py:class}`~autogen_agentchat.agents.BaseChatAgent`.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Making the Custom Agent Declarative \n",
    "\n",
    "Autogen provides a [Component](https://microsoft.github.io/autogen/dev/user-guide/core-user-guide/framework/component-config.html) interface for making the configuration of components serializable to a declarative format. This is useful for saving and loading configurations, and for sharing configurations with others. \n",
    "\n",
    "We accomplish this by inheriting from the `Component` class and implementing the `_from_config` and `_to_config` methods.\n",
    "The declarative class can be serialized to a JSON format using the `dump_component` method, and deserialized from a JSON format using the `load_component` method."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from typing import AsyncGenerator, Sequence\n",
    "\n",
    "from autogen_agentchat.agents import BaseChatAgent\n",
    "from autogen_agentchat.base import Response\n",
    "from autogen_agentchat.messages import AgentEvent, ChatMessage\n",
    "from autogen_core import CancellationToken, Component\n",
    "from pydantic import BaseModel\n",
    "from typing_extensions import Self\n",
    "\n",
    "\n",
    "class GeminiAssistantAgentConfig(BaseModel):\n",
    "    name: str\n",
    "    description: str = \"An agent that provides assistance with ability to use tools.\"\n",
    "    model: str = \"gemini-1.5-flash-002\"\n",
    "    system_message: str | None = None\n",
    "\n",
    "\n",
    "class GeminiAssistantAgent(BaseChatAgent, Component[GeminiAssistantAgentConfig]):  # type: ignore[no-redef]\n",
    "    component_config_schema = GeminiAssistantAgentConfig\n",
    "    # component_provider_override = \"mypackage.agents.GeminiAssistantAgent\"\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        name: str,\n",
    "        description: str = \"An agent that provides assistance with ability to use tools.\",\n",
    "        model: str = \"gemini-1.5-flash-002\",\n",
    "        api_key: str = os.environ[\"GEMINI_API_KEY\"],\n",
    "        system_message: str\n",
    "        | None = \"You are a helpful assistant that can respond to messages. Reply with TERMINATE when the task has been completed.\",\n",
    "    ):\n",
    "        super().__init__(name=name, description=description)\n",
    "        self._model_context = UnboundedChatCompletionContext()\n",
    "        self._model_client = genai.Client(api_key=api_key)\n",
    "        self._system_message = system_message\n",
    "        self._model = model\n",
    "\n",
    "    @property\n",
    "    def produced_message_types(self) -> Sequence[type[ChatMessage]]:\n",
    "        return (TextMessage,)\n",
    "\n",
    "    async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:\n",
    "        final_response = None\n",
    "        async for message in self.on_messages_stream(messages, cancellation_token):\n",
    "            if isinstance(message, Response):\n",
    "                final_response = message\n",
    "\n",
    "        if final_response is None:\n",
    "            raise AssertionError(\"The stream should have returned the final result.\")\n",
    "\n",
    "        return final_response\n",
    "\n",
    "    async def on_messages_stream(\n",
    "        self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken\n",
    "    ) -> AsyncGenerator[AgentEvent | ChatMessage | Response, None]:\n",
    "        # Add messages to the model context\n",
    "        for msg in messages:\n",
    "            await self._model_context.add_message(UserMessage(content=msg.content, source=msg.source))\n",
    "\n",
    "        # Get conversation history\n",
    "        history = [\n",
    "            (msg.source if hasattr(msg, \"source\") else \"system\")\n",
    "            + \": \"\n",
    "            + (msg.content if isinstance(msg.content, str) else \"\")\n",
    "            + \"\\n\"\n",
    "            for msg in await self._model_context.get_messages()\n",
    "        ]\n",
    "\n",
    "        # Generate response using Gemini\n",
    "        response = self._model_client.models.generate_content(\n",
    "            model=self._model,\n",
    "            contents=f\"History: {history}\\nGiven the history, please provide a response\",\n",
    "            config=types.GenerateContentConfig(\n",
    "                system_instruction=self._system_message,\n",
    "                temperature=0.3,\n",
    "            ),\n",
    "        )\n",
    "\n",
    "        # Create usage metadata\n",
    "        usage = RequestUsage(\n",
    "            prompt_tokens=response.usage_metadata.prompt_token_count,\n",
    "            completion_tokens=response.usage_metadata.candidates_token_count,\n",
    "        )\n",
    "\n",
    "        # Add response to model context\n",
    "        await self._model_context.add_message(AssistantMessage(content=response.text, source=self.name))\n",
    "\n",
    "        # Yield the final response\n",
    "        yield Response(\n",
    "            chat_message=TextMessage(content=response.text, source=self.name, models_usage=usage),\n",
    "            inner_messages=[],\n",
    "        )\n",
    "\n",
    "    async def on_reset(self, cancellation_token: CancellationToken) -> None:\n",
    "        \"\"\"Reset the assistant by clearing the model context.\"\"\"\n",
    "        await self._model_context.clear()\n",
    "\n",
    "    @classmethod\n",
    "    def _from_config(cls, config: GeminiAssistantAgentConfig) -> Self:\n",
    "        return cls(\n",
    "            name=config.name, description=config.description, model=config.model, system_message=config.system_message\n",
    "        )\n",
    "\n",
    "    def _to_config(self) -> GeminiAssistantAgentConfig:\n",
    "        return GeminiAssistantAgentConfig(\n",
    "            name=self.name,\n",
    "            description=self.description,\n",
    "            model=self._model,\n",
    "            system_message=self._system_message,\n",
    "        )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that we have the required methods implemented, we can now load and dump the custom agent to and from a JSON format, and then load the agent from the JSON format.\n",
    " \n",
    " > Note: You should set the `component_provider_override` class variable to the full path of the module containing the custom agent class e.g., (`mypackage.agents.GeminiAssistantAgent`). This is used by   `load_component` method to determine how to instantiate the class. \n",
    " "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{\n",
      "  \"provider\": \"__main__.GeminiAssistantAgent\",\n",
      "  \"component_type\": \"agent\",\n",
      "  \"version\": 1,\n",
      "  \"component_version\": 1,\n",
      "  \"description\": null,\n",
      "  \"label\": \"GeminiAssistantAgent\",\n",
      "  \"config\": {\n",
      "    \"name\": \"gemini_assistant\",\n",
      "    \"description\": \"An agent that provides assistance with ability to use tools.\",\n",
      "    \"model\": \"gemini-1.5-flash-002\",\n",
      "    \"system_message\": \"You are a helpful assistant that can respond to messages. Reply with TERMINATE when the task has been completed.\"\n",
      "  }\n",
      "}\n",
      "<__main__.GeminiAssistantAgent object at 0x11a5c5a90>\n"
     ]
    }
   ],
   "source": [
    "gemini_assistant = GeminiAssistantAgent(\"gemini_assistant\")\n",
    "config = gemini_assistant.dump_component()\n",
    "print(config.model_dump_json(indent=2))\n",
    "loaded_agent = GeminiAssistantAgent.load_component(config)\n",
    "print(loaded_agent)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Next Steps \n",
    "\n",
    "So far, we have seen how to create custom agents, add custom model clients to agents, and make custom agents declarative. There are a few ways in which this basic sample can be extended:\n",
    "\n",
    "- Extend the Gemini model client to handle function calling similar to the {py:class}`~autogen_agentchat.agents.AssistantAgent` class. https://ai.google.dev/gemini-api/docs/function-calling  \n",
    "- Implement a package wit a custom agent and experiment with using it's declarative format in a tool like [AutoGen Studio](https://microsoft.github.io/autogen/stable/user-guide/autogenstudio-user-guide/index.html)."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}