Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tornado #25

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 35 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,8 @@ class Subscription(graphene.ObjectType):


def resolve_count_seconds(
root,
info,
root,
info,
up_to=5
):
return Observable.interval(1000)\
Expand Down Expand Up @@ -202,4 +202,36 @@ from graphql_ws.django_channels import GraphQLSubscriptionConsumer
channel_routing = [
route_class(GraphQLSubscriptionConsumer, path=r"^/subscriptions"),
]
```
```

### Tornado
```python
from asyncio import Queue
from tornado import web, ioloop, websocket

from graphql_ws.tornado import TornadoSubscriptionServer


subscription_server = TornadoSubscriptionServer(schema)


class SubscriptionHandler(websocket.WebSocketHandler):
def initialize(self, sub_server):
self.subscription_server = subscription_server
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am trying the example locally, but this part was a bit confusing. We have the subscription_server defined before this class declaration.

Here we get self.subscription_server = subscription_server , but we also get a sub_server parameter?

The open() method uses not self.subscription_server , but the previously created subscription_server .

self.queue = Queue()

def select_subprotocol(self, subprotocols):
return 'graphql-ws'

def open(self):
ioloop.IOLoop.current().spawn_callback(subscription_server.handle, self)

async def on_message(self, message):
await self.queue.put(message)

async def recv(self):
return await self.queue.get()

app = web.Application([(r"/subscriptions", SubscriptionHandler)]).listen(8000)
ioloop.IOLoop.current().start()
```
Empty file added examples/tornado/__init__.py
Empty file.
47 changes: 47 additions & 0 deletions examples/tornado/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from asyncio import Queue
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the bigger problem than pathlib for py2 support and would need figuring out...

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wouldn't worry about py2 support. We'll just call this a py3-only supported backend :)

from tornado import web, ioloop, websocket

from graphene_tornado.tornado_graphql_handler import TornadoGraphQLHandler

from graphql_ws.tornado import TornadoSubscriptionServer
from graphql_ws.constants import GRAPHQL_WS

from .template import render_graphiql
from .schema import schema


class GraphiQLHandler(web.RequestHandler):
def get(self):
self.finish(render_graphiql())


class SubscriptionHandler(websocket.WebSocketHandler):
def initialize(self, subscription_server):
self.subscription_server = subscription_server
self.queue = Queue(100)

def select_subprotocol(self, subprotocols):
return GRAPHQL_WS

def open(self):
ioloop.IOLoop.current().spawn_callback(self.subscription_server.handle, self)

async def on_message(self, message):
await self.queue.put(message)

async def recv(self):
return await self.queue.get()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the project where we have been using this code for a while, we ran into issues with the blocking queue here, where some subscriptions would hang, causing memory leaks.

We added a recv_nowait. This way the code calling the queue had a chance to look if the connection was interrupted, and then clear up the remaining asyncio tasks used for the subscription.



subscription_server = TornadoSubscriptionServer(schema)

app = web.Application([
(r"/graphql$", TornadoGraphQLHandler, dict(
schema=schema)),
(r"/subscriptions", SubscriptionHandler, dict(
subscription_server=subscription_server)),
(r"/graphiql$", GraphiQLHandler),
])

app.listen(8000)
ioloop.IOLoop.current().start()
3 changes: 3 additions & 0 deletions examples/tornado/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
graphql_ws
tornado
graphene>=2.0
34 changes: 34 additions & 0 deletions examples/tornado/schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import random
import asyncio
import graphene


class Query(graphene.ObjectType):
base = graphene.String()


class RandomType(graphene.ObjectType):
seconds = graphene.Int()
random_int = graphene.Int()


class Subscription(graphene.ObjectType):
count_seconds = graphene.Float(up_to=graphene.Int())
random_int = graphene.Field(RandomType)

async def resolve_count_seconds(root, info, up_to=5):
for i in range(up_to):
print("YIELD SECOND", i)
yield i
await asyncio.sleep(1.)
yield up_to

