Skip to content

Commit

Permalink
Simplify usage of BigQuery's TableReference and DatasetReference Clas…
Browse files Browse the repository at this point in the history
…ses (#98)

* drop conn method argument in _query_and_results.

* update Table Ref and Dataset Ref usage.

* update changelog.

Co-authored-by: Matthew McKnight <[email protected]>
  • Loading branch information
drewmcdonald and McKnight-42 committed Jan 26, 2022
1 parent c45c8b0 commit ee7a3ae
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 74 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
## dbt-bigquery 1.0.2 (Release TBD)

### Under the hood
- Address BigQuery API deprecation warning and simplify usage of `TableReference` and `DatasetReference` objects ([#97](https://github.com/dbt-labs/dbt-bigquery/issues/97))

## dbt-bigquery 1.0.0 (December 3, 2021)

## dbt-bigquery 1.0.0rc2 (November 24, 2021)
Expand Down
35 changes: 16 additions & 19 deletions dbt/adapters/bigquery/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,11 +462,13 @@ def copy_bq_table(self, source, destination, write_disposition):
if type(source) is not list:
source = [source]

source_ref_array = [self.table_ref(
src_table.database, src_table.schema, src_table.table, conn)
for src_table in source]
source_ref_array = [
self.table_ref(src_table.database, src_table.schema, src_table.table)
for src_table in source
]
destination_ref = self.table_ref(
destination.database, destination.schema, destination.table, conn)
destination.database, destination.schema, destination.table
)

logger.debug(
'Copying table(s) "{}" to "{}" with disposition: "{}"',
Expand All @@ -488,43 +490,38 @@ def copy_and_results():
conn=conn, fn=copy_and_results)

@staticmethod
def dataset(database, schema, conn):
dataset_ref = conn.handle.dataset(schema, database)
return google.cloud.bigquery.Dataset(dataset_ref)
def dataset_ref(database, schema):
return google.cloud.bigquery.DatasetReference(project=database, dataset_id=schema)

@staticmethod
def dataset_from_id(dataset_id):
return google.cloud.bigquery.Dataset.from_string(dataset_id)

def table_ref(self, database, schema, table_name, conn):
dataset = self.dataset(database, schema, conn)
return dataset.table(table_name)
def table_ref(database, schema, table_name):
dataset_ref = google.cloud.bigquery.DatasetReference(database, schema)
return google.cloud.bigquery.TableReference(dataset_ref, table_name)

def get_bq_table(self, database, schema, identifier):
"""Get a bigquery table for a schema/model."""
conn = self.get_thread_connection()
table_ref = self.table_ref(database, schema, identifier, conn)
table_ref = self.table_ref(database, schema, identifier)
return conn.handle.get_table(table_ref)

def drop_dataset(self, database, schema):
conn = self.get_thread_connection()
dataset = self.dataset(database, schema, conn)
dataset_ref = self.dataset_ref(database, schema)
client = conn.handle

def fn():
return client.delete_dataset(
dataset, delete_contents=True, not_found_ok=True)
return client.delete_dataset(dataset_ref, delete_contents=True, not_found_ok=True)

self._retry_and_handle(
msg='drop dataset', conn=conn, fn=fn)

def create_dataset(self, database, schema):
conn = self.get_thread_connection()
client = conn.handle
dataset = self.dataset(database, schema, conn)
dataset_ref = self.dataset_ref(database, schema)

def fn():
return client.create_dataset(dataset, exists_ok=True)
return client.create_dataset(dataset_ref, exists_ok=True)
self._retry_and_handle(msg='create dataset', conn=conn, fn=fn)

def _query_and_results(self, client, sql, conn, job_params, timeout=None):
Expand Down
66 changes: 21 additions & 45 deletions dbt/adapters/bigquery/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,9 @@ def drop_relation(self, relation: BigQueryRelation) -> None:
self.cache_dropped(relation)

conn = self.connections.get_thread_connection()
client = conn.handle

dataset = self.connections.dataset(relation.database, relation.schema,
conn)
relation_object = dataset.table(relation.identifier)
client.delete_table(relation_object)
table_ref = self.get_table_ref_from_relation(relation)
conn.handle.delete_table(table_ref)

def truncate_relation(self, relation: BigQueryRelation) -> None:
raise dbt.exceptions.NotImplementedException(
Expand All @@ -169,10 +166,7 @@ def rename_relation(
conn = self.connections.get_thread_connection()
client = conn.handle

from_table_ref = self.connections.table_ref(from_relation.database,
from_relation.schema,
from_relation.identifier,
conn)
from_table_ref = self.get_table_ref_from_relation(from_relation)
from_table = client.get_table(from_table_ref)
if from_table.table_type == "VIEW" or \
from_relation.type == RelationType.View or \
Expand All @@ -181,10 +175,7 @@ def rename_relation(
'Renaming of views is not currently supported in BigQuery'
)

to_table_ref = self.connections.table_ref(to_relation.database,
to_relation.schema,
to_relation.identifier,
conn)
to_table_ref = self.get_table_ref_from_relation(to_relation)

self.cache_renamed(from_relation, to_relation)
client.copy_table(from_table_ref, to_table_ref)
Expand Down Expand Up @@ -212,15 +203,13 @@ def check_schema_exists(self, database: str, schema: str) -> bool:
conn = self.connections.get_thread_connection()
client = conn.handle

bigquery_dataset = self.connections.dataset(
database, schema, conn
)
dataset_ref = self.connections.dataset_ref(database, schema)
# try to do things with the dataset. If it doesn't exist it will 404.
# we have to do it this way to handle underscore-prefixed datasets,
# which appear in neither the information_schema.schemata view nor the
# list_datasets method.
try:
next(iter(client.list_tables(bigquery_dataset, max_results=1)))
next(iter(client.list_tables(dataset_ref, max_results=1)))
except StopIteration:
pass
except google.api_core.exceptions.NotFound:
Expand Down Expand Up @@ -262,12 +251,12 @@ def list_relations_without_caching(
connection = self.connections.get_thread_connection()
client = connection.handle

bigquery_dataset = self.connections.dataset(
schema_relation.database, schema_relation.schema, connection
dataset_ref = self.connections.dataset_ref(
schema_relation.database, schema_relation.schema
)

all_tables = client.list_tables(
bigquery_dataset,
dataset_ref,
# BigQuery paginates tables by alphabetizing them, and using
# the name of the last table on a page as the key for the
# next page. If that key table gets dropped before we run
Expand Down Expand Up @@ -583,11 +572,10 @@ def parse_partition_by(
"""
return PartitionConfig.parse(raw_partition_by)

def get_table_ref_from_relation(self, conn, relation):
return self.connections.table_ref(relation.database,
relation.schema,
relation.identifier,
conn)
def get_table_ref_from_relation(self, relation):
return self.connections.table_ref(
relation.database, relation.schema, relation.identifier
)

def _update_column_dict(self, bq_column_dict, dbt_columns, parent=''):
"""
Expand Down Expand Up @@ -629,7 +617,7 @@ def update_columns(self, relation, columns):
return

conn = self.connections.get_thread_connection()
table_ref = self.get_table_ref_from_relation(conn, relation)
table_ref = self.get_table_ref_from_relation(relation)
table = conn.handle.get_table(table_ref)

new_schema = []
Expand All @@ -651,12 +639,7 @@ def update_table_description(
conn = self.connections.get_thread_connection()
client = conn.handle

table_ref = self.connections.table_ref(
database,
schema,
identifier,
conn
)
table_ref = self.connections.table_ref(database, schema, identifier)
table = client.get_table(table_ref)
table.description = description
client.update_table(table, ['description'])
Expand All @@ -670,9 +653,7 @@ def alter_table_add_columns(self, relation, columns):
conn = self.connections.get_thread_connection()
client = conn.handle

table_ref = self.connections.table_ref(relation.database,
relation.schema,
relation.identifier, conn)
table_ref = self.get_table_ref_from_relation(relation)
table = client.get_table(table_ref)

new_columns = [col.column_to_bq_schema() for col in columns]
Expand All @@ -688,14 +669,14 @@ def load_dataframe(self, database, schema, table_name, agate_table,
conn = self.connections.get_thread_connection()
client = conn.handle

table = self.connections.table_ref(database, schema, table_name, conn)
table_ref = self.connections.table_ref(database, schema, table_name)

load_config = google.cloud.bigquery.LoadJobConfig()
load_config.skip_leading_rows = 1
load_config.schema = bq_schema

with open(agate_table.original_abspath, "rb") as f:
job = client.load_table_from_file(f, table, rewind=True,
job = client.load_table_from_file(f, table_ref, rewind=True,
job_config=load_config)

timeout = self.connections.get_timeout(conn)
Expand Down Expand Up @@ -793,16 +774,11 @@ def grant_access_to(self, entity, entity_type, role, grant_target_dict):

GrantTarget.validate(grant_target_dict)
grant_target = GrantTarget.from_dict(grant_target_dict)
dataset = client.get_dataset(
self.connections.dataset_from_id(grant_target.render())
)
dataset_ref = self.connections.dataset_ref(grant_target.project, grant_target.dataset)
dataset = client.get_dataset(dataset_ref)

if entity_type == 'view':
entity = self.connections.table_ref(
entity.database,
entity.schema,
entity.identifier,
conn).to_api_repr()
entity = self.get_table_ref_from_relation(entity).to_api_repr()

access_entry = AccessEntry(role, entity_type, entity)
access_entries = dataset.access_entries
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,10 @@ def test_bigquery_adapter_methods(self):
client = conn.handle

grant_target = GrantTarget.from_dict(ae_grant_target_dict)
dataset = client.get_dataset(
self.adapter.connections.dataset_from_id(grant_target.render())
dataset_ref = self.adapter.connections.dataset_ref(
grant_target.project, grant_target.dataset
)
dataset = client.get_dataset(dataset_ref)

expected_access_entry = AccessEntry(ae_role, ae_entity_type, ae_entity)
self.assertTrue(expected_access_entry in dataset.access_entries)
Expand Down
14 changes: 6 additions & 8 deletions tests/unit/test_bigquery_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,8 +615,8 @@ def test_copy_bq_table_appends(self):
write_disposition=dbt.adapters.bigquery.impl.WRITE_APPEND)
args, kwargs = self.mock_client.copy_table.call_args
self.mock_client.copy_table.assert_called_once_with(
[self._table_ref('project', 'dataset', 'table1', None)],
self._table_ref('project', 'dataset', 'table2', None),
[self._table_ref('project', 'dataset', 'table1')],
self._table_ref('project', 'dataset', 'table2'),
job_config=ANY)
args, kwargs = self.mock_client.copy_table.call_args
self.assertEqual(
Expand All @@ -628,8 +628,8 @@ def test_copy_bq_table_truncates(self):
write_disposition=dbt.adapters.bigquery.impl.WRITE_TRUNCATE)
args, kwargs = self.mock_client.copy_table.call_args
self.mock_client.copy_table.assert_called_once_with(
[self._table_ref('project', 'dataset', 'table1', None)],
self._table_ref('project', 'dataset', 'table2', None),
[self._table_ref('project', 'dataset', 'table1')],
self._table_ref('project', 'dataset', 'table2'),
job_config=ANY)
args, kwargs = self.mock_client.copy_table.call_args
self.assertEqual(
Expand All @@ -645,12 +645,10 @@ def test_job_labels_invalid_json(self):
labels = self.connections._labels_from_query_comment("not json")
self.assertEqual(labels, {"query_comment": "not_json"})

def _table_ref(self, proj, ds, table, conn):
return google.cloud.bigquery.table.TableReference.from_string(
'{}.{}.{}'.format(proj, ds, table))
def _table_ref(self, proj, ds, table):
return self.connections.table_ref(proj, ds, table)

def _copy_table(self, write_disposition):
self.connections.table_ref = self._table_ref
source = BigQueryRelation.create(
database='project', schema='dataset', identifier='table1')
destination = BigQueryRelation.create(
Expand Down

0 comments on commit ee7a3ae

Please sign in to comment.