From f5f6b7e3aa10b80dc574abacf96b30e0927410fe Mon Sep 17 00:00:00 2001 From: Jorge Leitao Date: Sat, 2 Jul 2022 19:47:39 -0700 Subject: [PATCH] Simplified code in flight integration tests (#1136) --- integration-testing/Cargo.toml | 1 + integration-testing/README.md | 61 ++--- .../src/bin/arrow-json-integration-test.rs | 4 +- .../src/flight_client_scenarios.rs | 20 -- .../integration_test.rs | 194 ++++++-------- .../src/flight_client_scenarios/middleware.rs | 1 - .../src/flight_client_scenarios/mod.rs | 3 + .../auth_basic_proto.rs | 178 ++++--------- .../integration_test.rs | 249 ++++++++---------- .../mod.rs} | 21 +- integration-testing/src/lib.rs | 6 +- 11 files changed, 291 insertions(+), 447 deletions(-) delete mode 100644 integration-testing/src/flight_client_scenarios.rs create mode 100644 integration-testing/src/flight_client_scenarios/mod.rs rename integration-testing/src/{flight_server_scenarios.rs => flight_server_scenarios/mod.rs} (51%) diff --git a/integration-testing/Cargo.toml b/integration-testing/Cargo.toml index 4f9c1761bcb..8cc7d9a0e23 100644 --- a/integration-testing/Cargo.toml +++ b/integration-testing/Cargo.toml @@ -41,3 +41,4 @@ serde_json = { version = "1.0", features = ["preserve_order"] } tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread"] } tonic = "0.7.0" tracing-subscriber = { version = "0.3.1", optional = true } +async-stream = { version = "0.3.2" } diff --git a/integration-testing/README.md b/integration-testing/README.md index 66248deb346..766f77b1151 100644 --- a/integration-testing/README.md +++ b/integration-testing/README.md @@ -1,30 +1,31 @@ - - -# Apache Arrow Rust Integration Testing - -See [Integration.rst](../../docs/source/format/Integration.rst) for an overview of integration testing. - -This crate contains the following binaries, which are invoked by Archery during integration testing with other Arrow implementations. - -| Binary | Purpose | -|--------|---------| -| arrow-file-to-stream | Converts an Arrow file to an Arrow stream | -| arrow-stream-to-file | Converts an Arrow stream to an Arrow file | -| arrow-json-integration-test | Converts between Arrow and JSON formats | +# Integration tests + +This directory contains integration tests against official Arrow implementations. + +They are run as part of the CI pipeline, called by apache/arrow/dev/crossbow. + +The IPC files tested on the official pipeline are already tested on our own tests. + +## Flight tests + +To run the flight scenarios across this implementation, use + +```bash +SCENARIO="auth:basic_proto" +cargo run --bin flight-test-integration-server -- --port 3333 --scenario $SCENARIO & +# wait for server to be up + +cargo run --bin flight-test-integration-client -- --host localhost --port 3333 --scenario $SCENARIO +``` + +to run an integration test against a file, use + +```bash +FILE="../testing/arrow-testing/data/arrow-ipc-stream/integration/1.0.0-littleendian/generated_dictionary.json.gz" +gzip -dc $FILE > generated.json + +cargo build --bin flight-test-integration-server +cargo run --bin flight-test-integration-server -- --port 3333 & +cargo run --bin flight-test-integration-client -- --host localhost --port 3333 --path generated.json +# kill with `fg` and stop process +``` diff --git a/integration-testing/src/bin/arrow-json-integration-test.rs b/integration-testing/src/bin/arrow-json-integration-test.rs index 82de2e435d3..7152fb40a70 100644 --- a/integration-testing/src/bin/arrow-json-integration-test.rs +++ b/integration-testing/src/bin/arrow-json-integration-test.rs @@ -64,7 +64,7 @@ fn json_to_arrow(json_name: &str, arrow_name: &str, verbose: bool) -> Result<()> options, )?; - for b in json_file.batches { + for b in json_file.chunks { writer.write(&b, None)?; } @@ -129,7 +129,7 @@ fn validate(arrow_name: &str, json_name: &str, verbose: bool) -> Result<()> { ))); } - let json_batches = json_file.batches; + let json_batches = json_file.chunks; if verbose { eprintln!( diff --git a/integration-testing/src/flight_client_scenarios.rs b/integration-testing/src/flight_client_scenarios.rs deleted file mode 100644 index 66cced5f4c2..00000000000 --- a/integration-testing/src/flight_client_scenarios.rs +++ /dev/null @@ -1,20 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -pub mod auth_basic_proto; -pub mod integration_test; -pub mod middleware; diff --git a/integration-testing/src/flight_client_scenarios/integration_test.rs b/integration-testing/src/flight_client_scenarios/integration_test.rs index a731a06eab1..3cdcd17f2e7 100644 --- a/integration-testing/src/flight_client_scenarios/integration_test.rs +++ b/integration-testing/src/flight_client_scenarios/integration_test.rs @@ -38,7 +38,7 @@ use arrow_format::{ }, ipc::planus::ReadAsRoot, }; -use futures::{channel::mpsc, sink::SinkExt, stream, StreamExt}; +use futures::{stream::BoxStream, StreamExt, TryStreamExt}; use tonic::{Request, Streaming}; type Error = Box; @@ -55,17 +55,16 @@ pub async fn run_scenario(host: &str, port: u16, path: &str) -> Result { let ArrowFile { schema, - batches, + chunks, fields, .. } = read_json_file(path)?; + let ipc_schema = IpcSchema { fields, is_little_endian: true, }; - let schema = Box::new(schema); - let mut descriptor = FlightDescriptor::default(); descriptor.set_type(DescriptorType::Path); descriptor.path = vec![path.to_string()]; @@ -73,12 +72,12 @@ pub async fn run_scenario(host: &str, port: u16, path: &str) -> Result { upload_data( client.clone(), &schema, - &ipc_schema.fields, + ipc_schema.fields.clone(), descriptor.clone(), - batches.clone(), + chunks.clone(), ) .await?; - verify_data(client, descriptor, &schema, &ipc_schema, &batches).await?; + verify_data(client, descriptor, &schema, &ipc_schema, &chunks).await?; Ok(()) } @@ -86,77 +85,56 @@ pub async fn run_scenario(host: &str, port: u16, path: &str) -> Result { async fn upload_data( mut client: Client, schema: &Schema, - fields: &[IpcField], + fields: Vec, descriptor: FlightDescriptor, - original_data: Vec, + chunks: Vec, ) -> Result { - let (mut upload_tx, upload_rx) = mpsc::channel(10); - - let options = write::WriteOptions { compression: None }; + let stream = new_stream(schema, fields, descriptor, chunks); - let mut schema = flight::serialize_schema(schema, Some(fields)); - schema.flight_descriptor = Some(descriptor.clone()); - upload_tx.send(schema).await?; - - let mut original_data_iter = original_data.iter().enumerate(); - - if let Some((counter, first_batch)) = original_data_iter.next() { - let metadata = counter.to_string().into_bytes(); - // Preload the first batch into the channel before starting the request - send_batch(&mut upload_tx, &metadata, first_batch, fields, &options).await?; + // put the stream in the client + let responses = client.do_put(Request::new(stream)).await?.into_inner(); - let outer = client.do_put(Request::new(upload_rx)).await?; - let mut inner = outer.into_inner(); - - let r = inner - .next() - .await - .expect("No response received") - .expect("Invalid response received"); - assert_eq!(metadata, r.app_metadata); - - // Stream the rest of the batches - for (counter, batch) in original_data_iter { - let metadata = counter.to_string().into_bytes(); - send_batch(&mut upload_tx, &metadata, batch, fields, &options).await?; - - let r = inner - .next() - .await - .expect("No response received") - .expect("Invalid response received"); - assert_eq!(metadata, r.app_metadata); - } - drop(upload_tx); - assert!( - inner.next().await.is_none(), - "Should not receive more results" - ); - } else { - drop(upload_tx); - client.do_put(Request::new(upload_rx)).await?; - } + // confirm that all chunks were received in the right order + let results = responses.try_collect::>().await?; + assert!(results + .into_iter() + // only record batches have a metadata; ignore dictionary batches + .filter(|r| !r.app_metadata.is_empty()) + .enumerate() + .all(|(counter, r)| r.app_metadata == counter.to_string().as_bytes())); Ok(()) } -async fn send_batch( - upload_tx: &mut mpsc::Sender, - metadata: &[u8], - batch: &ChunkBox, - fields: &[IpcField], - options: &write::WriteOptions, -) -> Result { - let (dictionary_flight_data, mut batch_flight_data) = serialize_batch(batch, fields, options)?; +fn new_stream( + schema: &Schema, + fields: Vec, + descriptor: FlightDescriptor, + chunks: Vec, +) -> BoxStream<'static, FlightData> { + let options = write::WriteOptions { compression: None }; - upload_tx - .send_all(&mut stream::iter(dictionary_flight_data).map(Ok)) - .await?; + let mut schema = flight::serialize_schema(schema, Some(&fields)); + schema.flight_descriptor = Some(descriptor); - // Only the record batch's FlightData gets app_metadata - batch_flight_data.app_metadata = metadata.to_vec(); - upload_tx.send(batch_flight_data).await?; - Ok(()) + // iterator of [dictionaries0, chunk0, dictionaries1, chunk1, ...] + let iter = chunks + .into_iter() + .enumerate() + .flat_map(move |(counter, chunk)| { + let metadata = counter.to_string().into_bytes(); + let (mut dictionaries, mut chunk) = serialize_batch(&chunk, &fields, &options).unwrap(); + + // assign `app_metadata` to chunks + chunk.app_metadata = metadata.to_vec(); + dictionaries.push(chunk); + dictionaries + }); + + // the stream as per flight spec: the schema followed by stream of chunks + futures::stream::once(futures::future::ready(schema)) + .chain(futures::stream::iter(iter)) + .boxed() } async fn verify_data( @@ -164,7 +142,7 @@ async fn verify_data( descriptor: FlightDescriptor, expected_schema: &Schema, ipc_schema: &IpcSchema, - expected_data: &[ChunkBox], + expected_chunks: &[ChunkBox], ) -> Result { let resp = client.get_flight_info(Request::new(descriptor)).await?; let info = resp.into_inner(); @@ -186,7 +164,7 @@ async fn verify_data( consume_flight_location( location, ticket.clone(), - expected_data, + expected_chunks, expected_schema, ipc_schema, ) @@ -200,7 +178,7 @@ async fn verify_data( async fn consume_flight_location( location: Location, ticket: Ticket, - expected_data: &[ChunkBox], + expected_chunks: &[ChunkBox], schema: &Schema, ipc_schema: &IpcSchema, ) -> Result { @@ -211,75 +189,73 @@ async fn consume_flight_location( let mut client = FlightServiceClient::connect(location.uri).await?; let resp = client.do_get(ticket).await?; - let mut resp = resp.into_inner(); + let mut stream = resp.into_inner(); // We already have the schema from the FlightInfo, but the server sends it again as the // first FlightData. Ignore this one. - let _schema_again = resp.next().await.unwrap(); + let _schema_again = stream.next().await.unwrap(); let mut dictionaries = Default::default(); - for (counter, expected_batch) in expected_data.iter().enumerate() { - let data = - receive_batch_flight_data(&mut resp, &schema.fields, ipc_schema, &mut dictionaries) - .await - .unwrap_or_else(|| { - panic!( - "Got fewer batches than expected, received so far: {} expected: {}", - counter, - expected_data.len(), - ) - }); + for (counter, expected_chunk) in expected_chunks.iter().enumerate() { + let data = read_dictionaries(&mut stream, &schema.fields, ipc_schema, &mut dictionaries) + .await + .unwrap_or_else(|| { + panic!( + "Got fewer chunkes than expected, received so far: {} expected: {}", + counter, + expected_chunks.len(), + ) + }); let metadata = counter.to_string().into_bytes(); assert_eq!(metadata, data.app_metadata); - let actual_batch = deserialize_batch(&data, &schema.fields, ipc_schema, &dictionaries) - .expect("Unable to convert flight data to Arrow batch"); - - assert_eq!(expected_batch.columns().len(), actual_batch.columns().len()); - assert_eq!(expected_batch.len(), actual_batch.len()); - for (i, (expected, actual)) in expected_batch - .columns() - .iter() - .zip(actual_batch.columns().iter()) - .enumerate() - { - let field_name = &schema.fields[i].name; - assert_eq!(expected, actual, "Data for field {}", field_name); - } + let chunk = deserialize_batch(&data, &schema.fields, ipc_schema, &dictionaries) + .expect("Unable to convert flight data to Arrow chunk"); + + assert_eq!(&chunk, expected_chunk); } assert!( - resp.next().await.is_none(), - "Got more batches than the expected: {}", - expected_data.len(), + stream.next().await.is_none(), + "Got more chunkes than the expected: {}", + expected_chunks.len(), ); Ok(()) } -async fn receive_batch_flight_data( - resp: &mut Streaming, +async fn read_dictionaries( + stream: &mut Streaming, fields: &[Field], ipc_schema: &IpcSchema, dictionaries: &mut Dictionaries, ) -> Option { - let mut data = resp.next().await?.ok()?; + let mut data = stream.next().await?.ok()?; let mut message = ipc::MessageRef::read_as_root(&data.data_header).expect("Error parsing first message"); - while let ipc::MessageHeaderRef::DictionaryBatch(batch) = message + while let ipc::MessageHeaderRef::DictionaryBatch(chunk) = message .header() .expect("Header to be valid flatbuffers") .expect("Header to be present") { let length = data.data_body.len(); let mut reader = std::io::Cursor::new(&data.data_body); - read::read_dictionary(batch, fields, ipc_schema, dictionaries, &mut reader, 0, length as u64, &mut Default::default()) - .expect("Error reading dictionary"); - - data = resp.next().await?.ok()?; + read::read_dictionary( + chunk, + fields, + ipc_schema, + dictionaries, + &mut reader, + 0, + length as u64, + &mut Default::default(), + ) + .expect("Error reading dictionary"); + + data = stream.next().await?.ok()?; message = ipc::MessageRef::read_as_root(&data.data_header).expect("Error parsing message"); } diff --git a/integration-testing/src/flight_client_scenarios/middleware.rs b/integration-testing/src/flight_client_scenarios/middleware.rs index 29c96ce67ea..c5fd752768c 100644 --- a/integration-testing/src/flight_client_scenarios/middleware.rs +++ b/integration-testing/src/flight_client_scenarios/middleware.rs @@ -73,7 +73,6 @@ pub async fn run_scenario(host: &str, port: u16) -> Result { Ok(()) } -#[allow(clippy::unnecessary_wraps)] fn middleware_interceptor(mut req: Request<()>) -> Result, Status> { let metadata = req.metadata_mut(); metadata.insert("x-middleware", "expected value".parse().unwrap()); diff --git a/integration-testing/src/flight_client_scenarios/mod.rs b/integration-testing/src/flight_client_scenarios/mod.rs new file mode 100644 index 00000000000..c397bbdafbc --- /dev/null +++ b/integration-testing/src/flight_client_scenarios/mod.rs @@ -0,0 +1,3 @@ +pub mod auth_basic_proto; +pub mod integration_test; +pub mod middleware; diff --git a/integration-testing/src/flight_server_scenarios/auth_basic_proto.rs b/integration-testing/src/flight_server_scenarios/auth_basic_proto.rs index 0938de27dbe..4bef88cbe5a 100644 --- a/integration-testing/src/flight_server_scenarios/auth_basic_proto.rs +++ b/integration-testing/src/flight_server_scenarios/auth_basic_proto.rs @@ -1,44 +1,18 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::pin::Pin; -use std::sync::Arc; +use async_stream::try_stream; +use futures::pin_mut; +use prost::Message; +use tonic::{metadata::MetadataMap, transport::Server, Request, Response, Status, Streaming}; use arrow_format::flight::data::*; use arrow_format::flight::service::flight_service_server::{FlightService, FlightServiceServer}; -use futures::{channel::mpsc, sink::SinkExt, Stream, StreamExt}; -use tonic::{metadata::MetadataMap, transport::Server, Request, Response, Status, Streaming}; - -type TonicStream = Pin + Send + Sync + 'static>>; - -type Error = Box; -type Result = std::result::Result; -use prost::Message; +use super::{Result, TonicStream}; use crate::{AUTH_PASSWORD, AUTH_USERNAME}; pub async fn scenario_setup(port: u16) -> Result { - let service = AuthBasicProtoScenarioImpl { - username: AUTH_USERNAME.into(), - password: AUTH_PASSWORD.into(), - }; let addr = super::listen_on(port).await?; - let svc = FlightServiceServer::new(service); + let svc = FlightServiceServer::new(Service {}); let server = Server::builder().add_service(svc).serve(addr); @@ -49,42 +23,21 @@ pub async fn scenario_setup(port: u16) -> Result { } #[derive(Clone)] -pub struct AuthBasicProtoScenarioImpl { - username: Arc, - password: Arc, -} +struct Service {} -impl AuthBasicProtoScenarioImpl { - async fn check_auth(&self, metadata: &MetadataMap) -> Result { - let token = metadata +impl Service { + fn check_auth(&self, metadata: &MetadataMap) -> Result { + metadata .get_bin("auth-token-bin") .and_then(|v| v.to_bytes().ok()) - .and_then(|b| String::from_utf8(b.to_vec()).ok()); - self.is_valid(token).await - } - - async fn is_valid(&self, token: Option) -> Result { - match token { - Some(t) if t == *self.username => Ok(GrpcServerCallContext { - peer_identity: self.username.to_string(), - }), - _ => Err(Status::unauthenticated("Invalid token")), - } - } -} - -struct GrpcServerCallContext { - peer_identity: String, -} - -impl GrpcServerCallContext { - pub fn peer_identity(&self) -> &str { - &self.peer_identity + .and_then(|b| String::from_utf8(b.to_vec()).ok()) + .and_then(|username| (username == AUTH_USERNAME).then(|| AUTH_USERNAME.to_string())) + .ok_or_else(|| Status::unauthenticated("Invalid token")) } } #[tonic::async_trait] -impl FlightService for AuthBasicProtoScenarioImpl { +impl FlightService for Service { type HandshakeStream = TonicStream>; type ListFlightsStream = TonicStream>; type DoGetStream = TonicStream>; @@ -97,79 +50,48 @@ impl FlightService for AuthBasicProtoScenarioImpl { &self, request: Request, ) -> Result, Status> { - self.check_auth(request.metadata()).await?; - Err(Status::unimplemented("Not yet implemented")) + self.check_auth(request.metadata())?; + Err(Status::unimplemented("get_schema")) } async fn do_get( &self, request: Request, ) -> Result, Status> { - self.check_auth(request.metadata()).await?; - Err(Status::unimplemented("Not yet implemented")) + self.check_auth(request.metadata())?; + Err(Status::unimplemented("do_get")) } async fn handshake( &self, request: Request>, ) -> Result, Status> { - let (tx, rx) = mpsc::channel(10); - - tokio::spawn({ - let username = self.username.clone(); - let password = self.password.clone(); - - async move { - let requests = request.into_inner(); - - requests - .for_each(move |req| { - let mut tx = tx.clone(); - let req = req.expect("Error reading handshake request"); - let HandshakeRequest { payload, .. } = req; - - let auth = - BasicAuth::decode(&*payload).expect("Error parsing handshake request"); - - let resp = if *auth.username == *username && *auth.password == *password { - Ok(HandshakeResponse { - payload: username.as_bytes().to_vec(), - ..HandshakeResponse::default() - }) - } else { - Err(Status::unauthenticated(format!( - "Don't know user {}", - auth.username - ))) - }; - - async move { - tx.send(resp) - .await - .expect("Error sending handshake response"); - } - }) - .await; - } - }); + let stream = request.into_inner(); - Ok(Response::new(Box::pin(rx))) + let stream = try_stream! { + pin_mut!(stream); + for await item in stream { + let HandshakeRequest {payload, ..} = item.map_err(|_| Status::invalid_argument(format!("Invalid"))).unwrap(); + yield handle(&payload, AUTH_USERNAME, AUTH_PASSWORD).unwrap() + } + }; + Ok(Response::new(Box::pin(stream))) } async fn list_flights( &self, request: Request, ) -> Result, Status> { - self.check_auth(request.metadata()).await?; - Err(Status::unimplemented("Not yet implemented")) + self.check_auth(request.metadata())?; + Err(Status::unimplemented("list_flights")) } async fn get_flight_info( &self, request: Request, ) -> Result, Status> { - self.check_auth(request.metadata()).await?; - Err(Status::unimplemented("Not yet implemented")) + self.check_auth(request.metadata())?; + Err(Status::unimplemented("get_flight_info")) } async fn do_put( @@ -177,36 +99,52 @@ impl FlightService for AuthBasicProtoScenarioImpl { request: Request>, ) -> Result, Status> { let metadata = request.metadata(); - self.check_auth(metadata).await?; - Err(Status::unimplemented("Not yet implemented")) + self.check_auth(metadata)?; + Err(Status::unimplemented("do_put")) } async fn do_action( &self, request: Request, ) -> Result, Status> { - let flight_context = self.check_auth(request.metadata()).await?; + let username = self.check_auth(request.metadata())?; // Respond with the authenticated username. - let buf = flight_context.peer_identity().as_bytes().to_vec(); - let result = arrow_format::flight::data::Result { body: buf }; + let result = arrow_format::flight::data::Result { + body: username.as_bytes().to_vec(), + }; let output = futures::stream::once(async { Ok(result) }); - Ok(Response::new(Box::pin(output) as Self::DoActionStream)) + Ok(Response::new(Box::pin(output))) } async fn list_actions( &self, request: Request, ) -> Result, Status> { - self.check_auth(request.metadata()).await?; - Err(Status::unimplemented("Not yet implemented")) + self.check_auth(request.metadata())?; + Err(Status::unimplemented("list_actions")) } async fn do_exchange( &self, request: Request>, ) -> Result, Status> { - let metadata = request.metadata(); - self.check_auth(metadata).await?; - Err(Status::unimplemented("Not yet implemented")) + self.check_auth(request.metadata())?; + Err(Status::unimplemented("do_exchange")) + } +} + +fn handle(payload: &[u8], username: &str, password: &str) -> Result { + let auth = BasicAuth::decode(payload)?; + + if auth.username == username && auth.password == password { + Ok(HandshakeResponse { + payload: username.as_bytes().to_vec(), + ..HandshakeResponse::default() + }) + } else { + Err(Box::new(Status::unauthenticated(format!( + "Don't know user {}", + auth.username + )))) } } diff --git a/integration-testing/src/flight_server_scenarios/integration_test.rs b/integration-testing/src/flight_server_scenarios/integration_test.rs index 10424759e87..c224967bfa2 100644 --- a/integration-testing/src/flight_server_scenarios/integration_test.rs +++ b/integration-testing/src/flight_server_scenarios/integration_test.rs @@ -16,35 +16,34 @@ // under the License. use std::collections::HashMap; -use std::pin::Pin; use std::sync::Arc; -use arrow2::array::Array; -use arrow2::chunk::Chunk; -use arrow2::io::flight::{deserialize_schemas, serialize_batch, serialize_schema}; -use arrow2::io::ipc::read::Dictionaries; -use arrow2::io::ipc::IpcSchema; +use async_stream::try_stream; +use futures::pin_mut; +use tokio::sync::Mutex; +use tonic::{transport::Server, Request, Response, Status, Streaming}; + use arrow_format::flight::data::flight_descriptor::*; use arrow_format::flight::data::*; use arrow_format::flight::service::flight_service_server::*; use arrow_format::ipc::planus::ReadAsRoot; use arrow_format::ipc::MessageHeaderRef; -use arrow2::{datatypes::*, io::flight::serialize_schema_to_info, io::ipc}; - -use futures::{channel::mpsc, sink::SinkExt, Stream, StreamExt}; -use tokio::sync::Mutex; -use tonic::{transport::Server, Request, Response, Status, Streaming}; - -type TonicStream = Pin + Send + Sync + 'static>>; +use arrow2::array::Array; +use arrow2::chunk::Chunk; +use arrow2::datatypes::{Field, Schema}; +use arrow2::io::flight::{ + deserialize_schemas, serialize_batch, serialize_schema, serialize_schema_to_info, +}; +use arrow2::io::ipc; +use arrow2::io::ipc::read::Dictionaries; -type Error = Box; -type Result = std::result::Result; +use super::{Result, TonicStream}; pub async fn scenario_setup(port: u16) -> Result { let addr = super::listen_on(port).await?; - let service = FlightServiceImpl { + let service = Service { server_location: format!("grpc+tcp://{}", addr), ..Default::default() }; @@ -61,24 +60,24 @@ pub async fn scenario_setup(port: u16) -> Result { #[derive(Debug, Clone)] struct IntegrationDataset { schema: Schema, - ipc_schema: IpcSchema, + ipc_schema: ipc::IpcSchema, chunks: Vec>>, } #[derive(Clone, Default)] -pub struct FlightServiceImpl { +struct Service { server_location: String, uploaded_chunks: Arc>>, } -impl FlightServiceImpl { +impl Service { fn endpoint_from_path(&self, path: &str) -> FlightEndpoint { super::endpoint(path, &self.server_location) } } #[tonic::async_trait] -impl FlightService for FlightServiceImpl { +impl FlightService for Service { type HandshakeStream = TonicStream>; type ListFlightsStream = TonicStream>; type DoGetStream = TonicStream>; @@ -91,7 +90,7 @@ impl FlightService for FlightServiceImpl { &self, _request: Request, ) -> Result, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented("get_schema")) } async fn do_get( @@ -111,30 +110,31 @@ impl FlightService for FlightServiceImpl { let options = ipc::write::WriteOptions { compression: None }; - let schema = std::iter::once(Ok(serialize_schema( - &flight.schema, - Some(&flight.ipc_schema.fields), - ))); + let schema = serialize_schema(&flight.schema, Some(&flight.ipc_schema.fields)); let batches = flight .chunks .iter() .enumerate() .flat_map(|(counter, batch)| { - let (dictionary_flight_data, mut batch_flight_data) = + let (dictionaries, mut chunk) = serialize_batch(batch, &flight.ipc_schema.fields, &options).unwrap(); // Only the record batch's FlightData gets app_metadata let metadata = counter.to_string().into_bytes(); - batch_flight_data.app_metadata = metadata; + chunk.app_metadata = metadata; - dictionary_flight_data + dictionaries .into_iter() - .chain(std::iter::once(batch_flight_data)) + .chain(std::iter::once(chunk)) .map(Ok) }); - let output = futures::stream::iter(schema.chain(batches).collect::>()); + let output = futures::stream::iter( + std::iter::once(Ok(schema)) + .chain(batches) + .collect::>(), + ); Ok(Response::new(Box::pin(output) as Self::DoGetStream)) } @@ -143,14 +143,14 @@ impl FlightService for FlightServiceImpl { &self, _request: Request>, ) -> Result, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented("handshake")) } async fn list_flights( &self, _request: Request, ) -> Result, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented("list_flights")) } async fn get_flight_info( @@ -220,73 +220,67 @@ impl FlightService for FlightServiceImpl { let (schema, ipc_schema) = deserialize_schemas(&flight_data.data_header) .map_err(|e| Status::invalid_argument(format!("Invalid schema: {:?}", e)))?; - let (response_tx, response_rx) = mpsc::channel(10); - let uploaded_chunks = self.uploaded_chunks.clone(); - tokio::spawn(async { - let mut error_tx = response_tx.clone(); - if let Err(e) = save_uploaded_chunks( - uploaded_chunks, + let mut dictionaries = Dictionaries::default(); + let mut chunks = vec![]; + + let stream = try_stream! { + pin_mut!(input_stream); + for await item in input_stream { + let FlightData {data_header, data_body, app_metadata, ..} = item.map_err(|_| Status::invalid_argument(format!("Invalid")))?; + save_message(&data_header, + &data_body, + &schema, + &ipc_schema, + &mut dictionaries, + &mut chunks)?; + yield PutResult {app_metadata} + } + let dataset = IntegrationDataset { schema, + chunks, ipc_schema, - input_stream, - response_tx, - key, - ) - .await - { - error_tx.send(Err(e)).await.expect("Error sending error") - } - }); - - Ok(Response::new(Box::pin(response_rx) as Self::DoPutStream)) + }; + let mut uploaded_chunks = uploaded_chunks.lock().await; + uploaded_chunks.insert(key, dataset); + }; + Ok(Response::new(Box::pin(stream))) } async fn do_action( &self, _request: Request, ) -> Result, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented("do_action")) } async fn list_actions( &self, _request: Request, ) -> Result, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented("list_actions")) } async fn do_exchange( &self, _request: Request>, ) -> Result, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented("do_exchange")) } } -async fn send_app_metadata( - tx: &mut mpsc::Sender>, - app_metadata: &[u8], -) -> Result<(), Status> { - tx.send(Ok(PutResult { - app_metadata: app_metadata.to_vec(), - })) - .await - .map_err(|e| Status::internal(format!("Could not send PutResult: {:?}", e))) -} - -async fn record_batch_from_message( +fn chunk_from_message( batch: arrow_format::ipc::RecordBatchRef<'_>, data_body: &[u8], fields: &[Field], - ipc_schema: &IpcSchema, + ipc_schema: &ipc::IpcSchema, dictionaries: &mut Dictionaries, ) -> Result>, Status> { let length = data_body.len(); let mut reader = std::io::Cursor::new(data_body); - let arrow_batch_result = ipc::read::read_record_batch( + ipc::read::read_record_batch( batch, fields, ipc_schema, @@ -296,99 +290,64 @@ async fn record_batch_from_message( &mut reader, 0, length as u64, - &mut Default::default() - ); - - arrow_batch_result.map_err(|e| Status::internal(format!("Could not convert to Chunk: {:?}", e))) + &mut Default::default(), + ) + .map_err(|e| Status::internal(format!("Could not convert to Chunk: {:?}", e))) } -async fn dictionary_from_message( +fn dictionary_from_message( dict_batch: arrow_format::ipc::DictionaryBatchRef<'_>, data_body: &[u8], fields: &[Field], - ipc_schema: &IpcSchema, + ipc_schema: &ipc::IpcSchema, dictionaries: &mut Dictionaries, ) -> Result<(), Status> { let length = data_body.len(); let mut reader = std::io::Cursor::new(data_body); - let dictionary_batch_result = - ipc::read::read_dictionary(dict_batch, fields, ipc_schema, dictionaries, &mut reader, 0, length as u64, &mut Default::default()); - dictionary_batch_result - .map_err(|e| Status::internal(format!("Could not convert to Dictionary: {:?}", e))) + ipc::read::read_dictionary( + dict_batch, + fields, + ipc_schema, + dictionaries, + &mut reader, + 0, + length as u64, + &mut Default::default(), + ) + .map_err(|e| Status::internal(format!("Could not convert to Dictionary: {:?}", e))) } -async fn save_uploaded_chunks( - uploaded_chunks: Arc>>, - schema: Schema, - ipc_schema: IpcSchema, - mut input_stream: Streaming, - mut response_tx: mpsc::Sender>, - key: String, +fn save_message( + header: &[u8], + body: &[u8], + schema: &Schema, + ipc_schema: &ipc::IpcSchema, + dictionaries: &mut Dictionaries, + chunks: &mut Vec>>, ) -> Result<(), Status> { - let mut chunks = vec![]; - let mut uploaded_chunks = uploaded_chunks.lock().await; - - let mut dictionaries = Default::default(); - - while let Some(Ok(data)) = input_stream.next().await { - let message = arrow_format::ipc::MessageRef::read_as_root(&data.data_header) - .map_err(|e| Status::internal(format!("Could not parse message: {:?}", e)))?; - let header = message - .header() - .map_err(|x| Status::internal(x.to_string()))? - .ok_or_else(|| { - Status::internal( - "Unable to convert flight data header to a record batch".to_string(), - ) - })?; - - match header { - MessageHeaderRef::Schema(_) => { - return Err(Status::internal( - "Not expecting a schema when messages are read", - )) - } - MessageHeaderRef::RecordBatch(batch) => { - send_app_metadata(&mut response_tx, &data.app_metadata).await?; - - let batch = record_batch_from_message( - batch, - &data.data_body, - &schema.fields, - &ipc_schema, - &mut dictionaries, - ) - .await?; - - chunks.push(batch); - } - MessageHeaderRef::DictionaryBatch(dict_batch) => { - dictionary_from_message( - dict_batch, - &data.data_body, - &schema.fields, - &ipc_schema, - &mut dictionaries, - ) - .await?; - } - t => { - return Err(Status::internal(format!( - "Reading types other than record batches not yet supported, \ - unable to read {:?}", - t - ))); - } + let message = arrow_format::ipc::MessageRef::read_as_root(header) + .map_err(|e| Status::internal(format!("Could not parse message: {:?}", e)))?; + let header = message + .header() + .map_err(|x| Status::internal(x.to_string()))? + .ok_or_else(|| Status::internal("Message must contain a header".to_string()))?; + + match header { + MessageHeaderRef::RecordBatch(batch) => { + let chunk = chunk_from_message(batch, body, &schema.fields, ipc_schema, dictionaries)?; + + chunks.push(chunk); + } + MessageHeaderRef::DictionaryBatch(dict_batch) => { + dictionary_from_message(dict_batch, body, &schema.fields, ipc_schema, dictionaries)?; + } + t => { + return Err(Status::internal(format!( + "Reading types other than record batches not yet supported, unable to read {:?}", + t + ))); } } - - let dataset = IntegrationDataset { - schema, - chunks, - ipc_schema, - }; - uploaded_chunks.insert(key, dataset); - Ok(()) } diff --git a/integration-testing/src/flight_server_scenarios.rs b/integration-testing/src/flight_server_scenarios/mod.rs similarity index 51% rename from integration-testing/src/flight_server_scenarios.rs rename to integration-testing/src/flight_server_scenarios/mod.rs index 5cdb930d09d..68869c70ad2 100644 --- a/integration-testing/src/flight_server_scenarios.rs +++ b/integration-testing/src/flight_server_scenarios/mod.rs @@ -1,29 +1,16 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - use std::net::SocketAddr; +use std::pin::Pin; use arrow_format::flight::data::{FlightEndpoint, Location, Ticket}; +use futures::Stream; use tokio::net::TcpListener; pub mod auth_basic_proto; pub mod integration_test; pub mod middleware; +type TonicStream = Pin + Send + 'static>>; + type Error = Box; type Result = std::result::Result; diff --git a/integration-testing/src/lib.rs b/integration-testing/src/lib.rs index ed70a32f95b..0a3b778e38c 100644 --- a/integration-testing/src/lib.rs +++ b/integration-testing/src/lib.rs @@ -44,7 +44,7 @@ pub struct ArrowFile { // we can evolve this into a concrete Arrow type // this is temporarily not being read from pub _dictionaries: HashMap, - pub batches: Vec>>, + pub chunks: Vec>>, } pub fn read_json_file(json_name: &str) -> Result { @@ -67,7 +67,7 @@ pub fn read_json_file(json_name: &str) -> Result { } } - let batches = arrow_json["batches"] + let chunks = arrow_json["batches"] .as_array() .unwrap() .iter() @@ -80,6 +80,6 @@ pub fn read_json_file(json_name: &str) -> Result { schema, fields, _dictionaries: dictionaries, - batches, + chunks, }) }