TL;DR
AutoGen now supports custom models! This feature empowers users to define and load their own models, allowing for a more flexible and personalized inference mechanism. By adhering to a specific protocol, you can integrate your custom model for use with AutoGen and respond to prompts any way needed by using any model/API call/hardcoded response you want.
NOTE: Depending on what model you use, you may need to play with the default prompts of the Agent's
Quickstart
An interactive and easy way to get started is by following the notebook here which loads a local model from HuggingFace into AutoGen and uses it for inference, and making changes to the class provided.
Step 1: Create the custom model client class
To get started with using custom models in AutoGen, you need to create a model client class that adheres to the ModelClient
protocol defined in client.py
. The new model client class should implement these methods:
create()
: Returns a response object that implements theModelClientResponseProtocol
(more details in the Protocol section).message_retrieval()
: Processes the response object and returns a list of strings or a list of message objects (more details in the Protocol section).cost()
: Returns the cost of the response.get_usage()
: Returns a dictionary with keys fromRESPONSE_USAGE_KEYS = ["prompt_tokens", "completion_tokens", "total_tokens", "cost", "model"]
.
E.g. of a bare bones dummy custom class:
class CustomModelClient:
def __init__(self, config, **kwargs):
print(f"CustomModelClient config: {config}")
def create(self, params):
num_of_responses = params.get("n", 1)
# can create my own data response class
# here using SimpleNamespace for simplicity
# as long as it adheres to the ModelClientResponseProtocol
response = SimpleNamespace()
response.choices = []
response.model = "model_name" # should match the OAI_CONFIG_LIST registration
for _ in range(num_of_responses):
text = "this is a dummy text response"
choice = SimpleNamespace()
choice.message = SimpleNamespace()
choice.message.content = text
choice.message.function_call = None
response.choices.append(choice)
return response
def message_retrieval(self, response):
choices = response.choices
return [choice.message.content for choice in choices]
def cost(self, response) -> float:
response.cost = 0
return 0
@staticmethod
def get_usage(response):
return {}
Step 2: Add the configuration to the OAI_CONFIG_LIST
The field that is necessary is setting model_client_cls
to the name of the new class (as a string) "model_client_cls":"CustomModelClient"
. Any other fields will be forwarded to the class constructor, so you have full control over what parameters to specify and how to use them. E.g.:
{
"model": "Open-Orca/Mistral-7B-OpenOrca",
"model_client_cls": "CustomModelClient",
"device": "cuda",
"n": 1,
"params": {
"max_length": 1000,
}
}
Step 3: Register the new custom model to the agent that will use it
If a configuration with the field "model_client_cls":"<class name>"
has been added to an Agent's config list, then the corresponding model with the desired class must be registered after the agent is created and before the conversation is initialized:
my_agent.register_model_client(model_client_cls=CustomModelClient, [other args that will be forwarded to CustomModelClient constructor])
model_client_cls=CustomModelClient
arg matches the one specified in the OAI_CONFIG_LIST
and CustomModelClient
is the class that adheres to the ModelClient
protocol (more details on the protocol below).
If the new model client is in the config list but not registered by the time the chat is initialized, then an error will be raised.
Protocol details
A custom model class can be created in many ways, but needs to adhere to the ModelClient
protocol and response structure which is defined in client.py
and shown below.
The response protocol is currently using the minimum required fields from the autogen codebase that match the OpenAI response structure. Any response protocol that matches the OpenAI response structure will probably be more resilient to future changes, but we are starting off with minimum requirements to make adpotion of this feature easier.
class ModelClient(Protocol):
"""
A client class must implement the following methods:
- create must return a response object that implements the ModelClientResponseProtocol
- cost must return the cost of the response
- get_usage must return a dict with the following keys:
- prompt_tokens
- completion_tokens
- total_tokens
- cost
- model
This class is used to create a client that can be used by OpenAIWrapper.
The response returned from create must adhere to the ModelClientResponseProtocol but can be extended however needed.
The message_retrieval method must be implemented to return a list of str or a list of messages from the response.
"""
RESPONSE_USAGE_KEYS = ["prompt_tokens", "completion_tokens", "total_tokens", "cost", "model"]
class ModelClientResponseProtocol(Protocol):
class Choice(Protocol):
class Message(Protocol):
content: Optional[str]
message: Message
choices: List[Choice]
model: str
def create(self, params) -> ModelClientResponseProtocol:
...
def message_retrieval(
self, response: ModelClientResponseProtocol
) -> Union[List[str], List[ModelClient.ModelClientResponseProtocol.Choice.Message]]:
"""
Retrieve and return a list of strings or a list of Choice.Message from the response.
NOTE: if a list of Choice.Message is returned, it currently needs to contain the fields of OpenAI's ChatCompletion Message object,
since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used.
"""
...
def cost(self, response: ModelClientResponseProtocol) -> float:
...
@staticmethod
def get_usage(response: ModelClientResponseProtocol) -> Dict:
"""Return usage summary of the response using RESPONSE_USAGE_KEYS."""
...
Troubleshooting steps
If something doesn't work then run through the checklist:
- Make sure you have followed the client protocol and client response protocol when creating the custom model class
create()
method:ModelClientResponseProtocol
must be followed when returning an inference response duringcreate
call.message_retrieval()
method: returns a list of strings or a list of message objects. If a list of message objects is returned, they currently must contain the fields of OpenAI's ChatCompletion Message object, since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used.cost()
method: returns an integer, and if you don't care about cost tracking you can just return0
.get_usage()
: returns a dictionary, and if you don't care about usage tracking you can just return an empty dictionary{}
.
- Make sure you have a corresponding entry in the
OAI_CONFIG_LIST
and that that entry has the"model_client_cls":"<custom-model-class-name>"
field. - Make sure you have registered the client using the corresponding config entry and your new class
agent.register_model_client(model_client_cls=<class-of-custom-model>, [other optional args])
- Make sure that all of the custom models defined in the
OAI_CONFIG_LIST
have been registered. - Any other troubleshooting might need to be done in the custom code itself.
Conclusion
With the ability to use custom models, AutoGen now offers even more flexibility and power for your AI applications. Whether you've trained your own model or want to use a specific pre-trained model, AutoGen can accommodate your needs. Happy coding!