diff --git a/sql/engines/pgsql.py b/sql/engines/pgsql.py index e7d491428a..cbe9150813 100644 --- a/sql/engines/pgsql.py +++ b/sql/engines/pgsql.py @@ -5,6 +5,7 @@ @file: pgsql.py @time: 2019/03/29 """ +import json import re import psycopg2 import logging @@ -197,16 +198,38 @@ def query( f"SET search_path TO %(schema_name)s;", {"schema_name": schema_name} ) cursor.execute(sql, parameters) - effect_row = cursor.rowcount + # effect_row = cursor.rowcount if int(limit_num) > 0: rows = cursor.fetchmany(size=int(limit_num)) else: rows = cursor.fetchall() fields = cursor.description + column_type_codes = [i[1] for i in fields] if fields else [] + # 定义 JSON 和 JSONB 的 type_code,# 114 是 json,3802 是 jsonb + JSON_TYPE_CODE = 114 + JSONB_TYPE_CODE = 3802 + # 对 rows 进行循环处理,判断是否是 jsonb 或 json 类型 + converted_rows = [] + for row in rows: + new_row = [] + for idx, col_value in enumerate(row): + # 理论上, 下标不会越界的 + column_type_code = ( + column_type_codes[idx] if idx < len(column_type_codes) else None + ) + # 只在列类型为 json 或 jsonb 时转换 + if column_type_code in [JSON_TYPE_CODE, JSONB_TYPE_CODE]: + if isinstance(col_value, (dict, list)): + new_row.append(json.dumps(col_value)) # 转为 JSON 字符串 + else: + new_row.append(col_value) + else: + new_row.append(col_value) + converted_rows.append(tuple(new_row)) result_set.column_list = [i[0] for i in fields] if fields else [] - result_set.rows = rows - result_set.affected_rows = effect_row + result_set.rows = converted_rows + result_set.affected_rows = len(converted_rows) except Exception as e: logger.warning( f"PgSQL命令执行报错,语句:{sql}, 错误信息:{traceback.format_exc()}" diff --git a/sql/engines/tests.py b/sql/engines/tests.py index 2cd119b761..9a85dc04dd 100644 --- a/sql/engines/tests.py +++ b/sql/engines/tests.py @@ -1,6 +1,6 @@ import json from datetime import timedelta, datetime -from unittest.mock import patch, Mock, ANY +from unittest.mock import MagicMock, patch, Mock, ANY import sqlparse from django.contrib.auth import get_user_model @@ -576,16 +576,46 @@ def test_query(self, _conn, _cursor, _execute): @patch("psycopg2.connect.cursor") @patch("psycopg2.connect") def test_query_not_limit(self, _conn, _cursor, _execute): - _conn.return_value.cursor.return_value.fetchall.return_value = [(1,)] + # 模拟数据库连接和游标 + mock_cursor = MagicMock() + _conn.return_value.cursor.return_value = mock_cursor + + # 模拟 SQL 查询的返回结果,包含 JSONB 类型、字符串和数字数据 + mock_cursor.fetchall.return_value = [ + ({"key": "value"}, "test_string", 123) # 返回一行数据,三列 + ] + mock_cursor.description = [ + ("json_column", 3802), # JSONB 类型 + ("string_column", 25), # 25 表示 TEXT 类型的 OID + ("number_column", 23), # 23 表示 INTEGER 类型的 OID + ] + + # _conn.return_value.cursor.return_value.fetchall.return_value = [(1,)] new_engine = PgSQLEngine(instance=self.ins) query_result = new_engine.query( db_name="some_dbname", - sql="select 1", + sql="SELECT json_column, string_column, number_column FROM some_table", limit_num=0, schema_name="some_schema", ) + + # 断言查询结果的类型和数据 self.assertIsInstance(query_result, ResultSet) - self.assertListEqual(query_result.rows, [(1,)]) + # 验证返回的 JSONB 列已转换为 JSON 字符串 + expected_row = ('{"key": "value"}', "test_string", 123) + self.assertListEqual(query_result.rows, [expected_row]) + + expected_column = ["json_column", "string_column", "number_column"] + # 验证列名是否正确 + self.assertEqual(query_result.column_list, expected_column) + + # 验证受影响的行数 + self.assertEqual(query_result.affected_rows, 1) + + # 验证类型代码是否正确(3802 表示 JSONB,25 表示 TEXT,23 表示 INTEGER) + expected_column_type_codes = [3802, 25, 23] + actual_column_type_codes = [desc[1] for desc in mock_cursor.description] + self.assertListEqual(actual_column_type_codes, expected_column_type_codes) @patch( "sql.engines.pgsql.PgSQLEngine.query",