Skip to content

Commit

Permalink
Update typing
Browse files Browse the repository at this point in the history
  • Loading branch information
greg-el committed Aug 27, 2024
1 parent 6d8d588 commit f58d845
Show file tree
Hide file tree
Showing 15 changed files with 94 additions and 51 deletions.
5 changes: 5 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ jobs:
- name: Tests
run: |
pytest --cov=snowplow_tracker --cov-report=xml
- name: MyPy
run: |
python -m pip install mypy
mypy snowplow_tracker --exclude '/test'
- name: Demo
run: |
Expand Down
6 changes: 5 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,5 +65,9 @@
"Programming Language :: Python :: 3.12",
"Operating System :: OS Independent",
],
install_requires=["requests>=2.25.1,<3.0", "typing_extensions>=3.7.4"],
install_requires=[
"requests>=2.25.1,<3.0",
"types-requests>=2.25.1,<3.0",
"typing_extensions>=3.7.4",
],
)
2 changes: 1 addition & 1 deletion snowplow_tracker/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from snowplow_tracker import _version, SelfDescribingJson

VERSION = "py-%s" % _version.__version__
DEFAULT_ENCODE_BASE64 = True
DEFAULT_ENCODE_BASE64: bool = True # Type hint required for Python 3.6 MyPy check
BASE_SCHEMA_PATH = "iglu:com.snowplowanalytics.snowplow"
MOBILE_SCHEMA_PATH = "iglu:com.snowplowanalytics.mobile"
SCHEMA_TAG = "jsonschema"
Expand Down
2 changes: 1 addition & 1 deletion snowplow_tracker/contracts.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def _get_parameter_name() -> str:
match = _MATCH_FIRST_PARAMETER_REGEX.search(code)
if not match:
return "Unnamed parameter"
return match.groups(0)[0]
return str(match.groups(0)[0])


def _check_form_element(element: Dict[str, Any]) -> bool:
Expand Down
35 changes: 28 additions & 7 deletions snowplow_tracker/emitters.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import threading
import requests
import random
from typing import Optional, Union, Tuple, Dict
from typing import Optional, Union, Tuple, Dict, cast, Callable
from queue import Queue

from snowplow_tracker.self_describing_json import SelfDescribingJson
Expand All @@ -31,6 +31,7 @@
Method,
SuccessCallback,
FailureCallback,
EmitterProtocol,
)
from snowplow_tracker.contracts import one_of
from snowplow_tracker.event_store import EventStore, InMemoryEventStore
Expand All @@ -48,7 +49,20 @@
METHODS = {"get", "post"}


class Emitter(object):
# Unifes the two request methods under one interface
class Requester:
post: Callable
get: Callable

def __init__(self, post: Callable, get: Callable):
# 3.6 MyPy compatibility:
# error: Cannot assign to a method
# https://github.com/python/mypy/issues/2427
setattr(self, "post", post)
setattr(self, "get", get)


class Emitter(EmitterProtocol):
"""
Synchronously send Snowplow events to a Snowplow collector
Supports both GET and POST requests
Expand Down Expand Up @@ -151,12 +165,15 @@ def __init__(
self.retry_timer = FlushTimer(emitter=self, repeating=False)

self.max_retry_delay_seconds = max_retry_delay_seconds
self.retry_delay = 0
self.retry_delay: Union[int, float] = 0

self.custom_retry_codes = custom_retry_codes
logger.info("Emitter initialized with endpoint " + self.endpoint)

self.request_method = requests if session is None else session
if session is None:
self.request_method = Requester(post=requests.post, get=requests.get)
else:
self.request_method = Requester(post=session.post, get=session.get)

@staticmethod
def as_collector_uri(
Expand All @@ -183,7 +200,7 @@ def as_collector_uri(

if endpoint.split("://")[0] in PROTOCOLS:
endpoint_arr = endpoint.split("://")
protocol = endpoint_arr[0]
protocol = cast(HttpProtocol, endpoint_arr[0])
endpoint = endpoint_arr[1]

if method == "get":
Expand Down Expand Up @@ -427,6 +444,10 @@ def _cancel_retry_timer(self) -> None:
"""
self.retry_timer.cancel()

