From 2441dffbc0f4982ff9b7eeed4dcbfb9ae7db1d05 Mon Sep 17 00:00:00 2001 From: Bugen Zhao Date: Tue, 25 Jun 2024 17:46:34 +0800 Subject: [PATCH] refactor: return record output in run (#219) --- CHANGELOG.md | 2 ++ Cargo.lock | 6 ++-- Cargo.toml | 2 +- sqllogictest-bin/Cargo.toml | 4 +-- sqllogictest-engines/Cargo.toml | 2 +- sqllogictest/src/runner.rs | 59 +++++++++++++++++++++------------ 6 files changed, 47 insertions(+), 28 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d117f72..9d0243d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +* runner: `RecordOutput` is now returned by `Runner::run` (or `Runner::run_async`). This allows users to access the output of each record, or check whether the record is skipped. + ## [0.20.6] - 2024-06-21 * runner: add logs for `system` command (with target `sqllogictest::system_command`) for ease of debugging. diff --git a/Cargo.lock b/Cargo.lock index e9dc9e5..4d4092d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1405,7 +1405,7 @@ dependencies = [ [[package]] name = "sqllogictest" -version = "0.20.6" +version = "0.21.0" dependencies = [ "async-trait", "educe", @@ -1428,7 +1428,7 @@ dependencies = [ [[package]] name = "sqllogictest-bin" -version = "0.20.6" +version = "0.21.0" dependencies = [ "anyhow", "async-trait", @@ -1450,7 +1450,7 @@ dependencies = [ [[package]] name = "sqllogictest-engines" -version = "0.20.6" +version = "0.21.0" dependencies = [ "async-trait", "bytes", diff --git a/Cargo.toml b/Cargo.toml index b85f72e..3812a57 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,7 +3,7 @@ resolver = "2" members = ["sqllogictest", "sqllogictest-bin", "sqllogictest-engines", "tests"] [workspace.package] -version = "0.20.6" +version = "0.21.0" edition = "2021" homepage = "https://github.com/risinglightdb/sqllogictest-rs" keywords = ["sql", "database", "parser", "cli"] diff --git a/sqllogictest-bin/Cargo.toml b/sqllogictest-bin/Cargo.toml index f0d7402..783e29f 100644 --- a/sqllogictest-bin/Cargo.toml +++ b/sqllogictest-bin/Cargo.toml @@ -23,8 +23,8 @@ glob = "0.3" itertools = "0.13" quick-junit = { version = "0.4" } rand = "0.8" -sqllogictest = { path = "../sqllogictest", version = "0.20" } -sqllogictest-engines = { path = "../sqllogictest-engines", version = "0.20" } +sqllogictest = { path = "../sqllogictest", version = "0.21" } +sqllogictest-engines = { path = "../sqllogictest-engines", version = "0.21" } tokio = { version = "1", features = [ "rt", "rt-multi-thread", diff --git a/sqllogictest-engines/Cargo.toml b/sqllogictest-engines/Cargo.toml index 670e0e9..ad5ffb4 100644 --- a/sqllogictest-engines/Cargo.toml +++ b/sqllogictest-engines/Cargo.toml @@ -19,7 +19,7 @@ postgres-types = { version = "0.2.5", features = ["derive", "with-chrono-0_4"] } rust_decimal = { version = "1.30.0", features = ["tokio-pg"] } serde = { version = "1", features = ["derive"] } serde_json = "1" -sqllogictest = { path = "../sqllogictest", version = "0.20" } +sqllogictest = { path = "../sqllogictest", version = "0.21" } thiserror = "1" tokio = { version = "1", features = [ "rt", diff --git a/sqllogictest/src/runner.rs b/sqllogictest/src/runner.rs index 87a4928..d455f4c 100644 --- a/sqllogictest/src/runner.rs +++ b/sqllogictest/src/runner.rs @@ -23,19 +23,22 @@ use crate::{ColumnType, Connections, MakeConnection}; /// Type-erased error type. type AnyError = Arc; +/// Output of a record. #[derive(Debug, Clone)] #[non_exhaustive] pub enum RecordOutput { + /// No output. Occurs when the record is skipped or not a `query`, `statement`, or `system` + /// command. Nothing, + /// The output of a `query`. Query { types: Vec, rows: Vec>, error: Option, }, - Statement { - count: u64, - error: Option, - }, + /// The output of a `statement`. + Statement { count: u64, error: Option }, + /// The output of a `system` command. #[non_exhaustive] System { stdout: Option, @@ -833,10 +836,13 @@ impl> Runner { } /// Run a single record. - pub async fn run_async(&mut self, record: Record) -> Result<(), TestError> { + pub async fn run_async( + &mut self, + record: Record, + ) -> Result, TestError> { let result = self.apply_record(record.clone()).await; - match (record, result) { + match (record, &result) { (_, RecordOutput::Nothing) => {} // Tolerate the mismatched return type... ( @@ -894,7 +900,7 @@ impl> Runner { .at(loc)) } (None, StatementExpect::Count(expected_count)) => { - if expected_count != count { + if expected_count != *count { return Err(TestErrorKind::StatementResultMismatch { sql, expected: expected_count, @@ -908,7 +914,7 @@ impl> Runner { if !expected_error.is_match(&e.to_string()) { return Err(TestErrorKind::ErrorMismatch { sql, - err: Arc::new(e), + err: Arc::clone(e), expected_err: expected_error.to_string(), kind: RecordKind::Statement, } @@ -918,7 +924,7 @@ impl> Runner { (Some(e), StatementExpect::Count(_) | StatementExpect::Ok) => { return Err(TestErrorKind::Fail { sql, - err: Arc::new(e), + err: Arc::clone(e), kind: RecordKind::Statement, } .at(loc)); @@ -946,7 +952,7 @@ impl> Runner { if !expected_error.is_match(&e.to_string()) { return Err(TestErrorKind::ErrorMismatch { sql, - err: Arc::new(e), + err: Arc::clone(e), expected_err: expected_error.to_string(), kind: RecordKind::Query, } @@ -956,7 +962,7 @@ impl> Runner { (Some(e), QueryExpect::Results { .. }) => { return Err(TestErrorKind::Fail { sql, - err: Arc::new(e), + err: Arc::clone(e), kind: RecordKind::Query, } .at(loc)); @@ -969,7 +975,7 @@ impl> Runner { .. }, ) => { - if !(self.column_type_validator)(&types, &expected_types) { + if !(self.column_type_validator)(types, &expected_types) { return Err(TestErrorKind::QueryResultColumnsMismatch { sql, expected: expected_types.iter().map(|c| c.to_char()).join(""), @@ -978,11 +984,9 @@ impl> Runner { .at(loc)); } - if !(self.validator)(&rows, &expected_results) { - let output_rows = rows - .into_iter() - .map(|strs| strs.iter().join(" ")) - .collect_vec(); + if !(self.validator)(rows, &expected_results) { + let output_rows = + rows.iter().map(|strs| strs.iter().join(" ")).collect_vec(); return Err(TestErrorKind::QueryResultMismatch { sql, expected: expected_results.join("\n"), @@ -1006,12 +1010,16 @@ impl> Runner { }, ) => { if let Some(err) = error { - return Err(TestErrorKind::SystemFail { command, err }.at(loc)); + return Err(TestErrorKind::SystemFail { + command, + err: Arc::clone(err), + } + .at(loc)); } match (expected_stdout, actual_stdout) { (None, _) => {} (Some(expected_stdout), actual_stdout) => { - let actual_stdout = actual_stdout.unwrap_or_default(); + let actual_stdout = actual_stdout.clone().unwrap_or_default(); // TODO: support newlines contained in expected_stdout if expected_stdout != actual_stdout.trim() { return Err(TestErrorKind::SystemStdoutMismatch { @@ -1027,17 +1035,24 @@ impl> Runner { _ => unreachable!(), } - Ok(()) + Ok(result) } /// Run a single record. - pub fn run(&mut self, record: Record) -> Result<(), TestError> { + /// + /// Returns the output of the record if successful. + pub fn run( + &mut self, + record: Record, + ) -> Result, TestError> { futures::executor::block_on(self.run_async(record)) } /// Run multiple records. /// /// The runner will stop early once a halt record is seen. + /// + /// To acquire the result of each record, manually call `run_async` for each record instead. pub async fn run_multi_async( &mut self, records: impl IntoIterator>, @@ -1054,6 +1069,8 @@ impl> Runner { /// Run multiple records. /// /// The runner will stop early once a halt record is seen. + /// + /// To acquire the result of each record, manually call `run` for each record instead. pub fn run_multi( &mut self, records: impl IntoIterator>,