Skip to content

Commit

Permalink
lint changes in few files
Browse files Browse the repository at this point in the history
  • Loading branch information
bharatjetti committed Oct 8, 2024
1 parent 5401505 commit 441f63f
Show file tree
Hide file tree
Showing 8 changed files with 12 additions and 12 deletions.
5 changes: 3 additions & 2 deletions official/nlp/modeling/layers/transformer_encoder_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class RMSNorm(tf_keras.layers.Layer):

def __init__(
self,
axis: Union[int , Sequence[int]] = -1,
axis: Union[int, Sequence[int]] = -1,
epsilon: float = 1e-6,
**kwargs
):
Expand All @@ -43,7 +43,8 @@ def __init__(
self.axis = [axis] if isinstance(axis, int) else axis
self.epsilon = epsilon

def build(self, input_shape: Union[tf.TensorShape, Sequence[Union[int, None]]]):
def build(self,
input_shape: Union[tf.TensorShape, Sequence[Union[int, None]]]):
input_shape = tf.TensorShape(input_shape)
scale_shape = [1] * input_shape.rank
for dim in self.axis:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,12 @@
"""Defines base abstract uplift network layers."""

import abc
from typing import Union

import tensorflow as tf, tf_keras

from official.recommendation.uplift import types

from typing import Union


class BaseTwoTowerUpliftNetwork(tf_keras.layers.Layer, metaclass=abc.ABCMeta):
"""Abstract class for uplift layers that compute control and treatment logits.
Expand Down
3 changes: 2 additions & 1 deletion official/recommendation/uplift/metrics/label_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@

"""Keras metric for computing the label mean sliced by treatment group."""

from typing import Union

import tensorflow as tf, tf_keras

from official.recommendation.uplift import types
from official.recommendation.uplift.metrics import treatment_sliced_metric

from typing import Union

@tf_keras.utils.register_keras_serializable(package="Uplift")
class LabelMean(tf_keras.metrics.Metric):
Expand Down
2 changes: 1 addition & 1 deletion official/recommendation/uplift/metrics/label_variance.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
# limitations under the License.

"""Keras metric for computing the label variance sliced by treatment group."""
from typing import Union

import tensorflow as tf, tf_keras

from official.recommendation.uplift import types
from official.recommendation.uplift.metrics import treatment_sliced_metric
from official.recommendation.uplift.metrics import variance

from typing import Union

@tf_keras.utils.register_keras_serializable(package="Uplift")
class LabelVariance(tf_keras.metrics.Metric):
Expand Down
2 changes: 1 addition & 1 deletion official/recommendation/uplift/metrics/metric_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class SlicedMetricConfig(base_config.Config):

slicing_feature: Union[str, None] = None
slicing_spec: Union[Mapping[str, int], None] = None
slicing_feature_dtype: Union[str, None ]= None
slicing_feature_dtype: Union[str, None] = None

def __post_init__(
self, default_params: dict[str, Any], restrictions: list[str]
Expand Down
3 changes: 1 addition & 2 deletions official/recommendation/uplift/metrics/sliced_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,10 @@
"""Keras metric for reporting metrics sliced by a feature."""

import copy
from typing import Union

import tensorflow as tf, tf_keras

from typing import Union


class SlicedMetric(tf_keras.metrics.Metric):
"""A metric sliced by integer, boolean, or string features.
Expand Down
3 changes: 2 additions & 1 deletion official/recommendation/uplift/metrics/uplift_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@

"""Keras metric for computing the mean uplift sliced by treatment group."""

from typing import Union

import tensorflow as tf, tf_keras

from official.recommendation.uplift import types
from official.recommendation.uplift.metrics import treatment_sliced_metric

from typing import Union

@tf_keras.utils.register_keras_serializable(package="Uplift")
class UpliftMean(tf_keras.metrics.Metric):
Expand Down
3 changes: 1 addition & 2 deletions official/recommendation/uplift/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,11 @@
# limitations under the License.

"""Defines types used by the keras uplift modeling library."""
from typing import Union

import tensorflow as tf, tf_keras
from typing import Union

TensorType = Union[tf.Tensor, tf.SparseTensor, tf.RaggedTensor]

ListOfTensors = list[TensorType]
TupleOfTensors = tuple[TensorType, ...]
DictOfTensors = dict[str, TensorType]
Expand Down

0 comments on commit 441f63f

Please sign in to comment.