0


AI大模型应用(3)开源框架Vanna: 利用RAG方法做Text2SQL任务

AI大模型应用(3)开源框架Vanna: 利用RAG方法做Text2SQL任务

1 Vanna概述

1.1 Vanna的工作原理

  • vanna是基于检索增强(RAG)的sql生成框架,具体的执行逻辑如下:- 先用向量数据库将待查询数据库的建表语句、文档、常用SQL及其自然语言查询问题存储起来。- 用户发起查询请求时,会先从向量数据库中检索出相关的建表语句、文档、SQL问答对放入到prompt里(DDL和文档作为上下文、SQL问答对作为few-shot样例)- LLM根据prompt生成查询SQL并执行,框架会进一步将查询结果使用plotly可视化出来或用LLM生成后续问题。- 如果用户反馈LLM生成的结果是正确的,可以将这一问答对存储到向量数据库,可以使得以后的生成结果更准确在这里插入图片描述
  • Vanna 的工作过程可以概括为:- 在用户的数据上训练 RAG模型,然后提出问题,这些问题将返回 SQL 查询,这些查询可以设置为在用户的数据库上自动运行。- 这里的"训练"是指:根据数据结构构建向量库。- 用户可以使用 DDL 语句、文档或样例 SQL 查询对 Vanna 进行训练,让它掌握数据库的结构、业务术语和查询模式。- Vanna 会将训练数据转化为向量嵌入,存储在向量数据库中,并建立元数据索引,以便于后续检索。
  • Vanna的优缺点:- Vanna最大的优点就是:用户可以选择在成功执行的查询上“自动训练”,或让界面提示用户对结果提供反馈,使未来的结果更加准确。- 缺点也很明显: - 生成的 SQL 查询可能不完全准确,需要人工干预来修正。- 复杂查询生成能力有限,这也是Text2SQL场景的挑战了,尤其是涉及到多表查询。

1.2 Vanna的快速上手

1.2.1 利用OPENAI提供的API_KEY以及ChromaDB向量数据库

"""
Vanna的源码安装:
github地址:https://github.com/vanna-ai/vanna
拉取vanna的源码,使用下面命令安装pyproject.toml中定义的依赖项:
pip install .

我这里使用国内的代理,使用代理网站提供的url和key,使用起来和原生的区别不大。
代理网站如下:
https://api.zetatechs.com/
"""import os
from vanna.openai.openai_chat import OpenAI_Chat
from vanna.chromadb.chromadb_vector import ChromaDB_VectorStore
from vanna.flask import VannaFlaskApp

api_url =str(os.getenv('OPENAI_URL'))
api_key =str(os.getenv('OPENAI_API_KEY'))classMyVanna(ChromaDB_VectorStore, OpenAI_Chat):def__init__(self, config=None):
        ChromaDB_VectorStore.__init__(self, config=config)
        OpenAI_Chat.__init__(self, config=config)"""
我这里由于使用的代理网站提供的url和api_key,需要在vanna.openai.openai_chat.py中修改:
修改前:
if "api_key" in config:
    self.client = OpenAI(api_key=config["api_key"])

修改后:
if "api_key" in config:
    self.client = OpenAI(api_key=config["api_key"], base_url=config['base_url'])
"""
vn = MyVanna(config={'api_key': api_key,'model':'gpt-3.5-turbo','base_url': api_url})# 链接本地的Mysql数据库
vn.connect_to_mysql(host='localhost', dbname='test', user='root', password='root', port=3306)# 训练Vanna,构建知识库
vn.train(ddl="""
CREATE TABLE `goods` (
  `id` int(10) unsigned NOT NULL AUTO_INCREMENT,
  `name` varchar(150) NOT NULL,
  `cate_name` varchar(40) NOT NULL,
  `brand_name` varchar(40) NOT NULL,
  `price` decimal(10,3) NOT NULL DEFAULT '0.000',
  `is_show` bit(1) NOT NULL DEFAULT b'1',
  `is_saleoff` bit(1) NOT NULL DEFAULT b'0',
  PRIMARY KEY (`id`)
) ENGINE=InnoDB AUTO_INCREMENT=22 DEFAULT CHARSET=utf8;
""")

vn.train(
    documentation="""
    goods表中的字段cate_name为电脑类型,包括:笔记本、游戏本、超极本、平板电脑、台式机、服务器/工作站、笔记本配件
    goods表中的字段brand_name为品牌名字,包括:华硕、联想、索尼、戴尔、苹果等
    goods表中的字段name为电子产品具体型号,例如:ipad air 9.7英寸平板电脑
    """)

