Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modernize python type hints for apache_beam. #32872

Merged
merged 2 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sdks/python/apache_beam/dataframe/doctests.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ class _InMemoryResultRecorder(object):
"""

# Class-level value to survive pickling.
_ALL_RESULTS = {} # type: dict[str, list[Any]]
_ALL_RESULTS: dict[str, list[Any]] = {}

def __init__(self):
self._id = id(self)
Expand Down
4 changes: 2 additions & 2 deletions sdks/python/apache_beam/dataframe/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ class Session(object):
def __init__(self, bindings=None):
self._bindings = dict(bindings or {})

def evaluate(self, expr): # type: (Expression) -> Any
def evaluate(self, expr: 'Expression') -> Any:
if expr not in self._bindings:
self._bindings[expr] = expr.evaluate_at(self)
return self._bindings[expr]

def lookup(self, expr): # type: (Expression) -> Any
def lookup(self, expr: 'Expression') -> Any:
return self._bindings[expr]


Expand Down
7 changes: 4 additions & 3 deletions sdks/python/apache_beam/dataframe/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import apache_beam as beam
from apache_beam import transforms
from apache_beam.dataframe import expressions
from apache_beam.dataframe import frame_base
from apache_beam.dataframe import frames # pylint: disable=unused-import
from apache_beam.dataframe import partitionings
from apache_beam.pvalue import PCollection
Expand Down Expand Up @@ -101,15 +102,15 @@ def expand(self, input_pcolls):
from apache_beam.dataframe import convert

# Convert inputs to a flat dict.
input_dict = _flatten(input_pcolls) # type: dict[Any, PCollection]
input_dict: dict[Any, PCollection] = _flatten(input_pcolls)
proxies = _flatten(self._proxy) if self._proxy is not None else {
tag: None
for tag in input_dict
}
input_frames = {
input_frames: dict[Any, frame_base.DeferredFrame] = {
k: convert.to_dataframe(pc, proxies[k])
for k, pc in input_dict.items()
} # type: dict[Any, DeferredFrame] # noqa: F821
} # noqa: F821

# Apply the function.
frames_input = _substitute(input_pcolls, input_frames)
Expand Down
2 changes: 1 addition & 1 deletion sdks/python/apache_beam/io/avroio_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@

class AvroBase(object):

_temp_files = [] # type: List[str]
_temp_files: List[str] = []

def __init__(self, methodName='runTest'):
super().__init__(methodName)
Expand Down
37 changes: 12 additions & 25 deletions sdks/python/apache_beam/io/fileio.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@
import uuid
from collections import namedtuple
from functools import partial
from typing import TYPE_CHECKING
from typing import Any
from typing import BinaryIO # pylint: disable=unused-import
from typing import Callable
Expand All @@ -115,15 +114,13 @@
from apache_beam.options.value_provider import ValueProvider
from apache_beam.transforms.periodicsequence import PeriodicImpulse
from apache_beam.transforms.userstate import CombiningValueStateSpec
from apache_beam.transforms.window import BoundedWindow
from apache_beam.transforms.window import FixedWindows
from apache_beam.transforms.window import GlobalWindow
from apache_beam.transforms.window import IntervalWindow
from apache_beam.utils.timestamp import MAX_TIMESTAMP
from apache_beam.utils.timestamp import Timestamp

if TYPE_CHECKING:
from apache_beam.transforms.window import BoundedWindow

__all__ = [
'EmptyMatchTreatment',
'MatchFiles',
Expand Down Expand Up @@ -382,8 +379,7 @@ def create_metadata(
mime_type="application/octet-stream",
compression_type=CompressionTypes.AUTO)

def open(self, fh):
# type: (BinaryIO) -> None
def open(self, fh: BinaryIO) -> None:
raise NotImplementedError

def write(self, record):
Expand Down Expand Up @@ -575,8 +571,7 @@ class signature or an instance of FileSink to this parameter. If none is
self._max_num_writers_per_bundle = max_writers_per_bundle

@staticmethod
def _get_sink_fn(input_sink):
# type: (...) -> Callable[[Any], FileSink]
def _get_sink_fn(input_sink) -> Callable[[Any], FileSink]:
if isinstance(input_sink, type) and issubclass(input_sink, FileSink):
return lambda x: input_sink()
elif isinstance(input_sink, FileSink):
Expand All @@ -588,8 +583,7 @@ def _get_sink_fn(input_sink):
return lambda x: TextSink()

@staticmethod
def _get_destination_fn(destination):
# type: (...) -> Callable[[Any], str]
def _get_destination_fn(destination) -> Callable[[Any], str]:
if isinstance(destination, ValueProvider):
return lambda elm: destination.get()
elif callable(destination):
Expand Down Expand Up @@ -757,12 +751,8 @@ def _check_orphaned_files(self, writer_key):


class _WriteShardedRecordsFn(beam.DoFn):

def __init__(self,
base_path,
sink_fn, # type: Callable[[Any], FileSink]
shards # type: int
):
def __init__(
self, base_path, sink_fn: Callable[[Any], FileSink], shards: int):
self.base_path = base_path
self.sink_fn = sink_fn
self.shards = shards
Expand Down Expand Up @@ -805,17 +795,13 @@ def process(


class _AppendShardedDestination(beam.DoFn):
def __init__(
self,
destination, # type: Callable[[Any], str]
shards # type: int
):
def __init__(self, destination: Callable[[Any], str], shards: int):
self.destination_fn = destination
self.shards = shards

# We start the shards for a single destination at an arbitrary point.
self._shard_counter = collections.defaultdict(
lambda: random.randrange(self.shards)) # type: DefaultDict[str, int]
self._shard_counter: DefaultDict[str, int] = collections.defaultdict(
lambda: random.randrange(self.shards))

def _next_shard_for_destination(self, destination):
self._shard_counter[destination] = ((self._shard_counter[destination] + 1) %
Expand All @@ -835,8 +821,9 @@ class _WriteUnshardedRecordsFn(beam.DoFn):
SPILLED_RECORDS = 'spilled_records'
WRITTEN_FILES = 'written_files'

_writers_and_sinks = None # type: Dict[Tuple[str, BoundedWindow], Tuple[BinaryIO, FileSink]]
_file_names = None # type: Dict[Tuple[str, BoundedWindow], str]
_writers_and_sinks: Dict[Tuple[str, BoundedWindow], Tuple[BinaryIO,
FileSink]] = None
_file_names: Dict[Tuple[str, BoundedWindow], str] = None

def __init__(
self,
Expand Down
10 changes: 6 additions & 4 deletions sdks/python/apache_beam/io/gcp/bigquery_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,7 @@ def _parse_location_from_exc(content, job_id):

def _start_job(
self,
request, # type: bigquery.BigqueryJobsInsertRequest
request: 'bigquery.BigqueryJobsInsertRequest',
stream=None,
):
"""Inserts a BigQuery job.
Expand Down Expand Up @@ -1802,9 +1802,11 @@ def generate_bq_job_name(job_name, step_id, job_type, random=None):


