{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Reflection\n", "\n", "Reflection is a design pattern where an LLM generation is followed by a reflection,\n", "which in itself is another LLM generation conditioned on the output of the first one.\n", "For example, given a task to write code, the first LLM can generate a code snippet,\n", "and the second LLM can generate a critique of the code snippet.\n", "\n", "In the context of AutoGen and agents, reflection can be implemented as a pair\n", "of agents, where the first agent generates a message and the second agent\n", "generates a response to the message. The two agents continue to interact\n", "until they reach a stopping condition, such as a maximum number of iterations\n", "or an approval from the second agent." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's implement a simple reflection design pattern using AutoGen agents.\n", "There will be two agents: a coder agent and a reviewer agent, the coder agent\n", "will generate a code snippet, and the reviewer agent will generate a critique\n", "of the code snippet.\n", "\n", "## Message Protocol\n", "\n", "Before we define the agents, we need to first define the message protocol for the agents." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from dataclasses import dataclass\n", "\n", "\n", "@dataclass\n", "class CodeWritingTask:\n", " task: str\n", "\n", "\n", "@dataclass\n", "class CodeWritingResult:\n", " task: str\n", " code: str\n", " review: str\n", "\n", "\n", "@dataclass\n", "class CodeReviewTask:\n", " session_id: str\n", " code_writing_task: str\n", " code_writing_scratchpad: str\n", " code: str\n", "\n", "\n", "@dataclass\n", "class CodeReviewResult:\n", " review: str\n", " session_id: str\n", " approved: bool" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The above set of messages defines the protocol for our example reflection design pattern:\n", "- The application sends a `CodeWritingTask` message to the coder agent\n", "- The coder agent generates a `CodeReviewTask` message, which is sent to the reviewer agent\n", "- The reviewer agent generates a `CodeReviewResult` message, which is sent back to the coder agent\n", "- Depending on the `CodeReview` message, if the code is approved, the coder agent sends a `CodeWritingResult` message\n", "back to the application, otherwise, the coder agent sends another `CodeWritingTask` message to the reviewer agent,\n", "and the process continues.\n", "\n", "We can visualize the message protocol using a data flow diagram:\n", "\n", "![coder-reviewer data flow](coder-reviewer-data-flow.svg)\n", "\n", "## Agents\n", "\n", "Now, let's define the agents for the reflection design pattern." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import json\n", "import re\n", "import uuid\n", "from typing import Dict, List, Union\n", "\n", "from autogen_core.base import MessageContext, TopicId\n", "from autogen_core.components import RoutedAgent, default_subscription, message_handler\n", "from autogen_core.components.models import (\n", " AssistantMessage,\n", " ChatCompletionClient,\n", " LLMMessage,\n", " SystemMessage,\n", " UserMessage,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We use the [Broadcast](../framework/message-and-communication.ipynb#broadcast) API\n", "to implement the design pattern. The agents implements the pub/sub model.\n", "The coder agent subscribes to the `CodeWritingTask` and `CodeReviewResult` messages,\n", "and publishes the `CodeReviewTask` and `CodeWritingResult` messages." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "@default_subscription\n", "class CoderAgent(RoutedAgent):\n", " \"\"\"An agent that performs code writing tasks.\"\"\"\n", "\n", " def __init__(self, model_client: ChatCompletionClient) -> None:\n", " super().__init__(\"A code writing agent.\")\n", " self._system_messages: List[LLMMessage] = [\n", " SystemMessage(\n", " content=\"\"\"You are a proficient coder. You write code to solve problems.\n", "Work with the reviewer to improve your code.\n", "Always put all finished code in a single Markdown code block.\n", "For example:\n", "```python\n", "def hello_world():\n", " print(\"Hello, World!\")\n", "```\n", "\n", "Respond using the following format:\n", "\n", "Thoughts: \n", "Code: \n", "\"\"\",\n", " )\n", " ]\n", " self._model_client = model_client\n", " self._session_memory: Dict[str, List[CodeWritingTask | CodeReviewTask | CodeReviewResult]] = {}\n", "\n", " @message_handler\n", " async def handle_code_writing_task(self, message: CodeWritingTask, ctx: MessageContext) -> None:\n", " # Store the messages in a temporary memory for this request only.\n", " session_id = str(uuid.uuid4())\n", " self._session_memory.setdefault(session_id, []).append(message)\n", " # Generate a response using the chat completion API.\n", " response = await self._model_client.create(\n", " self._system_messages + [UserMessage(content=message.task, source=self.metadata[\"type\"])],\n", " cancellation_token=ctx.cancellation_token,\n", " )\n", " assert isinstance(response.content, str)\n", " # Extract the code block from the response.\n", " code_block = self._extract_code_block(response.content)\n", " if code_block is None:\n", " raise ValueError(\"Code block not found.\")\n", " # Create a code review task.\n", " code_review_task = CodeReviewTask(\n", " session_id=session_id,\n", " code_writing_task=message.task,\n", " code_writing_scratchpad=response.content,\n", " code=code_block,\n", " )\n", " # Store the code review task in the session memory.\n", " self._session_memory[session_id].append(code_review_task)\n", " # Publish a code review task.\n", " await self.publish_message(code_review_task, topic_id=TopicId(\"default\", self.id.key))\n", "\n", " @message_handler\n", " async def handle_code_review_result(self, message: CodeReviewResult, ctx: MessageContext) -> None:\n", " # Store the review result in the session memory.\n", " self._session_memory[message.session_id].append(message)\n", " # Obtain the request from previous messages.\n", " review_request = next(\n", " m for m in reversed(self._session_memory[message.session_id]) if isinstance(m, CodeReviewTask)\n", " )\n", " assert review_request is not None\n", " # Check if the code is approved.\n", " if message.approved:\n", " # Publish the code writing result.\n", " await self.publish_message(\n", " CodeWritingResult(\n", " code=review_request.code,\n", " task=review_request.code_writing_task,\n", " review=message.review,\n", " ),\n", " topic_id=TopicId(\"default\", self.id.key),\n", " )\n", " print(\"Code Writing Result:\")\n", " print(\"-\" * 80)\n", " print(f\"Task:\\n{review_request.code_writing_task}\")\n", " print(\"-\" * 80)\n", " print(f\"Code:\\n{review_request.code}\")\n", " print(\"-\" * 80)\n", " print(f\"Review:\\n{message.review}\")\n", " print(\"-\" * 80)\n", " else:\n", " # Create a list of LLM messages to send to the model.\n", " messages: List[LLMMessage] = [*self._system_messages]\n", " for m in self._session_memory[message.session_id]:\n", " if isinstance(m, CodeReviewResult):\n", " messages.append(UserMessage(content=m.review, source=\"Reviewer\"))\n", " elif isinstance(m, CodeReviewTask):\n", " messages.append(AssistantMessage(content=m.code_writing_scratchpad, source=\"Coder\"))\n", " elif isinstance(m, CodeWritingTask):\n", " messages.append(UserMessage(content=m.task, source=\"User\"))\n", " else:\n", " raise ValueError(f\"Unexpected message type: {m}\")\n", " # Generate a revision using the chat completion API.\n", " response = await self._model_client.create(messages, cancellation_token=ctx.cancellation_token)\n", " assert isinstance(response.content, str)\n", " # Extract the code block from the response.\n", " code_block = self._extract_code_block(response.content)\n", " if code_block is None:\n", " raise ValueError(\"Code block not found.\")\n", " # Create a new code review task.\n", " code_review_task = CodeReviewTask(\n", " session_id=message.session_id,\n", " code_writing_task=review_request.code_writing_task,\n", " code_writing_scratchpad=response.content,\n", " code=code_block,\n", " )\n", " # Store the code review task in the session memory.\n", " self._session_memory[message.session_id].append(code_review_task)\n", " # Publish a new code review task.\n", " await self.publish_message(code_review_task, topic_id=TopicId(\"default\", self.id.key))\n", "\n", " def _extract_code_block(self, markdown_text: str) -> Union[str, None]:\n", " pattern = r\"```(\\w+)\\n(.*?)\\n```\"\n", " # Search for the pattern in the markdown text\n", " match = re.search(pattern, markdown_text, re.DOTALL)\n", " # Extract the language and code block if a match is found\n", " if match:\n", " return match.group(2)\n", " return None" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A few things to note about `CoderAgent`:\n", "- It uses chain-of-thought prompting in its system message.\n", "- It stores message histories for different `CodeWritingTask` in a dictionary,\n", "so each task has its own history.\n", "- When making an LLM inference request using its model client, it transforms\n", "the message history into a list of {py:class}`autogen_core.components.models.LLMMessage` objects\n", "to pass to the model client.\n", "\n", "The reviewer agent subscribes to the `CodeReviewTask` message and publishes the `CodeReviewResult` message." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "@default_subscription\n", "class ReviewerAgent(RoutedAgent):\n", " \"\"\"An agent that performs code review tasks.\"\"\"\n", "\n", " def __init__(self, model_client: ChatCompletionClient) -> None:\n", " super().__init__(\"A code reviewer agent.\")\n", " self._system_messages: List[LLMMessage] = [\n", " SystemMessage(\n", " content=\"\"\"You are a code reviewer. You focus on correctness, efficiency and safety of the code.\n", "Respond using the following JSON format:\n", "{\n", " \"correctness\": \"\",\n", " \"efficiency\": \"\",\n", " \"safety\": \"\",\n", " \"approval\": \"\",\n", " \"suggested_changes\": \"\"\n", "}\n", "\"\"\",\n", " )\n", " ]\n", " self._session_memory: Dict[str, List[CodeReviewTask | CodeReviewResult]] = {}\n", " self._model_client = model_client\n", "\n", " @message_handler\n", " async def handle_code_review_task(self, message: CodeReviewTask, ctx: MessageContext) -> None:\n", " # Format the prompt for the code review.\n", " # Gather the previous feedback if available.\n", " previous_feedback = \"\"\n", " if message.session_id in self._session_memory:\n", " previous_review = next(\n", " (m for m in reversed(self._session_memory[message.session_id]) if isinstance(m, CodeReviewResult)),\n", " None,\n", " )\n", " if previous_review is not None:\n", " previous_feedback = previous_review.review\n", " # Store the messages in a temporary memory for this request only.\n", " self._session_memory.setdefault(message.session_id, []).append(message)\n", " prompt = f\"\"\"The problem statement is: {message.code_writing_task}\n", "The code is:\n", "```\n", "{message.code}\n", "```\n", "\n", "Previous feedback:\n", "{previous_feedback}\n", "\n", "Please review the code. If previous feedback was provided, see if it was addressed.\n", "\"\"\"\n", " # Generate a response using the chat completion API.\n", " response = await self._model_client.create(\n", " self._system_messages + [UserMessage(content=prompt, source=self.metadata[\"type\"])],\n", " cancellation_token=ctx.cancellation_token,\n", " json_output=True,\n", " )\n", " assert isinstance(response.content, str)\n", " # TODO: use structured generation library e.g. guidance to ensure the response is in the expected format.\n", " # Parse the response JSON.\n", " review = json.loads(response.content)\n", " # Construct the review text.\n", " review_text = \"Code review:\\n\" + \"\\n\".join([f\"{k}: {v}\" for k, v in review.items()])\n", " approved = review[\"approval\"].lower().strip() == \"approve\"\n", " result = CodeReviewResult(\n", " review=review_text,\n", " session_id=message.session_id,\n", " approved=approved,\n", " )\n", " # Store the review result in the session memory.\n", " self._session_memory[message.session_id].append(result)\n", " # Publish the review result.\n", " await self.publish_message(result, topic_id=TopicId(\"default\", self.id.key))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The `ReviewerAgent` uses JSON-mode when making an LLM inference request, and\n", "also uses chain-of-thought prompting in its system message." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Logging\n", "\n", "Turn on logging to see the messages exchanged between the agents." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "import logging\n", "\n", "logging.basicConfig(level=logging.WARNING)\n", "logging.getLogger(\"autogen_core\").setLevel(logging.DEBUG)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Running the Design Pattern\n", "\n", "Let's test the design pattern with a coding task.\n", "Since all the agents are decorated with the {py:meth}`~autogen_core.components.default_subscription` class decorator,\n", "the agents when created will automatically subscribe to the default topic.\n", "We publish a `CodeWritingTask` message to the default topic to start the reflection process." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO:autogen_core:Publishing message of type CodeWritingTask to all subscribers: {'task': 'Write a function to find the sum of all even numbers in a list.'}\n", "INFO:autogen_core:Calling message handler for ReviewerAgent with message type CodeWritingTask published by Unknown\n", "INFO:autogen_core:Calling message handler for CoderAgent with message type CodeWritingTask published by Unknown\n", "INFO:autogen_core:Unhandled message: CodeWritingTask(task='Write a function to find the sum of all even numbers in a list.')\n", "INFO:autogen_core.events:{\"prompt_tokens\": 101, \"completion_tokens\": 88, \"type\": \"LLMCall\"}\n", "INFO:autogen_core:Publishing message of type CodeReviewTask to all subscribers: {'session_id': '51db93d5-3e29-4b7f-9f96-77be7bb02a5e', 'code_writing_task': 'Write a function to find the sum of all even numbers in a list.', 'code_writing_scratchpad': 'Thoughts: To find the sum of all even numbers in a list, we can use a list comprehension to filter out the even numbers and then use the `sum()` function to calculate their total. The implementation should handle edge cases like an empty list or a list with no even numbers.\\n\\nCode:\\n```python\\ndef sum_of_even_numbers(numbers):\\n return sum(num for num in numbers if num % 2 == 0)\\n```', 'code': 'def sum_of_even_numbers(numbers):\\n return sum(num for num in numbers if num % 2 == 0)'}\n", "INFO:autogen_core:Calling message handler for ReviewerAgent with message type CodeReviewTask published by CoderAgent:default\n", "INFO:autogen_core.events:{\"prompt_tokens\": 163, \"completion_tokens\": 235, \"type\": \"LLMCall\"}\n", "INFO:autogen_core:Publishing message of type CodeReviewResult to all subscribers: {'review': \"Code review:\\ncorrectness: The function correctly identifies and sums all even numbers in the provided list. The use of a generator expression ensures that only even numbers are processed, which is correct.\\nefficiency: The function is efficient as it utilizes a generator expression that avoids creating an intermediate list, therefore using less memory. The time complexity is O(n) where n is the number of elements in the input list, which is optimal for this task.\\nsafety: The function does not include checks for input types. If a non-iterable or a list containing non-integer types is passed, it could lead to unexpected behavior or errors. It’s advisable to handle such cases.\\napproval: REVISE\\nsuggested_changes: Consider adding input validation to ensure that 'numbers' is a list and contains only integers. You could raise a ValueError if the input is invalid. Example: 'if not isinstance(numbers, list) or not all(isinstance(num, int) for num in numbers): raise ValueError('Input must be a list of integers')'. This will make the function more robust.\", 'session_id': '51db93d5-3e29-4b7f-9f96-77be7bb02a5e', 'approved': False}\n", "INFO:autogen_core:Calling message handler for CoderAgent with message type CodeReviewResult published by ReviewerAgent:default\n", "INFO:autogen_core.events:{\"prompt_tokens\": 421, \"completion_tokens\": 119, \"type\": \"LLMCall\"}\n", "INFO:autogen_core:Publishing message of type CodeReviewTask to all subscribers: {'session_id': '51db93d5-3e29-4b7f-9f96-77be7bb02a5e', 'code_writing_task': 'Write a function to find the sum of all even numbers in a list.', 'code_writing_scratchpad': \"Thoughts: I appreciate the reviewer's feedback on input validation. Adding type checks ensures that the function can handle unexpected inputs gracefully. I will implement the suggested changes and include checks for both the input type and the elements within the list to confirm that they are integers.\\n\\nCode:\\n```python\\ndef sum_of_even_numbers(numbers):\\n if not isinstance(numbers, list) or not all(isinstance(num, int) for num in numbers):\\n raise ValueError('Input must be a list of integers')\\n \\n return sum(num for num in numbers if num % 2 == 0)\\n```\", 'code': \"def sum_of_even_numbers(numbers):\\n if not isinstance(numbers, list) or not all(isinstance(num, int) for num in numbers):\\n raise ValueError('Input must be a list of integers')\\n \\n return sum(num for num in numbers if num % 2 == 0)\"}\n", "INFO:autogen_core:Calling message handler for ReviewerAgent with message type CodeReviewTask published by CoderAgent:default\n", "INFO:autogen_core.events:{\"prompt_tokens\": 420, \"completion_tokens\": 153, \"type\": \"LLMCall\"}\n", "INFO:autogen_core:Publishing message of type CodeReviewResult to all subscribers: {'review': 'Code review:\\ncorrectness: The function correctly sums all even numbers in the provided list. It raises a ValueError if the input is not a list of integers, which is a necessary check for correctness.\\nefficiency: The function remains efficient with a time complexity of O(n) due to the use of a generator expression. There are no unnecessary intermediate lists created, so memory usage is optimal.\\nsafety: The function includes input validation, which enhances safety by preventing incorrect input types. It raises a ValueError for invalid inputs, making the function more robust against unexpected data.\\napproval: APPROVE\\nsuggested_changes: No further changes are necessary as the previous feedback has been adequately addressed.', 'session_id': '51db93d5-3e29-4b7f-9f96-77be7bb02a5e', 'approved': True}\n", "INFO:autogen_core:Calling message handler for CoderAgent with message type CodeReviewResult published by ReviewerAgent:default\n", "INFO:autogen_core:Publishing message of type CodeWritingResult to all subscribers: {'task': 'Write a function to find the sum of all even numbers in a list.', 'code': \"def sum_of_even_numbers(numbers):\\n if not isinstance(numbers, list) or not all(isinstance(num, int) for num in numbers):\\n raise ValueError('Input must be a list of integers')\\n \\n return sum(num for num in numbers if num % 2 == 0)\", 'review': 'Code review:\\ncorrectness: The function correctly sums all even numbers in the provided list. It raises a ValueError if the input is not a list of integers, which is a necessary check for correctness.\\nefficiency: The function remains efficient with a time complexity of O(n) due to the use of a generator expression. There are no unnecessary intermediate lists created, so memory usage is optimal.\\nsafety: The function includes input validation, which enhances safety by preventing incorrect input types. It raises a ValueError for invalid inputs, making the function more robust against unexpected data.\\napproval: APPROVE\\nsuggested_changes: No further changes are necessary as the previous feedback has been adequately addressed.'}\n", "INFO:autogen_core:Calling message handler for ReviewerAgent with message type CodeWritingResult published by CoderAgent:default\n", "INFO:autogen_core:Unhandled message: CodeWritingResult(task='Write a function to find the sum of all even numbers in a list.', code=\"def sum_of_even_numbers(numbers):\\n if not isinstance(numbers, list) or not all(isinstance(num, int) for num in numbers):\\n raise ValueError('Input must be a list of integers')\\n \\n return sum(num for num in numbers if num % 2 == 0)\", review='Code review:\\ncorrectness: The function correctly sums all even numbers in the provided list. It raises a ValueError if the input is not a list of integers, which is a necessary check for correctness.\\nefficiency: The function remains efficient with a time complexity of O(n) due to the use of a generator expression. There are no unnecessary intermediate lists created, so memory usage is optimal.\\nsafety: The function includes input validation, which enhances safety by preventing incorrect input types. It raises a ValueError for invalid inputs, making the function more robust against unexpected data.\\napproval: APPROVE\\nsuggested_changes: No further changes are necessary as the previous feedback has been adequately addressed.')\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Code Writing Result:\n", "--------------------------------------------------------------------------------\n", "Task:\n", "Write a function to find the sum of all even numbers in a list.\n", "--------------------------------------------------------------------------------\n", "Code:\n", "def sum_of_even_numbers(numbers):\n", " if not isinstance(numbers, list) or not all(isinstance(num, int) for num in numbers):\n", " raise ValueError('Input must be a list of integers')\n", " \n", " return sum(num for num in numbers if num % 2 == 0)\n", "--------------------------------------------------------------------------------\n", "Review:\n", "Code review:\n", "correctness: The function correctly sums all even numbers in the provided list. It raises a ValueError if the input is not a list of integers, which is a necessary check for correctness.\n", "efficiency: The function remains efficient with a time complexity of O(n) due to the use of a generator expression. There are no unnecessary intermediate lists created, so memory usage is optimal.\n", "safety: The function includes input validation, which enhances safety by preventing incorrect input types. It raises a ValueError for invalid inputs, making the function more robust against unexpected data.\n", "approval: APPROVE\n", "suggested_changes: No further changes are necessary as the previous feedback has been adequately addressed.\n", "--------------------------------------------------------------------------------\n" ] } ], "source": [ "from autogen_core.application import SingleThreadedAgentRuntime\n", "from autogen_core.components import DefaultTopicId\n", "from autogen_ext.models import OpenAIChatCompletionClient\n", "\n", "runtime = SingleThreadedAgentRuntime()\n", "await ReviewerAgent.register(\n", " runtime, \"ReviewerAgent\", lambda: ReviewerAgent(model_client=OpenAIChatCompletionClient(model=\"gpt-4o-mini\"))\n", ")\n", "await CoderAgent.register(\n", " runtime, \"CoderAgent\", lambda: CoderAgent(model_client=OpenAIChatCompletionClient(model=\"gpt-4o-mini\"))\n", ")\n", "runtime.start()\n", "await runtime.publish_message(\n", " message=CodeWritingTask(task=\"Write a function to find the sum of all even numbers in a list.\"),\n", " topic_id=DefaultTopicId(),\n", ")\n", "\n", "# Keep processing messages until idle.\n", "await runtime.stop_when_idle()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The log messages show the interaction between the coder and reviewer agents.\n", "The final output shows the code snippet generated by the coder agent and the critique generated by the reviewer agent." ] } ], "metadata": { "kernelspec": { "display_name": "agnext", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.9" } }, "nbformat": 4, "nbformat_minor": 2 }