-
Notifications
You must be signed in to change notification settings - Fork 40
/
Copy pathfactory.py
264 lines (203 loc) · 9.15 KB
/
factory.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
from contextlib import contextmanager
from importlib import import_module
from multiprocessing.context import SpawnContext
from pathlib import Path
import threading
import traceback
from typing import Any, Dict, List, Optional, Set, Type
from dbt_common.events.functions import fire_event
from dbt_common.events.base_types import EventLevel
from dbt_common.exceptions import DbtInternalError, DbtRuntimeError
from dbt_common.semver import VersionSpecifier
from dbt.adapters.base.plugin import AdapterPlugin
from dbt.adapters.contracts.connection import AdapterRequiredConfig, Credentials
from dbt.adapters.events.types import (
AdapterImportError,
PluginLoadError,
AdapterRegistered,
)
from dbt.include.global_project import (
PACKAGE_PATH as GLOBAL_PROJECT_PATH,
PROJECT_NAME as GLOBAL_PROJECT_NAME,
)
from dbt.adapters.protocol import AdapterConfig, AdapterProtocol, RelationProtocol
Adapter = AdapterProtocol
class AdapterContainer:
def __init__(self) -> None:
self.lock = threading.Lock()
self.adapters: Dict[str, Adapter] = {}
self.plugins: Dict[str, AdapterPlugin] = {}
# map package names to their include paths
self.packages: Dict[str, Path] = {
GLOBAL_PROJECT_NAME: Path(GLOBAL_PROJECT_PATH),
}
def get_plugin_by_name(self, name: str) -> AdapterPlugin:
with self.lock:
if name in self.plugins:
return self.plugins[name]
names = ", ".join(self.plugins.keys())
message = f"Invalid adapter type {name}! Must be one of {names}"
raise DbtRuntimeError(message)
def get_adapter_class_by_name(self, name: str) -> Type[Adapter]:
plugin = self.get_plugin_by_name(name)
return plugin.adapter
def get_relation_class_by_name(self, name: str) -> Type[RelationProtocol]:
adapter = self.get_adapter_class_by_name(name)
return adapter.Relation
def get_config_class_by_name(self, name: str) -> Type[AdapterConfig]:
adapter = self.get_adapter_class_by_name(name)
return adapter.AdapterSpecificConfigs
def load_plugin(self, name: str) -> Type[Credentials]:
# this doesn't need a lock: in the worst case we'll overwrite packages
# and adapter_type entries with the same value, as they're all
# singletons
try:
# mypy doesn't think modules have any attributes.
mod: Any = import_module("." + name, "dbt.adapters")
except ModuleNotFoundError as exc:
# if we failed to import the target module in particular, inform
# the user about it via a runtime error
if exc.name == "dbt.adapters." + name:
fire_event(AdapterImportError(exc=str(exc)))
raise DbtRuntimeError(f"Could not find adapter type {name}!")
# otherwise, the error had to have come from some underlying
# library. Log the stack trace.
fire_event(PluginLoadError(exc_info=traceback.format_exc()))
raise
plugin: AdapterPlugin = mod.Plugin
plugin_type = plugin.adapter.type()
if plugin_type != name:
raise DbtRuntimeError(
f"Expected to find adapter with type named {name}, got "
f"adapter with type {plugin_type}"
)
with self.lock:
# things do hold the lock to iterate over it so we need it to add
self.plugins[name] = plugin
self.packages[plugin.project_name] = Path(plugin.include_path)
for dep in plugin.dependencies:
self.load_plugin(dep)
return plugin.credentials
def register_adapter(
self,
config: AdapterRequiredConfig,
mp_context: SpawnContext,
adapter_registered_log_level: Optional[EventLevel] = EventLevel.INFO,
) -> None:
adapter_name = config.credentials.type
adapter_type = self.get_adapter_class_by_name(adapter_name)
adapter_version = self._adapter_version(adapter_name)
fire_event(
AdapterRegistered(adapter_name=adapter_name, adapter_version=adapter_version),
level=adapter_registered_log_level,
)
with self.lock:
if adapter_name in self.adapters:
# this shouldn't really happen...
return
adapter: Adapter = adapter_type(config, mp_context) # type: ignore
self.adapters[adapter_name] = adapter
def _adapter_version(self, adapter_name: str) -> str:
try:
raw_version = import_module(f".{adapter_name}.__about__", "dbt.adapters").version
except ModuleNotFoundError:
raw_version = import_module(f".{adapter_name}.__version__", "dbt.adapters").version
return self._validate_version(raw_version)
def _validate_version(self, raw_version: str) -> str:
return VersionSpecifier.from_version_string(raw_version).to_version_string()
def lookup_adapter(self, adapter_name: str) -> Adapter:
return self.adapters[adapter_name]
def reset_adapters(self):
"""Clear the adapters. This is useful for tests, which change configs."""
with self.lock:
for adapter in self.adapters.values():
adapter.cleanup_connections()
self.adapters.clear()
def cleanup_connections(self):
"""Only clean up the adapter connections list without resetting the
actual adapters.
"""
with self.lock:
for adapter in self.adapters.values():
adapter.cleanup_connections()
def get_adapter_plugins(self, name: Optional[str]) -> List[AdapterPlugin]:
"""Iterate over the known adapter plugins. If a name is provided,
iterate in dependency order over the named plugin and its dependencies.
"""
if name is None:
return list(self.plugins.values())
plugins: List[AdapterPlugin] = []
seen: Set[str] = set()
plugin_names: List[str] = [name]
while plugin_names:
plugin_name = plugin_names[0]
plugin_names = plugin_names[1:]
try:
plugin = self.plugins[plugin_name]
except KeyError:
raise DbtInternalError(f"No plugin found for {plugin_name}") from None
plugins.append(plugin)
seen.add(plugin_name)
for dep in plugin.dependencies:
if dep not in seen:
plugin_names.append(dep)
return plugins
def get_adapter_package_names(self, name: Optional[str]) -> List[str]:
package_names: List[str] = [p.project_name for p in self.get_adapter_plugins(name)]
package_names.append(GLOBAL_PROJECT_NAME)
return package_names
def get_include_paths(self, name: Optional[str]) -> List[Path]:
paths = []
for package_name in self.get_adapter_package_names(name):
try:
path = self.packages[package_name]
except KeyError:
raise DbtInternalError(f"No internal package listing found for {package_name}")
paths.append(path)
return paths
def get_adapter_type_names(self, name: Optional[str]) -> List[str]:
return [p.adapter.type() for p in self.get_adapter_plugins(name)]
def get_adapter_constraint_support(self, name: Optional[str]) -> List[str]:
return self.lookup_adapter(name).CONSTRAINT_SUPPORT # type: ignore
FACTORY: AdapterContainer = AdapterContainer()
def register_adapter(
config: AdapterRequiredConfig,
mp_context: SpawnContext,
adapter_registered_log_level: Optional[EventLevel] = EventLevel.INFO,
) -> None:
FACTORY.register_adapter(config, mp_context, adapter_registered_log_level)
def get_adapter(config: AdapterRequiredConfig):
return FACTORY.lookup_adapter(config.credentials.type)
def get_adapter_by_type(adapter_type):
return FACTORY.lookup_adapter(adapter_type)
def reset_adapters():
"""Clear the adapters. This is useful for tests, which change configs."""
FACTORY.reset_adapters()
def cleanup_connections():
"""Only clean up the adapter connections list without resetting the actual
adapters.
"""
FACTORY.cleanup_connections()
def get_adapter_class_by_name(name: str) -> Type[AdapterProtocol]:
return FACTORY.get_adapter_class_by_name(name)
def get_config_class_by_name(name: str) -> Type[AdapterConfig]:
return FACTORY.get_config_class_by_name(name)
def get_relation_class_by_name(name: str) -> Type[RelationProtocol]:
return FACTORY.get_relation_class_by_name(name)
def load_plugin(name: str) -> Type[Credentials]:
return FACTORY.load_plugin(name)
def get_include_paths(name: Optional[str]) -> List[Path]:
return FACTORY.get_include_paths(name)
def get_adapter_package_names(name: Optional[str]) -> List[str]:
return FACTORY.get_adapter_package_names(name)
def get_adapter_type_names(name: Optional[str]) -> List[str]:
return FACTORY.get_adapter_type_names(name)
def get_adapter_constraint_support(name: Optional[str]) -> List[str]:
return FACTORY.get_adapter_constraint_support(name)
@contextmanager
def adapter_management():
reset_adapters()
try:
yield
finally:
cleanup_connections()