From 803afb34ac49ee9649f74e5a80cb6f816c668f30 Mon Sep 17 00:00:00 2001 From: felix Date: Tue, 24 Sep 2019 15:49:39 +0800 Subject: [PATCH] feat(snoflake-uuid): add snowflake uuid generator --- gnes/helper.py | 20 ++++++++++- gnes/uuid.py | 88 ++++++++++++++++++++++++++++++++++++++++++++++ tests/test_uuid.py | 21 +++++++++++ 3 files changed, 128 insertions(+), 1 deletion(-) create mode 100644 gnes/uuid.py create mode 100644 tests/test_uuid.py diff --git a/gnes/helper.py b/gnes/helper.py index 450e32e5..239d6b03 100644 --- a/gnes/helper.py +++ b/gnes/helper.py @@ -19,6 +19,7 @@ import os import sys import time +import threading from copy import copy from functools import wraps from itertools import islice @@ -41,7 +42,24 @@ 'profile_logger', 'load_contrib_module', 'parse_arg', 'profiling', 'FileLock', 'train_required', 'get_first_available_gpu', - 'PathImporter', 'progressbar'] + 'PathImporter', 'progressbar', 'Singleton'] + + +class Singleton: + """ + Make your class singeton + """ + def __init__(self, cls): + self.__instance = None + self.__cls = cls + self._lock = threading.Lock() + + def __call__(self, *args, **kwargs): + self._lock.acquire() + if self.__instance is None: + self.__instance = self.__cls(*args, **kwargs) + self._lock.release() + return self.__instance def progressbar(i, prefix="", suffix="", count=100, size=60): diff --git a/gnes/uuid.py b/gnes/uuid.py new file mode 100644 index 00000000..4dcfd940 --- /dev/null +++ b/gnes/uuid.py @@ -0,0 +1,88 @@ +import threading +import time +from datetime import datetime + +from . import helper + + +@helper.Singleton +class BaseIDGenerator(object): + """ + Thread-safe (auto incremental) uuid generator + """ + + def __init__(self, start_id: int = 0, *args, **kwargs): + self.args = args + self.kwargs = kwargs + self._lock = threading.Lock() + self._next_id = start_id + + def reset(self, start_id: int = 0): + with self._lock: + self._next_id = start_id + + def next(self) -> int: + with self._lock: + temp = self._next_id + self._next_id += 1 + return temp + + +@helper.Singleton +class SnowflakeIDGenerator(object): + + def __init__(self, + machine_id: int = 0, + datacenter_id: int = 0, + *args, + **kwargs): + self._lock = threading.Lock() + self._next_id = 0 + + self.machine_id = machine_id + self.datacenter_id = datacenter_id + + self.machine_bits = 5 + self.datacenter_bits = 5 + self.max_machine_id = -1 ^ -1 << self.machine_bits + self.max_datacenter_id = -1 ^ (-1 << self.datacenter_bits) + + self.counter_bits = 12 + self.max_counter_mask = -1 ^ -1 << self.counter_bits + + self.machine_shift = self.counter_bits + self.datacenter_shift = self.counter_bits + self.machine_shift + self.timestamp_shift = self.counter_bits + self.machine_shift + self.datacenter_shift + + self.twepoch = 687888001020 + self.last_timestamp = -1 + + def _get_timestamp(self) -> int: + return int(datetime.now().timestamp() * 1000) + + def _get_next_timestamp(self, last_timestamp) -> int: + timestamp = self._get_timestamp() + while timestamp <= last_timestamp: + timestamp = self._get_timestamp() + return timestamp + + def next(self) -> int: + with self._lock: + timestamp = int(datetime.now().timestamp() * 1000) + if self.last_timestamp == timestamp: + self._next_id = (self._next_id + 1) & self.max_counter_mask + if self._next_id == 0: + timestamp = self._get_next_timestamp(self.last_timestamp) + else: + self._next_id = 0 + + if timestamp < self.last_timestamp: + raise ValueError( + 'the current timestamp is smaller than the last timestamp') + + self.last_timestamp = timestamp + uuid = ((timestamp - self.twepoch) << self.timestamp_shift) \ + | (self.datacenter_id << self.datacenter_shift) \ + | (self.machine_id << self.machine_shift) \ + | self._next_id + return uuid diff --git a/tests/test_uuid.py b/tests/test_uuid.py new file mode 100644 index 00000000..09a6580a --- /dev/null +++ b/tests/test_uuid.py @@ -0,0 +1,21 @@ +import unittest + +from gnes.uuid import BaseIDGenerator, SnowflakeIDGenerator + +class TestUUID(unittest.TestCase): + def test_base_uuid(self): + uuid_generator = BaseIDGenerator() + last = -1 + for _ in range(10000): + nid = uuid_generator.next() + self.assertGreater(nid, last) + last = nid + + + def test_snoflake(self): + uuid_generator = SnowflakeIDGenerator() + last = -1 + for _ in range(10000): + nid = uuid_generator.next() + self.assertGreater(nid, last) + last = nid