Source code for autogen_ext.experimental.task_centric_memory.utils.grader

from __future__ import annotations

from typing import TYPE_CHECKING, List, Tuple, Union

from autogen_core import Image
from autogen_core.models import (
    AssistantMessage,
    ChatCompletionClient,
    CreateResult,
    LLMMessage,
    SystemMessage,
    UserMessage,
)

from ._functions import UserContent
from .page_logger import PageLogger

if TYPE_CHECKING:
    from .apprentice import Apprentice


[docs] class Grader: """ Runs basic tests, and determines task success without limitation to string matches. Args: client: The client to call the model. logger: An optional logger. If None, no logging will be performed. """ def __init__(self, client: ChatCompletionClient, logger: PageLogger | None = None) -> None: if logger is None: logger = PageLogger() # Nothing will be logged by this object. self.logger = logger self.client = client # Create the chat history self._chat_history: List[LLMMessage] = []
[docs] async def test_apprentice( self, apprentice: Apprentice, task_description: str, expected_answer: str, num_trials: int, use_memory: bool, client: ChatCompletionClient, ) -> Tuple[int, int]: self.logger.enter_function() self.logger.info("Testing the apprentice on the given task.\n") num_successes = 0 for trial in range(num_trials): self.logger.info("\n----- TRIAL {} -----\n".format(trial + 1)) self.logger.info("Try to solve the task.\n") response = await apprentice.assign_task(task_description, use_memory=use_memory) response_is_correct, extracted_answer = await self.is_response_correct( task_description, response, expected_answer ) self.logger.info("Extracted answer: {}".format(extracted_answer)) if response_is_correct: self.logger.info("Answer is CORRECT.\n") num_successes += 1 else: self.logger.info("Answer is INCORRECT.\n") self.logger.info("\nSuccess rate: {}%\n".format(round((num_successes / num_trials) * 100))) self.logger.leave_function() return num_successes, num_trials
[docs] async def call_model( self, summary: str, user_content: UserContent, system_message_content: str | None = None, keep_these_messages: bool = True, ) -> str: """ Calls the model client with the given input and returns the response. """ # Prepare the input message list if system_message_content is None: system_message_content = "You are a helpful assistant." system_message: LLMMessage if self.client.model_info["family"] == "o1": # No system message allowed, so pass it as the first user message. system_message = UserMessage(content=system_message_content, source="User") else: # System message allowed. system_message = SystemMessage(content=system_message_content) user_message = UserMessage(content=user_content, source="User") input_messages = [system_message] + self._chat_history + [user_message] # Call the model. response = await self.client.create(input_messages) assert isinstance(response, CreateResult) response_string = response.content assert isinstance(response_string, str) response_message = AssistantMessage(content=response_string, source="Assistant") assert isinstance(response_message, AssistantMessage) # Log the model call self.logger.log_model_call(summary=summary, input_messages=input_messages, response=response) # Manage the chat history if keep_these_messages: self._chat_history.append(user_message) self._chat_history.append(response_message) # Return the response as a string return response_string
def _clear_history(self) -> None: """ Empties the message list containing the chat history. """ self._chat_history = []
[docs] async def is_response_correct( self, task_description: str, response_to_be_graded: str, correct_answer: str ) -> Tuple[bool, str]: """ Determines whether the response is equivalent to the task's correct answer. """ self.logger.enter_function() sys_message = """You are a helpful and thoughtful assistant.""" # Ask the model to extract the answer from the response. user_message: List[Union[str, Image]] = [] user_message.append("""Your job is to extract a possible answer to the following question from the given text. - First review the following task. - Then review the text that follows, which may an answer, plus reasoning that led to the answer. - Do not attempt to actually solve the task yourself. - Don't try to judge whether the reasoning steps were correct. - Simply respond by summarizing the answer described in the text, omitting any other parts of the text. - If no answer is present can be extracted from the text, simply reply "None".""") user_message.append("\n# Task description") user_message.append(task_description) user_message.append("\n# Text that may contain an answer") user_message.append(response_to_be_graded) user_message_arg: UserContent = user_message self._clear_history() extracted_answer = await self.call_model( summary="Ask the model to extract the answer", system_message_content=sys_message, user_content=user_message_arg, ) self.logger.info("Extracted answer: " + extracted_answer) # Ask the model to check the answer for correctness. user_message = [ """Your job is to decide whether a given answer to a task is correct or not. - You will be given the task description and the correct, gold-standard answer, along with the answer to be graded. - In general, an answer is correct if it is equivalent to the correct answer. - Specifically, the given answer must contain the important information from the correct answer, and must not in any way contradict the correct answer. - Ignore any differences of grammar, spelling mistakes, punctuation, capitalization, formatting, or extra commentary. - An answer should be considered correct if it omits information that is clearly inferred. - For instance, if the correct answer is "Paris, France", the answer "Paris" should be considered correct. - Respond with a single character: '1' if the answer to be graded is correct", '0' if not.""" ] user_message.append("\n# Task description") user_message.append(task_description) user_message.append("\n# Correct answer") user_message.append(correct_answer) user_message.append("\n# Answer to be graded") user_message.append(extracted_answer) self._clear_history() decision = await self.call_model( summary="Ask the model to check the answer for correctness", system_message_content=sys_message, user_content=user_message, ) self.logger.info("Decision: " + decision) self.logger.leave_function() return decision == "1", extracted_answer