Skip to content

Commit

Permalink
refactor: return record output in run (#219)
Browse files Browse the repository at this point in the history
  • Loading branch information
BugenZhao authored Jun 25, 2024
1 parent 17d81db commit 2441dff
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 28 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

* runner: `RecordOutput` is now returned by `Runner::run` (or `Runner::run_async`). This allows users to access the output of each record, or check whether the record is skipped.

## [0.20.6] - 2024-06-21

* runner: add logs for `system` command (with target `sqllogictest::system_command`) for ease of debugging.
Expand Down
6 changes: 3 additions & 3 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ resolver = "2"
members = ["sqllogictest", "sqllogictest-bin", "sqllogictest-engines", "tests"]

[workspace.package]
version = "0.20.6"
version = "0.21.0"
edition = "2021"
homepage = "https://github.com/risinglightdb/sqllogictest-rs"
keywords = ["sql", "database", "parser", "cli"]
Expand Down
4 changes: 2 additions & 2 deletions sqllogictest-bin/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ glob = "0.3"
itertools = "0.13"
quick-junit = { version = "0.4" }
rand = "0.8"
sqllogictest = { path = "../sqllogictest", version = "0.20" }
sqllogictest-engines = { path = "../sqllogictest-engines", version = "0.20" }
sqllogictest = { path = "../sqllogictest", version = "0.21" }
sqllogictest-engines = { path = "../sqllogictest-engines", version = "0.21" }
tokio = { version = "1", features = [
"rt",
"rt-multi-thread",
Expand Down
2 changes: 1 addition & 1 deletion sqllogictest-engines/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ postgres-types = { version = "0.2.5", features = ["derive", "with-chrono-0_4"] }
rust_decimal = { version = "1.30.0", features = ["tokio-pg"] }
serde = { version = "1", features = ["derive"] }
serde_json = "1"
sqllogictest = { path = "../sqllogictest", version = "0.20" }
sqllogictest = { path = "../sqllogictest", version = "0.21" }
thiserror = "1"
tokio = { version = "1", features = [
"rt",
Expand Down
59 changes: 38 additions & 21 deletions sqllogictest/src/runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,22 @@ use crate::{ColumnType, Connections, MakeConnection};
/// Type-erased error type.
type AnyError = Arc<dyn std::error::Error + Send + Sync>;

/// Output of a record.
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum RecordOutput<T: ColumnType> {
/// No output. Occurs when the record is skipped or not a `query`, `statement`, or `system`
/// command.
Nothing,
/// The output of a `query`.
Query {
types: Vec<T>,
rows: Vec<Vec<String>>,
error: Option<AnyError>,
},
Statement {
count: u64,
error: Option<AnyError>,
},
/// The output of a `statement`.
Statement { count: u64, error: Option<AnyError> },
/// The output of a `system` command.
#[non_exhaustive]
System {
stdout: Option<String>,
Expand Down Expand Up @@ -833,10 +836,13 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
}

/// Run a single record.
pub async fn run_async(&mut self, record: Record<D::ColumnType>) -> Result<(), TestError> {
pub async fn run_async(
&mut self,
record: Record<D::ColumnType>,
) -> Result<RecordOutput<D::ColumnType>, TestError> {
let result = self.apply_record(record.clone()).await;

match (record, result) {
match (record, &result) {
(_, RecordOutput::Nothing) => {}
// Tolerate the mismatched return type...
(
Expand Down Expand Up @@ -894,7 +900,7 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
.at(loc))
}
(None, StatementExpect::Count(expected_count)) => {
if expected_count != count {
if expected_count != *count {
return Err(TestErrorKind::StatementResultMismatch {
sql,
expected: expected_count,
Expand All @@ -908,7 +914,7 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
if !expected_error.is_match(&e.to_string()) {
return Err(TestErrorKind::ErrorMismatch {
sql,
err: Arc::new(e),
err: Arc::clone(e),
expected_err: expected_error.to_string(),
kind: RecordKind::Statement,
}
Expand All @@ -918,7 +924,7 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
(Some(e), StatementExpect::Count(_) | StatementExpect::Ok) => {
return Err(TestErrorKind::Fail {
sql,
err: Arc::new(e),
err: Arc::clone(e),
kind: RecordKind::Statement,
}
.at(loc));
Expand Down Expand Up @@ -946,7 +952,7 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
if !expected_error.is_match(&e.to_string()) {
return Err(TestErrorKind::ErrorMismatch {
sql,
err: Arc::new(e),
err: Arc::clone(e),
expected_err: expected_error.to_string(),
kind: RecordKind::Query,
}
Expand All @@ -956,7 +962,7 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
(Some(e), QueryExpect::Results { .. }) => {
return Err(TestErrorKind::Fail {
sql,
err: Arc::new(e),
err: Arc::clone(e),
kind: RecordKind::Query,
}
.at(loc));
Expand All @@ -969,7 +975,7 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
..
},
) => {
if !(self.column_type_validator)(&types, &expected_types) {
if !(self.column_type_validator)(types, &expected_types) {
return Err(TestErrorKind::QueryResultColumnsMismatch {
sql,
expected: expected_types.iter().map(|c| c.to_char()).join(""),
Expand All @@ -978,11 +984,9 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
.at(loc));
}

if !(self.validator)(&rows, &expected_results) {
let output_rows = rows
.into_iter()
.map(|strs| strs.iter().join(" "))
.collect_vec();
if !(self.validator)(rows, &expected_results) {
let output_rows =
rows.iter().map(|strs| strs.iter().join(" ")).collect_vec();
return Err(TestErrorKind::QueryResultMismatch {
sql,
expected: expected_results.join("\n"),
Expand All @@ -1006,12 +1010,16 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
},
) => {
if let Some(err) = error {
return Err(TestErrorKind::SystemFail { command, err }.at(loc));
return Err(TestErrorKind::SystemFail {
command,
err: Arc::clone(err),
}
.at(loc));
}
match (expected_stdout, actual_stdout) {
(None, _) => {}
(Some(expected_stdout), actual_stdout) => {
let actual_stdout = actual_stdout.unwrap_or_default();
let actual_stdout = actual_stdout.clone().unwrap_or_default();
// TODO: support newlines contained in expected_stdout
if expected_stdout != actual_stdout.trim() {
return Err(TestErrorKind::SystemStdoutMismatch {
Expand All @@ -1027,17 +1035,24 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
_ => unreachable!(),
}

Ok(())
Ok(result)
}

/// Run a single record.
pub fn run(&mut self, record: Record<D::ColumnType>) -> Result<(), TestError> {
///
/// Returns the output of the record if successful.
pub fn run(
&mut self,
record: Record<D::ColumnType>,
) -> Result<RecordOutput<D::ColumnType>, TestError> {
futures::executor::block_on(self.run_async(record))
}

/// Run multiple records.
///
/// The runner will stop early once a halt record is seen.
///
/// To acquire the result of each record, manually call `run_async` for each record instead.
pub async fn run_multi_async(
&mut self,
records: impl IntoIterator<Item = Record<D::ColumnType>>,
Expand All @@ -1054,6 +1069,8 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
/// Run multiple records.
///
/// The runner will stop early once a halt record is seen.
///
/// To acquire the result of each record, manually call `run` for each record instead.
pub fn run_multi(
&mut self,
records: impl IntoIterator<Item = Record<D::ColumnType>>,
Expand Down

0 comments on commit 2441dff

Please sign in to comment.