Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: loadbalance stream based on response #6122

Merged
merged 11 commits into from
Dec 6, 2023
73 changes: 31 additions & 42 deletions jina/serve/runtimes/gateway/request_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,48 +157,37 @@
try:
async with aiohttp.ClientSession() as session:

if request.method == 'GET':
request_kwargs = {}
try:
payload = await request.json()
if payload:
request_kwargs['json'] = payload
except Exception:
self.logger.debug('No JSON payload found in request')
Comment on lines -164 to -167
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at the original implementation, i have this idea:
it looks like the original logic was only to write a debug log which is not useful at all for production application. Can we just act as a pure proxy here for performance consideration? Something like:

async with session.request(request.method, data=request.iter_any(), **request_kwargs) as response:
    ....

@NarekA @JoanFM

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For some reason, if I try to pass the content in any other way besides the json field, I get an error here. I've tried everything at this point, if you can get this to work, I am interested.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what error do you see?


async with session.get(
url=target_url, **request_kwargs
) as response:
# Create a StreamResponse with the same headers and status as the target response
stream_response = web.StreamResponse(
status=response.status,
headers=response.headers,
)

# Prepare the response to send headers
await stream_response.prepare(request)

# Stream the response from the target server to the client
async for chunk in response.content.iter_any():
await stream_response.write(chunk)

# Close the stream response once all chunks are sent
await stream_response.write_eof()
return stream_response

elif request.method == 'POST':
d = await request.read()
import json

async with session.post(
url=target_url, json=json.loads(d.decode())
) as response:
content = await response.read()
return web.Response(
body=content,
status=response.status,
content_type=response.content_type,
)
request_kwargs = {}
try:
payload = await request.json()
if payload:
request_kwargs['json'] = payload
except Exception:
self.logger.debug('No JSON payload found in request')

Check warning on line 166 in jina/serve/runtimes/gateway/request_handling.py

View check run for this annotation

Codecov / codecov/patch

jina/serve/runtimes/gateway/request_handling.py#L165-L166

Added lines #L165 - L166 were not covered by tests

async with session.request(
request.method,
url=target_url,
auto_decompress=False,
**request_kwargs,
) as response:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NarekA why not directly use the request method from the session?, it's already wrapping the context manager for you.

    def request(
        self, method: str, url: StrOrURL, **kwargs: Any
    ) -> "_RequestContextManager":
        """Perform HTTP request."""
        return _RequestContextManager(self._request(method, url, **kwargs))
        ```

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Somehow I missed this, will fix.

# Create a StreamResponse with the same headers and status as the target response
stream_response = web.StreamResponse(
status=response.status,
headers=response.headers,
)

# Prepare the response to send headers
await stream_response.prepare(request)

# Stream the response from the target server to the client
async for chunk in response.content.iter_any():
await stream_response.write(chunk)

# Close the stream response once all chunks are sent
await stream_response.write_eof()
return stream_response

except aiohttp.ClientError as e:
return web.Response(text=f'Error: {str(e)}', status=500)

Expand Down
6 changes: 2 additions & 4 deletions tests/integration/docarray_v2/test_issues.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,19 +178,17 @@ async def test_issue_6090_get_params(streaming_deployment):

docs = []
url = (
f"htto://localhost:{streaming_deployment.port}/stream-simple?text=my_input_text"
f"http://localhost:{streaming_deployment.port}/stream-simple?text=my_input_text"
)
async with aiohttp.ClientSession() as session:

async with session.get(url) as resp:
async for chunk in resp.content.iter_any():
print(chunk)
events = chunk.split(b'event: ')[1:]
for event in events:
if event.startswith(b'update'):
parsed = event[HTTPClientlet.UPDATE_EVENT_PREFIX:].decode()
parsed = event[HTTPClientlet.UPDATE_EVENT_PREFIX :].decode()
parsed = SimpleInput.parse_raw(parsed)
print(parsed)
docs.append(parsed)
elif event.startswith(b'end'):
pass
Expand Down
Loading