Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

在线查询 支持AI根据描述生成查询语句 #2726

Merged
merged 7 commits into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion common/templates/config.html
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ <h5 style="color: darkgrey"><b>SQL上线</b></h5>
</div>
</div>
<h5 style="color: darkgrey"><b>SQL查询</b></h5>
<h6 style="color:red">注:开启脱敏功能必须要配置goInception信息,用于SQL语法解析</h6>
<h6 style="color:red">注:开启脱敏功能必须要配置goInception信息,用于SQL语法解析;若无OPENAI配置则不开启AI生成SQL语句的功能</h6>
<hr/>
<div class="form-horizontal">
<div class="form-group">
Expand Down Expand Up @@ -932,6 +932,41 @@ <h4 style="color: darkgrey; display: inline;"><b>OIDC 配置</b></h4>
</div>
</div>
</div>

<h4 style="color: darkgrey; display: inline;"><b>OPENAI 配置</b></h4>
<hr/>
<div class="form-horizontal">
<div class="form-group">
<label for="openai_base_url"
class="col-sm-4 control-label">OPENAI_BASE_URL</label>
<div class="col-sm-5">
<input type="text" class="form-control" id="openai_base_url"
key="openai_base_url"
value="{{ config.openai_base_url }}"
placeholder="openai base url" />
</div>
</div>
<div class="form-group">
<label for="openai_api_key"
class="col-sm-4 control-label">OPENAI_API_KEY</label>
<div class="col-sm-5">
<input type="text" class="form-control" id="openai_api_key"
key="openai_api_key"
value="{{ config.openai_api_key }}"
placeholder="openai api key" />
</div>
</div>
<div class="form-group">
<label for="default_chat_model"
class="col-sm-4 control-label">DEFAULT_CHAT_MODEL</label>
<div class="col-sm-5">
<input type="text" class="form-control" id="default_chat_model"
key="default_chat_model"
value="{{ config.default_chat_model }}"
placeholder="openai default chat model" />
</div>
</div>
</div>

<h4 style="color: darkgrey"><b>其他配置</b></h4>
<hr/>
Expand Down
44 changes: 44 additions & 0 deletions common/utils/openai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from openai import OpenAI
import logging
from common.config import SysConfig

logger = logging.getLogger("default")


class OpenaiClient:
def __init__(self):
all_config = SysConfig()
self.base_url = all_config.get("openai_base_url", "")
self.api_key = all_config.get("openai_api_key", "")
self.default_chat_model = all_config.get("default_chat_model", "")
LeoQuote marked this conversation as resolved.
Show resolved Hide resolved
self.client = OpenAI(base_url=self.base_url, api_key=self.api_key)

def request_chat_completion(self, messages, **kwargs):
"""chat_completion"""
completion = self.client.chat.completions.create(
model=self.default_chat_model, messages=messages, **kwargs
)
return completion

def generate_sql_by_openai(self, prompt: str, table_schema: str, query_desc: str):
"""根据传入的基本信息生成查询语句"""
messages = [
dict(role="user", content=f"{prompt}: {table_schema}\n{query_desc}")
]
logger.info(messages)
try:
res = self.request_chat_completion(messages)
return res.choices[0].message.content
except Exception as e:
raise ValueError(f"请求openai生成查询语句失败: {e}")


def check_openai_config():
"""校验openai所需配置是否存在"""
all_config = SysConfig()
base_url = all_config.get("openai_base_url")
api_key = all_config.get("openai_api_key")
default_chat_model = all_config.get("default_chat_model")
if base_url and api_key and default_chat_model:
return True
return False
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,5 @@ mozilla-django-oidc==3.0.0
django-auth-dingding==0.0.3
django-cas-ng==4.3.0
cassandra-driver
httpx
OpenAI
77 changes: 77 additions & 0 deletions sql/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from django.http import HttpResponse
from common.config import SysConfig
from common.utils.extend_json_encoder import ExtendJSONEncoder, ExtendJSONEncoderFTime
from common.utils.openai import OpenaiClient, check_openai_config
from common.utils.timer import FuncTimer
from sql.query_privileges import query_priv_check
from sql.utils.resource_group import user_instances
Expand Down Expand Up @@ -313,3 +314,79 @@ def kill_query_conn(instance_id, thread_id):
instance = Instance.objects.get(pk=instance_id)
query_engine = get_engine(instance)
query_engine.kill_connection(thread_id)


@permission_required("sql.menu_sqlquery", raise_exception=True)
def generate_sql(request):
"""
利用AI生成查询SQL, 传入数据基本结构和查询描述
:param request:
:return:
"""
query_desc = request.POST.get("query_desc")
query_prompt = request.POST.get("query_prompt")
if not query_desc or not query_prompt:
return HttpResponse(
json.dumps(
{"status": 1, "msg": "query_desc or query_prompt不存在", "data": []}
),
content_type="application/json",
)

instance_name = request.POST.get("instance_name")
try:
instance = Instance.objects.get(instance_name=instance_name)
except Instance.DoesNotExist:
return HttpResponse(
json.dumps({"status": 1, "msg": "实例不存在", "data": []}),
content_type="application/json",
)
db_name = request.POST.get("db_name")
schema_name = request.POST.get("schema_name")
tb_name = request.POST.get("tb_name")

