Skip to content

Commit

Permalink
cli/conn/preview: Migrate multiprocess->multithreaded
Browse files Browse the repository at this point in the history
Requires recent patches to cyanodbc releasing GIL.
  • Loading branch information
detule committed Sep 20, 2020
1 parent 6815728 commit bea2288
Show file tree
Hide file tree
Showing 6 changed files with 144 additions and 297 deletions.
3 changes: 0 additions & 3 deletions odbcli/__main__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
from multiprocessing import set_start_method
from .cli import main
#main()

if __name__ == "__main__":
set_start_method('spawn')
main()
53 changes: 25 additions & 28 deletions odbcli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
from prompt_toolkit.utils import get_cwidth
from .app import sqlApp, ExitEX
from .layout import sqlAppLayout
from .conn import connStatus
from .executor import cmsg, commandStatus
from .conn import connStatus, executionStatus


def main():
Expand All @@ -31,42 +30,40 @@ def main():
# If it's a preview query we need an indication
# of where to run the query
if(app_res[0] == "preview"):
sqlConn = my_app.selected_object.conn
sql_conn = my_app.selected_object.conn
else:
sqlConn = my_app.active_conn
if sqlConn is not None:
sql_conn = my_app.active_conn
if sql_conn is not None:
#TODO also check that it is connected
try:
secho("Executing query...Ctrl-c to cancel", err = False)
start = time()
res = sqlConn.async_execute(app_res[1])
crsr = sql_conn.async_execute(app_res[1])
execution = time() - start
sqlConn.status = connStatus.IDLE
secho("Query execution...done", err = False)
if(app_res[0] == "preview"):
continue
if my_app.timing_enabled:
print("Time: %0.03fs" % execution)
if res.status == commandStatus.OKWRESULTS:
ht = my_app.application.output.get_size()[0]
formatted = sqlConn.formatted_fetch(ht - 3 - my_app.pager_reserve_lines, my_app.table_format)
sqlConn.status = connStatus.FETCHING
echo_via_pager(formatted)
elif res.status == commandStatus.OK:
secho("No rows returned\n", err = False)

if sql_conn.execution_status == executionStatus.FAIL:
err = sql_conn.execution_err
secho("Query error: %s\n" % err, err = True, fg = "red")
else:
secho("Query error: %s\n" % res.payload, err = True, fg = "red")
except BrokenPipeError:
my_app.logger.debug('BrokenPipeError caught. Recovering...', file = stderr)
if crsr.description:
cols = [col.name for col in crsr.description]
else:
cols = []
if len(cols):
ht = my_app.application.output.get_size()[0]
formatted = sql_conn.formatted_fetch(ht - 3 - my_app.pager_reserve_lines, cols, my_app.table_format)
sql_conn.status = connStatus.FETCHING
echo_via_pager(formatted)
else:
secho("No rows returned\n", err = False)
except KeyboardInterrupt:
secho("Cancelling query...", err = True, fg = 'red')
sqlConn.executor.terminate()
sqlConn.executor.join()
secho("Query cancelled.", err = True, fg='red')
#TODO: catch ConnectError
sqlConn.connect(start_executor = True)
sqlConn.status = connStatus.IDLE
# TODO check status of return
sqlConn.async_fetchdone()
# sqlConn.parent_chan.send(cmsg("fetchdone", None, None))
# sqlConn.parent_chan.recv()
secho("Cancelling query...", err = True, fg = "red")
sql_conn.cancel()
secho("Query cancelled.", err = True, fg = "red")
sql_conn.status = connStatus.IDLE
sql_conn.close_cursor()
197 changes: 93 additions & 104 deletions odbcli/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
from cyanodbc import connect, Connection, SQLGetInfo, Cursor, DatabaseError, ConnectError
from typing import Optional
from cli_helpers.tabular_output import TabularOutputFormatter
from multiprocessing import Process, Pipe
from logging import getLogger
from re import sub
from threading import Lock
from .executor import executor_process, cmsg, commandStatus
from threading import Lock, Event, Thread
from enum import IntEnum

