Customized LLM API
We welcome developers to use your customized LLM API in TaskWeaver. In this tutorial, we will show you how to contribute your LLM API to TaskWeaver.
- Create a new Python script
<your_LLM_name>.py
in thetaskweaver/llm
folder. - Import the
CompletionService
,LLMServiceConfig
andEmbeddingService
fromtaskweaver.llm.base
and other necessary libraries.
from injector import inject
from taskweaver.llm.base import CompletionService, EmbeddingService, LLMServiceConfig
from taskweaver.llm.util import ChatMessageType
...
- Create a new class
YourLLMServiceConfig
that inherits fromLLMServiceConfig
and implements the_configure
method. In this method, you can set the name, API key, model name, backup model name, and embedding model name of your LLM.
class YourLLMServiceConfig(LLMServiceConfig):
def _configure(self) -> None:
self._set_name("your_llm_name")
shared_api_key = self.llm_module_config.api_key
self.api_key = self._get_str(
"api_key",
shared_api_key,
)
shared_model = self.llm_module_config.model
self.model = self._get_str(
"model",
shared_model if shared_model is not None else "your_llm_model_name",
)
shared_backup_model = self.llm_module_config.backup_model
self.backup_model = self._get_str(
"backup_model",
shared_backup_model if shared_backup_model is not None else self.model,
)
shared_embedding_model = self.llm_module_config.embedding_model
self.embedding_model = self._get_str(
"embedding_model",
shared_embedding_model if shared_embedding_model is not None else self.model,
)
- Create a new class
YourLLMService
that inherits fromCompletionService
andEmbeddingService
and implements thechat_completion
andget_embeddings
methods.
class YourLLMService(CompletionService, EmbeddingService):
@inject
def __init__(self, config: YourLLMServiceConfig):
self.config = config
pass
def chat_completion(
self,
messages: List[ChatMessageType],
stream: bool = True,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
top_p: Optional[float] = None,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Generator[ChatMessageType, None, None]:
pass
def get_embeddings(self, strings: List[str]) -> List[List[float]]:
pass
Note:
- We set stream mode by default in
chat_completion
. - You need to use
self.config
to get the configuration variables of your LLM API (e.g., api key/model name) inYourLLMService
class. - The
get_embeddings
method is optional. - If you need to import other libraries for your LLM API, please import them in
__init__
function ofYourLLMService
class. You can refer to QWen dashscope library import for an example.
- Register your LLM service in
taskweaver/llm/__init__.py
by adding your LLM service to theLLMApi
__init__
function .
......
from .your_llm_name import YourLLMService # import your LLM service here
class LLMApi(object):
@inject
def __init__(self, config: LLMModuleConfig, injector: Injector) -> None:
self.config = config
self.injector = injector
if self.config.api_type in ["openai", "azure", "azure_ad"]:
self._set_completion_service(OpenAIService)
......
elif self.config.api_type == "your_llm_name":
self._set_completion_service(YourLLMService) # register your LLM service here
else:
raise ValueError(f"API type {self.config.api_type} is not supported")
if self.config.embedding_api_type in ["openai", "azure", "azure_ad"]:
self._set_embedding_service(OpenAIService)
......
elif self.config.embedding_api_type == "azure_ml":
self.embedding_service = PlaceholderEmbeddingService(
"Azure ML does not support embeddings yet. Please configure a different embedding API.",
)
# register your embedding service here, if do not have embedding service, please use `PlaceholderEmbeddingService` referring to the above line
elif self.config.embedding_api_type == "your_llm_name":
self._set_embedding_service(YourLLMService)
else:
raise ValueError(
f"Embedding API type {self.config.embedding_api_type} is not supported",
)
- Configurate
taskweaver_config.json
file in theproject
dir based on your implemented LLM API. - Run the following command to test your LLM API. If the LLM API is successfully set up, you will see the response from your LLM API.
cd ./scripts
python llm_api_test.py
You also can specify the project dir and query to be sent to your LLM API by using the following command:
python llm_api_test.py --project <your_project_path> --query "hello, what can you do?"