Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modernize python type hints for apache_beam. #32872

Merged
merged 2 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sdks/python/apache_beam/dataframe/doctests.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ class _InMemoryResultRecorder(object):
"""

# Class-level value to survive pickling.
_ALL_RESULTS = {} # type: Dict[str, List[Any]]
_ALL_RESULTS: Dict[str, List[Any]] = {}

def __init__(self):
self._id = id(self)
Expand Down
33 changes: 15 additions & 18 deletions sdks/python/apache_beam/dataframe/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ class Session(object):
def __init__(self, bindings=None):
self._bindings = dict(bindings or {})

def evaluate(self, expr): # type: (Expression) -> Any
def evaluate(self, expr: 'Expression') -> Any:
if expr not in self._bindings:
self._bindings[expr] = expr.evaluate_at(self)
return self._bindings[expr]

def lookup(self, expr): # type: (Expression) -> Any
def lookup(self, expr: 'Expression') -> Any:
return self._bindings[expr]


Expand Down Expand Up @@ -251,9 +251,9 @@ def preserves_partition_by(self) -> partitionings.Partitioning:
class PlaceholderExpression(Expression):
"""An expression whose value must be explicitly bound in the session."""
def __init__(
self, # type: PlaceholderExpression
proxy, # type: T
reference=None, # type: Any
self,
proxy: T,
reference: Any = None,
):
"""Initialize a placeholder expression.

Expand Down Expand Up @@ -282,11 +282,7 @@ def preserves_partition_by(self):

class ConstantExpression(Expression):
"""An expression whose value is known at pipeline construction time."""
def __init__(
self, # type: ConstantExpression
value, # type: T
proxy=None # type: Optional[T]
):
def __init__(self, value: T, proxy: Optional[T] = None):
"""Initialize a constant expression.

Args:
Expand Down Expand Up @@ -319,14 +315,15 @@ def preserves_partition_by(self):
class ComputedExpression(Expression):
"""An expression whose value must be computed at pipeline execution time."""
def __init__(
self, # type: ComputedExpression
name, # type: str
func, # type: Callable[...,T]
args, # type: Iterable[Expression]
proxy=None, # type: Optional[T]
_id=None, # type: Optional[str]
requires_partition_by=partitionings.Index(), # type: partitionings.Partitioning
preserves_partition_by=partitionings.Singleton(), # type: partitionings.Partitioning
self,
name: str,
func: Callable[..., T],
args: Iterable[Expression],
proxy: Optional[T] = None,
_id: Optional[str] = None,
requires_partition_by: partitionings.Partitioning = partitionings.Index(),
preserves_partition_by: partitionings.Partitioning = partitionings.
Singleton(),
):
"""Initialize a computed expression.

Expand Down
27 changes: 11 additions & 16 deletions sdks/python/apache_beam/dataframe/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import collections
import logging
from typing import TYPE_CHECKING
from typing import Any
from typing import Dict
from typing import List
Expand All @@ -30,18 +29,16 @@
import apache_beam as beam
from apache_beam import transforms
from apache_beam.dataframe import expressions
from apache_beam.dataframe import frame_base
from apache_beam.dataframe import frames # pylint: disable=unused-import
from apache_beam.dataframe import partitionings
from apache_beam.pvalue import PCollection
from apache_beam.utils import windowed_value

__all__ = [
'DataframeTransform',
]

if TYPE_CHECKING:
# pylint: disable=ungrouped-imports
from apache_beam.pvalue import PCollection

T = TypeVar('T')

TARGET_PARTITION_SIZE = 1 << 23 # 8M
Expand Down Expand Up @@ -108,15 +105,15 @@ def expand(self, input_pcolls):
from apache_beam.dataframe import convert

