diff --git a/common/check.py b/common/check.py index 07f74ce571..1c2aec66dc 100644 --- a/common/check.py +++ b/common/check.py @@ -5,6 +5,9 @@ import MySQLdb import simplejson as json from django.http import HttpResponse +from paramiko import Transport, SFTPClient +import oss2 +import os from common.utils.permission import superuser_required from sql.engines import get_engine @@ -131,3 +134,84 @@ def instance(request): result["msg"] = "无法连接实例,\n{}".format(str(e)) # 返回结果 return HttpResponse(json.dumps(result), content_type="application/json") + + +@superuser_required +def file_storage_connect(request): + result = {"status": 0, "msg": "ok", "data": []} + storage_type = request.POST.get("storage_type") + # 检查是否存在该变量 + max_export_rows = request.POST.get("max_export_rows", "10000") + max_export_exec_time = request.POST.get("max_export_exec_time", "60") + files_expire_with_days = request.POST.get("files_expire_with_days", "0") + # 若变量已经定义,检查是否为空 + max_export_rows = max_export_rows if max_export_rows else "10000" + max_export_exec_time = max_export_exec_time if max_export_exec_time else "60" + files_expire_with_days = files_expire_with_days if files_expire_with_days else "0" + check_list = { + "max_export_rows": max_export_rows, + "max_export_exec_time": max_export_exec_time, + "files_expire_with_days": files_expire_with_days, + } + try: + # 遍历字典,判断是否只有数字 + for key, value in check_list.items(): + if not value.isdigit(): + raise TypeError(f"Value: {key} \nmust be an integer.") + except TypeError as e: + result["status"] = 1 + result["msg"] = "参数类型错误,\n{}".format(str(e)) + + if storage_type == "sftp": + sftp_host = request.POST.get("sftp_host") + sftp_port = int(request.POST.get("sftp_port")) + sftp_user = request.POST.get("sftp_user") + sftp_password = request.POST.get("sftp_password") + sftp_path = request.POST.get("sftp_path") + + try: + with Transport((sftp_host, sftp_port)) as transport: + transport.connect(username=sftp_user, password=sftp_password) + # 创建 SFTPClient + sftp = SFTPClient.from_transport(transport) + remote_path = sftp_path + try: + sftp.listdir(remote_path) + except FileNotFoundError: + raise Exception(f"SFTP 远程路径 '{remote_path}' 不存在") + + except Exception as e: + result["status"] = 1 + result["msg"] = "无法连接,\n{}".format(str(e)) + elif storage_type == "oss": + access_key_id = request.POST.get("access_key_id") + access_key_secret = request.POST.get("access_key_secret") + endpoint = request.POST.get("endpoint") + bucket_name = request.POST.get("bucket_name") + try: + # 创建 OSS 认证 + auth = oss2.Auth(access_key_id, access_key_secret) + # 创建 OSS Bucket 对象 + bucket = oss2.Bucket(auth, endpoint, bucket_name) + + # 判断配置的 Bucket 是否存在 + try: + bucket.get_bucket_info() + except oss2.exceptions.NoSuchBucket: + raise Exception(f"OSS 存储桶 '{bucket_name}' 不存在") + + except Exception as e: + result["status"] = 1 + result["msg"] = "无法连接,\n{}".format(str(e)) + elif storage_type == "local": + local_path = r"{}".format(request.POST.get("local_path")) + try: + if not os.path.exists(local_path): + raise FileNotFoundError( + f"Destination directory '{local_path}' not found." + ) + except Exception as e: + result["status"] = 1 + result["msg"] = "本地路径不存在,\n{}".format(str(e)) + + return HttpResponse(json.dumps(result), content_type="application/json") diff --git a/common/templates/config.html b/common/templates/config.html index 93af8eeb1f..d2febe739d 100755 --- a/common/templates/config.html +++ b/common/templates/config.html @@ -399,6 +399,169 @@
注:开启脱敏功能必须要配置goInception信息 placeholder="管理员/DBA查询结果集限制" /> +
SQL离线导出配置
+
+
+ +
+ +
+ +
+ + + +
SQL优化

