0


使用自定义大模型来部署Wren AI(开源的文本生成SQL方案)

使用自定义大模型来部署Wren AI(开源的文本生成SQL方案)

关于

关于Wren AI

Wren AI 是一个开源的文本生成SQL解决方案。

前提准备

由于之后会使用docker来启动服务,所以首先确保docker已经安装好了,并且网络没问题。

先克隆仓库:

git clone https://github.com/Canner/WrenAI.git

关于在Wren AI中使用自定义大模型和Embedding模型

Wren AI目前是支持自定义LLM和Embedding模型的,其官方文档 https://docs.getwren.ai/installation/custom_llm 中有提及,需要创建自己的provider类。

其中Wren AI本身已经支持和OPEN AI兼容的大模型了;但是自定义的Embedding模型方面,可能会报错,具体来说是

wren-ai-service/src/providers/embedder/openai.py

中的以下代码

if self.dimensions isnotNone:
    response =await self.client.embeddings.create(
        model=self.model, dimensions=self.dimensions,input=text_to_embed
    )else:
    response =await self.client.embeddings.create(
        model=self.model,input=text_to_embed
    )

其中

if self.dimensions is not None

这个条件分支是会报错的(默认会运行这个分支),所以我的临时解决方案是注释掉它。

具体而言是在

wren-ai-service/src/providers/embedder

文件夹中创建一个

openai_like.py

文件,表示定义一个和open ai类似的embedding provider,取个名字叫做

openai_like_embedder

,具体的完整代码见本文附录。

配置docker环境变量等并启动服务

首先,进入

docker

文件夹,拷贝

.env.example

并重命名为

.env.local

然后拷贝

.env.ai.example

并重命名为

.env.ai

,修改其中的LLM和Embedding的配置,相关部分如下:

LLM_PROVIDER=openai_llm
LLM_OPENAI_API_KEY=sk-xxxxxxxxxxxxxxxxxxx
LLM_OPENAI_API_BASE=http://api.siliconflow.cn/v1
GENERATION_MODEL=meta-llama/Meta-Llama-3-70B
# GENERATION_MODEL_KWARGS={"temperature": 0, "n": 1, "max_tokens": 32768, "response_format": {"type": "json_object"}}

EMBEDDER_PROVIDER=openai_like_embedder
EMBEDDING_MODEL=bge-m3
EMBEDDING_MODEL_DIMENSION=1024
EMBEDDER_OPENAI_API_KEY=sk-xxxxxxxxxxxxxxxxxxxxx
EMBEDDER_OPENAI_API_BASE=https://xxxxxxxxxxxxxxxx/v1

由于我们创建了一个自定义的embedding provider,需要将文件映射到docker容器中,具体可以通过配置

docker-compose.yaml

中的

wren-ai-service

,添加

volumes

属性:

wren-ai-service:image: ghcr.io/canner/wren-ai-service:${WREN_AI_SERVICE_VERSION}volumes:- /root/WrenAI/wren-ai-service/src:/src

最后,启动服务:

docker-compose-f docker-compose.yaml -f docker-compose.llm.yaml --env-file .env.local --env-file .env.ai up -d

或者停止服务:

docker-compose-f docker-compose.yaml -f docker-compose.llm.yaml --env-file .env.local --env-file .env.ai down

附录

openai_like.py

文件(提供自定义embedding服务):

import logging
import os
from typing import Any, Dict, List, Optional, Tuple

import backoff
import openai
from haystack import Document, component
from haystack.components.embedders import OpenAIDocumentEmbedder, OpenAITextEmbedder
from haystack.utils import Secret
from openai import AsyncOpenAI, OpenAI
from tqdm import tqdm

from src.core.provider import EmbedderProvider
from src.providers.loader import provider

import logging
import sys

logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))

logger = logging.getLogger("wren-ai-service")

