Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

初始化 notify 时自动根据 workflow audit 取 workflow #2363

Merged
merged 2 commits into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions sql/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,16 @@ class WorkflowAudit(models.Model):
create_time = models.DateTimeField("申请时间", auto_now_add=True)
sys_time = models.DateTimeField("系统时间", auto_now=True)

def get_workflow(self):
"""尝试从 audit 中取出 workflow"""
if self.workflow_type == WorkflowType.QUERY:
return QueryPrivilegesApply.objects.get(apply_id=self.workflow_id)
elif self.workflow_type == WorkflowType.SQL_REVIEW:
return SqlWorkflow.objects.get(id=self.workflow_id)
elif self.workflow_type == WorkflowType.ARCHIVE:
return ArchiveConfig.objects.get(id=self.workflow_id)
raise ValueError("无法获取到关联工单")

def __int__(self):
return self.audit_id

Expand Down
54 changes: 22 additions & 32 deletions sql/notify.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,23 +52,25 @@ class My2SqlResult:
error: str = ""


@dataclass
class Notifier:
name = "base"
sys_config_key: str = ""

def __init__(
self,
workflow: Union[SqlWorkflow, ArchiveConfig, QueryPrivilegesApply, My2SqlResult],
sys_config: SysConfig,
audit: WorkflowAudit = None,
audit_detail: WorkflowAuditDetail = None,
event_type: EventType = EventType.AUDIT,
):
self.workflow = workflow
self.audit = audit
self.audit_detail = audit_detail
self.event_type = event_type
self.sys_config = sys_config
workflow: Union[SqlWorkflow, ArchiveConfig, QueryPrivilegesApply, My2SqlResult]
sys_config: SysConfig = None
# init false, class property, 不是 instance property
name: str = field(init=False, default="base")
sys_config_key: str = field(init=False, default="")
event_type: EventType = EventType.AUDIT
audit: WorkflowAudit = None
audit_detail: WorkflowAuditDetail = None

def __post_init__(self):
if not self.workflow:
if not self.audit:
raise ValueError("需要提供 WorkflowAudit 或 workflow")
self.workflow = self.audit.get_workflow()
# 防止 get_auditor 显式的传了个 None
if not self.sys_config:
self.sys_config = SysConfig()

def render(self):
raise NotImplementedError
Expand All @@ -91,12 +93,9 @@ def run(self):


class GenericWebhookNotifier(Notifier):
name = "generic_webhook"
name: str = "generic_webhook"
sys_config_key: str = "generic_webhook_url"

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.request_data = None
request_data: dict = None

def render(self):
self.request_data = {}
Expand Down Expand Up @@ -133,13 +132,9 @@ class LegacyMessage:
msg_cc: List[Users] = field(default_factory=list)


@dataclass
class LegacyRender(Notifier):
messages: List[LegacyMessage]
sys_config_key: str = ""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.messages = []
messages: List[LegacyMessage] = field(default_factory=list)

def render_audit(self):
# 获取审核信息
Expand Down Expand Up @@ -476,11 +471,6 @@ def auto_notify(
加载所有的 notifier, 调用 notifier 的 render 和 send 方法
内部方法, 有数据库查询, 为了方便测试, 请勿使用 async_task 调用, 防止 patch 后调用失败
"""
if not workflow and event_type == EventType.AUDIT:
if audit.workflow_type == 1:
workflow = QueryPrivilegesApply.objects.get(apply_id=audit.workflow_id)
if audit.workflow_type == 2:
workflow = SqlWorkflow.objects.get(id=audit.workflow_id)
for notifier in settings.ENABLED_NOTIFIERS:
file, _class = notifier.split(":")
try:
Expand Down
26 changes: 25 additions & 1 deletion sql/test_notify.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def test_base_notifier(self):
@patch("sql.notify.FeishuWebhookNotifier.run")
def test_auto_notify(self, mock_run):
with self.settings(ENABLED_NOTIFIERS=("sql.notify:FeishuWebhookNotifier",)):
auto_notify(self.sys_config, event_type=EventType.EXECUTE)
auto_notify(self.sys_config, event_type=EventType.EXECUTE, workflow=self.wf)
mock_run.assert_called_once()

@patch("sql.notify.auto_notify")
Expand Down Expand Up @@ -280,6 +280,17 @@ def test_legacy_render_audit(self):
notifier.render()
self.assertEqual(len(notifier.messages), 1)
self.assertIn("新的工单申请", notifier.messages[0].msg_title)
# 测试一下不传 workflow
notifier = LegacyRender(
event_type=EventType.AUDIT,
workflow=None,
audit=self.audit_wf,
audit_detail=self.audit_wf_detail,
sys_config=self.sys_config,
)
notifier.render()
self.assertEqual(len(notifier.messages), 1)
self.assertIn("新的工单申请", notifier.messages[0].msg_title)

def test_legacy_render_query_audit(self):
# 默认是库权限的
Expand Down Expand Up @@ -494,10 +505,13 @@ def tearDownClass(cls):
def setUp(self):
self.patcher = patch("sql.notify.MsgSender")
self.mock_msg_sender = self.patcher.start()
self.get_workflow_patcher = patch("sql.models.WorkflowAudit.get_workflow")
self.mock_get_workflow = self.get_workflow_patcher.start()
self.sys_config = SysConfig()

def tearDown(self):
self.patcher.stop()
self.get_workflow_patcher.stop()

def generate_notifier(self, module) -> Notifier:
return module(workflow=None, audit=self.audit_wf, sys_config=self.sys_config)
Expand Down Expand Up @@ -561,3 +575,13 @@ def test_mail(self):
]
notifier.send()
mocker.assert_called_once()


def test_override_sys_key():
"""dataclass 的继承有时候让人有点困惑, 在这里补一个测试确认可以正常覆盖一些值"""

class OverrideNotifier(Notifier):
sys_config_key = "test"

n = OverrideNotifier(workflow="test")
assert n.sys_config_key == "test"
10 changes: 2 additions & 8 deletions sql/utils/workflow_audit.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,14 +136,8 @@ def review_info(self) -> (str, str):

def get_workflow(self):
"""尝试从 audit 中取出 workflow"""
if self.audit.workflow_type == WorkflowType.QUERY:
self.workflow = QueryPrivilegesApply.objects.get(
apply_id=self.audit.workflow_id
)
elif self.audit.workflow_type == WorkflowType.SQL_REVIEW:
self.workflow = SqlWorkflow.objects.get(id=self.audit.workflow_id)
elif self.audit.workflow_type == WorkflowType.ARCHIVE:
self.workflow = ArchiveConfig.objects.get(id=self.audit.workflow_id)
self.workflow = self.audit.get_workflow()
if self.audit.workflow_type == WorkflowType.ARCHIVE:
self.resource_group = self.audit.group_name
self.resource_group_id = self.audit.group_id

Expand Down
Loading