{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Using LangGraph-Backed Agent\n",
    "\n",
    "This example demonstrates how to create an AI agent using LangGraph.\n",
    "Based on the example in the LangGraph documentation:\n",
    "https://langchain-ai.github.io/langgraph/.\n",
    "\n",
    "First install the dependencies:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "vscode": {
     "languageId": "shellscript"
    }
   },
   "outputs": [],
   "source": [
    "# pip install langgraph langchain-openai azure-identity"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's import the modules."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from dataclasses import dataclass\n",
    "from typing import Any, Callable, List, Literal\n",
    "\n",
    "from autogen_core import AgentId, MessageContext, RoutedAgent, SingleThreadedAgentRuntime, message_handler\n",
    "from azure.identity import DefaultAzureCredential, get_bearer_token_provider\n",
    "from langchain_core.messages import HumanMessage, SystemMessage\n",
    "from langchain_core.tools import tool  # pyright: ignore\n",
    "from langchain_openai import AzureChatOpenAI, ChatOpenAI\n",
    "from langgraph.graph import END, MessagesState, StateGraph\n",
    "from langgraph.prebuilt import ToolNode"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Define our message type that will be used to communicate with the agent."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "@dataclass\n",
    "class Message:\n",
    "    content: str"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Define the tools the agent will use."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "@tool  # pyright: ignore\n",
    "def get_weather(location: str) -> str:\n",
    "    \"\"\"Call to surf the web.\"\"\"\n",
    "    # This is a placeholder, but don't tell the LLM that...\n",
    "    if \"sf\" in location.lower() or \"san francisco\" in location.lower():\n",
    "        return \"It's 60 degrees and foggy.\"\n",
    "    return \"It's 90 degrees and sunny.\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Define the agent using LangGraph's API."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LangGraphToolUseAgent(RoutedAgent):\n",
    "    def __init__(self, description: str, model: ChatOpenAI, tools: List[Callable[..., Any]]) -> None:  # pyright: ignore\n",
    "        super().__init__(description)\n",
    "        self._model = model.bind_tools(tools)  # pyright: ignore\n",
    "\n",
    "        # Define the function that determines whether to continue or not\n",
    "        def should_continue(state: MessagesState) -> Literal[\"tools\", END]:  # type: ignore\n",
    "            messages = state[\"messages\"]\n",
    "            last_message = messages[-1]\n",
    "            # If the LLM makes a tool call, then we route to the \"tools\" node\n",
    "            if last_message.tool_calls:  # type: ignore\n",
    "                return \"tools\"\n",
    "            # Otherwise, we stop (reply to the user)\n",
    "            return END\n",
    "\n",
    "        # Define the function that calls the model\n",
    "        async def call_model(state: MessagesState):  # type: ignore\n",
    "            messages = state[\"messages\"]\n",
    "            response = await self._model.ainvoke(messages)\n",
    "            # We return a list, because this will get added to the existing list\n",
    "            return {\"messages\": [response]}\n",
    "\n",
    "        tool_node = ToolNode(tools)  # pyright: ignore\n",
    "\n",
    "        # Define a new graph\n",
    "        self._workflow = StateGraph(MessagesState)\n",
    "\n",
    "        # Define the two nodes we will cycle between\n",
    "        self._workflow.add_node(\"agent\", call_model)  # pyright: ignore\n",
    "        self._workflow.add_node(\"tools\", tool_node)  # pyright: ignore\n",
    "\n",
    "        # Set the entrypoint as `agent`\n",
    "        # This means that this node is the first one called\n",
    "        self._workflow.set_entry_point(\"agent\")\n",
    "\n",
    "        # We now add a conditional edge\n",
    "        self._workflow.add_conditional_edges(\n",
    "            # First, we define the start node. We use `agent`.\n",
    "            # This means these are the edges taken after the `agent` node is called.\n",
    "            \"agent\",\n",
    "            # Next, we pass in the function that will determine which node is called next.\n",
    "            should_continue,  # type: ignore\n",
    "        )\n",
    "\n",
    "        # We now add a normal edge from `tools` to `agent`.\n",
    "        # This means that after `tools` is called, `agent` node is called next.\n",
    "        self._workflow.add_edge(\"tools\", \"agent\")\n",
    "\n",
    "        # Finally, we compile it!\n",
    "        # This compiles it into a LangChain Runnable,\n",
    "        # meaning you can use it as you would any other runnable.\n",
    "        # Note that we're (optionally) passing the memory when compiling the graph\n",
    "        self._app = self._workflow.compile()\n",
    "\n",
    "    @message_handler\n",
    "    async def handle_user_message(self, message: Message, ctx: MessageContext) -> Message:\n",
    "        # Use the Runnable\n",
    "        final_state = await self._app.ainvoke(\n",
    "            {\n",
    "                \"messages\": [\n",
    "                    SystemMessage(\n",
    "                        content=\"You are a helpful AI assistant. You can use tools to help answer questions.\"\n",
    "                    ),\n",
    "                    HumanMessage(content=message.content),\n",
    "                ]\n",
    "            },\n",
    "            config={\"configurable\": {\"thread_id\": 42}},\n",
    "        )\n",
    "        response = Message(content=final_state[\"messages\"][-1].content)\n",
    "        return response"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now let's test the agent. First we need to create an agent runtime and\n",
    "register the agent, by providing the agent's name and a factory function\n",
    "that will create the agent."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "runtime = SingleThreadedAgentRuntime()\n",
    "await LangGraphToolUseAgent.register(\n",
    "    runtime,\n",
    "    \"langgraph_tool_use_agent\",\n",
    "    lambda: LangGraphToolUseAgent(\n",
    "        \"Tool use agent\",\n",
    "        ChatOpenAI(\n",
    "            model=\"gpt-4o\",\n",
    "            # api_key=os.getenv(\"OPENAI_API_KEY\"),\n",
    "        ),\n",
    "        # AzureChatOpenAI(\n",
    "        #     azure_deployment=os.getenv(\"AZURE_OPENAI_DEPLOYMENT\"),\n",
    "        #     azure_endpoint=os.getenv(\"AZURE_OPENAI_ENDPOINT\"),\n",
    "        #     api_version=os.getenv(\"AZURE_OPENAI_API_VERSION\"),\n",
    "        #     # Using Azure Active Directory authentication.\n",
    "        #     azure_ad_token_provider=get_bearer_token_provider(DefaultAzureCredential()),\n",
    "        #     # Using API key.\n",
    "        #     # api_key=os.getenv(\"AZURE_OPENAI_API_KEY\"),\n",
    "        # ),\n",
    "        [get_weather],\n",
    "    ),\n",
    ")\n",
    "agent = AgentId(\"langgraph_tool_use_agent\", key=\"default\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Start the agent runtime."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "runtime.start()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Send a direct message to the agent, and print the response."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The current weather in San Francisco is 60 degrees and foggy.\n"
     ]
    }
   ],
   "source": [
    "response = await runtime.send_message(Message(\"What's the weather in SF?\"), agent)\n",
    "print(response.content)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Stop the agent runtime."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "await runtime.stop()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "autogen_core",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}