diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 134cde8976d67..5a2aaa3e50fd5 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -88,6 +88,12 @@ version = "1.0.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1bec1de6f59aedf83baf9ff929c98f2ad654b97c9510f4e70cf6f661d49fd5b1" +[[package]] +name = "anyhow" +version = "1.0.86" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b3d1d046238990b9cf5bcde22a3fb3584ee5cf65fb2765f454ed428c7a0063da" + [[package]] name = "apache-avro" version = "0.16.0" @@ -244,6 +250,34 @@ dependencies = [ "num", ] +[[package]] +name = "arrow-flight" +version = "52.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e7ffbc96072e466ae5188974725bb46757587eafe427f77a25b828c375ae882" +dependencies = [ + "arrow-arith", + "arrow-array", + "arrow-buffer", + "arrow-cast", + "arrow-data", + "arrow-ipc", + "arrow-ord", + "arrow-row", + "arrow-schema", + "arrow-select", + "arrow-string", + "base64 0.22.1", + "bytes", + "futures", + "once_cell", + "paste", + "prost", + "prost-types", + "tokio", + "tonic", +] + [[package]] name = "arrow-ipc" version = "52.2.0" @@ -378,6 +412,28 @@ dependencies = [ "zstd-safe 7.2.1", ] +[[package]] +name = "async-stream" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd56dd203fef61ac097dd65721a419ddccb106b2d2b70ba60a6b529f03961a51" +dependencies = [ + "async-stream-impl", + "futures-core", + "pin-project-lite", +] + +[[package]] +name = "async-stream-impl" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.72", +] + [[package]] name = "async-trait" version = "0.1.81" @@ -711,6 +767,51 @@ dependencies = [ "tracing", ] +[[package]] +name = "axum" +version = "0.6.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b829e4e32b91e643de6eafe82b1d90675f5874230191a4ffbc1b336dec4d6bf" +dependencies = [ + "async-trait", + "axum-core", + "bitflags 1.3.2", + "bytes", + "futures-util", + "http 0.2.12", + "http-body 0.4.6", + "hyper 0.14.30", + "itoa", + "matchit", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "rustversion", + "serde", + "sync_wrapper 0.1.2", + "tower", + "tower-layer", + "tower-service", +] + +[[package]] +name = "axum-core" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "759fa577a247914fd3f7f76d62972792636412fbfd634cd452f6a385a74d2d2c" +dependencies = [ + "async-trait", + "bytes", + "futures-util", + "http 0.2.12", + "http-body 0.4.6", + "mime", + "rustversion", + "tower-layer", + "tower-service", +] + [[package]] name = "backtrace" version = "0.3.73" @@ -1134,6 +1235,7 @@ dependencies = [ "apache-avro", "arrow", "arrow-array", + "arrow-flight", "arrow-ipc", "arrow-schema", "async-compression", @@ -1177,6 +1279,7 @@ dependencies = [ "tempfile", "tokio", "tokio-util", + "tonic", "url", "uuid", "xz2", @@ -2092,6 +2195,18 @@ dependencies = [ "tower-service", ] +[[package]] +name = "hyper-timeout" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbb958482e8c7be4bc3cf272a766a2b0bf1a6755e7a6ae777f017a31d11b13b1" +dependencies = [ + "hyper 0.14.30", + "pin-project-lite", + "tokio", + "tokio-io-timeout", +] + [[package]] name = "hyper-util" version = "0.1.7" @@ -2399,6 +2514,12 @@ dependencies = [ "pkg-config", ] +[[package]] +name = "matchit" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" + [[package]] name = "md-5" version = "0.10.6" @@ -2884,6 +3005,38 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "prost" +version = "0.12.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "deb1435c188b76130da55f17a466d252ff7b1418b2ad3e037d127b94e3411f29" +dependencies = [ + "bytes", + "prost-derive", +] + +[[package]] +name = "prost-derive" +version = "0.12.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81bddcdb20abf9501610992b6759a4c888aef7d1a7247ef75e2404275ac24af1" +dependencies = [ + "anyhow", + "itertools 0.12.1", + "proc-macro2", + "quote", + "syn 2.0.72", +] + +[[package]] +name = "prost-types" +version = "0.12.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9091c90b0a32608e984ff2fa4091273cbdd755d54935c51d520887f4a1dbd5b0" +dependencies = [ + "prost", +] + [[package]] name = "quad-rand" version = "0.2.1" @@ -3084,7 +3237,7 @@ dependencies = [ "serde", "serde_json", "serde_urlencoded", - "sync_wrapper", + "sync_wrapper 1.0.1", "tokio", "tokio-rustls 0.26.0", "tokio-util", @@ -3627,6 +3780,12 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "sync_wrapper" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" + [[package]] name = "sync_wrapper" version = "1.0.1" @@ -3770,6 +3929,16 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "tokio-io-timeout" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30b74022ada614a1b4834de765f9bb43877f910cc8ce4be40e89042c9223a8bf" +dependencies = [ + "pin-project-lite", + "tokio", +] + [[package]] name = "tokio-macros" version = "2.4.0" @@ -3827,6 +3996,33 @@ dependencies = [ "tokio", ] +[[package]] +name = "tonic" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76c4eb7a4e9ef9d4763600161f12f5070b92a578e1b634db88a6887844c91a13" +dependencies = [ + "async-stream", + "async-trait", + "axum", + "base64 0.21.7", + "bytes", + "h2 0.3.26", + "http 0.2.12", + "http-body 0.4.6", + "hyper 0.14.30", + "hyper-timeout", + "percent-encoding", + "pin-project", + "prost", + "tokio", + "tokio-stream", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "tower" version = "0.4.13" @@ -3835,9 +4031,13 @@ checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c" dependencies = [ "futures-core", "futures-util", + "indexmap 1.9.3", "pin-project", "pin-project-lite", + "rand", + "slab", "tokio", + "tokio-util", "tower-layer", "tower-service", "tracing", diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml index cbd9ffd0febab..d078868506d1c 100644 --- a/datafusion-cli/Cargo.toml +++ b/datafusion-cli/Cargo.toml @@ -44,6 +44,7 @@ datafusion = { path = "../datafusion/core", version = "41.0.0", features = [ "regex_expressions", "unicode_expressions", "compression", + "flight", ] } dirs = "4.0.0" env_logger = "0.9" diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index e678c93ede8be..c47e8441d9694 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -59,10 +59,12 @@ default = [ "unicode_expressions", "compression", "parquet", + "flight", ] encoding_expressions = ["datafusion-functions/encoding_expressions"] # Used for testing ONLY: causes all values to hash to the same value (test for collisions) force_hash_collisions = ["datafusion-physical-plan/force_hash_collisions", "datafusion-common/force_hash_collisions"] +flight = ["dep:arrow-flight", "dep:tonic"] math_expressions = ["datafusion-functions/math_expressions"] parquet = ["datafusion-common/parquet", "dep:parquet"] pyarrow = ["datafusion-common/pyarrow", "parquet"] @@ -83,6 +85,7 @@ ahash = { workspace = true } apache-avro = { version = "0.16", optional = true } arrow = { workspace = true } arrow-array = { workspace = true } +arrow-flight = { workspace = true, optional = true } arrow-ipc = { workspace = true } arrow-schema = { workspace = true } async-compression = { version = "0.4.0", features = [ @@ -133,6 +136,7 @@ sqlparser = { workspace = true } tempfile = { workspace = true } tokio = { workspace = true } tokio-util = { version = "0.7.4", features = ["io"], optional = true } +tonic = { version = "0.11", optional = true } url = { workspace = true } uuid = { version = "1.7", features = ["v4"] } xz2 = { version = "0.1", optional = true, features = ["static"] } @@ -151,6 +155,7 @@ half = { workspace = true, default-features = true } paste = "^1.0" postgres-protocol = "0.6.4" postgres-types = { version = "0.2.4", features = ["derive", "with-chrono-0_4"] } +prost = "0.12" rand = { workspace = true, features = ["small_rng"] } rand_distr = "0.4.3" regex = { workspace = true } @@ -161,6 +166,7 @@ test-utils = { path = "../../test-utils" } thiserror = { workspace = true } tokio = { workspace = true, features = ["rt-multi-thread", "parking_lot", "fs"] } tokio-postgres = "0.7.7" +tokio-stream = { version = "0.1.15", features = ["net"] } [target.'cfg(not(target_os = "windows"))'.dev-dependencies] nix = { version = "0.29.0", features = ["fs"] } diff --git a/datafusion/core/src/datasource/flight/mod.rs b/datafusion/core/src/datasource/flight/mod.rs new file mode 100644 index 0000000000000..eb95c2200a68d --- /dev/null +++ b/datafusion/core/src/datasource/flight/mod.rs @@ -0,0 +1,230 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Generic [FlightTableFactory] that can connect to Arrow Flight services, +//! with a [sql::FlightSqlDriver] provided out-of-the-box. + +use std::any::Any; +use std::collections::HashMap; +use std::fmt::Debug; +use std::sync::Arc; + +use arrow_flight::error::FlightError; +use arrow_flight::FlightInfo; +use arrow_schema::SchemaRef; +use async_trait::async_trait; +use tonic::transport::Channel; + +use datafusion_catalog::{Session, TableProvider, TableProviderFactory}; +use datafusion_common::DataFusionError; +use datafusion_expr::TableType::View; +use datafusion_expr::{CreateExternalTable, Expr, TableType}; +use datafusion_physical_expr::EquivalenceProperties; +use datafusion_physical_expr::Partitioning::UnknownPartitioning; +use datafusion_physical_plan::{ExecutionMode, ExecutionPlan, PlanProperties}; + +use crate::datasource::physical_plan::FlightExec; + +pub mod sql; + +/// Generic Arrow Flight data source. Requires a [FlightDriver] that allows implementors +/// to integrate any custom Flight RPC service by producing a [FlightMetadata] for some DDL. +/// +/// # Sample usage: +/// ``` +/// use std::collections::HashMap; +/// use arrow_flight::{FlightClient, FlightDescriptor}; +/// use tonic::transport::Channel; +/// use datafusion::datasource::flight::{FlightMetadata, FlightDriver}; +/// use datafusion::prelude::SessionContext; +/// use std::sync::Arc; +/// use datafusion::datasource::flight::FlightTableFactory; +/// +/// #[derive(Debug, Clone, Default)] +/// struct CustomFlightDriver {} +/// #[async_trait::async_trait] +/// impl FlightDriver for CustomFlightDriver { +/// async fn metadata(&self, channel: Channel, opts: &HashMap) +/// -> arrow_flight::error::Result { +/// let mut client = FlightClient::new(channel); +/// let descriptor = FlightDescriptor::new_cmd(opts["custom.flight.command"].clone()); +/// let flight_info = client.get_flight_info(descriptor).await?; +/// FlightMetadata::try_from(flight_info) +/// } +/// } +/// +/// #[tokio::main] +/// async fn main() -> datafusion_common::Result<()> { +/// let ctx = SessionContext::new(); +/// ctx.state_ref().write().table_factories_mut() +/// .insert("CUSTOM_FLIGHT".into(), Arc::new(FlightTableFactory::new( +/// Arc::new(CustomFlightDriver::default()) +/// ))); +/// let _ = ctx.sql(r#" +/// CREATE EXTERNAL TABLE custom_flight_table STORED AS CUSTOM_FLIGHT +/// LOCATION 'https://custom.flight.rpc' +/// OPTIONS ('custom.flight.command' 'select * from everywhere') +/// "#).await; // will fail as it can't connect to the bogus URL, but we ignore the error +/// Ok(()) +/// } +/// +/// ``` +#[derive(Clone, Debug)] +pub struct FlightTableFactory { + driver: Arc, +} + +impl FlightTableFactory { + /// Create a data source using the provided driver + pub fn new(driver: Arc) -> Self { + Self { driver } + } +} + +#[async_trait] +impl TableProviderFactory for FlightTableFactory { + async fn create( + &self, + _state: &dyn Session, + cmd: &CreateExternalTable, + ) -> datafusion_common::Result> { + let channel = Channel::from_shared(cmd.location.clone()) + .unwrap() + .connect() + .await + .map_err(|e| DataFusionError::External(Box::new(e)))?; + let metadata = self + .driver + .metadata(channel, &cmd.options) + .await + .map_err(|e| DataFusionError::External(Box::new(e)))?; + Ok(Arc::new(FlightTable { + metadata, + origin: cmd.location.clone(), + })) + } +} + +/// Extension point for integrating any Flight RPC service as a [FlightTableFactory]. +/// Handles the initial `GetFlightInfo` call and all its prerequisites (such as `Handshake`), +/// to produce a [FlightMetadata]. +#[async_trait] +pub trait FlightDriver: Sync + Send + Debug { + /// Returns a [FlightMetadata] from the specified channel, + /// according to the provided table options. + /// The driver must provide at least a [FlightInfo] in order to construct a flight metadata. + async fn metadata( + &self, + channel: Channel, + opts: &HashMap, + ) -> arrow_flight::error::Result; +} + +/// The information required for registering flights as DataFusion tables. +#[derive(Clone, Debug)] +pub struct FlightMetadata { + info: Arc, + props: Arc, + schema: SchemaRef, +} + +impl FlightMetadata { + /// Wrap the provided [FlightInfo], [PlanProperties] and [SchemaRef]. + /// Provides full control over both plan properties and schema. + pub fn new(info: FlightInfo, props: PlanProperties, schema: SchemaRef) -> Self { + let info = Arc::new(info); + let props = Arc::new(props); + Self { + info, + props, + schema, + } + } + + /// Provide custom [PlanProperties] to account for service specifics, + /// such as known partitioning scheme, unbounded execution mode etc. + /// Infers the schema from the [FlightInfo]. + /// If the schema was already decoded for building the [PlanProperties], + /// the [FlightMetadata::new()] constructor should be used instead to + /// pass a reference to it instead of decoding it twice. + pub fn try_new( + info: FlightInfo, + props: PlanProperties, + ) -> arrow_flight::error::Result { + Ok(Self::new( + info.clone(), + props, + Arc::new(info.try_decode_schema()?), + )) + } +} + +/// Uses the default [PlanProperties] and infers the schema from the FlightInfo response. +impl TryFrom for FlightMetadata { + type Error = FlightError; + + fn try_from(info: FlightInfo) -> Result { + let schema = Arc::new(info.clone().try_decode_schema()?); + let partitions = info.endpoint.len(); + let info = Arc::new(info); + let props = Arc::new(PlanProperties::new( + EquivalenceProperties::new(schema.clone()), + UnknownPartitioning(partitions), + ExecutionMode::Bounded, + )); + Ok(Self { + info, + props, + schema, + }) + } +} + +/// Table provider that wraps a specific flight from an Arrow Flight service +struct FlightTable { + metadata: FlightMetadata, + origin: String, +} + +#[async_trait] +impl TableProvider for FlightTable { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.metadata.schema.clone() + } + + fn table_type(&self) -> TableType { + View + } + + async fn scan( + &self, + _state: &dyn Session, + _projection: Option<&Vec>, + _filters: &[Expr], + _limit: Option, + ) -> datafusion_common::Result> { + Ok(Arc::new(FlightExec::new( + self.metadata.info.clone(), + self.metadata.props.clone(), + self.origin.clone(), + ))) + } +} diff --git a/datafusion/core/src/datasource/flight/sql.rs b/datafusion/core/src/datasource/flight/sql.rs new file mode 100644 index 0000000000000..ab5aa5f843259 --- /dev/null +++ b/datafusion/core/src/datasource/flight/sql.rs @@ -0,0 +1,289 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Default [FlightDriver] for Flight SQL + +use std::collections::HashMap; + +use arrow_flight::error::Result; +use arrow_flight::sql::client::FlightSqlServiceClient; +use async_trait::async_trait; +use tonic::transport::Channel; + +use crate::datasource::flight::{FlightDriver, FlightMetadata}; + +/// Default Flight SQL driver. Requires a `flight.sql.query` to be passed as a table option. +/// If `flight.sql.username` (and optionally `flight.sql.password`) are passed, +/// will perform the `Handshake` using basic authentication. +/// Any additional headers can be passed as table options using the `flight.sql.header.` prefix. +/// +/// A [crate::datasource::flight::FlightTableFactory] using this driver is registered +/// with the default `SessionContext` under the name `FLIGHT_SQL`. +#[derive(Clone, Debug, Default)] +pub struct FlightSqlDriver {} + +#[async_trait] +impl FlightDriver for FlightSqlDriver { + async fn metadata( + &self, + channel: Channel, + opts: &HashMap, + ) -> Result { + let mut client = FlightSqlServiceClient::new(channel); + let headers = opts.iter().filter_map(|(key, value)| { + key.strip_prefix("flight.sql.header.") + .map(|header_name| (header_name, value)) + }); + for header in headers { + client.set_header(header.0, header.1) + } + if let Some(username) = opts.get("flight.sql.username") { + let default_password = "".to_string(); + let password = opts.get("flight.sql.password").unwrap_or(&default_password); + _ = client.handshake(username, password).await?; + } + let info = client + .execute(opts["flight.sql.query"].clone(), None) + .await?; + info.try_into() + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + use std::net::SocketAddr; + use std::pin::Pin; + use std::sync::Arc; + use std::time::Duration; + + use crate::prelude::SessionContext; + use arrow_array::{Int8Array, RecordBatch}; + use arrow_flight::encode::FlightDataEncoderBuilder; + use arrow_flight::flight_service_server::{FlightService, FlightServiceServer}; + use arrow_flight::sql::server::FlightSqlService; + use arrow_flight::sql::{ + CommandStatementQuery, ProstMessageExt, SqlInfo, TicketStatementQuery, + }; + use arrow_flight::{ + FlightDescriptor, FlightEndpoint, FlightInfo, HandshakeRequest, + HandshakeResponse, Ticket, + }; + use arrow_schema::{DataType, Field, Schema}; + use async_trait::async_trait; + use futures::{stream, Stream, TryStreamExt}; + use prost::Message; + use tokio::net::TcpListener; + use tokio::sync::oneshot::{channel, Receiver, Sender}; + use tokio_stream::wrappers::TcpListenerStream; + use tonic::codegen::http::HeaderMap; + use tonic::codegen::tokio_stream; + use tonic::metadata::MetadataMap; + use tonic::transport::Server; + use tonic::{Extensions, Request, Response, Status, Streaming}; + + const AUTH_HEADER: &str = "authorization"; + const BEARER_TOKEN: &str = "Bearer flight-sql-token"; + + struct TestFlightSqlService { + flight_info: FlightInfo, + partition_data: RecordBatch, + expected_handshake_headers: HashMap, + expected_flight_info_query: String, + shutdown_sender: Option>, + } + + impl TestFlightSqlService { + async fn run_in_background(self, rx: Receiver<()>) -> SocketAddr { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let service = FlightServiceServer::new(self); + #[allow(clippy::disallowed_methods)] // spawn allowed only in tests + tokio::spawn(async move { + Server::builder() + .timeout(Duration::from_secs(1)) + .add_service(service) + .serve_with_incoming_shutdown( + TcpListenerStream::new(listener), + async { + rx.await.ok(); + }, + ) + .await + .unwrap(); + }); + tokio::time::sleep(Duration::from_millis(25)).await; + addr + } + } + + impl Drop for TestFlightSqlService { + fn drop(&mut self) { + if let Some(tx) = self.shutdown_sender.take() { + tx.send(()).ok(); + } + } + } + + fn check_header( + request: &Request, + rpc: &str, + header_name: &str, + expected_value: &str, + ) { + let actual_value = request + .metadata() + .get(header_name) + .unwrap_or_else(|| panic!("[{}] missing header `{}`", rpc, header_name)) + .to_str() + .unwrap_or_else(|e| { + panic!( + "[{}] error parsing value for header `{}`: {:?}", + rpc, header_name, e + ) + }); + assert_eq!( + actual_value, expected_value, + "[{}] unexpected value for header `{}`", + rpc, header_name + ) + } + + #[async_trait] + impl FlightSqlService for TestFlightSqlService { + type FlightService = TestFlightSqlService; + + async fn do_handshake( + &self, + request: Request>, + ) -> Result< + Response< + Pin> + Send>>, + >, + Status, + > { + for (header_name, expected_value) in self.expected_handshake_headers.iter() { + check_header(&request, "do_handshake", header_name, expected_value); + } + Ok(Response::from_parts( + MetadataMap::from_headers(HeaderMap::from_iter([( + AUTH_HEADER.parse().unwrap(), + BEARER_TOKEN.parse().unwrap(), + )])), // the client should send this header back on the next request (i.e. GetFlightInfo) + Box::pin(tokio_stream::empty()), + Extensions::default(), + )) + } + + async fn get_flight_info_statement( + &self, + query: CommandStatementQuery, + request: Request, + ) -> Result, Status> { + let mut expected_flight_info_headers = + self.expected_handshake_headers.clone(); + expected_flight_info_headers.insert(AUTH_HEADER.into(), BEARER_TOKEN.into()); + for (header_name, expected_value) in expected_flight_info_headers.iter() { + check_header(&request, "get_flight_info", header_name, expected_value); + } + assert_eq!( + query.query.to_lowercase(), + self.expected_flight_info_query.to_lowercase() + ); + Ok(Response::new(self.flight_info.clone())) + } + + async fn do_get_statement( + &self, + _ticket: TicketStatementQuery, + _request: Request, + ) -> Result::DoGetStream>, Status> { + let data = self.partition_data.clone(); + let rb = async move { Ok(data) }; + + let stream = FlightDataEncoderBuilder::default() + .with_schema(self.partition_data.schema()) + .build(stream::once(rb)) + .map_err(|e| Status::from_error(Box::new(e))); + + Ok(Response::new(Box::pin(stream))) + } + + async fn register_sql_info(&self, _id: i32, _result: &SqlInfo) {} + } + + #[tokio::test] + async fn flight_sql_data_source() -> datafusion_common::Result<()> { + let partition_data = RecordBatch::try_new( + Arc::new(Schema::new([Arc::new(Field::new( + "col1", + DataType::Int8, + false, + ))])), + vec![Arc::new(Int8Array::from(vec![0, 1, 2, 3]))], + ) + .unwrap(); + let rows_per_partition = partition_data.num_rows(); + + let query = "SELECT * FROM some_table"; + let ticket_payload = TicketStatementQuery::default().as_any().encode_to_vec(); + let endpoint_archetype = + FlightEndpoint::default().with_ticket(Ticket::new(ticket_payload)); + let endpoints = vec![ + endpoint_archetype.clone(), + endpoint_archetype.clone(), + endpoint_archetype, + ]; + let num_partitions = endpoints.len(); + let flight_info = FlightInfo::default() + .try_with_schema(partition_data.schema().as_ref()) + .unwrap(); + let flight_info = endpoints + .into_iter() + .fold(flight_info, |fi, e| fi.with_endpoint(e)); + let (tx, rx) = channel(); + let service = TestFlightSqlService { + flight_info, + partition_data, + expected_handshake_headers: HashMap::from([ + (AUTH_HEADER.into(), "Basic YWRtaW46cGFzc3dvcmQ=".into()), + ("custom-hdr1".into(), "v1".into()), + ("custom-hdr2".into(), "v2".into()), + ]), + expected_flight_info_query: query.into(), + shutdown_sender: Some(tx), + }; + let port = service.run_in_background(rx).await.port(); + let ctx = SessionContext::new(); + let _ = ctx.sql(&format!(r#" + CREATE EXTERNAL TABLE fsql STORED AS FLIGHT_SQL LOCATION 'http://localhost:{port}' + OPTIONS( + 'flight.sql.username' 'admin', + 'flight.sql.password' 'password', + 'flight.sql.query' '{query}', + 'flight.sql.header.custom-hdr1' 'v1', + 'flight.sql.header.custom-hdr2' 'v2', + )"# + )).await.unwrap(); + let df = ctx.sql("select col1 from fsql").await.unwrap(); + assert_eq!( + df.count().await.unwrap(), + rows_per_partition * num_partitions + ); + Ok(()) + } +} diff --git a/datafusion/core/src/datasource/mod.rs b/datafusion/core/src/datasource/mod.rs index 1c9924735735d..58ca50c6a68b1 100644 --- a/datafusion/core/src/datasource/mod.rs +++ b/datafusion/core/src/datasource/mod.rs @@ -24,6 +24,8 @@ pub mod cte_worktable; pub mod default_table_source; pub mod empty; pub mod file_format; +#[cfg(feature = "flight")] +pub mod flight; pub mod function; pub mod listing; pub mod listing_table_factory; diff --git a/datafusion/core/src/datasource/physical_plan/flight.rs b/datafusion/core/src/datasource/physical_plan/flight.rs new file mode 100644 index 0000000000000..5e91f047423e0 --- /dev/null +++ b/datafusion/core/src/datasource/physical_plan/flight.rs @@ -0,0 +1,177 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Execution plan for reading flights from Arrow Flight services + +use std::any::Any; +use std::error::Error; +use std::fmt::Formatter; +use std::sync::Arc; + +use arrow_flight::error::FlightError; +use arrow_flight::{FlightClient, FlightInfo, Ticket}; +use arrow_schema::SchemaRef; +use futures::TryStreamExt; +use tonic::transport::Channel; + +use datafusion_common::DataFusionError; +use datafusion_common::Result; +use datafusion_execution::{SendableRecordBatchStream, TaskContext}; +use datafusion_physical_plan::stream::RecordBatchStreamAdapter; +use datafusion_physical_plan::{ + DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, +}; + +/// Arrow Flight physical plan that maps flight endpoints to partitions +#[derive(Debug)] +pub struct FlightExec { + info: Arc, + props: Arc, + origin: String, +} + +impl FlightExec { + /// Creates a FlightExec with the provided [FlightInfo], [PlanProperties] + /// and origin URL (used as fallback location as per the protocol spec). + pub fn new( + info: Arc, + props: Arc, + origin: String, + ) -> Self { + Self { + info, + props, + origin, + } + } +} + +async fn flight_stream( + info: Arc, + fallback_location: String, + schema: SchemaRef, + partition: usize, +) -> Result { + let endpoint = &info.endpoint[partition]; + let locations = if endpoint.location.is_empty() { + vec![fallback_location] + } else { + endpoint + .location + .iter() + .map(|loc| { + if loc.uri.starts_with("arrow-flight-reuse-connection://") { + fallback_location.clone() + } else { + loc.uri.clone() + } + }) + .collect() + }; + let mut errors: Vec> = vec![]; + for loc in locations { + match try_fetch_stream(loc, endpoint.ticket.clone(), schema.clone()).await { + Ok(stream) => return Ok(stream), + Err(e) => errors.push(Box::new(e)), + } + } + let err = errors.into_iter().last().unwrap_or_else(|| { + Box::new(FlightError::ProtocolError(format!( + "No available endpoint for partition {}: {:?}", + partition, &endpoint.location + ))) + }); + Err(DataFusionError::External(err)) +} + +async fn try_fetch_stream( + source: String, + ticket: Option, + schema: SchemaRef, +) -> arrow_flight::error::Result { + let ticket = ticket.ok_or(FlightError::ProtocolError("no flight ticket".into()))?; + let dest = Channel::from_shared(source) + .map_err(|e| FlightError::ExternalError(Box::new(e)))?; + let mut client = dest + .connect() + .await + .map(FlightClient::new) + .map_err(|e| FlightError::ExternalError(Box::new(e)))?; + let stream = client.do_get(ticket).await?; + Ok(Box::pin(RecordBatchStreamAdapter::new( + schema, + stream.map_err(|e| DataFusionError::External(Box::new(e))), + ))) +} + +impl DisplayAs for FlightExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { + match t { + DisplayFormatType::Default => f.write_str("FlightExec"), + DisplayFormatType::Verbose => write!( + f, + "FlightExec [{}]: {:?}, {:?}", + self.origin.as_str(), + self.info.as_ref(), + self.props.as_ref() + ), + } + } +} + +impl ExecutionPlan for FlightExec { + fn name(&self) -> &str { + "FlightExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + self.props.as_ref() + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _children: Vec>, + ) -> Result> { + Ok(self) + } + + fn execute( + &self, + partition: usize, + _context: Arc, + ) -> Result { + let future_stream = flight_stream( + self.info.clone(), + self.origin.clone(), + self.schema(), + partition, + ); + let stream = futures::stream::once(future_stream).try_flatten(); + Ok(Box::pin(RecordBatchStreamAdapter::new( + self.schema(), + stream, + ))) + } +} diff --git a/datafusion/core/src/datasource/physical_plan/mod.rs b/datafusion/core/src/datasource/physical_plan/mod.rs index f810fb86bd896..fc298c99b9518 100644 --- a/datafusion/core/src/datasource/physical_plan/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/mod.rs @@ -23,6 +23,8 @@ mod csv; mod file_groups; mod file_scan_config; mod file_stream; +#[cfg(feature = "flight")] +mod flight; mod json; #[cfg(feature = "parquet")] pub mod parquet; @@ -43,6 +45,9 @@ pub use file_scan_config::{ pub use file_stream::{FileOpenFuture, FileOpener, FileStream, OnError}; pub use json::{JsonOpener, NdJsonExec}; +#[cfg(feature = "flight")] +pub use self::flight::FlightExec; + use std::{ fmt::{Debug, Formatter, Result as FmtResult}, ops::Range, diff --git a/datafusion/core/src/execution/session_state_defaults.rs b/datafusion/core/src/execution/session_state_defaults.rs index 07420afe842f7..1577e793861b4 100644 --- a/datafusion/core/src/execution/session_state_defaults.rs +++ b/datafusion/core/src/execution/session_state_defaults.rs @@ -25,6 +25,8 @@ use crate::datasource::file_format::json::JsonFormatFactory; #[cfg(feature = "parquet")] use crate::datasource::file_format::parquet::ParquetFormatFactory; use crate::datasource::file_format::FileFormatFactory; +#[cfg(feature = "flight")] +use crate::datasource::flight::{sql::FlightSqlDriver, FlightTableFactory}; use crate::datasource::provider::DefaultTableFactory; use crate::execution::context::SessionState; #[cfg(feature = "nested_expressions")] @@ -55,6 +57,13 @@ impl SessionStateDefaults { table_factories.insert("NDJSON".into(), Arc::new(DefaultTableFactory::new())); table_factories.insert("AVRO".into(), Arc::new(DefaultTableFactory::new())); table_factories.insert("ARROW".into(), Arc::new(DefaultTableFactory::new())); + #[cfg(feature = "flight")] + table_factories.insert( + "FLIGHT_SQL".into(), + Arc::new(FlightTableFactory::new( + Arc::new(FlightSqlDriver::default()), + )), + ); table_factories }