vn.train(question="华硕品牌的笔记本的平均价格是多少?", sql="SELECT AVG(price) AS avg_price FROM goods WHERE brand_name = '华硕' AND cate_name = '笔记本';")# 访问地址: http://localhost:8084
VannaFlaskApp(vn).run()

在TrainingData中就会出现代码中训练数据:

在这里插入图片描述

我们就可以利用自然语言进行查询了:

在这里插入图片描述

如果生成的SQL是正确的,我们就可以点击下面的按钮,将此条

Question-SQL pair

添加到知识库中:

在这里插入图片描述

右上角的

Open Debugger

中可以看到如下信息:

  • 包括生成SQL的提示词,大模型的回复,以及提取到最终执行的SQL

在这里插入图片描述

1.2.2 利用fastchat部署本地模型

"""
# fastchat的安装
pip install "fschat[model_worker,webui]"

# 1、启动controller
python -m fastchat.serve.controller  --host 0.0.0.0 --port 21001

# 2、启动worker(这里采用小模型,可以CPU运行)
python -m fastchat.serve.model_worker --model-names qwen2-1.5b --model-path D:\python\models\qwen\Qwen2-1.5B-Instruct
--host 0.0.0.0 --device cpu

# 3、启动openai_api_server
# 兼容OpenAI的RESTful API
python -m fastchat.serve.openai_api_server --controller-address http://127.0.0.1:21001 --host 0.0.0.0 --port 48000

# 4、测试OpenAI的RESTful API
​```python
from openai import OpenAI

api_key = 'none'
api_url = 'http://127.0.0.1:48000/v1'

client = OpenAI(base_url=api_url, api_key=api_key)

completion = client.chat.completions.create(
  model="qwen2-1.5b",
  messages=[
    {"role": "system", "content": "You are a helpful assistant."},
    {"role": "user", "content": "你是谁?"}
  ]
)
print(completion)
​```
"""from vanna.openai.openai_chat import OpenAI_Chat
from vanna.chromadb.chromadb_vector import ChromaDB_VectorStore
from vanna.flask import VannaFlaskApp

api_url ='http://127.0.0.1:48000/v1'
api_key ='none'classMyVanna(ChromaDB_VectorStore, OpenAI_Chat):def__init__(self, config=None):
        ChromaDB_VectorStore.__init__(self, config=config)
        OpenAI_Chat.__init__(self, config=config)

vn = MyVanna(config={'api_key': api_key,'model':'qwen2-1.5b','base_url': api_url})# 链接数据库
vn.connect_to_mysql(host='localhost', dbname='test', user='root', password='root', port=3306)# 训练Vanna,构建知识库
vn.train(ddl="""
CREATE TABLE `goods` (
  `id` int(10) unsigned NOT NULL AUTO_INCREMENT,
  `name` varchar(150) NOT NULL,
  `cate_name` varchar(40) NOT NULL,
  `brand_name` varchar(40) NOT NULL,
  `price` decimal(10,3) NOT NULL DEFAULT '0.000',
  `is_show` bit(1) NOT NULL DEFAULT b'1',
  `is_saleoff` bit(1) NOT NULL DEFAULT b'0',
  PRIMARY KEY (`id`)
) ENGINE=InnoDB AUTO_INCREMENT=22 DEFAULT CHARSET=utf8;
""")

vn.train(
    documentation="""
    goods表中的字段cate_name为电脑类型,包括:笔记本、游戏本、超极本、平板电脑、台式机、服务器/工作站、笔记本配件
    goods表中的字段brand_name为品牌名字,包括:华硕、联想、索尼、戴尔、苹果等
    goods表中的字段name为电子产品具体型号,例如:ipad air 9.7英寸平板电脑
    """)

vn.train(question="华硕品牌的笔记本的平均价格是多少?", sql="SELECT AVG(price) AS avg_price FROM goods WHERE brand_name = '华硕' AND cate_name = '笔记本';")# 访问地址: http://localhost:8084
VannaFlaskApp(vn).run()

2 Vanna源码分析

我们可以利用下面代码,查看源码:

from vanna.openai.openai_chat import OpenAI_Chat
from vanna.chromadb.chromadb_vector import ChromaDB_VectorStore

api_url ='http://127.0.0.1:48000/v1'
api_key ='none'classMyVanna(ChromaDB_VectorStore, OpenAI_Chat):def__init__(self, config=None):
        ChromaDB_VectorStore.__init__(self, config=config)
        OpenAI_Chat.__init__(self, config=config)