def check_schema_equal(
left, right, *, ignore_descriptions=False, ignore_field_order=False):
# type: (Union[bigquery.TableSchema, bigquery.TableFieldSchema], Union[bigquery.TableSchema, bigquery.TableFieldSchema], bool, bool) -> bool

left: Union['bigquery.TableSchema', 'bigquery.TableFieldSchema'],
right: Union['bigquery.TableSchema', 'bigquery.TableFieldSchema'],
*,
ignore_descriptions: bool = False,
ignore_field_order: bool = False) -> bool:
"""Check whether schemas are equivalent.

This comparison function differs from using == to compare TableSchema
Expand Down
6 changes: 4 additions & 2 deletions sdks/python/apache_beam/io/gcp/gcsio.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,10 @@ def create_storage_client(pipeline_options, use_credentials=True):

class GcsIO(object):
"""Google Cloud Storage I/O client."""
def __init__(self, storage_client=None, pipeline_options=None):
# type: (Optional[storage.Client], Optional[Union[dict, PipelineOptions]]) -> None
def __init__(
self,
storage_client: Optional[storage.Client] = None,
pipeline_options: Optional[Union[dict, PipelineOptions]] = None) -> None:
if pipeline_options is None:
pipeline_options = PipelineOptions()
elif isinstance(pipeline_options, dict):
Expand Down
54 changes: 26 additions & 28 deletions sdks/python/apache_beam/metrics/monitoring_infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,8 @@ def create_labels(ptransform=None, namespace=None, name=None, pcollection=None):
return labels


def int64_user_counter(namespace, name, metric, ptransform=None):
# type: (...) -> metrics_pb2.MonitoringInfo

def int64_user_counter(
namespace, name, metric, ptransform=None) -> metrics_pb2.MonitoringInfo:
"""Return the counter monitoring info for the specifed URN, metric and labels.

