Skip to content

Commit

Permalink
tool: sql_database generate sql-query filter redundant text (#612)
Browse files Browse the repository at this point in the history
* tool: sql_database generate sql-query filter redundant text
  • Loading branch information
devinyf authored Mar 13, 2024
1 parent b4e6529 commit ca2969c
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 1 deletion.
53 changes: 52 additions & 1 deletion chains/sql_database.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,12 @@ func (s SQLDatabaseChain) Call(ctx context.Context, inputs map[string]any, optio
if err != nil {
return nil, err
}
sqlQuery := strings.TrimSpace(out)

sqlQuery := extractSQLQuery(out)

if sqlQuery == "" {
return nil, fmt.Errorf("no sql query generated")
}

// Execute sql query
queryResult, err := s.Database.Query(ctx, sqlQuery)
Expand Down Expand Up @@ -148,3 +153,49 @@ func (s SQLDatabaseChain) GetInputKeys() []string {
func (s SQLDatabaseChain) GetOutputKeys() []string {
return []string{s.OutputKey}
}

// sometimes llm model returned result is not only the SQLQuery,
// it also contains some extra text,
// which will cause the entire process to fail.
// this function is used to extract the exact SQLQuery from the result.
// nolint:cyclop
func extractSQLQuery(rawOut string) string {
outStrings := strings.Split(rawOut, "\n")

var sqlQuery string
containSQLQuery := strings.Contains(rawOut, "SQLQuery:")
findSQLQuery := false

for _, v := range outStrings {
line := strings.TrimSpace(v)

// filter empty line and markdown symbols
if line == "" || strings.HasPrefix(line, "```") {
continue
}

// stop when we find SQLResult: or Answer:
if strings.HasPrefix(line, "SQLResult:") || strings.HasPrefix(line, "Answer:") {
break
}

var currentLine string
switch {
case containSQLQuery && strings.HasPrefix(line, "SQLQuery:"):
findSQLQuery = true
currentLine = strings.TrimPrefix(line, "SQLQuery:")
if strings.TrimSpace(currentLine) == "" {
continue
}
case containSQLQuery && !findSQLQuery:
// filter unwanted text above the SQLQuery:
continue
default:
currentLine = line
}

sqlQuery += currentLine + "\n"
}

return strings.TrimSpace(sqlQuery)
}
91 changes: 91 additions & 0 deletions chains/sql_database_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,94 @@ func TestSQLDatabaseChain_Call(t *testing.T) {

t.Log(ret)
}

func TestExtractSQLQuery(t *testing.T) {
t.Parallel()

cases := []struct {
inputStr string
expected string
}{
{
inputStr: "SELECT * FROM example_table;",
expected: "SELECT * FROM example_table;",
},
{
inputStr: `
I am a clumsy llm model
I just feel good to put some extra text here.
SQLQuery: SELECT * FROM example_table;
SQLResult: 3 (this is not a real data)
Answer: There are 3 data in the table. (this is not a real data)`,
expected: "SELECT * FROM example_table;",
},
{
inputStr: `
SELECT * FROM example_table;
SQLResult: 3 (this is not a real data)
Answer: There are 3 data in the table. (this is not a real data)`,
expected: "SELECT * FROM example_table;",
},
{
inputStr: "```sql" + `
SELECT * FROM example_table;
` + "```" + `
SQLResult: 3 (this is not a real data)
Answer: There are 3 data in the table. (this is not a real data)`,
expected: "SELECT * FROM example_table;",
},
{ // multi-line sql query with markdown symbols and redundant text above and below
inputStr: `
I am also a clumsy llm model, I don't fully understand the prompt
And accidentally put some extra text here.
SQLQuery:
` + "```sql\n" + `
SELECT
order_id,
customer_name,
order_date
FROM orders;
` + "```" + `
SQLResult: xxx (this is not a real data)
Answer: some illusion answer. (this is not a real data)`,
expected: `SELECT
order_id,
customer_name,
order_date
FROM orders;`,
},
{ // slightly complexed multi-line query, no extra text before but only with redundant text after
inputStr: `SELECT
orders.order_id,
customers.customer_name,
orders.order_date
FROM
orders
INNER JOIN customers ON orders.customer_id = customers.customer_id
WHERE
orders.order_date >= '2023-01-01'
ORDER BY
orders.order_date;
SQLResult: xxx (this is not a real data)
Answer: some illusion answer. (this is not a real data)`,
expected: `SELECT
orders.order_id,
customers.customer_name,
orders.order_date
FROM
orders
INNER JOIN customers ON orders.customer_id = customers.customer_id
WHERE
orders.order_date >= '2023-01-01'
ORDER BY
orders.order_date;`,
},
}

for _, tc := range cases {
filterQuerySyntax := extractSQLQuery(tc.inputStr)
require.Equal(t, tc.expected, filterQuerySyntax)
}
}

0 comments on commit ca2969c

Please sign in to comment.