Skip to content

Commit

Permalink
[AIRFLOW-4335] Add default num_retries to GCP connection (apache#5117)
Browse files Browse the repository at this point in the history
Add default num_retries to GCP connection

(cherry picked from commit 16e7e61)
  • Loading branch information
ryanyuan authored and ashb committed Apr 29, 2019
1 parent ec472fb commit 272e2df
Show file tree
Hide file tree
Showing 19 changed files with 204 additions and 141 deletions.
63 changes: 35 additions & 28 deletions airflow/contrib/hooks/bigquery_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(self,
gcp_conn_id=bigquery_conn_id, delegate_to=delegate_to)
self.use_legacy_sql = use_legacy_sql
self.location = location
self.num_retries = self._get_field('num_retries', 5)

def get_conn(self):
"""
Expand All @@ -72,6 +73,7 @@ def get_conn(self):
project_id=project,
use_legacy_sql=self.use_legacy_sql,
location=self.location,
num_retries=self.num_retries
)

def get_service(self):
Expand Down Expand Up @@ -134,7 +136,7 @@ def table_exists(self, project_id, dataset_id, table_id):
try:
service.tables().get(
projectId=project_id, datasetId=dataset_id,
tableId=table_id).execute()
tableId=table_id).execute(num_retries=self.num_retries)
return True
except HttpError as e:
if e.resp['status'] == '404':
Expand Down Expand Up @@ -207,7 +209,8 @@ def __init__(self,
project_id,
use_legacy_sql=True,
api_resource_configs=None,
location=None):
location=None,
num_retries=None):

self.service = service
self.project_id = project_id
Expand All @@ -218,6 +221,7 @@ def __init__(self,
if api_resource_configs else {}
self.running_job_id = None
self.location = location
self.num_retries = num_retries

def create_empty_table(self,
project_id,
Expand All @@ -228,7 +232,7 @@ def create_empty_table(self,
cluster_fields=None,
labels=None,
view=None,
num_retries=5):
num_retries=None):
"""
Creates a new, empty table in the dataset.
To create a view, which is defined by a SQL query, parse a dictionary to 'view' kwarg
Expand Down Expand Up @@ -301,6 +305,8 @@ def create_empty_table(self,
if view:
table_resource['view'] = view

num_retries = num_retries if num_retries else self.num_retries

self.log.info('Creating Table %s:%s.%s',
project_id, dataset_id, table_id)

Expand Down Expand Up @@ -504,7 +510,7 @@ def create_external_table(self,
projectId=project_id,
datasetId=dataset_id,
body=table_resource
).execute()
).execute(num_retries=self.num_retries)

self.log.info('External table created successfully: %s',
external_project_dataset_table)
Expand Down Expand Up @@ -612,7 +618,7 @@ def patch_table(self,
projectId=project_id,
datasetId=dataset_id,
tableId=table_id,
body=table_resource).execute()
body=table_resource).execute(num_retries=self.num_retries)