@@ -1213,6 +1376,56 @@
当前审批流程:" + storage_display(storages[i]) + "" + } else { + storage = "" + } + $("#sqlfile_storage").append(storage) + } + // sms_provider参数处理 function provider_display(provider) { if (provider === 'disabled') { @@ -1557,6 +1770,66 @@
当前审批流程:当前审批流程: str: """字符串参数转义""" return pymysql.escape_string(value) - def execute_check(self, instance=None, db_name=None, sql=""): + def execute_check( + self, instance=None, db_name=None, sql="", is_offline_export=None + ): """inception check""" # 判断如果配置了隧道则连接隧道 host, port, user, password = self.remote_instance_conn(instance) @@ -99,6 +101,8 @@ def execute_check(self, instance=None, db_name=None, sql=""): if check_result.syntax_type == 2: if get_syntax_type(r[5], parser=False, db_type="mysql") == "DDL": check_result.syntax_type = 1 + if is_offline_export == "yes": + check_result.syntax_type = 3 check_result.column_list = inception_result.column_list check_result.checked = True check_result.error = inception_result.error diff --git a/sql/engines/mysql.py b/sql/engines/mysql.py index 4a5b399ec4..bf5b7d8c6f 100644 --- a/sql/engines/mysql.py +++ b/sql/engines/mysql.py @@ -17,6 +17,7 @@ from .models import ResultSet, ReviewResult, ReviewSet from sql.utils.data_masking import data_masking from common.config import SysConfig +from sql.engines.offlinedownload import OffLineDownLoad logger = logging.getLogger("default") @@ -71,6 +72,7 @@ def __init__(self, instance=None): super().__init__(instance=instance) self.config = SysConfig() self.inc_engine = GoInceptionEngine() + self.sql_export = OffLineDownLoad() def get_connection(self, db_name=None): # https://stackoverflow.com/questions/19256155/python-mysqldb-returning-x01-for-bit-values @@ -621,12 +623,19 @@ def query_masking(self, db_name=None, sql="", resultset=None): mask_result = resultset return mask_result - def execute_check(self, db_name=None, sql=""): + def execute_check(self, db_name=None, sql="", offline_data=None): """上线单执行前的检查, 返回Review set""" + # 获取离线导出工单参数 + offline_exp = ( + offline_data["is_offline_export"] if offline_data is not None else "0" + ) # 进行Inception检查,获取检测结果 try: check_result = self.inc_engine.execute_check( - instance=self.instance, db_name=db_name, sql=sql + instance=self.instance, + db_name=db_name, + sql=sql, + is_offline_export=offline_exp, ) except Exception as e: logger.debug( @@ -659,10 +668,11 @@ def execute_check(self, db_name=None, sql=""): syntax_type = get_syntax_type(statement, parser=False, db_type="mysql") # 禁用语句 if re.match(r"^select", statement.lower()): - check_result.error_count += 1 - row.stagestatus = "驳回不支持语句" - row.errlevel = 2 - row.errormessage = "仅支持DML和DDL语句,查询语句请使用SQL查询功能!" + if offline_exp != "yes": + check_result.error_count += 1 + row.stagestatus = "驳回不支持语句" + row.errlevel = 2 + row.errormessage = "仅支持DML和DDL语句,查询语句请使用SQL查询功能!" # 高危语句 elif critical_ddl_regex and p.match(statement.strip().lower()): check_result.error_count += 1 @@ -681,28 +691,31 @@ def execute_check(self, db_name=None, sql=""): def execute_workflow(self, workflow): """执行上线单,返回Review set""" - # 判断实例是否只读 - read_only = self.query(sql="SELECT @@global.read_only;").rows[0][0] - if read_only in (1, "ON"): - result = ReviewSet( - full_sql=workflow.sqlworkflowcontent.sql_content, - rows=[ - ReviewResult( - id=1, - errlevel=2, - stagestatus="Execute Failed", - errormessage="实例read_only=1,禁止执行变更语句!", - sql=workflow.sqlworkflowcontent.sql_content, - ) - ], - ) - result.error = ("实例read_only=1,禁止执行变更语句!",) - return result - # TODO 原生执行 - # if workflow.is_manual == 1: - # return self.execute(db_name=workflow.db_name, sql=workflow.sqlworkflowcontent.sql_content) - # inception执行 - return self.inc_engine.execute(workflow) + if workflow.is_offline_export == "yes": + return self.sql_export.execute_offline_download(workflow) + else: + # 判断实例是否只读 + read_only = self.query(sql="SELECT @@global.read_only;").rows[0][0] + if read_only in (1, "ON"): + result = ReviewSet( + full_sql=workflow.sqlworkflowcontent.sql_content, + rows=[ + ReviewResult( + id=1, + errlevel=2, + stagestatus="Execute Failed", + errormessage="实例read_only=1,禁止执行变更语句!", + sql=workflow.sqlworkflowcontent.sql_content, + ) + ], + ) + result.error = ("实例read_only=1,禁止执行变更语句!",) + return result + # TODO 原生执行 + # if workflow.is_manual == 1: + # return self.execute(db_name=workflow.db_name, sql=workflow.sqlworkflowcontent.sql_content) + # inception执行 + return self.inc_engine.execute(workflow) def execute(self, db_name=None, sql="", close_conn=True, parameters=None): """原生执行语句""" diff --git a/sql/engines/offlinedownload.py b/sql/engines/offlinedownload.py new file mode 100644 index 0000000000..47b6d765f2 --- /dev/null +++ b/sql/engines/offlinedownload.py @@ -0,0 +1,574 @@ +# -*- coding: UTF-8 -*- +import logging +import re +import traceback +import os +import tempfile +import csv +from io import BytesIO +import hashlib +import shutil +import datetime +import xml.etree.ElementTree as ET +import zipfile +import sqlparse +from threading import Thread +import queue +import time + +import MySQLdb +import simplejson as json +from paramiko import Transport, SFTPClient +import oss2 +import pandas as pd +from django.http import HttpResponse +from urllib.parse import quote + +from sql.models import SqlWorkflow, AuditEntry, Config +from . import EngineBase +from .models import ReviewSet, ReviewResult + + +logger = logging.getLogger("default") + + +class TimeoutException(Exception): + pass + + +class OffLineDownLoad(EngineBase): + def execute_offline_download(self, workflow): + if workflow.is_offline_export == "yes": + # 创建一个临时目录用于存放文件 + temp_dir = tempfile.mkdtemp() + # 获取系统配置 + config = get_sys_config() + # 先进行 max_export_exec_time 变量的判断是否存在以及是否为空,默认值60 + timeout_str = config.get("max_export_exec_time", "60") + timeout = int(timeout_str) if timeout_str else 60 + storage_type = config["sqlfile_storage"] + # 获取前端提交的 SQL 和其他工单信息 + full_sql = workflow.sqlworkflowcontent.sql_content + full_sql = sqlparse.format(full_sql, strip_comments=True) + full_sql = sqlparse.split(full_sql)[0] + sql = full_sql.strip() + instance = workflow.instance + host, port, user, password = self.remote_instance_conn(instance) + execute_result = ReviewSet(full_sql=sql) + # 定义数据库连接 + conn = MySQLdb.connect( + host=host, + port=port, + user=user, + password=password, + db=workflow.db_name, + charset="utf8mb4", + ) + + start_time = time.time() + try: + check_result = execute_check_sql(conn, sql, config) + if isinstance(check_result, Exception): + raise check_result + except Exception as e: + execute_result.rows = [ + ReviewResult( + stage="Execute failed", + error=1, + errlevel=2, + stagestatus="异常终止", + errormessage=f"{e}", + sql=full_sql, + ) + ] + execute_result.error = e + return execute_result + + try: + # 执行 SQL 查询 + results = self.execute_with_timeout( + conn, workflow.sqlworkflowcontent.sql_content, timeout + ) + if results: + columns = results["columns"] + result = results["data"] + + # 保存查询结果为 CSV or JSON or XML or XLSX or SQL 文件 + get_format_type = workflow.export_format + file_name = save_to_format_file( + get_format_type, result, workflow, columns, temp_dir + ) + + # 将导出的文件上传至 OSS 或 FTP 或 本地保存 + upload_file_to_storage(file_name, storage_type, temp_dir) + + end_time = time.time() # 记录结束时间 + elapsed_time = round(end_time - start_time, 3) + execute_result.rows = [ + ReviewResult( + stage="Executed", + errlevel=0, + stagestatus="执行正常", + errormessage=f"保存文件: {file_name}", + sql=full_sql, + execute_time=elapsed_time, + affected_rows=check_result, + ) + ] + + change_workflow = SqlWorkflow.objects.get(id=workflow.id) + change_workflow.file_name = file_name + change_workflow.save() + + return execute_result + except Exception as e: + # 返回工单执行失败的状态和错误信息 + execute_result.rows = [ + ReviewResult( + stage="Execute failed", + error=1, + errlevel=2, + stagestatus="异常终止", + errormessage=f"{e}", + sql=full_sql, + ) + ] + execute_result.error = e + return execute_result + finally: + # 清理本地文件和临时目录 + clean_local_files(temp_dir) + # 关闭游标和数据库连接 + conn.close() + + @staticmethod + def execute_query(conn, sql): + try: + cursor = conn.cursor() + cursor.execute(sql) + columns = [column[0] for column in cursor.description] + result = {"columns": columns, "data": cursor.fetchall()} + cursor.close() + return result + except Exception as e: + raise Exception(f"Query execution failed: {e}") + + def worker(self, conn, sql, result_queue): + try: + result = self.execute_query(conn, sql) + result_queue.put(result) + except Exception as e: + result_queue.put(e) + + def execute_with_timeout(self, conn, sql, timeout): + result_queue = queue.Queue() + thread = Thread(target=self.worker, args=(conn, sql, result_queue)) + thread.start() + thread.join(timeout) + + if thread.is_alive(): + thread.join() + raise TimeoutException( + f"Query execution timed out after {timeout} seconds." + ) + else: + result = result_queue.get() + if isinstance(result, Exception): + raise result + else: + return result + + +def get_sys_config(): + all_config = Config.objects.all().values("item", "value") + sys_config = {} + for items in all_config: + sys_config[items["item"]] = items["value"] + return sys_config + + +def save_to_format_file( + format_type=None, result=None, workflow=None, columns=None, temp_dir=None +): + # 生成唯一的文件名(包含工单ID、日期和随机哈希值) + timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S") + hash_value = hashlib.sha256(os.urandom(32)).hexdigest()[:8] # 使用前8位作为哈希值 + base_name = f"{workflow.db_name}_{timestamp}_{hash_value}" + file_name = f"{base_name}.{format_type}" + file_path = os.path.join(temp_dir, file_name) + # 将查询结果写入 CSV 文件 + if format_type == "csv": + save_csv(file_path, result, columns) + elif format_type == "json": + save_json(file_path, result, columns) + elif format_type == "xml": + save_xml(file_path, result, columns) + elif format_type == "xlsx": + save_xlsx(file_path, result, columns) + elif format_type == "sql": + save_sql(file_path, result, columns) + else: + raise ValueError(f"Unsupported format type: {format_type}") + + zip_file_name = f"{base_name}.zip" + zip_file_path = os.path.join(temp_dir, zip_file_name) + with zipfile.ZipFile(zip_file_path, "w", zipfile.ZIP_DEFLATED) as zipf: + zipf.write(file_path, os.path.basename(file_path)) + return zip_file_name + + +def upload_file_to_storage(file_name=None, storage_type=None, temp_dir=None): + action_exec = StorageControl( + file_name=file_name, storage_type=storage_type, temp_dir=temp_dir + ) + try: + if storage_type == "oss": + # 使用阿里云 OSS 进行上传 + action_exec.upload_to_oss() + elif storage_type == "sftp": + # 使用 SFTP 进行上传 + action_exec.upload_to_sftp() + elif storage_type == "local": + # 本地存储 + action_exec.upload_to_local() + else: + # 未知存储类型,可以抛出异常或处理其他逻辑 + raise ValueError(f"Unknown storage type: {storage_type}") + except Exception as e: + raise e + + +def clean_local_files(temp_dir): + # 删除临时目录及其内容 + shutil.rmtree(temp_dir) + + +def datetime_serializer(obj): + if isinstance(obj, (datetime.date, datetime.datetime)): + return obj.isoformat() + raise TypeError("Type %s not serializable" % type(obj)) + + +def save_csv(file_path, result, columns): + with open(file_path, "w", newline="", encoding="utf-8") as csv_file: + csv_writer = csv.writer(csv_file, quoting=csv.QUOTE_ALL) + + if columns: + csv_writer.writerow(columns) + + for row in result: + csv_row = ["null" if value is None else value for value in row] + csv_writer.writerow(csv_row) + + +def save_json(file_path, result, columns): + with open(file_path, "w", encoding="utf-8") as json_file: + json.dump( + [dict(zip(columns, row)) for row in result], + json_file, + indent=2, + default=datetime_serializer, + ensure_ascii=False, + ) + + +def save_xml(file_path, result, columns): + root = ET.Element("tabledata") + + # Create fields element + fields_elem = ET.SubElement(root, "fields") + for column in columns: + field_elem = ET.SubElement(fields_elem, "field") + field_elem.text = column + + # Create data element + data_elem = ET.SubElement(root, "data") + for row_id, row in enumerate(result, start=1): + row_elem = ET.SubElement(data_elem, "row", id=str(row_id)) + for col_idx, value in enumerate(row, start=1): + col_elem = ET.SubElement(row_elem, f"column-{col_idx}") + if value is None: + col_elem.text = "(null)" + elif isinstance(value, (datetime.date, datetime.datetime)): + col_elem.text = value.isoformat() + else: + col_elem.text = str(value) + + tree = ET.ElementTree(root) + tree.write(file_path, encoding="utf-8", xml_declaration=True) + + +def save_xlsx(file_path, result, columns): + try: + df = pd.DataFrame( + [ + [ + str(value) if value is not None and value != "NULL" else "" + for value in row + ] + for row in result + ], + columns=columns, + ) + df.to_excel(file_path, index=False, header=True) + except ValueError as e: + raise ValueError(f"Excel最大支持行数为1048576,已超出!") + + +def save_sql(file_path, result, columns): + with open(file_path, "w") as sql_file: + for row in result: + table_name = "your_table_name" + if columns: + sql_file.write( + f"INSERT INTO {table_name} ({', '.join(columns)}) VALUES " + ) + + values = ", ".join( + [ + ( + "'{}'".format(str(value).replace("'", "''")) + if isinstance(value, str) + or isinstance(value, datetime.date) + or isinstance(value, datetime.datetime) + else "NULL" if value is None or value == "" else str(value) + ) + for value in row + ] + ) + sql_file.write(f"({values});\n") + + +def offline_file_download(request): + file_name = request.GET.get("file_name", " ") + workflow_id = request.GET.get("workflow_id", " ") + action = "离线下载" + extra_info = f"工单id:{workflow_id},文件:{file_name}" + config = get_sys_config() + storage_type = config["sqlfile_storage"] + + try: + action_exec = StorageControl(storage_type=storage_type, file_name=file_name) + if storage_type == "sftp": + response = action_exec.download_from_sftp() + return response + elif storage_type == "oss": + response = action_exec.download_from_oss() + return response + elif storage_type == "local": + response = action_exec.download_from_local() + return response + + except Exception as e: + action = "离线下载失败" + return HttpResponse(f"下载失败:{e}", status=500) + finally: + AuditEntry.objects.create( + user_id=request.user.id, + user_name=request.user.username, + user_display=request.user.display, + action=action, + extra_info=extra_info, + ) + + +class StorageControl: + def __init__( + self, storage_type=None, do_action=None, file_name=None, temp_dir=None + ): + """根据存储服务进行文件的上传下载""" + # 存储类型 + self.storage_type = storage_type + # 暂时无用,可考虑删除 + self.do_action = do_action + # 导出文件的压缩包名称 + self.file_name = file_name + # 导出文件的本地临时目录,上传完成后会自动清理 + self.temp_dir = temp_dir + + # 获取系统配置 + self.config = get_sys_config() + # 先进行系统管理内配置的 files_expire_with_days 参数的判断是否存在以及是否为空,默认值 0-不过期 + self.expire_time_str = self.config.get("files_expire_with_days", "0") + self.expire_time_with_days = ( + int(self.expire_time_str) if self.expire_time_str else 0 + ) + # 获取当前时间 + self.current_time = datetime.datetime.now() + # 获取过期的时间 + self.expire_time = self.current_time - datetime.timedelta( + days=self.expire_time_with_days + ) + + # SFTP 存储相关配置信息 + self.sftp_host = self.config["sftp_host"] + self.sftp_user = self.config["sftp_user"] + self.sftp_password = self.config["sftp_password"] + self.sftp_port_str = self.config.get("sftp_port", "22") + self.sftp_port = int(self.sftp_port_str) if self.sftp_port_str else 22 + self.sftp_path = self.config["sftp_path"] + + # OSS 存储相关配置信息 + self.access_key_id = self.config["oss_access_key_id"] + self.access_key_secret = self.config["oss_access_key_secret"] + self.endpoint = self.config["oss_endpoint"] + self.bucket_name = self.config["oss_bucket_name"] + self.oss_path = self.config["oss_path"] + + # 本地存储相关配置信息 + # self.local_path = r'{}'.format(self.config['local_path']) + self.local_path = r"{}".format(self.config.get("local_path", "/tmp")) + + def upload_to_sftp(self): + # SFTP 配置 + try: + with Transport((self.sftp_host, self.sftp_port)) as transport: + transport.connect(username=self.sftp_user, password=self.sftp_password) + with SFTPClient.from_transport(transport) as sftp: + remote_file = os.path.join( + self.sftp_path, os.path.basename(self.file_name) + ) + # 判断时间是否配置,为 0 则默认不删除,大于 0 则调用删除方法进行删除过期文件 + if self.expire_time_with_days > 0: + self.del_file_before_upload_to_sftp(sftp) + # 上传离线导出的文件压缩包到SFTP + sftp.put(os.path.join(self.temp_dir, self.file_name), remote_file) + + except Exception as e: + upload_to_sftp_exception = Exception(f"上传失败: {e}") + raise upload_to_sftp_exception + + def download_from_sftp(self): + file_path = os.path.join(self.sftp_path, self.file_name) + + with Transport((self.sftp_host, self.sftp_port)) as transport: + transport.connect(username=self.sftp_user, password=self.sftp_password) + with SFTPClient.from_transport(transport) as sftp: + # 获取压缩包内容 + file_content = BytesIO() + sftp.getfo(file_path, file_content) + + # 构造 HttpResponse 返回 ZIP 文件内容 + response = HttpResponse(file_content.getvalue(), content_type="application/zip") + response["Content-Disposition"] = ( + f"attachment; filename={quote(self.file_name)}" + ) + return response + + def del_file_before_upload_to_sftp(self, sftp): + for file_info in sftp.listdir_attr(self.sftp_path): + file_path = os.path.join(self.sftp_path, file_info.filename) + + # 获取文件的修改时间 + modified_time = datetime.datetime.fromtimestamp(file_info.st_mtime) + + # 如果文件过期,则删除 + if modified_time < self.expire_time: + sftp.remove(file_path) + + def upload_to_oss(self): + # 创建 OSS 认证 + auth = oss2.Auth(self.access_key_id, self.access_key_secret) + + # 创建 OSS Bucket 对象 + bucket = oss2.Bucket(auth, self.endpoint, self.bucket_name) + + # 上传文件到 OSS + remote_key = os.path.join(self.oss_path, os.path.basename(self.file_name)) + # 判断时间是否配置,为 0 则默认不删除,大于 0 则调用删除方法进行删除过期文件 + if self.expire_time_with_days > 0: + self.del_file_before_upload_to_oss(bucket) + # 读取并上传离线导出的文件压缩包到OSS + with open(os.path.join(self.temp_dir, self.file_name), "rb") as file: + bucket.put_object(remote_key, file) + + def download_from_oss(self): + # 创建 OSS 认证 + auth = oss2.Auth(self.access_key_id, self.access_key_secret) + + # 创建 OSS Bucket 对象 + bucket = oss2.Bucket(auth, self.endpoint, self.bucket_name) + + # 从OSS下载文件 + remote_path = self.oss_path + remote_key = os.path.join(remote_path, self.file_name) + object_stream = bucket.get_object(remote_key) + response = HttpResponse(object_stream.read(), content_type="application/zip") + response["Content-Disposition"] = ( + f"attachment; filename={quote(self.file_name)}" + ) + return response + + def del_file_before_upload_to_oss(self, bucket): + for object_info in oss2.ObjectIterator(bucket, prefix=self.oss_path): + # 获取 bucket 存储路径下的文件名 + file_path = object_info.key + + # 获取文件的修改时间 + modified_time = datetime.datetime.fromtimestamp(object_info.last_modified) + + # 如果文件过期,则删除 + if modified_time < self.expire_time: + bucket.delete_object(file_path) + + def upload_to_local(self): + try: + source_path = os.path.join(self.temp_dir, self.file_name) + # 判断配置内的本地存储路径是否存在,若不存在则抛出报错 + if not os.path.exists(self.local_path): + raise FileNotFoundError( + f"Destination directory '{self.local_path}' not found." + ) + # 判断时间是否配置,为 0 则默认不删除,大于 0 则调用删除方法进行删除过期文件 + if self.expire_time_with_days > 0: + self.del_file_before_upload_to_local() + # 拷贝离线导出的文件压缩包到指定路径 + shutil.copy(source_path, self.local_path) + except Exception as e: + raise e + + def download_from_local(self): + file_path = os.path.join(self.local_path, self.file_name) + + with open(file_path, "rb") as file: + response = HttpResponse(file.read(), content_type="application/zip") + response["Content-Disposition"] = ( + f"attachment; filename={quote(self.file_name)}" + ) + return response + + def del_file_before_upload_to_local(self): + for local_file_info in os.listdir(self.local_path): + file_path = os.path.join(self.local_path, local_file_info) + if ( + os.path.isfile(file_path) + and os.path.getmtime(file_path) < self.expire_time.timestamp() + ): + os.remove(file_path) + + +def execute_check_sql(conn, sql, config): + # 先进行 max_export_rows 变量的判断是否存在以及是否为空,默认值10000 + max_export_rows_str = config.get("max_export_rows", "10000") + max_export_rows = int(max_export_rows_str) if max_export_rows_str else 10000 + + # 判断sql是否以 select 开头 + if not sql.strip().lower().startswith("select"): + return Exception(f"违规语句:{sql}") + sql = "explain " + sql + cursor = conn.cursor() + try: + cursor.execute(sql) + check_result = cursor.fetchall() + total_explain_scan_rows = sum( + row[9] if row[9] is not None else 0 for row in check_result + ) + if int(total_explain_scan_rows) > max_export_rows: + return Exception(f"扫描行数超出阈值: {max_export_rows}") + else: + return total_explain_scan_rows + except Exception as e: + return e + finally: + # 关闭游标和数据库连接 + cursor.close() diff --git a/sql/models.py b/sql/models.py index 405692459d..a6ee2536f7 100755 --- a/sql/models.py +++ b/sql/models.py @@ -290,8 +290,8 @@ class SqlWorkflow(models.Model, WorkflowAuditMixin): instance = models.ForeignKey(Instance, on_delete=models.CASCADE) db_name = models.CharField("数据库", max_length=64) syntax_type = models.IntegerField( - "工单类型 0、未知,1、DDL,2、DML", - choices=((0, "其他"), (1, "DDL"), (2, "DML")), + "工单类型 0、未知,1、DDL,2、DML,3、离线导出工单", + choices=((0, "其他"), (1, "DDL"), (2, "DML"), (3, "离线导出工单")), default=0, ) is_backup = models.BooleanField( @@ -313,6 +313,38 @@ class SqlWorkflow(models.Model, WorkflowAuditMixin): is_manual = models.IntegerField( "是否原生执行", choices=((0, "否"), (1, "是")), default=0 ) + is_offline_export = models.CharField( + "是否为离线导出工单", + max_length=3, + choices=( + ("no", "否"), + ("yes", "是"), + ), + default="no", + ) + + # 导出格式 + export_format = models.CharField( + "导出格式", + max_length=10, + choices=( + ("csv", "CSV"), + ("xlsx", "Excel"), + ("sql", "SQL"), + ("json", "JSON"), + ("xml", "XML"), + ), + # default="csv", + null=True, + blank=True, + ) + + file_name = models.CharField( + "文件名", + max_length=255, # 适当调整最大长度 + null=True, # 允许为空 + blank=True, # 允许为空字符串 + ) def __str__(self): return self.workflow_name @@ -965,6 +997,7 @@ class Meta: ("archive_mgt", "管理归档申请"), ("audit_user", "审计权限"), ("query_download", "在线查询下载权限"), + ("offline_download", "离线下载权限"), ) diff --git a/sql/sql_workflow.py b/sql/sql_workflow.py index b812be7c51..5624b63e23 100644 --- a/sql/sql_workflow.py +++ b/sql/sql_workflow.py @@ -116,6 +116,7 @@ def _sql_workflow_list(request): "db_name", "group_name", "syntax_type", + "export_format", ) # QuerySet 序列化 diff --git a/sql/templates/detail.html b/sql/templates/detail.html index bc40b4d58e..b7a3c00ec8 100644 --- a/sql/templates/detail.html +++ b/sql/templates/detail.html @@ -9,11 +9,13 @@

