Source code for autogen_ext.experimental.task_centric_memory.utils.apprentice
importrandomimporttimefromtypingimportTYPE_CHECKING,Any,List,Sequence,Tuple,TypedDictfromautogen_agentchat.agentsimportAssistantAgentfromautogen_agentchat.baseimportTaskResultfromautogen_agentchat.messagesimportAgentEvent,ChatMessage,TextMessagefromautogen_core.modelsimport(ChatCompletionClient,LLMMessage,SystemMessage,UserMessage,)from.page_loggerimportPageLoggerifTYPE_CHECKING:from..memory_controllerimportMemoryControllerConfig# Following the nested-config pattern, this TypedDict minimizes code changes by encapsulating# the settings that change frequently, as when loading many settings from a single YAML file.classApprenticeConfig(TypedDict,total=False):name_of_agent_or_team:strdisable_prefix_caching:boolMemoryController:"MemoryControllerConfig"
[docs]classApprentice:""" A minimal wrapper combining task-centric memory with an agent or team. Applications may use the Apprentice class, or they may directly instantiate and call the Memory Controller using this class as an example. Args: client: The client to call the model. config: An optional dict that can be used to override the following values: - name_of_agent_or_team: The name of the target agent or team for assigning tasks to. - disable_prefix_caching: True to disable prefix caching by prepending random ints to the first message. - MemoryController: A config dict passed to MemoryController. logger: An optional logger. If None, a default logger will be created. """def__init__(self,client:ChatCompletionClient,config:ApprenticeConfig|None=None,logger:PageLogger|None=None,)->None:ifloggerisNone:logger=PageLogger({"level":"DEBUG"})self.logger=logger# Apply default settings and any config overrides.self.name_of_agent_or_team="AssistantAgent"self.disable_prefix_caching=Falsememory_controller_config=NoneifconfigisnotNone:self.name_of_agent_or_team=config.get("name_of_agent_or_team",self.name_of_agent_or_team)self.disable_prefix_caching=config.get("disable_prefix_caching",self.disable_prefix_caching)memory_controller_config=config.get("MemoryController",memory_controller_config)self.client=clientifself.disable_prefix_caching:self.rand=random.Random()self.rand.seed(int(time.time()*1000))# Create the MemoryController, which creates the MemoryBank.from..memory_controllerimportMemoryControllerself.memory_controller=MemoryController(reset=True,client=self.client,task_assignment_callback=self.assign_task_to_agent_or_team,config=memory_controller_config,logger=self.logger,)
[docs]defreset_memory(self)->None:""" Resets the memory bank. """self.memory_controller.reset_memory()
[docs]asyncdefhandle_user_message(self,text:str,should_await:bool=True)->str:""" Handles a user message, extracting any advice and assigning a task to the agent. """self.logger.enter_function()# Pass the user message through to the memory controller.response=awaitself.memory_controller.handle_user_message(text,should_await)self.logger.leave_function()returnresponse
[docs]asyncdefadd_task_solution_pair_to_memory(self,task:str,solution:str)->None:""" Adds a task-solution pair to the memory bank, to be retrieved together later as a combined insight. This is useful when the insight is a demonstration of how to solve a given type of task. """self.logger.enter_function()# Pass the task and solution through to the memory controller.awaitself.memory_controller.add_task_solution_pair_to_memory(task,solution)self.logger.leave_function()
[docs]asyncdefassign_task(self,task:str,use_memory:bool=True,should_await:bool=True)->str:""" Assigns a task to the agent, along with any relevant insights/memories. """self.logger.enter_function()# Pass the task through to the memory controller.response=awaitself.memory_controller.assign_task(task,use_memory,should_await)self.logger.leave_function()returnresponse
[docs]asyncdeftrain_on_task(self,task:str,expected_answer:str)->None:""" Repeatedly assigns a task to the completion agent, and tries to learn from failures by creating useful insights as memories. """self.logger.enter_function()# Pass the task through to the memory controller.awaitself.memory_controller.train_on_task(task,expected_answer)self.logger.leave_function()
[docs]asyncdefassign_task_to_agent_or_team(self,task:str)->Tuple[str,str]:""" Passes the given task to the target agent or team. """self.logger.enter_function()# Pass the task through.ifself.name_of_agent_or_team=="MagenticOneGroupChat":response,work_history=awaitself._assign_task_to_magentic_one(task)elifself.name_of_agent_or_team=="AssistantAgent":response,work_history=awaitself._assign_task_to_assistant_agent(task)else:raiseAssertionError("Invalid base agent")self.logger.leave_function()returnresponse,work_history
asyncdef_assign_task_to_assistant_agent(self,task:str)->Tuple[Any,Any]:""" Passes the given task to a newly created AssistantAgent with a generic 6-step system prompt. """self.logger.enter_function()self.logger.info(task)system_message_content="""You are a helpful and thoughtful assistant.In responding to every user message, you follow the same multi-step process given here:1. Explain your understanding of the user message in detail, covering all the important points.2. List as many possible responses as you can think of.3. Carefully list and weigh the pros and cons (if any) of each possible response.4. Critique the pros and cons above, looking for any flaws in your reasoning. But don't make up flaws that don't exist.5. Decide on the best response, looping back to step 1 if none of the responses are satisfactory.6. Finish by providing your final response in the particular format requested by the user."""ifself.disable_prefix_caching:# Prepend a random int to disable prefix caching.random_str="({})\n\n".format(self.rand.randint(0,1000000))system_message_content=random_str+system_message_contentsystem_message:LLMMessageifself.client.model_info["family"]=="o1":# No system message allowed, so pass it as the first user message.system_message=UserMessage(content=system_message_content,source="User")else:# System message allowed.system_message=SystemMessage(content=system_message_content)user_message:LLMMessage=UserMessage(content=task,source="User")system_message_list:List[LLMMessage]=[system_message]user_message_list:List[LLMMessage]=[user_message]input_messages:List[LLMMessage]=system_message_list+user_message_listassistant_agent=AssistantAgent("assistant_agent",self.client,system_message=system_message_content,)# Get the agent's response to the task.task_result:TaskResult=awaitassistant_agent.run(task=TextMessage(content=task,source="User"))messages:Sequence[AgentEvent|ChatMessage]=task_result.messagesmessage:AgentEvent|ChatMessage=messages[-1]response_str=message.content# Log the model callself.logger.log_model_task(summary="Ask the model to complete the task",input_messages=input_messages,task_result=task_result)self.logger.info("\n----- RESPONSE -----\n\n{}\n".format(response_str))# Use the response as the work history as well.work_history=response_strself.logger.leave_function()returnresponse_str,work_historyasyncdef_assign_task_to_magentic_one(self,task:str)->Tuple[str,str]:""" Instantiates a MagenticOneGroupChat team, and passes the given task to it. """self.logger.enter_function()self.logger.info(task)general_agent=AssistantAgent("general_agent",self.client,description="A general GPT-4o AI assistant capable of performing a variety of tasks.",)fromautogen_ext.agents.web_surferimportMultimodalWebSurferweb_surfer=MultimodalWebSurfer(name="web_surfer",model_client=self.client,downloads_folder="logs",debug_dir="logs",to_save_screenshots=True,)fromautogen_agentchat.teamsimportMagenticOneGroupChatteam=MagenticOneGroupChat([general_agent,web_surfer],model_client=self.client,max_turns=20,)# Get the team's response to the task.task_result:TaskResult=awaitteam.run(task=task)assertisinstance(task_result,TaskResult)messages=task_result.messagesresponse_str_list:List[str]=[]formessageinmessages:content=message.contentifisinstance(content,str):content_str=contentelse:content_str="Not a string."response_str_list.append(content_str)response_str="\n".join(response_str_list)self.logger.info("\n----- RESPONSE -----\n\n{}\n".format(response_str))# MagenticOne's response is the chat history, which we use here as the work history.work_history=response_strself.logger.leave_function()returnresponse_str,work_history