Skip to content

Commit

Permalink
Check datasource level perms for downloading csv and fetching results (
Browse files Browse the repository at this point in the history
…apache#2032)

* Check datasource level perms for downloading csv and fetching results

* Add index on the query table on the result key column
  • Loading branch information
bkyryliuk authored and Saleh Hindi committed Jun 9, 2017
1 parent 5bc5d49 commit 5328bf8
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 33 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""Add index on the result key to the query table.
Revision ID: f18570e03440
Revises: 1296d28ec131
Create Date: 2017-01-24 12:40:42.494787
"""
from alembic import op

# revision identifiers, used by Alembic.
revision = 'f18570e03440'
down_revision = '1296d28ec131'


def upgrade():
op.create_index(op.f('ix_query_results_key'), 'query', ['results_key'], unique=False)


def downgrade():
op.drop_index(op.f('ix_query_results_key'), table_name='query')
2 changes: 1 addition & 1 deletion superset/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2660,7 +2660,7 @@ class Query(Model):
rows = Column(Integer)
error_message = Column(Text)
# key used to store the results in the results backend
results_key = Column(String(64))
results_key = Column(String(64), index=True)

# Using Numeric in place of DateTime for sub-second precision
# stored as seconds since epoch, allowing for milliseconds
Expand Down
69 changes: 37 additions & 32 deletions superset/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,24 @@ def datasource_access_by_name(
return True
return False

def datasource_access_by_fullname(
self, database, full_table_name, schema):
table_name_pieces = full_table_name.split(".")
if len(table_name_pieces) == 2:
table_schema = table_name_pieces[0]
table_name = table_name_pieces[1]
else:
table_schema = schema
table_name = table_name_pieces[0]
return self.datasource_access_by_name(
database, table_name, schema=table_schema)

def rejected_datasources(self, sql, database, schema):
superset_query = sql_parse.SupersetQuery(sql)
return [
t for t in superset_query.tables if not
self.datasource_access_by_fullname(database, t, schema)]


class ListWidgetWithCheckboxes(ListWidget):
"""An alternative to list view that renders Boolean fields as checkboxes
Expand Down Expand Up @@ -2419,18 +2437,19 @@ def results(self, key):

blob = results_backend.get(key)
if blob:
json_payload = zlib.decompress(blob)
obj = json.loads(json_payload)
db_id = obj['query']['dbId']
session = db.session()
mydb = session.query(models.Database).filter_by(id=db_id).one()

if not self.database_access(mydb):
return json_error_response(
get_database_access_error_msg(mydb.database_name))
query = (
db.session.query(models.Query)
.filter_by(results_key=key)
.one()
)
rejected_tables = self.rejected_datasources(
query.sql, query.database, query.schema)
if rejected_tables:
return json_error_response(get_datasource_access_error_msg(
'{}'.format(rejected_tables)))

return Response(
json_payload,
zlib.decompress(blob),
status=200,
mimetype="application/json")
else:
Expand All @@ -2449,20 +2468,10 @@ def results(self, key):
@log_this
def sql_json(self):
"""Runs arbitrary sql and returns and json"""
def table_accessible(database, full_table_name, schema_name=None):
table_name_pieces = full_table_name.split(".")
if len(table_name_pieces) == 2:
table_schema = table_name_pieces[0]
table_name = table_name_pieces[1]
else:
table_schema = schema_name
table_name = table_name_pieces[0]
return self.datasource_access_by_name(
database, table_name, schema=table_schema)

async = request.form.get('runAsync') == 'true'
sql = request.form.get('sql')
database_id = request.form.get('database_id')
schema = request.form.get('schema') or None

session = db.session()
mydb = session.query(models.Database).filter_by(id=database_id).one()
Expand All @@ -2471,16 +2480,10 @@ def table_accessible(database, full_table_name, schema_name=None):
json_error_response(
'Database with id {} is missing.'.format(database_id))

superset_query = sql_parse.SupersetQuery(sql)
schema = request.form.get('schema')
schema = schema if schema else None

rejected_tables = [
t for t in superset_query.tables if not
table_accessible(mydb, t, schema_name=schema)]
rejected_tables = self.rejected_datasources(sql, mydb, schema)
if rejected_tables:
return json_error_response(
get_datasource_access_error_msg('{}'.format(rejected_tables)))
return json_error_response(get_datasource_access_error_msg(
'{}'.format(rejected_tables)))
session.commit()

select_as_cta = request.form.get('select_as_cta') == 'true'
Expand Down Expand Up @@ -2555,8 +2558,10 @@ def csv(self, client_id):
.one()
)

if not self.database_access(query.database):
flash(get_database_access_error_msg(query.database.database_name))
rejected_tables = self.rejected_datasources(
query.sql, query.database, query.schema)
if rejected_tables:
flash(get_datasource_access_error_msg('{}'.format(rejected_tables)))
return redirect('/')

sql = query.select_sql or query.sql
Expand Down

0 comments on commit 5328bf8

Please sign in to comment.