Skip to content

Commit

Permalink
pr feedback: refactor code and extract TableauSiteSource into its own…
Browse files Browse the repository at this point in the history
… class
  • Loading branch information
Yanik Häni authored and Yanik Häni committed Jul 2, 2024
1 parent c85d387 commit fc3a05b
Showing 1 changed file with 127 additions and 117 deletions.
244 changes: 127 additions & 117 deletions metadata-ingestion/src/datahub/ingestion/source/tableau.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ class TableauConnectionConfig(ConfigModel):
def remove_trailing_slash(cls, v):
return config_clean.remove_trailing_slashes(v)

def make_tableau_client(self, site: str) -> Server:
def get_tableau_auth(self, site: str) -> TableauAuth:
# https://tableau.github.io/server-client-python/docs/api-ref#authentication
authentication: Union[TableauAuth, PersonalAccessTokenAuth]
if self.username and self.password:
Expand All @@ -224,7 +224,10 @@ def make_tableau_client(self, site: str) -> Server:
raise ConfigurationError(
"Tableau Source: Either username/password or token_name/token_value must be set"
)
return authentication

def make_tableau_client(self, site: str) -> Server:
authentication: Union[TableauAuth, PersonalAccessTokenAuth] = self.get_tableau_auth(site)
try:
server = Server(
self.connect_uri,
Expand Down Expand Up @@ -548,6 +551,102 @@ class TableauSourceReport(StaleEntityRemovalSourceReport):
"Enabled by default, configure using `extract_column_level_lineage`",
)
class TableauSource(StatefulIngestionSourceBase, TestableSource):
def __init__(
self,
config: TableauConfig,
ctx: PipelineContext,
):
super().__init__(config, ctx)
self.config: TableauConfig = config
self.report: TableauSourceReport = TableauSourceReport()
self.server: Optional[Server] = None
self._authenticate(self.config.site)


def _authenticate(self, site_content_url) -> Server:
try:
logger.info(f"Authenticated to Tableau site: '{site_content_url}'")
self.server = self.config.make_tableau_client(site_content_url)
# Note that we're not catching ConfigurationError, since we want that to throw.
except ValueError as e:
self.report.failure(
key="tableau-login",
reason=str(e),
)
@staticmethod
def test_connection(config_dict: dict) -> TestConnectionReport:
test_report = TestConnectionReport()
try:
source_config = TableauConfig.parse_obj_allow_extras(config_dict)
source_config.make_tableau_client(source_config.site)
test_report.basic_connectivity = CapabilityReport(capable=True)
except Exception as e:
test_report.basic_connectivity = CapabilityReport(
capable=False, failure_reason=str(e)
)
return test_report


def get_report(self) -> TableauSourceReport:
return self.report


@classmethod
def create(cls, config_dict: dict, ctx: PipelineContext) -> Source:
config = TableauConfig.parse_obj(config_dict)
return cls(config, ctx)


def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]:
return [
*super().get_workunit_processors(),
StaleEntityRemovalHandler.create(
self, self.config, self.ctx
).workunit_processor,
]

def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
if self.server is None or not self.server.is_signed_in():
return
try:
if self.config.ingest_multiple_sites:
for site in TSC.Pager(self.server.sites):
if (
site.state != "Active"
or not self.config.site_name_pattern.allowed(site.name)
):
logger.info(f"Skip site '{site.name}' as it's excluded in site_name_pattern or inactive.")
continue
self.server.auth.switch_site(site)
site_source = TableauSiteSource(config=self.config, ctx=self.ctx, site=site, report=self.report, server=self.server)
logger.info(f"Ingesting assets of site '{site.content_url}'.")
yield from site_source.ingest_tableau_site()
else:
site = self.server.sites.get_by_id(self.server.site_id)
site_source = TableauSiteSource(config=self.config, ctx=self.ctx, site=site, report=self.report, server=self.server)
yield from site_source.ingest_tableau_site()
except MetadataQueryException as md_exception:
self.report.failure(
key="tableau-metadata",
reason=f"Unable to retrieve metadata from tableau. Information: {str(md_exception)}",
)


def close(self) -> None:
try:
if self.server is not None:
self.server.auth.sign_out()
except Exception as ex:
logger.warning(
"During graceful closing of Tableau source a sign-out call was tried but ended up with"
" an Exception (%s). Continuing closing of the source",
ex,
)
self.server = None
super().close()


class TableauSiteSource:
platform = "tableau"

def __hash__(self):
Expand All @@ -557,18 +656,21 @@ def __init__(
self,
config: TableauConfig,
ctx: PipelineContext,
site: SiteItem,
report: TableauSourceReport,
server: Server
):
super().__init__(config, ctx)

