Skip to main content

AutoGen with Custom Models: Empowering Users to Use Their Own Inference Mechanism

· 6 min read
Olga Vrousgou

TL;DR

AutoGen now supports custom models! This feature empowers users to define and load their own models, allowing for a more flexible and personalized inference mechanism. By adhering to a specific protocol, you can integrate your custom model for use with AutoGen and respond to prompts any way needed by using any model/API call/hardcoded response you want.

NOTE: Depending on what model you use, you may need to play with the default prompts of the Agent's

Quickstart

An interactive and easy way to get started is by following the notebook here which loads a local model from HuggingFace into AutoGen and uses it for inference, and making changes to the class provided.

Step 1: Create the custom model client class

To get started with using custom models in AutoGen, you need to create a model client class that adheres to the ModelClient protocol defined in client.py. The new model client class should implement these methods:

  • create(): Returns a response object that implements the ModelClientResponseProtocol (more details in the Protocol section).
  • message_retrieval(): Processes the response object and returns a list of strings or a list of message objects (more details in the Protocol section).
  • cost(): Returns the cost of the response.
  • get_usage(): Returns a dictionary with keys from RESPONSE_USAGE_KEYS = ["prompt_tokens", "completion_tokens", "total_tokens", "cost", "model"].

E.g. of a bare bones dummy custom class:

class CustomModelClient:
def __init__(self, config, **kwargs):
print(f"CustomModelClient config: {config}")

def create(self, params):
num_of_responses = params.get("n", 1)

# can create my own data response class
# here using SimpleNamespace for simplicity
# as long as it adheres to the ModelClientResponseProtocol

response = SimpleNamespace()
response.choices = []
response.model = "model_name" # should match the OAI_CONFIG_LIST registration

for _ in range(num_of_responses):
text = "this is a dummy text response"
choice = SimpleNamespace()
choice.message = SimpleNamespace()
choice.message.content = text
choice.message.function_call = None
response.choices.append(choice)
return response

def message_retrieval(self, response):
choices = response.choices
return [choice.message.content for choice in choices]

def cost(self, response) -> float:
response.cost = 0
return 0

@staticmethod
def get_usage(response):
return {}

Step 2: Add the configuration to the OAI_CONFIG_LIST

The field that is necessary is setting model_client_cls to the name of the new class (as a string) "model_client_cls":"CustomModelClient". Any other fields will be forwarded to the class constructor, so you have full control over what parameters to specify and how to use them. E.g.:

{
"model": "Open-Orca/Mistral-7B-OpenOrca",
"model_client_cls": "CustomModelClient",
"device": "cuda",
"n": 1,
"params": {
"max_length": 1000,
}
}

Step 3: Register the new custom model to the agent that will use it

If a configuration with the field "model_client_cls":"<class name>" has been added to an Agent's config list, then the corresponding model with the desired class must be registered after the agent is created and before the conversation is initialized:

my_agent.register_model_client(model_client_cls=CustomModelClient, [other args that will be forwarded to CustomModelClient constructor])

model_client_cls=CustomModelClient arg matches the one specified in the OAI_CONFIG_LIST and CustomModelClient is the class that adheres to the ModelClient protocol (more details on the protocol below).

If the new model client is in the config list but not registered by the time the chat is initialized, then an error will be raised.

Protocol details

A custom model class can be created in many ways, but needs to adhere to the ModelClient protocol and response structure which is defined in client.py and shown below.

The response protocol is currently using the minimum required fields from the autogen codebase that match the OpenAI response structure. Any response protocol that matches the OpenAI response structure will probably be more resilient to future changes, but we are starting off with minimum requirements to make adpotion of this feature easier.


class ModelClient(Protocol):
"""
A client class must implement the following methods:
- create must return a response object that implements the ModelClientResponseProtocol
- cost must return the cost of the response
- get_usage must return a dict with the following keys:
- prompt_tokens
- completion_tokens
- total_tokens
- cost
- model

This class is used to create a client that can be used by OpenAIWrapper.
The response returned from create must adhere to the ModelClientResponseProtocol but can be extended however needed.
The message_retrieval method must be implemented to return a list of str or a list of messages from the response.
"""

RESPONSE_USAGE_KEYS = ["prompt_tokens", "completion_tokens", "total_tokens", "cost", "model"]

class ModelClientResponseProtocol(Protocol):
class Choice(Protocol):
class Message(Protocol):
content: Optional[str]

message: Message

choices: List[Choice]
model: str

def create(self, params) -> ModelClientResponseProtocol:
...

def message_retrieval(
self, response: ModelClientResponseProtocol
) -> Union[List[str], List[ModelClient.ModelClientResponseProtocol.Choice.Message]]:
"""
Retrieve and return a list of strings or a list of Choice.Message from the response.

NOTE: if a list of Choice.Message is returned, it currently needs to contain the fields of OpenAI's ChatCompletion Message object,
since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used.
"""
...

def cost(self, response: ModelClientResponseProtocol) -> float:
...

@staticmethod
def get_usage(response: ModelClientResponseProtocol) -> Dict:
"""Return usage summary of the response using RESPONSE_USAGE_KEYS."""
...

Troubleshooting steps

If something doesn't work then run through the checklist:

  • Make sure you have followed the client protocol and client response protocol when creating the custom model class
    • create() method: ModelClientResponseProtocol must be followed when returning an inference response during create call.
    • message_retrieval() method: returns a list of strings or a list of message objects. If a list of message objects is returned, they currently must contain the fields of OpenAI's ChatCompletion Message object, since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used.
    • cost()method: returns an integer, and if you don't care about cost tracking you can just return 0.
    • get_usage(): returns a dictionary, and if you don't care about usage tracking you can just return an empty dictionary {}.
  • Make sure you have a corresponding entry in the OAI_CONFIG_LIST and that that entry has the "model_client_cls":"<custom-model-class-name>" field.
  • Make sure you have registered the client using the corresponding config entry and your new class agent.register_model_client(model_client_cls=<class-of-custom-model>, [other optional args])
  • Make sure that all of the custom models defined in the OAI_CONFIG_LIST have been registered.
  • Any other troubleshooting might need to be done in the custom code itself.

Conclusion

With the ability to use custom models, AutoGen now offers even more flexibility and power for your AI applications. Whether you've trained your own model or want to use a specific pre-trained model, AutoGen can accommodate your needs. Happy coding!