Source code for autogen_core.tools._static_workbench

import asyncio
import builtins
from typing import Any, AsyncGenerator, Dict, List, Literal, Mapping

from pydantic import BaseModel
from typing_extensions import Self

from .._cancellation_token import CancellationToken
from .._component_config import Component, ComponentModel
from ._base import BaseTool, StreamTool, ToolSchema
from ._workbench import StreamWorkbench, TextResultContent, ToolResult, Workbench


class StaticWorkbenchConfig(BaseModel):
    tools: List[ComponentModel] = []


class StateicWorkbenchState(BaseModel):
    type: Literal["StaticWorkbenchState"] = "StaticWorkbenchState"
    tools: Dict[str, Mapping[str, Any]] = {}


[docs] class StaticWorkbench(Workbench, Component[StaticWorkbenchConfig]): """ A workbench that provides a static set of tools that do not change after each tool execution. Args: tools (List[BaseTool[Any, Any]]): A list of tools to be included in the workbench. The tools should be subclasses of :class:`~autogen_core.tools.BaseTool`. """ component_provider_override = "autogen_core.tools.StaticWorkbench" component_config_schema = StaticWorkbenchConfig def __init__(self, tools: List[BaseTool[Any, Any]]) -> None: self._tools = tools
[docs] async def list_tools(self) -> List[ToolSchema]: return [tool.schema for tool in self._tools]
[docs] async def call_tool( self, name: str, arguments: Mapping[str, Any] | None = None, cancellation_token: CancellationToken | None = None, call_id: str | None = None, ) -> ToolResult: tool = next((tool for tool in self._tools if tool.name == name), None) if tool is None: return ToolResult( name=name, result=[TextResultContent(content=f"Tool {name} not found.")], is_error=True, ) if not cancellation_token: cancellation_token = CancellationToken() if not arguments: arguments = {} try: result_future = asyncio.ensure_future(tool.run_json(arguments, cancellation_token, call_id=call_id)) cancellation_token.link_future(result_future) actual_tool_output = await result_future is_error = False result_str = tool.return_value_as_string(actual_tool_output) except Exception as e: result_str = self._format_errors(e) is_error = True return ToolResult(name=tool.name, result=[TextResultContent(content=result_str)], is_error=is_error)
[docs] async def start(self) -> None: return None
[docs] async def stop(self) -> None: return None
[docs] async def reset(self) -> None: return None
[docs] async def save_state(self) -> Mapping[str, Any]: tool_states = StateicWorkbenchState() for tool in self._tools: tool_states.tools[tool.name] = await tool.save_state_json() return tool_states.model_dump()
[docs] async def load_state(self, state: Mapping[str, Any]) -> None: parsed_state = StateicWorkbenchState.model_validate(state) for tool in self._tools: if tool.name in parsed_state.tools: await tool.load_state_json(parsed_state.tools[tool.name])
[docs] def _to_config(self) -> StaticWorkbenchConfig: return StaticWorkbenchConfig(tools=[tool.dump_component() for tool in self._tools])
[docs] @classmethod def _from_config(cls, config: StaticWorkbenchConfig) -> Self: return cls(tools=[BaseTool.load_component(tool) for tool in config.tools])
def _format_errors(self, error: Exception) -> str: """Recursively format errors into a string.""" error_message = "" if hasattr(builtins, "ExceptionGroup") and isinstance(error, builtins.ExceptionGroup): # ExceptionGroup is available in Python 3.11+. # TODO: how to make this compatible with Python 3.10? for sub_exception in error.exceptions: # type: ignore error_message += self._format_errors(sub_exception) # type: ignore else: error_message += f"{str(error)}\n" return error_message.strip()
[docs] class StaticStreamWorkbench(StaticWorkbench, StreamWorkbench): """ A workbench that provides a static set of tools that do not change after each tool execution, and supports streaming results. """ component_provider_override = "autogen_core.tools.StaticStreamWorkbench"
[docs] async def call_tool_stream( self, name: str, arguments: Mapping[str, Any] | None = None, cancellation_token: CancellationToken | None = None, call_id: str | None = None, ) -> AsyncGenerator[Any | ToolResult, None]: tool = next((tool for tool in self._tools if tool.name == name), None) if tool is None: yield ToolResult( name=name, result=[TextResultContent(content=f"Tool {name} not found.")], is_error=True, ) return if not cancellation_token: cancellation_token = CancellationToken() if not arguments: arguments = {} try: actual_tool_output: Any | None = None if isinstance(tool, StreamTool): previous_result: Any | None = None try: async for result in tool.run_json_stream(arguments, cancellation_token, call_id=call_id): if previous_result is not None: yield previous_result previous_result = result actual_tool_output = previous_result except Exception as e: # If there was a previous result before the exception, yield it first if previous_result is not None: yield previous_result # Then yield the error result result_str = self._format_errors(e) yield ToolResult(name=tool.name, result=[TextResultContent(content=result_str)], is_error=True) return else: # If the tool is not a stream tool, we run it normally and yield the result result_future = asyncio.ensure_future(tool.run_json(arguments, cancellation_token, call_id=call_id)) cancellation_token.link_future(result_future) actual_tool_output = await result_future is_error = False result_str = tool.return_value_as_string(actual_tool_output) except Exception as e: result_str = self._format_errors(e) is_error = True yield ToolResult(name=tool.name, result=[TextResultContent(content=result_str)], is_error=is_error)