Args:
Expand All @@ -199,9 +198,12 @@ def int64_user_counter(namespace, name, metric, ptransform=None):
USER_COUNTER_URN, SUM_INT64_TYPE, metric, labels)


def int64_counter(urn, metric, ptransform=None, pcollection=None, labels=None):
# type: (...) -> metrics_pb2.MonitoringInfo

def int64_counter(
urn,
metric,
ptransform=None,
pcollection=None,
labels=None) -> metrics_pb2.MonitoringInfo:
"""Return the counter monitoring info for the specifed URN, metric and labels.

Args:
Expand All @@ -217,9 +219,8 @@ def int64_counter(urn, metric, ptransform=None, pcollection=None, labels=None):
return create_monitoring_info(urn, SUM_INT64_TYPE, metric, labels)


def int64_user_distribution(namespace, name, metric, ptransform=None):
# type: (...) -> metrics_pb2.MonitoringInfo

def int64_user_distribution(
namespace, name, metric, ptransform=None) -> metrics_pb2.MonitoringInfo:
"""Return the distribution monitoring info for the URN, metric and labels.

Args:
Expand All @@ -234,9 +235,11 @@ def int64_user_distribution(namespace, name, metric, ptransform=None):
USER_DISTRIBUTION_URN, DISTRIBUTION_INT64_TYPE, payload, labels)


def int64_distribution(urn, metric, ptransform=None, pcollection=None):
# type: (...) -> metrics_pb2.MonitoringInfo

def int64_distribution(
urn,
metric,
ptransform=None,
pcollection=None) -> metrics_pb2.MonitoringInfo:
"""Return a distribution monitoring info for the URN, metric and labels.

Args:
Expand All @@ -251,9 +254,8 @@ def int64_distribution(urn, metric, ptransform=None, pcollection=None):
return create_monitoring_info(urn, DISTRIBUTION_INT64_TYPE, payload, labels)


def int64_user_gauge(namespace, name, metric, ptransform=None):
# type: (...) -> metrics_pb2.MonitoringInfo

def int64_user_gauge(
namespace, name, metric, ptransform=None) -> metrics_pb2.MonitoringInfo:
"""Return the gauge monitoring info for the URN, metric and labels.

Args:
Expand All @@ -276,9 +278,7 @@ def int64_user_gauge(namespace, name, metric, ptransform=None):
USER_GAUGE_URN, LATEST_INT64_TYPE, payload, labels)


def int64_gauge(urn, metric, ptransform=None):
# type: (...) -> metrics_pb2.MonitoringInfo

def int64_gauge(urn, metric, ptransform=None) -> metrics_pb2.MonitoringInfo:
"""Return the gauge monitoring info for the URN, metric and labels.

Args:
Expand Down Expand Up @@ -320,9 +320,8 @@ def user_set_string(namespace, name, metric, ptransform=None):
USER_STRING_SET_URN, STRING_SET_TYPE, metric, labels)


def create_monitoring_info(urn, type_urn, payload, labels=None):
# type: (...) -> metrics_pb2.MonitoringInfo

def create_monitoring_info(
urn, type_urn, payload, labels=None) -> metrics_pb2.MonitoringInfo:
"""Return the gauge monitoring info for the URN, type, metric and labels.