href="{{ workflow_detail.demand_url }}">{{ workflow_detail.workflow_name }}

- {% if user.username == workflow_detail.engineer %} + {% if user.username == workflow_detail.engineer and workflow_detail.is_offline_export != "yes" %} 上线其他实例 {% endif %} {% if is_can_review or is_can_execute or is_can_rollback or user.is_superuser %} - 查看提交信息 + {% if workflow_detail.is_offline_export != "yes" %} + 查看提交信息 + {% endif %} {% endif %}
@@ -225,6 +227,14 @@

{% endif %} + {% if workflow_detail.status == 'workflow_finish' and workflow_detail.is_offline_export == 'yes' %} +
+ +
+ {% endif %} {% if is_can_rollback %} {% if workflow_detail.status == 'workflow_finish' or workflow_detail.status == 'workflow_exception' %} @@ -458,7 +468,8 @@

// 执行确认 $("#btnExecuteOnly").click(function () { - if ("{{ workflow_detail.is_backup }}" === 'False') { + if ("{{ workflow_detail.is_backup }}" === 'False' && + "{{workflow_detail.is_offline_export}}" !== "yes") { var isContinue = confirm("该工单未选择备份,将不会自动备份数据,请确认是否立即执行?"); } else { var isContinue = confirm("请确认是否立即执行?"); @@ -964,5 +975,46 @@ }); }); + + {% endblock %} diff --git a/sql/templates/sqlquery.html b/sql/templates/sqlquery.html index e6c3d39174..6db24f5482 100644 --- a/sql/templates/sqlquery.html +++ b/sql/templates/sqlquery.html @@ -65,6 +65,38 @@

                         
                         
+
+ +
+
+ +
+
+ +
-
+
-
+
+ + +
- + + + {% if config.sqlfile_storage == "sftp" or config.sqlfile_storage == "oss" or config.sqlfile_storage == "local"%} + {% if can_offline_download == 1 %} + + {% endif %} + {% endif %}
-
+
  • 支持注释行,可选择指定语句执行,默认执行第一条;
  • 查询结果行数限制见权限管理,会选择查询涉及表的最小limit值
  • +
    +
  • 导出工单仅支持导出一条查询语句
  • +
  • 导出请先看执行计划,若扫描行数过多将禁止执行,阈值: {{ config.max_export_rows }}
  • +
    +
    @@ -256,6 +318,190 @@ + + +