Source code for autogen_ext.models.llama_cpp._llama_cpp_completion_client
importlogging# added importimportrefromtypingimportAny,AsyncGenerator,Dict,List,Literal,Mapping,Optional,Sequence,TypedDict,Union,castfromautogen_coreimportEVENT_LOGGER_NAME,CancellationToken,FunctionCall,MessageHandlerContextfromautogen_core.loggingimportLLMCallEventfromautogen_core.modelsimport(AssistantMessage,ChatCompletionClient,CreateResult,FinishReasons,FunctionExecutionResultMessage,LLMMessage,ModelInfo,RequestUsage,SystemMessage,UserMessage,validate_model_info,)fromautogen_core.toolsimportTool,ToolSchemafromllama_cppimport(ChatCompletionFunctionParameters,ChatCompletionRequestAssistantMessage,ChatCompletionRequestFunctionMessage,ChatCompletionRequestSystemMessage,ChatCompletionRequestToolMessage,ChatCompletionRequestUserMessage,ChatCompletionTool,ChatCompletionToolFunction,Llama,llama_chat_format,)fromtyping_extensionsimportUnpacklogger=logging.getLogger(EVENT_LOGGER_NAME)# initialize loggerdefnormalize_stop_reason(stop_reason:str|None)->FinishReasons:ifstop_reasonisNone:return"unknown"# Convert to lower casestop_reason=stop_reason.lower()KNOWN_STOP_MAPPINGS:Dict[str,FinishReasons]={"stop":"stop","length":"length","content_filter":"content_filter","function_calls":"function_calls","end_turn":"stop","tool_calls":"function_calls",}returnKNOWN_STOP_MAPPINGS.get(stop_reason,"unknown")defnormalize_name(name:str)->str:""" LLMs sometimes ask functions while ignoring their own format requirements, this function should be used to replace invalid characters with "_". Prefer _assert_valid_name for validating user configuration or input """returnre.sub(r"[^a-zA-Z0-9_-]","_",name)[:64]defassert_valid_name(name:str)->str:""" Ensure that configured names are valid, raises ValueError if not. For munging LLM responses use _normalize_name to ensure LLM specified names don't break the API. """ifnotre.match(r"^[a-zA-Z0-9_-]+$",name):raiseValueError(f"Invalid name: {name}. Only letters, numbers, '_' and '-' are allowed.")iflen(name)>64:raiseValueError(f"Invalid name: {name}. Name must be less than 64 characters.")returnnamedefconvert_tools(tools:Sequence[Tool|ToolSchema],)->List[ChatCompletionTool]:result:List[ChatCompletionTool]=[]fortoolintools:ifisinstance(tool,Tool):tool_schema=tool.schemaelse:assertisinstance(tool,dict)tool_schema=toolresult.append(ChatCompletionTool(type="function",function=ChatCompletionToolFunction(name=tool_schema["name"],description=(tool_schema["description"]if"description"intool_schemaelse""),parameters=(cast(ChatCompletionFunctionParameters,tool_schema["parameters"])if"parameters"intool_schemaelse{}),),))# Check if all tools have valid names.fortool_paraminresult:assert_valid_name(tool_param["function"]["name"])returnresultclassLlamaCppParams(TypedDict,total=False):# from_pretrained parameters:repo_id:Optional[str]filename:Optional[str]additional_files:Optional[List[Any]]local_dir:Optional[str]local_dir_use_symlinks:Union[bool,Literal["auto"]]cache_dir:Optional[str]# __init__ parameters:model_path:strn_gpu_layers:intsplit_mode:intmain_gpu:inttensor_split:Optional[List[float]]rpc_servers:Optional[str]vocab_only:booluse_mmap:booluse_mlock:boolkv_overrides:Optional[Dict[str,Union[bool,int,float,str]]]seed:intn_ctx:intn_batch:intn_ubatch:intn_threads:Optional[int]n_threads_batch:Optional[int]rope_scaling_type:Optional[int]pooling_type:intrope_freq_base:floatrope_freq_scale:floatyarn_ext_factor:floatyarn_attn_factor:floatyarn_beta_fast:floatyarn_beta_slow:floatyarn_orig_ctx:intlogits_all:boolembedding:booloffload_kqv:boolflash_attn:boolno_perf:boollast_n_tokens_size:intlora_base:Optional[str]lora_scale:floatlora_path:Optional[str]numa:Union[bool,int]chat_format:Optional[str]chat_handler:Optional[llama_chat_format.LlamaChatCompletionHandler]draft_model:Optional[Any]# LlamaDraftModel not exposed by llama_cpptokenizer:Optional[Any]# BaseLlamaTokenizer not exposed by llama_cpptype_k:Optional[int]type_v:Optional[int]spm_infill:boolverbose:bool
[docs]classLlamaCppChatCompletionClient(ChatCompletionClient):"""Chat completion client for LlamaCpp models. To use this client, you must install the `llama-cpp` extra: .. code-block:: bash pip install "autogen-ext[llama-cpp]" This client allows you to interact with LlamaCpp models, either by specifying a local model path or by downloading a model from Hugging Face Hub. Args: model_path (optional, str): The path to the LlamaCpp model file. Required if repo_id and filename are not provided. repo_id (optional, str): The Hugging Face Hub repository ID. Required if model_path is not provided. filename (optional, str): The filename of the model within the Hugging Face Hub repository. Required if model_path is not provided. n_gpu_layers (optional, int): The number of layers to put on the GPU. n_ctx (optional, int): The context size. n_batch (optional, int): The batch size. verbose (optional, bool): Whether to print verbose output. model_info (optional, ModelInfo): The capabilities of the model. Defaults to a ModelInfo instance with function_calling set to True. **kwargs: Additional parameters to pass to the Llama class. Examples: The following code snippet shows how to use the client with a local model file: .. code-block:: python import asyncio from autogen_core.models import UserMessage from autogen_ext.models.llama_cpp import LlamaCppChatCompletionClient async def main(): llama_client = LlamaCppChatCompletionClient(model_path="/path/to/your/model.gguf") result = await llama_client.create([UserMessage(content="What is the capital of France?", source="user")]) print(result) asyncio.run(main()) The following code snippet shows how to use the client with a model from Hugging Face Hub: .. code-block:: python import asyncio from autogen_core.models import UserMessage from autogen_ext.models.llama_cpp import LlamaCppChatCompletionClient async def main(): llama_client = LlamaCppChatCompletionClient( repo_id="unsloth/phi-4-GGUF", filename="phi-4-Q2_K_L.gguf", n_gpu_layers=-1, seed=1337, n_ctx=5000 ) result = await llama_client.create([UserMessage(content="What is the capital of France?", source="user")]) print(result) asyncio.run(main()) """def__init__(self,model_info:Optional[ModelInfo]=None,**kwargs:Unpack[LlamaCppParams],)->None:""" Initialize the LlamaCpp client. """ifmodel_info:validate_model_info(model_info)if"repo_id"inkwargsand"filename"inkwargsandkwargs["repo_id"]andkwargs["filename"]:repo_id:str=cast(str,kwargs.pop("repo_id"))filename:str=cast(str,kwargs.pop("filename"))pretrained=Llama.from_pretrained(repo_id=repo_id,filename=filename,**kwargs)# type: ignoreassertisinstance(pretrained,Llama)self.llm=pretrainedelif"model_path"inkwargs:self.llm=Llama(**kwargs)# pyright: ignore[reportUnknownMemberType]else:raiseValueError("Please provide model_path if ... or provide repo_id and filename if ....")self._total_usage={"prompt_tokens":0,"completion_tokens":0}
[docs]asyncdefcreate(self,messages:Sequence[LLMMessage],*,tools:Sequence[Tool|ToolSchema]=[],# None means do not override the default# A value means to override the client default - often specified in the constructorjson_output:Optional[bool]=None,extra_create_args:Mapping[str,Any]={},cancellation_token:Optional[CancellationToken]=None,)->CreateResult:# Convert LLMMessage objects to dictionaries with 'role' and 'content'# converted_messages: List[Dict[str, str | Image | list[str | Image] | list[FunctionCall]]] = []converted_messages:list[ChatCompletionRequestSystemMessage|ChatCompletionRequestUserMessage|ChatCompletionRequestAssistantMessage|ChatCompletionRequestUserMessage|ChatCompletionRequestToolMessage|ChatCompletionRequestFunctionMessage]=[]formsginmessages:ifisinstance(msg,SystemMessage):converted_messages.append({"role":"system","content":msg.content})elifisinstance(msg,UserMessage)andisinstance(msg.content,str):converted_messages.append({"role":"user","content":msg.content})elifisinstance(msg,AssistantMessage)andisinstance(msg.content,str):converted_messages.append({"role":"assistant","content":msg.content})elif(isinstance(msg,SystemMessage)orisinstance(msg,UserMessage)orisinstance(msg,AssistantMessage))andisinstance(msg.content,list):raiseValueError("Multi-part messages such as those containing images are currently not supported.")else:raiseValueError(f"Unsupported message type: {type(msg)}")ifself.model_info["function_calling"]:response=self.llm.create_chat_completion(messages=converted_messages,tools=convert_tools(tools),stream=False)else:response=self.llm.create_chat_completion(messages=converted_messages,stream=False)ifnotisinstance(response,dict):raiseValueError("Unexpected response type from LlamaCpp model.")self._total_usage["prompt_tokens"]+=response["usage"]["prompt_tokens"]self._total_usage["completion_tokens"]+=response["usage"]["completion_tokens"]# Parse the responseresponse_tool_calls:ChatCompletionTool|None=Noneresponse_text:str|None=Noneif"choices"inresponseandlen(response["choices"])>0:if"message"inresponse["choices"][0]:response_text=response["choices"][0]["message"]["content"]if"tool_calls"inresponse["choices"][0]:response_tool_calls=response["choices"][0]["tool_calls"]# type: ignorecontent:List[FunctionCall]|str=""thought:str|None=Noneifresponse_tool_calls:content=[]fortool_callinresponse_tool_calls:ifnotisinstance(tool_call,dict):raiseValueError("Unexpected tool call type from LlamaCpp model.")content.append(FunctionCall(id=tool_call["id"],arguments=tool_call["function"]["arguments"],name=normalize_name(tool_call["function"]["name"]),))ifresponse_textandlen(response_text)>0:thought=response_textelse:ifresponse_text:content=response_text# Detect tool usage in the responseifnotresponse_tool_callsandnotresponse_text:logger.debug("DEBUG: No response text found. Returning empty response.")returnCreateResult(content="",usage=RequestUsage(prompt_tokens=0,completion_tokens=0),finish_reason="stop",cached=False)# Create a CreateResult objectif"finish_reason"inresponse["choices"][0]:finish_reason=response["choices"][0]["finish_reason"]else:finish_reason="unknown"iffinish_reasonnotin("stop","length","function_calls","content_filter","unknown"):finish_reason="unknown"create_result=CreateResult(content=content,thought=thought,usage=cast(RequestUsage,response["usage"]),finish_reason=normalize_stop_reason(finish_reason),# type: ignorecached=False,)# If we are running in the context of a handler we can get the agent_idtry:agent_id=MessageHandlerContext.agent_id()exceptRuntimeError:agent_id=Nonelogger.info(LLMCallEvent(messages=cast(List[Dict[str,Any]],converted_messages),response=create_result.model_dump(),prompt_tokens=response["usage"]["prompt_tokens"],completion_tokens=response["usage"]["completion_tokens"],agent_id=agent_id,))returncreate_result
[docs]asyncdefcreate_stream(self,messages:Sequence[LLMMessage],*,tools:Sequence[Tool|ToolSchema]=[],# None means do not override the default# A value means to override the client default - often specified in the constructorjson_output:Optional[bool]=None,extra_create_args:Mapping[str,Any]={},cancellation_token:Optional[CancellationToken]=None,)->AsyncGenerator[Union[str,CreateResult],None]:raiseNotImplementedError("Stream not yet implemented for LlamaCppChatCompletionClient")yield""
[docs]defcount_tokens(self,messages:Sequence[SystemMessage|UserMessage|AssistantMessage|FunctionExecutionResultMessage],**kwargs:Any,)->int:total=0formsginmessages:# Use the Llama model's tokenizer to encode the contenttokens=self.llm.tokenize(str(msg.content).encode("utf-8"))total+=len(tokens)returntotal