Args:
Expand Down Expand Up @@ -366,9 +365,9 @@ def is_user_monitoring_info(monitoring_info_proto):
return monitoring_info_proto.urn in USER_METRIC_URNS


def extract_metric_result_map_value(monitoring_info_proto):
# type: (...) -> Union[None, int, DistributionResult, GaugeResult, set]

def extract_metric_result_map_value(
monitoring_info_proto
) -> Union[None, int, DistributionResult, GaugeResult, set]:
"""Returns the relevant GaugeResult, DistributionResult or int value for
counter metric, set for StringSet metric.

Expand Down Expand Up @@ -408,14 +407,13 @@ def get_step_name(monitoring_info_proto):
return monitoring_info_proto.labels.get(PTRANSFORM_LABEL)


def to_key(monitoring_info_proto):
# type: (metrics_pb2.MonitoringInfo) -> FrozenSet[Hashable]

def to_key(
monitoring_info_proto: metrics_pb2.MonitoringInfo) -> FrozenSet[Hashable]:
"""Returns a key based on the URN and labels.

This is useful in maps to prevent reporting the same MonitoringInfo twice.
"""
key_items = list(monitoring_info_proto.labels.items()) # type: List[Hashable]
key_items: List[Hashable] = list(monitoring_info_proto.labels.items())
key_items.append(monitoring_info_proto.urn)
return frozenset(key_items)

Expand Down
26 changes: 9 additions & 17 deletions sdks/python/apache_beam/options/pipeline_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,8 +267,7 @@ def __getstate__(self):
return self.__dict__

@classmethod
def _add_argparse_args(cls, parser):
# type: (_BeamArgumentParser) -> None
def _add_argparse_args(cls, parser: _BeamArgumentParser) -> None:
# Override this in subclasses to provide options.
pass

Expand Down Expand Up @@ -317,11 +316,8 @@ def from_dictionary(cls, options):
def get_all_options(
self,
drop_default=False,
add_extra_args_fn=None, # type: Optional[Callable[[_BeamArgumentParser], None]]
retain_unknown_options=False
):
# type: (...) -> Dict[str, Any]

add_extra_args_fn: Optional[Callable[[_BeamArgumentParser], None]] = None,
retain_unknown_options=False) -> Dict[str, Any]:
"""Returns a dictionary of all defined arguments.

Returns a dictionary of all defined arguments (arguments that are defined in
Expand Down Expand Up @@ -446,9 +442,7 @@ def from_urn(key):
def display_data(self):
return self.get_all_options(drop_default=True, retain_unknown_options=True)

def view_as(self, cls):
# type: (Type[PipelineOptionsT]) -> PipelineOptionsT

def view_as(self, cls: Type[PipelineOptionsT]) -> PipelineOptionsT:
"""Returns a view of current object as provided PipelineOption subclass.

Example Usage::
Expand Down Expand Up @@ -487,13 +481,11 @@ def view_as(self, cls):
view._all_options = self._all_options
return view

def _visible_option_list(self):
# type: () -> List[str]
def _visible_option_list(self) -> List[str]:
return sorted(
option for option in dir(self._visible_options) if option[0] != '_')

def __dir__(self):
# type: () -> List[str]
def __dir__(self) -> List[str]:
return sorted(
dir(type(self)) + list(self.__dict__) + self._visible_option_list())

Expand Down Expand Up @@ -643,9 +635,9 @@ def additional_option_ptransform_fn():


# Optional type checks that aren't enabled by default.
additional_type_checks = {
additional_type_checks: Dict[str, Callable[[], None]] = {
'ptransform_fn': additional_option_ptransform_fn,
} # type: Dict[str, Callable[[], None]]
}


def enable_all_additional_type_checks():
Expand Down Expand Up @@ -1840,7 +1832,7 @@ class OptionsContext(object):

Can also be used as a decorator.
"""
overrides = [] # type: List[Dict[str, Any]]
overrides: List[Dict[str, Any]] = []

def __init__(self, **options):
self.options = options
Expand Down
Loading
Loading