Skip to content

Commit

Permalink
Switch the RL locks to asyncio
Browse files Browse the repository at this point in the history
  • Loading branch information
olijeffers0n committed May 5, 2023
1 parent 353e235 commit a0f8b60
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 59 deletions.
8 changes: 4 additions & 4 deletions rustplus/api/base_rust_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,15 @@ async def _handle_ratelimit(self, amount=1) -> None:
:return: None
"""
while True:
if self.remote.ratelimiter.can_consume(self.server_id, amount):
self.remote.ratelimiter.consume(self.server_id, amount)
if await self.remote.ratelimiter.can_consume(self.server_id, amount):
await self.remote.ratelimiter.consume(self.server_id, amount)
break

if self.raise_ratelimit_exception:
raise RateLimitError("Out of tokens")

await asyncio.sleep(
self.remote.ratelimiter.get_estimated_delay_time(self.server_id, amount)
await self.remote.ratelimiter.get_estimated_delay_time(self.server_id, amount)
)

self.heartbeat.reset_rhythm()
Expand Down Expand Up @@ -237,7 +237,7 @@ async def switch_server(

# reset ratelimiter
self.remote.use_proxy = use_proxy
self.remote.ratelimiter.remove(self.server_id)
await self.remote.ratelimiter.remove(self.server_id)
self.remote.ratelimiter.add_socket(
self.server_id,
self.ratelimit_limit,
Expand Down
92 changes: 44 additions & 48 deletions rustplus/api/remote/ratelimiter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math
import time
import threading
import asyncio
from typing import Dict

from ...exceptions.exceptions import RateLimitError
Expand Down Expand Up @@ -50,7 +50,7 @@ def default(cls) -> "RateLimiter":
def __init__(self) -> None:
self.socket_buckets: Dict[ServerID, TokenBucket] = {}
self.server_buckets: Dict[str, TokenBucket] = {}
self.lock = threading.Lock()
self.lock = asyncio.Lock()

def add_socket(
self,
Expand All @@ -68,66 +68,62 @@ def add_socket(
self.SERVER_LIMIT, self.SERVER_LIMIT, 1, self.SERVER_REFRESH_AMOUNT
)

def can_consume(self, server_id: ServerID, amount: int = 1) -> bool:
async def can_consume(self, server_id: ServerID, amount: int = 1) -> bool:
"""
Returns whether the user can consume the amount of tokens provided
"""
self.lock.acquire(blocking=True)
can_consume = True

for bucket in [
self.socket_buckets.get(server_id),
self.server_buckets.get(server_id.get_server_string()),
]:
bucket.refresh()
if not bucket.can_consume(amount):
can_consume = False

self.lock.release()
async with self.lock:
can_consume = True

for bucket in [
self.socket_buckets.get(server_id),
self.server_buckets.get(server_id.get_server_string()),
]:
bucket.refresh()
if not bucket.can_consume(amount):
can_consume = False

return can_consume

def consume(self, server_id: ServerID, amount: int = 1) -> None:
async def consume(self, server_id: ServerID, amount: int = 1) -> None:
"""
Consumes an amount of tokens from the bucket. You should first check to see whether it is possible with can_consume
"""
self.lock.acquire(blocking=True)
for bucket in [
self.socket_buckets.get(server_id),
self.server_buckets.get(server_id.get_server_string()),
]:
bucket.refresh()
if not bucket.can_consume(amount):
self.lock.release()
raise RateLimitError("Not Enough Tokens")
bucket.consume(amount)
self.lock.release()

def get_estimated_delay_time(self, server_id: ServerID, target_cost: int) -> float:
async with self.lock:
for bucket in [
self.socket_buckets.get(server_id),
self.server_buckets.get(server_id.get_server_string()),
]:
bucket.refresh()
if not bucket.can_consume(amount):
self.lock.release()
raise RateLimitError("Not Enough Tokens")
bucket.consume(amount)

async def get_estimated_delay_time(self, server_id: ServerID, target_cost: int) -> float:
"""
Returns how long until the amount of tokens needed will be available
"""
self.lock.acquire(blocking=True)
delay = 0
for bucket in [
self.socket_buckets.get(server_id),
self.server_buckets.get(server_id.get_server_string()),
]:
val = (
math.ceil(
(((target_cost - bucket.current) / bucket.refresh_per_second) + 0.1)
* 100
async with self.lock:
delay = 0
for bucket in [
self.socket_buckets.get(server_id),
self.server_buckets.get(server_id.get_server_string()),
]:
val = (
math.ceil(
(((target_cost - bucket.current) / bucket.refresh_per_second) + 0.1)
* 100
)
/ 100
)
/ 100
)
if val > delay:
delay = val
self.lock.release()
if val > delay:
delay = val
return delay

def remove(self, server_id: ServerID) -> None:
async def remove(self, server_id: ServerID) -> None:
"""
Removes the limiter
"""
self.lock.acquire(blocking=True)
del self.socket_buckets[server_id]
self.lock.release()
async with self.lock:
del self.socket_buckets[server_id]
8 changes: 3 additions & 5 deletions rustplus/api/remote/rust_remote_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,7 @@ async def get_response(
)

# Fully Refill the bucket

bucket = self.ratelimiter.socket_buckets.get(self.server_id)

bucket.current = 0

while bucket.current < bucket.max:
Expand All @@ -150,12 +148,12 @@ async def get_response(
cost = self.ws.get_proto_cost(app_request)

while True:
if self.ratelimiter.can_consume(self.server_id, cost):
self.ratelimiter.consume(self.server_id, cost)
if await self.ratelimiter.can_consume(self.server_id, cost):
await self.ratelimiter.consume(self.server_id, cost)
break

await asyncio.sleep(
self.ratelimiter.get_estimated_delay_time(self.server_id, cost)
await self.ratelimiter.get_estimated_delay_time(self.server_id, cost)
)

await self.send_message(app_request)
Expand Down
4 changes: 2 additions & 2 deletions rustplus/api/remote/rustws.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,8 @@ async def run(self) -> None:

try:
await self.handle_message(app_message)
except Exception as e:
self.logger.error(e)
except Exception:
self.logger.exception("An Error occurred whilst handling the event")

async def handle_message(self, app_message: AppMessage) -> None:
if app_message.response.seq in self.remote.ignored_responses:
Expand Down

0 comments on commit a0f8b60

Please sign in to comment.