Skip to content

Commit

Permalink
Replace AbortOnDrop / AbortDropOnMany with tokio JoinSet (#6750)
Browse files Browse the repository at this point in the history
* Use JoinSet in MemTable

* Fix error handling

* Refactor AbortOnDropSingle in csv physical plan

* Fix csv write physical plan error propagation

* Refactor json write physical plan to use JoinSet

* Refactor parquet write physical plan to use JoinSet

* Refactor collect_partitioned to use JoinSet

* Refactor pull_from_input method to make it easier to read

* Fix typo
  • Loading branch information
aprimadi authored Jul 4, 2023
1 parent 07a721f commit 02a470f
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 93 deletions.
39 changes: 22 additions & 17 deletions datafusion/core/src/datasource/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@ use async_trait::async_trait;
use datafusion_common::SchemaExt;
use datafusion_execution::TaskContext;
use tokio::sync::RwLock;
use tokio::task::JoinSet;

use crate::datasource::{TableProvider, TableType};
use crate::error::{DataFusionError, Result};
use crate::execution::context::SessionState;
use crate::logical_expr::Expr;
use crate::physical_plan::common::AbortOnDropSingle;
use crate::physical_plan::insert::{DataSink, InsertExec};
use crate::physical_plan::memory::MemoryExec;
use crate::physical_plan::{common, SendableRecordBatchStream};
Expand Down Expand Up @@ -89,26 +89,31 @@ impl MemTable {
let exec = t.scan(state, None, &[], None).await?;
let partition_count = exec.output_partitioning().partition_count();

let tasks = (0..partition_count)
.map(|part_i| {
let task = state.task_ctx();
let exec = exec.clone();
let task = tokio::spawn(async move {
let stream = exec.execute(part_i, task)?;
common::collect(stream).await
});

AbortOnDropSingle::new(task)
})
// this collect *is needed* so that the join below can
// switch between tasks
.collect::<Vec<_>>();
let mut join_set = JoinSet::new();

for part_idx in 0..partition_count {
let task = state.task_ctx();
let exec = exec.clone();
join_set.spawn(async move {
let stream = exec.execute(part_idx, task)?;
common::collect(stream).await
});
}

let mut data: Vec<Vec<RecordBatch>> =
Vec::with_capacity(exec.output_partitioning().partition_count());

for result in futures::future::join_all(tasks).await {
data.push(result.map_err(|e| DataFusionError::External(Box::new(e)))??)
while let Some(result) = join_set.join_next().await {
match result {
Ok(res) => data.push(res?),
Err(e) => {
if e.is_panic() {
std::panic::resume_unwind(e.into_panic());
} else {
unreachable!();
}
}
}
}

let exec = MemoryExec::try_new(&data, schema.clone(), None)?;
Expand Down
32 changes: 19 additions & 13 deletions datafusion/core/src/datasource/physical_plan/csv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ use crate::datasource::physical_plan::file_stream::{
};
use crate::datasource::physical_plan::FileMeta;
use crate::error::{DataFusionError, Result};
use crate::physical_plan::common::AbortOnDropSingle;
use crate::physical_plan::expressions::PhysicalSortExpr;
use crate::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet};
use crate::physical_plan::{
Expand All @@ -46,7 +45,7 @@ use std::fs;
use std::path::Path;
use std::sync::Arc;
use std::task::Poll;
use tokio::task::{self, JoinHandle};
use tokio::task::JoinSet;

/// Execution plan for scanning a CSV file
#[derive(Debug, Clone)]
Expand Down Expand Up @@ -331,7 +330,7 @@ pub async fn plan_to_csv(
)));
}

