Source code for autogen_ext.agents.openai._openai_assistant_agent

import asyncio
import json
import logging
import os
from typing import (
    TYPE_CHECKING,
    Any,
    AsyncGenerator,
    Awaitable,
    Callable,
    Dict,
    Iterable,
    List,
    Literal,
    Optional,
    Sequence,
    Set,
    Tuple,
    Union,
    cast,
)

from autogen_agentchat import EVENT_LOGGER_NAME
from autogen_agentchat.agents import BaseChatAgent
from autogen_agentchat.base import Response
from autogen_agentchat.messages import (
    AgentEvent,
    ChatMessage,
    HandoffMessage,
    MultiModalMessage,
    StopMessage,
    TextMessage,
    ToolCallExecutionEvent,
    ToolCallRequestEvent,
)
from autogen_core import CancellationToken, FunctionCall
from autogen_core.models._types import FunctionExecutionResult
from autogen_core.tools import FunctionTool, Tool

_has_openai_dependencies: bool = True
try:
    import aiofiles

    from openai import NOT_GIVEN
    from openai.resources.beta.threads import AsyncMessages, AsyncRuns, AsyncThreads
    from openai.types.beta.code_interpreter_tool_param import CodeInterpreterToolParam
    from openai.types.beta.file_search_tool_param import FileSearchToolParam
    from openai.types.beta.function_tool_param import FunctionToolParam
    from openai.types.shared_params.function_definition import FunctionDefinition
except ImportError:
    _has_openai_dependencies = False

if TYPE_CHECKING:
    import aiofiles

    from openai import NOT_GIVEN, AsyncClient, NotGiven
    from openai.pagination import AsyncCursorPage
    from openai.resources.beta.threads import AsyncMessages, AsyncRuns, AsyncThreads
    from openai.types import FileObject
    from openai.types.beta import thread_update_params
    from openai.types.beta.assistant import Assistant
    from openai.types.beta.assistant_response_format_option_param import AssistantResponseFormatOptionParam
    from openai.types.beta.assistant_tool_param import AssistantToolParam
    from openai.types.beta.code_interpreter_tool_param import CodeInterpreterToolParam
    from openai.types.beta.file_search_tool_param import FileSearchToolParam
    from openai.types.beta.function_tool_param import FunctionToolParam
    from openai.types.beta.thread import Thread, ToolResources, ToolResourcesCodeInterpreter
    from openai.types.beta.threads import Message, MessageDeleted, Run
    from openai.types.beta.vector_store import VectorStore
    from openai.types.shared_params.function_definition import FunctionDefinition

event_logger = logging.getLogger(EVENT_LOGGER_NAME)


def _convert_tool_to_function_param(tool: Tool) -> "FunctionToolParam":
    """Convert an autogen Tool to an OpenAI Assistant function tool parameter."""
    if not _has_openai_dependencies:
        raise RuntimeError(
            "Missing dependecies for OpenAIAssistantAgent. Please ensure the autogen-ext package was installed with the 'openai' extra."
        )

    schema = tool.schema
    parameters: Dict[str, object] = {}
    if "parameters" in schema:
        parameters = {
            "type": schema["parameters"]["type"],
            "properties": schema["parameters"]["properties"],
        }
        if "required" in schema["parameters"]:
            parameters["required"] = schema["parameters"]["required"]

    function_def = FunctionDefinition(
        name=schema["name"],
        description=schema.get("description", ""),
        parameters=parameters,
    )
    return FunctionToolParam(type="function", function=function_def)


