diff --git a/aws/rust-runtime/aws-inlineable/src/http_response_checksum.rs b/aws/rust-runtime/aws-inlineable/src/http_response_checksum.rs index 2b3c622f24..7c59b6899f 100644 --- a/aws/rust-runtime/aws-inlineable/src/http_response_checksum.rs +++ b/aws/rust-runtime/aws-inlineable/src/http_response_checksum.rs @@ -10,7 +10,7 @@ use aws_smithy_checksums::ChecksumAlgorithm; use aws_smithy_runtime_api::box_error::BoxError; use aws_smithy_runtime_api::client::interceptors::context::{ - BeforeDeserializationInterceptorContextMut, BeforeSerializationInterceptorContextRef, Input, + BeforeDeserializationInterceptorContextMut, BeforeSerializationInterceptorContextMut, Input, }; use aws_smithy_runtime_api::client::interceptors::Intercept; use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents; @@ -28,12 +28,13 @@ impl Storable for ResponseChecksumInterceptorState { type Storer = StoreReplace; } -pub(crate) struct ResponseChecksumInterceptor { +pub(crate) struct ResponseChecksumInterceptor { response_algorithms: &'static [&'static str], validation_enabled: VE, + checksum_mutator: CM, } -impl fmt::Debug for ResponseChecksumInterceptor { +impl fmt::Debug for ResponseChecksumInterceptor { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("ResponseChecksumInterceptor") .field("response_algorithms", &self.response_algorithms) @@ -41,32 +42,36 @@ impl fmt::Debug for ResponseChecksumInterceptor { } } -impl ResponseChecksumInterceptor { +impl ResponseChecksumInterceptor { pub(crate) fn new( response_algorithms: &'static [&'static str], validation_enabled: VE, + checksum_mutator: CM, ) -> Self { Self { response_algorithms, validation_enabled, + checksum_mutator, } } } -impl Intercept for ResponseChecksumInterceptor +impl Intercept for ResponseChecksumInterceptor where VE: Fn(&Input) -> bool + Send + Sync, + CM: Fn(&mut Input, &ConfigBag) -> Result<(), BoxError> + Send + Sync, { fn name(&self) -> &'static str { "ResponseChecksumInterceptor" } - fn read_before_serialization( + fn modify_before_serialization( &self, - context: &BeforeSerializationInterceptorContextRef<'_>, + context: &mut BeforeSerializationInterceptorContextMut<'_>, _runtime_components: &RuntimeComponents, cfg: &mut ConfigBag, ) -> Result<(), BoxError> { + let _ = (self.checksum_mutator)(context.input_mut(), cfg); let validation_enabled = (self.validation_enabled)(context.input()); let mut layer = Layer::new("ResponseChecksumInterceptor"); diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsCodegenDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsCodegenDecorator.kt index ed47f1a452..f27ca2c45b 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsCodegenDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsCodegenDecorator.kt @@ -44,7 +44,6 @@ val DECORATORS: List = SigV4AuthDecorator(), HttpRequestChecksumDecorator(), HttpResponseChecksumDecorator(), - HttpResponseChecksumMutationInterceptorDecorator(), IntegrationTestDecorator(), AwsFluentClientDecorator(), CrateLicenseDecorator(), diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/HttpRequestChecksumDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/HttpRequestChecksumDecorator.kt index ffd7b81062..ee14bde53f 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/HttpRequestChecksumDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/HttpRequestChecksumDecorator.kt @@ -7,6 +7,7 @@ package software.amazon.smithy.rustsdk import software.amazon.smithy.aws.traits.HttpChecksumTrait import software.amazon.smithy.model.knowledge.TopDownIndex +import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext import software.amazon.smithy.rust.codegen.client.smithy.configReexport @@ -347,3 +348,15 @@ fun serviceHasHttpChecksumOperation(codegenContext: ClientCodegenContext): Boole val ops = index.getContainedOperations(codegenContext.serviceShape.id) return ops.any { it.hasTrait() } } + +/** + * Get the top-level operation input member used to opt-in to best-effort validation of a checksum returned in + * the HTTP response of the operation. + */ +fun HttpChecksumTrait.requestAlgorithmMemberShape( + codegenContext: ClientCodegenContext, + operationShape: OperationShape, +): MemberShape? { + val requestAlgorithmMember = this.requestAlgorithmMember.orNull() ?: return null + return operationShape.inputShape(codegenContext.model).expectMember(requestAlgorithmMember) +} diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/HttpResponseChecksumDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/HttpResponseChecksumDecorator.kt index 98c4b37886..0f0ef550f9 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/HttpResponseChecksumDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/HttpResponseChecksumDecorator.kt @@ -120,6 +120,7 @@ class HttpResponseChecksumCustomization( val checksumTrait = operationShape.getTrait() ?: return@writable val requestValidationModeMember = checksumTrait.requestValidationModeMember(codegenContext, operationShape) ?: return@writable + val requestValidationModeName = codegenContext.symbolProvider.toSymbol(requestValidationModeMember).name val requestValidationModeMemberInner = if (requestValidationModeMember.isOptional) { codegenContext.model.expectShape(requestValidationModeMember.target) @@ -150,6 +151,34 @@ class HttpResponseChecksumCustomization( } let input: &#{OperationInput} = input.downcast_ref().expect("correct type"); matches!(input.$validationModeName(), #{Some}(#{ValidationModeShape}::Enabled)) + }, + |input: &mut #{Input}, cfg: &#{ConfigBag}| { + let input = input + .downcast_mut::<#{OperationInputType}>() + .ok_or("failed to downcast to #{OperationInputType}")?; + + let request_validation_enabled = + matches!(input.$requestValidationModeName(), Some(#{ValidationModeShape}::Enabled)); + + if !request_validation_enabled { + // This value is set by the user on the SdkConfig to indicate their preference + let response_checksum_validation = cfg + .load::<#{ResponseChecksumValidation}>() + .unwrap_or(&#{ResponseChecksumValidation}::WhenSupported); + + // If validation setting is WhenSupported (or unknown) we enable response checksum + // validation. If it is WhenRequired we do not enable (since there is no way to + // indicate that a response checksum is required). + ##[allow(clippy::wildcard_in_or_patterns)] + match response_checksum_validation { + #{ResponseChecksumValidation}::WhenRequired => {} + #{ResponseChecksumValidation}::WhenSupported | _ => { + input.$requestValidationModeName = Some(#{ValidationModeShape}::Enabled); + } + } + } + + #{Ok}(()) } ) """, @@ -163,14 +192,20 @@ class HttpResponseChecksumCustomization( codegenContext.symbolProvider.toSymbol( requestValidationModeMemberInner, ), - ) - } - section.registerInterceptor(codegenContext.runtimeConfig, this) { - val interceptorName = "${operationName}HttpResponseChecksumMutationInterceptor" - rustTemplate( - """ - $interceptorName - """, + "OperationInputType" to + codegenContext.symbolProvider.toSymbol( + operationShape.inputShape( + codegenContext.model, + ), + ), + "ValidationModeShape" to + codegenContext.symbolProvider.toSymbol( + requestValidationModeMemberInner, + ), + "ResponseChecksumValidation" to + CargoDependency.smithyTypes(codegenContext.runtimeConfig).toType() + .resolve("checksum_config::ResponseChecksumValidation"), + "ConfigBag" to RuntimeType.configBag(codegenContext.runtimeConfig), ) } } diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/HttpResponseChecksumMutationInterceptorGenerator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/HttpResponseChecksumMutationInterceptorGenerator.kt deleted file mode 100644 index 413421fa06..0000000000 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/HttpResponseChecksumMutationInterceptorGenerator.kt +++ /dev/null @@ -1,180 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -package software.amazon.smithy.rustsdk - -import software.amazon.smithy.aws.traits.HttpChecksumTrait -import software.amazon.smithy.model.shapes.MemberShape -import software.amazon.smithy.model.shapes.OperationShape -import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext -import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator -import software.amazon.smithy.rust.codegen.client.smithy.generators.OperationCustomization -import software.amazon.smithy.rust.codegen.client.smithy.generators.OperationSection -import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency -import software.amazon.smithy.rust.codegen.core.rustlang.Writable -import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate -import software.amazon.smithy.rust.codegen.core.rustlang.writable -import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType -import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope -import software.amazon.smithy.rust.codegen.core.util.dq -import software.amazon.smithy.rust.codegen.core.util.expectMember -import software.amazon.smithy.rust.codegen.core.util.getTrait -import software.amazon.smithy.rust.codegen.core.util.inputShape -import software.amazon.smithy.rust.codegen.core.util.orNull - -/** - * This class generates an interceptor for operations with the `httpChecksum` trait that support response validations. - * In the `modify_before_serialization` hook the interceptor checks the operation's `requestValidationModeMember`. If - * that member is `ENABLED` then we end early and do nothing. If it is not `ENABLED` then it checks the - * `response_checksum_validation` set by the user on the SdkConfig. If that is `WhenSupported` (or unknown) then we - * update the `requestValidationModeMember` to `ENABLED` and if the value is `WhenRequired` we end without modifying - * anything since there is no way to indicate that a response checksum is required. - * - * Note that although there is an existing inlineable `ResponseChecksumInterceptor` this logic could not live there. - * Since that interceptor is inlineable it does not have access to the name of the `requestValidationModeMember` on the - * operation's input, and in certain circumstances we need to mutate that member on the input before serializing the - * request and sending it to the service. - */ -class HttpResponseChecksumMutationInterceptorDecorator : ClientCodegenDecorator { - override val name: String = "HttpResponseChecksumMutationInterceptorGenerator" - override val order: Byte = 0 - - override fun operationCustomizations( - codegenContext: ClientCodegenContext, - operation: OperationShape, - baseCustomizations: List, - ): List { - // If the operation doesn't have the HttpChecksum trait we return early - val checksumTrait = operation.getTrait() ?: return baseCustomizations - // Also return early if there is no requestValidationModeMember on the trait - val requestValidationModeMember = - (checksumTrait.requestValidationModeMemberShape(codegenContext, operation) ?: return baseCustomizations) - - return baseCustomizations + - listOf( - InterceptorSection( - codegenContext, - operation, - requestValidationModeMember, - checksumTrait, - ), - ) - } - - private class InterceptorSection( - private val codegenContext: ClientCodegenContext, - private val operation: OperationShape, - private val requestValidationModeMember: MemberShape, - private val checksumTrait: HttpChecksumTrait, - ) : OperationCustomization() { - override fun section(section: OperationSection): Writable = - writable { - if (section is OperationSection.RuntimePluginSupportingTypes) { - val model = codegenContext.model - val symbolProvider = codegenContext.symbolProvider - val codegenScope = - codegenContext.runtimeConfig.let { rc -> - val runtimeApi = CargoDependency.smithyRuntimeApiClient(rc).toType() - val interceptors = runtimeApi.resolve("client::interceptors") - - arrayOf( - *preludeScope, - "BoxError" to RuntimeType.boxError(rc), - "ConfigBag" to RuntimeType.configBag(rc), - "Intercept" to RuntimeType.intercept(rc), - "BeforeSerializationInterceptorContextMut" to - RuntimeType.beforeSerializationInterceptorContextMut( - rc, - ), - "Input" to interceptors.resolve("context::Input"), - "Output" to interceptors.resolve("context::Output"), - "Error" to interceptors.resolve("context::Error"), - "RuntimeComponents" to RuntimeType.runtimeComponents(rc), - "ResponseChecksumValidation" to - CargoDependency.smithyTypes(rc).toType() - .resolve("checksum_config::ResponseChecksumValidation"), - ) - } - - val requestValidationModeName = symbolProvider.toSymbol(requestValidationModeMember).name - val requestValidationModeMemberInner = - if (requestValidationModeMember.isOptional) { - codegenContext.model.expectShape(requestValidationModeMember.target) - } else { - requestValidationModeMember - } - - val operationName = symbolProvider.toSymbol(operation).name - val interceptorName = "${operationName}HttpResponseChecksumMutationInterceptor" - - rustTemplate( - """ - ##[derive(Debug)] - struct $interceptorName; - - impl #{Intercept} for $interceptorName { - fn name(&self) -> &'static str { - ${interceptorName.dq()} - } - - fn modify_before_serialization( - &self, - context: &mut #{BeforeSerializationInterceptorContextMut}<'_, #{Input}, #{Output}, #{Error}>, - _runtime_comps: &#{RuntimeComponents}, - cfg: &mut #{ConfigBag}, - ) -> #{Result}<(), #{BoxError}> { - let input = context - .input_mut() - .downcast_mut::<#{OperationInputType}>() - .ok_or("failed to downcast to #{OperationInputType}")?; - - let request_validation_enabled = - matches!(input.$requestValidationModeName(), Some(#{ValidationModeShape}::Enabled)); - - if !request_validation_enabled { - // This value is set by the user on the SdkConfig to indicate their preference - let response_checksum_validation = cfg - .load::<#{ResponseChecksumValidation}>() - .unwrap_or(&#{ResponseChecksumValidation}::WhenSupported); - - // If validation setting is WhenSupported (or unknown) we enable response checksum - // validation. If it is WhenRequired we do not enable (since there is no way to - // indicate that a response checksum is required). - ##[allow(clippy::wildcard_in_or_patterns)] - match response_checksum_validation { - #{ResponseChecksumValidation}::WhenRequired => {} - #{ResponseChecksumValidation}::WhenSupported | _ => { - input.$requestValidationModeName = Some(#{ValidationModeShape}::Enabled); - } - } - } - - #{Ok}(()) - } - } - """, - *codegenScope, - "OperationInputType" to codegenContext.symbolProvider.toSymbol(operation.inputShape(model)), - "ValidationModeShape" to - codegenContext.symbolProvider.toSymbol( - requestValidationModeMemberInner, - ), - ) - } - } - } -} - -/** - * Get the top-level operation input member used to opt in to best-effort validation of a checksum returned in - * the HTTP response of the operation. - */ -fun HttpChecksumTrait.requestValidationModeMemberShape( - codegenContext: ClientCodegenContext, - operationShape: OperationShape, -): MemberShape? { - val requestValidationModeMember = this.requestValidationModeMember.orNull() ?: return null - return operationShape.inputShape(codegenContext.model).expectMember(requestValidationModeMember) -}