Skip to content

Commit

Permalink
Allow overriding of do_get & export useful macro (#2582)
Browse files Browse the repository at this point in the history
* Allow overriding of do_get & export useful macro

* All hail clippy

* Remove macro export

* Rename function
  • Loading branch information
Brent Gardner authored Aug 25, 2022
1 parent c692b25 commit a685c5f
Showing 1 changed file with 75 additions and 104 deletions.
179 changes: 75 additions & 104 deletions arrow-flight/src/sql/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,18 @@ pub trait FlightSqlService:
))
}

/// Implementors may override to handle additional calls to do_get()
async fn do_get_fallback(
&self,
_request: Request<Ticket>,
message: prost_types::Any,
) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
Err(Status::unimplemented(format!(
"do_get: The defined request is invalid: {}",
message.type_url
)))
}

/// Get a FlightInfo for executing a SQL query.
async fn get_flight_info_statement(
&self,
Expand Down Expand Up @@ -301,92 +313,92 @@ where
&self,
request: Request<FlightDescriptor>,
) -> Result<Response<FlightInfo>, Status> {
let any: prost_types::Any =
let message: prost_types::Any =
Message::decode(&*request.get_ref().cmd).map_err(decode_error_to_status)?;

if any.is::<CommandStatementQuery>() {
let token = any
if message.is::<CommandStatementQuery>() {
let token = message
.unpack()
.map_err(arrow_error_to_status)?
.expect("unreachable");
return self.get_flight_info_statement(token, request).await;
}
if any.is::<CommandPreparedStatementQuery>() {
let handle = any
if message.is::<CommandPreparedStatementQuery>() {
let handle = message
.unpack()
.map_err(arrow_error_to_status)?
.expect("unreachable");
return self
.get_flight_info_prepared_statement(handle, request)
.await;
}
if any.is::<CommandGetCatalogs>() {
let token = any
if message.is::<CommandGetCatalogs>() {
let token = message
.unpack()
.map_err(arrow_error_to_status)?
.expect("unreachable");
return self.get_flight_info_catalogs(token, request).await;
}
if any.is::<CommandGetDbSchemas>() {
let token = any
if message.is::<CommandGetDbSchemas>() {
let token = message
.unpack()
.map_err(arrow_error_to_status)?
.expect("unreachable");
return self.get_flight_info_schemas(token, request).await;
}
if any.is::<CommandGetTables>() {
let token = any
if message.is::<CommandGetTables>() {
let token = message
.unpack()
.map_err(arrow_error_to_status)?
.expect("unreachable");
return self.get_flight_info_tables(token, request).await;
}
if any.is::<CommandGetTableTypes>() {
let token = any
if message.is::<CommandGetTableTypes>() {
let token = message
.unpack()
.map_err(arrow_error_to_status)?
.expect("unreachable");
return self.get_flight_info_table_types(token, request).await;
}
if any.is::<CommandGetSqlInfo>() {
let token = any
if message.is::<CommandGetSqlInfo>() {
let token = message
.unpack()
.map_err(arrow_error_to_status)?
.expect("unreachable");
return self.get_flight_info_sql_info(token, request).await;
}
if any.is::<CommandGetPrimaryKeys>() {
let token = any
if message.is::<CommandGetPrimaryKeys>() {
let token = message
.unpack()
.map_err(arrow_error_to_status)?
.expect("unreachable");
return self.get_flight_info_primary_keys(token, request).await;
}
if any.is::<CommandGetExportedKeys>() {
let token = any
if message.is::<CommandGetExportedKeys>() {
let token = message
.unpack()
.map_err(arrow_error_to_status)?
.expect("unreachable");
return self.get_flight_info_exported_keys(token, request).await;
}
if any.is::<CommandGetImportedKeys>() {
let token = any
if message.is::<CommandGetImportedKeys>() {
let token = message
.unpack()
.map_err(arrow_error_to_status)?
.expect("unreachable");
return self.get_flight_info_imported_keys(token, request).await;
}
if any.is::<CommandGetCrossReference>() {
let token = any
if message.is::<CommandGetCrossReference>() {
let token = message
.unpack()
.map_err(arrow_error_to_status)?
.expect("unreachable");
return self.get_flight_info_cross_reference(token, request).await;
}

Err(Status::unimplemented(format!(
"get_flight_info: The defined request is invalid: {:?}",
String::from_utf8(any.encode_to_vec()).unwrap()
"get_flight_info: The defined request is invalid: {}",
message.type_url
)))
}

Expand All @@ -401,103 +413,62 @@ where
&self,
request: Request<Ticket>,
) -> Result<Response<Self::DoGetStream>, Status> {
let any: prost_types::Any = prost::Message::decode(&*request.get_ref().ticket)
let msg: prost_types::Any = prost::Message::decode(&*request.get_ref().ticket)
.map_err(decode_error_to_status)?;

if any.is::<TicketStatementQuery>() {
let token = any
.unpack()
fn unpack<T: ProstMessageExt>(msg: prost_types::Any) -> Result<T, Status> {
msg.unpack()
.map_err(arrow_error_to_status)?
.expect("unreachable");
return self.do_get_statement(token, request).await;
.ok_or_else(|| Status::internal("Expected a command, but found none."))
}
if any.is::<CommandPreparedStatementQuery>() {
let token = any
.unpack()
.map_err(arrow_error_to_status)?
.expect("unreachable");
return self.do_get_prepared_statement(token, request).await;

if msg.is::<TicketStatementQuery>() {
return self.do_get_statement(unpack(msg)?, request).await;
}
if any.is::<CommandGetCatalogs>() {
let token = any
.unpack()
.map_err(arrow_error_to_status)?
.expect("unreachable");
return self.do_get_catalogs(token, request).await;
if msg.is::<CommandPreparedStatementQuery>() {
return self.do_get_prepared_statement(unpack(msg)?, request).await;
}
if any.is::<CommandGetDbSchemas>() {
let token = any
.unpack()
.map_err(arrow_error_to_status)?
.expect("unreachable");
return self.do_get_schemas(token, request).await;
if msg.is::<CommandGetCatalogs>() {
return self.do_get_catalogs(unpack(msg)?, request).await;
}
if any.is::<CommandGetTables>() {
let token = any
.unpack()
.map_err(arrow_error_to_status)?
.expect("unreachable");
return self.do_get_tables(token, request).await;
if msg.is::<CommandGetDbSchemas>() {
return self.do_get_schemas(unpack(msg)?, request).await;
}
if any.is::<CommandGetTableTypes>() {
let token = any
.unpack()
.map_err(arrow_error_to_status)?
.expect("unreachable");
return self.do_get_table_types(token, request).await;
if msg.is::<CommandGetTables>() {
return self.do_get_tables(unpack(msg)?, request).await;
}
if any.is::<CommandGetSqlInfo>() {
let token = any
.unpack()
.map_err(arrow_error_to_status)?
.expect("unreachable");
return self.do_get_sql_info(token, request).await;
if msg.is::<CommandGetTableTypes>() {
return self.do_get_table_types(unpack(msg)?, request).await;
}
if any.is::<CommandGetPrimaryKeys>() {
let token = any
.unpack()
.map_err(arrow_error_to_status)?
.expect("unreachable");
return self.do_get_primary_keys(token, request).await;
if msg.is::<CommandGetSqlInfo>() {
return self.do_get_sql_info(unpack(msg)?, request).await;
}
if any.is::<CommandGetExportedKeys>() {
let token = any
.unpack()
.map_err(arrow_error_to_status)?
.expect("unreachable");
return self.do_get_exported_keys(token, request).await;
if msg.is::<CommandGetPrimaryKeys>() {
return self.do_get_primary_keys(unpack(msg)?, request).await;
}
if any.is::<CommandGetImportedKeys>() {
let token = any
.unpack()
.map_err(arrow_error_to_status)?
.expect("unreachable");
return self.do_get_imported_keys(token, request).await;
if msg.is::<CommandGetExportedKeys>() {
return self.do_get_exported_keys(unpack(msg)?, request).await;
}
if any.is::<CommandGetCrossReference>() {
let token = any
.unpack()
.map_err(arrow_error_to_status)?
.expect("unreachable");
return self.do_get_cross_reference(token, request).await;
if msg.is::<CommandGetImportedKeys>() {
return self.do_get_imported_keys(unpack(msg)?, request).await;
}
if msg.is::<CommandGetCrossReference>() {
return self.do_get_cross_reference(unpack(msg)?, request).await;
}

Err(Status::unimplemented(format!(
"do_get: The defined request is invalid: {:?}",
String::from_utf8(request.get_ref().ticket.clone()).unwrap()
)))
self.do_get_fallback(request, msg).await
}

async fn do_put(
&self,
mut request: Request<Streaming<FlightData>>,
) -> Result<Response<Self::DoPutStream>, Status> {
let cmd = request.get_mut().message().await?.unwrap();
let any: prost_types::Any =
let message: prost_types::Any =
prost::Message::decode(&*cmd.flight_descriptor.unwrap().cmd)
.map_err(decode_error_to_status)?;
if any.is::<CommandStatementUpdate>() {
let token = any
if message.is::<CommandStatementUpdate>() {
let token = message
.unpack()
.map_err(arrow_error_to_status)?
.expect("unreachable");
Expand All @@ -508,15 +479,15 @@ where
})]);
return Ok(Response::new(Box::pin(output)));
}
if any.is::<CommandPreparedStatementQuery>() {
let token = any
if message.is::<CommandPreparedStatementQuery>() {
let token = message
.unpack()
.map_err(arrow_error_to_status)?
.expect("unreachable");
return self.do_put_prepared_statement_query(token, request).await;
}
if any.is::<CommandPreparedStatementUpdate>() {
let handle = any
if message.is::<CommandPreparedStatementUpdate>() {
let handle = message
.unpack()
.map_err(arrow_error_to_status)?
.expect("unreachable");
Expand All @@ -531,8 +502,8 @@ where
}

Err(Status::invalid_argument(format!(
"do_put: The defined request is invalid: {:?}",
String::from_utf8(any.encode_to_vec()).unwrap()
"do_put: The defined request is invalid: {}",
message.type_url
)))
}

Expand Down

0 comments on commit a685c5f

Please sign in to comment.