[docs] class OpenAIAssistantAgent(BaseChatAgent): """An agent implementation that uses the OpenAI Assistant API to generate responses. This agent leverages the OpenAI Assistant API to create AI assistants with capabilities like: * Code interpretation and execution * File handling and search * Custom function calling * Multi-turn conversations The agent maintains a thread of conversation and can use various tools including * Code interpreter: For executing code and working with files * File search: For searching through uploaded documents * Custom functions: For extending capabilities with user-defined tools Key Features: * Supports multiple file formats including code, documents, images * Can handle up to 128 tools per assistant * Maintains conversation context in threads * Supports file uploads for code interpreter and search * Vector store integration for efficient file search * Automatic file parsing and embedding Example: .. code-block:: python from openai import AsyncClient from autogen_core import CancellationToken import asyncio from autogen_ext.agents.openai import OpenAIAssistantAgent from autogen_agentchat.messages import TextMessage async def example(): cancellation_token = CancellationToken() # Create an OpenAI client client = AsyncClient(api_key="your-api-key", base_url="your-base-url") # Create an assistant with code interpreter assistant = OpenAIAssistantAgent( name="Python Helper", description="Helps with Python programming", client=client, model="gpt-4", instructions="You are a helpful Python programming assistant.", tools=["code_interpreter"], ) # Upload files for the assistant to use await assistant.on_upload_for_code_interpreter("data.csv", cancellation_token) # Get response from the assistant _response = await assistant.on_messages( [TextMessage(source="user", content="Analyze the data in data.csv")], cancellation_token ) # Clean up resources await assistant.delete_uploaded_files(cancellation_token) await assistant.delete_assistant(cancellation_token) asyncio.run(example()) Args: name (str): Name of the assistant description (str): Description of the assistant's purpose client (AsyncClient): OpenAI API client instance model (str): Model to use (e.g. "gpt-4") instructions (str): System instructions for the assistant tools (Optional[Iterable[Union[Literal["code_interpreter", "file_search"], Tool | Callable[..., Any] | Callable[..., Awaitable[Any]]]]]): Tools the assistant can use assistant_id (Optional[str]): ID of existing assistant to use metadata (Optional[object]): Additional metadata for the assistant response_format (Optional[AssistantResponseFormatOptionParam]): Response format settings temperature (Optional[float]): Temperature for response generation tool_resources (Optional[ToolResources]): Additional tool configuration top_p (Optional[float]): Top p sampling parameter """ def __init__( self, name: str, description: str, client: "AsyncClient", model: str, instructions: str, tools: Optional[ Iterable[ Union[ Literal["code_interpreter", "file_search"], Tool | Callable[..., Any] | Callable[..., Awaitable[Any]], ] ] ] = None, assistant_id: Optional[str] = None, thread_id: Optional[str] = None, metadata: Optional[object] = None, response_format: Optional["AssistantResponseFormatOptionParam"] = None, temperature: Optional[float] = None, tool_resources: Optional["ToolResources"] = None, top_p: Optional[float] = None, ) -> None: if not _has_openai_dependencies: raise RuntimeError( "Missing dependecies for OpenAIAssistantAgent. Please ensure the autogen-ext package was installed with the 'openai' extra." ) super().__init__(name, description) if tools is None: tools = [] # Store original tools and converted tools separately self._original_tools: List[Tool] = [] converted_tools: List["AssistantToolParam"] = [] for tool in tools: if isinstance(tool, str): if tool == "code_interpreter": converted_tools.append(CodeInterpreterToolParam(type="code_interpreter")) elif tool == "file_search": converted_tools.append(FileSearchToolParam(type="file_search")) elif isinstance(tool, Tool): self._original_tools.append(tool) converted_tools.append(_convert_tool_to_function_param(tool)) elif callable(tool): if hasattr(tool, "__doc__") and tool.__doc__ is not None: description = tool.__doc__ else: description = "" function_tool = FunctionTool(tool, description=description) self._original_tools.append(function_tool) converted_tools.append(_convert_tool_to_function_param(function_tool)) else: raise ValueError(f"Unsupported tool type: {type(tool)}") self._client = client self._assistant: Optional["Assistant"] = None self._thread: Optional["Thread"] = None self._init_thread_id = thread_id self._model = model self._instructions = instructions self._api_tools = converted_tools self._assistant_id = assistant_id self._metadata = metadata self._response_format = response_format self._temperature = temperature self._tool_resources = tool_resources self._top_p = top_p self._vector_store_id: Optional[str] = None self._uploaded_file_ids: List[str] = [] # Variables to track initial state self._initial_message_ids: Set[str] = set() self._initial_state_retrieved: bool = False async def _ensure_initialized(self) -> None: """Ensure assistant and thread are created.""" if self._assistant is None: if self._assistant_id: self._assistant = await self._client.beta.assistants.retrieve(assistant_id=self._assistant_id) else: self._assistant = await self._client.beta.assistants.create( model=self._model, description=self.description, instructions=self._instructions, tools=self._api_tools, metadata=self._metadata, response_format=self._response_format if self._response_format else NOT_GIVEN, # type: ignore temperature=self._temperature, tool_resources=self._tool_resources if self._tool_resources else NOT_GIVEN, # type: ignore top_p=self._top_p, ) if self._thread is None: if self._init_thread_id: self._thread = await self._client.beta.threads.retrieve(thread_id=self._init_thread_id) else: self._thread = await self._client.beta.threads.create() # Retrieve initial state only once if not self._initial_state_retrieved: await self._retrieve_initial_state() self._initial_state_retrieved = True async def _retrieve_initial_state(self) -> None: """Retrieve and store the initial state of messages and runs.""" # Retrieve all initial message IDs initial_message_ids: Set[str] = set() after: str | NotGiven = NOT_GIVEN while True: msgs: AsyncCursorPage[Message] = await self._client.beta.threads.messages.list( self._thread_id, after=after, order="asc", limit=100 ) for msg in msgs.data: initial_message_ids.add(msg.id) if not msgs.has_next_page(): break after = msgs.data[-1].id self._initial_message_ids = initial_message_ids @property def produced_message_types(self) -> Tuple[type[ChatMessage], ...]: """The types of messages that the assistant agent produces.""" return (TextMessage,) @property def threads(self) -> AsyncThreads: return self._client.beta.threads @property def runs(self) -> AsyncRuns: return self._client.beta.threads.runs @property def messages(self) -> AsyncMessages: return self._client.beta.threads.messages @property def _get_assistant_id(self) -> str: if self._assistant is None: raise ValueError("Assistant not initialized") return self._assistant.id @property def _thread_id(self) -> str: if self._thread is None: raise ValueError("Thread not initialized") return self._thread.id async def _execute_tool_call(self, tool_call: FunctionCall, cancellation_token: CancellationToken) -> str: """Execute a tool call and return the result.""" try: if not self._original_tools: raise ValueError("No tools are available.") tool = next((t for t in self._original_tools if t.name == tool_call.name), None) if tool is None: raise ValueError(f"The tool '{tool_call.name}' is not available.") arguments = json.loads(tool_call.arguments) result = await tool.run_json(arguments, cancellation_token) return tool.return_value_as_string(result) except Exception as e: return f"Error: {e}"
[docs] async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response: """Handle incoming messages and return a response.""" async for message in self.on_messages_stream(messages, cancellation_token): if isinstance(message, Response): return message raise AssertionError("The stream should have returned the final result.")
[docs] async def on_messages_stream( self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken ) -> AsyncGenerator[AgentEvent | ChatMessage | Response, None]: """Handle incoming messages and return a response.""" await self._ensure_initialized() # Process all messages in sequence for message in messages: if isinstance(message, (TextMessage, MultiModalMessage)): await self.handle_text_message(str(message.content), cancellation_token) elif isinstance(message, (StopMessage, HandoffMessage)): await self.handle_text_message(message.content, cancellation_token) # Inner messages for tool calls inner_messages: List[AgentEvent | ChatMessage] = [] # Create and start a run run: Run = await cancellation_token.link_future( asyncio.ensure_future( self._client.beta.threads.runs.create( thread_id=self._thread_id, assistant_id=self._get_assistant_id, ) ) ) # Wait for run completion by polling while True: run = await cancellation_token.link_future( asyncio.ensure_future( self._client.beta.threads.runs.retrieve( thread_id=self._thread_id, run_id=run.id, ) ) ) if run.status == "failed": raise ValueError(f"Run failed: {run.last_error}") # If the run requires action (function calls), execute tools and continue if run.status == "requires_action" and run.required_action is not None: tool_calls: List[FunctionCall] = [] for required_tool_call in run.required_action.submit_tool_outputs.tool_calls: if required_tool_call.type == "function": tool_calls.append( FunctionCall( id=required_tool_call.id, name=required_tool_call.function.name, arguments=required_tool_call.function.arguments, ) ) # Add tool call message to inner messages tool_call_msg = ToolCallRequestEvent(source=self.name, content=tool_calls) inner_messages.append(tool_call_msg) event_logger.debug(tool_call_msg) yield tool_call_msg # Execute tool calls and get results tool_outputs: List[FunctionExecutionResult] = [] for tool_call in tool_calls: result = await self._execute_tool_call(tool_call, cancellation_token) tool_outputs.append(FunctionExecutionResult(content=result, call_id=tool_call.id)) # Add tool result message to inner messages tool_result_msg = ToolCallExecutionEvent(source=self.name, content=tool_outputs) inner_messages.append(tool_result_msg) event_logger.debug(tool_result_msg) yield tool_result_msg # Submit tool outputs back to the run run = await cancellation_token.link_future( asyncio.ensure_future( self._client.beta.threads.runs.submit_tool_outputs( thread_id=self._thread_id, run_id=run.id, tool_outputs=[{"tool_call_id": t.call_id, "output": t.content} for t in tool_outputs], ) ) ) continue if run.status == "completed": break await asyncio.sleep(0.5) # Get messages after run completion assistant_messages: AsyncCursorPage[Message] = await cancellation_token.link_future( asyncio.ensure_future( self._client.beta.threads.messages.list(thread_id=self._thread_id, order="desc", limit=1) ) ) if not assistant_messages.data: raise ValueError("No messages received from assistant") # Get the last message's content last_message = assistant_messages.data[0] if not last_message.content: raise ValueError(f"No content in the last message: {last_message}") # Extract text content text_content = [content for content in last_message.content if content.type == "text"] if not text_content: raise ValueError(f"Expected text content in the last message: {last_message.content}") # Return the assistant's response as a Response with inner messages chat_message = TextMessage(source=self.name, content=text_content[0].text.value) yield Response(chat_message=chat_message, inner_messages=inner_messages)
[docs] async def handle_text_message(self, content: str, cancellation_token: CancellationToken) -> None: """Handle regular text messages by adding them to the thread.""" await cancellation_token.link_future( asyncio.ensure_future( self._client.beta.threads.messages.create( thread_id=self._thread_id, content=content, role="user", ) ) )
[docs] async def on_reset(self, cancellation_token: CancellationToken) -> None: """Handle reset command by deleting new messages and runs since initialization.""" await self._ensure_initialized() # Retrieve all message IDs in the thread new_message_ids: List[str] = [] after: str | NotGiven = NOT_GIVEN while True: msgs: AsyncCursorPage[Message] = await cancellation_token.link_future( asyncio.ensure_future( self._client.beta.threads.messages.list(self._thread_id, after=after, order="asc", limit=100) ) ) for msg in msgs.data: if msg.id not in self._initial_message_ids: new_message_ids.append(msg.id) if not msgs.has_next_page(): break after = msgs.data[-1].id # Delete new messages for msg_id in new_message_ids: status: MessageDeleted = await cancellation_token.link_future( asyncio.ensure_future( self._client.beta.threads.messages.delete(message_id=msg_id, thread_id=self._thread_id) ) ) assert status.deleted is True
async def _upload_files(self, file_paths: str | Iterable[str], cancellation_token: CancellationToken) -> List[str]: """Upload files and return their IDs.""" if isinstance(file_paths, str): file_paths = [file_paths] file_ids: List[str] = [] for file_path in file_paths: async with aiofiles.open(file_path, mode="rb") as f: file_content = await cancellation_token.link_future(asyncio.ensure_future(f.read())) file_name = os.path.basename(file_path) file: FileObject = await cancellation_token.link_future( asyncio.ensure_future(self._client.files.create(file=(file_name, file_content), purpose="assistants")) ) file_ids.append(file.id) self._uploaded_file_ids.append(file.id) return file_ids
[docs] async def on_upload_for_code_interpreter( self, file_paths: str | Iterable[str], cancellation_token: CancellationToken ) -> None: """Handle file uploads for the code interpreter.""" file_ids = await self._upload_files(file_paths, cancellation_token) # Update thread with the new files thread = await cancellation_token.link_future( asyncio.ensure_future(self._client.beta.threads.retrieve(thread_id=self._thread_id)) ) tool_resources: ToolResources = thread.tool_resources or ToolResources() code_interpreter: ToolResourcesCodeInterpreter = ( tool_resources.code_interpreter or ToolResourcesCodeInterpreter() ) existing_file_ids: List[str] = code_interpreter.file_ids or [] existing_file_ids.extend(file_ids) tool_resources.code_interpreter = ToolResourcesCodeInterpreter(file_ids=existing_file_ids) await cancellation_token.link_future( asyncio.ensure_future( self._client.beta.threads.update( thread_id=self._thread_id, tool_resources=cast(thread_update_params.ToolResources, tool_resources.model_dump()), ) ) )
[docs] async def delete_uploaded_files(self, cancellation_token: CancellationToken) -> None: """Delete all files that were uploaded by this agent instance.""" for file_id in self._uploaded_file_ids: try: await cancellation_token.link_future(asyncio.ensure_future(self._client.files.delete(file_id=file_id))) except Exception as e: event_logger.error(f"Failed to delete file {file_id}: {str(e)}") self._uploaded_file_ids = []
[docs] async def delete_assistant(self, cancellation_token: CancellationToken) -> None: """Delete the assistant if it was created by this instance.""" if self._assistant is not None and not self._assistant_id: try: await cancellation_token.link_future( asyncio.ensure_future(self._client.beta.assistants.delete(assistant_id=self._get_assistant_id)) ) self._assistant = None except Exception as e: event_logger.error(f"Failed to delete assistant: {str(e)}")
[docs] async def delete_vector_store(self, cancellation_token: CancellationToken) -> None: """Delete the vector store if it was created by this instance.""" if self._vector_store_id is not None: try: await cancellation_token.link_future( asyncio.ensure_future(self._client.beta.vector_stores.delete(vector_store_id=self._vector_store_id)) ) self._vector_store_id = None except Exception as e: event_logger.error(f"Failed to delete vector store: {str(e)}")