self.config: TableauConfig = config
self.report: TableauSourceReport = TableauSourceReport()
self.server: Optional[Server] = None
self.report = report
self.server: Server = server
self.ctx: PipelineContext = ctx
self.site: SiteItem = site

self.database_tables: Dict[str, DatabaseTable] = {}
self.tableau_stat_registry: Dict[str, UsageStat] = {}
self.tableau_project_registry: Dict[str, TableauProject] = {}
self.workbook_project_map: Dict[str, str] = {}
self.datasource_project_map: Dict[str, str] = {}
self.current_site: Optional[SiteItem] = None

# This map keeps track of the database server connection hostnames.
self.database_server_hostname_map: Dict[str, str] = {}
Expand All @@ -588,34 +690,6 @@ def __init__(
# when emitting custom SQL data sources.
self.custom_sql_ids_being_used: List[str] = []

self._authenticate()

@staticmethod
def test_connection(config_dict: dict) -> TestConnectionReport:
test_report = TestConnectionReport()
try:
source_config = TableauConfig.parse_obj_allow_extras(config_dict)
source_config.make_tableau_client(source_config.site)
test_report.basic_connectivity = CapabilityReport(capable=True)
except Exception as e:
test_report.basic_connectivity = CapabilityReport(
capable=False, failure_reason=str(e)
)
return test_report

def close(self) -> None:
try:
if self.server is not None:
self.server.auth.sign_out()
except Exception as ex:
logger.warning(
"During graceful closing of Tableau source a sign-out call was tried but ended up with"
" an Exception (%s). Continuing closing of the source",
ex,
)
self.server = None
super().close()

@property
def no_env_browse_prefix(self) -> str:
# Prefix to use with browse path (v1)
Expand All @@ -631,8 +705,8 @@ def no_env_browse_prefix(self) -> str:
@property
def site_name_browse_path(self) -> str:
site_name_prefix = (
self.current_site.name
if self.current_site and self.config.add_site_container
self.site.name
if self.site and self.config.add_site_container
else ""
)
return f"/{site_name_prefix}" if site_name_prefix else ""
Expand All @@ -642,18 +716,10 @@ def dataset_browse_prefix(self) -> str:
# datasets also have the env in the browse path
return f"/{self.config.env.lower()}{self.no_env_browse_prefix}"

def _init_ingestion_variables(self):
# Reset / initialize all ingestion variables.
self.tableau_stat_registry = {}
self.tableau_project_registry = {}
self.datasource_project_map = {}
self.database_tables = {}
self.workbook_project_map = {}
self.sheet_ids = []
self.dashboard_ids = []
self.embedded_datasource_ids_being_used = []
self.datasource_ids_being_used = []
self.custom_sql_ids_being_used = []

def _re_authenticate(self):
tableau_auth: Union[TableauAuth, PersonalAccessTokenAuth] = self.config.get_tableau_auth(self.site.content_url)
self.server.auth.sign_in(tableau_auth)

def _populate_usage_stat_registry(self) -> None:
if self.server is None:
Expand Down Expand Up @@ -841,20 +907,6 @@ def _populate_projects_registry(self) -> None:
f"Tableau workbooks {self.workbook_project_map}",
)

def _authenticate(self) -> None:
try:
site_content_url = (
self.current_site.content_url if self.current_site else self.config.site
)
self.server = self.config.make_tableau_client(site_content_url)
logger.info(f"Authenticated to Tableau site: '{site_content_url}'")
# Note that we're not catching ConfigurationError, since we want that to throw.
except ValueError as e:
self.report.failure(
key="tableau-login",
reason=str(e),
)

