0%

使用大型語言模型(LLM)完成Text-to-SQL任務 - Langchain SQL Chain說明

上一篇透過Langchain SQL Chain建立資料問答系統中使用Langchain的sql query chain能夠快速開發資料問答功能。雖然Langchain提供了方便的開發模組,但卻封裝了許多與大型語言互動的細節。我們可以嘗試的從原始碼來看幾個Langchain封裝起來的細節。

和LLM互動的prompt

在這篇使用大型語言模型(LLM)完成Text-to-SQL任務可以看到要把輸入的問題轉成SQL查詢,和LLM互動時會需要準備prompt,如果有自己準備prompt的話,就能在呼叫create_sql_query_chain帶進prompt參數。如果沒有的話,Langchain會預設使用sql_database的prompt

1
2
from langchain.chains.sql_database.prompt import SQL_PROMPTS
SQL_PROMPTS[db.dialect].pretty_print() # sqlite

在呼叫create_sql_query_chain且沒有提供prompt的情況下,會從SQLDatabase取得dialect,再從Langchain的sql_database取出對應資料庫的prompt。目前Langchain提供11種database prompt,每一種資料庫的prompt都有所不同,特別是在current date的寫法會不同。

以下是以SQLite為例,這prompt中會有三個在過程中會被帶入的參數,分別為top_k、table_info和input。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use date('now') function to get the current date, if the question involves "today".

Use the following format:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result of the SQLQuery
Answer: Final answer here

Only use the following tables:
{table_info}

Question: {input}

prompt內置換的變數

top_k
1
2
3
4
5
6
7
8
# langchain/langchain/chains/sql_database/query.py
def create_sql_query_chain(
llm: BaseLanguageModel,
db: SQLDatabase,
prompt: Optional[BasePromptTemplate] = None,
k: int = 5,

# top_k = 5

top_k本身在create_sql_query_chain就是一個輸入參數,如果不指定的話預設值會是5。

input
1
2
3
4
5
6
7
8
9
10
11
# langchain/langchain/chains/sql_database/query.py
inputs = {
"input": lambda x: x["question"] + "\nSQLQuery: ",
"table_info": lambda x: db.get_table_info(
table_names=x.get("table_names_to_use")
),
}
# chain.invoke({"question": "How many employees are there"})

# input = How Many employees are there
# SQLQuery:

input會從參數dict取出question,就是對應到在invoke時帶入的問題,接著會在問題後接SQLQuery用來讓LLM往後生成SQL查詢。

table_info

table_info會呼叫SQLDatabase的get_table_info函式,這個函式的作用是查找出資料庫內的資料表schmea加三筆資料,而且這裡可以接受帶入一個List參數table_names_to_use正向表列要查的資料表。注意如果不指定資料表,get_table_info會帶出所有的資料表資訊。

1
db.get_table_info(table_names=["Employee"])

實際使用這個函式並帶入資料表名稱,就會取得以下的資訊。這裡的目的就是要讓LLM看懂資料表名稱、欄位清單,還有範例的資料,讓LLM可以更準確的將問題轉成SQL。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
CREATE TABLE "Employee" (
"EmployeeId" INTEGER NOT NULL,
"LastName" NVARCHAR(20) NOT NULL,
"FirstName" NVARCHAR(20) NOT NULL,
"Title" NVARCHAR(30),
"ReportsTo" INTEGER,
"BirthDate" DATETIME,
"HireDate" DATETIME,
"Address" NVARCHAR(70),
"City" NVARCHAR(40),
"State" NVARCHAR(40),
"Country" NVARCHAR(40),
"PostalCode" NVARCHAR(10),
"Phone" NVARCHAR(24),
"Fax" NVARCHAR(24),
"Email" NVARCHAR(60),
PRIMARY KEY ("EmployeeId"),
FOREIGN KEY("ReportsTo") REFERENCES "Employee" ("EmployeeId")
)

/*
3 rows from Employee table:
EmployeeId LastName FirstName Title ReportsTo BirthDate HireDate Address City State Country PostalCode Phone Fax Email
1 Adams Andrew General Manager None 1962-02-18 00:00:00 2002-08-14 00:00:00 11120 Jasper Ave NW Edmonton AB Canada T5K 2N1 +1 (780) 428-9482 +1 (780) 428-3457 andrew@chinookcorp.com
2 Edwards Nancy Sales Manager 1 1958-12-08 00:00:00 2002-05-01 00:00:00 825 8 Ave SW Calgary AB Canada T2P 2T3 +1 (403) 262-3443 +1 (403) 262-3322 nancy@chinookcorp.com
3 Peacock Jane Sales Support Agent 2 1973-08-29 00:00:00 2002-04-01 00:00:00 1111 6 Ave SW Calgary AB Canada T2P 5M5 +1 (403) 262-3443 +1 (403) 262-6712 jane@chinookcorp.com
*/

完整的prompt

總結以上的prompt template加上top_k、table_info和input三個要代入的參數,最後完整丟給LLM的prompt如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use date('now') function to get the current date, if the question involves "today".

Use the following format:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result of the SQLQuery
Answer: Final answer here

Only use the following tables:
CREATE TABLE "Employee" (
"EmployeeId" INTEGER NOT NULL,
"LastName" NVARCHAR(20) NOT NULL,
"FirstName" NVARCHAR(20) NOT NULL,
"Title" NVARCHAR(30),
"ReportsTo" INTEGER,
"BirthDate" DATETIME,
"HireDate" DATETIME,
"Address" NVARCHAR(70),
"City" NVARCHAR(40),
"State" NVARCHAR(40),
"Country" NVARCHAR(40),
"PostalCode" NVARCHAR(10),
"Phone" NVARCHAR(24),
"Fax" NVARCHAR(24),
"Email" NVARCHAR(60),
PRIMARY KEY ("EmployeeId"),
FOREIGN KEY("ReportsTo") REFERENCES "Employee" ("EmployeeId")
)

/*
3 rows from Employee table:
EmployeeId LastName FirstName Title ReportsTo BirthDate HireDate Address City State Country PostalCode Phone Fax Email
1 Adams Andrew General Manager None 1962-02-18 00:00:00 2002-08-14 00:00:00 11120 Jasper Ave NW Edmonton AB Canada T5K 2N1 +1 (780) 428-9482 +1 (780) 428-3457 andrew@chinookcorp.com
2 Edwards Nancy Sales Manager 1 1958-12-08 00:00:00 2002-05-01 00:00:00 825 8 Ave SW Calgary AB Canada T2P 2T3 +1 (403) 262-3443 +1 (403) 262-3322 nancy@chinookcorp.com
3 Peacock Jane Sales Support Agent 2 1973-08-29 00:00:00 2002-04-01 00:00:00 1111 6 Ave SW Calgary AB Canada T2P 5M5 +1 (403) 262-3443 +1 (403) 262-6712 jane@chinookcorp.com
*/

Question: How Many employees are there
SQLQuery:

如果要把table_names_to_use帶入資料表清單,可以用RunnablePassthrough作assign,注意如果資料表不存在資料庫中會回傳table_names not found in database。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
from langchain_core.runnables import RunnableLambda
from langchain_core.runnables import RunnablePassthrough

execute_query = QuerySQLDataBaseTool(db=db, verbose=True)
write_query = create_sql_query_chain(llm, db)

chain = (
# assign table_names_to_use
RunnablePassthrough.assign(table_names_to_use=RunnableLambda(lambda x: ['Employee']))
| write_query
| RunnableLambda(lambda x: parse_sql(x))
| execute_query
)

chain.invoke({"question": "How many employees are there"})
1
[(8,)]