Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(dispatcher): refactor out training code #3663

Merged
merged 4 commits into from
Mar 16, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 50 additions & 140 deletions src/bentoml/_internal/marshal/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(self, max_latency: float):
self.o_a = min(2, max_latency * 2.0 / 30)
self.o_b = min(1, max_latency * 1.0 / 30)

self.wait = 0.01 # the avg wait time before outbound called
self.wait = 0 # the avg wait time before outbound called

self._refresh_tb = TokenBucket(2) # to limit params refresh interval
self.outbound_counter = 0
Expand Down Expand Up @@ -168,52 +168,24 @@ async def _func(data: t.Any) -> t.Any:

return _func

async def controller(self):
"""
A standalone coroutine to wait/dispatch calling.
"""
logger.debug("Starting dispatcher optimizer training...")
# warm up the model
while self.optimizer.outbound_counter <= self.optimizer.N_SKIPPED_SAMPLE:
try:
async with self._wake_event: # block until there's any request in queue
await self._wake_event.wait_for(self._queue.__len__)

n = len(self._queue)
now = time.time()
w0 = now - self._queue[0][0]

# only cancel requests if there are more than enough for training
if (
n
> self.optimizer.N_SKIPPED_SAMPLE
- self.optimizer.outbound_counter
+ 6
and w0 >= self.max_latency_in_ms
):
# we're being very conservative and only canceling requests if they have already timed out
self._queue.popleft()[2].cancel()
continue
# don't try to be smart here, just serve the first few requests
if self._sema.is_locked():
await asyncio.sleep(self.tick_interval)
continue

n_call_out = 1
# call
self._sema.acquire()
inputs_info = tuple(self._queue.pop() for _ in range(n_call_out))
self._loop.create_task(self.outbound_call(inputs_info))
except asyncio.CancelledError:
return
except Exception as e: # pylint: disable=broad-except
logger.error(traceback.format_exc(), exc_info=e)
async def train_optimizer(
self,
num_required_reqs: int,
num_reqs_to_train: int,
batch_size: int,
):
if self.max_batch_size < batch_size:
batch_size = self.max_batch_size

logger.debug("Dispatcher finished warming up model.")
if batch_size > 1:
wait = min(
self.max_latency * 0.95,
(batch_size * 2 + 1) * (self.optimizer.o_a + self.optimizer.o_b),
)

while self.optimizer.outbound_counter <= self.optimizer.N_SKIPPED_SAMPLE + 1:
try:
# step 1: attempt to serve a single request immediately
req_count = 0
try:
while req_count < num_reqs_to_train:
async with self._wake_event: # block until there's any request in queue
await self._wake_event.wait_for(self._queue.__len__)

Expand All @@ -222,117 +194,55 @@ async def controller(self):
w0 = now - self._queue[0][0]

# only cancel requests if there are more than enough for training
if n > 6 and w0 >= self.max_latency_in_ms:
if n > num_required_reqs - req_count and w0 >= self.max_latency_in_ms:
# we're being very conservative and only canceling requests if they have already timed out
self._queue.popleft()[2].cancel()
continue
if batch_size > 1: # only wait if batch_size
a = self.optimizer.o_a
b = self.optimizer.o_b

if n < batch_size and (batch_size * a + b) + w0 <= wait:
await asyncio.sleep(self.tick_interval)
continue
if self._sema.is_locked():
await asyncio.sleep(self.tick_interval)
continue

n_call_out = 1
n_call_out = min(n, batch_size)
req_count += 1
# call
self._sema.acquire()
inputs_info = tuple(self._queue.pop() for _ in range(n_call_out))
self._loop.create_task(self.outbound_call(inputs_info))
except asyncio.CancelledError:
return
except Exception as e: # pylint: disable=broad-except
logger.error(traceback.format_exc(), exc_info=e)

logger.debug("Dispatcher finished optimizer training request 1.")
self.optimizer.trigger_refresh()