# This is only here to satisfy the `EmitterProtocol` interface
def async_flush(self) -> None:
return


class AsyncEmitter(Emitter):
"""
Expand All @@ -446,7 +467,7 @@ def __init__(
byte_limit: Optional[int] = None,
request_timeout: Optional[Union[float, Tuple[float, float]]] = None,
max_retry_delay_seconds: int = 60,
buffer_capacity: int = None,
buffer_capacity: Optional[int] = None,
custom_retry_codes: Dict[int, bool] = {},
event_store: Optional[EventStore] = None,
session: Optional[requests.Session] = None,
Expand Down Expand Up @@ -501,7 +522,7 @@ def __init__(
event_store=event_store,
session=session,
)
self.queue = Queue()
self.queue: Queue = Queue()
for i in range(thread_count):
t = threading.Thread(target=self.consume)
t.daemon = True
Expand Down
11 changes: 6 additions & 5 deletions snowplow_tracker/event_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# language governing permissions and limitations there under.
# """

from typing import List
from typing_extensions import Protocol
from snowplow_tracker.typing import PayloadDict, PayloadDictList
from logging import Logger
Expand All @@ -25,7 +26,7 @@ class EventStore(Protocol):
EventStore protocol. For buffering events in the Emitter.
"""

def add_event(payload: PayloadDict) -> bool:
def add_event(self, payload: PayloadDict) -> bool:
"""
Add PayloadDict to buffer. Returns True if successful.
Expand All @@ -35,15 +36,15 @@ def add_event(payload: PayloadDict) -> bool:
"""
...

def get_events_batch() -> PayloadDictList:
def get_events_batch(self) -> PayloadDictList:
"""
Get a list of all the PayloadDicts in the buffer.
:rtype PayloadDictList
"""
...

def cleanup(batch: PayloadDictList, need_retry: bool) -> None:
def cleanup(self, batch: PayloadDictList, need_retry: bool) -> None:
"""
Removes sent events from the event store. If events need to be retried they are re-added to the buffer.
Expand All @@ -54,7 +55,7 @@ def cleanup(batch: PayloadDictList, need_retry: bool) -> None:
"""
...

def size() -> int:
def size(self) -> int:
"""
Returns the number of events in the buffer
Expand All @@ -76,7 +77,7 @@ def __init__(self, logger: Logger, buffer_capacity: int = 10000) -> None:
When the buffer is full new events are lost.
:type buffer_capacity int
"""
self.event_buffer = []
self.event_buffer: List[PayloadDict] = []
self.buffer_capacity = buffer_capacity
self.logger = logger

Expand Down
5 changes: 2 additions & 3 deletions snowplow_tracker/events/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,9 @@ def build_payload(
if self.event_subject is not None:
fin_payload_dict = self.event_subject.combine_subject(subject)
else:
fin_payload_dict = None if subject is None else subject.standard_nv_pairs
fin_payload_dict = {} if subject is None else subject.standard_nv_pairs

if fin_payload_dict is not None:
self.payload.add_dict(fin_payload_dict)
self.payload.add_dict(fin_payload_dict)
return self.payload

@property
Expand Down
4 changes: 2 additions & 2 deletions snowplow_tracker/events/screen_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# language governing permissions and limitations there under.
# """

