Agent Chat with custom model loading
In this notebook, we demonstrate how a custom model can be defined and loaded, and what protocol it needs to comply to.
NOTE: Depending on what model you use, you may need to play with the default prompts of the Agent’s
Requirements
Some extra dependencies are needed for this notebook, which can be installed via pip:
pip install autogen-agentchat~=0.2 torch transformers sentencepiece
For more information, please refer to the installation guide.
from types import SimpleNamespace
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
import autogen
from autogen import AssistantAgent, UserProxyAgent
Create and configure the custom model
A custom model class can be created in many ways, but needs to adhere to
the ModelClient
protocol and response structure which is defined in
client.py and shown below.
The response protocol has some minimum requirements, but can be extended
to include any additional information that is needed. Message retrieval
therefore can be customized, but needs to return a list of strings or a
list of ModelClientResponseProtocol.Choice.Message
objects.
class ModelClient(Protocol):
"""
A client class must implement the following methods:
- create must return a response object that implements the ModelClientResponseProtocol
- cost must return the cost of the response
- get_usage must return a dict with the following keys:
- prompt_tokens
- completion_tokens
- total_tokens
- cost
- model
This class is used to create a client that can be used by OpenAIWrapper.
The response returned from create must adhere to the ModelClientResponseProtocol but can be extended however needed.
The message_retrieval method must be implemented to return a list of str or a list of messages from the response.
"""
RESPONSE_USAGE_KEYS = ["prompt_tokens", "completion_tokens", "total_tokens", "cost", "model"]
class ModelClientResponseProtocol(Protocol):
class Choice(Protocol):
class Message(Protocol):
content: Optional[str]
message: Message
choices: List[Choice]
model: str
def create(self, params) -> ModelClientResponseProtocol:
...
def message_retrieval(
self, response: ModelClientResponseProtocol
) -> Union[List[str], List[ModelClient.ModelClientResponseProtocol.Choice.Message]]:
"""
Retrieve and return a list of strings or a list of Choice.Message from the response.
NOTE: if a list of Choice.Message is returned, it currently needs to contain the fields of OpenAI's ChatCompletion Message object,
since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used.
"""
...
def cost(self, response: ModelClientResponseProtocol) -> float:
...
@staticmethod
def get_usage(response: ModelClientResponseProtocol) -> Dict:
"""Return usage summary of the response using RESPONSE_USAGE_KEYS."""
...
Example of simple custom client
Following the huggingface example for using Mistral’s Open-Orca
For the response object, python’s SimpleNamespace
is used to create a
simple object that can be used to store the response data, but any
object that follows the ClientResponseProtocol
can be used.
# custom client with custom model loader
class CustomModelClient:
def __init__(self, config, **kwargs):
print(f"CustomModelClient config: {config}")
self.device = config.get("device", "cpu")
self.model = AutoModelForCausalLM.from_pretrained(config["model"]).to(self.device)
self.model_name = config["model"]
self.tokenizer = AutoTokenizer.from_pretrained(config["model"], use_fast=False)
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
# params are set by the user and consumed by the user since they are providing a custom model
# so anything can be done here
gen_config_params = config.get("params", {})
self.max_length = gen_config_params.get("max_length", 256)
print(f"Loaded model {config['model']} to {self.device}")
def create(self, params):
if params.get("stream", False) and "messages" in params:
raise NotImplementedError("Local models do not support streaming.")
else:
num_of_responses = params.get("n", 1)
# can create my own data response class
# here using SimpleNamespace for simplicity
# as long as it adheres to the ClientResponseProtocol
response = SimpleNamespace()
inputs = self.tokenizer.apply_chat_template(
params["messages"], return_tensors="pt", add_generation_prompt=True
).to(self.device)
inputs_length = inputs.shape[-1]
# add inputs_length to max_length
max_length = self.max_length + inputs_length
generation_config = GenerationConfig(
max_length=max_length,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.eos_token_id,
)
response.choices = []
response.model = self.model_name
for _ in range(num_of_responses):
outputs = self.model.generate(inputs, generation_config=generation_config)
# Decode only the newly generated text, excluding the prompt
text = self.tokenizer.decode(outputs[0, inputs_length:])
choice = SimpleNamespace()
choice.message = SimpleNamespace()
choice.message.content = text
choice.message.function_call = None
response.choices.append(choice)
return response
def message_retrieval(self, response):
"""Retrieve the messages from the response."""
choices = response.choices
return [choice.message.content for choice in choices]
def cost(self, response) -> float:
"""Calculate the cost of the response."""
response.cost = 0
return 0
@staticmethod
def get_usage(response):
# returns a dict of prompt_tokens, completion_tokens, total_tokens, cost, model
# if usage needs to be tracked, else None
return {}
Set your API Endpoint
The
config_list_from_json
function loads a list of configurations from an environment variable or
a json file.
It first looks for an environment variable of a specified name (“OAI_CONFIG_LIST” in this example), which needs to be a valid json string. If that variable is not found, it looks for a json file with the same name. It filters the configs by models (you can filter by other keys as well).
The json looks like the following:
[
{
"model": "gpt-4",
"api_key": "<your OpenAI API key here>"
},
{
"model": "gpt-4",
"api_key": "<your Azure OpenAI API key here>",
"base_url": "<your Azure OpenAI API base here>",
"api_type": "azure",
"api_version": "2024-02-01"
},
{
"model": "gpt-4-32k",
"api_key": "<your Azure OpenAI API key here>",
"base_url": "<your Azure OpenAI API base here>",
"api_type": "azure",
"api_version": "2024-02-01"
}
]
You can set the value of config_list in any way you prefer. Please refer to this notebook for full code examples of the different methods.
Set the config for the custom model
You can add any paramteres that are needed for the custom model loading in the same configuration list.
It is important to add the model_client_cls
field and set it to a
string that corresponds to the class name: "CustomModelClient"
.
{
"model": "Open-Orca/Mistral-7B-OpenOrca",
"model_client_cls": "CustomModelClient",
"device": "cuda",
"n": 1,
"params": {
"max_length": 1000,
}
},
config_list_custom = autogen.config_list_from_json(
"OAI_CONFIG_LIST",
filter_dict={"model_client_cls": ["CustomModelClient"]},
)
Construct Agents
Consturct a simple conversation between a User proxy and an Assistent agent
assistant = AssistantAgent("assistant", llm_config={"config_list": config_list_custom})
user_proxy = UserProxyAgent(
"user_proxy",
code_execution_config={
"work_dir": "coding",
"use_docker": False, # Please set use_docker=True if docker is available to run the generated code. Using docker is safer than running the generated code directly.
},
)
Register the custom client class to the assistant agent
assistant.register_model_client(model_client_cls=CustomModelClient)
user_proxy.initiate_chat(assistant, message="Write python code to print Hello World!")
Register a custom client class with a pre-loaded model
If you want to have more control over when the model gets loaded, you can load the model yourself and pass it as an argument to the CustomClient during registration
# custom client with custom model loader
class CustomModelClientWithArguments(CustomModelClient):
def __init__(self, config, loaded_model, tokenizer, **kwargs):
print(f"CustomModelClientWithArguments config: {config}")
self.model_name = config["model"]
self.model = loaded_model
self.tokenizer = tokenizer
self.device = config.get("device", "cpu")
gen_config_params = config.get("params", {})
self.max_length = gen_config_params.get("max_length", 256)
print(f"Loaded model {config['model']} to {self.device}")
# load model here
config = config_list_custom[0]
device = config.get("device", "cpu")
loaded_model = AutoModelForCausalLM.from_pretrained(config["model"]).to(device)
tokenizer = AutoTokenizer.from_pretrained(config["model"], use_fast=False)
tokenizer.pad_token_id = tokenizer.eos_token_id
Add the config of the new custom model
{
"model": "Open-Orca/Mistral-7B-OpenOrca",
"model_client_cls": "CustomModelClientWithArguments",
"device": "cuda",
"n": 1,
"params": {
"max_length": 1000,
}
},
config_list_custom = autogen.config_list_from_json(
"OAI_CONFIG_LIST",
filter_dict={"model_client_cls": ["CustomModelClientWithArguments"]},
)
assistant = AssistantAgent("assistant", llm_config={"config_list": config_list_custom})
assistant.register_model_client(
model_client_cls=CustomModelClientWithArguments,
loaded_model=loaded_model,
tokenizer=tokenizer,
)
user_proxy.initiate_chat(assistant, message="Write python code to print Hello World!")