self.log.info('Table patched successfully: %s:%s.%s',
project_id, dataset_id, table_id)
Expand Down Expand Up @@ -1235,7 +1241,7 @@ def run_with_configuration(self, configuration):
# Send query and wait for reply.
query_reply = jobs \
.insert(projectId=self.project_id, body=job_data) \
.execute()
.execute(num_retries=self.num_retries)
self.running_job_id = query_reply['jobReference']['jobId']
if 'location' in query_reply['jobReference']:
location = query_reply['jobReference']['location']
Expand All @@ -1250,11 +1256,11 @@ def run_with_configuration(self, configuration):
job = jobs.get(
projectId=self.project_id,
jobId=self.running_job_id,
location=location).execute()
location=location).execute(num_retries=self.num_retries)
else:
job = jobs.get(
projectId=self.project_id,
jobId=self.running_job_id).execute()
jobId=self.running_job_id).execute(num_retries=self.num_retries)
if job['status']['state'] == 'DONE':
keep_polling_job = False
# Check if job had errors.
Expand Down Expand Up @@ -1286,10 +1292,10 @@ def poll_job_complete(self, job_id):
if self.location:
job = jobs.get(projectId=self.project_id,
jobId=job_id,
location=self.location).execute()
location=self.location).execute(num_retries=self.num_retries)
else:
job = jobs.get(projectId=self.project_id,
jobId=job_id).execute()
jobId=job_id).execute(num_retries=self.num_retries)
if job['status']['state'] == 'DONE':
return True
except HttpError as err:
Expand All @@ -1316,11 +1322,11 @@ def cancel_query(self):
jobs.cancel(
projectId=self.project_id,
jobId=self.running_job_id,
location=self.location).execute()
location=self.location).execute(num_retries=self.num_retries)
else:
jobs.cancel(
projectId=self.project_id,
jobId=self.running_job_id).execute()
jobId=self.running_job_id).execute(num_retries=self.num_retries)
else:
self.log.info('No running BigQuery jobs to cancel.')
return
Expand Down Expand Up @@ -1357,7 +1363,7 @@ def get_schema(self, dataset_id, table_id):
"""
tables_resource = self.service.tables() \
.get(projectId=self.project_id, datasetId=dataset_id, tableId=table_id) \
.execute()
.execute(num_retries=self.num_retries)
return tables_resource['schema']

def get_tabledata(self, dataset_id, table_id,
Expand Down Expand Up @@ -1390,7 +1396,7 @@ def get_tabledata(self, dataset_id, table_id,
projectId=self.project_id,
datasetId=dataset_id,
tableId=table_id,
**optional_params).execute())
**optional_params).execute(num_retries=self.num_retries))

def run_table_delete(self, deletion_dataset_table,
ignore_if_missing=False):
Expand All @@ -1417,7 +1423,7 @@ def run_table_delete(self, deletion_dataset_table,
.delete(projectId=deletion_project,
datasetId=deletion_dataset,
tableId=deletion_table) \
.execute()
.execute(num_retries=self.num_retries)
self.log.info('Deleted table %s:%s.%s.', deletion_project,
deletion_dataset, deletion_table)
except HttpError:
Expand Down Expand Up @@ -1446,7 +1452,7 @@ def run_table_upsert(self, dataset_id, table_resource, project_id=None):
table_id = table_resource['tableReference']['tableId']
project_id = project_id if project_id is not None else self.project_id
tables_list_resp = self.service.tables().list(
projectId=project_id, datasetId=dataset_id).execute()
projectId=project_id, datasetId=dataset_id).execute(num_retries=self.num_retries)
while True:
for table in tables_list_resp.get('tables', []):
if table['tableReference']['tableId'] == table_id:
Expand All @@ -1457,14 +1463,14 @@ def run_table_upsert(self, dataset_id, table_resource, project_id=None):
projectId=project_id,
datasetId=dataset_id,
tableId=table_id,
body=table_resource).execute()
body=table_resource).execute(num_retries=self.num_retries)
# If there is a next page, we need to check the next page.
if 'nextPageToken' in tables_list_resp:
tables_list_resp = self.service.tables()\
.list(projectId=project_id,
datasetId=dataset_id,
pageToken=tables_list_resp['nextPageToken'])\
.execute()
.execute(num_retries=self.num_retries)
# If there is no next page, then the table doesn't exist.
else:
# do insert
Expand All @@ -1473,7 +1479,7 @@ def run_table_upsert(self, dataset_id, table_resource, project_id=None):
return self.service.tables().insert(
projectId=project_id,
datasetId=dataset_id,
body=table_resource).execute()
body=table_resource).execute(num_retries=self.num_retries)

def run_grant_dataset_view_access(self,
source_dataset,
Expand Down Expand Up @@ -1508,7 +1514,7 @@ def run_grant_dataset_view_access(self,
# we don't want to clobber any existing accesses, so we have to get
# info on the dataset before we can add view access
source_dataset_resource = self.service.datasets().get(
projectId=source_project, datasetId=source_dataset).execute()
projectId=source_project, datasetId=source_dataset).execute(num_retries=self.num_retries)
access = source_dataset_resource[
'access'] if 'access' in source_dataset_resource else []
view_access = {
Expand All @@ -1530,7 +1536,7 @@ def run_grant_dataset_view_access(self,
datasetId=source_dataset,
body={
'access': access
}).execute()
}).execute(num_retries=self.num_retries)
else:
# if view is already in access, do nothing.
self.log.info(
Expand Down Expand Up @@ -1596,7 +1602,7 @@ def create_empty_dataset(self, dataset_id="", project_id="",
try:
self.service.datasets().insert(
projectId=dataset_project_id,
body=dataset_reference).execute()
body=dataset_reference).execute(num_retries=self.num_retries)
self.log.info('Dataset created successfully: In project %s '
'Dataset %s', dataset_project_id, dataset_id)

Expand All @@ -1621,7 +1627,7 @@ def delete_dataset(self, project_id, dataset_id):
try:
self.service.datasets().delete(
projectId=project_id,
datasetId=dataset_id).execute()
datasetId=dataset_id).execute(num_retries=self.num_retries)
self.log.info('Dataset deleted successfully: In project %s '
'Dataset %s', project_id, dataset_id)

Expand Down Expand Up @@ -1654,7 +1660,7 @@ def get_dataset(self, dataset_id, project_id=None):

try:
dataset_resource = self.service.datasets().get(
datasetId=dataset_id, projectId=dataset_project_id).execute()
datasetId=dataset_id, projectId=dataset_project_id).execute(num_retries=self.num_retries)
self.log.info("Dataset Resource: %s", dataset_resource)
except HttpError as err:
raise AirflowException(
Expand Down Expand Up @@ -1701,7 +1707,7 @@ def get_datasets_list(self, project_id=None):

try:
datasets_list = self.service.datasets().list(
projectId=dataset_project_id).execute()['datasets']
projectId=dataset_project_id).execute(num_retries=self.num_retries)['datasets']
self.log.info("Datasets List: %s", datasets_list)

except HttpError as err:
Expand Down Expand Up @@ -1765,7 +1771,7 @@ def insert_all(self, project_id, dataset_id, table_id,
resp = self.service.tabledata().insertAll(
projectId=dataset_project_id, datasetId=dataset_id,
tableId=table_id, body=body
).execute()
).execute(num_retries=self.num_retries)

if 'insertErrors' not in resp:
self.log.info(
Expand Down Expand Up @@ -1796,12 +1802,13 @@ class BigQueryCursor(BigQueryBaseCursor):
https://github.com/dropbox/PyHive/blob/master/pyhive/common.py
"""

def __init__(self, service, project_id, use_legacy_sql=True, location=None):
def __init__(self, service, project_id, use_legacy_sql=True, location=None, num_retries=None):
super(BigQueryCursor, self).__init__(
service=service,
project_id=project_id,
use_legacy_sql=use_legacy_sql,
location=location,
num_retries=num_retries
)
self.buffersize = None
self.page_token = None
Expand Down Expand Up @@ -1869,7 +1876,7 @@ def next(self):
query_results = (self.service.jobs().getQueryResults(
projectId=self.project_id,
jobId=self.job_id,
pageToken=self.page_token).execute())
pageToken=self.page_token).execute(num_retries=self.num_retries))

if 'rows' in query_results and query_results['rows']:
self.page_token = query_results.get('pageToken')
Expand Down
Loading

0 comments on commit 272e2df

Please sign in to comment.