from typing import Optional, List
from typing import Dict, Optional, List
from snowplow_tracker.typing import JsonEncoderFunction
from snowplow_tracker.events.event import Event
from snowplow_tracker.events.self_describing import SelfDescribing
Expand Down Expand Up @@ -76,7 +76,7 @@ def __init__(
super(ScreenView, self).__init__(
event_subject=event_subject, context=context, true_timestamp=true_timestamp
)
self.screen_view_properties = {}
self.screen_view_properties: Dict[str, str] = {}
self.id_ = id_
self.name = name
self.type = type
Expand Down
12 changes: 6 additions & 6 deletions snowplow_tracker/events/structured_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# language governing permissions and limitations there under.
# """
from snowplow_tracker.events.event import Event
from typing import Optional, List
from typing import Optional, List, Union
from snowplow_tracker.subject import Subject
from snowplow_tracker.self_describing_json import SelfDescribingJson
from snowplow_tracker.contracts import non_empty_string
Expand All @@ -41,7 +41,7 @@ def __init__(
action: str,
label: Optional[str] = None,
property_: Optional[str] = None,
value: Optional[int] = None,
value: Optional[Union[int, float]] = None,
event_subject: Optional[Subject] = None,
context: Optional[List[SelfDescribingJson]] = None,
true_timestamp: Optional[float] = None,
Expand Down Expand Up @@ -84,7 +84,7 @@ def category(self) -> Optional[str]:
return self.payload.nv_pairs.get("se_ca")

@category.setter
def category(self, value: Optional[str]):
def category(self, value: str):
non_empty_string(value)
self.payload.add("se_ca", value)

Expand All @@ -96,7 +96,7 @@ def action(self) -> Optional[str]:
return self.payload.nv_pairs.get("se_ac")

@action.setter
def action(self, value: Optional[str]):
def action(self, value: str):
non_empty_string(value)
self.payload.add("se_ac", value)

Expand All @@ -123,12 +123,12 @@ def property_(self, value: Optional[str]):
self.payload.add("se_pr", value)

@property
def value(self) -> Optional[int]:
def value(self) -> Optional[Union[int, float]]:
"""
A value associated with the user action
"""
return self.payload.nv_pairs.get("se_va")

@value.setter
def value(self, value: Optional[int]):
def value(self, value: Optional[Union[int, float]]):
self.payload.add("se_va", value)
5 changes: 2 additions & 3 deletions snowplow_tracker/payload.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,8 @@ def add_json(

if encode_base64:
encoded_dict = base64.urlsafe_b64encode(json_dict.encode("utf-8"))
if not isinstance(encoded_dict, str):
encoded_dict = encoded_dict.decode("utf-8")
self.add(type_when_encoded, encoded_dict)
encoded_dict_str = encoded_dict.decode("utf-8")
self.add(type_when_encoded, encoded_dict_str)

else:
self.add(type_when_not_encoded, json_dict)
Expand Down
6 changes: 3 additions & 3 deletions snowplow_tracker/snowplow.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# """

import logging
from typing import Optional
from typing import Dict, Optional
from snowplow_tracker import (
Tracker,
Emitter,
Expand All @@ -37,7 +37,7 @@


class Snowplow:
_trackers = {}
_trackers: Dict[str, Tracker] = {}

@staticmethod
def create_tracker(
Expand Down Expand Up @@ -149,7 +149,7 @@ def reset(cls):
cls._trackers = {}

@classmethod
def get_tracker(cls, namespace: str) -> Tracker:
def get_tracker(cls, namespace: str) -> Optional[Tracker]:
"""
Returns a Snowplow tracker from the Snowplow object if it exists
:param namespace: Snowplow tracker namespace
Expand Down
4 changes: 2 additions & 2 deletions snowplow_tracker/subject.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# language governing permissions and limitations there under.
# """

from typing import Optional
from typing import Dict, Optional, Union
from snowplow_tracker.contracts import one_of, greater_than
from snowplow_tracker.typing import SupportedPlatform, SUPPORTED_PLATFORMS, PayloadDict

Expand All @@ -30,7 +30,7 @@ class Subject(object):
"""

def __init__(self) -> None:
self.standard_nv_pairs = {"p": DEFAULT_PLATFORM}
self.standard_nv_pairs: Dict[str, Union[str, int]] = {"p": DEFAULT_PLATFORM}

def set_platform(self, value: SupportedPlatform) -> "Subject":
"""
Expand Down
Loading

0 comments on commit f58d845

Please sign in to comment.