vn = MyVanna(config={'api_key': api_key,'model':'qwen2-1.5b','base_url': api_url})# 链接数据库
vn.connect_to_mysql(host='localhost', dbname='test', user='root', password='root', port=3306)# 训练Vanna,构建知识库
vn.train(ddl="""
CREATE TABLE `goods` (
  `id` int(10) unsigned NOT NULL AUTO_INCREMENT,
  `name` varchar(150) NOT NULL,
  `cate_name` varchar(40) NOT NULL,
  `brand_name` varchar(40) NOT NULL,
  `price` decimal(10,3) NOT NULL DEFAULT '0.000',
  `is_show` bit(1) NOT NULL DEFAULT b'1',
  `is_saleoff` bit(1) NOT NULL DEFAULT b'0',
  PRIMARY KEY (`id`)
) ENGINE=InnoDB AUTO_INCREMENT=22 DEFAULT CHARSET=utf8;
""")

vn.train(
    documentation="""
    goods表中的字段cate_name为电脑类型,包括:笔记本、游戏本、超极本、平板电脑、台式机、服务器/工作站、笔记本配件
    goods表中的字段brand_name为品牌名字,包括:华硕、联想、索尼、戴尔、苹果等
    goods表中的字段name为电子产品具体型号,例如:ipad air 9.7英寸平板电脑
    """)

vn.train(question="华硕品牌的笔记本的平均价格是多少?", sql="SELECT AVG(price) AS avg_price FROM goods WHERE brand_name = '华硕' AND cate_name = '笔记本';")# 用户提问
sql, df, fig = vn.ask("华硕品牌的笔记本的最低价格、最高价格分别是多少?")print('=======================================')print('the final sql = \n',  sql)print('the final df = \n', df)

Vanna的核心代码就是

src.vanna.base.base.py

文件下的train函数和ask函数

2.1 train函数

  • train函数就是根据documentation、sql以及ddl构建知识库
# src.vanna.base.base.pydeftrain(
        self,
        question:str=None,
        sql:str=None,
        ddl:str=None,
        documentation:str=None,
        plan: TrainingPlan =None,)->str:if question andnot sql:raise ValidationError("Please also provide a SQL query")if documentation:print("Adding documentation....")return self.add_documentation(documentation)if sql:if question isNone:
                question = self.generate_question(sql)print("Question generated with sql:", question,"\nAdding SQL...")return self.add_question_sql(question=question, sql=sql)if ddl:print("Adding ddl:", ddl)return self.add_ddl(ddl)......
  • 函数中add_documentation、add_ddl以及add_question_sql均为抽象函数,不同的向量数据库有不同的实现方式
  • 比如,这里使用的chromadb的实现方式如下:
# src.vanna.chromadb.chromadb_vector.pyclassChromaDB_VectorStore(VannaBase):def__init__(self, config=None):
        VannaBase.__init__(self, config=config)if config isNone:
            config ={}......# 创建三个集合,分别存储:文档、DDL、以及sql 三种经过embedding的知识库信息
        self.documentation_collection = self.chroma_client.get_or_create_collection(
            name="documentation",
            embedding_function=self.embedding_function,
            metadata=collection_metadata,)
        self.ddl_collection = self.chroma_client.get_or_create_collection(
            name="ddl",
            embedding_function=self.embedding_function,
            metadata=collection_metadata,)
        self.sql_collection = self.chroma_client.get_or_create_collection(
            name="sql",
            embedding_function=self.embedding_function,
            metadata=collection_metadata,)defgenerate_embedding(self, data:str,**kwargs)-> List[float]:
        embedding = self.embedding_function([data])iflen(embedding)==1:return embedding[0]return embedding

    defadd_question_sql(self, question:str, sql:str,**kwargs)->str:
        question_sql_json = json.dumps({"question": question,"sql": sql,},
            ensure_ascii=False,)id= deterministic_uuid(question_sql_json)+"-sql"
        self.sql_collection.add(
            documents=question_sql_json,
            embeddings=self.generate_embedding(question_sql_json),
            ids=id,)returniddefadd_ddl(self, ddl:str,**kwargs)->str:id= deterministic_uuid(ddl)+"-ddl"
        self.ddl_collection.add(
            documents=ddl,
            embeddings=self.generate_embedding(ddl),
            ids=id,)returniddefadd_documentation(self, documentation:str,**kwargs)->str:id= deterministic_uuid(documentation)+"-doc"
        self.documentation_collection.add(
            documents=documentation,
            embeddings=self.generate_embedding(documentation),
            ids=id,)returnid......

2.2 ask函数