# Convert inputs to a flat dict.
input_dict = _flatten(input_pcolls) # type: Dict[Any, PCollection]
input_dict: Dict[Any, PCollection] = _flatten(input_pcolls)
proxies = _flatten(self._proxy) if self._proxy is not None else {
tag: None
for tag in input_dict
}
input_frames = {
input_frames: Dict[Any, frame_base.DeferredFrame] = {
k: convert.to_dataframe(pc, proxies[k])
for k, pc in input_dict.items()
} # type: Dict[Any, DeferredFrame] # noqa: F821
} # noqa: F821

# Apply the function.
frames_input = _substitute(input_pcolls, input_frames)
Expand Down Expand Up @@ -152,9 +149,9 @@ def expand(self, inputs):

def _apply_deferred_ops(
self,
inputs, # type: Dict[expressions.Expression, PCollection]
outputs, # type: Dict[Any, expressions.Expression]
): # -> Dict[Any, PCollection]
inputs: Dict[expressions.Expression, PCollection],
outputs: Dict[Any, expressions.Expression],
): # -> Dict[Any, PCollection]
"""Construct a Beam graph that evaluates a set of expressions on a set of
input PCollections.

Expand Down Expand Up @@ -585,11 +582,9 @@ def _concat(parts):


def _flatten(
valueish, # type: Union[T, List[T], Tuple[T], Dict[Any, T]]
root=(), # type: Tuple[Any, ...]
):
# type: (...) -> Mapping[Tuple[Any, ...], T]

valueish: Union[T, List[T], Tuple[T], Dict[Any, T]],
root: Tuple[Any, ...] = (),
) -> Mapping[Tuple[Any, ...], T]:
"""Given a nested structure of dicts, tuples, and lists, return a flat
dictionary where the values are the leafs and the keys are the "paths" to
these leaves.
Expand Down
2 changes: 1 addition & 1 deletion sdks/python/apache_beam/io/avroio_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@

class AvroBase(object):

_temp_files = [] # type: List[str]
_temp_files: List[str] = []

def __init__(self, methodName='runTest'):
super().__init__(methodName)
Expand Down
37 changes: 12 additions & 25 deletions sdks/python/apache_beam/io/fileio.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@
import uuid
from collections import namedtuple
from functools import partial
from typing import TYPE_CHECKING
from typing import Any
from typing import BinaryIO # pylint: disable=unused-import
from typing import Callable
Expand All @@ -115,15 +114,13 @@
from apache_beam.options.value_provider import ValueProvider
from apache_beam.transforms.periodicsequence import PeriodicImpulse
from apache_beam.transforms.userstate import CombiningValueStateSpec
from apache_beam.transforms.window import BoundedWindow
from apache_beam.transforms.window import FixedWindows
from apache_beam.transforms.window import GlobalWindow
from apache_beam.transforms.window import IntervalWindow
from apache_beam.utils.timestamp import MAX_TIMESTAMP
from apache_beam.utils.timestamp import Timestamp

if TYPE_CHECKING:
from apache_beam.transforms.window import BoundedWindow

