diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationHandlerGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationHandlerGenerator.kt index da1a7153240..073cafb91c6 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationHandlerGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationHandlerGenerator.kt @@ -18,6 +18,8 @@ import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerHttpPro import software.amazon.smithy.rust.codegen.smithy.CodegenContext import software.amazon.smithy.rust.codegen.smithy.RuntimeType import software.amazon.smithy.rust.codegen.smithy.generators.error.errorSymbol +import software.amazon.smithy.rust.codegen.util.hasStreamingMember +import software.amazon.smithy.rust.codegen.util.inputShape import software.amazon.smithy.rust.codegen.util.outputShape /** @@ -39,6 +41,7 @@ class ServerOperationHandlerGenerator( "PinProjectLite" to ServerCargoDependency.PinProjectLite.asType(), "Tower" to ServerCargoDependency.Tower.asType(), "FuturesUtil" to ServerCargoDependency.FuturesUtil.asType(), + "SmithyHttp" to CargoDependency.SmithyHttp(runtimeConfig).asType(), "SmithyHttpServer" to CargoDependency.SmithyHttpServer(runtimeConfig).asType(), "SmithyRejection" to ServerHttpProtocolGenerator.smithyRejection(runtimeConfig), "Phantom" to ServerRuntimeType.Phantom, @@ -132,13 +135,19 @@ class ServerOperationHandlerGenerator( } else { symbolProvider.toSymbol(operation.outputShape(model)).fullName } + val streamingBodyTraitBounds = if (operation.inputShape(model).hasStreamingMember(model)) { + "B: Into<#{SmithyHttp}::byte_stream::ByteStream>," + } else { + "" + } return """ $inputFn Fut: std::future::Future + Send, B: $serverCrate::HttpBody + Send + 'static, + $streamingBodyTraitBounds B::Data: Send, B::Error: Into<$serverCrate::BoxError>, $serverCrate::rejection::SmithyRejection: From<::Error> - """ + """.trimIndent() } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpProtocolGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpProtocolGenerator.kt index 30225f36a15..db297645dea 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpProtocolGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpProtocolGenerator.kt @@ -7,15 +7,20 @@ package software.amazon.smithy.rust.codegen.server.smithy.protocols import software.amazon.smithy.aws.traits.protocols.RestJson1Trait import software.amazon.smithy.aws.traits.protocols.RestXmlTrait +import software.amazon.smithy.codegen.core.CodegenException import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.model.knowledge.HttpBinding import software.amazon.smithy.model.knowledge.HttpBindingIndex import software.amazon.smithy.model.node.ExpectationNotMetException import software.amazon.smithy.model.shapes.CollectionShape +import software.amazon.smithy.model.shapes.MapShape +import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.Shape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.traits.ErrorTrait import software.amazon.smithy.model.traits.HttpErrorTrait +import software.amazon.smithy.model.traits.MediaTypeTrait import software.amazon.smithy.rust.codegen.rustlang.Attribute import software.amazon.smithy.rust.codegen.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.rustlang.RustModule @@ -23,6 +28,7 @@ import software.amazon.smithy.rust.codegen.rustlang.RustType import software.amazon.smithy.rust.codegen.rustlang.RustWriter import software.amazon.smithy.rust.codegen.rustlang.Writable import software.amazon.smithy.rust.codegen.rustlang.asType +import software.amazon.smithy.rust.codegen.rustlang.autoDeref import software.amazon.smithy.rust.codegen.rustlang.render import software.amazon.smithy.rust.codegen.rustlang.rust import software.amazon.smithy.rust.codegen.rustlang.rustBlock @@ -47,11 +53,15 @@ import software.amazon.smithy.rust.codegen.smithy.protocols.HttpBindingDescripto import software.amazon.smithy.rust.codegen.smithy.protocols.HttpBoundProtocolBodyGenerator import software.amazon.smithy.rust.codegen.smithy.protocols.HttpLocation import software.amazon.smithy.rust.codegen.smithy.protocols.Protocol +import software.amazon.smithy.rust.codegen.util.UNREACHABLE import software.amazon.smithy.rust.codegen.util.dq import software.amazon.smithy.rust.codegen.util.expectTrait +import software.amazon.smithy.rust.codegen.util.findStreamingMember import software.amazon.smithy.rust.codegen.util.getTrait import software.amazon.smithy.rust.codegen.util.hasStreamingMember +import software.amazon.smithy.rust.codegen.util.hasTrait import software.amazon.smithy.rust.codegen.util.inputShape +import software.amazon.smithy.rust.codegen.util.isPrimitive import software.amazon.smithy.rust.codegen.util.outputShape import software.amazon.smithy.rust.codegen.util.toSnakeCase import java.util.logging.Logger @@ -99,6 +109,7 @@ private class ServerHttpProtocolImplGenerator( val httpBindingResolver = protocol.httpBindingResolver private val operationDeserModule = RustModule.private("operation_deser") private val operationSerModule = RustModule.private("operation_ser") + private val Encoder = CargoDependency.SmithyTypes(runtimeConfig).asType().member("primitive::Encoder") private val codegenScope = arrayOf( "AsyncTrait" to ServerCargoDependency.AsyncTrait.asType(), @@ -110,6 +121,7 @@ private class ServerHttpProtocolImplGenerator( "PercentEncoding" to CargoDependency.PercentEncoding.asType(), "Regex" to CargoDependency.Regex.asType(), "SerdeUrlEncoded" to ServerCargoDependency.SerdeUrlEncoded.asType(), + "SmithyHttp" to CargoDependency.SmithyHttp(runtimeConfig).asType(), "SmithyHttpServer" to CargoDependency.SmithyHttpServer(runtimeConfig).asType(), "SmithyRejection" to ServerHttpProtocolGenerator.smithyRejection(runtimeConfig), "http" to RuntimeType.http, @@ -123,13 +135,12 @@ private class ServerHttpProtocolImplGenerator( } /* - * Generation of `FromRequest` and `IntoResponse`. They are currently only implemented for non-streaming request - * and response bodies, that is, models without streaming traits - * (https://awslabs.github.io/smithy/1.0/spec/core/stream-traits.html). - * For non-streaming request bodies, we require the HTTP body to be fully read in memory before parsing or - * deserialization. From a server perspective we need a way to parse an HTTP request from `Bytes` and serialize + * Generation of `FromRequest` and `IntoResponse`. + * For non-streaming request bodies, that is, models without streaming traits + * (https://awslabs.github.io/smithy/1.0/spec/core/stream-traits.html) + * we require the HTTP body to be fully read in memory before parsing or deserialization. + * From a server perspective we need a way to parse an HTTP request from `Bytes` and serialize * an HTTP response to `Bytes`. - * TODO Add support for streaming. * These traits are the public entrypoint of the ser/de logic of the `aws-smithy-http-server` server. */ private fun RustWriter.renderTraits( @@ -138,25 +149,9 @@ private class ServerHttpProtocolImplGenerator( operationShape: OperationShape ) { val operationName = symbolProvider.toSymbol(operationShape).name - // Implement Axum `FromRequest` trait for input types. val inputName = "${operationName}${ServerHttpProtocolGenerator.OPERATION_INPUT_WRAPPER_SUFFIX}" - val fromRequest = if (operationShape.inputShape(model).hasStreamingMember(model)) { - // For streaming request bodies, we need to generate a different implementation of the `FromRequest` trait. - // It will first offer the streaming input to the parser and potentially read the body into memory - // if an error occurred or if the streaming parser indicates that it needs the full data to proceed. - """ - async fn from_request(_req: &mut #{AxumCore}::extract::RequestParts) -> Result { - todo!("Streaming support for input shapes is not yet supported in `smithy-rs`") - } - """.trimIndent() - } else { - """ - async fn from_request(req: &mut #{AxumCore}::extract::RequestParts) -> Result { - Ok($inputName(#{parse_request}(req).await?)) - } - """.trimIndent() - } + // Implement Axum `FromRequest` trait for input types. rustTemplate( """ pub struct $inputName(pub #{I}); @@ -164,12 +159,15 @@ private class ServerHttpProtocolImplGenerator( impl #{AxumCore}::extract::FromRequest for $inputName where B: #{SmithyHttpServer}::HttpBody + Send, + ${getStreamingBodyTraitBounds(operationShape)} B::Data: Send, B::Error: Into<#{SmithyHttpServer}::BoxError>, #{SmithyRejection}: From<::Error> { type Rejection = #{SmithyRejection}; - $fromRequest + async fn from_request(req: &mut #{AxumCore}::extract::RequestParts) -> Result { + Ok($inputName(#{parse_request}(req).await?)) + } } """.trimIndent(), *codegenScope, @@ -178,21 +176,20 @@ private class ServerHttpProtocolImplGenerator( ) // Implement Axum `IntoResponse` for output types. + // For streaming response bodies, the service implementer will return a `Stream` from their handler. + val outputName = "${operationName}${ServerHttpProtocolGenerator.OPERATION_OUTPUT_WRAPPER_SUFFIX}" val errorSymbol = operationShape.errorSymbol(symbolProvider) - val httpExtensions = setHttpExtensions(operationShape) - // For streaming response bodies, we need to generate a different implementation of the `IntoResponse` trait. - // The body type will have to be a `StreamBody`. The service implementer will return a `Stream` from their handler. - val intoResponseStreaming = "todo!(\"Streaming support for output shapes is not yet supported in `smithy-rs`\")" + if (operationShape.errors.isNotEmpty()) { - val intoResponseImpl = if (operationShape.outputShape(model).hasStreamingMember(model)) { - intoResponseStreaming - } else { + // The output of fallible operations is a `Result` which we convert into an + // isomorphic `enum` type we control that can in turn be converted into a response. + val intoResponseImpl = """ let mut response = match self { Self::Output(o) => { - match #{serialize_response}(&o) { + match #{serialize_response}(o) { Ok(response) => response, Err(e) => { e.into_response() @@ -214,7 +211,7 @@ private class ServerHttpProtocolImplGenerator( $httpExtensions response """.trimIndent() - } + // The output of fallible operations is a `Result` which we convert into an isomorphic `enum` type we control // that can in turn be converted into a response. rustTemplate( @@ -237,25 +234,18 @@ private class ServerHttpProtocolImplGenerator( "serialize_error" to serverSerializeError(operationShape) ) } else { - val handleSerializeOutput = if (operationShape.outputShape(model).hasStreamingMember(model)) { - intoResponseStreaming - } else { - """ - match #{serialize_response}(&self.0) { - Ok(response) => response, - Err(e) => e.into_response() - } - """.trimIndent() - } - // The output of non-fallible operations is a model type which we convert into a "wrapper" unit `struct` type - // we control that can in turn be converted into a response. + // The output of non-fallible operations is a model type which we convert into + // a "wrapper" unit `struct` type we control that can in turn be converted into a response. rustTemplate( """ pub struct $outputName(pub #{O}); ##[#{AsyncTrait}::async_trait] impl #{AxumCore}::response::IntoResponse for $outputName { fn into_response(self) -> #{AxumCore}::response::Response { - $handleSerializeOutput + match #{serialize_response}(self.0) { + Ok(response) => response, + Err(e) => e.into_response() + } } } """.trimIndent(), @@ -324,6 +314,7 @@ private class ServerHttpProtocolImplGenerator( val inputSymbol = symbolProvider.toSymbol(inputShape) val includedMembers = httpBindingResolver.requestMembers(operationShape, HttpLocation.DOCUMENT) val unusedVars = if (includedMembers.isEmpty()) "##[allow(unused_variables)] " else "" + return RuntimeType.forInlineFun(fnName, operationDeserModule) { Attribute.Custom("allow(clippy::unnecessary_wraps)").render(it) it.rustBlockTemplate( @@ -336,10 +327,11 @@ private class ServerHttpProtocolImplGenerator( > where B: #{SmithyHttpServer}::HttpBody + Send, + ${getStreamingBodyTraitBounds(operationShape)} B::Data: Send, B::Error: Into<#{SmithyHttpServer}::BoxError>, #{SmithyRejection}: From<::Error> - """, + """.trimIndent(), *codegenScope, "I" to inputSymbol, ) { @@ -361,7 +353,7 @@ private class ServerHttpProtocolImplGenerator( return RuntimeType.forInlineFun(fnName, operationSerModule) { Attribute.Custom("allow(clippy::unnecessary_wraps)").render(it) it.rustBlockTemplate( - "pub fn $fnName(output: &#{O}) -> std::result::Result<#{AxumCore}::response::Response, #{SmithyRejection}>", + "pub fn $fnName(output: #{O}) -> std::result::Result<#{AxumCore}::response::Response, #{SmithyRejection}>", *codegenScope, "O" to outputSymbol, ) { @@ -444,13 +436,6 @@ private class ServerHttpProtocolImplGenerator( operationShape: OperationShape, bindings: List, ) { - val structuredDataSerializer = protocol.structuredDataSerializer(operationShape) - structuredDataSerializer.serverOutputSerializer(operationShape).also { serializer -> - rust( - "let payload = #T(output)?;", - serializer - ) - } // avoid non-usage warnings for response Attribute.AllowUnusedMut.render(this) rustTemplate("let mut response = #{http}::Response::builder();", *codegenScope) @@ -460,6 +445,24 @@ private class ServerHttpProtocolImplGenerator( serializedValue(this) } } + val streamingMember = operationShape.outputShape(model).findStreamingMember(model) + if (streamingMember != null) { + val memberName = symbolProvider.toMemberName(streamingMember) + rustTemplate( + """ + let payload = #{SmithyHttpServer}::body::Body::wrap_stream(output.$memberName); + """, + *codegenScope, + ) + } else { + val structuredDataSerializer = protocol.structuredDataSerializer(operationShape) + structuredDataSerializer.serverOutputSerializer(operationShape).also { serializer -> + rust( + "let payload = #T(&output)?;", + serializer + ) + } + } rustTemplate( """ response.body(#{SmithyHttpServer}::body::to_boxed(payload))? @@ -475,14 +478,10 @@ private class ServerHttpProtocolImplGenerator( val operationName = symbolProvider.toSymbol(operationShape).name val member = binding.member return when (binding.location) { - HttpLocation.HEADER, HttpLocation.PREFIX_HEADERS, HttpLocation.PAYLOAD -> { - logger.warning("[rust-server-codegen] $operationName: response serialization does not currently support ${binding.location} bindings") - null - } - HttpLocation.DOCUMENT -> { - // document is handled separately - null - } + HttpLocation.HEADER -> writable { serverRenderHeaderGenerator(this, binding) } + HttpLocation.PREFIX_HEADERS -> writable { serverRenderPrefixHeadersGenerator(this, binding) } + HttpLocation.PAYLOAD -> { null } // payload is handled separately + HttpLocation.DOCUMENT -> { null } // document is handled separately HttpLocation.RESPONSE_CODE -> writable { val memberName = symbolProvider.toMemberName(member) rustTemplate( @@ -508,6 +507,7 @@ private class ServerHttpProtocolImplGenerator( bindings: List, ) { val structuredDataParser = protocol.structuredDataParser(operationShape) + val streamingMember = inputShape.findStreamingMember(model) Attribute.AllowUnusedMut.render(this) rust("let mut input = #T::default();", inputShape.builderSymbol(symbolProvider)) val parser = structuredDataParser.serverInputParser(operationShape) @@ -525,6 +525,15 @@ private class ServerHttpProtocolImplGenerator( *codegenScope, "parser" to parser, ) + } else if (streamingMember != null) { + // TODO Need contentTypeCheck here? + rustTemplate( + """ + let body = request.take_body().ok_or(#{SmithyHttpServer}::rejection::BodyAlreadyExtracted)?; + input = input.${streamingMember.setterName()}(Some(body.into())); + """.trimIndent(), + *codegenScope + ) } for (binding in bindings) { val member = binding.member @@ -801,6 +810,109 @@ private class ServerHttpProtocolImplGenerator( ) } + private fun serverRenderHeaderGenerator(writer: RustWriter, binding: HttpBindingDescriptor) { + val memberShape = binding.member + val memberType = model.expectShape(memberShape.target) + val memberSymbol = symbolProvider.toSymbol(memberShape) + val memberName = symbolProvider.toMemberName(memberShape) + writer.ifSet(memberType, memberSymbol, "&output.$memberName") { field -> + writer.listForEach(memberType, field) { innerField, targetId -> + val innerMemberType = model.expectShape(targetId) + if (innerMemberType.isPrimitive()) { + rust("let mut encoder = #T::from(${autoDeref(innerField)});", Encoder) + } + val formatted = headerFmtFun(writer, innerMemberType, memberShape, innerField) + val safeName = safeName("formatted") + writer.write("let $safeName = $formatted;") + writer.rustBlock("if !$safeName.is_empty()") { + writer.rustTemplate( + """ + use std::convert::TryFrom; + let header_value = $safeName; + let header_value = http::header::HeaderValue::try_from(&*header_value) + .map_err(|err| { + #{SmithyHttpServer}::rejection::Serialize::from( + ${(memberName + " cannot be used as a header value").dq()} + ) + })?; + response = response.header(${binding.locationName.dq()}, header_value); + """.trimIndent(), + *codegenScope + ) + } + } + } + } + + private fun serverRenderPrefixHeadersGenerator(writer: RustWriter, binding: HttpBindingDescriptor) { + val memberShape = binding.member + val memberType = model.expectShape(memberShape.target) + val memberSymbol = symbolProvider.toSymbol(memberShape) + val memberName = symbolProvider.toMemberName(memberShape) + val target = when (memberType) { + is CollectionShape -> model.expectShape(memberType.member.target) + is MapShape -> model.expectShape(memberType.value.target) + else -> UNREACHABLE("unexpected member for prefix headers: $memberType") + } + writer.ifSet(memberType, memberSymbol, "&output.$memberName") { field -> + writer.rustTemplate( + """ + for (k, v) in $field { + use std::str::FromStr; + let header_name = http::header::HeaderName::from_str( + &format!("{}{}", ${binding.locationName.dq()}, &k) + ) + .map_err(|err| { + #{SmithyHttpServer}::rejection::Serialize::from( + ${(memberName + " cannot be used as a header name").dq()} + ) + })?; + use std::convert::TryFrom; + let header_value = ${headerFmtFun(writer, target, memberShape, "v")}; + let header_value = http::header::HeaderValue::try_from(header_value) + .map_err(|err| { + #{SmithyHttpServer}::rejection::Serialize::from( + ${(memberName + " cannot be used as a header value").dq()} + ) + })?; + response = response.header(header_name, header_value); + } + """.trimIndent(), + *codegenScope + ) + } + } + + /** + * Format [member] in the when used as an HTTP header + */ + private fun headerFmtFun(writer: RustWriter, target: Shape, member: MemberShape, targetName: String): String { + return when { + target.isStringShape -> { + if (target.hasTrait()) { + val func = writer.format(RuntimeType.Base64Encode(runtimeConfig)) + "$func(&$targetName)" + } else { + "AsRef::::as_ref($targetName)" + } + } + target.isTimestampShape -> { + val index = HttpBindingIndex.of(model) + val timestampFormat = + index.determineTimestampFormat(member, HttpBinding.Location.HEADER, protocol.defaultTimestampFormat) + val timestampFormatType = RuntimeType.TimestampFormat(runtimeConfig, timestampFormat) + "$targetName.fmt(${writer.format(timestampFormatType)})?" + } + target.isListShape || target.isMemberShape -> { + throw IllegalArgumentException("lists should be handled at a higher level") + } + target.isPrimitive() -> { + "encoder.encode()" + } + else -> throw CodegenException("unexpected shape: $target") + } + } + private fun generateParsePercentEncodedStrFn(binding: HttpBindingDescriptor): RuntimeType { // HTTP bindings we support that contain percent-encoded data. check(binding.location == HttpLocation.LABEL || binding.location == HttpLocation.QUERY) @@ -906,4 +1018,12 @@ private class ServerHttpProtocolImplGenerator( } } } + + private fun getStreamingBodyTraitBounds(operationShape: OperationShape): String { + if (operationShape.inputShape(model).hasStreamingMember(model)) { + return "B: Into<#{SmithyHttp}::byte_stream::ByteStream>," + } else { + return "" + } + } } diff --git a/rust-runtime/aws-smithy-http-server/src/rejection.rs b/rust-runtime/aws-smithy-http-server/src/rejection.rs index e350d2922d6..97cfad86fe0 100644 --- a/rust-runtime/aws-smithy-http-server/src/rejection.rs +++ b/rust-runtime/aws-smithy-http-server/src/rejection.rs @@ -186,6 +186,12 @@ impl From for SmithyRejection { } } +impl From for SmithyRejection { + fn from(err: aws_smithy_types::date_time::DateTimeFormatError) -> Self { + SmithyRejection::Serialize(Serialize::from_err(err)) + } +} + impl From for SmithyRejection { fn from(err: aws_smithy_types::primitive::PrimitiveParseError) -> Self { SmithyRejection::Deserialize(Deserialize::from_err(err)) diff --git a/rust-runtime/aws-smithy-http/src/byte_stream.rs b/rust-runtime/aws-smithy-http/src/byte_stream.rs index 00b8cecb00c..fdd7ec91243 100644 --- a/rust-runtime/aws-smithy-http/src/byte_stream.rs +++ b/rust-runtime/aws-smithy-http/src/byte_stream.rs @@ -326,6 +326,14 @@ impl From> for ByteStream { } } +impl From for ByteStream { + fn from(input: hyper::Body) -> Self { + ByteStream::new(SdkBody::from_dyn( + input.map_err(|e| e.into_cause().unwrap()).boxed(), + )) + } +} + #[derive(Debug)] pub struct Error(Box);