if self.max_batch_size >= 2:
# we will attempt to keep the second request served within this time
step_2_wait = min(
self.max_latency_in_ms * 0.95,
5 * (self.optimizer.o_a + self.optimizer.o_b),
)
except asyncio.CancelledError:
return
except Exception as e: # pylint: disable=broad-except
logger.error(traceback.format_exc(), exc_info=e)

# step 2: attempt to serve 2 requests
while (
self.optimizer.outbound_counter <= self.optimizer.N_SKIPPED_SAMPLE + 2
):
try:
async with self._wake_event: # block until there's any request in queue
await self._wake_event.wait_for(self._queue.__len__)

n = len(self._queue)
dt = self.tick_interval
now = time.time()
w0 = now - self._queue[0][0]
a = self.optimizer.o_a
b = self.optimizer.o_b
async def controller(self):
"""
A standalone coroutine to wait/dispatch calling.
"""
logger.debug("Starting dispatcher optimizer training...")
# warm up the model
self.train_optimizer(
self.optimizer.N_SKIPPED_SAMPLE, self.optimizer.N_SKIPPED_SAMPLE + 6, 1
)

# only cancel requests if there are more than enough for training
if n > 5 and w0 >= self.max_latency_in_ms:
# we're being very conservative and only canceling requests if they have already timed out
self._queue.popleft()[2].cancel()
continue
if n < 2 and (2 * a + b) + w0 <= step_2_wait:
await asyncio.sleep(self.tick_interval)
continue
if self._sema.is_locked():
await asyncio.sleep(self.tick_interval)
continue
logger.debug("Dispatcher finished warming up model.")

n_call_out = min(n, 2)
# call
self._sema.acquire()
inputs_info = tuple(self._queue.pop() for _ in range(n_call_out))
self._loop.create_task(self.outbound_call(inputs_info))
except asyncio.CancelledError:
return
except Exception as e: # pylint: disable=broad-except
logger.error(traceback.format_exc(), exc_info=e)

logger.debug("Dispatcher finished optimizer training request 2.")
self.optimizer.trigger_refresh()

if self.max_batch_size >= 3:
# step 3: attempt to serve 3 requests

# we will attempt to keep the second request served within this time
step_3_wait = min(
self.max_latency_in_ms * 0.95,
7 * (self.optimizer.o_a + self.optimizer.o_b),
)
while (
self.optimizer.outbound_counter <= self.optimizer.N_SKIPPED_SAMPLE + 3
):
try:
async with self._wake_event: # block until there's any request in queue
await self._wake_event.wait_for(self._queue.__len__)

n = len(self._queue)
dt = self.tick_interval
now = time.time()
w0 = now - self._queue[0][0]
a = self.optimizer.o_a
b = self.optimizer.o_b
await self.train_optimizer(1, 6, 1)
self.optimizer.trigger_refresh()
logger.debug("Dispatcher finished optimizer training request 1.")

# only cancel requests if there are more than enough for training
if n > 3 and w0 >= self.max_latency_in_ms:
# we're being very conservative and only canceling requests if they have already timed out
self._queue.popleft()[2].cancel()
continue
if n < 3 and (3 * a + b) + w0 <= step_3_wait:
await asyncio.sleep(self.tick_interval)
continue
await self.train_optimizer(1, 5, 2)
self.optimizer.trigger_refresh()
logger.debug("Dispatcher finished optimizer training request 2.")

n_call_out = min(n, 3)
# call
self._sema.acquire()
inputs_info = tuple(self._queue.pop() for _ in range(n_call_out))
self._loop.create_task(self.outbound_call(inputs_info))
except asyncio.CancelledError:
return
except Exception as e: # pylint: disable=broad-except
logger.error(traceback.format_exc(), exc_info=e)

logger.debug("Dispatcher finished optimizer training request 3.")
self.optimizer.trigger_refresh()
await self.train_optimizer(1, 3, 3)
self.optimizer.trigger_refresh()
logger.debug("Dispatcher finished optimizer training request 3.")

if self.optimizer.o_a + self.optimizer.o_b >= self.max_latency_in_ms:
logger.warning(
Expand Down