result = {"status": 0, "msg": "ok", "data": ""}
try:
query_engine = get_engine(instance=instance)
query_result = query_engine.describe_table(
db_name, tb_name, schema_name=schema_name
)
openai_client = OpenaiClient()
# 有些不存在表结构, 例如 redis
if len(query_result.rows) != 0:
result["data"] = openai_client.generate_sql_by_openai(
query_prompt, query_result.rows[0][-1], query_desc
)
else:
result["data"] = openai_client.generate_sql_by_openai(
query_prompt, "", query_desc
)
except Exception as msg:
result["status"] = 1
result["msg"] = str(msg)
return HttpResponse(json.dumps(result), content_type="application/json")


def check_openai(request):
"""
校验openai配置是否存在
:param request:
:return:
"""
config_validate = check_openai_config()
if not config_validate:
return HttpResponse(
json.dumps(
{
"status": 1,
"msg": "openai 缺少配置, 必需配置[openai_base_url, openai_api_key, default_chat_model]",
"data": False,
}
),
content_type="application/json",
)

return HttpResponse(
json.dumps({"status": 0, "msg": "ok", "data": True}),
content_type="application/json",
)
102 changes: 102 additions & 0 deletions sql/templates/sqlquery.html
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ <h4 class="modal-title text-danger">收藏语句</h4>
<option value={{ sql.id }}>{{ sql.alias }}</option>
{% endfor %}
</select>
<input id="generatePrompt"class="form-control" style="display: none" placeholder="AI 查询提示词" value="{{openai_prompt}}" />
QSummerY marked this conversation as resolved.
Show resolved Hide resolved
<input id="generateDesc" class="form-control" style="display: none" placeholder="AI 查询描述" />
<input id="btn-generatesql" type="button" class="btn btn-info" style="display: none" value="生成SQL"/>
QSummerY marked this conversation as resolved.
Show resolved Hide resolved
</div>
<div class="panel-body">
<form id="form-sqlquery" action="/sqlquery/" method="post" class="form-horizontal" role="form">
Expand Down Expand Up @@ -495,6 +498,28 @@ <h4 class="modal-title text-danger">收藏语句</h4>
}
sessionStorage.removeItem('re_query');
}

// 获取sysconfig
function check_openai() {
$.ajax({
type: "get",
url: "/check/openai/",
dataType: "json",
data: false,
complete: function () {
},
success: function (data) {
if (data["data"]) {
$("#generatePrompt").show()
$("#generateDesc").show()
$("#btn-generatesql").show()
}
},
error: function (XMLHttpRequest, textStatus, errorThrown) {
alert(errorThrown);
}
});
}
</script>
<!-- 执行结果 -->
<script>
Expand Down Expand Up @@ -624,6 +649,35 @@ <h4 class="modal-title text-danger">收藏语句</h4>
return result;
}

//提交AI生成sql语句请求
$("#btn-generatesql").click(function () {
var check = false
var optgroup = $('#instance_name :selected').parent().attr('label')
var instance_name = $("#instance_name").val()
var db_name = $("#db_name").val()
var tb_name = $("#table_name").val()
var query_desc = $("#generateDesc").val()
var prompt = $("#generatePrompt").val()

if (!instance_name) {
alert("请选择实例!")
} else if (!db_name) {
alert("请选择数据库!")
} else if (optgroup !== 'Redis' && !tb_name){
alert("请选择表结构!")
} else if (!prompt) {
alert("请输入查询提示词!")
} else if (!query_desc) {
alert("请输入查询描述!")
} else {
check = true
}
if (check) {
generatesql()
}
}
);

//先做表单验证,验证成功再成功提交查询请求
$("#btn-sqlquery").click(function () {
dosqlquery();
Expand Down Expand Up @@ -1023,6 +1077,36 @@ <h4 class="modal-title text-danger">收藏语句</h4>
});
}

function generatesql() {
const data = {
instance_name: $("#instance_name").val(),
db_name: $("#db_name").val(),
schema_name: $("#schema_name").val(),
tb_name: $("#table_name").val(),
query_desc: $("#generateDesc").val(),
query_prompt: $("#generatePrompt").val(),
}
//提交请求
$.ajax({
type: "post",
url: "/query/generate_sql/",
dataType: "json",
data: data,
complete: function () {
$('input[type=button]').removeClass('disabled');
$('input[type=button]').prop('disabled', false);
optgroup_control();
},
success: function (data) {
editor.setValue(data["data"]);
editor.clearSelection();
},
error: function (XMLHttpRequest, textStatus, errorThrown) {
alert(errorThrown);
}
});
}

function dosqlquery() {
if (sqlquery_validate()) {
$('input[type=button]').addClass('disabled');
Expand Down Expand Up @@ -1099,6 +1183,21 @@ <h4 class="modal-title text-danger">收藏语句</h4>
} else {
get_instance(true)
}
// 若存在AI提示词, 则根据实例自动更新数据库类型在开头, 否则自定义输入
var prompt = $('#generatePrompt').val()
if (prompt) {
console.log("aaa", prompt)
var optgroupLabel = $(this).find(':selected').parent().attr('label');
var parts = prompt.split(' ');
var combinePrompt = ""
if (parts.length > 1 && /^[A-Za-z]+$/.test(parts[0])){
parts[0] = optgroupLabel
combinePrompt = parts.join(" ");
} else {
combinePrompt = optgroupLabel + " " + prompt
}
$('#generatePrompt').val(combinePrompt)
}
});

function get_instance(async) {
Expand Down Expand Up @@ -1325,6 +1424,9 @@ <h4 class="modal-title text-danger">收藏语句</h4>
} else {
editor.setValue("");
}

// check openai 配置是否存在以支持AI生成查询语句功能
check_openai()

//默认获取查询历史
get_querylog();
Expand Down
Loading
Loading