Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix all Amazon Provider MyPy errors #20935

Merged
merged 1 commit into from
Jan 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion airflow/contrib/sensors/aws_redshift_cluster_sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@

import warnings

from airflow.providers.amazon.aws.sensors.redshift_cluster import AwsRedshiftClusterSensor
from airflow.providers.amazon.aws.sensors.redshift_cluster import (
RedshiftClusterSensor as AwsRedshiftClusterSensor,
)

warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.redshift_cluster`.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

CLUSTER_NAME = 'fargate-demo'
FARGATE_PROFILE_NAME = f'{CLUSTER_NAME}-profile'
SELECTORS = environ.get('FARGATE_SELECTORS', [{'namespace': 'default'}])
SELECTORS = [{'namespace': 'default'}]

ROLE_ARN = environ.get('EKS_DEMO_ROLE_ARN', 'arn:aws:iam::123456789012:role/role_name')
SUBNETS = environ.get('EKS_DEMO_SUBNETS', 'subnet-12345ab subnet-67890cd').split(' ')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,4 +94,10 @@
target_state=ClusterStates.NONEXISTENT,
)

create_cluster_and_nodegroup >> await_create_nodegroup >> start_pod >> delete_all >> await_delete_cluster
(
create_cluster_and_nodegroup
>> await_create_nodegroup
>> start_pod
>> delete_all
>> await_delete_cluster
)
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
) as dag:

# [START howto_operator_eks_create_cluster]
# Create an Amazon EKS Cluster control plane without attaching a compute service.
# Create an Amazon EKS Cluster control plane without attaching compute service.
create_cluster = EksCreateClusterOperator(
task_id='create_eks_cluster',
cluster_role_arn=ROLE_ARN,
Expand Down
12 changes: 6 additions & 6 deletions airflow/providers/amazon/aws/hooks/eks.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from contextlib import contextmanager
from enum import Enum
from functools import partial
from typing import Callable, Dict, List, Optional
from typing import Callable, Dict, Generator, List, Optional

import yaml
from botocore.exceptions import ClientError
Expand Down Expand Up @@ -127,7 +127,7 @@ def create_nodegroup(
clusterName: str,
nodegroupName: str,
subnets: List[str],
nodeRole: str,
nodeRole: Optional[str],
*,
tags: Optional[Dict] = None,
**kwargs,
Expand Down Expand Up @@ -178,8 +178,8 @@ def create_nodegroup(
def create_fargate_profile(
self,
clusterName: str,
fargateProfileName: str,
podExecutionRoleArn: str,
fargateProfileName: Optional[str],
podExecutionRoleArn: Optional[str],
selectors: List,
**kwargs,
) -> Dict:
Expand Down Expand Up @@ -536,10 +536,10 @@ def _list_all(self, api_call: Callable, response_key: str, verbose: bool) -> Lis
def generate_config_file(
self,
eks_cluster_name: str,
pod_namespace: str,
pod_namespace: Optional[str],
pod_username: Optional[str] = None,
pod_context: Optional[str] = None,
) -> str:
) -> Generator[str, None, None]:
"""
Writes the kubeconfig file given an EKS Cluster.

Expand Down
9 changes: 6 additions & 3 deletions airflow/providers/amazon/aws/hooks/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@
import tempfile
import time
import warnings
from datetime import datetime
from functools import partial
from typing import Any, Callable, Dict, Generator, List, Optional, Set
from typing import Any, Callable, Dict, Generator, List, Optional, Set, cast

from botocore.exceptions import ClientError

Expand Down Expand Up @@ -95,7 +96,7 @@ def secondary_training_status_changed(current_job_description: dict, prev_job_de


def secondary_training_status_message(
job_description: Dict[str, List[dict]], prev_description: Optional[dict]
job_description: Dict[str, List[Any]], prev_description: Optional[dict]
) -> str:
"""
Returns a string contains start time and the secondary training job status message.
Expand Down Expand Up @@ -125,7 +126,9 @@ def secondary_training_status_message(
status_strs = []
for transition in transitions_to_print:
message = transition['StatusMessage']
time_str = timezone.convert_to_utc(job_description['LastModifiedTime']).strftime('%Y-%m-%d %H:%M:%S')
time_str = timezone.convert_to_utc(cast(datetime, job_description['LastModifiedTime'])).strftime(
'%Y-%m-%d %H:%M:%S'
)
status_strs.append(f"{time_str} {transition['Status']} - {message}")

return '\n'.join(status_strs)
Expand Down
31 changes: 22 additions & 9 deletions airflow/providers/amazon/aws/operators/eks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@

"""This module contains Amazon EKS operators."""
import warnings
from ast import literal_eval
from time import sleep
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union, cast

from airflow import AirflowException
from airflow.models import BaseOperator
Expand Down Expand Up @@ -138,13 +139,13 @@ def __init__(
self,
cluster_name: str,
cluster_role_arn: str,
resources_vpc_config: Dict,
resources_vpc_config: Dict[str, Any],
compute: Optional[str] = DEFAULT_COMPUTE_TYPE,
create_cluster_kwargs: Optional[Dict] = None,
nodegroup_name: Optional[str] = DEFAULT_NODEGROUP_NAME,
nodegroup_name: str = DEFAULT_NODEGROUP_NAME,
nodegroup_role_arn: Optional[str] = None,
create_nodegroup_kwargs: Optional[Dict] = None,
fargate_profile_name: Optional[str] = DEFAULT_FARGATE_PROFILE_NAME,
fargate_profile_name: str = DEFAULT_FARGATE_PROFILE_NAME,
fargate_pod_execution_role_arn: Optional[str] = None,
fargate_selectors: Optional[List] = None,
create_fargate_profile_kwargs: Optional[Dict] = None,
Expand Down Expand Up @@ -222,7 +223,7 @@ def execute(self, context: 'Context'):
eks_hook.create_nodegroup(
clusterName=self.cluster_name,
nodegroupName=self.nodegroup_name,
subnets=self.resources_vpc_config.get('subnetIds'),
subnets=cast(List[str], self.resources_vpc_config.get('subnetIds')),
nodeRole=self.nodegroup_role_arn,
**self.create_nodegroup_kwargs,
)
Expand Down Expand Up @@ -281,29 +282,41 @@ class EksCreateNodegroupOperator(BaseOperator):
def __init__(
self,
cluster_name: str,
nodegroup_subnets: List[str],
nodegroup_subnets: Union[List[str], str],
nodegroup_role_arn: str,
nodegroup_name: Optional[str] = DEFAULT_NODEGROUP_NAME,
nodegroup_name: str = DEFAULT_NODEGROUP_NAME,
create_nodegroup_kwargs: Optional[Dict] = None,
aws_conn_id: str = DEFAULT_CONN_ID,
region: Optional[str] = None,
**kwargs,
) -> None:
self.cluster_name = cluster_name
self.nodegroup_subnets = nodegroup_subnets
self.nodegroup_role_arn = nodegroup_role_arn
self.nodegroup_name = nodegroup_name
self.create_nodegroup_kwargs = create_nodegroup_kwargs or {}
self.aws_conn_id = aws_conn_id
self.region = region
nodegroup_subnets_list: List[str] = []
if isinstance(nodegroup_subnets, str):
if nodegroup_subnets != "":
try:
nodegroup_subnets_list = cast(List, literal_eval(nodegroup_subnets))
except ValueError:
self.log.warning(
"The nodegroup_subnets should be List or string representing "
"Python list and is %s. Defaulting to []",
nodegroup_subnets,
)
else:
nodegroup_subnets_list = nodegroup_subnets
self.nodegroup_subnets = nodegroup_subnets_list
super().__init__(**kwargs)

def execute(self, context: 'Context'):
eks_hook = EksHook(
aws_conn_id=self.aws_conn_id,
region_name=self.region,
)

eks_hook.create_nodegroup(
clusterName=self.cluster_name,
nodegroupName=self.nodegroup_name,
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/amazon/aws/operators/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def __init__(
self.client_request_token = client_request_token or str(uuid4())
self.poll_interval = poll_interval
self.max_tries = max_tries
self.job_id = None
self.job_id: Optional[str] = None

@cached_property
def hook(self) -> EmrContainerHook:
Expand Down
6 changes: 3 additions & 3 deletions airflow/providers/amazon/aws/operators/glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,10 @@ def __init__(
*,
job_name: str = 'aws_glue_default_job',
job_desc: str = 'AWS Glue Job with Airflow',
script_location: Optional[str] = None,
script_location: str,
potiuk marked this conversation as resolved.
Show resolved Hide resolved
concurrent_run_limit: Optional[int] = None,
script_args: Optional[dict] = None,
retry_limit: Optional[int] = None,
retry_limit: int = 0,
num_of_dpus: int = 6,
aws_conn_id: str = 'aws_default',
region_name: Optional[str] = None,
Expand Down Expand Up @@ -113,7 +113,7 @@ def execute(self, context: 'Context'):

:return: the id of the current glue job.
"""
if self.script_location and not self.script_location.startswith(self.s3_protocol):
if not self.script_location.startswith(self.s3_protocol):
s3_hook = S3Hook(aws_conn_id=self.aws_conn_id)
script_name = os.path.basename(self.script_location)
s3_hook.load_file(
Expand Down
5 changes: 3 additions & 2 deletions airflow/providers/amazon/aws/operators/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,8 +523,9 @@ def execute(self, context: 'Context'):
close_fds=True,
) as process:
self.log.info("Output:")
for line in iter(process.stdout.readline, b''):
self.log.info(line.decode(self.output_encoding).rstrip())
if process.stdout is not None:
for line in iter(process.stdout.readline, b''):
self.log.info(line.decode(self.output_encoding).rstrip())

process.wait()

Expand Down