diff --git a/import_export/resources.py b/import_export/resources.py index cb2a6d279..5a3f6e0e1 100644 --- a/import_export/resources.py +++ b/import_export/resources.py @@ -8,6 +8,8 @@ from django.utils.safestring import mark_safe from django.utils.datastructures import SortedDict +from django.db import transaction +from django.conf import settings from .results import Error, Result, RowResult from .fields import Field @@ -17,6 +19,9 @@ ) +USE_TRANSACTIONS = getattr(settings, 'IMPORT_EXPORT_USE_TRANSACTIONS', False) + + class ResourceOptions(object): """ The inner Meta class allows for class-level configuration of how the @@ -41,6 +46,10 @@ class ResourceOptions(object): * ``widgets`` - dictionary defines widget kwargs for fields. + * ``use_transactions`` - Controls if import should use database + transactions. Default value is ``None`` meaning + ``settings.IMPORT_EXPORT_USE_TRANSACTIONS`` will be evaluated. + """ fields = None model = None @@ -49,6 +58,7 @@ class ResourceOptions(object): import_id_fields = ['id'] export_order = None widgets = None + use_transactions = None def __new__(cls, meta=None): overrides = {} @@ -89,6 +99,12 @@ class Resource(object): """ __metaclass__ = DeclarativeMetaclass + def get_use_transactions(self): + if self._meta.use_transactions is None: + return USE_TRANSACTIONS + else: + return self._meta.use_transactions + def get_fields(self): """ Returns fields in ``export_order`` order. @@ -207,9 +223,32 @@ def get_diff_headers(self): """ return self.get_export_headers() - def import_data(self, dataset, dry_run=False, raise_errors=False): + def import_data(self, dataset, dry_run=False, raise_errors=False, + use_transactions=None): + """ + Imports data from ``dataset``. + + ``use_transactions`` + If ``True`` import process will be processed inside transaction. + If ``dry_run`` is set, or error occurs, transaction will be rolled + back. + """ result = Result() + + if use_transactions is None: + use_transactions = self.get_use_transactions() + + if use_transactions is True: + # when transactions are used we want to create/update/delete object + # as transaction will be rolled back if dry_run is set + real_dry_run = False + transaction.enter_transaction_management() + transaction.managed(True) + else: + real_dry_run = dry_run + instance_loader = self._meta.instance_loader_class(self, dataset) + for row in dataset.dict: try: row_result = RowResult() @@ -223,24 +262,36 @@ def import_data(self, dataset, dry_run=False, raise_errors=False): if self.for_delete(row, instance): if new: row_result.import_type = RowResult.IMPORT_TYPE_SKIP - row_result.diff = self.get_diff(None, None, dry_run) + row_result.diff = self.get_diff(None, None, + real_dry_run) else: row_result.import_type = RowResult.IMPORT_TYPE_DELETE - self.delete_instance(instance, dry_run) + self.delete_instance(instance, real_dry_run) row_result.diff = self.get_diff(original, None, - dry_run) + real_dry_run) else: self.import_obj(instance, row) - self.save_instance(instance, dry_run) - self.save_m2m(instance, row, dry_run) + self.save_instance(instance, real_dry_run) + self.save_m2m(instance, row, real_dry_run) row_result.diff = self.get_diff(original, instance, - dry_run) + real_dry_run) except Exception, e: tb_info = traceback.format_exc(sys.exc_info()[2]) row_result.errors.append(Error(repr(e), tb_info)) if raise_errors: + if use_transactions: + transaction.rollback() + transaction.leave_transaction_management() raise result.rows.append(row_result) + + if use_transactions: + if dry_run or result.has_errors(): + transaction.rollback() + else: + transaction.commit() + transaction.leave_transaction_management() + return result def get_export_order(self): diff --git a/requirements/dev.txt b/requirements/dev.txt index b95038da1..605dfbd35 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -1,2 +1,3 @@ -r base.txt sphinx +mysql-python diff --git a/tests/core/tests/resources_tests.py b/tests/core/tests/resources_tests.py index dcdf1776c..bb3df0d44 100644 --- a/tests/core/tests/resources_tests.py +++ b/tests/core/tests/resources_tests.py @@ -1,7 +1,12 @@ from decimal import Decimal from datetime import date -from django.test import TestCase +from django.test import ( + TestCase, + TransactionTestCase, + skipUnlessDBFeature, + ) +from django.utils.html import strip_tags import tablib @@ -261,6 +266,38 @@ def test_m2m_import(self): self.assertIn(cat1, book.categories.all()) +class ModelResourceTransactionTest(TransactionTestCase): + + def setUp(self): + self.resource = BookResource() + + @skipUnlessDBFeature('supports_transactions') + def test_m2m_import_with_transactions(self): + cat1 = Category.objects.create(name='Cat 1') + headers = ['id', 'name', 'categories'] + row = [None, 'FooBook', "%s" % cat1.pk] + dataset = tablib.Dataset(row, headers=headers) + + result = self.resource.import_data(dataset, dry_run=True, + use_transactions=True) + + row_diff = result.rows[0].diff + fields = self.resource.get_fields() + + id_field = self.resource.fields['id'] + id_diff = row_diff[fields.index(id_field)] + #id diff should exists because in rollbacked transaction + #FooBook has been saved + self.assertTrue(id_diff) + + category_field = self.resource.fields['categories'] + categories_diff = row_diff[fields.index(category_field)] + self.assertEqual(strip_tags(categories_diff), unicode(cat1.pk)) + + #check that it is really rollbacked + self.assertFalse(Book.objects.filter(name='FooBook')) + + class ModelResourceFactoryTest(TestCase): def test_create(self): diff --git a/tests/settings.py b/tests/settings.py index b73ea3bc9..7184d58d1 100644 --- a/tests/settings.py +++ b/tests/settings.py @@ -1,4 +1,4 @@ -import os.path +import os INSTALLED_APPS = [ 'django.contrib.admin', @@ -20,9 +20,22 @@ STATIC_URL = '/static/' -DATABASES = { - 'default': { - 'ENGINE': 'django.db.backends.sqlite3', - 'NAME': os.path.join(os.path.dirname(__file__), 'database.db'), +if os.environ.get('IMPORT_EXPORT_TEST_TYPE') == 'mysql-innodb': + IMPORT_EXPORT_USE_TRANSACTIONS = True + DATABASES = { + 'default': { + 'ENGINE': 'django.db.backends.mysql', + 'NAME': 'import_export_test', + 'USER': os.environ.get('IMPORT_EXPORT_MYSQL_USER', 'root'), + 'OPTIONS': { + 'init_command': 'SET storage_engine=INNODB', + } + } + } +else: + DATABASES = { + 'default': { + 'ENGINE': 'django.db.backends.sqlite3', + 'NAME': os.path.join(os.path.dirname(__file__), 'database.db'), + } } -} diff --git a/tox.ini b/tox.ini index 1a48b44bf..b102eeef3 100644 --- a/tox.ini +++ b/tox.ini @@ -1,8 +1,13 @@ [tox] -envlist = py26, py27, py27-tablib-dev +envlist = py26, py27, py27-tablib-dev, py27-mysql-innodb [testenv] commands=python {toxinidir}/tests/manage.py test core [testenv:py27-tablib-dev] deps = -egit+https://github.com/kennethreitz/tablib.git#egg=tablib + +[testenv:py27-mysql-innodb] +deps = mysql-python +setenv = + IMPORT_EXPORT_TEST_TYPE=mysql-innodb