async def resolve_random_int(root, info):
i = 0
while True:
yield RandomType(seconds=i, random_int=random.randint(0, 500))
await asyncio.sleep(1.)
i += 1


schema = graphene.Schema(query=Query, subscription=Subscription)
125 changes: 125 additions & 0 deletions examples/tornado/template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@

from string import Template


def render_graphiql():
return Template('''
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8" />
<title>GraphiQL</title>
<meta name="robots" content="noindex" />
<style>
html, body {
height: 100%;
margin: 0;
overflow: hidden;
width: 100%;
}
</style>
<link href="//cdn.jsdelivr.net/graphiql/${GRAPHIQL_VERSION}/graphiql.css" rel="stylesheet" />
<script src="//cdn.jsdelivr.net/fetch/0.9.0/fetch.min.js"></script>
<script src="//cdn.jsdelivr.net/react/15.0.0/react.min.js"></script>
<script src="//cdn.jsdelivr.net/react/15.0.0/react-dom.min.js"></script>
<script src="//cdn.jsdelivr.net/graphiql/${GRAPHIQL_VERSION}/graphiql.min.js"></script>
<script src="//unpkg.com/subscriptions-transport-ws@${SUBSCRIPTIONS_TRANSPORT_VERSION}/browser/client.js"></script>
<script src="//unpkg.com/[email protected]/browser/client.js"></script>
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We updated our template recently, using the latest version of graphiql.js, and had issues with the graphiql-subscriptions-fetcher.

Unfortunately the repo for this dependency is archived and hasn't been updated in a while, which causes subscriptions to fail when you update it. So for this PR I think we should leave the current versions, even if old.

Ref: https://github.com/cylc/cylc-ui/pull/457/files#diff-1696770be90f4f901322de74987fdf90R23

</head>
<body>
<script>
// Collect the URL parameters
var parameters = {};
window.location.search.substr(1).split('&').forEach(function (entry) {
var eq = entry.indexOf('=');
if (eq >= 0) {
parameters[decodeURIComponent(entry.slice(0, eq))] =
decodeURIComponent(entry.slice(eq + 1));
}
});
// Produce a Location query string from a parameter object.
function locationQuery(params, location) {
return (location ? location: '') + '?' + Object.keys(params).map(function (key) {
return encodeURIComponent(key) + '=' +
encodeURIComponent(params[key]);
}).join('&');
}
// Derive a fetch URL from the current URL, sans the GraphQL parameters.
var graphqlParamNames = {
query: true,
variables: true,
operationName: true
};
var otherParams = {};
for (var k in parameters) {
if (parameters.hasOwnProperty(k) && graphqlParamNames[k] !== true) {
otherParams[k] = parameters[k];
}
}
var fetcher;
if (true) {
var subscriptionsClient = new window.SubscriptionsTransportWs.SubscriptionClient('${subscriptionsEndpoint}', {
reconnect: true
});
fetcher = window.GraphiQLSubscriptionsFetcher.graphQLFetcher(subscriptionsClient, graphQLFetcher);
} else {
fetcher = graphQLFetcher;
}
// We don't use safe-serialize for location, because it's not client input.
var fetchURL = locationQuery(otherParams, '${endpointURL}');
// Defines a GraphQL fetcher using the fetch API.
function graphQLFetcher(graphQLParams) {
return fetch(fetchURL, {
method: 'post',
headers: {
'Accept': 'application/json',
'Content-Type': 'application/json',
},
body: JSON.stringify(graphQLParams),
credentials: 'include',
}).then(function (response) {
return response.text();
}).then(function (responseBody) {
try {
return JSON.parse(responseBody);
} catch (error) {
return responseBody;
}
});
}
// When the query and variables string is edited, update the URL bar so
// that it can be easily shared.
function onEditQuery(newQuery) {
parameters.query = newQuery;
updateURL();
}
function onEditVariables(newVariables) {
parameters.variables = newVariables;
updateURL();
}
function onEditOperationName(newOperationName) {
parameters.operationName = newOperationName;
updateURL();
}
function updateURL() {
history.replaceState(null, null, locationQuery(parameters) + window.location.hash);
}
// Render <GraphiQL /> into the body.
ReactDOM.render(
React.createElement(GraphiQL, {
fetcher: fetcher,
onEditQuery: onEditQuery,
onEditVariables: onEditVariables,
onEditOperationName: onEditOperationName,
}),
document.body
);
</script>
</body>
</html>''').substitute(
GRAPHIQL_VERSION='0.10.2',
SUBSCRIPTIONS_TRANSPORT_VERSION='0.7.0',
subscriptionsEndpoint='ws://localhost:8000/subscriptions',
# subscriptionsEndpoint='ws://localhost:5000/',
endpointURL='/graphql',
)
114 changes: 114 additions & 0 deletions graphql_ws/tornado.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
from inspect import isawaitable

