Skip to content

Commit

Permalink
fix(ingest/powerbi): reduce type cast usage (#12004)
Browse files Browse the repository at this point in the history
  • Loading branch information
hsheth2 authored Dec 6, 2024
1 parent ea9eaf4 commit b495205
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 56 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
from abc import ABC
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, List, Optional
Expand All @@ -12,18 +11,8 @@
TRACE_POWERBI_MQUERY_PARSER = os.getenv("DATAHUB_TRACE_POWERBI_MQUERY_PARSER", False)


class AbstractIdentifierAccessor(ABC): # To pass lint
pass


# @dataclass
# class ItemSelector:
# items: Dict[str, Any]
# next: Optional[AbstractIdentifierAccessor]


@dataclass
class IdentifierAccessor(AbstractIdentifierAccessor):
class IdentifierAccessor:
"""
statement
public_order_date = Source{[Schema="public",Item="order_date"]}[Data]
Expand All @@ -40,7 +29,7 @@ class IdentifierAccessor(AbstractIdentifierAccessor):

identifier: str
items: Dict[str, Any]
next: Optional[AbstractIdentifierAccessor]
next: Optional["IdentifierAccessor"]


@dataclass
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from abc import ABC, abstractmethod
from enum import Enum
from typing import Dict, List, Optional, Tuple, Type, Union, cast
from typing import Dict, List, Optional, Tuple, Type, cast

from lark import Tree

Expand All @@ -22,7 +22,6 @@
)
from datahub.ingestion.source.powerbi.m_query import native_sql_parser, tree_function
from datahub.ingestion.source.powerbi.m_query.data_classes import (
AbstractIdentifierAccessor,
DataAccessFunctionDetail,
DataPlatformTable,
FunctionName,
Expand Down Expand Up @@ -412,33 +411,25 @@ def create_lineage(
)
table_detail: Dict[str, str] = {}
temp_accessor: Optional[
Union[IdentifierAccessor, AbstractIdentifierAccessor]
IdentifierAccessor
] = data_access_func_detail.identifier_accessor

while temp_accessor:
if isinstance(temp_accessor, IdentifierAccessor):
# Condition to handle databricks M-query pattern where table, schema and database all are present in
# the same invoke statement
if all(
element in temp_accessor.items
for element in ["Item", "Schema", "Catalog"]
):
table_detail["Schema"] = temp_accessor.items["Schema"]
table_detail["Table"] = temp_accessor.items["Item"]
else:
table_detail[temp_accessor.items["Kind"]] = temp_accessor.items[
"Name"
]

if temp_accessor.next is not None:
temp_accessor = temp_accessor.next
else:
break
# Condition to handle databricks M-query pattern where table, schema and database all are present in
# the same invoke statement
if all(
element in temp_accessor.items
for element in ["Item", "Schema", "Catalog"]
):
table_detail["Schema"] = temp_accessor.items["Schema"]
table_detail["Table"] = temp_accessor.items["Item"]
else:
logger.debug(
"expecting instance to be IdentifierAccessor, please check if parsing is done properly"
)
return Lineage.empty()
table_detail[temp_accessor.items["Kind"]] = temp_accessor.items["Name"]

if temp_accessor.next is not None:
temp_accessor = temp_accessor.next
else:
break

table_reference = self.create_reference_table(
arg_list=data_access_func_detail.arg_list,
Expand Down Expand Up @@ -786,9 +777,10 @@ def get_db_name(self, data_access_tokens: List[str]) -> Optional[str]:
def create_lineage(
self, data_access_func_detail: DataAccessFunctionDetail
) -> Lineage:
t1: Tree = cast(
Tree, tree_function.first_arg_list_func(data_access_func_detail.arg_list)
t1: Optional[Tree] = tree_function.first_arg_list_func(
data_access_func_detail.arg_list
)
assert t1 is not None
flat_argument_list: List[Tree] = tree_function.flat_argument_list(t1)

if len(flat_argument_list) != 2:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple, Union, cast
from typing import Any, Dict, List, Optional, Tuple, Union

from lark import Tree

Expand Down Expand Up @@ -95,14 +95,12 @@ def get_item_selector_tokens(
# remove whitespaces and quotes from token
tokens: List[str] = tree_function.strip_char_from_list(
tree_function.remove_whitespaces_from_list(
tree_function.token_values(
cast(Tree, item_selector), parameters=self.parameters
)
tree_function.token_values(item_selector, parameters=self.parameters)
),
)
identifier: List[str] = tree_function.token_values(
cast(Tree, identifier_tree)
) # type :ignore
identifier_tree, parameters={}
)

# convert tokens to dict
iterator = iter(tokens)
Expand Down Expand Up @@ -238,10 +236,10 @@ def _process_invoke_expression(
def _process_item_selector_expression(
self, rh_tree: Tree
) -> Tuple[Optional[str], Optional[Dict[str, str]]]:
new_identifier, key_vs_value = self.get_item_selector_tokens( # type: ignore
cast(Tree, tree_function.first_expression_func(rh_tree))
)
first_expression: Optional[Tree] = tree_function.first_expression_func(rh_tree)
assert first_expression is not None

new_identifier, key_vs_value = self.get_item_selector_tokens(first_expression)
return new_identifier, key_vs_value

@staticmethod
Expand Down Expand Up @@ -327,7 +325,7 @@ def internal(
# The first argument can be a single table argument or list of table.
# For example Table.Combine({t1,t2},....), here first argument is list of table.
# Table.AddColumn(t1,....), here first argument is single table.
for token in cast(List[str], result):
for token in result:
internal(token, identifier_accessor)

else:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from functools import partial
from typing import Any, Dict, List, Optional, Union, cast
from typing import Any, Dict, List, Optional, Union

from lark import Token, Tree

Expand Down Expand Up @@ -58,7 +58,7 @@ def internal(node: Union[Tree, Token]) -> Optional[Tree]:
if isinstance(node, Token):
return None

for child in cast(Tree, node).children:
for child in node.children:
child_node: Optional[Tree] = internal(child)
if child_node is not None:
return child_node
Expand Down Expand Up @@ -99,7 +99,7 @@ def internal(node: Union[Tree, Token]) -> None:
logger.debug(f"Unable to resolve parameter reference to {ref}")
values.append(ref)
elif isinstance(node, Token):
values.append(cast(Token, node).value)
values.append(node.value)
return
else:
for child in node.children:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, List, Optional, Set, cast
from typing import Dict, List, Optional, Set

import pydantic
from pydantic import Field, SecretStr, root_validator, validator
Expand Down Expand Up @@ -118,9 +118,10 @@ def validate_legacy_schema_pattern(cls, values: Dict) -> Dict:
)

# Always exclude reporting metadata for INFORMATION_SCHEMA schema
if schema_pattern is not None and schema_pattern:
if schema_pattern:
logger.debug("Adding deny for INFORMATION_SCHEMA to schema_pattern.")
cast(AllowDenyPattern, schema_pattern).deny.append(r".*INFORMATION_SCHEMA$")
assert isinstance(schema_pattern, AllowDenyPattern)
schema_pattern.deny.append(r".*INFORMATION_SCHEMA$")

return values

Expand Down

0 comments on commit b495205

Please sign in to comment.