diff --git a/pyorient/ogm/property.py b/pyorient/ogm/property.py index 60f08064..ceffe7c3 100644 --- a/pyorient/ogm/property.py +++ b/pyorient/ogm/property.py @@ -105,6 +105,16 @@ def encode_name(name): raise ValueError('Prohibited character in property name: {}'.format(name)) return name + @staticmethod + def encode_operator(value): + """Encode the correct SQL operator based on the value""" + + # If the value is "None" (SQL null) use " is " all other cases use " = " + if value is None: + return u' is ' + else: + return u' = ' + @staticmethod def encode_value(value, expressions): from pyorient.ogm.what import What @@ -227,4 +237,3 @@ def __call__(self, graph, attr): :param attr: Name of attribute specifying PreOp """ pass - diff --git a/pyorient/ogm/query.py b/pyorient/ogm/query.py index 3d3b9266..28179fec 100644 --- a/pyorient/ogm/query.py +++ b/pyorient/ogm/query.py @@ -107,7 +107,7 @@ def format(self, *args, **kwargs): new_query = self.from_string(self.compile().format(*[encode(arg) for arg in args], **{k:encode(v) for k,v in kwargs.items()}), self._graph) new_query.source_name = self.source_name - new_query._class_props = self._class_props + new_query._class_props = self._class_props new_query._params = self._params return new_query @@ -527,7 +527,8 @@ def build_assign_what(self, k, v): (u'(' + str(v) + ')' if isinstance(v, RetrievalCommand) else self.build_what(v)) def build_assign_vertex(self, k, v): - return PropertyEncoder.encode_name(k) + u' = ' + \ + return PropertyEncoder.encode_name(k) + \ + PropertyEncoder.encode_operator(v) + \ ArgConverter.convert_to(ArgConverter.Vertex, v, self) def build_lets(self, params): @@ -733,4 +734,3 @@ def __exit__(self, type, value, traceback): del self.params[k] else: self.params[k] = v - diff --git a/tests/test_ogm.py b/tests/test_ogm.py index b7742ffd..48056c82 100644 --- a/tests/test_ogm.py +++ b/tests/test_ogm.py @@ -1681,7 +1681,7 @@ def testTokens(self): not_fun = enjoy_query.format(False).all() self.assertEqual(len(not_fun), 3) - next_query = g.next.query().what(outV().as_('o'), inV().as_('i')).filter(OGMTokensCase.Next.probability > 0.5) + next_query = g.next.query().what(outV().as_('o'), inV().as_('i')).filter(OGMTokensCase.Next.probability > 0.5) uncached = next_query.query().what(unionall('o', 'i')) cache = {} @@ -1709,7 +1709,7 @@ def testTokens(self): cached = token_sub.query().what(unionall('o', 'i')).fetch_plan('*:1', cache) self.assertIsInstance(cached.compile(), STR_TYPES) - + probable = cached.all() self.assertEqual(len(probable), 3) for p in probable: @@ -1817,4 +1817,85 @@ def testPretty(self): print(q.pretty()) print('\n') +class OGMDictQueryTestCase(unittest.TestCase): + + Node = declarative_node() + + class DictQueryTest(Node): + element_type = 'dict_query_test' + element_plural = 'dict_query_tests' + + column_1 = String() + column_2 = String() + + def __init__(self, *args, **kwargs): + super(OGMDictQueryTestCase, self).__init__(*args, **kwargs) + self.g = None + + def setUp(self): + + g = self.g = Graph(Config.from_url('dict_queries', 'root', 'root', + initial_drop=True)) + + g.create_all(OGMDictQueryTestCase.Node.registry) + + self.db_data = [ + {"column_1":"Collection 1", "column_2" : "Test"}, + {"column_1":"Collection 1"}, # this will populate a null in column_2 + {"column_1":"Collection 2", "column_2" : "Test"}, + {"column_1":"Collection 2", "column_2" : None}, + {"column_1":"Collection 3", "column_2" : ""}, + ] + + for data in self.db_data: + g.dict_query_tests.create(**data) + + def testDictBasicQueryTest(self): + assert len(OGMDictQueryTestCase.Node.registry) == 1 + g = self.g + # Validate the setup was ok + query_res = g.dict_query_tests.query().all() + assert len(query_res) == 5, "Expected 4 tuples, retrieved {}".format(len(query_res)) + + # Test a query where the kwargs contain a full match + query_res = g.dict_query_tests.query(**self.db_data[0]).all() + assert len(query_res) == 1, "Expected 1 tuple, retrieved {}".format(len(query_res)) + assert query_res[0].column_1 == self.db_data[0]["column_1"], "Retrieved tuple did not match expected data" + assert query_res[0].column_2 == self.db_data[0]["column_2"], "Retrieved tuple did not match expected data" + + # Test a query where the kwargs contain a partial set of values + # Where the kwargs are created using missing values the missing values will be assigned null + # Where the query is made with missing values the missing values will match any value + query_res = g.dict_query_tests.query(**self.db_data[1]).all() + assert len(query_res) == 2, "Expected 2 tuples, retrieved {}".format(len(query_res)) + # allow for tuples returned in either order + assert ( + query_res[0].column_1 == self.db_data[0]["column_1"] and + query_res[0].column_2 == self.db_data[0]["column_2"] and + query_res[1].column_1 == self.db_data[1]["column_1"] and + query_res[1].column_2 is None + ) or ( + query_res[1].column_1 == self.db_data[0]["column_1"] and + query_res[1].column_2 == self.db_data[0]["column_2"] and + query_res[0].column_1 == self.db_data[1]["column_1"] and + query_res[0].column_2 is None + ), "Retrieved tuples did not match expected data" + + # Test a query where the kwargs contain a full match + query_res = g.dict_query_tests.query(**self.db_data[2]).all() + assert len(query_res) == 1, "Expected 1 tuple, retrieved {}".format(len(query_res)) + assert query_res[0].column_1 == self.db_data[2]["column_1"], "Retrieved tuple did not match expected data" + assert query_res[0].column_2 == self.db_data[2]["column_2"], "Retrieved tuple did not match expected data" + + # Test a query where the kwargs contain a full match, where one of the values is None/null + query_res = g.dict_query_tests.query(**self.db_data[3]).all() + assert len(query_res) == 1, "Expected 1 tuple, retrieved {}".format(len(query_res)) + assert query_res[0].column_1 == self.db_data[3]["column_1"], "Retrieved tuple did not match expected data" + assert query_res[0].column_2 == self.db_data[3]["column_2"], "Retrieved tuple did not match expected data" + + # Test a query where the kwargs contain a full match, where one of the values is an empty string + query_res = g.dict_query_tests.query(**self.db_data[4]).all() + assert len(query_res) == 1, "Expected 1 tuple, retrieved {}".format(len(query_res)) + assert query_res[0].column_1 == self.db_data[4]["column_1"], "Retrieved tuple did not match expected data" + assert query_res[0].column_2 == self.db_data[4]["column_2"], "Retrieved tuple did not match expected data"