let mut tasks = vec![];
let mut join_set = JoinSet::new();
for i in 0..plan.output_partitioning().partition_count() {
let plan = plan.clone();
let filename = format!("part-{i}.csv");
Expand All @@ -340,22 +339,29 @@ pub async fn plan_to_csv(
let mut writer = csv::Writer::new(file);
let stream = plan.execute(i, task_ctx.clone())?;

let handle: JoinHandle<Result<()>> = task::spawn(async move {
stream
join_set.spawn(async move {
let result: Result<()> = stream
.map(|batch| writer.write(&batch?))
.try_collect()
.await
.map_err(DataFusionError::from)
.map_err(DataFusionError::from);
result
});
tasks.push(AbortOnDropSingle::new(handle));
}

futures::future::join_all(tasks)
.await
.into_iter()
.try_for_each(|result| {
result.map_err(|e| DataFusionError::Execution(format!("{e}")))?
})?;
while let Some(result) = join_set.join_next().await {
match result {
Ok(res) => res?, // propagate DataFusion error
Err(e) => {
if e.is_panic() {
std::panic::resume_unwind(e.into_panic());
} else {
unreachable!();
}
}
}
}

Ok(())
}

Expand Down
32 changes: 19 additions & 13 deletions datafusion/core/src/datasource/physical_plan/json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ use crate::datasource::physical_plan::file_stream::{
};
use crate::datasource::physical_plan::FileMeta;
use crate::error::{DataFusionError, Result};
use crate::physical_plan::common::AbortOnDropSingle;
use crate::physical_plan::expressions::PhysicalSortExpr;
use crate::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet};
use crate::physical_plan::{
Expand All @@ -44,7 +43,7 @@ use std::io::BufReader;
use std::path::Path;
use std::sync::Arc;
use std::task::Poll;
use tokio::task::{self, JoinHandle};
use tokio::task::JoinSet;

use super::FileScanConfig;

Expand Down Expand Up @@ -266,30 +265,37 @@ pub async fn plan_to_json(
)));
}