def get_data_platform_instance(self) -> DataPlatformInstanceClass:
return DataPlatformInstanceClass(
platform=builder.make_data_platform_urn(self.platform),
Expand Down Expand Up @@ -895,7 +947,7 @@ def get_connection_object_page(
# If ingestion has been running for over 2 hours, the Tableau
# temporary credentials will expire. If this happens, this exception
# will be thrown and we need to re-authenticate and retry.
self._authenticate()
self._re_authenticate()
return self.get_connection_object_page(
query,
connection_type,
Expand Down Expand Up @@ -1474,6 +1526,7 @@ def emit_custom_sql_datasources(self) -> Iterable[MetadataWorkUnit]:

datasource_name = None
project = None
columns = []
if len(csql[c.DATA_SOURCES]) > 0:
# CustomSQLTable id owned by exactly one tableau data source
logger.debug(
Expand Down Expand Up @@ -2366,7 +2419,7 @@ def emit_sheets_as_charts(

if sheet.get(c.PATH):
site_part = (
f"/site/{self.current_site.content_url}" if self.current_site else ""
f"/site/{self.site.content_url}" if self.site else ""
)
sheet_external_url = (
f"{self.config.connect_uri}/#{site_part}/views/{sheet.get(c.PATH)}"
Expand All @@ -2379,7 +2432,7 @@ def emit_sheets_as_charts(
):
# sheet contained in dashboard
site_part = (
f"/t/{self.current_site.content_url}" if self.current_site else ""
f"/t/{self.site.content_url}" if self.site else ""
)
dashboard_path = sheet[c.CONTAINED_IN_DASHBOARDS][0][c.PATH]
sheet_external_url = f"{self.config.connect_uri}{site_part}/authoring/{dashboard_path}/{sheet.get(c.NAME, '')}"
Expand Down Expand Up @@ -2513,7 +2566,7 @@ def emit_workbook_as_container(self, workbook: Dict) -> Iterable[MetadataWorkUni
)

site_part = (
f"/site/{self.current_site.content_url}" if self.current_site else ""
f"/site/{self.site.content_url}" if self.site else ""
)
workbook_uri = workbook.get("uri")
workbook_part = (
Expand Down Expand Up @@ -2674,7 +2727,7 @@ def emit_dashboard(
last_modified = self.get_last_modified(creator, created_at, updated_at)

site_part = (
f"/site/{self.current_site.content_url}" if self.current_site else ""
f"/site/{self.site.content_url}" if self.site else ""
)
dashboard_external_url = (
f"{self.config.connect_uri}/#{site_part}/views/{dashboard.get(c.PATH, '')}"
Expand Down Expand Up @@ -2818,22 +2871,17 @@ def _get_ownership(self, user: str) -> Optional[OwnershipClass]:

return None

@classmethod
def create(cls, config_dict: dict, ctx: PipelineContext) -> Source:
config = TableauConfig.parse_obj(config_dict)
return cls(config, ctx)

def emit_project_containers(self) -> Iterable[MetadataWorkUnit]:
for _id, project in self.tableau_project_registry.items():
parent_container_key: Optional[ContainerKey] = None
if project.parent_id:
parent_container_key = self.gen_project_key(project.parent_id)
elif (
self.config.add_site_container
and self.current_site
and self.current_site.id
and self.site
and self.site.id
):
parent_container_key = self.gen_site_key(self.current_site.id)
parent_container_key = self.gen_site_key(self.site.id)

yield from gen_containers(
container_key=self.gen_project_key(_id),
Expand All @@ -2856,52 +2904,17 @@ def emit_project_containers(self) -> Iterable[MetadataWorkUnit]:
)

def emit_site_container(self):
if not self.current_site or not self.current_site.id:
if not self.site or not self.site.id:
logger.warning("Can not ingest site container. No site information found.")
return

yield from gen_containers(
container_key=self.gen_site_key(self.current_site.id),
name=self.current_site.name or "Default",
container_key=self.gen_site_key(self.site.id),
name=self.site.name or "Default",
sub_types=[c.SITE],
)

def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]:
return [
*super().get_workunit_processors(),
StaleEntityRemovalHandler.create(
self, self.config, self.ctx
).workunit_processor,
]

def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
if self.server is None or not self.server.is_signed_in():
return
try:
if self.config.ingest_multiple_sites:
for site in TSC.Pager(self.server.sites):
if (
site.state != "Active"
or not self.config.site_name_pattern.allowed(site.name)
):
logger.info(f"Skip site '{site.name}' as it's excluded in site_name_pattern or inactive.")
continue
self.current_site = site
self.server.auth.switch_site(site)
logger.info(f"Ingesting assets of site '{site.content_url}'.")
yield from self._ingest_tableau_site()
else:
site = self.server.sites.get_by_id(self.server.site_id)
self.current_site = site
yield from self._ingest_tableau_site()
except MetadataQueryException as md_exception:
self.report.failure(
key="tableau-metadata",
reason=f"Unable to retrieve metadata from tableau. Information: {str(md_exception)}",
)

def _ingest_tableau_site(self):
self._init_ingestion_variables()
def ingest_tableau_site(self):
# Initialise the dictionary to later look-up for chart and dashboard stat
if self.config.extract_usage_stats:
self._populate_usage_stat_registry()
Expand Down Expand Up @@ -2929,6 +2942,3 @@ def _ingest_tableau_site(self):
yield from self.emit_custom_sql_datasources()
if self.database_tables:
yield from self.emit_upstream_tables()

def get_report(self) -> TableauSourceReport:
return self.report

0 comments on commit fc3a05b

Please sign in to comment.