Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate to dbt-adapter and common #1071

Merged
merged 8 commits into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/unreleased/Under the Hood-20240116-154305.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Under the Hood
body: Migrate to dbt-common and dbt-adapters package
time: 2024-01-16T15:43:05.046735-08:00
custom:
Author: colin-rogers-dbt
Issue: "1071"
18 changes: 9 additions & 9 deletions dbt/adapters/bigquery/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from contextlib import contextmanager
from dataclasses import dataclass, field

from dbt.common.invocation import get_invocation_id
from dbt_common.invocation import get_invocation_id

from dbt.common.events.contextvars import get_node_info
from dbt_common.events.contextvars import get_node_info
from mashumaro.helper import pass_through

from functools import lru_cache
Expand All @@ -27,21 +27,21 @@
)

from dbt.adapters.bigquery import gcloud
from dbt.common.clients import agate_helper
from dbt.adapters.contracts.connection import ConnectionState, AdapterResponse
from dbt.common.exceptions import (
from dbt_common.clients import agate_helper
from dbt.adapters.contracts.connection import ConnectionState, AdapterResponse, Credentials
from dbt_common.exceptions import (
DbtRuntimeError,
DbtConfigError,
)
from dbt.common.exceptions import DbtDatabaseError
from dbt_common.exceptions import DbtDatabaseError
from dbt.adapters.exceptions.connection import FailedToConnectError
from dbt.adapters.base import BaseConnectionManager, Credentials
from dbt.adapters.base import BaseConnectionManager
from dbt.adapters.events.logging import AdapterLogger
from dbt.adapters.events.types import SQLQuery
from dbt.common.events.functions import fire_event
from dbt_common.events.functions import fire_event
from dbt.adapters.bigquery import __version__ as dbt_version

from dbt.common.dataclass_schema import ExtensibleDbtClassMixin, StrEnum
from dbt_common.dataclass_schema import ExtensibleDbtClassMixin, StrEnum

logger = AdapterLogger("BigQuery")

Expand Down
7 changes: 4 additions & 3 deletions dbt/adapters/bigquery/gcloud.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dbt_common.exceptions import DbtRuntimeError

from dbt.adapters.events.logging import AdapterLogger
import dbt.common.exceptions
from dbt.common.clients.system import run_cmd
from dbt_common.clients.system import run_cmd

NOT_INSTALLED_MSG = """
dbt requires the gcloud SDK to be installed to authenticate with BigQuery.
Expand All @@ -25,4 +26,4 @@ def setup_default_credentials():
if gcloud_installed():
run_cmd(".", ["gcloud", "auth", "application-default", "login"])
else:
raise dbt.common.exceptions.DbtRuntimeError(NOT_INSTALLED_MSG)
raise DbtRuntimeError(NOT_INSTALLED_MSG)
30 changes: 15 additions & 15 deletions dbt/adapters/bigquery/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import agate
from dbt.adapters.contracts.relation import RelationConfig

import dbt.common.exceptions.base
import dbt_common.exceptions.base
from dbt.adapters.base import ( # type: ignore
AdapterConfig,
BaseAdapter,
Expand All @@ -21,15 +21,15 @@
available,
)
from dbt.adapters.cache import _make_ref_key_dict # type: ignore
import dbt.common.clients.agate_helper
import dbt_common.clients.agate_helper
from dbt.adapters.contracts.connection import AdapterResponse
from dbt.common.contracts.constraints import ColumnLevelConstraint, ConstraintType, ModelLevelConstraint # type: ignore
from dbt.common.dataclass_schema import dbtClassMixin
from dbt_common.contracts.constraints import ColumnLevelConstraint, ConstraintType, ModelLevelConstraint # type: ignore
from dbt_common.dataclass_schema import dbtClassMixin
from dbt.adapters.events.logging import AdapterLogger
from dbt.common.events.functions import fire_event
from dbt_common.events.functions import fire_event
from dbt.adapters.events.types import SchemaCreation, SchemaDrop
import dbt.common.exceptions
from dbt.common.utils import filter_null_values
import dbt_common.exceptions
from dbt_common.utils import filter_null_values
import google.api_core
import google.auth
import google.oauth2
Expand Down Expand Up @@ -147,7 +147,7 @@ def drop_relation(self, relation: BigQueryRelation) -> None:
conn.handle.delete_table(table_ref, not_found_ok=True)

def truncate_relation(self, relation: BigQueryRelation) -> None:
raise dbt.common.exceptions.base.NotImplementedError(
raise dbt_common.exceptions.base.NotImplementedError(
"`truncate` is not implemented for this adapter!"
)

Expand All @@ -164,7 +164,7 @@ def rename_relation(
or from_relation.type == RelationType.View
or to_relation.type == RelationType.View
):
raise dbt.common.exceptions.DbtRuntimeError(
raise dbt_common.exceptions.DbtRuntimeError(
"Renaming of views is not currently supported in BigQuery"
)

Expand Down Expand Up @@ -390,7 +390,7 @@ def copy_table(self, source, destination, materialization):
elif materialization == "table":
write_disposition = WRITE_TRUNCATE
else:
raise dbt.common.exceptions.CompilationError(
raise dbt_common.exceptions.CompilationError(
'Copy table materialization must be "copy" or "table", but '
f"config.get('copy_materialization', 'table') was "
f"{materialization}"
Expand Down Expand Up @@ -437,11 +437,11 @@ def poll_until_job_completes(cls, job, timeout):
job.reload()

if job.state != "DONE":
raise dbt.common.exceptions.DbtRuntimeError("BigQuery Timeout Exceeded")
raise dbt_common.exceptions.DbtRuntimeError("BigQuery Timeout Exceeded")

elif job.error_result:
message = "\n".join(error["message"].strip() for error in job.errors)
raise dbt.common.exceptions.DbtRuntimeError(message)
raise dbt_common.exceptions.DbtRuntimeError(message)

def _bq_table_to_relation(self, bq_table) -> Union[BigQueryRelation, None]:
if bq_table is None:
Expand All @@ -465,7 +465,7 @@ def add_query(self, sql, auto_begin=True, bindings=None, abridge_sql_log=False):
if self.nice_connection_name() in ["on-run-start", "on-run-end"]:
self.warning_on_hooks(self.nice_connection_name())
else:
raise dbt.common.exceptions.base.NotImplementedError(
raise dbt_common.exceptions.base.NotImplementedError(
"`add_query` is not implemented for this adapter!"
)

Expand Down Expand Up @@ -777,7 +777,7 @@ def describe_relation(
bq_table = self.get_bq_table(relation)
parser = BigQueryMaterializedViewConfig
else:
raise dbt.common.exceptions.DbtRuntimeError(
raise dbt_common.exceptions.DbtRuntimeError(
f"The method `BigQueryAdapter.describe_relation` is not implemented "
f"for the relation type: {relation.type}"
)
Expand Down Expand Up @@ -843,7 +843,7 @@ def string_add_sql(
elif location == "prepend":
return f"concat('{value}', {add_to})"
else:
raise dbt.common.exceptions.DbtRuntimeError(
raise dbt_common.exceptions.DbtRuntimeError(
f'Got an unexpected location value of "{location}"'
)

Expand Down
4 changes: 2 additions & 2 deletions dbt/adapters/bigquery/relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
BigQueryPartitionConfigChange,
)
from dbt.adapters.contracts.relation import RelationType, RelationConfig
from dbt.common.exceptions import CompilationError
from dbt.common.utils.dict import filter_null_values
from dbt_common.exceptions import CompilationError
from dbt_common.utils.dict import filter_null_values


Self = TypeVar("Self", bound="BigQueryRelation")
Expand Down
8 changes: 4 additions & 4 deletions dbt/adapters/bigquery/relation_configs/_partition.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Optional

import dbt.common.exceptions
import dbt_common.exceptions
from dbt.adapters.relation_configs import RelationConfigChange
from dbt.adapters.contracts.relation import RelationConfig
from dbt.common.dataclass_schema import dbtClassMixin, ValidationError
from dbt_common.dataclass_schema import dbtClassMixin, ValidationError
from google.cloud.bigquery.table import Table as BigQueryTable


Expand Down Expand Up @@ -92,11 +92,11 @@ def parse(cls, raw_partition_by) -> Optional["PartitionConfig"]:
}
)
except ValidationError as exc:
raise dbt.common.exceptions.base.DbtValidationError(
raise dbt_common.exceptions.base.DbtValidationError(
"Could not parse partition config"
) from exc
except TypeError:
raise dbt.common.exceptions.CompilationError(
raise dbt_common.exceptions.CompilationError(
f"Invalid partition_by config:\n"
f" Got: {raw_partition_by}\n"
f' Expected a dictionary with "field" and "data_type" keys'
Expand Down
4 changes: 2 additions & 2 deletions dbt/adapters/bigquery/utility.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
from typing import Any, Optional

import dbt.common.exceptions
import dbt_common.exceptions


def bool_setting(value: Optional[Any] = None) -> Optional[bool]:
Expand Down Expand Up @@ -41,5 +41,5 @@ def float_setting(value: Optional[Any] = None) -> Optional[float]:

def sql_escape(string):
if not isinstance(string, str):
raise dbt.common.exceptions.CompilationError(f"cannot escape a non-string: {string}")
raise dbt_common.exceptions.CompilationError(f"cannot escape a non-string: {string}")
return json.dumps(string)[1:-1]
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ def _dbt_core_version(plugin_version: str) -> str:
packages=find_namespace_packages(include=["dbt", "dbt.*"]),
include_package_data=True,
install_requires=[
f"dbt-core~={_dbt_core_version(_dbt_bigquery_version())}",
"dbt-common<1.0",
"dbt-adapters~=0.1.0a1",
"google-cloud-bigquery~=3.0",
"google-cloud-storage~=2.4",
"google-cloud-dataproc~=5.0",
Expand Down
4 changes: 2 additions & 2 deletions tests/functional/adapter/column_types/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
version: 2
models:
- name: model
tests:
data_tests:
- is_type:
column_map:
int64_col: ['integer', 'number']
Expand All @@ -39,7 +39,7 @@
version: 2
models:
- name: model
tests:
data_tests:
- is_type:
column_map:
int64_col: ['string', 'not number']
Expand Down
4 changes: 2 additions & 2 deletions tests/functional/adapter/test_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@
version: 2
models:
- name: model_a
tests:
data_tests:
- expect_value:
field: tablename
value: duped_alias
- name: model_b
tests:
data_tests:
- expect_value:
field: tablename
value: duped_alias
Expand Down
36 changes: 19 additions & 17 deletions tests/unit/test_bigquery_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,16 @@
import unittest
from unittest.mock import patch, MagicMock, create_autospec

import dbt.common.dataclass_schema
import dbt.common.exceptions.base
import dbt_common.dataclass_schema
import dbt_common.exceptions.base

import dbt.adapters
from dbt.adapters.bigquery.relation_configs import PartitionConfig
from dbt.adapters.bigquery import BigQueryAdapter, BigQueryRelation
from google.cloud.bigquery.table import Table
from dbt.adapters.bigquery.connections import _sanitize_label, _VALIDATE_LABEL_LENGTH_LIMIT
from dbt.common.clients import agate_helper
import dbt.common.exceptions
from dbt_common.clients import agate_helper
import dbt_common.exceptions
from dbt.context.manifest import generate_query_header_context
from dbt.contracts.files import FileHash
from dbt.contracts.graph.manifest import ManifestStateCheck
Expand Down Expand Up @@ -214,7 +216,7 @@ def test_acquire_connection_oauth_no_project_validations(
connection = adapter.acquire_connection("dummy")
self.assertEqual(connection.type, "bigquery")

except dbt.common.exceptions.base.DbtValidationError as e:
except dbt_common.exceptions.base.DbtValidationError as e:
self.fail("got DbtValidationError: {}".format(str(e)))

except BaseException:
Expand All @@ -231,7 +233,7 @@ def test_acquire_connection_oauth_validations(self, mock_open_connection):
connection = adapter.acquire_connection("dummy")
self.assertEqual(connection.type, "bigquery")

except dbt.common.exceptions.base.DbtValidationError as e:
except dbt_common.exceptions.base.DbtValidationError as e:
self.fail("got DbtValidationError: {}".format(str(e)))

except BaseException:
Expand All @@ -255,7 +257,7 @@ def test_acquire_connection_dataproc_serverless(
connection = adapter.acquire_connection("dummy")
self.assertEqual(connection.type, "bigquery")

except dbt.common.exceptions.ValidationException as e:
except dbt_common.exceptions.ValidationException as e:
self.fail("got ValidationException: {}".format(str(e)))

except BaseException:
Expand All @@ -272,7 +274,7 @@ def test_acquire_connection_service_account_validations(self, mock_open_connecti
connection = adapter.acquire_connection("dummy")
self.assertEqual(connection.type, "bigquery")

except dbt.common.exceptions.base.DbtValidationError as e:
except dbt_common.exceptions.base.DbtValidationError as e:
self.fail("got DbtValidationError: {}".format(str(e)))

except BaseException:
Expand All @@ -289,7 +291,7 @@ def test_acquire_connection_oauth_token_validations(self, mock_open_connection):
connection = adapter.acquire_connection("dummy")
self.assertEqual(connection.type, "bigquery")

except dbt.common.exceptions.base.DbtValidationError as e:
except dbt_common.exceptions.base.DbtValidationError as e:
self.fail("got DbtValidationError: {}".format(str(e)))

except BaseException:
Expand All @@ -306,7 +308,7 @@ def test_acquire_connection_oauth_credentials_validations(self, mock_open_connec
connection = adapter.acquire_connection("dummy")
self.assertEqual(connection.type, "bigquery")

except dbt.common.exceptions.base.DbtValidationError as e:
except dbt_common.exceptions.base.DbtValidationError as e:
self.fail("got DbtValidationError: {}".format(str(e)))

except BaseException:
Expand All @@ -325,7 +327,7 @@ def test_acquire_connection_impersonated_service_account_validations(
connection = adapter.acquire_connection("dummy")
self.assertEqual(connection.type, "bigquery")

except dbt.common.exceptions.base.DbtValidationError as e:
except dbt_common.exceptions.base.DbtValidationError as e:
self.fail("got DbtValidationError: {}".format(str(e)))

except BaseException:
Expand All @@ -343,7 +345,7 @@ def test_acquire_connection_priority(self, mock_open_connection):
self.assertEqual(connection.type, "bigquery")
self.assertEqual(connection.credentials.priority, "batch")

except dbt.common.exceptions.base.DbtValidationError as e:
except dbt_common.exceptions.base.DbtValidationError as e:
self.fail("got DbtValidationError: {}".format(str(e)))

mock_open_connection.assert_not_called()
Expand All @@ -358,7 +360,7 @@ def test_acquire_connection_maximum_bytes_billed(self, mock_open_connection):
self.assertEqual(connection.type, "bigquery")
self.assertEqual(connection.credentials.maximum_bytes_billed, 0)

except dbt.common.exceptions.base.DbtValidationError as e:
except dbt_common.exceptions.base.DbtValidationError as e:
self.fail("got DbtValidationError: {}".format(str(e)))

mock_open_connection.assert_not_called()
Expand Down Expand Up @@ -509,7 +511,7 @@ def test_invalid_relation(self):
},
"quote_policy": {"identifier": False, "schema": True},
}
with self.assertRaises(dbt.common.dataclass_schema.ValidationError):
with self.assertRaises(dbt_common.dataclass_schema.ValidationError):
BigQueryRelation.validate(kwargs)


Expand Down Expand Up @@ -581,10 +583,10 @@ def test_copy_table_materialization_incremental(self):
def test_parse_partition_by(self):
adapter = self.get_adapter("oauth")

with self.assertRaises(dbt.common.exceptions.base.DbtValidationError):
with self.assertRaises(dbt_common.exceptions.base.DbtValidationError):
adapter.parse_partition_by("date(ts)")

with self.assertRaises(dbt.common.exceptions.base.DbtValidationError):
with self.assertRaises(dbt_common.exceptions.base.DbtValidationError):
adapter.parse_partition_by("ts")

self.assertEqual(
Expand Down Expand Up @@ -736,7 +738,7 @@ def test_parse_partition_by(self):
)

# Invalid, should raise an error
with self.assertRaises(dbt.common.exceptions.base.DbtValidationError):
with self.assertRaises(dbt_common.exceptions.base.DbtValidationError):
adapter.parse_partition_by({})

# passthrough
Expand Down
Loading