diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationHandlerGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationHandlerGenerator.kt index da1a7153240..e8f24471dc5 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationHandlerGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationHandlerGenerator.kt @@ -18,6 +18,8 @@ import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerHttpPro import software.amazon.smithy.rust.codegen.smithy.CodegenContext import software.amazon.smithy.rust.codegen.smithy.RuntimeType import software.amazon.smithy.rust.codegen.smithy.generators.error.errorSymbol +import software.amazon.smithy.rust.codegen.util.hasStreamingMember +import software.amazon.smithy.rust.codegen.util.inputShape import software.amazon.smithy.rust.codegen.util.outputShape /** @@ -132,13 +134,18 @@ class ServerOperationHandlerGenerator( } else { symbolProvider.toSymbol(operation.outputShape(model)).fullName } + val streamingBodyTraitBounds = if (operation.inputShape(model).hasStreamingMember(model)) { + "\n B: Into<#{SmithyHttpServer}::ByteStream>," + } else { + "" + } return """ $inputFn Fut: std::future::Future + Send, - B: $serverCrate::HttpBody + Send + 'static, + B: $serverCrate::HttpBody + Send + 'static, $streamingBodyTraitBounds B::Data: Send, B::Error: Into<$serverCrate::BoxError>, $serverCrate::rejection::SmithyRejection: From<::Error> - """ + """.trimIndent() } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt index 87cab3c3403..c724694cc40 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt @@ -212,16 +212,16 @@ class ServerProtocolTestGenerator( rustTemplate( """ - ##[allow(unused_mut)] let mut http_request = http::Request::builder() - .uri("${httpRequestTestCase.uri}") - """, - *codegenScope + ##[allow(unused_mut)] let mut http_request = http::Request::builder() + .uri("${httpRequestTestCase.uri}") + """, + *codegenScope ) for (header in httpRequestTestCase.headers) { rust(".header(${header.key.dq()}, ${header.value.dq()})") } rustTemplate( - """ + """ .body(#{SmithyHttpServer}::Body::from(#{Bytes}::from_static(b${httpRequestTestCase.body.orNull()?.dq()}))) .unwrap(); """, @@ -363,9 +363,9 @@ class ServerProtocolTestGenerator( basicCheck( requireHeaders, rustWriter, - "required_headers", - actualExpression, - "require_headers" + "required_headers", + actualExpression, + "require_headers" ) } @@ -373,9 +373,9 @@ class ServerProtocolTestGenerator( basicCheck( forbidHeaders, rustWriter, - "forbidden_headers", + "forbidden_headers", actualExpression, - "forbid_headers" + "forbid_headers" ) } @@ -526,14 +526,14 @@ class ServerProtocolTestGenerator( FailingTest(RestJson, "RestJsonSupportsNegativeInfinityFloatInputs", Action.Response), FailingTest(RestJson, "RestJsonStreamingTraitsWithBlob", Action.Request), FailingTest(RestJson, "RestJsonStreamingTraitsWithNoBlobBody", Action.Request), - FailingTest(RestJson, "RestJsonStreamingTraitsWithBlob", Action.Response), - FailingTest(RestJson, "RestJsonStreamingTraitsWithNoBlobBody", Action.Response), + // FailingTest(RestJson, "RestJsonStreamingTraitsWithBlob", Action.Response), + // FailingTest(RestJson, "RestJsonStreamingTraitsWithNoBlobBody", Action.Response), FailingTest(RestJson, "RestJsonStreamingTraitsRequireLengthWithBlob", Action.Request), FailingTest(RestJson, "RestJsonStreamingTraitsRequireLengthWithNoBlobBody", Action.Request), FailingTest(RestJson, "RestJsonStreamingTraitsRequireLengthWithBlob", Action.Response), - FailingTest(RestJson, "RestJsonStreamingTraitsRequireLengthWithNoBlobBody", Action.Response), + // FailingTest(RestJson, "RestJsonStreamingTraitsRequireLengthWithNoBlobBody", Action.Response), FailingTest(RestJson, "RestJsonStreamingTraitsWithMediaTypeWithBlob", Action.Request), - FailingTest(RestJson, "RestJsonStreamingTraitsWithMediaTypeWithBlob", Action.Response), + // FailingTest(RestJson, "RestJsonStreamingTraitsWithMediaTypeWithBlob", Action.Response), FailingTest(RestJson, "RestJsonHttpWithEmptyBlobPayload", Action.Request), FailingTest(RestJson, "RestJsonHttpWithEmptyStructurePayload", Action.Request), @@ -604,8 +604,9 @@ class ServerProtocolTestGenerator( ).asObjectNode().get() ).build() private fun fixRestJsonAllQueryStringTypes(testCase: HttpRequestTestCase): HttpRequestTestCase = - testCase.toBuilder().params( - Node.parse("""{ + testCase.toBuilder().params( + Node.parse( + """{ "queryString": "Hello there", "queryStringList": ["a", "b", "c"], "queryStringSet": ["a", "b", "c"], @@ -644,8 +645,9 @@ class ServerProtocolTestGenerator( "Enum": ["Foo"], "EnumList": ["Foo", "Baz", "Bar"] } - }""".trimMargin()).asObjectNode().get() - ).build() + }""".trimMargin() + ).asObjectNode().get() + ).build() // This test assumes that errors in responses are identified by an `X-Amzn-Errortype` header with the error shape name. // However, Smithy specifications for AWS protocols that serialize to JSON recommend that new server implementations // serialize error types using a `__type` field in the body. diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpProtocolGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpProtocolGenerator.kt index cca36a0adf5..59bcb9c2ec9 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpProtocolGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpProtocolGenerator.kt @@ -8,7 +8,6 @@ package software.amazon.smithy.rust.codegen.server.smithy.protocols import software.amazon.smithy.aws.traits.protocols.RestJson1Trait import software.amazon.smithy.aws.traits.protocols.RestXmlTrait import software.amazon.smithy.codegen.core.Symbol -import software.amazon.smithy.model.knowledge.HttpBinding import software.amazon.smithy.model.knowledge.HttpBindingIndex import software.amazon.smithy.model.node.ExpectationNotMetException import software.amazon.smithy.model.shapes.CollectionShape @@ -55,6 +54,7 @@ import software.amazon.smithy.rust.codegen.smithy.protocols.parse.StructuredData import software.amazon.smithy.rust.codegen.util.UNREACHABLE import software.amazon.smithy.rust.codegen.util.dq import software.amazon.smithy.rust.codegen.util.expectTrait +import software.amazon.smithy.rust.codegen.util.findStreamingMember import software.amazon.smithy.rust.codegen.util.getTrait import software.amazon.smithy.rust.codegen.util.hasStreamingMember import software.amazon.smithy.rust.codegen.util.inputShape @@ -132,13 +132,12 @@ private class ServerHttpProtocolImplGenerator( } /* - * Generation of `FromRequest` and `IntoResponse`. They are currently only implemented for non-streaming request - * and response bodies, that is, models without streaming traits - * (https://awslabs.github.io/smithy/1.0/spec/core/stream-traits.html). - * For non-streaming request bodies, we require the HTTP body to be fully read in memory before parsing or - * deserialization. From a server perspective we need a way to parse an HTTP request from `Bytes` and serialize + * Generation of `FromRequest` and `IntoResponse`. + * For non-streaming request bodies, that is, models without streaming traits + * (https://awslabs.github.io/smithy/1.0/spec/core/stream-traits.html) + * we require the HTTP body to be fully read in memory before parsing or deserialization. + * From a server perspective we need a way to parse an HTTP request from `Bytes` and serialize * an HTTP response to `Bytes`. - * TODO Add support for streaming. * These traits are the public entrypoint of the ser/de logic of the `aws-smithy-http-server` server. */ private fun RustWriter.renderTraits( @@ -147,38 +146,24 @@ private class ServerHttpProtocolImplGenerator( operationShape: OperationShape ) { val operationName = symbolProvider.toSymbol(operationShape).name - // Implement Axum `FromRequest` trait for input types. val inputName = "${operationName}${ServerHttpProtocolGenerator.OPERATION_INPUT_WRAPPER_SUFFIX}" - val fromRequest = if (operationShape.inputShape(model).hasStreamingMember(model)) { - // For streaming request bodies, we need to generate a different implementation of the `FromRequest` trait. - // It will first offer the streaming input to the parser and potentially read the body into memory - // if an error occurred or if the streaming parser indicates that it needs the full data to proceed. - """ - async fn from_request(_req: &mut #{AxumCore}::extract::RequestParts) -> Result { - todo!("Streaming support for input shapes is not yet supported in `smithy-rs`") - } - """.trimIndent() - } else { - """ - async fn from_request(req: &mut #{AxumCore}::extract::RequestParts) -> Result { - Ok($inputName(#{parse_request}(req).await?)) - } - """.trimIndent() - } + // Implement Axum `FromRequest` trait for input types. rustTemplate( """ pub struct $inputName(pub #{I}); ##[#{AsyncTrait}::async_trait] impl #{AxumCore}::extract::FromRequest for $inputName where - B: #{SmithyHttpServer}::HttpBody + Send, + B: #{SmithyHttpServer}::HttpBody + Send, ${getStreamingBodyTraitBounds(operationShape)} B::Data: Send, B::Error: Into<#{SmithyHttpServer}::BoxError>, #{SmithyRejection}: From<::Error> { type Rejection = #{SmithyRejection}; - $fromRequest + async fn from_request(req: &mut #{AxumCore}::extract::RequestParts) -> Result { + Ok($inputName(#{parse_request}(req).await?)) + } } """.trimIndent(), *codegenScope, @@ -187,21 +172,19 @@ private class ServerHttpProtocolImplGenerator( ) // Implement Axum `IntoResponse` for output types. + val outputName = "${operationName}${ServerHttpProtocolGenerator.OPERATION_OUTPUT_WRAPPER_SUFFIX}" val errorSymbol = operationShape.errorSymbol(symbolProvider) - val httpExtensions = setHttpExtensions(operationShape) - // For streaming response bodies, we need to generate a different implementation of the `IntoResponse` trait. - // The body type will have to be a `StreamBody`. The service implementer will return a `Stream` from their handler. - val intoResponseStreaming = "todo!(\"Streaming support for output shapes is not yet supported in `smithy-rs`\")" + if (operationShape.errors.isNotEmpty()) { - val intoResponseImpl = if (operationShape.outputShape(model).hasStreamingMember(model)) { - intoResponseStreaming - } else { + // The output of fallible operations is a `Result` which we convert into an + // isomorphic `enum` type we control that can in turn be converted into a response. + val intoResponseImpl = """ let mut response = match self { Self::Output(o) => { - match #{serialize_response}(&o) { + match #{serialize_response}(o) { Ok(response) => response, Err(e) => { e.into_response() @@ -223,9 +206,7 @@ private class ServerHttpProtocolImplGenerator( $httpExtensions response """.trimIndent() - } - // The output of fallible operations is a `Result` which we convert into an isomorphic `enum` type we control - // that can in turn be converted into a response. + rustTemplate( """ pub enum $outputName { @@ -246,27 +227,25 @@ private class ServerHttpProtocolImplGenerator( "serialize_error" to serverSerializeError(operationShape) ) } else { - val handleSerializeOutput = if (operationShape.outputShape(model).hasStreamingMember(model)) { - intoResponseStreaming - } else { + // The output of non-fallible operations is a model type which we convert into + // a "wrapper" unit `struct` type we control that can in turn be converted into a response. + val intoResponseImpl = """ - let mut response = match #{serialize_response}(&self.0) { + let mut response = match #{serialize_response}(self.0) { Ok(response) => response, Err(e) => e.into_response() }; $httpExtensions response """.trimIndent() - } - // The output of non-fallible operations is a model type which we convert into a "wrapper" unit `struct` type - // we control that can in turn be converted into a response. + rustTemplate( """ pub struct $outputName(pub #{O}); ##[#{AsyncTrait}::async_trait] impl #{AxumCore}::response::IntoResponse for $outputName { fn into_response(self) -> #{AxumCore}::response::Response { - $handleSerializeOutput + $intoResponseImpl } } """.trimIndent(), @@ -335,6 +314,7 @@ private class ServerHttpProtocolImplGenerator( val inputSymbol = symbolProvider.toSymbol(inputShape) val includedMembers = httpBindingResolver.requestMembers(operationShape, HttpLocation.DOCUMENT) val unusedVars = if (includedMembers.isEmpty()) "##[allow(unused_variables)] " else "" + return RuntimeType.forInlineFun(fnName, operationDeserModule) { Attribute.Custom("allow(clippy::unnecessary_wraps)").render(it) it.rustBlockTemplate( @@ -346,11 +326,11 @@ private class ServerHttpProtocolImplGenerator( #{SmithyRejection} > where - B: #{SmithyHttpServer}::HttpBody + Send, + B: #{SmithyHttpServer}::HttpBody + Send, ${getStreamingBodyTraitBounds(operationShape)} B::Data: Send, B::Error: Into<#{SmithyHttpServer}::BoxError>, #{SmithyRejection}: From<::Error> - """, + """.trimIndent(), *codegenScope, "I" to inputSymbol, ) { @@ -371,8 +351,12 @@ private class ServerHttpProtocolImplGenerator( val outputSymbol = symbolProvider.toSymbol(outputShape) return RuntimeType.forInlineFun(fnName, operationSerModule) { Attribute.Custom("allow(clippy::unnecessary_wraps)").render(it) + + // Note we only need to take ownership of the output in the case that it contains streaming members. + // However we currently always take ownership here, but worth noting in case in the future we want + // to generate different signatures for streaming vs non-streaming for some reason. it.rustBlockTemplate( - "pub fn $fnName(output: &#{O}) -> std::result::Result<#{AxumCore}::response::Response, #{SmithyRejection}>", + "pub fn $fnName(output: #{O}) -> std::result::Result<#{AxumCore}::response::Response, #{SmithyRejection}>", *codegenScope, "O" to outputSymbol, ) { @@ -459,13 +443,6 @@ private class ServerHttpProtocolImplGenerator( operationShape: OperationShape, bindings: List, ) { - val structuredDataSerializer = protocol.structuredDataSerializer(operationShape) - structuredDataSerializer.serverOutputSerializer(operationShape)?.let { serializer -> - rust( - "let payload = #T(output)?;", - serializer - ) - } ?: rust("""let payload = "";""") // avoid non-usage warnings for response Attribute.AllowUnusedMut.render(this) rustTemplate("let mut builder = #{http}::Response::builder();", *codegenScope) @@ -477,6 +454,24 @@ private class ServerHttpProtocolImplGenerator( serializedValue(this) } } + val streamingMember = operationShape.outputShape(model).findStreamingMember(model) + if (streamingMember != null) { + val memberName = symbolProvider.toMemberName(streamingMember) + rustTemplate( + """ + let payload = #{SmithyHttpServer}::body::Body::wrap_stream(output.$memberName); + """, + *codegenScope, + ) + } else { + val structuredDataSerializer = protocol.structuredDataSerializer(operationShape) + structuredDataSerializer.serverOutputSerializer(operationShape)?.let { serializer -> + rust( + "let payload = #T(&output)?;", + serializer + ) + } ?: rust("""let payload = "";""") + } rustTemplate( """ builder.body(#{SmithyHttpServer}::body::to_boxed(payload))? @@ -510,11 +505,16 @@ private class ServerHttpProtocolImplGenerator( } val bindingGenerator = ServerResponseBindingGenerator(protocol, codegenContext, operationShape) - val addHeadersFn = bindingGenerator.generateAddHeadersFn(errorShape?: operationShape) + val addHeadersFn = bindingGenerator.generateAddHeadersFn(errorShape ?: operationShape) if (addHeadersFn != null) { + // we need to allow needless_borrow here because some of the occurrences + // of this function require borrowing, and some do not. + Attribute.Custom("allow(clippy::needless_borrow)").render(this) rust( """ - builder = #{T}(output, builder)?; + { + builder = #{T}(&output, builder)?; + } """.trimIndent(), addHeadersFn ) @@ -528,12 +528,11 @@ private class ServerHttpProtocolImplGenerator( val operationName = symbolProvider.toSymbol(operationShape).name val member = binding.member return when (binding.location) { - HttpLocation.HEADER, HttpLocation.PREFIX_HEADERS, HttpLocation.DOCUMENT -> { - // All of these are handled separately. - null - } + HttpLocation.HEADER, + HttpLocation.PREFIX_HEADERS, + HttpLocation.DOCUMENT, HttpLocation.PAYLOAD -> { - logger.warning("[rust-server-codegen] $operationName: response serialization does not currently support ${binding.location} bindings") + // All of these are handled separately. null } HttpLocation.RESPONSE_CODE -> writable { @@ -562,6 +561,7 @@ private class ServerHttpProtocolImplGenerator( ) { val httpBindingGenerator = ServerRequestBindingGenerator(protocol, codegenContext, operationShape) val structuredDataParser = protocol.structuredDataParser(operationShape) + val streamingMember = inputShape.findStreamingMember(model) Attribute.AllowUnusedMut.render(this) rust("let mut input = #T::default();", inputShape.builderSymbol(symbolProvider)) val parser = structuredDataParser.serverInputParser(operationShape) @@ -579,13 +579,23 @@ private class ServerHttpProtocolImplGenerator( *codegenScope, "parser" to parser, ) + } else if (streamingMember != null) { + rustTemplate( + """ + let body = request.take_body().ok_or(#{SmithyHttpServer}::rejection::BodyAlreadyExtracted)?; + input = input.${streamingMember.setterName()}(Some(body.into())); + """.trimIndent(), + *codegenScope + ) } - for (binding in bindings) { - val member = binding.member - val parsedValue = serverRenderBindingParser(binding, operationShape, httpBindingGenerator, structuredDataParser) - if (parsedValue != null) { - withBlock("input = input.${member.setterName()}(", ");") { - parsedValue(this) + if (streamingMember == null) { + for (binding in bindings) { + val member = binding.member + val parsedValue = serverRenderBindingParser(binding, operationShape, httpBindingGenerator, structuredDataParser) + if (parsedValue != null) { + withBlock("input = input.${member.setterName()}(", ");") { + parsedValue(this) + } } } } @@ -1047,4 +1057,12 @@ private class ServerHttpProtocolImplGenerator( } } } + + private fun getStreamingBodyTraitBounds(operationShape: OperationShape): String { + if (operationShape.inputShape(model).hasStreamingMember(model)) { + return "\n B: Into<#{SmithyHttpServer}::ByteStream>," + } else { + return "" + } + } } diff --git a/rust-runtime/aws-smithy-http-server/Cargo.toml b/rust-runtime/aws-smithy-http-server/Cargo.toml index 9606e1787e5..5afeaa1e4a5 100644 --- a/rust-runtime/aws-smithy-http-server/Cargo.toml +++ b/rust-runtime/aws-smithy-http-server/Cargo.toml @@ -26,7 +26,7 @@ bytes = "1.1" futures-util = { version = "0.3", default-features = false } http = "0.2" http-body = "0.4" -hyper = { version = "0.14", features = ["server", "http1", "http2", "tcp"] } +hyper = { version = "0.14", features = ["server", "http1", "http2", "tcp", "stream"] } mime = "0.3" nom = "7" pin-project-lite = "0.2" diff --git a/rust-runtime/aws-smithy-http-server/src/lib.rs b/rust-runtime/aws-smithy-http-server/src/lib.rs index 2e604645c4d..6e2d9b031b6 100644 --- a/rust-runtime/aws-smithy-http-server/src/lib.rs +++ b/rust-runtime/aws-smithy-http-server/src/lib.rs @@ -30,6 +30,9 @@ pub use self::routing::Router; #[doc(inline)] pub use tower_http::add_extension::{AddExtension, AddExtensionLayer}; +#[doc(inline)] +pub use aws_smithy_http::byte_stream::ByteStream; + /// Alias for a type-erased error type. pub use axum_core::BoxError; diff --git a/rust-runtime/aws-smithy-http-server/src/rejection.rs b/rust-runtime/aws-smithy-http-server/src/rejection.rs index 6af2ff17129..2787bcab2cf 100644 --- a/rust-runtime/aws-smithy-http-server/src/rejection.rs +++ b/rust-runtime/aws-smithy-http-server/src/rejection.rs @@ -178,6 +178,12 @@ impl From for SmithyRejection { } } +impl From for SmithyRejection { + fn from(err: aws_smithy_types::date_time::DateTimeFormatError) -> Self { + SmithyRejection::Serialize(Serialize::from_err(err)) + } +} + impl From for SmithyRejection { fn from(err: aws_smithy_types::primitive::PrimitiveParseError) -> Self { SmithyRejection::Deserialize(Deserialize::from_err(err)) diff --git a/rust-runtime/aws-smithy-http/src/byte_stream.rs b/rust-runtime/aws-smithy-http/src/byte_stream.rs index 00b8cecb00c..fdd7ec91243 100644 --- a/rust-runtime/aws-smithy-http/src/byte_stream.rs +++ b/rust-runtime/aws-smithy-http/src/byte_stream.rs @@ -326,6 +326,14 @@ impl From> for ByteStream { } } +impl From for ByteStream { + fn from(input: hyper::Body) -> Self { + ByteStream::new(SdkBody::from_dyn( + input.map_err(|e| e.into_cause().unwrap()).boxed(), + )) + } +} + #[derive(Debug)] pub struct Error(Box);