使用自定义大模型来部署Wren AI(开源的文本生成SQL方案)
关于
- 首次发表日期:2024-07-15
- Wren AI官方文档: https://docs.getwren.ai/overview/introduction
- Wren AI Github仓库: https://github.com/Canner/WrenAI
关于Wren AI
Wren AI 是一个开源的文本生成SQL解决方案。
前提准备
由于之后会使用docker来启动服务,所以首先确保docker已经安装好了,并且网络没问题。
先克隆仓库:
git clone https://github.com/Canner/WrenAI.git
关于在Wren AI中使用自定义大模型和Embedding模型
Wren AI目前是支持自定义LLM和Embedding模型的,其官方文档 https://docs.getwren.ai/installation/custom_llm 中有提及,需要创建自己的provider类。
其中Wren AI本身已经支持和OPEN AI兼容的大模型了;但是自定义的Embedding模型方面,可能会报错,具体来说是
wren-ai-service/src/providers/embedder/openai.py
中的以下代码
if self.dimensions isnotNone:
response =await self.client.embeddings.create(
model=self.model, dimensions=self.dimensions,input=text_to_embed
)else:
response =await self.client.embeddings.create(
model=self.model,input=text_to_embed
)
其中
if self.dimensions is not None
这个条件分支是会报错的(默认会运行这个分支),所以我的临时解决方案是注释掉它。
具体而言是在
wren-ai-service/src/providers/embedder
文件夹中创建一个
openai_like.py
文件,表示定义一个和open ai类似的embedding provider,取个名字叫做
openai_like_embedder
,具体的完整代码见本文附录。
配置docker环境变量等并启动服务
首先,进入
docker
文件夹,拷贝
.env.example
并重命名为
.env.local
。
然后拷贝
.env.ai.example
并重命名为
.env.ai
,修改其中的LLM和Embedding的配置,相关部分如下:
LLM_PROVIDER=openai_llm
LLM_OPENAI_API_KEY=sk-xxxxxxxxxxxxxxxxxxx
LLM_OPENAI_API_BASE=http://api.siliconflow.cn/v1
GENERATION_MODEL=meta-llama/Meta-Llama-3-70B
# GENERATION_MODEL_KWARGS={"temperature": 0, "n": 1, "max_tokens": 32768, "response_format": {"type": "json_object"}}
EMBEDDER_PROVIDER=openai_like_embedder
EMBEDDING_MODEL=bge-m3
EMBEDDING_MODEL_DIMENSION=1024
EMBEDDER_OPENAI_API_KEY=sk-xxxxxxxxxxxxxxxxxxxxx
EMBEDDER_OPENAI_API_BASE=https://xxxxxxxxxxxxxxxx/v1
由于我们创建了一个自定义的embedding provider,需要将文件映射到docker容器中,具体可以通过配置
docker-compose.yaml
中的
wren-ai-service
,添加
volumes
属性:
wren-ai-service:image: ghcr.io/canner/wren-ai-service:${WREN_AI_SERVICE_VERSION}volumes:- /root/WrenAI/wren-ai-service/src:/src
最后,启动服务:
docker-compose-f docker-compose.yaml -f docker-compose.llm.yaml --env-file .env.local --env-file .env.ai up -d
或者停止服务:
docker-compose-f docker-compose.yaml -f docker-compose.llm.yaml --env-file .env.local --env-file .env.ai down
附录
openai_like.py
文件(提供自定义embedding服务):
import logging
import os
from typing import Any, Dict, List, Optional, Tuple
import backoff
import openai
from haystack import Document, component
from haystack.components.embedders import OpenAIDocumentEmbedder, OpenAITextEmbedder
from haystack.utils import Secret
from openai import AsyncOpenAI, OpenAI
from tqdm import tqdm
from src.core.provider import EmbedderProvider
from src.providers.loader import provider
import logging
import sys
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))
logger = logging.getLogger("wren-ai-service")
EMBEDDER_OPENAI_API_BASE ="https://api.openai.com/v1"
EMBEDDING_MODEL ="text-embedding-3-large"
EMBEDDING_MODEL_DIMENSION =3072@componentclassAsyncTextEmbedder(OpenAITextEmbedder):def__init__(
self,
api_key: Secret = Secret.from_env_var("EMBEDDER_OPENAI_API_KEY"),
model:str="text-embedding-ada-002",
dimensions: Optional[int]=None,
api_base_url: Optional[str]=None,
organization: Optional[str]=None,
prefix:str="",
suffix:str="",):super(AsyncTextEmbedder, self).__init__(
api_key,
model,
dimensions,
api_base_url,
organization,
prefix,
suffix,)
self.client = AsyncOpenAI(
api_key=api_key.resolve_value(),
organization=organization,
base_url=api_base_url,)@component.output_types(embedding=List[float], meta=Dict[str, Any])@backoff.on_exception(backoff.expo, openai.RateLimitError, max_time=60, max_tries=3)asyncdefrun(self, text:str):ifnotisinstance(text,str):raise TypeError("OpenAITextEmbedder expects a string as an input.""In case you want to embed a list of Documents, please use the OpenAIDocumentEmbedder.")
logger.debug(f"Running Async OpenAI text embedder with text: {text}")
text_to_embed = self.prefix + text + self.suffix
# copied from OpenAI embedding_utils (https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py)# replace newlines, which can negatively affect performance.
text_to_embed = text_to_embed.replace("\n"," ")# if self.dimensions is not None:# response = await self.client.embeddings.create(# model=self.model, dimensions=self.dimensions, input=text_to_embed# )# else:
response =await self.client.embeddings.create(
model=self.model,input=text_to_embed
)
meta ={"model": response.model,"usage":dict(response.usage)}return{"embedding": response.data[0].embedding,"meta": meta}@componentclassAsyncDocumentEmbedder(OpenAIDocumentEmbedder):def__init__(
self,
api_key: Secret = Secret.from_env_var("EMBEDDER_OPENAI_API_KEY"),
model:str="text-embedding-ada-002",
dimensions: Optional[int]=None,
api_base_url: Optional[str]=None,
organization: Optional[str]=None,
prefix:str="",
suffix:str="",
batch_size:int=32,
progress_bar:bool=True,
meta_fields_to_embed: Optional[List[str]]=None,
embedding_separator:str="\n",):super(AsyncDocumentEmbedder, self).__init__(
api_key,
model,
dimensions,
api_base_url,
organization,
prefix,
suffix,
batch_size,
progress_bar,
meta_fields_to_embed,
embedding_separator,)
self.client = AsyncOpenAI(
api_key=api_key.resolve_value(),
organization=organization,
base_url=api_base_url,)asyncdef_embed_batch(
self, texts_to_embed: List[str], batch_size:int)-> Tuple[List[List[float]], Dict[str, Any]]:
all_embeddings =[]
meta: Dict[str, Any]={}for i in tqdm(range(0,len(texts_to_embed), batch_size),
disable=not self.progress_bar,
desc="Calculating embeddings",):
batch = texts_to_embed[i : i + batch_size]# if self.dimensions is not None:# response = await self.client.embeddings.create(# model=self.model, dimensions=self.dimensions, input=batch# )# else:
response =await self.client.embeddings.create(
model=self.model,input=batch
)
embeddings =[el.embedding for el in response.data]
all_embeddings.extend(embeddings)if"model"notin meta:
meta["model"]= response.model
if"usage"notin meta:
meta["usage"]=dict(response.usage)else:
meta["usage"]["prompt_tokens"]+= response.usage.prompt_tokens
meta["usage"]["total_tokens"]+= response.usage.total_tokens
return all_embeddings, meta
@component.output_types(documents=List[Document], meta=Dict[str, Any])@backoff.on_exception(backoff.expo, openai.RateLimitError, max_time=60, max_tries=3)asyncdefrun(self, documents: List[Document]):if(notisinstance(documents,list)or documents
andnotisinstance(documents[0], Document)):raise TypeError("OpenAIDocumentEmbedder expects a list of Documents as input.""In case you want to embed a string, please use the OpenAITextEmbedder.")
logger.debug(f"Running Async OpenAI document embedder with documents: {documents}")
texts_to_embed = self._prepare_texts_to_embed(documents=documents)
embeddings, meta =await self._embed_batch(
texts_to_embed=texts_to_embed, batch_size=self.batch_size
)for doc, emb inzip(documents, embeddings):
doc.embedding = emb
return{"documents": documents,"meta": meta}@provider("openai_like_embedder")classOpenAIEmbedderProvider(EmbedderProvider):def__init__(
self,
api_key: Secret = Secret.from_env_var("EMBEDDER_OPENAI_API_KEY"),
api_base:str= os.getenv("EMBEDDER_OPENAI_API_BASE")or EMBEDDER_OPENAI_API_BASE,
embedding_model:str= os.getenv("EMBEDDING_MODEL")or EMBEDDING_MODEL,
embedding_model_dim:int=(int(os.getenv("EMBEDDING_MODEL_DIMENSION"))if os.getenv("EMBEDDING_MODEL_DIMENSION")else0)or EMBEDDING_MODEL_DIMENSION,):def_verify_api_key(api_key:str, api_base:str)->None:"""
this is a temporary solution to verify that the required environment variables are set
"""
OpenAI(api_key=api_key, base_url=api_base).models.list()
logger.info(f"Initializing OpenAIEmbedder provider with API base: {api_base}")# TODO: currently only OpenAI api key can be verifiedif api_base == EMBEDDER_OPENAI_API_BASE:
_verify_api_key(api_key.resolve_value(), api_base)
logger.info(f"Using OpenAI Embedding Model: {embedding_model}")else:
logger.info(f"Using OpenAI API-compatible Embedding Model: {embedding_model}")
self._api_key = api_key
self._api_base = api_base
self._embedding_model = embedding_model
self._embedding_model_dim = embedding_model_dim
defget_text_embedder(self):return AsyncTextEmbedder(
api_key=self._api_key,
api_base_url=self._api_base,
model=self._embedding_model,
dimensions=self._embedding_model_dim,)defget_document_embedder(self):return AsyncDocumentEmbedder(
api_key=self._api_key,
api_base_url=self._api_base,
model=self._embedding_model,
dimensions=self._embedding_model_dim,)
版权归原作者 shizidushu 所有, 如有侵权,请联系我们删除。