Skip to main content

oai.client

ModelClient

class ModelClient(Protocol)

A client class must implement the following methods:

  • create must return a response object that implements the ModelClientResponseProtocol
  • cost must return the cost of the response
  • get_usage must return a dict with the following keys:
  • prompt_tokens
  • completion_tokens
  • total_tokens
  • cost
  • model

This class is used to create a client that can be used by OpenAIWrapper. The response returned from create must adhere to the ModelClientResponseProtocol but can be extended however needed. The message_retrieval method must be implemented to return a list of str or a list of messages from the response.

message_retrieval

def message_retrieval(
response: ModelClientResponseProtocol
) -> Union[List[str],
List[ModelClient.ModelClientResponseProtocol.Choice.Message]]

Retrieve and return a list of strings or a list of Choice.Message from the response.

NOTE: if a list of Choice.Message is returned, it currently needs to contain the fields of OpenAI's ChatCompletion Message object, since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used.

get_usage

@staticmethod
def get_usage(response: ModelClientResponseProtocol) -> Dict

Return usage summary of the response using RESPONSE_USAGE_KEYS.

OpenAIClient

class OpenAIClient()

Follows the Client protocol and wraps the OpenAI client.

message_retrieval

def message_retrieval(
response: Union[ChatCompletion, Completion]
) -> Union[List[str], List[ChatCompletionMessage]]

Retrieve the messages from the response.

create

def create(params: Dict[str, Any]) -> ChatCompletion

Create a completion for a given config using openai's client.

Arguments:

  • client - The openai client.
  • params - The params for the completion.

Returns:

The completion.

cost

def cost(response: Union[ChatCompletion, Completion]) -> float

Calculate the cost of the response.

OpenAIWrapper

class OpenAIWrapper()

A wrapper class for openai client.

__init__

def __init__(*,
config_list: Optional[List[Dict[str, Any]]] = None,
**base_config: Any)

Arguments:

  • config_list - a list of config dicts to override the base_config. They can contain additional kwargs as allowed in the create method. E.g.,
config_list=[
{
"model": "gpt-4",
"api_key": os.environ.get("AZURE_OPENAI_API_KEY"),
"api_type": "azure",
"base_url": os.environ.get("AZURE_OPENAI_API_BASE"),
"api_version": "2024-02-15-preview",
},
{
"model": "gpt-3.5-turbo",
"api_key": os.environ.get("OPENAI_API_KEY"),
"api_type": "openai",
"base_url": "https://api.openai.com/v1",
},
{
"model": "llama-7B",
"base_url": "http://127.0.0.1:8080",
}
]
  • base_config - base config. It can contain both keyword arguments for openai client and additional kwargs.

register_model_client

def register_model_client(model_client_cls: ModelClient, **kwargs)

Register a model client.

Arguments:

  • model_client_cls - A custom client class that follows the ModelClient interface
  • **kwargs - The kwargs for the custom client class to be initialized with

create

def create(**config: Any) -> ModelClient.ModelClientResponseProtocol

Make a completion for a given config using available clients. Besides the kwargs allowed in openai's [or other] client, we allow the following additional kwargs. The config in each client will be overridden by the config.

Arguments:

  • context (Dict | None): The context to instantiate the prompt or messages. Default to None. It needs to contain keys that are used by the prompt template or the filter function. E.g., prompt="Complete the following sentence: {prefix}, context={"prefix": "Today I feel"}. The actual prompt will be: "Complete the following sentence: Today I feel". More examples can be found at templating.
  • cache (Cache | None): A Cache object to use for response cache. Default to None. Note that the cache argument overrides the legacy cache_seed argument: if this argument is provided, then the cache_seed argument is ignored. If this argument is not provided or None, then the cache_seed argument is used.
  • (Legacy) cache_seed (int | None) for using the DiskCache. Default to 41. An integer cache_seed is useful when implementing "controlled randomness" for the completion. None for no caching.
  • Note - this is a legacy argument. It is only used when the cache argument is not provided.
    • filter_func (Callable | None): A function that takes in the context and the response and returns a boolean to indicate whether the response is valid. E.g.,
def yes_or_no_filter(context, response):
return context.get("yes_or_no_choice", False) is False or any(
text in ["Yes.", "No."] for text in client.extract_text_or_completion_object(response)
)
  • allow_format_str_template (bool | None): Whether to allow format string template in the config. Default to false.
  • api_version (str | None): The api version. Default to None. E.g., "2024-02-15-preview".

Raises:

  • RuntimeError: If all declared custom model clients are not registered
  • APIError: If any model client create call raises an APIError
def print_usage_summary(
mode: Union[str, List[str]] = ["actual", "total"]) -> None

Print the usage summary.

clear_usage_summary

def clear_usage_summary() -> None

Clear the usage summary.

extract_text_or_completion_object

@classmethod
def extract_text_or_completion_object(
cls, response: ModelClient.ModelClientResponseProtocol
) -> Union[List[str],
List[ModelClient.ModelClientResponseProtocol.Choice.Message]]

Extract the text or ChatCompletion objects from a completion or chat response.

Arguments:

  • response ChatCompletion | Completion - The response from openai.

Returns:

A list of text, or a list of ChatCompletion objects if function_call/tool_calls are present.