let mut tasks = vec![];
let mut join_set = JoinSet::new();
for i in 0..plan.output_partitioning().partition_count() {
let plan = plan.clone();
let filename = format!("part-{i}.json");
let path = fs_path.join(filename);
let file = fs::File::create(path)?;
let mut writer = json::LineDelimitedWriter::new(file);
let stream = plan.execute(i, task_ctx.clone())?;
let handle: JoinHandle<Result<()>> = task::spawn(async move {
stream
join_set.spawn(async move {
let result: Result<()> = stream
.map(|batch| writer.write(&batch?))
.try_collect()
.await
.map_err(DataFusionError::from)
.map_err(DataFusionError::from);
result
});
tasks.push(AbortOnDropSingle::new(handle));
}

futures::future::join_all(tasks)
.await
.into_iter()
.try_for_each(|result| {
result.map_err(|e| DataFusionError::Execution(format!("{e}")))?
})?;
while let Some(result) = join_set.join_next().await {
match result {
Ok(res) => res?, // propagate DataFusion error
Err(e) => {
if e.is_panic() {
std::panic::resume_unwind(e.into_panic());
} else {
unreachable!();
}
}
}
}

Ok(())
}

Expand Down
43 changes: 23 additions & 20 deletions datafusion/core/src/datasource/physical_plan/parquet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ use crate::{
execution::context::TaskContext,
physical_optimizer::pruning::PruningPredicate,
physical_plan::{
common::AbortOnDropSingle,
metrics::{ExecutionPlanMetricsSet, MetricBuilder, MetricsSet},
ordering_equivalence_properties_helper, DisplayFormatType, ExecutionPlan,
Partitioning, SendableRecordBatchStream, Statistics,
Expand Down Expand Up @@ -64,6 +63,7 @@ use parquet::arrow::{ArrowWriter, ParquetRecordBatchStreamBuilder, ProjectionMas
use parquet::basic::{ConvertedType, LogicalType};
use parquet::file::{metadata::ParquetMetaData, properties::WriterProperties};
use parquet::schema::types::ColumnDescriptor;
use tokio::task::JoinSet;

mod metrics;
pub mod page_filter;
Expand Down Expand Up @@ -701,7 +701,7 @@ pub async fn plan_to_parquet(
)));
}

let mut tasks = vec![];
let mut join_set = JoinSet::new();
for i in 0..plan.output_partitioning().partition_count() {
let plan = plan.clone();
let filename = format!("part-{i}.parquet");
Expand All @@ -710,27 +710,30 @@ pub async fn plan_to_parquet(
let mut writer =
ArrowWriter::try_new(file, plan.schema(), writer_properties.clone())?;
let stream = plan.execute(i, task_ctx.clone())?;
let handle: tokio::task::JoinHandle<Result<()>> =
tokio::task::spawn(async move {
stream
.map(|batch| {
writer.write(&batch?).map_err(DataFusionError::ParquetError)
})
.try_collect()
.await
.map_err(DataFusionError::from)?;
join_set.spawn(async move {
stream
.map(|batch| writer.write(&batch?).map_err(DataFusionError::ParquetError))
.try_collect()
.await
.map_err(DataFusionError::from)?;

writer.close().map_err(DataFusionError::from).map(|_| ())
});
}

writer.close().map_err(DataFusionError::from).map(|_| ())
});
tasks.push(AbortOnDropSingle::new(handle));
while let Some(result) = join_set.join_next().await {
match result {
Ok(res) => res?,
Err(e) => {
if e.is_panic() {
std::panic::resume_unwind(e.into_panic());
} else {
unreachable!();
}
}
}
}

futures::future::join_all(tasks)
.await
.into_iter()
.try_for_each(|result| {
result.map_err(|e| DataFusionError::Execution(format!("{e}")))?
})?;
Ok(())
}

Expand Down
41 changes: 29 additions & 12 deletions datafusion/core/src/physical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ pub use display::{DefaultDisplay, DisplayAs, DisplayFormatType, VerboseDisplay};
use futures::stream::{Stream, TryStreamExt};
use std::fmt;
use std::fmt::Debug;
use tokio::task::JoinSet;

use datafusion_common::tree_node::Transformed;
use datafusion_common::DataFusionError;
Expand Down Expand Up @@ -445,20 +446,37 @@ pub async fn collect_partitioned(
) -> Result<Vec<Vec<RecordBatch>>> {
let streams = execute_stream_partitioned(plan, context)?;

let mut join_set = JoinSet::new();
// Execute the plan and collect the results into batches.
let handles = streams
.into_iter()
.enumerate()
.map(|(idx, stream)| async move {
let handle = tokio::task::spawn(stream.try_collect());
AbortOnDropSingle::new(handle).await.map_err(|e| {
DataFusionError::Execution(format!(
"collect_partitioned partition {idx} panicked: {e}"
))
})?
streams.into_iter().enumerate().for_each(|(idx, stream)| {
join_set.spawn(async move {
let result: Result<Vec<RecordBatch>> = stream.try_collect().await;
(idx, result)
});
});

let mut batches = vec![];
// Note that currently this doesn't identify the thread that panicked
//
// TODO: Replace with [join_next_with_id](https://docs.rs/tokio/latest/tokio/task/struct.JoinSet.html#method.join_next_with_id
// once it is stable
while let Some(result) = join_set.join_next().await {
match result {
Ok((idx, res)) => batches.push((idx, res?)),
Err(e) => {
if e.is_panic() {
std::panic::resume_unwind(e.into_panic());
} else {
unreachable!();
}
}
}
}

batches.sort_by_key(|(idx, _)| *idx);
let batches = batches.into_iter().map(|(_, batch)| batch).collect();

futures::future::try_join_all(handles).await
Ok(batches)
}

/// Execute the [ExecutionPlan] and return a vec with one stream per output partition
Expand Down Expand Up @@ -713,7 +731,6 @@ pub mod unnest;
pub mod values;
pub mod windows;

use crate::physical_plan::common::AbortOnDropSingle;
use crate::physical_plan::repartition::RepartitionExec;
use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec;
use datafusion_execution::TaskContext;
Expand Down
Loading

0 comments on commit 02a470f

Please sign in to comment.