__all__ = [
'EmptyMatchTreatment',
'MatchFiles',
Expand Down Expand Up @@ -382,8 +379,7 @@ def create_metadata(
mime_type="application/octet-stream",
compression_type=CompressionTypes.AUTO)

def open(self, fh):
# type: (BinaryIO) -> None
def open(self, fh: BinaryIO) -> None:
raise NotImplementedError

def write(self, record):
Expand Down Expand Up @@ -575,8 +571,7 @@ class signature or an instance of FileSink to this parameter. If none is
self._max_num_writers_per_bundle = max_writers_per_bundle

@staticmethod
def _get_sink_fn(input_sink):
# type: (...) -> Callable[[Any], FileSink]
def _get_sink_fn(input_sink) -> Callable[[Any], FileSink]:
if isinstance(input_sink, type) and issubclass(input_sink, FileSink):
return lambda x: input_sink()
elif isinstance(input_sink, FileSink):
Expand All @@ -588,8 +583,7 @@ def _get_sink_fn(input_sink):
return lambda x: TextSink()

@staticmethod
def _get_destination_fn(destination):
# type: (...) -> Callable[[Any], str]
def _get_destination_fn(destination) -> Callable[[Any], str]:
if isinstance(destination, ValueProvider):
return lambda elm: destination.get()
elif callable(destination):
Expand Down Expand Up @@ -757,12 +751,8 @@ def _check_orphaned_files(self, writer_key):


class _WriteShardedRecordsFn(beam.DoFn):

def __init__(self,
base_path,
sink_fn, # type: Callable[[Any], FileSink]
shards # type: int
):
def __init__(
self, base_path, sink_fn: Callable[[Any], FileSink], shards: int):
self.base_path = base_path
self.sink_fn = sink_fn
self.shards = shards
Expand Down Expand Up @@ -805,17 +795,13 @@ def process(


class _AppendShardedDestination(beam.DoFn):
def __init__(
self,
destination, # type: Callable[[Any], str]
shards # type: int
):
def __init__(self, destination: Callable[[Any], str], shards: int):
self.destination_fn = destination
self.shards = shards

# We start the shards for a single destination at an arbitrary point.
self._shard_counter = collections.defaultdict(
lambda: random.randrange(self.shards)) # type: DefaultDict[str, int]
self._shard_counter: DefaultDict[str, int] = collections.defaultdict(
lambda: random.randrange(self.shards))

def _next_shard_for_destination(self, destination):
self._shard_counter[destination] = ((self._shard_counter[destination] + 1) %
Expand All @@ -835,8 +821,9 @@ class _WriteUnshardedRecordsFn(beam.DoFn):
SPILLED_RECORDS = 'spilled_records'
WRITTEN_FILES = 'written_files'

_writers_and_sinks = None # type: Dict[Tuple[str, BoundedWindow], Tuple[BinaryIO, FileSink]]
_file_names = None # type: Dict[Tuple[str, BoundedWindow], str]
_writers_and_sinks: Dict[Tuple[str, BoundedWindow], Tuple[BinaryIO,
FileSink]] = None
_file_names: Dict[Tuple[str, BoundedWindow], str] = None

def __init__(
self,
Expand Down
10 changes: 6 additions & 4 deletions sdks/python/apache_beam/io/gcp/bigquery_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,7 @@ def _insert_load_job(

def _start_job(
self,
request, # type: bigquery.BigqueryJobsInsertRequest
request: 'bigquery.BigqueryJobsInsertRequest',
stream=None,
):
"""Inserts a BigQuery job.
Expand Down Expand Up @@ -1786,9 +1786,11 @@ def generate_bq_job_name(job_name, step_id, job_type, random=None):


def check_schema_equal(
left, right, *, ignore_descriptions=False, ignore_field_order=False):
# type: (Union[bigquery.TableSchema, bigquery.TableFieldSchema], Union[bigquery.TableSchema, bigquery.TableFieldSchema], bool, bool) -> bool

left: Union['bigquery.TableSchema', 'bigquery.TableFieldSchema'],
right: Union['bigquery.TableSchema', 'bigquery.TableFieldSchema'],
*,
ignore_descriptions: bool = False,
ignore_field_order: bool = False) -> bool:
"""Check whether schemas are equivalent.

This comparison function differs from using == to compare TableSchema
Expand Down
6 changes: 4 additions & 2 deletions sdks/python/apache_beam/io/gcp/gcsio.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,10 @@ def create_storage_client(pipeline_options, use_credentials=True):

class GcsIO(object):
"""Google Cloud Storage I/O client."""
def __init__(self, storage_client=None, pipeline_options=None):
# type: (Optional[storage.Client], Optional[Union[dict, PipelineOptions]]) -> None
def __init__(
self,
storage_client: Optional[storage.Client] = None,
pipeline_options: Optional[Union[dict, PipelineOptions]] = None) -> None:
if pipeline_options is None:
pipeline_options = PipelineOptions()
elif isinstance(pipeline_options, dict):
Expand Down
Loading
Loading