Skip to content

Commit

Permalink
fix: Map BigQuery policy tags to datahub column-level tags
Browse files Browse the repository at this point in the history
  • Loading branch information
sagar-salvi-apptware committed Jun 10, 2024
1 parent 812bcbb commit 2db4d41
Show file tree
Hide file tree
Showing 7 changed files with 140 additions and 23 deletions.
1 change: 1 addition & 0 deletions metadata-ingestion/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@
# Google cloud logging library
"google-cloud-logging<=3.5.0",
"google-cloud-bigquery",
"google-cloud-datacatalog>=1.5.0",
"more-itertools>=8.12.0",
"sqlalchemy-bigquery>=1.4.1",
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,9 @@ def __init__(self, ctx: PipelineContext, config: BigQueryV2Config):
BigqueryTableIdentifier._BQ_SHARDED_TABLE_SUFFIX = ""

self.bigquery_data_dictionary = BigQuerySchemaApi(
self.report.schema_api_perf, self.config.get_bigquery_client()
self.report.schema_api_perf,
self.config.get_bigquery_client(),
self.config.get_policy_tag_manager_client(),
)
self.sql_parser_schema_resolver = self._init_schema_resolver()

Expand Down Expand Up @@ -1275,6 +1277,9 @@ def gen_schema_fields(self, columns: List[BigqueryColumn]) -> List[SchemaField]:
)
)

if col.policy_tags:
for policy_tag in col.policy_tags:
tags.append(TagAssociationClass(make_tag_urn(policy_tag)))
field = SchemaField(
fieldPath=col.name,
type=SchemaFieldDataType(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from datetime import timedelta
from typing import Any, Dict, List, Optional, Union

from google.cloud import bigquery
from google.cloud import bigquery, datacatalog_v1
from google.cloud.logging_v2.client import Client as GCPLoggingClient
from pydantic import Field, PositiveInt, PrivateAttr, root_validator, validator

Expand Down Expand Up @@ -70,6 +70,9 @@ def get_bigquery_client(self) -> bigquery.Client:
client_options = self.extra_client_options
return bigquery.Client(self.project_on_behalf, **client_options)

def get_policy_tag_manager_client(self) -> datacatalog_v1.PolicyTagManagerClient:
return datacatalog_v1.PolicyTagManagerClient()

def make_gcp_logging_client(
self, project_id: Optional[str] = None
) -> GCPLoggingClient:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from datetime import datetime, timezone
from typing import Any, Dict, Iterator, List, Optional

from google.cloud import bigquery
from google.cloud import bigquery, datacatalog_v1
from google.cloud.bigquery.table import (
RowIterator,
TableListItem,
Expand All @@ -31,6 +31,7 @@ class BigqueryColumn(BaseColumn):
field_path: str
is_partition_column: bool
cluster_column_position: Optional[int]
policy_tags: Optional[List[str]] = None


RANGE_PARTITION_NAME: str = "RANGE"
Expand Down Expand Up @@ -137,10 +138,14 @@ class BigqueryProject:

class BigQuerySchemaApi:
def __init__(
self, report: BigQuerySchemaApiPerfReport, client: bigquery.Client
self,
report: BigQuerySchemaApiPerfReport,
client: bigquery.Client,
datacatalog_client: Optional[datacatalog_v1.PolicyTagManagerClient] = None,
) -> None:
self.bq_client = client
self.report = report
self.datacatalog_client = datacatalog_client

def get_query_result(self, query: str) -> RowIterator:
logger.debug(f"Query : {query}")
Expand Down Expand Up @@ -347,6 +352,28 @@ def _make_bigquery_view(view: bigquery.Row) -> BigqueryView:
rows_count=view.get("row_count"),
)

def get_policy_tags_for_column(
self, project_id: str, dataset_name: str, table_name: str, column_name: str
) -> List[str]:
assert self.datacatalog_client
# Get the table schema
table_ref = f"{project_id}.{dataset_name}.{table_name}"
table = self.bq_client.get_table(table_ref)
schema = table.schema

# Find the specific field in the schema
field = next((f for f in schema if f.name == column_name), None)
if not field or not field.policy_tags:
return []

# Retrieve policy tag display names
policy_tag_display_names = [
self.datacatalog_client.get_policy_tag(name=policy_tag_name).display_name
for policy_tag_name in field.policy_tags.names
]

return policy_tag_display_names

def get_columns_for_dataset(
self,
project_id: str,
Expand Down Expand Up @@ -387,6 +414,9 @@ def get_columns_for_dataset(
)
last_seen_table = column.table_name
else:
policy_tags = self.get_policy_tags_for_column(
project_id, dataset_name, column.table_name, column.column_name
)
columns[column.table_name].append(
BigqueryColumn(
name=column.column_name,
Expand All @@ -397,6 +427,7 @@ def get_columns_for_dataset(
comment=column.comment,
is_partition_column=column.is_partitioning_column == "YES",
cluster_column_position=column.clustering_ordinal_position,
policy_tags=policy_tags,
)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,11 @@
"nativeDataType": "INT",
"recursive": false,
"globalTags": {
"tags": []
"tags": [
{
"tag": "urn:li:tag:Test Policy Tag"
}
]
},
"glossaryTerms": {
"terms": [
Expand Down Expand Up @@ -428,5 +432,21 @@
"runId": "bigquery-2022_02_03-07_00_00",
"lastRunId": "no-run-id-provided"
}
},
{
"entityType": "tag",
"entityUrn": "urn:li:tag:Test Policy Tag",
"changeType": "UPSERT",
"aspectName": "tagKey",
"aspect": {
"json": {
"name": "Test Policy Tag"
}
},
"systemMetadata": {
"lastObserved": 1643871600000,
"runId": "bigquery-2022_02_03-07_00_00",
"lastRunId": "no-run-id-provided"
}
}
]
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,10 @@ def random_email():
@patch.object(BigQuerySchemaApi, "get_columns_for_dataset")
@patch.object(BigQueryDataReader, "get_sample_data_for_table")
@patch("google.cloud.bigquery.Client")
@patch("google.cloud.datacatalog_v1.PolicyTagManagerClient")
def test_bigquery_v2_ingest(
client,
policy_tag_manager_client,
get_sample_data_for_table,
get_columns_for_dataset,
get_datasets_for_project_id,
Expand Down Expand Up @@ -78,6 +80,7 @@ def test_bigquery_v2_ingest(
comment="comment",
is_partition_column=False,
cluster_column_position=None,
policy_tags=["Test Policy Tag"]
),
BigqueryColumn(
name="email",
Expand Down
Loading

0 comments on commit 2db4d41

Please sign in to comment.