From 58ba7083e2c1165e1f6978d447874f39a857b9d4 Mon Sep 17 00:00:00 2001 From: Guy Margalit Date: Sun, 2 Jan 2022 17:32:37 +0200 Subject: [PATCH] Server streaming body + response headers Signed-off-by: Guy Margalit --- .../ServerOperationHandlerGenerator.kt | 10 +- .../protocols/ServerHttpProtocolGenerator.kt | 249 +++++++++++++----- .../aws-smithy-http-server/Cargo.toml | 2 +- .../aws-smithy-http-server/src/lib.rs | 3 + .../aws-smithy-http-server/src/rejection.rs | 6 + .../aws-smithy-http/src/byte_stream.rs | 8 + 6 files changed, 211 insertions(+), 67 deletions(-) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationHandlerGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationHandlerGenerator.kt index da1a7153240..810a1e8fef8 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationHandlerGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationHandlerGenerator.kt @@ -18,6 +18,8 @@ import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerHttpPro import software.amazon.smithy.rust.codegen.smithy.CodegenContext import software.amazon.smithy.rust.codegen.smithy.RuntimeType import software.amazon.smithy.rust.codegen.smithy.generators.error.errorSymbol +import software.amazon.smithy.rust.codegen.util.hasStreamingMember +import software.amazon.smithy.rust.codegen.util.inputShape import software.amazon.smithy.rust.codegen.util.outputShape /** @@ -132,13 +134,19 @@ class ServerOperationHandlerGenerator( } else { symbolProvider.toSymbol(operation.outputShape(model)).fullName } + val streamingBodyTraitBounds = if (operation.inputShape(model).hasStreamingMember(model)) { + "B: Into<#{SmithyHttpServer}::ByteStream>," + } else { + "" + } return """ $inputFn Fut: std::future::Future + Send, B: $serverCrate::HttpBody + Send + 'static, + $streamingBodyTraitBounds B::Data: Send, B::Error: Into<$serverCrate::BoxError>, $serverCrate::rejection::SmithyRejection: From<::Error> - """ + """.trimIndent() } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpProtocolGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpProtocolGenerator.kt index 30225f36a15..565d413c981 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpProtocolGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpProtocolGenerator.kt @@ -7,15 +7,20 @@ package software.amazon.smithy.rust.codegen.server.smithy.protocols import software.amazon.smithy.aws.traits.protocols.RestJson1Trait import software.amazon.smithy.aws.traits.protocols.RestXmlTrait +import software.amazon.smithy.codegen.core.CodegenException import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.model.knowledge.HttpBinding import software.amazon.smithy.model.knowledge.HttpBindingIndex import software.amazon.smithy.model.node.ExpectationNotMetException import software.amazon.smithy.model.shapes.CollectionShape +import software.amazon.smithy.model.shapes.MapShape +import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.Shape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.traits.ErrorTrait import software.amazon.smithy.model.traits.HttpErrorTrait +import software.amazon.smithy.model.traits.MediaTypeTrait import software.amazon.smithy.rust.codegen.rustlang.Attribute import software.amazon.smithy.rust.codegen.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.rustlang.RustModule @@ -23,6 +28,7 @@ import software.amazon.smithy.rust.codegen.rustlang.RustType import software.amazon.smithy.rust.codegen.rustlang.RustWriter import software.amazon.smithy.rust.codegen.rustlang.Writable import software.amazon.smithy.rust.codegen.rustlang.asType +import software.amazon.smithy.rust.codegen.rustlang.autoDeref import software.amazon.smithy.rust.codegen.rustlang.render import software.amazon.smithy.rust.codegen.rustlang.rust import software.amazon.smithy.rust.codegen.rustlang.rustBlock @@ -47,11 +53,15 @@ import software.amazon.smithy.rust.codegen.smithy.protocols.HttpBindingDescripto import software.amazon.smithy.rust.codegen.smithy.protocols.HttpBoundProtocolBodyGenerator import software.amazon.smithy.rust.codegen.smithy.protocols.HttpLocation import software.amazon.smithy.rust.codegen.smithy.protocols.Protocol +import software.amazon.smithy.rust.codegen.util.UNREACHABLE import software.amazon.smithy.rust.codegen.util.dq import software.amazon.smithy.rust.codegen.util.expectTrait +import software.amazon.smithy.rust.codegen.util.findStreamingMember import software.amazon.smithy.rust.codegen.util.getTrait import software.amazon.smithy.rust.codegen.util.hasStreamingMember +import software.amazon.smithy.rust.codegen.util.hasTrait import software.amazon.smithy.rust.codegen.util.inputShape +import software.amazon.smithy.rust.codegen.util.isPrimitive import software.amazon.smithy.rust.codegen.util.outputShape import software.amazon.smithy.rust.codegen.util.toSnakeCase import java.util.logging.Logger @@ -99,6 +109,7 @@ private class ServerHttpProtocolImplGenerator( val httpBindingResolver = protocol.httpBindingResolver private val operationDeserModule = RustModule.private("operation_deser") private val operationSerModule = RustModule.private("operation_ser") + private val Encoder = CargoDependency.SmithyTypes(runtimeConfig).asType().member("primitive::Encoder") private val codegenScope = arrayOf( "AsyncTrait" to ServerCargoDependency.AsyncTrait.asType(), @@ -123,13 +134,12 @@ private class ServerHttpProtocolImplGenerator( } /* - * Generation of `FromRequest` and `IntoResponse`. They are currently only implemented for non-streaming request - * and response bodies, that is, models without streaming traits - * (https://awslabs.github.io/smithy/1.0/spec/core/stream-traits.html). - * For non-streaming request bodies, we require the HTTP body to be fully read in memory before parsing or - * deserialization. From a server perspective we need a way to parse an HTTP request from `Bytes` and serialize + * Generation of `FromRequest` and `IntoResponse`. + * For non-streaming request bodies, that is, models without streaming traits + * (https://awslabs.github.io/smithy/1.0/spec/core/stream-traits.html) + * we require the HTTP body to be fully read in memory before parsing or deserialization. + * From a server perspective we need a way to parse an HTTP request from `Bytes` and serialize * an HTTP response to `Bytes`. - * TODO Add support for streaming. * These traits are the public entrypoint of the ser/de logic of the `aws-smithy-http-server` server. */ private fun RustWriter.renderTraits( @@ -138,25 +148,9 @@ private class ServerHttpProtocolImplGenerator( operationShape: OperationShape ) { val operationName = symbolProvider.toSymbol(operationShape).name - // Implement Axum `FromRequest` trait for input types. val inputName = "${operationName}${ServerHttpProtocolGenerator.OPERATION_INPUT_WRAPPER_SUFFIX}" - val fromRequest = if (operationShape.inputShape(model).hasStreamingMember(model)) { - // For streaming request bodies, we need to generate a different implementation of the `FromRequest` trait. - // It will first offer the streaming input to the parser and potentially read the body into memory - // if an error occurred or if the streaming parser indicates that it needs the full data to proceed. - """ - async fn from_request(_req: &mut #{AxumCore}::extract::RequestParts) -> Result { - todo!("Streaming support for input shapes is not yet supported in `smithy-rs`") - } - """.trimIndent() - } else { - """ - async fn from_request(req: &mut #{AxumCore}::extract::RequestParts) -> Result { - Ok($inputName(#{parse_request}(req).await?)) - } - """.trimIndent() - } + // Implement Axum `FromRequest` trait for input types. rustTemplate( """ pub struct $inputName(pub #{I}); @@ -164,12 +158,15 @@ private class ServerHttpProtocolImplGenerator( impl #{AxumCore}::extract::FromRequest for $inputName where B: #{SmithyHttpServer}::HttpBody + Send, + ${getStreamingBodyTraitBounds(operationShape)} B::Data: Send, B::Error: Into<#{SmithyHttpServer}::BoxError>, #{SmithyRejection}: From<::Error> { type Rejection = #{SmithyRejection}; - $fromRequest + async fn from_request(req: &mut #{AxumCore}::extract::RequestParts) -> Result { + Ok($inputName(#{parse_request}(req).await?)) + } } """.trimIndent(), *codegenScope, @@ -178,21 +175,19 @@ private class ServerHttpProtocolImplGenerator( ) // Implement Axum `IntoResponse` for output types. + val outputName = "${operationName}${ServerHttpProtocolGenerator.OPERATION_OUTPUT_WRAPPER_SUFFIX}" val errorSymbol = operationShape.errorSymbol(symbolProvider) - val httpExtensions = setHttpExtensions(operationShape) - // For streaming response bodies, we need to generate a different implementation of the `IntoResponse` trait. - // The body type will have to be a `StreamBody`. The service implementer will return a `Stream` from their handler. - val intoResponseStreaming = "todo!(\"Streaming support for output shapes is not yet supported in `smithy-rs`\")" + if (operationShape.errors.isNotEmpty()) { - val intoResponseImpl = if (operationShape.outputShape(model).hasStreamingMember(model)) { - intoResponseStreaming - } else { + // The output of fallible operations is a `Result` which we convert into an + // isomorphic `enum` type we control that can in turn be converted into a response. + val intoResponseImpl = """ let mut response = match self { Self::Output(o) => { - match #{serialize_response}(&o) { + match #{serialize_response}(o) { Ok(response) => response, Err(e) => { e.into_response() @@ -214,9 +209,7 @@ private class ServerHttpProtocolImplGenerator( $httpExtensions response """.trimIndent() - } - // The output of fallible operations is a `Result` which we convert into an isomorphic `enum` type we control - // that can in turn be converted into a response. + rustTemplate( """ pub enum $outputName { @@ -237,25 +230,18 @@ private class ServerHttpProtocolImplGenerator( "serialize_error" to serverSerializeError(operationShape) ) } else { - val handleSerializeOutput = if (operationShape.outputShape(model).hasStreamingMember(model)) { - intoResponseStreaming - } else { - """ - match #{serialize_response}(&self.0) { - Ok(response) => response, - Err(e) => e.into_response() - } - """.trimIndent() - } - // The output of non-fallible operations is a model type which we convert into a "wrapper" unit `struct` type - // we control that can in turn be converted into a response. + // The output of non-fallible operations is a model type which we convert into + // a "wrapper" unit `struct` type we control that can in turn be converted into a response. rustTemplate( """ pub struct $outputName(pub #{O}); ##[#{AsyncTrait}::async_trait] impl #{AxumCore}::response::IntoResponse for $outputName { fn into_response(self) -> #{AxumCore}::response::Response { - $handleSerializeOutput + match #{serialize_response}(self.0) { + Ok(response) => response, + Err(e) => e.into_response() + } } } """.trimIndent(), @@ -324,6 +310,7 @@ private class ServerHttpProtocolImplGenerator( val inputSymbol = symbolProvider.toSymbol(inputShape) val includedMembers = httpBindingResolver.requestMembers(operationShape, HttpLocation.DOCUMENT) val unusedVars = if (includedMembers.isEmpty()) "##[allow(unused_variables)] " else "" + return RuntimeType.forInlineFun(fnName, operationDeserModule) { Attribute.Custom("allow(clippy::unnecessary_wraps)").render(it) it.rustBlockTemplate( @@ -336,10 +323,11 @@ private class ServerHttpProtocolImplGenerator( > where B: #{SmithyHttpServer}::HttpBody + Send, + ${getStreamingBodyTraitBounds(operationShape)} B::Data: Send, B::Error: Into<#{SmithyHttpServer}::BoxError>, #{SmithyRejection}: From<::Error> - """, + """.trimIndent(), *codegenScope, "I" to inputSymbol, ) { @@ -360,8 +348,12 @@ private class ServerHttpProtocolImplGenerator( val outputSymbol = symbolProvider.toSymbol(outputShape) return RuntimeType.forInlineFun(fnName, operationSerModule) { Attribute.Custom("allow(clippy::unnecessary_wraps)").render(it) + + // Note we only need to take ownership of the output in the case that it contains streaming members. + // However we currently always take ownership here, but worth noting in case in the future we want + // to generate different signatures for streaming vs non-streaming for some reason. it.rustBlockTemplate( - "pub fn $fnName(output: &#{O}) -> std::result::Result<#{AxumCore}::response::Response, #{SmithyRejection}>", + "pub fn $fnName(output: #{O}) -> std::result::Result<#{AxumCore}::response::Response, #{SmithyRejection}>", *codegenScope, "O" to outputSymbol, ) { @@ -444,13 +436,6 @@ private class ServerHttpProtocolImplGenerator( operationShape: OperationShape, bindings: List, ) { - val structuredDataSerializer = protocol.structuredDataSerializer(operationShape) - structuredDataSerializer.serverOutputSerializer(operationShape).also { serializer -> - rust( - "let payload = #T(output)?;", - serializer - ) - } // avoid non-usage warnings for response Attribute.AllowUnusedMut.render(this) rustTemplate("let mut response = #{http}::Response::builder();", *codegenScope) @@ -460,6 +445,24 @@ private class ServerHttpProtocolImplGenerator( serializedValue(this) } } + val streamingMember = operationShape.outputShape(model).findStreamingMember(model) + if (streamingMember != null) { + val memberName = symbolProvider.toMemberName(streamingMember) + rustTemplate( + """ + let payload = #{SmithyHttpServer}::body::Body::wrap_stream(output.$memberName); + """, + *codegenScope, + ) + } else { + val structuredDataSerializer = protocol.structuredDataSerializer(operationShape) + structuredDataSerializer.serverOutputSerializer(operationShape).also { serializer -> + rust( + "let payload = #T(&output)?;", + serializer + ) + } + } rustTemplate( """ response.body(#{SmithyHttpServer}::body::to_boxed(payload))? @@ -475,14 +478,10 @@ private class ServerHttpProtocolImplGenerator( val operationName = symbolProvider.toSymbol(operationShape).name val member = binding.member return when (binding.location) { - HttpLocation.HEADER, HttpLocation.PREFIX_HEADERS, HttpLocation.PAYLOAD -> { - logger.warning("[rust-server-codegen] $operationName: response serialization does not currently support ${binding.location} bindings") - null - } - HttpLocation.DOCUMENT -> { - // document is handled separately - null - } + HttpLocation.HEADER -> writable { serverRenderHeaderGenerator(this, binding) } + HttpLocation.PREFIX_HEADERS -> writable { serverRenderPrefixHeadersGenerator(this, binding) } + HttpLocation.PAYLOAD -> { null } // payload is handled separately + HttpLocation.DOCUMENT -> { null } // document is handled separately HttpLocation.RESPONSE_CODE -> writable { val memberName = symbolProvider.toMemberName(member) rustTemplate( @@ -508,6 +507,7 @@ private class ServerHttpProtocolImplGenerator( bindings: List, ) { val structuredDataParser = protocol.structuredDataParser(operationShape) + val streamingMember = inputShape.findStreamingMember(model) Attribute.AllowUnusedMut.render(this) rust("let mut input = #T::default();", inputShape.builderSymbol(symbolProvider)) val parser = structuredDataParser.serverInputParser(operationShape) @@ -525,6 +525,14 @@ private class ServerHttpProtocolImplGenerator( *codegenScope, "parser" to parser, ) + } else if (streamingMember != null) { + rustTemplate( + """ + let body = request.take_body().ok_or(#{SmithyHttpServer}::rejection::BodyAlreadyExtracted)?; + input = input.${streamingMember.setterName()}(Some(body.into())); + """.trimIndent(), + *codegenScope + ) } for (binding in bindings) { val member = binding.member @@ -801,6 +809,109 @@ private class ServerHttpProtocolImplGenerator( ) } + private fun serverRenderHeaderGenerator(writer: RustWriter, binding: HttpBindingDescriptor) { + val memberShape = binding.member + val memberType = model.expectShape(memberShape.target) + val memberSymbol = symbolProvider.toSymbol(memberShape) + val memberName = symbolProvider.toMemberName(memberShape) + writer.ifSet(memberType, memberSymbol, "&output.$memberName") { field -> + writer.listForEach(memberType, field) { innerField, targetId -> + val innerMemberType = model.expectShape(targetId) + if (innerMemberType.isPrimitive()) { + rust("let mut encoder = #T::from(${autoDeref(innerField)});", Encoder) + } + val formatted = headerFmtFun(writer, innerMemberType, memberShape, innerField) + val safeName = safeName("formatted") + writer.write("let $safeName = $formatted;") + writer.rustBlock("if !$safeName.is_empty()") { + writer.rustTemplate( + """ + use std::convert::TryFrom; + let header_value = $safeName; + let header_value = http::header::HeaderValue::try_from(&*header_value) + .map_err(|err| { + #{SmithyHttpServer}::rejection::Serialize::from( + format!("{} cannot be used as a header value: {}", ${memberName.dq()}, err).as_str() + ) + })?; + response = response.header(${binding.locationName.dq()}, header_value); + """.trimIndent(), + *codegenScope + ) + } + } + } + } + + private fun serverRenderPrefixHeadersGenerator(writer: RustWriter, binding: HttpBindingDescriptor) { + val memberShape = binding.member + val memberType = model.expectShape(memberShape.target) + val memberSymbol = symbolProvider.toSymbol(memberShape) + val memberName = symbolProvider.toMemberName(memberShape) + val target = when (memberType) { + is CollectionShape -> model.expectShape(memberType.member.target) + is MapShape -> model.expectShape(memberType.value.target) + else -> UNREACHABLE("unexpected member for prefix headers: $memberType") + } + writer.ifSet(memberType, memberSymbol, "&output.$memberName") { field -> + writer.rustTemplate( + """ + for (k, v) in $field { + use std::str::FromStr; + let header_name = http::header::HeaderName::from_str( + &format!("{}{}", ${binding.locationName.dq()}, &k) + ) + .map_err(|err| { + #{SmithyHttpServer}::rejection::Serialize::from( + format!("{} cannot be used as a header name: {}", ${memberName.dq()}, err).as_str() + ) + })?; + use std::convert::TryFrom; + let header_value = ${headerFmtFun(writer, target, memberShape, "v")}; + let header_value = http::header::HeaderValue::try_from(header_value) + .map_err(|err| { + #{SmithyHttpServer}::rejection::Serialize::from( + format!("{} cannot be used as a header value: {}", ${memberName.dq()}, err).as_str() + ) + })?; + response = response.header(header_name, header_value); + } + """.trimIndent(), + *codegenScope + ) + } + } + + /** + * Format [member] in the when used as an HTTP header + */ + private fun headerFmtFun(writer: RustWriter, target: Shape, member: MemberShape, targetName: String): String { + return when { + target.isStringShape -> { + if (target.hasTrait()) { + val func = writer.format(RuntimeType.Base64Encode(runtimeConfig)) + "$func(&$targetName)" + } else { + "AsRef::::as_ref($targetName)" + } + } + target.isTimestampShape -> { + val index = HttpBindingIndex.of(model) + val timestampFormat = + index.determineTimestampFormat(member, HttpBinding.Location.HEADER, protocol.defaultTimestampFormat) + val timestampFormatType = RuntimeType.TimestampFormat(runtimeConfig, timestampFormat) + "$targetName.fmt(${writer.format(timestampFormatType)})?" + } + target.isListShape || target.isMemberShape -> { + throw IllegalArgumentException("lists should be handled at a higher level") + } + target.isPrimitive() -> { + "encoder.encode()" + } + else -> throw CodegenException("unexpected shape: $target") + } + } + private fun generateParsePercentEncodedStrFn(binding: HttpBindingDescriptor): RuntimeType { // HTTP bindings we support that contain percent-encoded data. check(binding.location == HttpLocation.LABEL || binding.location == HttpLocation.QUERY) @@ -906,4 +1017,12 @@ private class ServerHttpProtocolImplGenerator( } } } + + private fun getStreamingBodyTraitBounds(operationShape: OperationShape): String { + if (operationShape.inputShape(model).hasStreamingMember(model)) { + return "B: Into<#{SmithyHttpServer}::ByteStream>," + } else { + return "" + } + } } diff --git a/rust-runtime/aws-smithy-http-server/Cargo.toml b/rust-runtime/aws-smithy-http-server/Cargo.toml index 816ddb973ef..4b697ab2f02 100644 --- a/rust-runtime/aws-smithy-http-server/Cargo.toml +++ b/rust-runtime/aws-smithy-http-server/Cargo.toml @@ -26,7 +26,7 @@ bytes = "1.1" futures-util = { version = "0.3", default-features = false } http = "0.2" http-body = "0.4" -hyper = { version = "0.14", features = ["server", "http1", "http2", "tcp"] } +hyper = { version = "0.14", features = ["server", "http1", "http2", "tcp", "stream"] } mime = "0.3" pin-project-lite = "0.2" regex = "1.0" diff --git a/rust-runtime/aws-smithy-http-server/src/lib.rs b/rust-runtime/aws-smithy-http-server/src/lib.rs index 95302b50270..f8f2a822622 100644 --- a/rust-runtime/aws-smithy-http-server/src/lib.rs +++ b/rust-runtime/aws-smithy-http-server/src/lib.rs @@ -30,6 +30,9 @@ pub use self::routing::Router; #[doc(inline)] pub use tower_http::add_extension::{AddExtension, AddExtensionLayer}; +#[doc(inline)] +pub use aws_smithy_http::byte_stream::ByteStream; + /// Alias for a type-erased error type. pub use axum_core::BoxError; diff --git a/rust-runtime/aws-smithy-http-server/src/rejection.rs b/rust-runtime/aws-smithy-http-server/src/rejection.rs index e350d2922d6..97cfad86fe0 100644 --- a/rust-runtime/aws-smithy-http-server/src/rejection.rs +++ b/rust-runtime/aws-smithy-http-server/src/rejection.rs @@ -186,6 +186,12 @@ impl From for SmithyRejection { } } +impl From for SmithyRejection { + fn from(err: aws_smithy_types::date_time::DateTimeFormatError) -> Self { + SmithyRejection::Serialize(Serialize::from_err(err)) + } +} + impl From for SmithyRejection { fn from(err: aws_smithy_types::primitive::PrimitiveParseError) -> Self { SmithyRejection::Deserialize(Deserialize::from_err(err)) diff --git a/rust-runtime/aws-smithy-http/src/byte_stream.rs b/rust-runtime/aws-smithy-http/src/byte_stream.rs index 00b8cecb00c..fdd7ec91243 100644 --- a/rust-runtime/aws-smithy-http/src/byte_stream.rs +++ b/rust-runtime/aws-smithy-http/src/byte_stream.rs @@ -326,6 +326,14 @@ impl From> for ByteStream { } } +impl From for ByteStream { + fn from(input: hyper::Body) -> Self { + ByteStream::new(SdkBody::from_dyn( + input.map_err(|e| e.into_cause().unwrap()).boxed(), + )) + } +} + #[derive(Debug)] pub struct Error(Box);