from asyncio import ensure_future, wait, shield
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And all this...

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think Tornado also offers something similar to ensure_future (not sure about wait and shield tho)

from tornado.websocket import WebSocketClosedError
from graphql.execution.executors.asyncio import AsyncioExecutor

from .base import ConnectionClosedException, BaseConnectionContext, BaseSubscriptionServer
from .observable_aiter import setup_observable_extension

from .constants import (
GQL_CONNECTION_ACK,
GQL_CONNECTION_ERROR,
GQL_COMPLETE
)

setup_observable_extension()


class TornadoConnectionContext(BaseConnectionContext):
async def receive(self):
try:
msg = await self.ws.recv()
return msg
except WebSocketClosedError:
raise ConnectionClosedException()

async def send(self, data):
if self.closed:
return
await self.ws.write_message(data)

@property
def closed(self):
return self.ws.close_code is not None

async def close(self, code):
await self.ws.close(code)


class TornadoSubscriptionServer(BaseSubscriptionServer):
def __init__(self, schema, keep_alive=True, loop=None):
self.loop = loop
super().__init__(schema, keep_alive)

def get_graphql_params(self, *args, **kwargs):
params = super(TornadoSubscriptionServer,
self).get_graphql_params(*args, **kwargs)
return dict(params, return_promise=True, executor=AsyncioExecutor(loop=self.loop))

async def _handle(self, ws, request_context):
connection_context = TornadoConnectionContext(ws, request_context)
await self.on_open(connection_context)
pending = set()
while True:
try:
if connection_context.closed:
raise ConnectionClosedException()
message = await connection_context.receive()
except ConnectionClosedException:
break
finally:
if pending:
(_, pending) = await wait(pending, timeout=0, loop=self.loop)

task = ensure_future(
self.on_message(connection_context, message), loop=self.loop)
pending.add(task)

self.on_close(connection_context)
for task in pending:
task.cancel()

async def handle(self, ws, request_context=None):
await shield(self._handle(ws, request_context), loop=self.loop)

async def on_open(self, connection_context):
pass

def on_close(self, connection_context):
remove_operations = list(connection_context.operations.keys())
for op_id in remove_operations:
self.unsubscribe(connection_context, op_id)

async def on_connect(self, connection_context, payload):
pass

async def on_connection_init(self, connection_context, op_id, payload):
try:
await self.on_connect(connection_context, payload)
await self.send_message(connection_context, op_type=GQL_CONNECTION_ACK)
except Exception as e:
await self.send_error(connection_context, op_id, e, GQL_CONNECTION_ERROR)
await connection_context.close(1011)

async def on_start(self, connection_context, op_id, params):
execution_result = self.execute(
connection_context.request_context, params)

if isawaitable(execution_result):
execution_result = await execution_result

if not hasattr(execution_result, '__aiter__'):
await self.send_execution_result(connection_context, op_id, execution_result)
else:
iterator = await execution_result.__aiter__()
connection_context.register_operation(op_id, iterator)
async for single_result in iterator:
if not connection_context.has_operation(op_id):
break
await self.send_execution_result(connection_context, op_id, single_result)
await self.send_message(connection_context, op_id, GQL_COMPLETE)

async def on_stop(self, connection_context, op_id):
self.unsubscribe(connection_context, op_id)