From a61379a0561a25787150aedcf81ad5e24c3707e1 Mon Sep 17 00:00:00 2001 From: Sauyon Lee <2347889+sauyon@users.noreply.github.com> Date: Wed, 15 Feb 2023 11:24:32 -0800 Subject: [PATCH] fix: update formparser for new starlette (#3569) Co-authored-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com> --- src/bentoml/_internal/utils/formparser.py | 348 ++++++++++++++-------- 1 file changed, 229 insertions(+), 119 deletions(-) diff --git a/src/bentoml/_internal/utils/formparser.py b/src/bentoml/_internal/utils/formparser.py index 449ce06d07d..756e9e65496 100644 --- a/src/bentoml/_internal/utils/formparser.py +++ b/src/bentoml/_internal/utils/formparser.py @@ -3,92 +3,247 @@ import io import uuid import typing as t -from typing import TYPE_CHECKING +from enum import Enum from tempfile import SpooledTemporaryFile +from dataclasses import field +from dataclasses import dataclass +from urllib.parse import unquote_plus import multipart.multipart as multipart from starlette.requests import Request from starlette.responses import Response -from starlette.formparsers import MultiPartMessage from starlette.datastructures import Headers from starlette.datastructures import FormData from starlette.datastructures import UploadFile from starlette.datastructures import MutableHeaders from .http import set_cookies -from ...exceptions import BadInput from ...exceptions import BentoMLException -if TYPE_CHECKING: +if t.TYPE_CHECKING: from ..context import InferenceApiContext as Context -_ItemsBody = t.List[ - t.Tuple[str, t.List[t.Tuple[bytes, bytes]], t.Union[bytes, UploadFile]] -] +# Code below adapted from starlette's formparser. See the license: https://github.com/encode/starlette/blob/fc480890fe1f1e421746de303c6f8da1323e5626/LICENSE.md -def user_safe_decode(src: bytes, codec: str) -> str: +class FormMessage(Enum): + FIELD_START = 1 + FIELD_NAME = 2 + FIELD_DATA = 3 + FIELD_END = 4 + END = 5 + + +@dataclass +class MultipartPart: + content_disposition: bytes | None = None + field_name: str = "" + data: bytes = b"" + file: UploadFile | None = None + item_headers: list[tuple[bytes, bytes]] = field(default_factory=list) + + +def _user_safe_decode(src: bytes, codec: str) -> str: try: return src.decode(codec) except (UnicodeDecodeError, LookupError): return src.decode("latin-1") -MAX_FILE_SIZE = 1024 * 1024 +class MultiPartException(Exception): + def __init__(self, message: str) -> None: + self.message = message -class MultiPartParser: - """ - An modified version of starlette MultiPartParser. - """ - +class FormParser: def __init__(self, headers: Headers, stream: t.AsyncGenerator[bytes, None]) -> None: assert ( multipart is not None ), "The `python-multipart` library must be installed to use form parsing." - self.headers: Headers = headers + self.headers = headers self.stream = stream - self.messages: t.List[t.Tuple[MultiPartMessage, bytes]] = list() + self.messages: list[tuple[FormMessage, bytes]] = [] - def on_part_begin(self) -> None: - message = (MultiPartMessage.PART_BEGIN, b"") + def on_field_start(self) -> None: + message = (FormMessage.FIELD_START, b"") self.messages.append(message) - def on_part_data(self, data: bytes, start: int, end: int) -> None: - message = (MultiPartMessage.PART_DATA, data[start:end]) + def on_field_name(self, data: bytes, start: int, end: int) -> None: + message = (FormMessage.FIELD_NAME, data[start:end]) self.messages.append(message) - def on_part_end(self) -> None: - message = (MultiPartMessage.PART_END, b"") + def on_field_data(self, data: bytes, start: int, end: int) -> None: + message = (FormMessage.FIELD_DATA, data[start:end]) self.messages.append(message) - def on_header_field(self, data: bytes, start: int, end: int) -> None: - message = (MultiPartMessage.HEADER_FIELD, data[start:end]) + def on_field_end(self) -> None: + message = (FormMessage.FIELD_END, b"") self.messages.append(message) - def on_header_value(self, data: bytes, start: int, end: int) -> None: - message = (MultiPartMessage.HEADER_VALUE, data[start:end]) + def on_end(self) -> None: + message = (FormMessage.END, b"") self.messages.append(message) + async def parse(self) -> FormData: + # Callbacks dictionary. + callbacks = { + "on_field_start": self.on_field_start, + "on_field_name": self.on_field_name, + "on_field_data": self.on_field_data, + "on_field_end": self.on_field_end, + "on_end": self.on_end, + } + + # Create the parser. + parser = multipart.QuerystringParser(callbacks) + field_name = b"" + field_value = b"" + + items: list[tuple[str, str | UploadFile]] = [] + + # Feed the parser with data from the request. + async for chunk in self.stream: + if chunk: + parser.write(chunk) + else: + parser.finalize() + messages = list(self.messages) + self.messages.clear() + for message_type, message_bytes in messages: + if message_type == FormMessage.FIELD_START: + field_name = b"" + field_value = b"" + elif message_type == FormMessage.FIELD_NAME: + field_name += message_bytes + elif message_type == FormMessage.FIELD_DATA: + field_value += message_bytes + elif message_type == FormMessage.FIELD_END: + name = unquote_plus(field_name.decode("latin-1")) + value = unquote_plus(field_value.decode("latin-1")) + items.append((name, value)) + + return FormData(items) + + +class MultiPartParser: + max_file_size = 1024 * 1024 + + def __init__( + self, + headers: Headers, + stream: t.AsyncGenerator[bytes, None], + *, + max_files: int | float = 1000, + max_fields: int | float = 1000, + ) -> None: + assert ( + multipart is not None + ), "The `python-multipart` library must be installed to use form parsing." + self.headers = headers + self.stream = stream + self.max_files = max_files + self.max_fields = max_fields + self.items: list[tuple[str, str | UploadFile]] = [] + self._current_files = 0 + self._current_fields = 0 + self._current_partial_header_name: bytes = b"" + self._current_partial_header_value: bytes = b"" + self._current_part = MultipartPart() + self._charset: str = "" + self._file_parts_to_write: list[tuple[MultipartPart, bytes]] = [] + self._file_parts_to_finish: list[MultipartPart] = [] + self._files_to_close_on_error: list[SpooledTemporaryFile[bytes]] = [] + + def on_part_begin(self) -> None: + self._current_part = MultipartPart() + + def on_part_data(self, data: bytes, start: int, end: int) -> None: + message_bytes = data[start:end] + if self._current_part.file is None: + self._current_part.data += message_bytes + else: + self._file_parts_to_write.append((self._current_part, message_bytes)) + + def on_part_end(self) -> None: + if self._current_part.file is None: + self.items.append( + ( + self._current_part.field_name, + _user_safe_decode(self._current_part.data, self._charset), + ) + ) + else: + self._file_parts_to_finish.append(self._current_part) + # The file can be added to the items right now even though it's not + # finished yet, because it will be finished in the `parse()` method, before + # self.items is used in the return value. + self.items.append((self._current_part.field_name, self._current_part.file)) + + def on_header_field(self, data: bytes, start: int, end: int) -> None: + self._current_partial_header_name += data[start:end] + + def on_header_value(self, data: bytes, start: int, end: int) -> None: + self._current_partial_header_value += data[start:end] + def on_header_end(self) -> None: - message = (MultiPartMessage.HEADER_END, b"") - self.messages.append(message) + field = self._current_partial_header_name.lower() + if field == b"content-disposition": + self._current_part.content_disposition = self._current_partial_header_value + self._current_part.item_headers.append( + (field, self._current_partial_header_value) + ) + self._current_partial_header_name = b"" + self._current_partial_header_value = b"" def on_headers_finished(self) -> None: - message = (MultiPartMessage.HEADERS_FINISHED, b"") - self.messages.append(message) + assert self._current_part.content_disposition is not None + _, options = multipart.parse_options_header( + self._current_part.content_disposition + ) + try: + self._current_part.field_name = _user_safe_decode( + bytes(options[b"name"]), self._charset + ) + except KeyError: + raise MultiPartException( + 'The Content-Disposition header field "name" must be ' "provided." + ) + if b"filename" in options: + self._current_files += 1 + if self._current_files > self.max_files: + raise MultiPartException( + f"Too many files. Maximum number of files is {self.max_files}." + ) + filename = _user_safe_decode(bytes(options[b"filename"]), self._charset) + tempfile = SpooledTemporaryFile(max_size=self.max_file_size) + self._files_to_close_on_error.append(tempfile) + self._current_part.file = UploadFile( + file=tempfile, + filename=filename, + headers=Headers(raw=self._current_part.item_headers), + ) + else: + self._current_fields += 1 + if self._current_fields > self.max_fields: + raise MultiPartException( + f"Too many fields. Maximum number of fields is {self.max_fields}." + ) + self._current_part.file = None def on_end(self) -> None: - message = (MultiPartMessage.END, b"") - self.messages.append(message) + pass - async def parse(self) -> _ItemsBody: + async def parse(self) -> FormData: # Parse the Content-Type header to get the multipart boundary. _, params = multipart.parse_options_header(self.headers["Content-Type"]) - params = t.cast(t.Dict[bytes, bytes], params) - charset = params.get(b"charset", b"utf-8") - charset = charset.decode("latin-1") - boundary = params.get(b"boundary") + charset = params.get(b"charset", "utf-8") + if isinstance(charset, bytes): + charset = charset.decode("latin-1") + self._charset = str(charset) + try: + boundary = params[b"boundary"] + except KeyError: + raise MultiPartException("Missing boundary in multipart.") # Callbacks dictionary. callbacks = { @@ -104,80 +259,31 @@ async def parse(self) -> _ItemsBody: # Create the parser. parser = multipart.MultipartParser(boundary, callbacks) - header_field = b"" - header_value = b"" - field_name = "" - content_disposition = None - multipart_file = None - - data = b"" - - items: _ItemsBody = [] - headers: t.List[t.Tuple[bytes, bytes]] = [] - - # Feed the parser with data from the request. - async for chunk in self.stream: - parser.write(chunk) - messages = list(self.messages) - self.messages.clear() - for message_type, message_bytes in messages: - if message_type == MultiPartMessage.PART_BEGIN: - content_disposition = None - field_name = "" - data = b"" - headers = list() - elif message_type == MultiPartMessage.HEADER_FIELD: # type: ignore - header_field += message_bytes - elif message_type == MultiPartMessage.HEADER_VALUE: # type: ignore - header_value += message_bytes - elif message_type == MultiPartMessage.HEADER_END: # type: ignore - field = header_field.lower() - if field == b"content-disposition": - content_disposition = header_value - elif field == b"bentoml-payload-field": - field_name = user_safe_decode(header_value, charset) - else: - headers.append((field, header_value)) - header_field = b"" - header_value = b"" - elif message_type == MultiPartMessage.HEADERS_FINISHED: # type: ignore - if content_disposition is None: - raise BadInput( - 'The Content-Disposition header field "name" must be provided.' - ) - _, options = multipart.parse_options_header(content_disposition) - options = t.cast(t.Dict[bytes, bytes], options) - try: - field_name = user_safe_decode(options[b"name"], charset) - except KeyError: - raise BadInput( - 'The Content-Disposition header field "name" must be provided.' - ) - if b"filename" in options: - filename = user_safe_decode(options[b"filename"], charset) - tempfile = SpooledTemporaryFile(max_size=MAX_FILE_SIZE) - multipart_file = UploadFile( - file=tempfile, - # size=0, # TODO: support size for starlette 0.24 onwards - filename=filename, - headers=Headers(raw=headers), # type: ignore (incomplete starlette types) - ) - else: - multipart_file = None - elif message_type == MultiPartMessage.PART_DATA: # type: ignore - if multipart_file is None: - data += message_bytes - else: - await multipart_file.write(message_bytes) - elif message_type == MultiPartMessage.PART_END: # type: ignore - if multipart_file is None: - items.append((field_name, headers, data)) - else: - await multipart_file.seek(0) - items.append((field_name, headers, multipart_file)) + try: + # Feed the parser with data from the request. + async for chunk in self.stream: + parser.write(chunk) + # Write file data, it needs to use await with the UploadFile methods + # that call the corresponding file methods *in a threadpool*, + # otherwise, if they were called directly in the callback methods above + # (regular, non-async functions), that would block the event loop in + # the main thread. + for part, data in self._file_parts_to_write: + assert part.file # for type checkers + await part.file.write(data) + for part in self._file_parts_to_finish: + assert part.file # for type checkers + await part.file.seek(0) + self._file_parts_to_write.clear() + self._file_parts_to_finish.clear() + except MultiPartException as exc: + # Close all the files if there was an error. + for file in self._files_to_close_on_error: + file.close() + raise exc parser.finalize() - return items + return FormData(self.items) def file_body_to_message(f: UploadFile): @@ -202,18 +308,22 @@ async def populate_multipart_requests(request: Request) -> t.Dict[str, Request]: raise BentoMLException("Invalid multipart requests") reqs: dict[str, Request] = dict() - for field_name, headers, data in form: + for field_name, data in form.items(): scope = dict(request.scope) - ori_headers = dict(scope.get("headers", dict())) - ori_headers = t.cast(t.Dict[bytes, bytes], ori_headers) - ori_headers.update(dict(headers)) - scope["headers"] = list(ori_headers.items()) + if isinstance(data, UploadFile): + ori_headers = dict(scope.get("headers", dict())) + ori_headers = t.cast(t.Dict[bytes, bytes], ori_headers) + ori_headers.update(dict(data.headers)) + scope["headers"] = list(ori_headers.items()) + if "headers" not in scope: + scope["headers"] = [] + req = Request(scope) - req._form = FormData([(field_name, data)]) # type: ignore (using internal starlette APIs) - if isinstance(data, bytes): - req._body = data - else: - req._receive = ( # type: ignore (using internal starlette APIs) + req._form = FormData([(field_name, data)]) # type: ignore # using internal starlette APIs + if isinstance(data, str): + req._body = bytes(data, "utf-8") + elif isinstance(data, UploadFile): + req._receive = ( # type: ignore # using internal starlette APIs file_body_to_message(data) ) reqs[field_name] = req