defask(
        self,
        question: Union[str,None]=None,
        print_results:bool=True,
        auto_train:bool=True,
        visualize:bool=True,# if False, will not generate plotly code
        allow_llm_to_see_data:bool=False,)-> Union[
        Tuple[
            Union[str,None],
            Union[pd.DataFrame,None],
            Union[plotly.graph_objs.Figure,None],],None,]:"""
        **Example:**
        ```python
        vn.ask("What are the top 10 customers by sales?")
        ```

        Ask Vanna.AI a question and get the SQL query that answers it.

        Args:
            question (str): The question to ask.
                               提出的问题
            print_results (bool): Whether to print the results of the SQL query.
                               是否打印SQL查询的结果, 默认为True
            auto_train (bool): Whether to automatically train Vanna.AI on the question and SQL query.
                               是否自动使用问题和SQL查询对Vanna.AI进行训练, 默认为True
            visualize (bool): Whether to generate plotly code and display the plotly figure.
                               是否生成plotly代码并显示plotly图表, 默认为True

        Returns:
            Tuple[str, pd.DataFrame, plotly.graph_objs.Figure]:
                  The SQL query, the results of the SQL query, and the plotly figure.
                  包含SQL查询语句、SQL查询的结果(以pandas DataFrame形式)以及plotly图表对象的元组
        """if question isNone:
            question =input("Enter a question: ")try:# 1、根据用户的question产生SQL
            sql = self.generate_sql(question=question, allow_llm_to_see_data=allow_llm_to_see_data)except Exception as e:print(e)returnNone,None,None......try:# 2、在相应的数据库中执行SQL语句,获取结果# 这里支持不同的数据库的查询语句,如:sqlite、mysql、oracle、hive、clickhouse等
            df = self.run_sql(sql)if print_results:try:
                    display =__import__("IPython.display", fromList=["display"]).display
                    display(df)except Exception as e:print(df)iflen(df)>0and auto_train:
                self.add_question_sql(question=question, sql=sql)# 3、对查询的结果进行可视化# Only generate plotly code if visualize is Trueif visualize:try:
                    plotly_code = self.generate_plotly_code(
                        question=question,
                        sql=sql,
                        df_metadata=f"Running df.dtypes gives:\n {df.dtypes}",)
                    fig = self.get_plotly_figure(plotly_code=plotly_code, df=df)if print_results:try:
                            display =__import__("IPython.display", fromlist=["display"]).display
                            Image =__import__("IPython.display", fromlist=["Image"]).Image
                            img_bytes = fig.to_image(format="png", scale=2)
                            display(Image(img_bytes))except Exception as e:
                            fig.show()except Exception as e:# Print stack trace
                    traceback.print_exc()print("Couldn't run plotly code: ", e)if print_results:returnNoneelse:return sql, df,Noneelse:return sql, df,Noneexcept Exception as e:print("Couldn't run sql: ", e)if print_results:returnNoneelse:return sql,None,Nonereturn sql, df, fig
  • 用户发起查询请求时,会先从向量数据库中检索出相关的建表语句、文档、SQL问答对放入到prompt里(DDL和文档作为上下文、SQL问答对作为few-shot样例)
defgenerate_sql(self, question:str, allow_llm_to_see_data=False,**kwargs)->str:"""
        Example:
        ```python
        vn.generate_sql("What are the top 10 customers by sales?")
        ```

        Uses the LLM to generate a SQL query that answers a question. It runs the following methods:
        该函数使用大语言模型(LLM)生成一个能够回答特定问题的SQL查询。它按顺序执行以下方法:
          1、获取与输入问题相似的SQL查询
          2、获取与问题相关的数据定义语言(DDL)语句
          3、获取与问题相关的文档
          4、生成用于提交给LLM的SQL查询prompt
          5、将提示提交给LLM并获取生成的SQL查询

        Args:
            question (str): The question to generate a SQL query for.
            allow_llm_to_see_data (bool):
                  Whether to allow the LLM to see the data (for the purposes of introspecting the data to generate the final SQL).
                  是否允许大型语言模型(LLM)查看数据,以便更好地理解数据结构并生成相应的SQL查询

        Returns:
            str: The SQL query that answers the question.
        """if self.config isnotNone:
            initial_prompt = self.config.get("initial_prompt",None)else:
            initial_prompt =None# 1、获取与输入问题相似的SQL查询,默认最多返回10条数据
        question_sql_list = self.get_similar_question_sql(question,**kwargs)# 2、获取与问题相关的DDL语句,默认最多返回10条数据
        ddl_list = self.get_related_ddl(question,**kwargs)# 3、获取与问题相关的文档
        doc_list = self.get_related_documentation(question,**kwargs)# 4、生成用于提交给LLM的SQL查询prompt
        prompt = self.get_sql_prompt(
            initial_prompt=initial_prompt,
            question=question,
            question_sql_list=question_sql_list,
            ddl_list=ddl_list,
            doc_list=doc_list,**kwargs,)
        self.log(title="SQL Prompt", message=prompt)......

