From c198535292869ae1ced6d66d08de536188be7e05 Mon Sep 17 00:00:00 2001 From: Bogdan Date: Thu, 20 Oct 2016 09:17:51 -0700 Subject: [PATCH] Change slice ids in the position json during dashboard import. (#1380) * Change slice ids in the position json during dashboard import. * Update slice ids in the dashboard json metadata. --- caravel/models.py | 57 +++++++++++++++++++++++++++++++++-- tests/import_export_tests.py | 58 ++++++++++++++++++++++++++++++------ 2 files changed, 103 insertions(+), 12 deletions(-) diff --git a/caravel/models.py b/caravel/models.py index e4734b697d671..ddfdee252afeb 100644 --- a/caravel/models.py +++ b/caravel/models.py @@ -449,6 +449,12 @@ def params(self): def params(self, value): self.json_metadata = value + @property + def position_array(self): + if self.position_json: + return json.loads(self.position_json) + return [] + @classmethod def import_obj(cls, dashboard_to_import, import_time=None): """Imports the dashboard from the object to the database. @@ -460,6 +466,28 @@ def import_obj(cls, dashboard_to_import, import_time=None): to import/export dashboards between multiple caravel instances. Audit metadata isn't copies over. """ + def alter_positions(dashboard, old_to_new_slc_id_dict): + """ Updates slice_ids in the position json. + + Sample position json: + [{ + "col": 5, + "row": 10, + "size_x": 4, + "size_y": 2, + "slice_id": "3610" + }] + """ + position_array = dashboard.position_array + for position in position_array: + if 'slice_id' not in position: + continue + old_slice_id = int(position['slice_id']) + if old_slice_id in old_to_new_slc_id_dict: + position['slice_id'] = '{}'.format( + old_to_new_slc_id_dict[old_slice_id]) + dashboard.position_json = json.dumps(position_array) + logging.info('Started import of the dashboard: {}' .format(dashboard_to_import.to_json())) session = db.session @@ -468,11 +496,25 @@ def import_obj(cls, dashboard_to_import, import_time=None): # copy slices object as Slice.import_slice will mutate the slice # and will remove the existing dashboard - slice association slices = copy(dashboard_to_import.slices) - slice_ids = set() + old_to_new_slc_id_dict = {} + new_filter_immune_slices = [] + new_expanded_slices = {} + i_params_dict = dashboard_to_import.params_dict for slc in slices: logging.info('Importing slice {} from the dashboard: {}'.format( slc.to_json(), dashboard_to_import.dashboard_title)) - slice_ids.add(Slice.import_obj(slc, import_time=import_time)) + new_slc_id = Slice.import_obj(slc, import_time=import_time) + old_to_new_slc_id_dict[slc.id] = new_slc_id + # update json metadata that deals with slice ids + if ('filter_immune_slices' in i_params_dict and + slc.id in i_params_dict['filter_immune_slices']): + new_filter_immune_slices.append(new_slc_id) + new_slc_id_str = '{}'.format(new_slc_id) + old_slc_id_str = '{}'.format(slc.id) + if ('expanded_slices' in i_params_dict and + old_slc_id_str in i_params_dict['expanded_slices']): + new_expanded_slices[new_slc_id_str] = ( + i_params_dict['expanded_slices'][old_slc_id_str]) # override the dashboard existing_dashboard = None @@ -483,8 +525,17 @@ def import_obj(cls, dashboard_to_import, import_time=None): existing_dashboard = dash dashboard_to_import.id = None + alter_positions(dashboard_to_import, old_to_new_slc_id_dict) dashboard_to_import.alter_params(import_time=import_time) - new_slices = session.query(Slice).filter(Slice.id.in_(slice_ids)).all() + if new_expanded_slices: + dashboard_to_import.alter_params( + expanded_slices=new_expanded_slices) + if new_filter_immune_slices: + dashboard_to_import.alter_params( + filter_immune_slices=new_filter_immune_slices) + + new_slices = session.query(Slice).filter( + Slice.id.in_(old_to_new_slc_id_dict.values())).all() if existing_dashboard: existing_dashboard.override(dashboard_to_import) diff --git a/tests/import_export_tests.py b/tests/import_export_tests.py index 6a929a6c4292a..4b36881e1104d 100644 --- a/tests/import_export_tests.py +++ b/tests/import_export_tests.py @@ -97,6 +97,10 @@ def create_table(self, name, schema='', id=0, cols_names=[], metric_names=[]): def get_slice(self, slc_id): return db.session.query(models.Slice).filter_by(id=slc_id).first() + def get_slice_by_name(self, name): + return db.session.query(models.Slice).filter_by( + slice_name=name).first() + def get_dash(self, dash_id): return db.session.query(models.Dashboard).filter_by( id=dash_id).first() @@ -113,12 +117,11 @@ def get_table_by_name(self, name): return db.session.query(models.SqlaTable).filter_by( table_name=name).first() - def assert_dash_equals(self, expected_dash, actual_dash): + def assert_dash_equals(self, expected_dash, actual_dash, + check_position=True): self.assertEquals(expected_dash.slug, actual_dash.slug) self.assertEquals( expected_dash.dashboard_title, actual_dash.dashboard_title) - self.assertEquals( - expected_dash.position_json, actual_dash.position_json) self.assertEquals( len(expected_dash.slices), len(actual_dash.slices)) expected_slices = sorted( @@ -127,6 +130,9 @@ def assert_dash_equals(self, expected_dash, actual_dash): actual_dash.slices, key=lambda s: s.slice_name) for e_slc, a_slc in zip(expected_slices, actual_slices): self.assert_slice_equals(e_slc, a_slc) + if check_position: + self.assertEquals( + expected_dash.position_json, actual_dash.position_json) def assert_table_equals(self, expected_ds, actual_ds): self.assertEquals(expected_ds.table_name, actual_ds.table_name) @@ -221,7 +227,6 @@ def test_import_2_slices_for_same_table(self): self.assert_slice_equals(slc_1, imported_slc_1) self.assertEquals(imported_slc_1.datasource.perm, imported_slc_1.perm) - self.assertEquals(table_id, imported_slc_2.datasource_id) self.assert_slice_equals(slc_2, imported_slc_2) self.assertEquals(imported_slc_2.datasource.perm, imported_slc_2.perm) @@ -246,12 +251,22 @@ def test_import_empty_dashboard(self): imported_dash_id = models.Dashboard.import_obj( empty_dash, import_time=1989) imported_dash = self.get_dash(imported_dash_id) - self.assert_dash_equals(empty_dash, imported_dash) + self.assert_dash_equals( + empty_dash, imported_dash, check_position=False) def test_import_dashboard_1_slice(self): slc = self.create_slice('health_slc', id=10006) dash_with_1_slice = self.create_dashboard( 'dash_with_1_slice', slcs=[slc], id=10002) + dash_with_1_slice.position_json = """ + [{{ + "col": 5, + "row": 10, + "size_x": 4, + "size_y": 2, + "slice_id": "{}" + }}] + """.format(slc.id) imported_dash_id = models.Dashboard.import_obj( dash_with_1_slice, import_time=1990) imported_dash = self.get_dash(imported_dash_id) @@ -259,15 +274,27 @@ def test_import_dashboard_1_slice(self): expected_dash = self.create_dashboard( 'dash_with_1_slice', slcs=[slc], id=10002) make_transient(expected_dash) - self.assert_dash_equals(expected_dash, imported_dash) + self.assert_dash_equals( + expected_dash, imported_dash, check_position=False) self.assertEquals({"remote_id": 10002, "import_time": 1990}, json.loads(imported_dash.json_metadata)) + expected_position = dash_with_1_slice.position_array + expected_position[0]['slice_id'] = '{}'.format( + imported_dash.slices[0].id) + self.assertEquals(expected_position, imported_dash.position_array) + def test_import_dashboard_2_slices(self): e_slc = self.create_slice('e_slc', id=10007, table_name='energy_usage') b_slc = self.create_slice('b_slc', id=10008, table_name='birth_names') dash_with_2_slices = self.create_dashboard( 'dash_with_2_slices', slcs=[e_slc, b_slc], id=10003) + dash_with_2_slices.json_metadata = json.dumps({ + "remote_id": 10003, + "filter_immune_slices": [e_slc.id], + "expanded_slices": {e_slc.id: True, b_slc.id: False} + }) + imported_dash_id = models.Dashboard.import_obj( dash_with_2_slices, import_time=1991) imported_dash = self.get_dash(imported_dash_id) @@ -275,8 +302,20 @@ def test_import_dashboard_2_slices(self): expected_dash = self.create_dashboard( 'dash_with_2_slices', slcs=[e_slc, b_slc], id=10003) make_transient(expected_dash) - self.assert_dash_equals(imported_dash, expected_dash) - self.assertEquals({"remote_id": 10003, "import_time": 1991}, + self.assert_dash_equals( + imported_dash, expected_dash, check_position=False) + i_e_slc = self.get_slice_by_name('e_slc') + i_b_slc = self.get_slice_by_name('b_slc') + expected_json_metadata = { + "remote_id": 10003, + "import_time": 1991, + "filter_immune_slices": [i_e_slc.id], + "expanded_slices": { + '{}'.format(i_e_slc.id): True, + '{}'.format(i_b_slc.id): False + } + } + self.assertEquals(expected_json_metadata, json.loads(imported_dash.json_metadata)) def test_import_override_dashboard_2_slices(self): @@ -304,7 +343,8 @@ def test_import_override_dashboard_2_slices(self): 'override_dashboard_new', slcs=[e_slc, b_slc, c_slc], id=10004) make_transient(expected_dash) imported_dash = self.get_dash(imported_dash_id_2) - self.assert_dash_equals(expected_dash, imported_dash) + self.assert_dash_equals( + expected_dash, imported_dash, check_position=False) self.assertEquals({"remote_id": 10004, "import_time": 1992}, json.loads(imported_dash.json_metadata))