diff --git a/.github/workflows/ci-workflow.yml b/.github/workflows/ci-workflow.yml index 740cce9f6db..2512178f9a3 100644 --- a/.github/workflows/ci-workflow.yml +++ b/.github/workflows/ci-workflow.yml @@ -359,7 +359,6 @@ jobs: run: | ulimit -n 10240 source /Users/runner/.cargo/env - cargo update -p cc --precise '1.0.83' make rusttest - name: Rust example diff --git a/Makefile b/Makefile index c779e796a1f..d5e99124fc3 100644 --- a/Makefile +++ b/Makefile @@ -128,11 +128,11 @@ nodejstest: nodejs javatest: java ifeq ($(OS),Windows_NT) $(call mkdirp,tools/java_api/build/test) && cd tools/java_api/ && \ - javac -d build/test -cp ".;build/kuzu_java.jar;third_party/junit-platform-console-standalone-1.9.3.jar" -sourcepath src/test/java/com/kuzudb/test/*.java && \ + javac -d build/test -cp ".;build/kuzu_java.jar;third_party/junit-platform-console-standalone-1.9.3.jar" src/test/java/com/kuzudb/test/*.java && \ java -jar third_party/junit-platform-console-standalone-1.9.3.jar -cp ".;build/kuzu_java.jar;build/test/" --scan-classpath --include-package=com.kuzudb.java_test --details=verbose else $(call mkdirp,tools/java_api/build/test) && cd tools/java_api/ && \ - javac -d build/test -cp ".:build/kuzu_java.jar:third_party/junit-platform-console-standalone-1.9.3.jar" -sourcepath src/test/java/com/kuzudb/test/*.java && \ + javac -d build/test -cp ".:build/kuzu_java.jar:third_party/junit-platform-console-standalone-1.9.3.jar" src/test/java/com/kuzudb/test/*.java && \ java -jar third_party/junit-platform-console-standalone-1.9.3.jar -cp ".:build/kuzu_java.jar:build/test/" --scan-classpath --include-package=com.kuzudb.java_test --details=verbose endif diff --git a/src/c_api/connection.cpp b/src/c_api/connection.cpp index add525690e2..acd9b91e3e8 100644 --- a/src/c_api/connection.cpp +++ b/src/c_api/connection.cpp @@ -67,7 +67,7 @@ kuzu_prepared_statement* kuzu_connection_prepare(kuzu_connection* connection, co auto* c_prepared_statement = new kuzu_prepared_statement; c_prepared_statement->_prepared_statement = prepared_statement; c_prepared_statement->_bound_values = - new std::unordered_map>; + new std::unordered_map>; return c_prepared_statement; } @@ -75,11 +75,20 @@ kuzu_query_result* kuzu_connection_execute( kuzu_connection* connection, kuzu_prepared_statement* prepared_statement) { auto prepared_statement_ptr = static_cast(prepared_statement->_prepared_statement); - auto bound_values = static_cast>*>( + auto bound_values = static_cast>*>( prepared_statement->_bound_values); - auto query_result = static_cast(connection->_connection) - ->executeWithParams(prepared_statement_ptr, *bound_values) - .release(); + + // Must copy the parameters for safety, and so that the parameters in the prepared statement + // stay the same. + std::unordered_map> copied_bound_values; + for (auto& [name, value] : *bound_values) { + copied_bound_values.emplace(name, value->copy()); + } + + auto query_result = + static_cast(connection->_connection) + ->executeWithParams(prepared_statement_ptr, std::move(copied_bound_values)) + .release(); if (query_result == nullptr) { return nullptr; } diff --git a/src/c_api/prepared_statement.cpp b/src/c_api/prepared_statement.cpp index 8133616331b..04479d75777 100644 --- a/src/c_api/prepared_statement.cpp +++ b/src/c_api/prepared_statement.cpp @@ -8,10 +8,10 @@ using namespace kuzu::common; using namespace kuzu::main; void kuzu_prepared_statement_bind_cpp_value(kuzu_prepared_statement* prepared_statement, - const char* param_name, const std::shared_ptr& value) { - auto* bound_values = static_cast>*>( + const char* param_name, std::unique_ptr value) { + auto* bound_values = static_cast>*>( prepared_statement->_bound_values); - bound_values->insert({param_name, value}); + bound_values->insert({param_name, std::move(value)}); } void kuzu_prepared_statement_destroy(kuzu_prepared_statement* prepared_statement) { @@ -22,7 +22,7 @@ void kuzu_prepared_statement_destroy(kuzu_prepared_statement* prepared_statement delete static_cast(prepared_statement->_prepared_statement); } if (prepared_statement->_bound_values != nullptr) { - delete static_cast>*>( + delete static_cast>*>( prepared_statement->_bound_values); } delete prepared_statement; @@ -48,97 +48,97 @@ char* kuzu_prepared_statement_get_error_message(kuzu_prepared_statement* prepare void kuzu_prepared_statement_bind_bool( kuzu_prepared_statement* prepared_statement, const char* param_name, bool value) { - auto value_ptr = std::make_shared(value); - kuzu_prepared_statement_bind_cpp_value(prepared_statement, param_name, value_ptr); + auto value_ptr = std::make_unique(value); + kuzu_prepared_statement_bind_cpp_value(prepared_statement, param_name, std::move(value_ptr)); } void kuzu_prepared_statement_bind_int64( kuzu_prepared_statement* prepared_statement, const char* param_name, int64_t value) { - auto value_ptr = std::make_shared(value); - kuzu_prepared_statement_bind_cpp_value(prepared_statement, param_name, value_ptr); + auto value_ptr = std::make_unique(value); + kuzu_prepared_statement_bind_cpp_value(prepared_statement, param_name, std::move(value_ptr)); } void kuzu_prepared_statement_bind_int32( kuzu_prepared_statement* prepared_statement, const char* param_name, int32_t value) { - auto value_ptr = std::make_shared(value); - kuzu_prepared_statement_bind_cpp_value(prepared_statement, param_name, value_ptr); + auto value_ptr = std::make_unique(value); + kuzu_prepared_statement_bind_cpp_value(prepared_statement, param_name, std::move(value_ptr)); } void kuzu_prepared_statement_bind_int16( kuzu_prepared_statement* prepared_statement, const char* param_name, int16_t value) { - auto value_ptr = std::make_shared(value); - kuzu_prepared_statement_bind_cpp_value(prepared_statement, param_name, value_ptr); + auto value_ptr = std::make_unique(value); + kuzu_prepared_statement_bind_cpp_value(prepared_statement, param_name, std::move(value_ptr)); } void kuzu_prepared_statement_bind_int8( kuzu_prepared_statement* prepared_statement, const char* param_name, int8_t value) { - auto value_ptr = std::make_shared(value); - kuzu_prepared_statement_bind_cpp_value(prepared_statement, param_name, value_ptr); + auto value_ptr = std::make_unique(value); + kuzu_prepared_statement_bind_cpp_value(prepared_statement, param_name, std::move(value_ptr)); } void kuzu_prepared_statement_bind_uint64( kuzu_prepared_statement* prepared_statement, const char* param_name, uint64_t value) { - auto value_ptr = std::make_shared(value); - kuzu_prepared_statement_bind_cpp_value(prepared_statement, param_name, value_ptr); + auto value_ptr = std::make_unique(value); + kuzu_prepared_statement_bind_cpp_value(prepared_statement, param_name, std::move(value_ptr)); } void kuzu_prepared_statement_bind_uint32( kuzu_prepared_statement* prepared_statement, const char* param_name, uint32_t value) { - auto value_ptr = std::make_shared(value); - kuzu_prepared_statement_bind_cpp_value(prepared_statement, param_name, value_ptr); + auto value_ptr = std::make_unique(value); + kuzu_prepared_statement_bind_cpp_value(prepared_statement, param_name, std::move(value_ptr)); } void kuzu_prepared_statement_bind_uint16( kuzu_prepared_statement* prepared_statement, const char* param_name, uint16_t value) { - auto value_ptr = std::make_shared(value); - kuzu_prepared_statement_bind_cpp_value(prepared_statement, param_name, value_ptr); + auto value_ptr = std::make_unique(value); + kuzu_prepared_statement_bind_cpp_value(prepared_statement, param_name, std::move(value_ptr)); } void kuzu_prepared_statement_bind_uint8( kuzu_prepared_statement* prepared_statement, const char* param_name, uint8_t value) { - auto value_ptr = std::make_shared(value); - kuzu_prepared_statement_bind_cpp_value(prepared_statement, param_name, value_ptr); + auto value_ptr = std::make_unique(value); + kuzu_prepared_statement_bind_cpp_value(prepared_statement, param_name, std::move(value_ptr)); } void kuzu_prepared_statement_bind_double( kuzu_prepared_statement* prepared_statement, const char* param_name, double value) { - auto value_ptr = std::make_shared(value); - kuzu_prepared_statement_bind_cpp_value(prepared_statement, param_name, value_ptr); + auto value_ptr = std::make_unique(value); + kuzu_prepared_statement_bind_cpp_value(prepared_statement, param_name, std::move(value_ptr)); } void kuzu_prepared_statement_bind_float( kuzu_prepared_statement* prepared_statement, const char* param_name, float value) { - auto value_ptr = std::make_shared(value); - kuzu_prepared_statement_bind_cpp_value(prepared_statement, param_name, value_ptr); + auto value_ptr = std::make_unique(value); + kuzu_prepared_statement_bind_cpp_value(prepared_statement, param_name, std::move(value_ptr)); } void kuzu_prepared_statement_bind_date( kuzu_prepared_statement* prepared_statement, const char* param_name, kuzu_date_t value) { - auto value_ptr = std::make_shared(date_t(value.days)); - kuzu_prepared_statement_bind_cpp_value(prepared_statement, param_name, value_ptr); + auto value_ptr = std::make_unique(date_t(value.days)); + kuzu_prepared_statement_bind_cpp_value(prepared_statement, param_name, std::move(value_ptr)); } void kuzu_prepared_statement_bind_timestamp( kuzu_prepared_statement* prepared_statement, const char* param_name, kuzu_timestamp_t value) { - auto value_ptr = std::make_shared(timestamp_t(value.value)); - kuzu_prepared_statement_bind_cpp_value(prepared_statement, param_name, value_ptr); + auto value_ptr = std::make_unique(timestamp_t(value.value)); + kuzu_prepared_statement_bind_cpp_value(prepared_statement, param_name, std::move(value_ptr)); } void kuzu_prepared_statement_bind_interval( kuzu_prepared_statement* prepared_statement, const char* param_name, kuzu_interval_t value) { - auto value_ptr = std::make_shared(interval_t(value.months, value.days, value.micros)); - kuzu_prepared_statement_bind_cpp_value(prepared_statement, param_name, value_ptr); + auto value_ptr = std::make_unique(interval_t(value.months, value.days, value.micros)); + kuzu_prepared_statement_bind_cpp_value(prepared_statement, param_name, std::move(value_ptr)); } void kuzu_prepared_statement_bind_string( kuzu_prepared_statement* prepared_statement, const char* param_name, const char* value) { auto value_ptr = - std::make_shared(LogicalType{LogicalTypeID::STRING}, std::string(value)); - kuzu_prepared_statement_bind_cpp_value(prepared_statement, param_name, value_ptr); + std::make_unique(LogicalType{LogicalTypeID::STRING}, std::string(value)); + kuzu_prepared_statement_bind_cpp_value(prepared_statement, param_name, std::move(value_ptr)); } void kuzu_prepared_statement_bind_value( kuzu_prepared_statement* prepared_statement, const char* param_name, kuzu_value* value) { - auto value_ptr = std::make_shared(*static_cast(value->_value)); - kuzu_prepared_statement_bind_cpp_value(prepared_statement, param_name, value_ptr); + auto value_ptr = std::make_unique(*static_cast(value->_value)); + kuzu_prepared_statement_bind_cpp_value(prepared_statement, param_name, std::move(value_ptr)); } diff --git a/src/include/common/enums/statement_type.h b/src/include/common/enums/statement_type.h index 04e99f07d80..97bd6bf9ffd 100644 --- a/src/include/common/enums/statement_type.h +++ b/src/include/common/enums/statement_type.h @@ -27,9 +27,9 @@ struct StatementTypeUtils { case StatementType::ALTER: case StatementType::CREATE_MACRO: case StatementType::COPY_FROM: - return true; - default: return false; + default: + return true; } } }; diff --git a/src/include/common/types/value/value.h b/src/include/common/types/value/value.h index 6fc95cc1de3..73843a8ff1f 100644 --- a/src/include/common/types/value/value.h +++ b/src/include/common/types/value/value.h @@ -156,6 +156,14 @@ class Value { * @return a Value with the same value as other. */ KUZU_API Value(const Value& other); + + /** + * @param other the value to move from. + * @return a Value with the same value as other. + */ + KUZU_API Value(Value&& other) = default; + KUZU_API Value& operator=(Value&& other) = default; + /** * @brief Sets the data type of the Value. * @param dataType_ the data type to set to. diff --git a/src/include/main/connection.h b/src/include/main/connection.h index c12ef2de429..48cca26242e 100644 --- a/src/include/main/connection.h +++ b/src/include/main/connection.h @@ -82,8 +82,8 @@ class Connection { template inline std::unique_ptr execute( PreparedStatement* preparedStatement, std::pair... args) { - std::unordered_map> inputParameters; - return executeWithParams(preparedStatement, inputParameters, args...); + std::unordered_map> inputParameters; + return executeWithParams(preparedStatement, std::move(inputParameters), args...); } /** * @brief Executes the given prepared statement with inputParams and returns the result. @@ -93,7 +93,7 @@ class Connection { * @return the result of the query. */ KUZU_API std::unique_ptr executeWithParams(PreparedStatement* preparedStatement, - std::unordered_map>& inputParams); + std::unordered_map> inputParams); /** * @brief interrupts all queries currently executing within this connection. */ @@ -151,16 +151,16 @@ class Connection { template std::unique_ptr executeWithParams(PreparedStatement* preparedStatement, - std::unordered_map>& params, + std::unordered_map> params, std::pair arg, std::pair... args) { auto name = arg.first; - auto val = std::make_shared((T)arg.second); - params.insert({name, val}); - return executeWithParams(preparedStatement, params, args...); + auto val = std::make_unique((T)arg.second); + params.insert({name, std::move(val)}); + return executeWithParams(preparedStatement, std::move(params), args...); } void bindParametersNoLock(PreparedStatement* preparedStatement, - std::unordered_map>& inputParams); + std::unordered_map> inputParams); std::unique_ptr executeAndAutoCommitIfNecessaryNoLock( PreparedStatement* preparedStatement, uint32_t planIdx = 0u); diff --git a/src/main/connection.cpp b/src/main/connection.cpp index 77f3850563c..0cdb1351df5 100644 --- a/src/main/connection.cpp +++ b/src/main/connection.cpp @@ -157,13 +157,13 @@ uint64_t Connection::getQueryTimeOut() { } std::unique_ptr Connection::executeWithParams(PreparedStatement* preparedStatement, - std::unordered_map>& inputParams) { + std::unordered_map> inputParams) { lock_t lck{mtx}; if (!preparedStatement->isSuccess()) { return queryResultWithError(preparedStatement->errMsg); } try { - bindParametersNoLock(preparedStatement, inputParams); + bindParametersNoLock(preparedStatement, std::move(inputParams)); } catch (Exception& exception) { std::string errMsg = exception.what(); return queryResultWithError(errMsg); @@ -172,7 +172,7 @@ std::unique_ptr Connection::executeWithParams(PreparedStatement* pr } void Connection::bindParametersNoLock(PreparedStatement* preparedStatement, - std::unordered_map>& inputParams) { + std::unordered_map> inputParams) { auto& parameterMap = preparedStatement->parameterMap; for (auto& [name, value] : inputParams) { if (!parameterMap.contains(name)) { @@ -184,7 +184,10 @@ void Connection::bindParametersNoLock(PreparedStatement* preparedStatement, value->getDataType()->toString() + " but expects " + expectParam->getDataType()->toString() + "."); } - parameterMap.at(name)->copyValueFrom(*value); + // The much more natural `parameterMap.at(name) = std::move(v)` fails. + // The reason is that other parts of the code rely on the existing Value object to be + // modified in-place, not replaced in this map. + *parameterMap.at(name) = std::move(*value); } } diff --git a/src/main/prepared_statement.cpp b/src/main/prepared_statement.cpp index bc70e4fb985..73e494ae47c 100644 --- a/src/main/prepared_statement.cpp +++ b/src/main/prepared_statement.cpp @@ -10,7 +10,7 @@ namespace kuzu { namespace main { bool PreparedStatement::allowActiveTransaction() const { - return !StatementTypeUtils::allowActiveTransaction(preparedSummary.statementType); + return StatementTypeUtils::allowActiveTransaction(preparedSummary.statementType); } bool PreparedStatement::isTransactionStatement() const { diff --git a/test/c_api/connection_test.cpp b/test/c_api/connection_test.cpp index e448c46d29f..c3989dc2ac6 100644 --- a/test/c_api/connection_test.cpp +++ b/test/c_api/connection_test.cpp @@ -82,7 +82,7 @@ TEST_F(CApiConnectionTest, Execute) { auto connection = getConnection(); auto query = "MATCH (a:person) WHERE a.isStudent = $1 RETURN COUNT(*)"; auto statement = kuzu_connection_prepare(connection, query); - kuzu_prepared_statement_bind_bool(statement, (char*)"1", true); + kuzu_prepared_statement_bind_bool(statement, "1", true); auto result = kuzu_connection_execute(connection, statement); ASSERT_NE(result, nullptr); ASSERT_NE(result->_query_result, nullptr); diff --git a/test/c_api/prepared_statement_test.cpp b/test/c_api/prepared_statement_test.cpp index fdc29c6cec1..f0f8bcccc00 100644 --- a/test/c_api/prepared_statement_test.cpp +++ b/test/c_api/prepared_statement_test.cpp @@ -71,7 +71,7 @@ TEST_F(CApiPreparedStatementTest, BindBool) { auto connection = getConnection(); auto query = "MATCH (a:person) WHERE a.isStudent = $1 RETURN COUNT(*)"; auto preparedStatement = kuzu_connection_prepare(connection, query); - kuzu_prepared_statement_bind_bool(preparedStatement, (char*)"1", true); + kuzu_prepared_statement_bind_bool(preparedStatement, "1", true); auto result = kuzu_connection_execute(connection, preparedStatement); ASSERT_NE(result, nullptr); ASSERT_NE(result->_query_result, nullptr); @@ -91,7 +91,7 @@ TEST_F(CApiPreparedStatementTest, BindInt64) { auto connection = getConnection(); auto query = "MATCH (a:person) WHERE a.age > $1 RETURN COUNT(*)"; auto preparedStatement = kuzu_connection_prepare(connection, query); - kuzu_prepared_statement_bind_int64(preparedStatement, (char*)"1", 30); + kuzu_prepared_statement_bind_int64(preparedStatement, "1", 30); auto result = kuzu_connection_execute(connection, preparedStatement); ASSERT_NE(result, nullptr); ASSERT_NE(result->_query_result, nullptr); @@ -111,7 +111,7 @@ TEST_F(CApiPreparedStatementTest, BindInt32) { auto connection = getConnection(); auto query = "MATCH (a:movies) WHERE a.length > $1 RETURN COUNT(*)"; auto preparedStatement = kuzu_connection_prepare(connection, query); - kuzu_prepared_statement_bind_int32(preparedStatement, (char*)"1", 200); + kuzu_prepared_statement_bind_int32(preparedStatement, "1", 200); auto result = kuzu_connection_execute(connection, preparedStatement); ASSERT_NE(result, nullptr); ASSERT_NE(result->_query_result, nullptr); @@ -132,7 +132,7 @@ TEST_F(CApiPreparedStatementTest, BindInt16) { auto query = "MATCH (a:person) -[s:studyAt]-> (b:organisation) WHERE s.length > $1 RETURN COUNT(*)"; auto preparedStatement = kuzu_connection_prepare(connection, query); - kuzu_prepared_statement_bind_int16(preparedStatement, (char*)"1", 10); + kuzu_prepared_statement_bind_int16(preparedStatement, "1", 10); auto result = kuzu_connection_execute(connection, preparedStatement); ASSERT_NE(result, nullptr); ASSERT_NE(result->_query_result, nullptr); @@ -153,7 +153,7 @@ TEST_F(CApiPreparedStatementTest, BindInt8) { auto query = "MATCH (a:person) -[s:studyAt]-> (b:organisation) WHERE s.level > $1 RETURN COUNT(*)"; auto preparedStatement = kuzu_connection_prepare(connection, query); - kuzu_prepared_statement_bind_int8(preparedStatement, (char*)"1", 3); + kuzu_prepared_statement_bind_int8(preparedStatement, "1", 3); auto result = kuzu_connection_execute(connection, preparedStatement); ASSERT_NE(result, nullptr); ASSERT_NE(result->_query_result, nullptr); @@ -174,7 +174,7 @@ TEST_F(CApiPreparedStatementTest, BindUInt64) { auto query = "MATCH (a:person) -[s:studyAt]-> (b:organisation) WHERE s.code > $1 RETURN COUNT(*)"; auto preparedStatement = kuzu_connection_prepare(connection, query); - kuzu_prepared_statement_bind_uint64(preparedStatement, (char*)"1", 100); + kuzu_prepared_statement_bind_uint64(preparedStatement, "1", 100); auto result = kuzu_connection_execute(connection, preparedStatement); ASSERT_NE(result, nullptr); ASSERT_NE(result->_query_result, nullptr); @@ -195,7 +195,7 @@ TEST_F(CApiPreparedStatementTest, BindUInt32) { auto query = "MATCH (a:person) -[s:studyAt]-> (b:organisation) WHERE s.temprature> $1 RETURN COUNT(*)"; auto preparedStatement = kuzu_connection_prepare(connection, query); - kuzu_prepared_statement_bind_uint32(preparedStatement, (char*)"1", 10); + kuzu_prepared_statement_bind_uint32(preparedStatement, "1", 10); auto result = kuzu_connection_execute(connection, preparedStatement); ASSERT_NE(result, nullptr); ASSERT_NE(result->_query_result, nullptr); @@ -216,7 +216,7 @@ TEST_F(CApiPreparedStatementTest, BindUInt16) { auto query = "MATCH (a:person) -[s:studyAt]-> (b:organisation) WHERE s.ulength> $1 RETURN COUNT(*)"; auto preparedStatement = kuzu_connection_prepare(connection, query); - kuzu_prepared_statement_bind_uint16(preparedStatement, (char*)"1", 100); + kuzu_prepared_statement_bind_uint16(preparedStatement, "1", 100); auto result = kuzu_connection_execute(connection, preparedStatement); ASSERT_NE(result, nullptr); ASSERT_NE(result->_query_result, nullptr); @@ -237,7 +237,7 @@ TEST_F(CApiPreparedStatementTest, BindUInt8) { auto query = "MATCH (a:person) -[s:studyAt]-> (b:organisation) WHERE s.ulevel> $1 RETURN COUNT(*)"; auto preparedStatement = kuzu_connection_prepare(connection, query); - kuzu_prepared_statement_bind_uint8(preparedStatement, (char*)"1", 14); + kuzu_prepared_statement_bind_uint8(preparedStatement, "1", 14); auto result = kuzu_connection_execute(connection, preparedStatement); ASSERT_NE(result, nullptr); ASSERT_NE(result->_query_result, nullptr); @@ -257,7 +257,7 @@ TEST_F(CApiPreparedStatementTest, BindDouble) { auto connection = getConnection(); auto query = "MATCH (a:person) WHERE a.eyeSight > $1 RETURN COUNT(*)"; auto preparedStatement = kuzu_connection_prepare(connection, query); - kuzu_prepared_statement_bind_double(preparedStatement, (char*)"1", 4.5); + kuzu_prepared_statement_bind_double(preparedStatement, "1", 4.5); auto result = kuzu_connection_execute(connection, preparedStatement); ASSERT_NE(result, nullptr); ASSERT_NE(result->_query_result, nullptr); @@ -277,7 +277,7 @@ TEST_F(CApiPreparedStatementTest, BindFloat) { auto connection = getConnection(); auto query = "MATCH (a:person) WHERE a.height < $1 RETURN COUNT(*)"; auto preparedStatement = kuzu_connection_prepare(connection, query); - kuzu_prepared_statement_bind_float(preparedStatement, (char*)"1", 1.0); + kuzu_prepared_statement_bind_float(preparedStatement, "1", 1.0); auto result = kuzu_connection_execute(connection, preparedStatement); ASSERT_NE(result, nullptr); ASSERT_NE(result->_query_result, nullptr); @@ -298,7 +298,7 @@ TEST_F(CApiPreparedStatementTest, BindString) { auto query = "MATCH (a:person) WHERE a.fName = $1 RETURN COUNT(*)"; auto preparedStatement = kuzu_connection_prepare(connection, query); ASSERT_TRUE(kuzu_prepared_statement_is_success(preparedStatement)); - kuzu_prepared_statement_bind_string(preparedStatement, (char*)"1", (char*)"Alice"); + kuzu_prepared_statement_bind_string(preparedStatement, "1", "Alice"); auto result = kuzu_connection_execute(connection, preparedStatement); ASSERT_NE(result, nullptr); ASSERT_NE(result->_query_result, nullptr); @@ -320,7 +320,7 @@ TEST_F(CApiPreparedStatementTest, BindDate) { auto preparedStatement = kuzu_connection_prepare(connection, query); ASSERT_TRUE(kuzu_prepared_statement_is_success(preparedStatement)); auto date = kuzu_date_t{0}; - kuzu_prepared_statement_bind_date(preparedStatement, (char*)"1", date); + kuzu_prepared_statement_bind_date(preparedStatement, "1", date); auto result = kuzu_connection_execute(connection, preparedStatement); ASSERT_NE(result, nullptr); ASSERT_NE(result->_query_result, nullptr); @@ -342,7 +342,7 @@ TEST_F(CApiPreparedStatementTest, BindTimestamp) { auto preparedStatement = kuzu_connection_prepare(connection, query); ASSERT_TRUE(kuzu_prepared_statement_is_success(preparedStatement)); auto timestamp = kuzu_timestamp_t{0}; - kuzu_prepared_statement_bind_timestamp(preparedStatement, (char*)"1", timestamp); + kuzu_prepared_statement_bind_timestamp(preparedStatement, "1", timestamp); auto result = kuzu_connection_execute(connection, preparedStatement); ASSERT_NE(result, nullptr); ASSERT_NE(result->_query_result, nullptr); @@ -364,7 +364,7 @@ TEST_F(CApiPreparedStatementTest, BindInteval) { auto preparedStatement = kuzu_connection_prepare(connection, query); ASSERT_TRUE(kuzu_prepared_statement_is_success(preparedStatement)); auto interval = kuzu_interval_t{0, 0, 0}; - kuzu_prepared_statement_bind_interval(preparedStatement, (char*)"1", interval); + kuzu_prepared_statement_bind_interval(preparedStatement, "1", interval); auto result = kuzu_connection_execute(connection, preparedStatement); ASSERT_NE(result, nullptr); ASSERT_NE(result->_query_result, nullptr); @@ -387,7 +387,7 @@ TEST_F(CApiPreparedStatementTest, BindValue) { ASSERT_TRUE(kuzu_prepared_statement_is_success(preparedStatement)); auto timestamp = kuzu_timestamp_t{0}; auto timestampValue = kuzu_value_create_timestamp(timestamp); - kuzu_prepared_statement_bind_value(preparedStatement, (char*)"1", timestampValue); + kuzu_prepared_statement_bind_value(preparedStatement, "1", timestampValue); kuzu_value_destroy(timestampValue); auto result = kuzu_connection_execute(connection, preparedStatement); ASSERT_NE(result, nullptr); diff --git a/tools/java_api/src/jni/kuzu_java.cpp b/tools/java_api/src/jni/kuzu_java.cpp index 411f17b3e85..e2cae8e956b 100644 --- a/tools/java_api/src/jni/kuzu_java.cpp +++ b/tools/java_api/src/jni/kuzu_java.cpp @@ -1,7 +1,5 @@ -#include #include -#include "binder/bound_statement_result.h" // This header is generated at build time. See CMakeLists.txt. #include "com_kuzudb_KuzuNative.h" #include "common/exception/conversion.h" @@ -11,10 +9,8 @@ #include "common/types/value/node.h" #include "common/types/value/rel.h" #include "common/types/value/value.h" -#include "json.hpp" #include "main/kuzu.h" #include "main/query_summary.h" -#include "planner/operator/logical_plan.h" #include using namespace kuzu::main; @@ -116,8 +112,8 @@ std::string dataTypeToString(const LogicalType& dataType) { return LogicalTypeUtils::toString(dataType.getLogicalTypeID()); } -void javaMapToCPPMap( - JNIEnv* env, jobject javaMap, std::unordered_map>& cppMap) { +std::unordered_map> javaMapToCPPMap( + JNIEnv* env, jobject javaMap) { jclass mapClass = env->FindClass("java/util/Map"); jmethodID entrySet = env->GetMethodID(mapClass, "entrySet", "()Ljava/util/Set;"); @@ -132,20 +128,22 @@ void javaMapToCPPMap( jmethodID entryGetKey = env->GetMethodID(entryClass, "getKey", "()Ljava/lang/Object;"); jmethodID entryGetValue = env->GetMethodID(entryClass, "getValue", "()Ljava/lang/Object;"); + std::unordered_map> result; while (env->CallBooleanMethod(iter, hasNext)) { jobject entry = env->CallObjectMethod(iter, next); jstring key = (jstring)env->CallObjectMethod(entry, entryGetKey); jobject value = env->CallObjectMethod(entry, entryGetValue); const char* keyStr = env->GetStringUTFChars(key, JNI_FALSE); - Value* v = getValue(env, value); - std::shared_ptr value_ptr(v); - cppMap.insert({keyStr, value_ptr}); + const Value* v = getValue(env, value); + // Java code can keep a reference to the value, so we cannot move. + result.insert({keyStr, v->copy()}); env->DeleteLocalRef(entry); env->ReleaseStringUTFChars(key, keyStr); env->DeleteLocalRef(key); env->DeleteLocalRef(value); } + return result; } /** @@ -301,14 +299,10 @@ JNIEXPORT jobject JNICALL Java_com_kuzudb_KuzuNative_kuzu_1connection_1execute( Connection* conn = getConnection(env, thisConn); PreparedStatement* ps = getPreparedStatement(env, preStm); - std::unordered_map> param; - javaMapToCPPMap(env, param_map, param); + std::unordered_map> params = + javaMapToCPPMap(env, param_map); - for (auto const& pair : param) { - std::cout << "{" << pair.first << ": " << pair.second.get()->toString() << "}\n"; - } - - auto query_result = conn->executeWithParams(ps, param).release(); + auto query_result = conn->executeWithParams(ps, std::move(params)).release(); if (query_result == nullptr) { return nullptr; } diff --git a/tools/java_api/src/main/java/com/kuzudb/KuzuConnection.java b/tools/java_api/src/main/java/com/kuzudb/KuzuConnection.java index fd80c515046..1e6d59d24c4 100644 --- a/tools/java_api/src/main/java/com/kuzudb/KuzuConnection.java +++ b/tools/java_api/src/main/java/com/kuzudb/KuzuConnection.java @@ -16,7 +16,8 @@ public class KuzuConnection { * @param db: KuzuDatabase instance. */ public KuzuConnection(KuzuDatabase db) { - assert db != null : "Cannot create connection, database is null."; + if (db == null) + throw new AssertionError("Cannot create connection, database is null."); conn_ref = KuzuNative.kuzu_connection_init(db); } diff --git a/tools/java_api/src/test/java/com/kuzudb/test/ConnectionTest.java b/tools/java_api/src/test/java/com/kuzudb/test/ConnectionTest.java index eddeff1f45f..4c99df39894 100644 --- a/tools/java_api/src/test/java/com/kuzudb/test/ConnectionTest.java +++ b/tools/java_api/src/test/java/com/kuzudb/test/ConnectionTest.java @@ -256,8 +256,10 @@ void ConnPrepareInterval() throws KuzuObjectRefDestroyedException { void ConnPrepareMultiParam() throws KuzuObjectRefDestroyedException { String query = "MATCH (a:person) WHERE a.lastJobDuration > $1 AND a.fName = $2 RETURN COUNT(*)"; Map m = new HashMap(); - m.put("1", new KuzuValue(Duration.ofDays(730))); - m.put("2", new KuzuValue("Alice")); + KuzuValue v1 = new KuzuValue(Duration.ofDays(730)); + KuzuValue v2 = new KuzuValue("Alice"); + m.put("1", v1); + m.put("2", v2); KuzuPreparedStatement statement = conn.prepare(query); assertNotNull(statement); KuzuQueryResult result = conn.execute(statement, m); @@ -270,6 +272,11 @@ void ConnPrepareMultiParam() throws KuzuObjectRefDestroyedException { assertEquals(((long) tuple.getValue(0).getValue()), 1); statement.destroy(); result.destroy(); + + // Not strictly necessary, but this makes sure if we freed v1 or v2 in + // the execute() call, we segfault here. + v1.destroy(); + v2.destroy(); } @Test diff --git a/tools/nodejs_api/src_cpp/include/node_connection.h b/tools/nodejs_api/src_cpp/include/node_connection.h index 19104d11b88..d4a1761a43e 100644 --- a/tools/nodejs_api/src_cpp/include/node_connection.h +++ b/tools/nodejs_api/src_cpp/include/node_connection.h @@ -1,7 +1,5 @@ #pragma once -#include - #include "main/kuzu.h" #include "node_database.h" #include "node_prepared_statement.h" @@ -65,15 +63,15 @@ class ConnectionExecuteAsyncWorker : public Napi::AsyncWorker { public: ConnectionExecuteAsyncWorker(Napi::Function& callback, std::shared_ptr& connection, std::shared_ptr preparedStatement, NodeQueryResult* nodeQueryResult, - std::unordered_map>& params) + std::unordered_map> params) : Napi::AsyncWorker(callback), preparedStatement(preparedStatement), - nodeQueryResult(nodeQueryResult), connection(connection), params(params) {} + nodeQueryResult(nodeQueryResult), connection(connection), params(std::move(params)) {} ~ConnectionExecuteAsyncWorker() = default; inline void Execute() override { try { std::shared_ptr result = - std::move(connection->executeWithParams(preparedStatement.get(), params)); + connection->executeWithParams(preparedStatement.get(), std::move(params)); nodeQueryResult->SetQueryResult(result); if (!result->isSuccess()) { SetError(result->getErrorMessage()); @@ -90,5 +88,5 @@ class ConnectionExecuteAsyncWorker : public Napi::AsyncWorker { std::shared_ptr connection; std::shared_ptr preparedStatement; NodeQueryResult* nodeQueryResult; - std::unordered_map> params; + std::unordered_map> params; }; diff --git a/tools/nodejs_api/src_cpp/include/node_util.h b/tools/nodejs_api/src_cpp/include/node_util.h index 26bbe6d1562..e8980f61d5d 100644 --- a/tools/nodejs_api/src_cpp/include/node_util.h +++ b/tools/nodejs_api/src_cpp/include/node_util.h @@ -1,10 +1,6 @@ #pragma once -#include -#include -#include - -#include "main/kuzu.h" +#include "common/types/value/value.h" #include using namespace kuzu::common; @@ -12,8 +8,9 @@ using namespace kuzu::common; class Util { public: static Napi::Value ConvertToNapiObject(const Value& value, Napi::Env env); - static std::unordered_map> TransformParametersForExec( - Napi::Array params, std::unordered_map>& parameterMap); + static std::unordered_map> TransformParametersForExec( + Napi::Array params, + const std::unordered_map>& parameterMap); private: static Napi::Object ConvertNodeIdToNapiObject(const nodeID_t& nodeId, Napi::Env env); diff --git a/tools/nodejs_api/src_cpp/node_connection.cpp b/tools/nodejs_api/src_cpp/node_connection.cpp index 7d7b542461b..98c4bbcf46c 100644 --- a/tools/nodejs_api/src_cpp/node_connection.cpp +++ b/tools/nodejs_api/src_cpp/node_connection.cpp @@ -75,10 +75,10 @@ Napi::Value NodeConnection::ExecuteAsync(const Napi::CallbackInfo& info) { auto nodeQueryResult = Napi::ObjectWrap::Unwrap(info[1].As()); auto callback = info[3].As(); try { - auto parameterMap = nodePreparedStatement->preparedStatement->getParameterMap(); + const auto& parameterMap = nodePreparedStatement->preparedStatement->getParameterMap(); auto params = Util::TransformParametersForExec(info[2].As(), parameterMap); auto asyncWorker = new ConnectionExecuteAsyncWorker(callback, connection, - nodePreparedStatement->preparedStatement, nodeQueryResult, params); + nodePreparedStatement->preparedStatement, nodeQueryResult, std::move(params)); asyncWorker->Queue(); } catch (const std::exception& exc) { Napi::Error::New(env, std::string(exc.what())).ThrowAsJavaScriptException(); diff --git a/tools/nodejs_api/src_cpp/node_util.cpp b/tools/nodejs_api/src_cpp/node_util.cpp index 80062b5103d..d7f3523c934 100644 --- a/tools/nodejs_api/src_cpp/node_util.cpp +++ b/tools/nodejs_api/src_cpp/node_util.cpp @@ -176,9 +176,10 @@ Napi::Value Util::ConvertToNapiObject(const Value& value, Napi::Env env) { return Napi::Value(); } -std::unordered_map> Util::TransformParametersForExec( - Napi::Array params, std::unordered_map>& parameterMap) { - std::unordered_map> result; +std::unordered_map> Util::TransformParametersForExec( + Napi::Array params, + const std::unordered_map>& parameterMap) { + std::unordered_map> result; for (size_t i = 0; i < params.Length(); i++) { auto param = params.Get(i).As(); KU_ASSERT(param.Length() == 2); @@ -187,11 +188,10 @@ std::unordered_map> Util::TransformParameter if (!parameterMap.count(key)) { throw Exception("Parameter " + key + " is not defined in the prepared statement"); } - auto paramValue = parameterMap[key]; auto napiValue = param.Get(uint32_t(1)); - auto expectedDataType = paramValue->getDataType(); + auto expectedDataType = parameterMap.at(key)->getDataType(); auto transformedVal = TransformNapiValue(napiValue, expectedDataType, key); - result[key] = std::make_shared(transformedVal); + result[key] = std::make_unique(transformedVal); } return result; } @@ -280,8 +280,7 @@ Value Util::TransformNapiValue( return Value(normalizedInterval); } default: - throw Exception("Unsupported type " + - expectedDataType->toString() + - " for parameter: " + key); + throw Exception( + "Unsupported type " + expectedDataType->toString() + " for parameter: " + key); } } diff --git a/tools/python_api/src_cpp/include/py_connection.h b/tools/python_api/src_cpp/include/py_connection.h index 526725859f5..4c8f43d5acc 100644 --- a/tools/python_api/src_cpp/include/py_connection.h +++ b/tools/python_api/src_cpp/include/py_connection.h @@ -32,12 +32,6 @@ class PyConnection { static bool isPandasDataframe(const py::object& object); -private: - std::unordered_map> transformPythonParameters( - py::dict params); - - kuzu::common::Value transformPythonValue(py::handle val); - private: std::unique_ptr storageDriver; std::unique_ptr conn; diff --git a/tools/python_api/src_cpp/py_connection.cpp b/tools/python_api/src_cpp/py_connection.cpp index 228ff0b743c..3402b6ee25f 100644 --- a/tools/python_api/src_cpp/py_connection.cpp +++ b/tools/python_api/src_cpp/py_connection.cpp @@ -1,10 +1,8 @@ #include "include/py_connection.h" -#include "binder/bound_statement_result.h" #include "common/string_format.h" #include "datetime.h" // from Python #include "main/connection.h" -#include "planner/operator/logical_plan.h" #include "pandas/pandas_scan.h" #include "processor/result/factorized_table.h" @@ -30,7 +28,7 @@ void PyConnection::initialize(py::handle& m) { PyConnection::PyConnection(PyDatabase* pyDatabase, uint64_t numThreads) { storageDriver = std::make_unique(pyDatabase->database.get()); conn = std::make_unique(pyDatabase->database.get()); - //TODO(Xiyang): We should implement a generic replacement framework in binder. + // TODO(Xiyang): We should implement a generic replacement framework in binder. conn->setReplaceFunc(kuzu::replacePD); if (numThreads > 0) { conn->setMaxNumThreadForExec(numThreads); @@ -41,12 +39,15 @@ void PyConnection::setQueryTimeout(uint64_t timeoutInMS) { conn->setQueryTimeOut(timeoutInMS); } -std::unique_ptr PyConnection::execute(PyPreparedStatement* preparedStatement, - py::dict params) { +static std::unordered_map> transformPythonParameters( + py::dict params); + +std::unique_ptr PyConnection::execute( + PyPreparedStatement* preparedStatement, py::dict params) { auto parameters = transformPythonParameters(params); py::gil_scoped_release release; auto queryResult = - conn->executeWithParams(preparedStatement->preparedStatement.get(), parameters); + conn->executeWithParams(preparedStatement->preparedStatement.get(), std::move(parameters)); py::gil_scoped_acquire acquire; if (!queryResult->isSuccess()) { throw std::runtime_error(queryResult->getErrorMessage()); @@ -103,10 +104,10 @@ void PyConnection::getAllEdgesForTorchGeometric(py::array_t& npArray, int64_t start = batch * queryBatchSize; int64_t end = (batch + 1) * queryBatchSize; end = end > numDstNodes ? numDstNodes : end; - std::unordered_map> parameters; - parameters["s"] = std::make_shared(start); - parameters["e"] = std::make_shared(end); - auto result = conn->executeWithParams(preparedStatement.get(), parameters); + std::unordered_map> parameters; + parameters["s"] = std::make_unique(start); + parameters["e"] = std::make_unique(end); + auto result = conn->executeWithParams(preparedStatement.get(), std::move(parameters)); if (!result->isSuccess()) { throw std::runtime_error(result->getErrorMessage()); } @@ -151,22 +152,23 @@ bool PyConnection::isPandasDataframe(const py::object& object) { return py::isinstance(object, pandas.attr("DataFrame")); } -std::unordered_map> PyConnection::transformPythonParameters( - py::dict params) { - std::unordered_map> result; +static Value transformPythonValue(py::handle val); + +std::unordered_map> transformPythonParameters(py::dict params) { + std::unordered_map> result; for (auto& [key, value] : params) { if (!py::isinstance(key)) { throw std::runtime_error("Parameter name must be of type string but get " + py::str(key.get_type()).cast()); } auto name = key.cast(); - auto val = std::make_shared(transformPythonValue(value)); - result.insert({name, val}); + auto val = std::make_unique(transformPythonValue(value)); + result.insert({name, std::move(val)}); } return result; } -Value PyConnection::transformPythonValue(py::handle val) { +Value transformPythonValue(py::handle val) { auto datetime_mod = py::module::import("datetime"); auto datetime_datetime = datetime_mod.attr("datetime"); auto time_delta = datetime_mod.attr("timedelta"); diff --git a/tools/rust_api/include/kuzu_rs.h b/tools/rust_api/include/kuzu_rs.h index e41d8e492f9..3080deafda2 100644 --- a/tools/rust_api/include/kuzu_rs.h +++ b/tools/rust_api/include/kuzu_rs.h @@ -27,7 +27,7 @@ struct TypeListBuilder { std::unique_ptr create_type_list(); struct QueryParams { - std::unordered_map> inputParams; + std::unordered_map> inputParams; void insert(const rust::Str key, std::unique_ptr value) { inputParams.insert(std::make_pair(key, std::move(value))); diff --git a/tools/rust_api/src/kuzu_rs.cpp b/tools/rust_api/src/kuzu_rs.cpp index 2f63b909884..48dbb7a63a3 100644 --- a/tools/rust_api/src/kuzu_rs.cpp +++ b/tools/rust_api/src/kuzu_rs.cpp @@ -101,7 +101,7 @@ std::unique_ptr database_connect(kuzu::main::Database& d std::unique_ptr connection_execute(kuzu::main::Connection& connection, kuzu::main::PreparedStatement& query, std::unique_ptr params) { - return connection.executeWithParams(&query, params->inputParams); + return connection.executeWithParams(&query, std::move(params->inputParams)); } rust::String prepared_statement_error_message(const kuzu::main::PreparedStatement& statement) {