函数中get_similar_question_sql均为抽象函数,不同的向量数据库有不同的实现方式

比如,这里使用的chromadb的实现方式如下:

# src.vanna.chromadb.chromadb_vector.pyclassChromaDB_VectorStore(VannaBase):......defget_similar_question_sql(self, question:str,**kwargs)->list:return ChromaDB_VectorStore._extract_documents(
            self.sql_collection.query(
                query_texts=[question],
                n_results=self.n_results_sql,# 默认为10))defget_related_ddl(self, question:str,**kwargs)->list:return ChromaDB_VectorStore._extract_documents(
            self.ddl_collection.query(
                query_texts=[question],
                n_results=self.n_results_ddl,# 默认为10))defget_related_documentation(self, question:str,**kwargs)->list:return ChromaDB_VectorStore._extract_documents(
            self.documentation_collection.query(
                query_texts=[question],
                n_results=self.n_results_documentation,# 默认为10))

我们看下,最终的prompt:

DDL和文档作为上下文、SQL问答对作为few-shot样例
[[{'role':'system', 'content':"""
You are a SQL expert. Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. 
===Tables 

CREATE TABLE `goods`(`id` int(10) unsigned NOT NULL AUTO_INCREMENT,
  `name` varchar(150) NOT NULL,
  `cate_name` varchar(40) NOT NULL,
  `brand_name` varchar(40) NOT NULL,
  `price` decimal(10,3) NOT NULL DEFAULT '0.000',
  `is_show` bit(1) NOT NULL DEFAULT b'1',
  `is_saleoff` bit(1) NOT NULL DEFAULT b'0',
  PRIMARY KEY (`id`))ENGINE=InnoDB AUTO_INCREMENT=22 DEFAULT CHARSET=utf8;===Additional Context 

    goods表中的字段cate_name为电脑类型,包括:笔记本、游戏本、超极本、平板电脑、台式机、服务器/工作站、笔记本配件
    goods表中的字段brand_name为品牌名字,包括:华硕、联想、索尼、戴尔、苹果等
    goods表中的字段name为电子产品具体型号,例如:ipad air 9.7英寸平板电脑
    

===Response Guidelines 
1. If the provided context is sufficient, please generate a valid SQL query without any explanations for the question. 
2. If the provided context is almost sufficient but requires knowledge of a specific string in a particular column, please generate an intermediate SQL query to find the distinct strings in that column. Prepend the query with a comment saying intermediate_sql 
3. If the provided context is insufficient, please explain why it can't be generated. 
4. Please use the most relevant table(s). 
5. If the question has been asked and answered before, please repeat the answer exactly as it was given before. 
6. Ensure that the output SQL is SQL-compliant and executable, and free of syntax errors. 
  """
  , {"role":"user","content":"华硕品牌的笔记本的平均价格是多少?"}
  , {"role":"assistant","content":"SELECT AVG(price) AS avg_price\nFROM goods\nWHERE brand_name = '华硕' AND cate_name = '笔记本';"}
  , {"role":"user","content":"华硕品牌的笔记本的最低价格、最高价格分别是多少?"}]
  • LLM根据prompt生成查询SQL
defgenerate_sql(self, question:str, allow_llm_to_see_data=False,**kwargs)->str:......# 5、将提示提交给LLM并获取生成的SQL查询# submit_prompt为抽象方法,这里实现方法在OpenAI_Chat中
        llm_response = self.submit_prompt(prompt,**kwargs)
        self.log(title="LLM Response", message=llm_response)if'intermediate_sql'in llm_response:......# 6、提取最终的SQLreturn self.extract_sql(llm_response)
  • 然后会执行查询SQL(支持的数据库如下所示,可以参考源码,这里不再赘述),框架会进一步将查询结果使用plotly可视化出来或用LLM生成后续问题。

在这里插入图片描述


本文转载自: https://blog.csdn.net/qq_44665283/article/details/143213115
版权归原作者 undo_try 所有, 如有侵权,请联系我们删除。

“AI大模型应用(3)开源框架Vanna: 利用RAG方法做Text2SQL任务”的评论:

还没有评论