formatter = TabularOutputFormatter()

Expand All @@ -17,6 +16,11 @@ class connStatus(Enum):
FETCHING = 3
ERROR = 4

class executionStatus(IntEnum):
OK = 0
FAIL = 1
OKWRESULTS = 2

class sqlConnection:
def __init__(
self,
Expand All @@ -32,8 +36,6 @@ def __init__(
self.username = username
self.password = password
self.status = connStatus.DISCONNECTED
self.executor: Process = None
self.parent_chan, self.child_chan = Pipe()
self.logger = getLogger(__name__)
self._quotechar = None
self._search_escapechar = None
Expand All @@ -48,6 +50,26 @@ def __init__(
# multiple auto-completion result queries before each has had a chance
# to return.
self._lock = Lock()
self._fetch_res: list = None
self._execution_status: executionStatus = executionStatus.OK
self._execution_err: str = None

@property
def execution_status(self) -> executionStatus:
""" Hold the lock here since it gets assigned in execute
which can be called in a different thread """
with self._lock:
res = self._execution_status
return res

@property
def execution_err(self) -> str:
""" Last execution error: Cleared prior to every execution.
Hold the lock here since it gets assigned in execute
which can be called in a different thread """
with self._lock:
res = self._execution_err
return res

@property
def quotechar(self) -> str:
Expand Down Expand Up @@ -85,8 +107,7 @@ def connect(
self,
username: str = "",
password: str = "",
force: bool = False,
start_executor: bool = False) -> None:
force: bool = False) -> None:
uid = username or self.username
pwd = password or self.password
conn_str = "DSN=" + self.dsn + ";"
Expand All @@ -103,97 +124,65 @@ def connect(
except ConnectError as e:
self.logger.error("Error while connecting: %s", str(e))
raise ConnectError(e)
if start_executor:
self.executor = Process(
target = executor_process,
args=(self.child_chan, self.logger.getEffectiveLevel(),))
self.executor.start()
self.logger.info("Started executor process: %d", self.executor.pid)
self.parent_chan.send(cmsg("connect", conn_str, None))
resp = self.parent_chan.recv()
# How do you handle failure here?
if not resp.status == commandStatus.OK:
self.logger.error("Error atempting to connect in executor process")
self.executor.terminate()
self.executor.join()
raise ConnectError("Connection failure in executor")

def async_lastresponse(self) -> cmsg:
if self.executor and self.executor.is_alive():
self.logger.debug("Asking for last message, executor pid %d",
self.executor.pid)
self.parent_chan.send(cmsg("lastresponse", None, None))
resp = self.parent_chan.recv()
# Above should never fail
return resp

def async_execute(self, query) -> cmsg:
if self.executor and self.executor.is_alive():
self.logger.debug("Sending query %s to pid %d",
query, self.executor.pid)
# TODO: message should carry
# current catalog. One might
# think that the main process
# connection always "follows"
# database changes since all
# main queries get executed
# against executor thread
# and main process conn only
# gets used for sidebar/auto
# completion. But, for
# example the MYSQL driver
# if starting without a
# declared database will just
# switch to the first db
# when running find_columns
self.parent_chan.send(
cmsg("execute", query, None))
# Will block but can be interrupted
res = self.parent_chan.recv()
self.logger.debug("Execution done")
self.query = query
# Check if catalog has changed in which case
# execute query locally
self.parent_chan.send(cmsg("currentcatalog", None, None))
rescat = self.parent_chan.recv()
if rescat.status == commandStatus.FAIL:
# TODO raise exception here since
# connection catalogs are possibly out of sync
# and we don't have a way of knowing
res = cmsg("execute", "", commandStatus.FAIL)
elif not rescat.payload == self.current_catalog():
# query changed the catalog
# so let's change the database locally
self.logger.debug("Execution changed catalog")
self.execute("USE " + rescat.payload)
else:
res = cmsg("execute", "", commandStatus.FAIL)
return res

def async_fetch(self, size) -> cmsg:
if self.executor and self.executor.is_alive():
self.logger.debug("Fetching size %d from pid %d",
size, self.executor.pid)
self.parent_chan.send(cmsg("fetch", size, None))
res = self.parent_chan.recv()
self.logger.debug("Fetching done")
else:
res = cmsg("fetch", "", commandStatus.FAIL)
return res

def async_fetchdone(self) -> cmsg:
if self.executor and self.executor.is_alive():
self.parent_chan.send(cmsg("fetchdone", None, None))
res = self.parent_chan.recv()
def fetchmany(self, size, event: Event = None) -> list:
if self.cursor:
self._fetch_res = self.cursor.fetchmany(size)
else:
res = cmsg("fetchdone", "", commandStatus.FAIL)
return res

def execute(self, query, parameters = None) -> Cursor:
self._fetch_res = []
if event is not None:
event.set()
return self._fetch_res

def async_fetchmany(self, size) -> list:
""" async_ is a misnomer here. It does execute fetch in a new thread
however it will also wait for execution to complete. At this time
this helps us with registering KeyboardInterrupt during cyanodbc.
fetchmany only; it may evolve to have more true async-like behavior.
"""
exec_event = Event()
t = Thread(
target = self.fetchmany,
kwargs = {"size": size, "event": exec_event},
daemon = True)
t.start()
# Will block but can be interrupted
exec_event.wait()
return self._fetch_res

def execute(self, query, parameters = None, event: Event = None) -> Cursor:
with self._lock:
self.close_cursor()
self.cursor = self.conn.cursor()
self.cursor.execute(query, parameters)
self.query = query
try:
self._execution_err = None
self.status = connStatus.EXECUTING
self.cursor.execute(query, parameters)
self.status = connStatus.IDLE
self._execution_status = executionStatus.OK
self.query = query
except DatabaseError as e:
self._execution_status = executionStatus.FAIL
self._execution_err = str(e)
self.logger.warning("Execution error: %s", str(e))
if event is not None:
event.set()
return self.cursor

def async_execute(self, query) -> Cursor:
""" async_ is a misnomer here. It does execute fetch in a new thread
however it will also wait for execution to complete. At this time
this helps us with registering KeyboardInterrupt during cyanodbc.
execute only; it may evolve to have more true async-like behavior.
"""
exec_event = Event()
t = Thread(
target = self.execute,
kwargs = {"query": query, "parameters": None, "event": exec_event},
daemon = True)
t.start()
# Will block but can be interrupted
exec_event.wait()
return self.cursor

def list_catalogs(self) -> list:
Expand Down Expand Up @@ -296,9 +285,6 @@ def get_info(self, code: int) -> str:
return self.conn.get_info(code)

def close(self) -> None:
if self.executor and self.executor.is_alive():
self.executor.terminate()
self.executor.join()
# TODO: When disconnecting
# We likely don't want to allow any exception to
# propagate. Catch DatabaseError?
Expand All @@ -311,24 +297,27 @@ def close_cursor(self) -> None:
self.cursor = None
self.query = None

def cancel(self) -> None:
if self.cursor:
self.cursor.cancel()
self.query = None

def preview_query(self, table, filter_query = "", limit = -1) -> str:
qry = "SELECT * FROM " + table + " " + filter_query
if limit > 0:
qry = qry + " LIMIT " + str(limit)
return qry

def formatted_fetch(self, size, format_name = "psql"):
def formatted_fetch(self, size, cols, format_name = "psql"):
while True:
res = self.async_fetch(size)
if (res.status == commandStatus.FAIL) or (not res.type == "fetch"):
return "Encountered a problem while fetching"
elif len(res.payload[1]) == 0:
res = self.async_fetchmany(size)
if len(res) < 1:
break
else:
yield "\n".join(
formatter.format_output(
res.payload[1],
res.payload[0],
res,
cols,
format_name = format_name))

connWrappers = {}
Expand Down
Loading

0 comments on commit bea2288

Please sign in to comment.