EMBEDDER_OPENAI_API_BASE ="https://api.openai.com/v1"
EMBEDDING_MODEL ="text-embedding-3-large"
EMBEDDING_MODEL_DIMENSION =3072@componentclassAsyncTextEmbedder(OpenAITextEmbedder):def__init__(
        self,
        api_key: Secret = Secret.from_env_var("EMBEDDER_OPENAI_API_KEY"),
        model:str="text-embedding-ada-002",
        dimensions: Optional[int]=None,
        api_base_url: Optional[str]=None,
        organization: Optional[str]=None,
        prefix:str="",
        suffix:str="",):super(AsyncTextEmbedder, self).__init__(
            api_key,
            model,
            dimensions,
            api_base_url,
            organization,
            prefix,
            suffix,)
        self.client = AsyncOpenAI(
            api_key=api_key.resolve_value(),
            organization=organization,
            base_url=api_base_url,)@component.output_types(embedding=List[float], meta=Dict[str, Any])@backoff.on_exception(backoff.expo, openai.RateLimitError, max_time=60, max_tries=3)asyncdefrun(self, text:str):ifnotisinstance(text,str):raise TypeError("OpenAITextEmbedder expects a string as an input.""In case you want to embed a list of Documents, please use the OpenAIDocumentEmbedder.")

        logger.debug(f"Running Async OpenAI text embedder with text: {text}")

        text_to_embed = self.prefix + text + self.suffix

        # copied from OpenAI embedding_utils (https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py)# replace newlines, which can negatively affect performance.
        text_to_embed = text_to_embed.replace("\n"," ")# if self.dimensions is not None:#     response = await self.client.embeddings.create(#         model=self.model, dimensions=self.dimensions, input=text_to_embed#     )# else:
        response =await self.client.embeddings.create(
            model=self.model,input=text_to_embed
        )

        meta ={"model": response.model,"usage":dict(response.usage)}return{"embedding": response.data[0].embedding,"meta": meta}@componentclassAsyncDocumentEmbedder(OpenAIDocumentEmbedder):def__init__(
        self,
        api_key: Secret = Secret.from_env_var("EMBEDDER_OPENAI_API_KEY"),
        model:str="text-embedding-ada-002",
        dimensions: Optional[int]=None,
        api_base_url: Optional[str]=None,
        organization: Optional[str]=None,
        prefix:str="",
        suffix:str="",
        batch_size:int=32,
        progress_bar:bool=True,
        meta_fields_to_embed: Optional[List[str]]=None,
        embedding_separator:str="\n",):super(AsyncDocumentEmbedder, self).__init__(
            api_key,
            model,
            dimensions,
            api_base_url,
            organization,
            prefix,
            suffix,
            batch_size,
            progress_bar,
            meta_fields_to_embed,
            embedding_separator,)
        self.client = AsyncOpenAI(
            api_key=api_key.resolve_value(),
            organization=organization,
            base_url=api_base_url,)asyncdef_embed_batch(
        self, texts_to_embed: List[str], batch_size:int)-> Tuple[List[List[float]], Dict[str, Any]]:
        all_embeddings =[]
        meta: Dict[str, Any]={}for i in tqdm(range(0,len(texts_to_embed), batch_size),
            disable=not self.progress_bar,
            desc="Calculating embeddings",):
            batch = texts_to_embed[i : i + batch_size]# if self.dimensions is not None:#     response = await self.client.embeddings.create(#         model=self.model, dimensions=self.dimensions, input=batch#     )# else:
            response =await self.client.embeddings.create(
                model=self.model,input=batch
            )
            embeddings =[el.embedding for el in response.data]
            all_embeddings.extend(embeddings)if"model"notin meta:
                meta["model"]= response.model
            if"usage"notin meta:
                meta["usage"]=dict(response.usage)else:
                meta["usage"]["prompt_tokens"]+= response.usage.prompt_tokens
                meta["usage"]["total_tokens"]+= response.usage.total_tokens

        return all_embeddings, meta

    @component.output_types(documents=List[Document], meta=Dict[str, Any])@backoff.on_exception(backoff.expo, openai.RateLimitError, max_time=60, max_tries=3)asyncdefrun(self, documents: List[Document]):if(notisinstance(documents,list)or documents
            andnotisinstance(documents[0], Document)):raise TypeError("OpenAIDocumentEmbedder expects a list of Documents as input.""In case you want to embed a string, please use the OpenAITextEmbedder.")

        logger.debug(f"Running Async OpenAI document embedder with documents: {documents}")

        texts_to_embed = self._prepare_texts_to_embed(documents=documents)

        embeddings, meta =await self._embed_batch(
            texts_to_embed=texts_to_embed, batch_size=self.batch_size
        )for doc, emb inzip(documents, embeddings):
            doc.embedding = emb

        return{"documents": documents,"meta": meta}@provider("openai_like_embedder")classOpenAIEmbedderProvider(EmbedderProvider):def__init__(
        self,
        api_key: Secret = Secret.from_env_var("EMBEDDER_OPENAI_API_KEY"),
        api_base:str= os.getenv("EMBEDDER_OPENAI_API_BASE")or EMBEDDER_OPENAI_API_BASE,
        embedding_model:str= os.getenv("EMBEDDING_MODEL")or EMBEDDING_MODEL,
        embedding_model_dim:int=(int(os.getenv("EMBEDDING_MODEL_DIMENSION"))if os.getenv("EMBEDDING_MODEL_DIMENSION")else0)or EMBEDDING_MODEL_DIMENSION,):def_verify_api_key(api_key:str, api_base:str)->None:"""
            this is a temporary solution to verify that the required environment variables are set
            """
            OpenAI(api_key=api_key, base_url=api_base).models.list()

        logger.info(f"Initializing OpenAIEmbedder provider with API base: {api_base}")# TODO: currently only OpenAI api key can be verifiedif api_base == EMBEDDER_OPENAI_API_BASE:
            _verify_api_key(api_key.resolve_value(), api_base)
            logger.info(f"Using OpenAI Embedding Model: {embedding_model}")else:
            logger.info(f"Using OpenAI API-compatible Embedding Model: {embedding_model}")
        self._api_key = api_key
        self._api_base = api_base
        self._embedding_model = embedding_model
        self._embedding_model_dim = embedding_model_dim

    defget_text_embedder(self):return AsyncTextEmbedder(
            api_key=self._api_key,
            api_base_url=self._api_base,
            model=self._embedding_model,
            dimensions=self._embedding_model_dim,)defget_document_embedder(self):return AsyncDocumentEmbedder(
            api_key=self._api_key,
            api_base_url=self._api_base,
            model=self._embedding_model,
            dimensions=self._embedding_model_dim,)
标签: 人工智能 开源 sql

本文转载自: https://blog.csdn.net/shizidushu/article/details/140449466
版权归原作者 shizidushu 所有, 如有侵权,请联系我们删除。

“使用自定义大模型来部署Wren AI(开源的文本生成SQL方案)”的评论:

还没有评论