diff --git a/.editorconfig b/.editorconfig index 55bfe517da2..f87c1664096 100644 --- a/.editorconfig +++ b/.editorconfig @@ -17,3 +17,5 @@ ktlint_standard_filename = disabled ktlint_standard_max-line-length = disabled ktlint_standard_argument-list-wrapping = disabled ktlint_standard_parameter-list-wrapping = disabled +ktlint_standard_property-naming = disabled +ktlint_standard_comment-wrapping = disabled diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 16924be2112..b29005e2067 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,8 +9,6 @@ repos: - repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks rev: v2.11.0 hooks: - - id: pretty-format-kotlin - args: [--autofix, --ktlint-version, 0.48.2] - id: pretty-format-yaml args: [--autofix, --indent, '2'] - id: pretty-format-rust @@ -18,6 +16,11 @@ repos: files: ^.*\.rs$ - repo: local hooks: + - id: ktlint + name: Ktlint + entry: ./.pre-commit-hooks/ktlint.sh + language: system + files: ^.*\.kt$ - id: kotlin-block-quotes name: Kotlin Block Quotes entry: ./.pre-commit-hooks/kotlin-block-quotes.py diff --git a/.pre-commit-hooks/kotlin-block-quotes.py b/.pre-commit-hooks/kotlin-block-quotes.py index 8a05f49550d..adbbbcffa85 100755 --- a/.pre-commit-hooks/kotlin-block-quotes.py +++ b/.pre-commit-hooks/kotlin-block-quotes.py @@ -64,7 +64,7 @@ def starts_or_ends_block_quote(line, inside_block_quotes): # Returns the indentation of a line def line_indent(line): - indent = re.search("[^\s]", line) + indent = re.search(r"[^\s]", line) if indent != None: return indent.start(0) else: @@ -72,7 +72,7 @@ def line_indent(line): # Changes the indentation of a line def adjust_indent(line, indent): - old_indent = re.search("[^\s]", line) + old_indent = re.search(r"[^\s]", line) if old_indent == None: return line line = line[old_indent.start(0):] @@ -168,7 +168,8 @@ def fix_file(file_name): print("INFO: Fixed indentation in `" + file_name + "`.") return True else: - print("INFO: `" + file_name + "` is fine.") + # This print is useful when debugging this script, but spammy otherwise. Leave it commented. + # print("INFO: `" + file_name + "` is fine.") return False class SelfTest(unittest.TestCase): diff --git a/.pre-commit-hooks/ktlint.sh b/.pre-commit-hooks/ktlint.sh new file mode 100755 index 00000000000..18cfe0772f4 --- /dev/null +++ b/.pre-commit-hooks/ktlint.sh @@ -0,0 +1,13 @@ +#!/bin/bash +# +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +# + +set -e + +cd "$(git rev-parse --show-toplevel)" +# `-q`: run gradle in quiet mode +# `--console plain`: Turn off the fancy terminal printing in gradle +# `2>/dev/null`: Suppress the build success/failure output at the end since pre-commit will report failures +./gradlew -q --console plain ktlintPreCommit -DktlintPreCommitArgs="$*" 2>/dev/null diff --git a/aws/sdk-codegen/build.gradle.kts b/aws/sdk-codegen/build.gradle.kts index 6b255bd5a30..9221467a6d7 100644 --- a/aws/sdk-codegen/build.gradle.kts +++ b/aws/sdk-codegen/build.gradle.kts @@ -31,7 +31,7 @@ dependencies { } tasks.compileKotlin { - kotlinOptions.jvmTarget = "1.8" + kotlinOptions.jvmTarget = "11" } // Reusable license copySpec @@ -67,7 +67,7 @@ if (isTestingEnabled.toBoolean()) { } tasks.compileTestKotlin { - kotlinOptions.jvmTarget = "1.8" + kotlinOptions.jvmTarget = "11" } tasks.test { diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsCargoDependency.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsCargoDependency.kt index c803040f7d8..001f17984e1 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsCargoDependency.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsCargoDependency.kt @@ -9,15 +9,23 @@ import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig import software.amazon.smithy.rust.codegen.core.smithy.crateLocation -fun RuntimeConfig.awsRuntimeCrate(name: String, features: Set = setOf()): CargoDependency = - CargoDependency(name, awsRoot().crateLocation(name), features = features) +fun RuntimeConfig.awsRuntimeCrate( + name: String, + features: Set = setOf(), +): CargoDependency = CargoDependency(name, awsRoot().crateLocation(name), features = features) object AwsCargoDependency { fun awsConfig(runtimeConfig: RuntimeConfig) = runtimeConfig.awsRuntimeCrate("aws-config") + fun awsCredentialTypes(runtimeConfig: RuntimeConfig) = runtimeConfig.awsRuntimeCrate("aws-credential-types") + fun awsHttp(runtimeConfig: RuntimeConfig) = runtimeConfig.awsRuntimeCrate("aws-http") + fun awsRuntime(runtimeConfig: RuntimeConfig) = runtimeConfig.awsRuntimeCrate("aws-runtime") + fun awsRuntimeApi(runtimeConfig: RuntimeConfig) = runtimeConfig.awsRuntimeCrate("aws-runtime-api") + fun awsSigv4(runtimeConfig: RuntimeConfig) = runtimeConfig.awsRuntimeCrate("aws-sigv4") + fun awsTypes(runtimeConfig: RuntimeConfig) = runtimeConfig.awsRuntimeCrate("aws-types") } diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsCodegenDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsCodegenDecorator.kt index 5e694c1cead..3bc616a9b8f 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsCodegenDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsCodegenDecorator.kt @@ -27,56 +27,55 @@ import software.amazon.smithy.rustsdk.endpoints.AwsEndpointsStdLib import software.amazon.smithy.rustsdk.endpoints.OperationInputTestDecorator import software.amazon.smithy.rustsdk.endpoints.RequireEndpointRules -val DECORATORS: List = listOf( - // General AWS Decorators +val DECORATORS: List = listOf( - CredentialsProviderDecorator(), - RegionDecorator(), - RequireEndpointRules(), - UserAgentDecorator(), - SigV4AuthDecorator(), - HttpRequestChecksumDecorator(), - HttpResponseChecksumDecorator(), - RetryClassifierDecorator(), - IntegrationTestDecorator(), - AwsFluentClientDecorator(), - CrateLicenseDecorator(), - SdkConfigDecorator(), - ServiceConfigDecorator(), - AwsPresigningDecorator(), - AwsCrateDocsDecorator(), - AwsEndpointsStdLib(), - *PromotedBuiltInsDecorators, - GenericSmithySdkConfigSettings(), - OperationInputTestDecorator(), - AwsRequestIdDecorator(), - DisabledAuthDecorator(), - RecursionDetectionDecorator(), - InvocationIdDecorator(), - RetryInformationHeaderDecorator(), - RemoveDefaultsDecorator(), - ), - - // Service specific decorators - ApiGatewayDecorator().onlyApplyTo("com.amazonaws.apigateway#BackplaneControlService"), - Ec2Decorator().onlyApplyTo("com.amazonaws.ec2#AmazonEC2"), - GlacierDecorator().onlyApplyTo("com.amazonaws.glacier#Glacier"), - Route53Decorator().onlyApplyTo("com.amazonaws.route53#AWSDnsV20130401"), - "com.amazonaws.s3#AmazonS3".applyDecorators( - S3Decorator(), - S3ExtendedRequestIdDecorator(), - ), - S3ControlDecorator().onlyApplyTo("com.amazonaws.s3control#AWSS3ControlServiceV20180820"), - STSDecorator().onlyApplyTo("com.amazonaws.sts#AWSSecurityTokenServiceV20110615"), - SSODecorator().onlyApplyTo("com.amazonaws.sso#SWBPortalService"), - TimestreamDecorator().onlyApplyTo("com.amazonaws.timestreamwrite#Timestream_20181101"), - TimestreamDecorator().onlyApplyTo("com.amazonaws.timestreamquery#Timestream_20181101"), - - // Only build docs-rs for linux to reduce load on docs.rs - listOf( - DocsRsMetadataDecorator(DocsRsMetadataSettings(targets = listOf("x86_64-unknown-linux-gnu"), allFeatures = true)), - ), -).flatten() + // General AWS Decorators + listOf( + CredentialsProviderDecorator(), + RegionDecorator(), + RequireEndpointRules(), + UserAgentDecorator(), + SigV4AuthDecorator(), + HttpRequestChecksumDecorator(), + HttpResponseChecksumDecorator(), + RetryClassifierDecorator(), + IntegrationTestDecorator(), + AwsFluentClientDecorator(), + CrateLicenseDecorator(), + SdkConfigDecorator(), + ServiceConfigDecorator(), + AwsPresigningDecorator(), + AwsCrateDocsDecorator(), + AwsEndpointsStdLib(), + *PromotedBuiltInsDecorators, + GenericSmithySdkConfigSettings(), + OperationInputTestDecorator(), + AwsRequestIdDecorator(), + DisabledAuthDecorator(), + RecursionDetectionDecorator(), + InvocationIdDecorator(), + RetryInformationHeaderDecorator(), + RemoveDefaultsDecorator(), + ), + // Service specific decorators + ApiGatewayDecorator().onlyApplyTo("com.amazonaws.apigateway#BackplaneControlService"), + Ec2Decorator().onlyApplyTo("com.amazonaws.ec2#AmazonEC2"), + GlacierDecorator().onlyApplyTo("com.amazonaws.glacier#Glacier"), + Route53Decorator().onlyApplyTo("com.amazonaws.route53#AWSDnsV20130401"), + "com.amazonaws.s3#AmazonS3".applyDecorators( + S3Decorator(), + S3ExtendedRequestIdDecorator(), + ), + S3ControlDecorator().onlyApplyTo("com.amazonaws.s3control#AWSS3ControlServiceV20180820"), + STSDecorator().onlyApplyTo("com.amazonaws.sts#AWSSecurityTokenServiceV20110615"), + SSODecorator().onlyApplyTo("com.amazonaws.sso#SWBPortalService"), + TimestreamDecorator().onlyApplyTo("com.amazonaws.timestreamwrite#Timestream_20181101"), + TimestreamDecorator().onlyApplyTo("com.amazonaws.timestreamquery#Timestream_20181101"), + // Only build docs-rs for linux to reduce load on docs.rs + listOf( + DocsRsMetadataDecorator(DocsRsMetadataSettings(targets = listOf("x86_64-unknown-linux-gnu"), allFeatures = true)), + ), + ).flatten() class AwsCodegenDecorator : CombinedClientCodegenDecorator(DECORATORS) { override val name: String = "AwsSdkCodegenDecorator" diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsCrateDocsDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsCrateDocsDecorator.kt index bce5c6f0ef5..ffa58796334 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsCrateDocsDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsCrateDocsDecorator.kt @@ -52,28 +52,38 @@ class AwsCrateDocsDecorator : ClientCodegenDecorator { override fun libRsCustomizations( codegenContext: ClientCodegenContext, baseCustomizations: List, - ): List = baseCustomizations + listOf( - object : LibRsCustomization() { - override fun section(section: LibRsSection): Writable = when { - section is LibRsSection.ModuleDoc && section.subsection is ModuleDocSection.ServiceDocs -> writable { - // Include README contents in crate docs if they are to be generated - if (generateReadme(codegenContext)) { - AwsCrateDocGenerator(codegenContext).generateCrateDocComment()(this) - } - } - - else -> emptySection - } - }, - ) + ): List = + baseCustomizations + + listOf( + object : LibRsCustomization() { + override fun section(section: LibRsSection): Writable = + when { + section is LibRsSection.ModuleDoc && section.subsection is ModuleDocSection.ServiceDocs -> + writable { + // Include README contents in crate docs if they are to be generated + if (generateReadme(codegenContext)) { + AwsCrateDocGenerator(codegenContext).generateCrateDocComment()(this) + } + } + + else -> emptySection + } + }, + ) - override fun extras(codegenContext: ClientCodegenContext, rustCrate: RustCrate) { + override fun extras( + codegenContext: ClientCodegenContext, + rustCrate: RustCrate, + ) { if (generateReadme(codegenContext)) { AwsCrateDocGenerator(codegenContext).generateReadme(rustCrate) } } - override fun clientConstructionDocs(codegenContext: ClientCodegenContext, baseDocs: Writable): Writable = + override fun clientConstructionDocs( + codegenContext: ClientCodegenContext, + baseDocs: Writable, + ): Writable = writable { val serviceName = codegenContext.serviceShape.serviceNameOrDefault("the service") docs("Client for calling $serviceName.") @@ -97,142 +107,152 @@ internal class AwsCrateDocGenerator(private val codegenContext: ClientCodegenCon ?: throw IllegalStateException("missing `awsConfigVersion` codegen setting") } - private fun RustWriter.template(asComments: Boolean, text: String, vararg args: Pair) = - when (asComments) { - true -> containerDocsTemplate(text, *args) - else -> rawTemplate(text + "\n", *args) - } + private fun RustWriter.template( + asComments: Boolean, + text: String, + vararg args: Pair, + ) = when (asComments) { + true -> containerDocsTemplate(text, *args) + else -> rawTemplate(text + "\n", *args) + } internal fun docText( includeHeader: Boolean, includeLicense: Boolean, asComments: Boolean, - ): Writable = writable { - val moduleVersion = codegenContext.settings.moduleVersion - check(moduleVersion.isNotEmpty() && moduleVersion[0].isDigit()) - - val moduleName = codegenContext.settings.moduleName - val stableVersion = !moduleVersion.startsWith("0.") - val description = normalizeDescription( - codegenContext.moduleName, - codegenContext.settings.getService(codegenContext.model).getTrait()?.value ?: "", - ) - val snakeCaseModuleName = moduleName.replace('-', '_') - val shortModuleName = moduleName.removePrefix("aws-sdk-") + ): Writable = + writable { + val moduleVersion = codegenContext.settings.moduleVersion + check(moduleVersion.isNotEmpty() && moduleVersion[0].isDigit()) + + val moduleName = codegenContext.settings.moduleName + val stableVersion = !moduleVersion.startsWith("0.") + val description = + normalizeDescription( + codegenContext.moduleName, + codegenContext.settings.getService(codegenContext.model).getTrait()?.value ?: "", + ) + val snakeCaseModuleName = moduleName.replace('-', '_') + val shortModuleName = moduleName.removePrefix("aws-sdk-") + + if (includeHeader) { + template(asComments, escape("# $moduleName\n")) + } - if (includeHeader) { - template(asComments, escape("# $moduleName\n")) - } + // TODO(PostGA): Remove warning banner conditionals. + // NOTE: when you change this, you must also change SDK_README.md.hb + if (!stableVersion) { + template( + asComments, + """ + **Please Note: The SDK is currently released as a developer preview, without support or assistance for use + on production workloads. Any use in production is at your own risk.**${"\n"} + """.trimIndent(), + ) + } - // TODO(PostGA): Remove warning banner conditionals. - // NOTE: when you change this, you must also change SDK_README.md.hb - if (!stableVersion) { + if (description.isNotBlank()) { + template(asComments, escape("$description\n")) + } + + val compileExample = AwsDocs.canRelyOnAwsConfig(codegenContext) + val exampleMode = if (compileExample) "no_run" else "ignore" template( asComments, """ - **Please Note: The SDK is currently released as a developer preview, without support or assistance for use - on production workloads. Any use in production is at your own risk.**${"\n"} - """.trimIndent(), - ) - } - - if (description.isNotBlank()) { - template(asComments, escape("$description\n")) - } - - val compileExample = AwsDocs.canRelyOnAwsConfig(codegenContext) - val exampleMode = if (compileExample) "no_run" else "ignore" - template( - asComments, - """ - #### Getting Started + #### Getting Started - > Examples are available for many services and operations, check out the - > [examples folder in GitHub](https://github.com/awslabs/aws-sdk-rust/tree/main/examples). + > Examples are available for many services and operations, check out the + > [examples folder in GitHub](https://github.com/awslabs/aws-sdk-rust/tree/main/examples). - The SDK provides one crate per AWS service. You must add [Tokio](https://crates.io/crates/tokio) - as a dependency within your Rust project to execute asynchronous code. To add `$moduleName` to - your project, add the following to your **Cargo.toml** file: + The SDK provides one crate per AWS service. You must add [Tokio](https://crates.io/crates/tokio) + as a dependency within your Rust project to execute asynchronous code. To add `$moduleName` to + your project, add the following to your **Cargo.toml** file: - ```toml - [dependencies] - aws-config = { version = "$awsConfigVersion", features = ["behavior-version-latest"] } - $moduleName = "$moduleVersion" - tokio = { version = "1", features = ["full"] } - ``` + ```toml + [dependencies] + aws-config = { version = "$awsConfigVersion", features = ["behavior-version-latest"] } + $moduleName = "$moduleVersion" + tokio = { version = "1", features = ["full"] } + ``` - Then in code, a client can be created with the following: + Then in code, a client can be created with the following: - ```rust,$exampleMode - use $snakeCaseModuleName as $shortModuleName; + ```rust,$exampleMode + use $snakeCaseModuleName as $shortModuleName; - ##[#{tokio}::main] - async fn main() -> Result<(), $shortModuleName::Error> { - #{constructClient} + ##[#{tokio}::main] + async fn main() -> Result<(), $shortModuleName::Error> { + #{constructClient} - // ... make some calls with the client + // ... make some calls with the client - Ok(()) - } - ``` - - See the [client documentation](https://docs.rs/$moduleName/latest/$snakeCaseModuleName/client/struct.Client.html) - for information on what calls can be made, and the inputs and outputs for each of those calls.${"\n"} - """.trimIndent().trimStart(), - "tokio" to CargoDependency.Tokio.toDevDependency().toType(), - "aws_config" to when (compileExample) { - true -> AwsCargoDependency.awsConfig(codegenContext.runtimeConfig).toDevDependency().toType() - else -> writable { rust("aws_config") } - }, - "constructClient" to AwsDocs.constructClient(codegenContext, indent = " "), - ) - - template( - asComments, - """ - #### Using the SDK - - Until the SDK is released, we will be adding information about using the SDK to the - [Developer Guide](https://docs.aws.amazon.com/sdk-for-rust/latest/dg/welcome.html). Feel free to suggest - additional sections for the guide by opening an issue and describing what you are trying to do.${"\n"} - """.trimIndent(), - ) + Ok(()) + } + ``` + + See the [client documentation](https://docs.rs/$moduleName/latest/$snakeCaseModuleName/client/struct.Client.html) + for information on what calls can be made, and the inputs and outputs for each of those calls.${"\n"} + """.trimIndent().trimStart(), + "tokio" to CargoDependency.Tokio.toDevDependency().toType(), + "aws_config" to + when (compileExample) { + true -> AwsCargoDependency.awsConfig(codegenContext.runtimeConfig).toDevDependency().toType() + else -> writable { rust("aws_config") } + }, + "constructClient" to AwsDocs.constructClient(codegenContext, indent = " "), + ) - template( - asComments, - """ - #### Getting Help + template( + asComments, + """ + #### Using the SDK - * [GitHub discussions](https://github.com/awslabs/aws-sdk-rust/discussions) - For ideas, RFCs & general questions - * [GitHub issues](https://github.com/awslabs/aws-sdk-rust/issues/new/choose) - For bug reports & feature requests - * [Generated Docs (latest version)](https://awslabs.github.io/aws-sdk-rust/) - * [Usage examples](https://github.com/awslabs/aws-sdk-rust/tree/main/examples)${"\n"} - """.trimIndent(), - ) + Until the SDK is released, we will be adding information about using the SDK to the + [Developer Guide](https://docs.aws.amazon.com/sdk-for-rust/latest/dg/welcome.html). Feel free to suggest + additional sections for the guide by opening an issue and describing what you are trying to do.${"\n"} + """.trimIndent(), + ) - if (includeLicense) { template( asComments, """ - #### License + #### Getting Help - This project is licensed under the Apache-2.0 License. + * [GitHub discussions](https://github.com/awslabs/aws-sdk-rust/discussions) - For ideas, RFCs & general questions + * [GitHub issues](https://github.com/awslabs/aws-sdk-rust/issues/new/choose) - For bug reports & feature requests + * [Generated Docs (latest version)](https://awslabs.github.io/aws-sdk-rust/) + * [Usage examples](https://github.com/awslabs/aws-sdk-rust/tree/main/examples)${"\n"} """.trimIndent(), ) + + if (includeLicense) { + template( + asComments, + """ + #### License + + This project is licensed under the Apache-2.0 License. + """.trimIndent(), + ) + } } - } internal fun generateCrateDocComment(): Writable = docText(includeHeader = false, includeLicense = false, asComments = true) - internal fun generateReadme(rustCrate: RustCrate) = rustCrate.withFile("README.md") { - docText(includeHeader = true, includeLicense = true, asComments = false)(this) - } + internal fun generateReadme(rustCrate: RustCrate) = + rustCrate.withFile("README.md") { + docText(includeHeader = true, includeLicense = true, asComments = false)(this) + } /** * Strips HTML from the description and makes it human-readable Markdown. */ - internal fun normalizeDescription(moduleName: String, input: String): String { + internal fun normalizeDescription( + moduleName: String, + input: String, + ): String { val doc = Jsoup.parse(input) doc.body().apply { // The order of operations here is important: @@ -281,7 +301,10 @@ internal class AwsCrateDocGenerator(private val codegenContext: ClientCodegenCon getElementsByTag("i").forEach { normalizeInlineStyleTag("_", it) } } - private fun normalizeInlineStyleTag(surround: String, tag: Element) { + private fun normalizeInlineStyleTag( + surround: String, + tag: Element, + ) { tag.replaceWith( Element("span").also { span -> span.append(surround) diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsDocs.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsDocs.kt index 6263d54497d..61294802986 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsDocs.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsDocs.kt @@ -26,7 +26,10 @@ object AwsDocs { ShapeId.from("com.amazonaws.sts#AWSSecurityTokenServiceV20110615"), ).contains(codegenContext.serviceShape.id) - fun constructClient(codegenContext: ClientCodegenContext, indent: String): Writable { + fun constructClient( + codegenContext: ClientCodegenContext, + indent: String, + ): Writable { val crateName = codegenContext.moduleUseName() return writable { writeCustomizationsOrElse( @@ -46,60 +49,61 @@ object AwsDocs { } } - fun clientConstructionDocs(codegenContext: ClientCodegenContext): Writable = { - if (canRelyOnAwsConfig(codegenContext)) { - val crateName = codegenContext.moduleUseName() - docsTemplate( - """ - #### Constructing a `Client` + fun clientConstructionDocs(codegenContext: ClientCodegenContext): Writable = + { + if (canRelyOnAwsConfig(codegenContext)) { + val crateName = codegenContext.moduleUseName() + docsTemplate( + """ + #### Constructing a `Client` - A [`Config`] is required to construct a client. For most use cases, the [`aws-config`] - crate should be used to automatically resolve this config using - [`aws_config::load_from_env()`], since this will resolve an [`SdkConfig`] which can be shared - across multiple different AWS SDK clients. This config resolution process can be customized - by calling [`aws_config::from_env()`] instead, which returns a [`ConfigLoader`] that uses - the [builder pattern] to customize the default config. + A [`Config`] is required to construct a client. For most use cases, the [`aws-config`] + crate should be used to automatically resolve this config using + [`aws_config::load_from_env()`], since this will resolve an [`SdkConfig`] which can be shared + across multiple different AWS SDK clients. This config resolution process can be customized + by calling [`aws_config::from_env()`] instead, which returns a [`ConfigLoader`] that uses + the [builder pattern] to customize the default config. - In the simplest case, creating a client looks as follows: - ```rust,no_run - ## async fn wrapper() { - #{constructClient} - ## } - ``` + In the simplest case, creating a client looks as follows: + ```rust,no_run + ## async fn wrapper() { + #{constructClient} + ## } + ``` - Occasionally, SDKs may have additional service-specific that can be set on the [`Config`] that - is absent from [`SdkConfig`], or slightly different settings for a specific client may be desired. - The [`Config`] struct implements `From<&SdkConfig>`, so setting these specific settings can be - done as follows: + Occasionally, SDKs may have additional service-specific that can be set on the [`Config`] that + is absent from [`SdkConfig`], or slightly different settings for a specific client may be desired. + The [`Config`] struct implements `From<&SdkConfig>`, so setting these specific settings can be + done as follows: - ```rust,no_run - ## async fn wrapper() { - let sdk_config = #{aws_config}::load_from_env().await; - let config = $crateName::config::Builder::from(&sdk_config) - ## /* - .some_service_specific_setting("value") - ## */ - .build(); - ## } - ``` + ```rust,no_run + ## async fn wrapper() { + let sdk_config = #{aws_config}::load_from_env().await; + let config = $crateName::config::Builder::from(&sdk_config) + ## /* + .some_service_specific_setting("value") + ## */ + .build(); + ## } + ``` - See the [`aws-config` docs] and [`Config`] for more information on customizing configuration. + See the [`aws-config` docs] and [`Config`] for more information on customizing configuration. - _Note:_ Client construction is expensive due to connection thread pool initialization, and should - be done once at application start-up. + _Note:_ Client construction is expensive due to connection thread pool initialization, and should + be done once at application start-up. - [`Config`]: crate::Config - [`ConfigLoader`]: https://docs.rs/aws-config/*/aws_config/struct.ConfigLoader.html - [`SdkConfig`]: https://docs.rs/aws-config/*/aws_config/struct.SdkConfig.html - [`aws-config` docs]: https://docs.rs/aws-config/* - [`aws-config`]: https://crates.io/crates/aws-config - [`aws_config::from_env()`]: https://docs.rs/aws-config/*/aws_config/fn.from_env.html - [`aws_config::load_from_env()`]: https://docs.rs/aws-config/*/aws_config/fn.load_from_env.html - [builder pattern]: https://rust-lang.github.io/api-guidelines/type-safety.html##builders-enable-construction-of-complex-values-c-builder - """.trimIndent(), - "aws_config" to AwsCargoDependency.awsConfig(codegenContext.runtimeConfig).toDevDependency().toType(), - "constructClient" to constructClient(codegenContext, indent = ""), - ) + [`Config`]: crate::Config + [`ConfigLoader`]: https://docs.rs/aws-config/*/aws_config/struct.ConfigLoader.html + [`SdkConfig`]: https://docs.rs/aws-config/*/aws_config/struct.SdkConfig.html + [`aws-config` docs]: https://docs.rs/aws-config/* + [`aws-config`]: https://crates.io/crates/aws-config + [`aws_config::from_env()`]: https://docs.rs/aws-config/*/aws_config/fn.from_env.html + [`aws_config::load_from_env()`]: https://docs.rs/aws-config/*/aws_config/fn.load_from_env.html + [builder pattern]: https://rust-lang.github.io/api-guidelines/type-safety.html##builders-enable-construction-of-complex-values-c-builder + """.trimIndent(), + "aws_config" to AwsCargoDependency.awsConfig(codegenContext.runtimeConfig).toDevDependency().toType(), + "constructClient" to constructClient(codegenContext, indent = ""), + ) + } } - } } diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsFluentClientDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsFluentClientDecorator.kt index a118823e725..4ca0fb81101 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsFluentClientDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsFluentClientDecorator.kt @@ -43,15 +43,19 @@ class AwsFluentClientDecorator : ClientCodegenDecorator { // Must run after the AwsPresigningDecorator so that the presignable trait is correctly added to operations override val order: Byte = (AwsPresigningDecorator.ORDER + 1).toByte() - override fun extras(codegenContext: ClientCodegenContext, rustCrate: RustCrate) { + override fun extras( + codegenContext: ClientCodegenContext, + rustCrate: RustCrate, + ) { val runtimeConfig = codegenContext.runtimeConfig val types = Types(runtimeConfig) FluentClientGenerator( codegenContext, - customizations = listOf( - AwsPresignedFluentBuilderMethod(codegenContext), - AwsFluentClientDocs(codegenContext), - ), + customizations = + listOf( + AwsPresignedFluentBuilderMethod(codegenContext), + AwsFluentClientDocs(codegenContext), + ), ).render(rustCrate, emptyList()) rustCrate.withModule(ClientRustModule.client) { AwsFluentClientExtensions(codegenContext, types).render(this) @@ -63,47 +67,52 @@ class AwsFluentClientDecorator : ClientCodegenDecorator { codegenContext: ClientCodegenContext, baseCustomizations: List, ): List { - return baseCustomizations + object : LibRsCustomization() { - override fun section(section: LibRsSection) = when (section) { - is LibRsSection.Body -> writable { - Attribute.DocInline.render(this) - rust("pub use client::Client;") - } - - else -> emptySection + return baseCustomizations + + object : LibRsCustomization() { + override fun section(section: LibRsSection) = + when (section) { + is LibRsSection.Body -> + writable { + Attribute.DocInline.render(this) + rust("pub use client::Client;") + } + + else -> emptySection + } } - } } override fun protocolTestGenerator( codegenContext: ClientCodegenContext, baseGenerator: ProtocolTestGenerator, - ): ProtocolTestGenerator = DefaultProtocolTestGenerator( - codegenContext, - baseGenerator.protocolSupport, - baseGenerator.operationShape, - renderClientCreation = { params -> - rustTemplate( - """ - let mut ${params.configBuilderName} = ${params.configBuilderName}; - ${params.configBuilderName}.set_region(Some(crate::config::Region::new("us-east-1"))); + ): ProtocolTestGenerator = + DefaultProtocolTestGenerator( + codegenContext, + baseGenerator.protocolSupport, + baseGenerator.operationShape, + renderClientCreation = { params -> + rustTemplate( + """ + let mut ${params.configBuilderName} = ${params.configBuilderName}; + ${params.configBuilderName}.set_region(Some(crate::config::Region::new("us-east-1"))); - let config = ${params.configBuilderName}.http_client(${params.httpClientName}).build(); - let ${params.clientName} = #{Client}::from_conf(config); - """, - "Client" to ClientRustModule.root.toType().resolve("Client"), - ) - }, - ) + let config = ${params.configBuilderName}.http_client(${params.httpClientName}).build(); + let ${params.clientName} = #{Client}::from_conf(config); + """, + "Client" to ClientRustModule.root.toType().resolve("Client"), + ) + }, + ) } private class AwsFluentClientExtensions(private val codegenContext: ClientCodegenContext, private val types: Types) { - private val codegenScope = arrayOf( - "Arc" to RuntimeType.Arc, - "RetryConfig" to types.retryConfig, - "TimeoutConfig" to types.timeoutConfig, - "aws_types" to types.awsTypes, - ) + private val codegenScope = + arrayOf( + "Arc" to RuntimeType.Arc, + "RetryConfig" to types.retryConfig, + "TimeoutConfig" to types.timeoutConfig, + "aws_types" to types.awsTypes, + ) fun render(writer: RustWriter) { writer.rustBlockTemplate("impl Client", *codegenScope) { @@ -134,17 +143,18 @@ private class AwsFluentClientDocs(private val codegenContext: ClientCodegenConte override fun section(section: FluentClientSection): Writable { return when (section) { - is FluentClientSection.FluentClientDocs -> writable { - rustTemplate( - """ - /// Client for $serviceName - /// - /// Client for invoking operations on $serviceName. Each operation on $serviceName is a method on this + is FluentClientSection.FluentClientDocs -> + writable { + rustTemplate( + """ + /// Client for $serviceName + /// + /// Client for invoking operations on $serviceName. Each operation on $serviceName is a method on this /// this struct. `.send()` MUST be invoked on the generated operations to dispatch the request to the service.""", - ) - AwsDocs.clientConstructionDocs(codegenContext)(this) - FluentClientDocs.clientUsageDocs(codegenContext)(this) - } + ) + AwsDocs.clientConstructionDocs(codegenContext)(this) + FluentClientDocs.clientUsageDocs(codegenContext)(this) + } else -> emptySection } diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsPresigningDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsPresigningDecorator.kt index 4d7e3d8bdec..5362a3c721f 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsPresigningDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsPresigningDecorator.kt @@ -38,10 +38,11 @@ import software.amazon.smithy.rust.codegen.core.util.expectTrait import software.amazon.smithy.rustsdk.traits.PresignableTrait import kotlin.streams.toList -private val presigningTypes: List> = listOf( - "PresignedRequest" to AwsRuntimeType.presigning().resolve("PresignedRequest"), - "PresigningConfig" to AwsRuntimeType.presigning().resolve("PresigningConfig"), -) +private val presigningTypes: List> = + listOf( + "PresignedRequest" to AwsRuntimeType.presigning().resolve("PresignedRequest"), + "PresigningConfig" to AwsRuntimeType.presigning().resolve("PresigningConfig"), + ) internal enum class PayloadSigningType { EMPTY, @@ -68,17 +69,18 @@ internal val PRESIGNABLE_OPERATIONS by lazy { ShapeId.from("com.amazonaws.s3#PutObject") to PresignableOperation(PayloadSigningType.UNSIGNED_PAYLOAD), ShapeId.from("com.amazonaws.s3#UploadPart") to PresignableOperation(PayloadSigningType.UNSIGNED_PAYLOAD), ShapeId.from("com.amazonaws.s3#DeleteObject") to PresignableOperation(PayloadSigningType.UNSIGNED_PAYLOAD), - // Polly - SYNTHESIZE_SPEECH_OP to PresignableOperation( - PayloadSigningType.EMPTY, - // Polly's SynthesizeSpeech operation has the HTTP method overridden to GET, - // and the document members changed to query param members. - modelTransforms = listOf( - OverrideHttpMethodTransform(mapOf(SYNTHESIZE_SPEECH_OP to "GET")), - MoveDocumentMembersToQueryParamsTransform(listOf(SYNTHESIZE_SPEECH_OP)), + SYNTHESIZE_SPEECH_OP to + PresignableOperation( + PayloadSigningType.EMPTY, + // Polly's SynthesizeSpeech operation has the HTTP method overridden to GET, + // and the document members changed to query param members. + modelTransforms = + listOf( + OverrideHttpMethodTransform(mapOf(SYNTHESIZE_SPEECH_OP to "GET")), + MoveDocumentMembersToQueryParamsTransform(listOf(SYNTHESIZE_SPEECH_OP)), + ), ), - ), ) } @@ -95,25 +97,31 @@ class AwsPresigningDecorator internal constructor( /** * Adds presignable trait to known presignable operations and creates synthetic presignable shapes for codegen */ - override fun transformModel(service: ServiceShape, model: Model, settings: ClientRustSettings): Model { + override fun transformModel( + service: ServiceShape, + model: Model, + settings: ClientRustSettings, + ): Model { val modelWithSynthetics = addSyntheticOperations(model) val presignableTransforms = mutableListOf() - val intermediate = ModelTransformer.create().mapShapes(modelWithSynthetics) { shape -> - if (shape is OperationShape && presignableOperations.containsKey(shape.id)) { - presignableTransforms.addAll(presignableOperations.getValue(shape.id).modelTransforms) - shape.toBuilder().addTrait(PresignableTrait(syntheticShapeId(shape))).build() - } else { - shape + val intermediate = + ModelTransformer.create().mapShapes(modelWithSynthetics) { shape -> + if (shape is OperationShape && presignableOperations.containsKey(shape.id)) { + presignableTransforms.addAll(presignableOperations.getValue(shape.id).modelTransforms) + shape.toBuilder().addTrait(PresignableTrait(syntheticShapeId(shape))).build() + } else { + shape + } } - } // Apply operation-specific model transformations return presignableTransforms.fold(intermediate) { m, t -> t.transform(m) } } private fun addSyntheticOperations(model: Model): Model { - val presignableOps = model.shapes() - .filter { shape -> shape is OperationShape && presignableOperations.containsKey(shape.id) } - .toList() + val presignableOps = + model.shapes() + .filter { shape -> shape is OperationShape && presignableOperations.containsKey(shape.id) } + .toList() return model.toBuilder().also { builder -> for (op in presignableOps) { builder.cloneOperation(model, op, ::syntheticShapeId) @@ -126,12 +134,14 @@ class AwsPresignedFluentBuilderMethod( private val codegenContext: ClientCodegenContext, ) : FluentClientCustomization() { private val runtimeConfig = codegenContext.runtimeConfig - private val codegenScope = ( - presigningTypes + arrayOf( - *RuntimeType.preludeScope, - "Error" to AwsRuntimeType.presigning().resolve("config::Error"), - "SdkError" to RuntimeType.sdkError(runtimeConfig), - ) + private val codegenScope = + ( + presigningTypes + + arrayOf( + *RuntimeType.preludeScope, + "Error" to AwsRuntimeType.presigning().resolve("config::Error"), + "SdkError" to RuntimeType.sdkError(runtimeConfig), + ) ).toTypedArray() override fun section(section: FluentClientSection): Writable = @@ -158,11 +168,12 @@ class AwsPresignedFluentBuilderMethod( private fun RustWriter.renderPresignedMethodBody(section: FluentClientSection.FluentBuilderImpl) { val presignableOp = PRESIGNABLE_OPERATIONS.getValue(section.operationShape.id) - val operationShape = if (presignableOp.hasModelTransforms()) { - codegenContext.model.expectShape(syntheticShapeId(section.operationShape.id), OperationShape::class.java) - } else { - section.operationShape - } + val operationShape = + if (presignableOp.hasModelTransforms()) { + codegenContext.model.expectShape(syntheticShapeId(section.operationShape.id), OperationShape::class.java) + } else { + section.operationShape + } rustTemplate( """ @@ -191,52 +202,58 @@ class AwsPresignedFluentBuilderMethod( "Operation" to codegenContext.symbolProvider.toSymbol(section.operationShape), "OperationError" to section.operationErrorType, "RuntimePlugins" to RuntimeType.runtimePlugins(runtimeConfig), - "SharedInterceptor" to RuntimeType.smithyRuntimeApiClient(runtimeConfig).resolve("client::interceptors") - .resolve("SharedInterceptor"), - "SigV4PresigningRuntimePlugin" to AwsRuntimeType.presigningInterceptor(runtimeConfig) - .resolve("SigV4PresigningRuntimePlugin"), + "SharedInterceptor" to + RuntimeType.smithyRuntimeApiClient(runtimeConfig).resolve("client::interceptors") + .resolve("SharedInterceptor"), + "SigV4PresigningRuntimePlugin" to + AwsRuntimeType.presigningInterceptor(runtimeConfig) + .resolve("SigV4PresigningRuntimePlugin"), "StopPoint" to RuntimeType.smithyRuntime(runtimeConfig).resolve("client::orchestrator::StopPoint"), "USER_AGENT" to CargoDependency.Http.toType().resolve("header::USER_AGENT"), - "alternate_presigning_serializer" to writable { - if (presignableOp.hasModelTransforms()) { - val smithyTypes = RuntimeType.smithyTypes(codegenContext.runtimeConfig) - rustTemplate( - """ - ##[derive(::std::fmt::Debug)] - struct AlternatePresigningSerializerRuntimePlugin; - impl #{RuntimePlugin} for AlternatePresigningSerializerRuntimePlugin { - fn config(&self) -> #{Option}<#{FrozenLayer}> { - let mut cfg = #{Layer}::new("presigning_serializer"); - cfg.store_put(#{SharedRequestSerializer}::new(#{AlternateSerializer})); - #{Some}(cfg.freeze()) + "alternate_presigning_serializer" to + writable { + if (presignableOp.hasModelTransforms()) { + val smithyTypes = RuntimeType.smithyTypes(codegenContext.runtimeConfig) + rustTemplate( + """ + ##[derive(::std::fmt::Debug)] + struct AlternatePresigningSerializerRuntimePlugin; + impl #{RuntimePlugin} for AlternatePresigningSerializerRuntimePlugin { + fn config(&self) -> #{Option}<#{FrozenLayer}> { + let mut cfg = #{Layer}::new("presigning_serializer"); + cfg.store_put(#{SharedRequestSerializer}::new(#{AlternateSerializer})); + #{Some}(cfg.freeze()) + } } - } - """, - *preludeScope, - "AlternateSerializer" to alternateSerializer(operationShape), - "FrozenLayer" to smithyTypes.resolve("config_bag::FrozenLayer"), - "Layer" to smithyTypes.resolve("config_bag::Layer"), - "RuntimePlugin" to RuntimeType.runtimePlugin(codegenContext.runtimeConfig), - "SharedRequestSerializer" to RuntimeType.smithyRuntimeApiClient(codegenContext.runtimeConfig) - .resolve("client::ser_de::SharedRequestSerializer"), + """, + *preludeScope, + "AlternateSerializer" to alternateSerializer(operationShape), + "FrozenLayer" to smithyTypes.resolve("config_bag::FrozenLayer"), + "Layer" to smithyTypes.resolve("config_bag::Layer"), + "RuntimePlugin" to RuntimeType.runtimePlugin(codegenContext.runtimeConfig), + "SharedRequestSerializer" to + RuntimeType.smithyRuntimeApiClient(codegenContext.runtimeConfig) + .resolve("client::ser_de::SharedRequestSerializer"), + ) + } + }, + "alternate_presigning_serializer_registration" to + writable { + if (presignableOp.hasModelTransforms()) { + rust(".with_operation_plugin(AlternatePresigningSerializerRuntimePlugin)") + } + }, + "payload_override" to + writable { + rustTemplate( + "#{aws_sigv4}::http_request::SignableBody::" + + when (presignableOp.payloadSigningType) { + PayloadSigningType.EMPTY -> "Bytes(b\"\")" + PayloadSigningType.UNSIGNED_PAYLOAD -> "UnsignedPayload" + }, + "aws_sigv4" to AwsRuntimeType.awsSigv4(runtimeConfig), ) - } - }, - "alternate_presigning_serializer_registration" to writable { - if (presignableOp.hasModelTransforms()) { - rust(".with_operation_plugin(AlternatePresigningSerializerRuntimePlugin)") - } - }, - "payload_override" to writable { - rustTemplate( - "#{aws_sigv4}::http_request::SignableBody::" + - when (presignableOp.payloadSigningType) { - PayloadSigningType.EMPTY -> "Bytes(b\"\")" - PayloadSigningType.UNSIGNED_PAYLOAD -> "UnsignedPayload" - }, - "aws_sigv4" to AwsRuntimeType.awsSigv4(runtimeConfig), - ) - }, + }, ) } @@ -305,19 +322,21 @@ class MoveDocumentMembersToQueryParamsTransform( ) : PresignModelTransform { override fun transform(model: Model): Model { val index = HttpBindingIndex(model) - val operations = presignableOperations.map { id -> - model.expectShape(syntheticShapeId(id), OperationShape::class.java).also { shape -> - check(shape.hasTrait(HttpTrait.ID)) { - "MoveDocumentMembersToQueryParamsTransform can only be used with REST protocols" + val operations = + presignableOperations.map { id -> + model.expectShape(syntheticShapeId(id), OperationShape::class.java).also { shape -> + check(shape.hasTrait(HttpTrait.ID)) { + "MoveDocumentMembersToQueryParamsTransform can only be used with REST protocols" + } } } - } // Find document members of the presignable operations - val membersToUpdate = operations.map { operation -> - val payloadBindings = index.getRequestBindings(operation, HttpBinding.Location.DOCUMENT) - payloadBindings.map { binding -> binding.member } - }.flatten() + val membersToUpdate = + operations.map { operation -> + val payloadBindings = index.getRequestBindings(operation, HttpBinding.Location.DOCUMENT) + payloadBindings.map { binding -> binding.member } + }.flatten() // Transform found shapes for presigning return ModelTransformer.create().mapShapes(model) { shape -> @@ -331,11 +350,12 @@ class MoveDocumentMembersToQueryParamsTransform( } private fun RustWriter.documentPresignedMethod(hasConfigArg: Boolean) { - val configBlurb = if (hasConfigArg) { - "The credentials provider from the `config` will be used to generate the request's signature.\n" - } else { - "" - } + val configBlurb = + if (hasConfigArg) { + "The credentials provider from the `config` will be used to generate the request's signature.\n" + } else { + "" + } docs( """ Creates a presigned request for this operation. diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsRuntimeType.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsRuntimeType.kt index 9fea90bf893..b7042a90ee4 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsRuntimeType.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsRuntimeType.kt @@ -14,16 +14,18 @@ import java.io.File import java.nio.file.Path fun RuntimeConfig.awsRoot(): RuntimeCrateLocation { - val updatedPath = runtimeCrateLocation.path?.let { cratePath -> - val asPath = Path.of(cratePath) - val path = if (asPath.isAbsolute) { - asPath.parent.resolve("aws/rust-runtime").toAbsolutePath().toString() - } else { - cratePath + val updatedPath = + runtimeCrateLocation.path?.let { cratePath -> + val asPath = Path.of(cratePath) + val path = + if (asPath.isAbsolute) { + asPath.parent.resolve("aws/rust-runtime").toAbsolutePath().toString() + } else { + cratePath + } + check(File(path).exists()) { "$path must exist to generate a working SDK" } + path } - check(File(path).exists()) { "$path must exist to generate a working SDK" } - path - } return runtimeCrateLocation.copy( path = updatedPath, versions = runtimeCrateLocation.versions, ) @@ -32,6 +34,7 @@ fun RuntimeConfig.awsRoot(): RuntimeCrateLocation { object AwsRuntimeType { fun presigning(): RuntimeType = RuntimeType.forInlineDependency(InlineAwsDependency.forRustFile("presigning", visibility = Visibility.PUBLIC)) + fun presigningInterceptor(runtimeConfig: RuntimeConfig): RuntimeType = RuntimeType.forInlineDependency( InlineAwsDependency.forRustFile( @@ -50,8 +53,10 @@ object AwsRuntimeType { fun awsHttp(runtimeConfig: RuntimeConfig) = AwsCargoDependency.awsHttp(runtimeConfig).toType() fun awsSigv4(runtimeConfig: RuntimeConfig) = AwsCargoDependency.awsSigv4(runtimeConfig).toType() + fun awsTypes(runtimeConfig: RuntimeConfig) = AwsCargoDependency.awsTypes(runtimeConfig).toType() fun awsRuntime(runtimeConfig: RuntimeConfig) = AwsCargoDependency.awsRuntime(runtimeConfig).toType() + fun awsRuntimeApi(runtimeConfig: RuntimeConfig) = AwsCargoDependency.awsRuntimeApi(runtimeConfig).toType() } diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/BaseRequestIdDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/BaseRequestIdDecorator.kt index 96d61b4f9f3..91b3a1352eb 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/BaseRequestIdDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/BaseRequestIdDecorator.kt @@ -39,7 +39,9 @@ import software.amazon.smithy.rust.codegen.core.util.hasTrait abstract class BaseRequestIdDecorator : ClientCodegenDecorator { abstract val accessorFunctionName: String abstract val fieldName: String + abstract fun accessorTrait(codegenContext: ClientCodegenContext): RuntimeType + abstract fun applyToError(codegenContext: ClientCodegenContext): RuntimeType override fun operationCustomizations( @@ -51,8 +53,7 @@ abstract class BaseRequestIdDecorator : ClientCodegenDecorator { override fun errorCustomizations( codegenContext: ClientCodegenContext, baseCustomizations: List, - ): List = - baseCustomizations + listOf(RequestIdErrorCustomization(codegenContext)) + ): List = baseCustomizations + listOf(RequestIdErrorCustomization(codegenContext)) override fun errorImplCustomizations( codegenContext: ClientCodegenContext, @@ -69,7 +70,10 @@ abstract class BaseRequestIdDecorator : ClientCodegenDecorator { baseCustomizations: List, ): List = baseCustomizations + listOf(RequestIdBuilderCustomization()) - override fun extras(codegenContext: ClientCodegenContext, rustCrate: RustCrate) { + override fun extras( + codegenContext: ClientCodegenContext, + rustCrate: RustCrate, + ) { rustCrate.withModule(ClientRustModule.Operation) { // Re-export RequestId in generated crate rust("pub use #T;", accessorTrait(codegenContext)) @@ -82,160 +86,167 @@ abstract class BaseRequestIdDecorator : ClientCodegenDecorator { private inner class RequestIdOperationCustomization(private val codegenContext: ClientCodegenContext) : OperationCustomization() { - override fun section(section: OperationSection): Writable = writable { - when (section) { - is OperationSection.PopulateErrorMetadataExtras -> { - rustTemplate( - "${section.builderName} = #{apply_to_error}(${section.builderName}, ${section.responseHeadersName});", - "apply_to_error" to applyToError(codegenContext), - ) - } + override fun section(section: OperationSection): Writable = + writable { + when (section) { + is OperationSection.PopulateErrorMetadataExtras -> { + rustTemplate( + "${section.builderName} = #{apply_to_error}(${section.builderName}, ${section.responseHeadersName});", + "apply_to_error" to applyToError(codegenContext), + ) + } - is OperationSection.MutateOutput -> { - rust( - "output._set_$fieldName(#T::$accessorFunctionName(${section.responseHeadersName}).map(str::to_string));", - accessorTrait(codegenContext), - ) - } + is OperationSection.MutateOutput -> { + rust( + "output._set_$fieldName(#T::$accessorFunctionName(${section.responseHeadersName}).map(str::to_string));", + accessorTrait(codegenContext), + ) + } - is OperationSection.BeforeParseResponse -> { - rustTemplate( - "#{tracing}::debug!($fieldName = ?#{trait}::$accessorFunctionName(${section.responseName}));", - "tracing" to RuntimeType.Tracing, - "trait" to accessorTrait(codegenContext), - ) - } + is OperationSection.BeforeParseResponse -> { + rustTemplate( + "#{tracing}::debug!($fieldName = ?#{trait}::$accessorFunctionName(${section.responseName}));", + "tracing" to RuntimeType.Tracing, + "trait" to accessorTrait(codegenContext), + ) + } - else -> {} + else -> {} + } } - } } private inner class RequestIdErrorCustomization(private val codegenContext: ClientCodegenContext) : ErrorCustomization() { - override fun section(section: ErrorSection): Writable = writable { - when (section) { - is ErrorSection.OperationErrorAdditionalTraitImpls -> { - rustTemplate( - """ - impl #{AccessorTrait} for #{error} { - fn $accessorFunctionName(&self) -> Option<&str> { - self.meta().$accessorFunctionName() + override fun section(section: ErrorSection): Writable = + writable { + when (section) { + is ErrorSection.OperationErrorAdditionalTraitImpls -> { + rustTemplate( + """ + impl #{AccessorTrait} for #{error} { + fn $accessorFunctionName(&self) -> Option<&str> { + self.meta().$accessorFunctionName() + } } - } - """, - "AccessorTrait" to accessorTrait(codegenContext), - "error" to section.errorSymbol, - ) - } + """, + "AccessorTrait" to accessorTrait(codegenContext), + "error" to section.errorSymbol, + ) + } - is ErrorSection.ServiceErrorAdditionalTraitImpls -> { - rustBlock("impl #T for Error", accessorTrait(codegenContext)) { - rustBlock("fn $accessorFunctionName(&self) -> Option<&str>") { - rustBlock("match self") { - section.allErrors.forEach { error -> - val optional = asMemberShape(error)?.let { member -> - codegenContext.symbolProvider.toSymbol(member).isOptional() - } ?: true - val wrapped = writable { - when (optional) { - false -> rustTemplate("#{Some}(e.$accessorFunctionName())", *preludeScope) - true -> rustTemplate("e.$accessorFunctionName()") - } + is ErrorSection.ServiceErrorAdditionalTraitImpls -> { + rustBlock("impl #T for Error", accessorTrait(codegenContext)) { + rustBlock("fn $accessorFunctionName(&self) -> Option<&str>") { + rustBlock("match self") { + section.allErrors.forEach { error -> + val optional = + asMemberShape(error)?.let { member -> + codegenContext.symbolProvider.toSymbol(member).isOptional() + } ?: true + val wrapped = + writable { + when (optional) { + false -> rustTemplate("#{Some}(e.$accessorFunctionName())", *preludeScope) + true -> rustTemplate("e.$accessorFunctionName()") + } + } + val sym = codegenContext.symbolProvider.toSymbol(error) + rust("Self::${sym.name}(e) => #T,", wrapped) } - val sym = codegenContext.symbolProvider.toSymbol(error) - rust("Self::${sym.name}(e) => #T,", wrapped) + rust("Self::Unhandled(e) => e.meta.$accessorFunctionName(),") } - rust("Self::Unhandled(e) => e.meta.$accessorFunctionName(),") } } } } } - } } private inner class RequestIdErrorImplCustomization(private val codegenContext: ClientCodegenContext) : ErrorImplCustomization() { - override fun section(section: ErrorImplSection): Writable = writable { - when (section) { - is ErrorImplSection.ErrorAdditionalTraitImpls -> { - rustBlock("impl #1T for #2T", accessorTrait(codegenContext), section.errorType) { - rustBlock("fn $accessorFunctionName(&self) -> Option<&str>") { - rust("use #T;", RuntimeType.provideErrorMetadataTrait(codegenContext.runtimeConfig)) - rust("self.meta().$accessorFunctionName()") + override fun section(section: ErrorImplSection): Writable = + writable { + when (section) { + is ErrorImplSection.ErrorAdditionalTraitImpls -> { + rustBlock("impl #1T for #2T", accessorTrait(codegenContext), section.errorType) { + rustBlock("fn $accessorFunctionName(&self) -> Option<&str>") { + rust("use #T;", RuntimeType.provideErrorMetadataTrait(codegenContext.runtimeConfig)) + rust("self.meta().$accessorFunctionName()") + } } } - } - else -> {} + else -> {} + } } - } } private inner class RequestIdStructureCustomization(private val codegenContext: ClientCodegenContext) : StructureCustomization() { - override fun section(section: StructureSection): Writable = writable { - if (section.shape.hasTrait()) { - when (section) { - is StructureSection.AdditionalFields -> { - rust("_$fieldName: Option,") - } + override fun section(section: StructureSection): Writable = + writable { + if (section.shape.hasTrait()) { + when (section) { + is StructureSection.AdditionalFields -> { + rust("_$fieldName: Option,") + } - is StructureSection.AdditionalTraitImpls -> { - rustTemplate( - """ - impl #{AccessorTrait} for ${section.structName} { - fn $accessorFunctionName(&self) -> Option<&str> { - self._$fieldName.as_deref() + is StructureSection.AdditionalTraitImpls -> { + rustTemplate( + """ + impl #{AccessorTrait} for ${section.structName} { + fn $accessorFunctionName(&self) -> Option<&str> { + self._$fieldName.as_deref() + } } - } - """, - "AccessorTrait" to accessorTrait(codegenContext), - ) - } + """, + "AccessorTrait" to accessorTrait(codegenContext), + ) + } - is StructureSection.AdditionalDebugFields -> { - rust("""${section.formatterName}.field("_$fieldName", &self._$fieldName);""") + is StructureSection.AdditionalDebugFields -> { + rust("""${section.formatterName}.field("_$fieldName", &self._$fieldName);""") + } } } } - } } private inner class RequestIdBuilderCustomization : BuilderCustomization() { - override fun section(section: BuilderSection): Writable = writable { - if (section.shape.hasTrait()) { - when (section) { - is BuilderSection.AdditionalFields -> { - rust("_$fieldName: Option,") - } + override fun section(section: BuilderSection): Writable = + writable { + if (section.shape.hasTrait()) { + when (section) { + is BuilderSection.AdditionalFields -> { + rust("_$fieldName: Option,") + } - is BuilderSection.AdditionalMethods -> { - rust( - """ - pub(crate) fn _$fieldName(mut self, $fieldName: impl Into) -> Self { - self._$fieldName = Some($fieldName.into()); - self - } + is BuilderSection.AdditionalMethods -> { + rust( + """ + pub(crate) fn _$fieldName(mut self, $fieldName: impl Into) -> Self { + self._$fieldName = Some($fieldName.into()); + self + } - pub(crate) fn _set_$fieldName(&mut self, $fieldName: Option) -> &mut Self { - self._$fieldName = $fieldName; - self - } - """, - ) - } + pub(crate) fn _set_$fieldName(&mut self, $fieldName: Option) -> &mut Self { + self._$fieldName = $fieldName; + self + } + """, + ) + } - is BuilderSection.AdditionalDebugFields -> { - rust("""${section.formatterName}.field("_$fieldName", &self._$fieldName);""") - } + is BuilderSection.AdditionalDebugFields -> { + rust("""${section.formatterName}.field("_$fieldName", &self._$fieldName);""") + } - is BuilderSection.AdditionalFieldsInBuild -> { - rust("_$fieldName: self._$fieldName,") + is BuilderSection.AdditionalFieldsInBuild -> { + rust("_$fieldName: self._$fieldName,") + } } } } - } } } diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/CrateLicenseDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/CrateLicenseDecorator.kt index 22ebd0ff1e4..e842d731dc0 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/CrateLicenseDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/CrateLicenseDecorator.kt @@ -14,7 +14,10 @@ class CrateLicenseDecorator : ClientCodegenDecorator { override val name: String = "CrateLicense" override val order: Byte = 0 - override fun extras(codegenContext: ClientCodegenContext, rustCrate: RustCrate) { + override fun extras( + codegenContext: ClientCodegenContext, + rustCrate: RustCrate, + ) { rustCrate.withFile("LICENSE") { val license = this::class.java.getResource("/LICENSE").readText() raw(license) diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/CredentialProviders.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/CredentialProviders.kt index 331babe611a..ce8610e51ba 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/CredentialProviders.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/CredentialProviders.kt @@ -41,7 +41,10 @@ class CredentialsProviderDecorator : ClientCodegenDecorator { }, ) - override fun extras(codegenContext: ClientCodegenContext, rustCrate: RustCrate) { + override fun extras( + codegenContext: ClientCodegenContext, + rustCrate: RustCrate, + ) { rustCrate.mergeFeature(TestUtilFeature.copy(deps = listOf("aws-credential-types/test-util"))) rustCrate.withModule(ClientRustModule.config) { @@ -58,87 +61,94 @@ class CredentialsProviderDecorator : ClientCodegenDecorator { */ class CredentialProviderConfig(private val codegenContext: ClientCodegenContext) : ConfigCustomization() { private val runtimeConfig = codegenContext.runtimeConfig - private val codegenScope = arrayOf( - *preludeScope, - "Credentials" to configReexport(AwsRuntimeType.awsCredentialTypes(runtimeConfig).resolve("Credentials")), - "ProvideCredentials" to configReexport( - AwsRuntimeType.awsCredentialTypes(runtimeConfig) - .resolve("provider::ProvideCredentials"), - ), - "SharedCredentialsProvider" to configReexport( - AwsRuntimeType.awsCredentialTypes(runtimeConfig) - .resolve("provider::SharedCredentialsProvider"), - ), - "SIGV4A_SCHEME_ID" to AwsRuntimeType.awsRuntime(runtimeConfig) - .resolve("auth::sigv4a::SCHEME_ID"), - "SIGV4_SCHEME_ID" to AwsRuntimeType.awsRuntime(runtimeConfig) - .resolve("auth::sigv4::SCHEME_ID"), - "TestCredentials" to AwsRuntimeType.awsCredentialTypesTestUtil(runtimeConfig).resolve("Credentials"), - ) + private val codegenScope = + arrayOf( + *preludeScope, + "Credentials" to configReexport(AwsRuntimeType.awsCredentialTypes(runtimeConfig).resolve("Credentials")), + "ProvideCredentials" to + configReexport( + AwsRuntimeType.awsCredentialTypes(runtimeConfig) + .resolve("provider::ProvideCredentials"), + ), + "SharedCredentialsProvider" to + configReexport( + AwsRuntimeType.awsCredentialTypes(runtimeConfig) + .resolve("provider::SharedCredentialsProvider"), + ), + "SIGV4A_SCHEME_ID" to + AwsRuntimeType.awsRuntime(runtimeConfig) + .resolve("auth::sigv4a::SCHEME_ID"), + "SIGV4_SCHEME_ID" to + AwsRuntimeType.awsRuntime(runtimeConfig) + .resolve("auth::sigv4::SCHEME_ID"), + "TestCredentials" to AwsRuntimeType.awsCredentialTypesTestUtil(runtimeConfig).resolve("Credentials"), + ) - override fun section(section: ServiceConfig) = writable { - when (section) { - ServiceConfig.ConfigImpl -> { - rustTemplate( - """ - /// This function was intended to be removed, and has been broken since release-2023-11-15 as it always returns a `None`. Do not use. - ##[deprecated(note = "This function was intended to be removed, and has been broken since release-2023-11-15 as it always returns a `None`. Do not use.")] - pub fn credentials_provider(&self) -> Option<#{SharedCredentialsProvider}> { - #{None} - } - """, - *codegenScope, - ) - } + override fun section(section: ServiceConfig) = + writable { + when (section) { + ServiceConfig.ConfigImpl -> { + rustTemplate( + """ + /// This function was intended to be removed, and has been broken since release-2023-11-15 as it always returns a `None`. Do not use. + ##[deprecated(note = "This function was intended to be removed, and has been broken since release-2023-11-15 as it always returns a `None`. Do not use.")] + pub fn credentials_provider(&self) -> Option<#{SharedCredentialsProvider}> { + #{None} + } + """, + *codegenScope, + ) + } - ServiceConfig.BuilderImpl -> { - rustTemplate( - """ - /// Sets the credentials provider for this service - pub fn credentials_provider(mut self, credentials_provider: impl #{ProvideCredentials} + 'static) -> Self { - self.set_credentials_provider(#{Some}(#{SharedCredentialsProvider}::new(credentials_provider))); - self - } - """, - *codegenScope, - ) + ServiceConfig.BuilderImpl -> { + rustTemplate( + """ + /// Sets the credentials provider for this service + pub fn credentials_provider(mut self, credentials_provider: impl #{ProvideCredentials} + 'static) -> Self { + self.set_credentials_provider(#{Some}(#{SharedCredentialsProvider}::new(credentials_provider))); + self + } + """, + *codegenScope, + ) - rustBlockTemplate( - """ - /// Sets the credentials provider for this service - pub fn set_credentials_provider(&mut self, credentials_provider: #{Option}<#{SharedCredentialsProvider}>) -> &mut Self - """, - *codegenScope, - ) { rustBlockTemplate( """ - if let Some(credentials_provider) = credentials_provider + /// Sets the credentials provider for this service + pub fn set_credentials_provider(&mut self, credentials_provider: #{Option}<#{SharedCredentialsProvider}>) -> &mut Self """, *codegenScope, ) { - if (codegenContext.serviceShape.supportedAuthSchemes().contains("sigv4a")) { - featureGateBlock("sigv4a") { - rustTemplate( - "self.runtime_components.set_identity_resolver(#{SIGV4A_SCHEME_ID}, credentials_provider.clone());", - *codegenScope, - ) + rustBlockTemplate( + """ + if let Some(credentials_provider) = credentials_provider + """, + *codegenScope, + ) { + if (codegenContext.serviceShape.supportedAuthSchemes().contains("sigv4a")) { + featureGateBlock("sigv4a") { + rustTemplate( + "self.runtime_components.set_identity_resolver(#{SIGV4A_SCHEME_ID}, credentials_provider.clone());", + *codegenScope, + ) + } } + rustTemplate( + "self.runtime_components.set_identity_resolver(#{SIGV4_SCHEME_ID}, credentials_provider);", + *codegenScope, + ) } - rustTemplate( - "self.runtime_components.set_identity_resolver(#{SIGV4_SCHEME_ID}, credentials_provider);", - *codegenScope, - ) + rust("self") } - rust("self") } - } - is ServiceConfig.DefaultForTests -> rustTemplate( - "${section.configBuilderRef}.set_credentials_provider(Some(#{SharedCredentialsProvider}::new(#{TestCredentials}::for_tests())));", - *codegenScope, - ) + is ServiceConfig.DefaultForTests -> + rustTemplate( + "${section.configBuilderRef}.set_credentials_provider(Some(#{SharedCredentialsProvider}::new(#{TestCredentials}::for_tests())));", + *codegenScope, + ) - else -> emptySection + else -> emptySection + } } - } } diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/EndpointBuiltInsDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/EndpointBuiltInsDecorator.kt index 636f1951596..3a9b5815494 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/EndpointBuiltInsDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/EndpointBuiltInsDecorator.kt @@ -50,7 +50,9 @@ fun EndpointRuleSet.getBuiltIn(builtIn: String) = parameters.toList().find { it. /** load a builtIn parameter from a ruleset. The returned builtIn is the one defined in the ruleset (including latest docs, etc.) */ fun EndpointRuleSet.getBuiltIn(builtIn: Parameter) = getBuiltIn(builtIn.builtIn.orNull()!!) + fun ClientCodegenContext.getBuiltIn(builtIn: Parameter): Parameter? = getBuiltIn(builtIn.builtIn.orNull()!!) + fun ClientCodegenContext.getBuiltIn(builtIn: String): Parameter? { val idx = EndpointRulesetIndex.of(model) val rules = idx.endpointRulesForService(serviceShape) ?: return null @@ -62,24 +64,35 @@ private fun promotedBuiltins(parameter: Parameter) = parameter.builtIn == AwsBuiltIns.DUALSTACK.builtIn || parameter.builtIn == BuiltIns.SDK_ENDPOINT.builtIn -private fun configParamNewtype(parameter: Parameter, name: String, runtimeConfig: RuntimeConfig): RuntimeType { +private fun configParamNewtype( + parameter: Parameter, + name: String, + runtimeConfig: RuntimeConfig, +): RuntimeType { val type = parameter.symbol().mapRustType { t -> t.stripOuter() } return when (promotedBuiltins(parameter)) { - true -> AwsRuntimeType.awsTypes(runtimeConfig) - .resolve("endpoint_config::${name.toPascalCase()}") + true -> + AwsRuntimeType.awsTypes(runtimeConfig) + .resolve("endpoint_config::${name.toPascalCase()}") false -> configParamNewtype(name.toPascalCase(), type, runtimeConfig) } } -private fun ConfigParam.Builder.toConfigParam(parameter: Parameter, runtimeConfig: RuntimeConfig): ConfigParam = +private fun ConfigParam.Builder.toConfigParam( + parameter: Parameter, + runtimeConfig: RuntimeConfig, +): ConfigParam = this.name(this.name ?: parameter.name.rustName()) .type(parameter.symbol().mapRustType { t -> t.stripOuter() }) .newtype(configParamNewtype(parameter, this.name!!, runtimeConfig)) .setterDocs(this.setterDocs ?: parameter.documentation.orNull()?.let { writable { docs(it) } }) .build() -fun Model.loadBuiltIn(serviceId: ShapeId, builtInSrc: Parameter): Parameter? { +fun Model.loadBuiltIn( + serviceId: ShapeId, + builtInSrc: Parameter, +): Parameter? { val model = this val idx = EndpointRulesetIndex.of(model) val service = model.expectShape(serviceId, ServiceShape::class.java) @@ -96,10 +109,11 @@ fun Model.sdkConfigSetter( val builtIn = loadBuiltIn(serviceId, builtInSrc) ?: return null val fieldName = configParameterNameOverride ?: builtIn.name.rustName() - val map = when (builtIn.type!!) { - ParameterType.STRING -> writable { rust("|s|s.to_string()") } - ParameterType.BOOLEAN -> null - } + val map = + when (builtIn.type!!) { + ParameterType.STRING -> writable { rust("|s|s.to_string()") } + ParameterType.BOOLEAN -> null + } return SdkConfigCustomization.copyField(fieldName, map) } @@ -139,49 +153,59 @@ fun decoratorForBuiltIn( } } - override fun endpointCustomizations(codegenContext: ClientCodegenContext): List = listOf( - object : EndpointCustomization { - override fun loadBuiltInFromServiceConfig(parameter: Parameter, configRef: String): Writable? = - when (parameter.builtIn) { - builtIn.builtIn -> writable { - val newtype = configParamNewtype(parameter, name, codegenContext.runtimeConfig) - val symbol = parameter.symbol().mapRustType { t -> t.stripOuter() } + override fun endpointCustomizations(codegenContext: ClientCodegenContext): List = + listOf( + object : EndpointCustomization { + override fun loadBuiltInFromServiceConfig( + parameter: Parameter, + configRef: String, + ): Writable? = + when (parameter.builtIn) { + builtIn.builtIn -> + writable { + val newtype = configParamNewtype(parameter, name, codegenContext.runtimeConfig) + val symbol = parameter.symbol().mapRustType { t -> t.stripOuter() } + rustTemplate( + """$configRef.#{load_from_service_config_layer}""", + "load_from_service_config_layer" to loadFromConfigBag(symbol.name, newtype), + ) + } + + else -> null + } + + override fun setBuiltInOnServiceConfig( + name: String, + value: Node, + configBuilderRef: String, + ): Writable? { + if (name != builtIn.builtIn.get()) { + return null + } + return writable { rustTemplate( - """$configRef.#{load_from_service_config_layer}""", - "load_from_service_config_layer" to loadFromConfigBag(symbol.name, newtype), + "let $configBuilderRef = $configBuilderRef.${nameOverride ?: builtIn.name.rustName()}(#{value});", + "value" to value.toWritable(), ) } - - else -> null - } - - override fun setBuiltInOnServiceConfig(name: String, value: Node, configBuilderRef: String): Writable? { - if (name != builtIn.builtIn.get()) { - return null } - return writable { - rustTemplate( - "let $configBuilderRef = $configBuilderRef.${nameOverride ?: builtIn.name.rustName()}(#{value});", - "value" to value.toWritable(), - ) - } - } - }, - ) + }, + ) } } -private val endpointUrlDocs = writable { - rust( - """ - /// Sets the endpoint URL used to communicate with this service +private val endpointUrlDocs = + writable { + rust( + """ + /// Sets the endpoint URL used to communicate with this service - /// Note: this is used in combination with other endpoint rules, e.g. an API that applies a host-label prefix - /// will be prefixed onto this URL. To fully override the endpoint resolver, use - /// [`Builder::endpoint_resolver`]. - """.trimIndent(), - ) -} + /// Note: this is used in combination with other endpoint rules, e.g. an API that applies a host-label prefix + /// will be prefixed onto this URL. To fully override the endpoint resolver, use + /// [`Builder::endpoint_resolver`]. + """.trimIndent(), + ) + } fun Node.toWritable(): Writable { val node = this diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/HttpRequestChecksumDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/HttpRequestChecksumDecorator.kt index f2648ee914c..37bff546c94 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/HttpRequestChecksumDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/HttpRequestChecksumDecorator.kt @@ -26,19 +26,20 @@ import software.amazon.smithy.rust.codegen.core.util.getTrait import software.amazon.smithy.rust.codegen.core.util.inputShape import software.amazon.smithy.rust.codegen.core.util.orNull -private fun RuntimeConfig.awsInlineableHttpRequestChecksum() = RuntimeType.forInlineDependency( - InlineAwsDependency.forRustFile( - "http_request_checksum", visibility = Visibility.PUBCRATE, - CargoDependency.Bytes, - CargoDependency.Http, - CargoDependency.HttpBody, - CargoDependency.Tracing, - CargoDependency.smithyChecksums(this), - CargoDependency.smithyHttp(this), - CargoDependency.smithyRuntimeApiClient(this), - CargoDependency.smithyTypes(this), - ), -) +private fun RuntimeConfig.awsInlineableHttpRequestChecksum() = + RuntimeType.forInlineDependency( + InlineAwsDependency.forRustFile( + "http_request_checksum", visibility = Visibility.PUBCRATE, + CargoDependency.Bytes, + CargoDependency.Http, + CargoDependency.HttpBody, + CargoDependency.Tracing, + CargoDependency.smithyChecksums(this), + CargoDependency.smithyHttp(this), + CargoDependency.smithyRuntimeApiClient(this), + CargoDependency.smithyTypes(this), + ), + ) class HttpRequestChecksumDecorator : ClientCodegenDecorator { override val name: String = "HttpRequestChecksum" @@ -48,8 +49,7 @@ class HttpRequestChecksumDecorator : ClientCodegenDecorator { codegenContext: ClientCodegenContext, operation: OperationShape, baseCustomizations: List, - ): List = - baseCustomizations + HttpRequestChecksumCustomization(codegenContext, operation) + ): List = baseCustomizations + HttpRequestChecksumCustomization(codegenContext, operation) } private fun HttpChecksumTrait.requestAlgorithmMember( @@ -112,41 +112,44 @@ class HttpRequestChecksumCustomization( ) : OperationCustomization() { private val runtimeConfig = codegenContext.runtimeConfig - override fun section(section: OperationSection): Writable = writable { - // Get the `HttpChecksumTrait`, returning early if this `OperationShape` doesn't have one - val checksumTrait = operationShape.getTrait() ?: return@writable - val requestAlgorithmMember = checksumTrait.requestAlgorithmMember(codegenContext, operationShape) - val inputShape = codegenContext.model.expectShape(operationShape.inputShape) + override fun section(section: OperationSection): Writable = + writable { + // Get the `HttpChecksumTrait`, returning early if this `OperationShape` doesn't have one + val checksumTrait = operationShape.getTrait() ?: return@writable + val requestAlgorithmMember = checksumTrait.requestAlgorithmMember(codegenContext, operationShape) + val inputShape = codegenContext.model.expectShape(operationShape.inputShape) - when (section) { - is OperationSection.AdditionalInterceptors -> { - if (requestAlgorithmMember != null) { - section.registerInterceptor(runtimeConfig, this) { - val runtimeApi = RuntimeType.smithyRuntimeApiClient(runtimeConfig) - rustTemplate( - """ - #{RequestChecksumInterceptor}::new(|input: &#{Input}| { - let input: &#{OperationInput} = input.downcast_ref().expect("correct type"); - let checksum_algorithm = input.$requestAlgorithmMember(); - #{checksum_algorithm_to_str} - #{Result}::<_, #{BoxError}>::Ok(checksum_algorithm) - }) - """, - *preludeScope, - "BoxError" to RuntimeType.boxError(runtimeConfig), - "Input" to runtimeApi.resolve("client::interceptors::context::Input"), - "OperationInput" to codegenContext.symbolProvider.toSymbol(inputShape), - "RequestChecksumInterceptor" to runtimeConfig.awsInlineableHttpRequestChecksum() - .resolve("RequestChecksumInterceptor"), - "checksum_algorithm_to_str" to checksumTrait.checksumAlgorithmToStr( - codegenContext, - operationShape, - ), - ) + when (section) { + is OperationSection.AdditionalInterceptors -> { + if (requestAlgorithmMember != null) { + section.registerInterceptor(runtimeConfig, this) { + val runtimeApi = RuntimeType.smithyRuntimeApiClient(runtimeConfig) + rustTemplate( + """ + #{RequestChecksumInterceptor}::new(|input: &#{Input}| { + let input: &#{OperationInput} = input.downcast_ref().expect("correct type"); + let checksum_algorithm = input.$requestAlgorithmMember(); + #{checksum_algorithm_to_str} + #{Result}::<_, #{BoxError}>::Ok(checksum_algorithm) + }) + """, + *preludeScope, + "BoxError" to RuntimeType.boxError(runtimeConfig), + "Input" to runtimeApi.resolve("client::interceptors::context::Input"), + "OperationInput" to codegenContext.symbolProvider.toSymbol(inputShape), + "RequestChecksumInterceptor" to + runtimeConfig.awsInlineableHttpRequestChecksum() + .resolve("RequestChecksumInterceptor"), + "checksum_algorithm_to_str" to + checksumTrait.checksumAlgorithmToStr( + codegenContext, + operationShape, + ), + ) + } } } + else -> { } } - else -> { } } - } } diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/HttpResponseChecksumDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/HttpResponseChecksumDecorator.kt index e49e834ddd7..f5d5f67bd48 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/HttpResponseChecksumDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/HttpResponseChecksumDecorator.kt @@ -27,19 +27,20 @@ import software.amazon.smithy.rust.codegen.core.util.inputShape import software.amazon.smithy.rust.codegen.core.util.letIf import software.amazon.smithy.rust.codegen.core.util.orNull -private fun RuntimeConfig.awsInlineableHttpResponseChecksum() = RuntimeType.forInlineDependency( - InlineAwsDependency.forRustFile( - "http_response_checksum", visibility = Visibility.PUBCRATE, - CargoDependency.Bytes, - CargoDependency.Http, - CargoDependency.HttpBody, - CargoDependency.Tracing, - CargoDependency.smithyChecksums(this), - CargoDependency.smithyHttp(this), - CargoDependency.smithyRuntimeApiClient(this), - CargoDependency.smithyTypes(this), - ), -) +private fun RuntimeConfig.awsInlineableHttpResponseChecksum() = + RuntimeType.forInlineDependency( + InlineAwsDependency.forRustFile( + "http_response_checksum", visibility = Visibility.PUBCRATE, + CargoDependency.Bytes, + CargoDependency.Http, + CargoDependency.HttpBody, + CargoDependency.Tracing, + CargoDependency.smithyChecksums(this), + CargoDependency.smithyHttp(this), + CargoDependency.smithyRuntimeApiClient(this), + CargoDependency.smithyTypes(this), + ), + ) fun HttpChecksumTrait.requestValidationModeMember( codegenContext: ClientCodegenContext, @@ -60,9 +61,10 @@ class HttpResponseChecksumDecorator : ClientCodegenDecorator { codegenContext: ClientCodegenContext, operation: OperationShape, baseCustomizations: List, - ): List = baseCustomizations.letIf(applies(operation)) { - it + HttpResponseChecksumCustomization(codegenContext, operation) - } + ): List = + baseCustomizations.letIf(applies(operation)) { + it + HttpResponseChecksumCustomization(codegenContext, operation) + } } // This generator was implemented based on this spec: @@ -71,50 +73,52 @@ class HttpResponseChecksumCustomization( private val codegenContext: ClientCodegenContext, private val operationShape: OperationShape, ) : OperationCustomization() { - override fun section(section: OperationSection): Writable = writable { - val checksumTrait = operationShape.getTrait() ?: return@writable - val requestValidationModeMember = - checksumTrait.requestValidationModeMember(codegenContext, operationShape) ?: return@writable - val requestValidationModeMemberInner = if (requestValidationModeMember.isOptional) { - codegenContext.model.expectShape(requestValidationModeMember.target) - } else { - requestValidationModeMember - } - val validationModeName = codegenContext.symbolProvider.toMemberName(requestValidationModeMember) - val inputShape = codegenContext.model.expectShape(operationShape.inputShape) + override fun section(section: OperationSection): Writable = + writable { + val checksumTrait = operationShape.getTrait() ?: return@writable + val requestValidationModeMember = + checksumTrait.requestValidationModeMember(codegenContext, operationShape) ?: return@writable + val requestValidationModeMemberInner = + if (requestValidationModeMember.isOptional) { + codegenContext.model.expectShape(requestValidationModeMember.target) + } else { + requestValidationModeMember + } + val validationModeName = codegenContext.symbolProvider.toMemberName(requestValidationModeMember) + val inputShape = codegenContext.model.expectShape(operationShape.inputShape) - when (section) { - is OperationSection.AdditionalInterceptors -> { - section.registerInterceptor(codegenContext.runtimeConfig, this) { - // CRC32, CRC32C, SHA256, SHA1 -> "crc32", "crc32c", "sha256", "sha1" - val responseAlgorithms = checksumTrait.responseAlgorithms - .map { algorithm -> algorithm.lowercase() }.joinToString(", ") { algorithm -> "\"$algorithm\"" } - val runtimeApi = RuntimeType.smithyRuntimeApiClient(codegenContext.runtimeConfig) - rustTemplate( - """ - #{ResponseChecksumInterceptor}::new( - [$responseAlgorithms].as_slice(), - |input: &#{Input}| { - ${""/* - Per [the spec](https://smithy.io/2.0/aws/aws-core.html#http-response-checksums), - we check to see if it's the `ENABLED` variant - */} - let input: &#{OperationInput} = input.downcast_ref().expect("correct type"); - matches!(input.$validationModeName(), #{Some}(#{ValidationModeShape}::Enabled)) - } + when (section) { + is OperationSection.AdditionalInterceptors -> { + section.registerInterceptor(codegenContext.runtimeConfig, this) { + // CRC32, CRC32C, SHA256, SHA1 -> "crc32", "crc32c", "sha256", "sha1" + val responseAlgorithms = + checksumTrait.responseAlgorithms + .map { algorithm -> algorithm.lowercase() }.joinToString(", ") { algorithm -> "\"$algorithm\"" } + val runtimeApi = RuntimeType.smithyRuntimeApiClient(codegenContext.runtimeConfig) + rustTemplate( + """ + #{ResponseChecksumInterceptor}::new( + [$responseAlgorithms].as_slice(), + |input: &#{Input}| { + ${""/* Per [the spec](https://smithy.io/2.0/aws/aws-core.html#http-response-checksums), + we check to see if it's the `ENABLED` variant */} + let input: &#{OperationInput} = input.downcast_ref().expect("correct type"); + matches!(input.$validationModeName(), #{Some}(#{ValidationModeShape}::Enabled)) + } + ) + """, + *preludeScope, + "ResponseChecksumInterceptor" to + codegenContext.runtimeConfig.awsInlineableHttpResponseChecksum() + .resolve("ResponseChecksumInterceptor"), + "Input" to runtimeApi.resolve("client::interceptors::context::Input"), + "OperationInput" to codegenContext.symbolProvider.toSymbol(inputShape), + "ValidationModeShape" to codegenContext.symbolProvider.toSymbol(requestValidationModeMemberInner), ) - """, - *preludeScope, - "ResponseChecksumInterceptor" to codegenContext.runtimeConfig.awsInlineableHttpResponseChecksum() - .resolve("ResponseChecksumInterceptor"), - "Input" to runtimeApi.resolve("client::interceptors::context::Input"), - "OperationInput" to codegenContext.symbolProvider.toSymbol(inputShape), - "ValidationModeShape" to codegenContext.symbolProvider.toSymbol(requestValidationModeMemberInner), - ) + } } - } - else -> {} + else -> {} + } } - } } diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/InlineAwsDependency.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/InlineAwsDependency.kt index 4f34f3a7503..db78816a319 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/InlineAwsDependency.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/InlineAwsDependency.kt @@ -11,10 +11,18 @@ import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.Visibility object InlineAwsDependency { - fun forRustFile(file: String, visibility: Visibility = Visibility.PRIVATE, vararg additionalDependency: RustDependency): InlineDependency = - forRustFileAs(file, file, visibility, *additionalDependency) + fun forRustFile( + file: String, + visibility: Visibility = Visibility.PRIVATE, + vararg additionalDependency: RustDependency, + ): InlineDependency = forRustFileAs(file, file, visibility, *additionalDependency) - fun forRustFileAs(file: String, moduleName: String, visibility: Visibility = Visibility.PRIVATE, vararg additionalDependency: RustDependency): InlineDependency = + fun forRustFileAs( + file: String, + moduleName: String, + visibility: Visibility = Visibility.PRIVATE, + vararg additionalDependency: RustDependency, + ): InlineDependency = InlineDependency.Companion.forRustFile( RustModule.new(moduleName, visibility, documentationOverride = ""), "/aws-inlineable/src/$file.rs", diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/IntegrationTestDependencies.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/IntegrationTestDependencies.kt index ad1224546aa..7883b7f3e33 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/IntegrationTestDependencies.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/IntegrationTestDependencies.kt @@ -60,12 +60,13 @@ class IntegrationTestDecorator : ClientCodegenDecorator { return if (Files.exists(testPackagePath) && Files.exists(testPackagePath.resolve("Cargo.toml"))) { val hasTests = Files.exists(testPackagePath.resolve("tests")) val hasBenches = Files.exists(testPackagePath.resolve("benches")) - baseCustomizations + IntegrationTestDependencies( - codegenContext, - moduleName, - hasTests, - hasBenches, - ) + baseCustomizations + + IntegrationTestDependencies( + codegenContext, + moduleName, + hasTests, + hasBenches, + ) } else { baseCustomizations } @@ -79,42 +80,48 @@ class IntegrationTestDependencies( private val hasBenches: Boolean, ) : LibRsCustomization() { private val runtimeConfig = codegenContext.runtimeConfig - override fun section(section: LibRsSection) = when (section) { - is LibRsSection.Body -> testDependenciesOnly { - if (hasTests) { - val smithyAsync = CargoDependency.smithyAsync(codegenContext.runtimeConfig) - .copy(features = setOf("test-util"), scope = DependencyScope.Dev) - val smithyTypes = CargoDependency.smithyTypes(codegenContext.runtimeConfig) - .copy(features = setOf("test-util"), scope = DependencyScope.Dev) - addDependency(awsRuntime(runtimeConfig).toDevDependency().withFeature("test-util")) - addDependency(FuturesUtil) - addDependency(SerdeJson) - addDependency(smithyAsync) - addDependency(smithyProtocolTestHelpers(codegenContext.runtimeConfig)) - addDependency(smithyRuntime(runtimeConfig).copy(features = setOf("test-util", "wire-mock"), scope = DependencyScope.Dev)) - addDependency(smithyRuntimeApiTestUtil(runtimeConfig)) - addDependency(smithyTypes) - addDependency(Tokio) - addDependency(Tracing.toDevDependency()) - addDependency(TracingSubscriber) - } - if (hasBenches) { - addDependency(Criterion) - } - for (serviceSpecific in serviceSpecificCustomizations()) { - serviceSpecific.section(section)(this) - } - } - else -> emptySection - } + override fun section(section: LibRsSection) = + when (section) { + is LibRsSection.Body -> + testDependenciesOnly { + if (hasTests) { + val smithyAsync = + CargoDependency.smithyAsync(codegenContext.runtimeConfig) + .copy(features = setOf("test-util"), scope = DependencyScope.Dev) + val smithyTypes = + CargoDependency.smithyTypes(codegenContext.runtimeConfig) + .copy(features = setOf("test-util"), scope = DependencyScope.Dev) + addDependency(awsRuntime(runtimeConfig).toDevDependency().withFeature("test-util")) + addDependency(FuturesUtil) + addDependency(SerdeJson) + addDependency(smithyAsync) + addDependency(smithyProtocolTestHelpers(codegenContext.runtimeConfig)) + addDependency(smithyRuntime(runtimeConfig).copy(features = setOf("test-util", "wire-mock"), scope = DependencyScope.Dev)) + addDependency(smithyRuntimeApiTestUtil(runtimeConfig)) + addDependency(smithyTypes) + addDependency(Tokio) + addDependency(Tracing.toDevDependency()) + addDependency(TracingSubscriber) + } + if (hasBenches) { + addDependency(Criterion) + } + for (serviceSpecific in serviceSpecificCustomizations()) { + serviceSpecific.section(section)(this) + } + } - private fun serviceSpecificCustomizations(): List = when (moduleName) { - "transcribestreaming" -> listOf(TranscribeTestDependencies()) - "s3" -> listOf(S3TestDependencies(codegenContext)) - "dynamodb" -> listOf(DynamoDbTestDependencies()) - else -> emptyList() - } + else -> emptySection + } + + private fun serviceSpecificCustomizations(): List = + when (moduleName) { + "transcribestreaming" -> listOf(TranscribeTestDependencies()) + "s3" -> listOf(S3TestDependencies(codegenContext)) + "dynamodb" -> listOf(DynamoDbTestDependencies()) + else -> emptyList() + } } class TranscribeTestDependencies : LibRsCustomization() { diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/InvocationIdDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/InvocationIdDecorator.kt index def448bd55d..321bedee9dd 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/InvocationIdDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/InvocationIdDecorator.kt @@ -30,8 +30,7 @@ class InvocationIdDecorator : ClientCodegenDecorator { override fun configCustomizations( codegenContext: ClientCodegenContext, baseCustomizations: List, - ): List = - baseCustomizations + InvocationIdConfigCustomization(codegenContext) + ): List = baseCustomizations + InvocationIdConfigCustomization(codegenContext) } private class InvocationIdRuntimePluginCustomization( @@ -39,17 +38,19 @@ private class InvocationIdRuntimePluginCustomization( ) : ServiceRuntimePluginCustomization() { private val runtimeConfig = codegenContext.runtimeConfig private val awsRuntime = AwsRuntimeType.awsRuntime(runtimeConfig) - private val codegenScope = arrayOf( - "InvocationIdInterceptor" to awsRuntime.resolve("invocation_id::InvocationIdInterceptor"), - ) + private val codegenScope = + arrayOf( + "InvocationIdInterceptor" to awsRuntime.resolve("invocation_id::InvocationIdInterceptor"), + ) - override fun section(section: ServiceRuntimePluginSection): Writable = writable { - if (section is ServiceRuntimePluginSection.RegisterRuntimeComponents) { - section.registerInterceptor(this) { - rustTemplate("#{InvocationIdInterceptor}::new()", *codegenScope) + override fun section(section: ServiceRuntimePluginSection): Writable = + writable { + if (section is ServiceRuntimePluginSection.RegisterRuntimeComponents) { + section.registerInterceptor(this) { + rustTemplate("#{InvocationIdInterceptor}::new()", *codegenScope) + } } } - } } const val GENERATOR_DOCS: String = @@ -61,51 +62,53 @@ private class InvocationIdConfigCustomization( codegenContext: ClientCodegenContext, ) : ConfigCustomization() { private val awsRuntime = AwsRuntimeType.awsRuntime(codegenContext.runtimeConfig) - private val codegenScope = arrayOf( - *preludeScope, - "InvocationIdGenerator" to awsRuntime.resolve("invocation_id::InvocationIdGenerator"), - "SharedInvocationIdGenerator" to awsRuntime.resolve("invocation_id::SharedInvocationIdGenerator"), - ) + private val codegenScope = + arrayOf( + *preludeScope, + "InvocationIdGenerator" to awsRuntime.resolve("invocation_id::InvocationIdGenerator"), + "SharedInvocationIdGenerator" to awsRuntime.resolve("invocation_id::SharedInvocationIdGenerator"), + ) - override fun section(section: ServiceConfig): Writable = writable { - when (section) { - is ServiceConfig.BuilderImpl -> { - docs("Overrides the default invocation ID generator.\n\n$GENERATOR_DOCS") - rustTemplate( - """ - pub fn invocation_id_generator(mut self, gen: impl #{InvocationIdGenerator} + 'static) -> Self { - self.set_invocation_id_generator(#{Some}(#{SharedInvocationIdGenerator}::new(gen))); - self - } - """, - *codegenScope, - ) + override fun section(section: ServiceConfig): Writable = + writable { + when (section) { + is ServiceConfig.BuilderImpl -> { + docs("Overrides the default invocation ID generator.\n\n$GENERATOR_DOCS") + rustTemplate( + """ + pub fn invocation_id_generator(mut self, gen: impl #{InvocationIdGenerator} + 'static) -> Self { + self.set_invocation_id_generator(#{Some}(#{SharedInvocationIdGenerator}::new(gen))); + self + } + """, + *codegenScope, + ) - docs("Overrides the default invocation ID generator.\n\n$GENERATOR_DOCS") - rustTemplate( - """ - pub fn set_invocation_id_generator(&mut self, gen: #{Option}<#{SharedInvocationIdGenerator}>) -> &mut Self { - self.config.store_or_unset(gen); - self - } - """, - *codegenScope, - ) - } + docs("Overrides the default invocation ID generator.\n\n$GENERATOR_DOCS") + rustTemplate( + """ + pub fn set_invocation_id_generator(&mut self, gen: #{Option}<#{SharedInvocationIdGenerator}>) -> &mut Self { + self.config.store_or_unset(gen); + self + } + """, + *codegenScope, + ) + } - is ServiceConfig.ConfigImpl -> { - docs("Returns the invocation ID generator if one was given in config.\n\n$GENERATOR_DOCS") - rustTemplate( - """ - pub fn invocation_id_generator(&self) -> #{Option}<#{SharedInvocationIdGenerator}> { - self.config.load::<#{SharedInvocationIdGenerator}>().cloned() - } - """, - *codegenScope, - ) - } + is ServiceConfig.ConfigImpl -> { + docs("Returns the invocation ID generator if one was given in config.\n\n$GENERATOR_DOCS") + rustTemplate( + """ + pub fn invocation_id_generator(&self) -> #{Option}<#{SharedInvocationIdGenerator}> { + self.config.load::<#{SharedInvocationIdGenerator}>().cloned() + } + """, + *codegenScope, + ) + } - else -> {} + else -> {} + } } - } } diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/RecursionDetectionDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/RecursionDetectionDecorator.kt index 5809d8b4b39..d56fb680d6b 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/RecursionDetectionDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/RecursionDetectionDecorator.kt @@ -27,15 +27,16 @@ class RecursionDetectionDecorator : ClientCodegenDecorator { private class RecursionDetectionRuntimePluginCustomization( private val codegenContext: ClientCodegenContext, ) : ServiceRuntimePluginCustomization() { - override fun section(section: ServiceRuntimePluginSection): Writable = writable { - if (section is ServiceRuntimePluginSection.RegisterRuntimeComponents) { - section.registerInterceptor(this) { - rust( - "#T::new()", - AwsRuntimeType.awsRuntime(codegenContext.runtimeConfig) - .resolve("recursion_detection::RecursionDetectionInterceptor"), - ) + override fun section(section: ServiceRuntimePluginSection): Writable = + writable { + if (section is ServiceRuntimePluginSection.RegisterRuntimeComponents) { + section.registerInterceptor(this) { + rust( + "#T::new()", + AwsRuntimeType.awsRuntime(codegenContext.runtimeConfig) + .resolve("recursion_detection::RecursionDetectionInterceptor"), + ) + } } } - } } diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/RegionDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/RegionDecorator.kt index 4e77e1c49b6..c877d4aac54 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/RegionDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/RegionDecorator.kt @@ -28,53 +28,54 @@ import software.amazon.smithy.rust.codegen.core.util.dq import software.amazon.smithy.rust.codegen.core.util.extendIf import software.amazon.smithy.rust.codegen.core.util.thenSingletonListOf -/* Example Generated Code */ -/* -pub struct Config { - pub(crate) region: Option, -} - -impl std::fmt::Debug for Config { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let mut config = f.debug_struct("Config"); - config.finish() - } -} - -impl Config { - pub fn builder() -> Builder { - Builder::default() - } -} - -#[derive(Default)] -pub struct Builder { - region: Option, -} - -impl Builder { - pub fn new() -> Self { - Self::default() - } - - pub fn region(mut self, region: impl Into>) -> Self { - self.region = region.into(); - self - } - - pub fn build(self) -> Config { - Config { - region: self.region, - } - } -} - -#[test] -fn test_1() { - fn assert_send_sync() {} - assert_send_sync::(); -} - */ +// Example Generated Code +// ---------------------- +// +// pub struct Config { +// pub(crate) region: Option, +// } +// +// impl std::fmt::Debug for Config { +// fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +// let mut config = f.debug_struct("Config"); +// config.finish() +// } +// } +// +// impl Config { +// pub fn builder() -> Builder { +// Builder::default() +// } +// } +// +// #[derive(Default)] +// pub struct Builder { +// region: Option, +// } +// +// impl Builder { +// pub fn new() -> Self { +// Self::default() +// } +// +// pub fn region(mut self, region: impl Into>) -> Self { +// self.region = region.into(); +// self +// } +// +// pub fn build(self) -> Config { +// Config { +// region: self.region, +// } +// } +// } +// +// #[test] +// fn test_1() { +// fn assert_send_sync() {} +// assert_send_sync::(); +// } +// class RegionDecorator : ClientCodegenDecorator { override val name: String = "Region" @@ -83,8 +84,9 @@ class RegionDecorator : ClientCodegenDecorator { // Services that have an endpoint ruleset that references the SDK::Region built in, or // that use SigV4, both need a configurable region. private fun usesRegion(codegenContext: ClientCodegenContext) = - codegenContext.getBuiltIn(AwsBuiltIns.REGION) != null || ServiceIndex.of(codegenContext.model) - .getEffectiveAuthSchemes(codegenContext.serviceShape).containsKey(SigV4Trait.ID) + codegenContext.getBuiltIn(AwsBuiltIns.REGION) != null || + ServiceIndex.of(codegenContext.model) + .getEffectiveAuthSchemes(codegenContext.serviceShape).containsKey(SigV4Trait.ID) override fun configCustomizations( codegenContext: ClientCodegenContext, @@ -114,20 +116,28 @@ class RegionDecorator : ClientCodegenDecorator { } return listOf( object : EndpointCustomization { - override fun loadBuiltInFromServiceConfig(parameter: Parameter, configRef: String): Writable? { + override fun loadBuiltInFromServiceConfig( + parameter: Parameter, + configRef: String, + ): Writable? { return when (parameter.builtIn) { - AwsBuiltIns.REGION.builtIn -> writable { - rustTemplate( - "$configRef.load::<#{Region}>().map(|r|r.as_ref().to_owned())", - "Region" to region(codegenContext.runtimeConfig).resolve("Region"), - ) - } + AwsBuiltIns.REGION.builtIn -> + writable { + rustTemplate( + "$configRef.load::<#{Region}>().map(|r|r.as_ref().to_owned())", + "Region" to region(codegenContext.runtimeConfig).resolve("Region"), + ) + } else -> null } } - override fun setBuiltInOnServiceConfig(name: String, value: Node, configBuilderRef: String): Writable? { + override fun setBuiltInOnServiceConfig( + name: String, + value: Node, + configBuilderRef: String, + ): Writable? { if (name != AwsBuiltIns.REGION.builtIn.get()) { return null } @@ -146,62 +156,64 @@ class RegionDecorator : ClientCodegenDecorator { class RegionProviderConfig(codegenContext: ClientCodegenContext) : ConfigCustomization() { private val region = region(codegenContext.runtimeConfig) private val moduleUseName = codegenContext.moduleUseName() - private val codegenScope = arrayOf( - *preludeScope, - "Region" to configReexport(region.resolve("Region")), - ) - - override fun section(section: ServiceConfig) = writable { - when (section) { - ServiceConfig.ConfigImpl -> { - rustTemplate( - """ - /// Returns the AWS region, if it was provided. - pub fn region(&self) -> #{Option}<&#{Region}> { - self.config.load::<#{Region}>() - } - """, - *codegenScope, - ) - } + private val codegenScope = + arrayOf( + *preludeScope, + "Region" to configReexport(region.resolve("Region")), + ) - ServiceConfig.BuilderImpl -> { - rustTemplate( - """ - /// Sets the AWS region to use when making requests. - /// - /// ## Examples - /// ```no_run - /// use aws_types::region::Region; - /// use $moduleUseName::config::{Builder, Config}; - /// - /// let config = $moduleUseName::Config::builder() - /// .region(Region::new("us-east-1")) - /// .build(); - /// ``` - pub fn region(mut self, region: impl #{Into}<#{Option}<#{Region}>>) -> Self { - self.set_region(region.into()); - self - } - """, - *codegenScope, - ) + override fun section(section: ServiceConfig) = + writable { + when (section) { + ServiceConfig.ConfigImpl -> { + rustTemplate( + """ + /// Returns the AWS region, if it was provided. + pub fn region(&self) -> #{Option}<&#{Region}> { + self.config.load::<#{Region}>() + } + """, + *codegenScope, + ) + } - rustTemplate( - """ - /// Sets the AWS region to use when making requests. - pub fn set_region(&mut self, region: #{Option}<#{Region}>) -> &mut Self { - self.config.store_or_unset(region); - self - } - """, - *codegenScope, - ) - } + ServiceConfig.BuilderImpl -> { + rustTemplate( + """ + /// Sets the AWS region to use when making requests. + /// + /// ## Examples + /// ```no_run + /// use aws_types::region::Region; + /// use $moduleUseName::config::{Builder, Config}; + /// + /// let config = $moduleUseName::Config::builder() + /// .region(Region::new("us-east-1")) + /// .build(); + /// ``` + pub fn region(mut self, region: impl #{Into}<#{Option}<#{Region}>>) -> Self { + self.set_region(region.into()); + self + } + """, + *codegenScope, + ) + + rustTemplate( + """ + /// Sets the AWS region to use when making requests. + pub fn set_region(&mut self, region: #{Option}<#{Region}>) -> &mut Self { + self.config.store_or_unset(region); + self + } + """, + *codegenScope, + ) + } - else -> emptySection + else -> emptySection + } } - } } fun region(runtimeConfig: RuntimeConfig) = AwsRuntimeType.awsTypes(runtimeConfig).resolve("region") diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/RetryClassifierDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/RetryClassifierDecorator.kt index 6b3c0185e34..1ea17d13d91 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/RetryClassifierDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/RetryClassifierDecorator.kt @@ -21,8 +21,9 @@ class RetryClassifierDecorator : ClientCodegenDecorator { codegenContext: ClientCodegenContext, operation: OperationShape, baseCustomizations: List, - ): List = baseCustomizations + - OperationRetryClassifiersFeature(codegenContext, operation) + ): List = + baseCustomizations + + OperationRetryClassifiersFeature(codegenContext, operation) } class OperationRetryClassifiersFeature( @@ -32,17 +33,19 @@ class OperationRetryClassifiersFeature( private val runtimeConfig = codegenContext.runtimeConfig private val symbolProvider = codegenContext.symbolProvider - override fun section(section: OperationSection) = when (section) { - is OperationSection.RetryClassifiers -> writable { - section.registerRetryClassifier(this) { - rustTemplate( - "#{AwsErrorCodeClassifier}::<#{OperationError}>::new()", - "AwsErrorCodeClassifier" to AwsRuntimeType.awsRuntime(runtimeConfig).resolve("retries::classifiers::AwsErrorCodeClassifier"), - "OperationError" to symbolProvider.symbolForOperationError(operation), - ) - } - } + override fun section(section: OperationSection) = + when (section) { + is OperationSection.RetryClassifiers -> + writable { + section.registerRetryClassifier(this) { + rustTemplate( + "#{AwsErrorCodeClassifier}::<#{OperationError}>::new()", + "AwsErrorCodeClassifier" to AwsRuntimeType.awsRuntime(runtimeConfig).resolve("retries::classifiers::AwsErrorCodeClassifier"), + "OperationError" to symbolProvider.symbolForOperationError(operation), + ) + } + } - else -> emptySection - } + else -> emptySection + } } diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/RetryInformationHeaderDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/RetryInformationHeaderDecorator.kt index 2a354327f60..4b9077af763 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/RetryInformationHeaderDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/RetryInformationHeaderDecorator.kt @@ -29,20 +29,21 @@ private class AddRetryInformationHeaderInterceptors(codegenContext: ClientCodege private val runtimeConfig = codegenContext.runtimeConfig private val awsRuntime = AwsRuntimeType.awsRuntime(runtimeConfig) - override fun section(section: ServiceRuntimePluginSection): Writable = writable { - if (section is ServiceRuntimePluginSection.RegisterRuntimeComponents) { - // Track the latency between client and server. - section.registerInterceptor(this) { - rust( - "#T::new()", - awsRuntime.resolve("service_clock_skew::ServiceClockSkewInterceptor"), - ) - } + override fun section(section: ServiceRuntimePluginSection): Writable = + writable { + if (section is ServiceRuntimePluginSection.RegisterRuntimeComponents) { + // Track the latency between client and server. + section.registerInterceptor(this) { + rust( + "#T::new()", + awsRuntime.resolve("service_clock_skew::ServiceClockSkewInterceptor"), + ) + } - // Add request metadata to outgoing requests. Sets a header. - section.registerInterceptor(this) { - rust("#T::new()", awsRuntime.resolve("request_info::RequestInfoInterceptor")) + // Add request metadata to outgoing requests. Sets a header. + section.registerInterceptor(this) { + rust("#T::new()", awsRuntime.resolve("request_info::RequestInfoInterceptor")) + } } } - } } diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/SdkConfigDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/SdkConfigDecorator.kt index 69a1299d738..f071835fc4e 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/SdkConfigDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/SdkConfigDecorator.kt @@ -49,14 +49,16 @@ object SdkConfigCustomization { * SdkConfigCustomization.copyField("some_string_field") { rust("|s|s.to_to_string()") } * ``` */ - fun copyField(fieldName: String, map: Writable?) = - adhocCustomization { section -> - val mapBlock = map?.let { writable { rust(".map(#W)", it) } } ?: writable { } - rustTemplate( - "${section.serviceConfigBuilder}.set_$fieldName(${section.sdkConfig}.$fieldName()#{map});", - "map" to mapBlock, - ) - } + fun copyField( + fieldName: String, + map: Writable?, + ) = adhocCustomization { section -> + val mapBlock = map?.let { writable { rust(".map(#W)", it) } } ?: writable { } + rustTemplate( + "${section.serviceConfigBuilder}.set_$fieldName(${section.sdkConfig}.$fieldName()#{map});", + "map" to mapBlock, + ) + } } /** @@ -110,10 +112,14 @@ class SdkConfigDecorator : ClientCodegenDecorator { return baseCustomizations + NewFromShared(codegenContext.runtimeConfig) } - override fun extras(codegenContext: ClientCodegenContext, rustCrate: RustCrate) { - val codegenScope = arrayOf( - "SdkConfig" to AwsRuntimeType.awsTypes(codegenContext.runtimeConfig).resolve("sdk_config::SdkConfig"), - ) + override fun extras( + codegenContext: ClientCodegenContext, + rustCrate: RustCrate, + ) { + val codegenScope = + arrayOf( + "SdkConfig" to AwsRuntimeType.awsTypes(codegenContext.runtimeConfig).resolve("sdk_config::SdkConfig"), + ) rustCrate.withModule(ClientRustModule.config) { rustTemplate( @@ -133,15 +139,16 @@ class SdkConfigDecorator : ClientCodegenDecorator { } } """, - "augmentBuilder" to writable { - writeCustomizations( - codegenContext.rootDecorator.extraSections(codegenContext), - SdkConfigSection.CopySdkConfigToClientConfig( - sdkConfig = "input", - serviceConfigBuilder = "builder", - ), - ) - }, + "augmentBuilder" to + writable { + writeCustomizations( + codegenContext.rootDecorator.extraSections(codegenContext), + SdkConfigSection.CopySdkConfigToClientConfig( + sdkConfig = "input", + serviceConfigBuilder = "builder", + ), + ) + }, *codegenScope, ) } @@ -149,23 +156,25 @@ class SdkConfigDecorator : ClientCodegenDecorator { } class NewFromShared(runtimeConfig: RuntimeConfig) : ConfigCustomization() { - private val codegenScope = arrayOf( - "SdkConfig" to AwsRuntimeType.awsTypes(runtimeConfig).resolve("sdk_config::SdkConfig"), - ) + private val codegenScope = + arrayOf( + "SdkConfig" to AwsRuntimeType.awsTypes(runtimeConfig).resolve("sdk_config::SdkConfig"), + ) override fun section(section: ServiceConfig): Writable { return when (section) { - ServiceConfig.ConfigImpl -> writable { - rustTemplate( - """ - /// Creates a new [service config](crate::Config) from a [shared `config`](#{SdkConfig}). - pub fn new(config: &#{SdkConfig}) -> Self { - Builder::from(config).build() - } - """, - *codegenScope, - ) - } + ServiceConfig.ConfigImpl -> + writable { + rustTemplate( + """ + /// Creates a new [service config](crate::Config) from a [shared `config`](#{SdkConfig}). + pub fn new(config: &#{SdkConfig}) -> Self { + Builder::from(config).build() + } + """, + *codegenScope, + ) + } else -> emptySection } diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/SigV4AuthDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/SigV4AuthDecorator.kt index a6a0bc7b80e..fb1385bdc62 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/SigV4AuthDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/SigV4AuthDecorator.kt @@ -43,29 +43,33 @@ class SigV4AuthDecorator : ClientCodegenDecorator { private val sigv4a = "sigv4a" - private fun sigv4(runtimeConfig: RuntimeConfig) = writable { - val awsRuntimeAuthModule = AwsRuntimeType.awsRuntime(runtimeConfig).resolve("auth") - rust("#T", awsRuntimeAuthModule.resolve("sigv4::SCHEME_ID")) - } + private fun sigv4(runtimeConfig: RuntimeConfig) = + writable { + val awsRuntimeAuthModule = AwsRuntimeType.awsRuntime(runtimeConfig).resolve("auth") + rust("#T", awsRuntimeAuthModule.resolve("sigv4::SCHEME_ID")) + } - private fun sigv4a(runtimeConfig: RuntimeConfig) = writable { - val awsRuntimeAuthModule = AwsRuntimeType.awsRuntime(runtimeConfig).resolve("auth") - featureGateBlock(sigv4a) { - rust("#T", awsRuntimeAuthModule.resolve("sigv4a::SCHEME_ID")) + private fun sigv4a(runtimeConfig: RuntimeConfig) = + writable { + val awsRuntimeAuthModule = AwsRuntimeType.awsRuntime(runtimeConfig).resolve("auth") + featureGateBlock(sigv4a) { + rust("#T", awsRuntimeAuthModule.resolve("sigv4a::SCHEME_ID")) + } } - } override fun authOptions( codegenContext: ClientCodegenContext, operationShape: OperationShape, baseAuthSchemeOptions: List, ): List { - val supportsSigV4a = codegenContext.serviceShape.supportedAuthSchemes().contains(sigv4a) - .thenSingletonListOf { sigv4a(codegenContext.runtimeConfig) } - return baseAuthSchemeOptions + AuthSchemeOption.StaticAuthSchemeOption( - SigV4Trait.ID, - listOf(sigv4(codegenContext.runtimeConfig)) + supportsSigV4a, - ) + val supportsSigV4a = + codegenContext.serviceShape.supportedAuthSchemes().contains(sigv4a) + .thenSingletonListOf { sigv4a(codegenContext.runtimeConfig) } + return baseAuthSchemeOptions + + AuthSchemeOption.StaticAuthSchemeOption( + SigV4Trait.ID, + listOf(sigv4(codegenContext.runtimeConfig)) + supportsSigV4a, + ) } override fun serviceRuntimePluginCustomizations( @@ -86,7 +90,10 @@ class SigV4AuthDecorator : ClientCodegenDecorator { ): List = baseCustomizations + SigV4SigningConfig(codegenContext.runtimeConfig, codegenContext.serviceShape.getTrait()) - override fun extras(codegenContext: ClientCodegenContext, rustCrate: RustCrate) { + override fun extras( + codegenContext: ClientCodegenContext, + rustCrate: RustCrate, + ) { if (codegenContext.serviceShape.supportedAuthSchemes().contains("sigv4a")) { // Add optional feature for SigV4a support rustCrate.mergeFeature(Feature("sigv4a", true, listOf("aws-runtime/sigv4a"))) @@ -98,55 +105,57 @@ private class SigV4SigningConfig( runtimeConfig: RuntimeConfig, private val sigV4Trait: SigV4Trait?, ) : ConfigCustomization() { - private val codegenScope = arrayOf( - "Region" to AwsRuntimeType.awsTypes(runtimeConfig).resolve("region::Region"), - "SigningName" to AwsRuntimeType.awsTypes(runtimeConfig).resolve("SigningName"), - "SigningRegion" to AwsRuntimeType.awsTypes(runtimeConfig).resolve("region::SigningRegion"), - ) + private val codegenScope = + arrayOf( + "Region" to AwsRuntimeType.awsTypes(runtimeConfig).resolve("region::Region"), + "SigningName" to AwsRuntimeType.awsTypes(runtimeConfig).resolve("SigningName"), + "SigningRegion" to AwsRuntimeType.awsTypes(runtimeConfig).resolve("region::SigningRegion"), + ) - override fun section(section: ServiceConfig): Writable = writable { - if (sigV4Trait != null) { - when (section) { - ServiceConfig.ConfigImpl -> { - rust( - """ - /// The signature version 4 service signing name to use in the credential scope when signing requests. - /// - /// The signing service may be overridden by the `Endpoint`, or by specifying a custom - /// [`SigningName`](aws_types::SigningName) during operation construction - pub fn signing_name(&self) -> &'static str { - ${sigV4Trait.name.dq()} - } - """, - ) - } + override fun section(section: ServiceConfig): Writable = + writable { + if (sigV4Trait != null) { + when (section) { + ServiceConfig.ConfigImpl -> { + rust( + """ + /// The signature version 4 service signing name to use in the credential scope when signing requests. + /// + /// The signing service may be overridden by the `Endpoint`, or by specifying a custom + /// [`SigningName`](aws_types::SigningName) during operation construction + pub fn signing_name(&self) -> &'static str { + ${sigV4Trait.name.dq()} + } + """, + ) + } - ServiceConfig.BuilderBuild -> { - rustTemplate( - """ - layer.store_put(#{SigningName}::from_static(${sigV4Trait.name.dq()})); - layer.load::<#{Region}>().cloned().map(|r| layer.store_put(#{SigningRegion}::from(r))); - """, - *codegenScope, - ) - } + ServiceConfig.BuilderBuild -> { + rustTemplate( + """ + layer.store_put(#{SigningName}::from_static(${sigV4Trait.name.dq()})); + layer.load::<#{Region}>().cloned().map(|r| layer.store_put(#{SigningRegion}::from(r))); + """, + *codegenScope, + ) + } - is ServiceConfig.OperationConfigOverride -> { - rustTemplate( - """ - resolver.config_mut() - .load::<#{Region}>() - .cloned() - .map(|r| resolver.config_mut().store_put(#{SigningRegion}::from(r))); - """, - *codegenScope, - ) - } + is ServiceConfig.OperationConfigOverride -> { + rustTemplate( + """ + resolver.config_mut() + .load::<#{Region}>() + .cloned() + .map(|r| resolver.config_mut().store_put(#{SigningRegion}::from(r))); + """, + *codegenScope, + ) + } - else -> {} + else -> {} + } } } - } } private class AuthServiceRuntimePluginCustomization(private val codegenContext: ClientCodegenContext) : @@ -161,49 +170,53 @@ private class AuthServiceRuntimePluginCustomization(private val codegenContext: ) } - override fun section(section: ServiceRuntimePluginSection): Writable = writable { - when (section) { - is ServiceRuntimePluginSection.RegisterRuntimeComponents -> { - val serviceHasEventStream = codegenContext.serviceShape.hasEventStreamOperations(codegenContext.model) - if (serviceHasEventStream) { - // enable the aws-runtime `sign-eventstream` feature - addDependency( - AwsCargoDependency.awsRuntime(runtimeConfig).withFeature("event-stream").toType().toSymbol(), - ) - } - section.registerAuthScheme(this) { - rustTemplate("#{SharedAuthScheme}::new(#{SigV4AuthScheme}::new())", *codegenScope) - } + override fun section(section: ServiceRuntimePluginSection): Writable = + writable { + when (section) { + is ServiceRuntimePluginSection.RegisterRuntimeComponents -> { + val serviceHasEventStream = codegenContext.serviceShape.hasEventStreamOperations(codegenContext.model) + if (serviceHasEventStream) { + // enable the aws-runtime `sign-eventstream` feature + addDependency( + AwsCargoDependency.awsRuntime(runtimeConfig).withFeature("event-stream").toType().toSymbol(), + ) + } + section.registerAuthScheme(this) { + rustTemplate("#{SharedAuthScheme}::new(#{SigV4AuthScheme}::new())", *codegenScope) + } - if (codegenContext.serviceShape.supportedAuthSchemes().contains("sigv4a")) { - featureGateBlock("sigv4a") { - section.registerAuthScheme(this) { - rustTemplate("#{SharedAuthScheme}::new(#{SigV4aAuthScheme}::new())", *codegenScope) + if (codegenContext.serviceShape.supportedAuthSchemes().contains("sigv4a")) { + featureGateBlock("sigv4a") { + section.registerAuthScheme(this) { + rustTemplate("#{SharedAuthScheme}::new(#{SigV4aAuthScheme}::new())", *codegenScope) + } } } } - } - else -> {} + else -> {} + } } - } } -fun needsAmzSha256(service: ServiceShape) = when (service.id) { - ShapeId.from("com.amazonaws.s3#AmazonS3") -> true - ShapeId.from("com.amazonaws.s3control#AWSS3ControlServiceV20180820") -> true - else -> false -} +fun needsAmzSha256(service: ServiceShape) = + when (service.id) { + ShapeId.from("com.amazonaws.s3#AmazonS3") -> true + ShapeId.from("com.amazonaws.s3control#AWSS3ControlServiceV20180820") -> true + else -> false + } -fun disableDoubleEncode(service: ServiceShape) = when (service.id) { - ShapeId.from("com.amazonaws.s3#AmazonS3") -> true - else -> false -} +fun disableDoubleEncode(service: ServiceShape) = + when (service.id) { + ShapeId.from("com.amazonaws.s3#AmazonS3") -> true + else -> false + } -fun disableUriPathNormalization(service: ServiceShape) = when (service.id) { - ShapeId.from("com.amazonaws.s3#AmazonS3") -> true - else -> false -} +fun disableUriPathNormalization(service: ServiceShape) = + when (service.id) { + ShapeId.from("com.amazonaws.s3#AmazonS3") -> true + else -> false + } private class AuthOperationCustomization(private val codegenContext: ClientCodegenContext) : OperationCustomization() { private val runtimeConfig = codegenContext.runtimeConfig @@ -218,45 +231,47 @@ private class AuthOperationCustomization(private val codegenContext: ClientCodeg } private val serviceIndex = ServiceIndex.of(codegenContext.model) - override fun section(section: OperationSection): Writable = writable { - when (section) { - is OperationSection.AdditionalRuntimePluginConfig -> { - val authSchemes = - serviceIndex.getEffectiveAuthSchemes(codegenContext.serviceShape, section.operationShape) - if (authSchemes.containsKey(SigV4Trait.ID)) { - val unsignedPayload = section.operationShape.hasTrait() - val doubleUriEncode = unsignedPayload || !disableDoubleEncode(codegenContext.serviceShape) - val contentSha256Header = needsAmzSha256(codegenContext.serviceShape) || unsignedPayload - val normalizeUrlPath = !disableUriPathNormalization(codegenContext.serviceShape) - rustTemplate( - """ - let mut signing_options = #{SigningOptions}::default(); - signing_options.double_uri_encode = $doubleUriEncode; - signing_options.content_sha256_header = $contentSha256Header; - signing_options.normalize_uri_path = $normalizeUrlPath; - signing_options.payload_override = #{payload_override}; + override fun section(section: OperationSection): Writable = + writable { + when (section) { + is OperationSection.AdditionalRuntimePluginConfig -> { + val authSchemes = + serviceIndex.getEffectiveAuthSchemes(codegenContext.serviceShape, section.operationShape) + if (authSchemes.containsKey(SigV4Trait.ID)) { + val unsignedPayload = section.operationShape.hasTrait() + val doubleUriEncode = unsignedPayload || !disableDoubleEncode(codegenContext.serviceShape) + val contentSha256Header = needsAmzSha256(codegenContext.serviceShape) || unsignedPayload + val normalizeUrlPath = !disableUriPathNormalization(codegenContext.serviceShape) + rustTemplate( + """ + let mut signing_options = #{SigningOptions}::default(); + signing_options.double_uri_encode = $doubleUriEncode; + signing_options.content_sha256_header = $contentSha256Header; + signing_options.normalize_uri_path = $normalizeUrlPath; + signing_options.payload_override = #{payload_override}; - ${section.newLayerName}.store_put(#{SigV4OperationSigningConfig} { - signing_options, - ..#{Default}::default() - }); - """, - *codegenScope, - "payload_override" to writable { - if (unsignedPayload) { - rustTemplate("Some(#{SignableBody}::UnsignedPayload)", *codegenScope) - } else if (section.operationShape.isInputEventStream(codegenContext.model)) { - // TODO(EventStream): Is this actually correct for all Event Stream operations? - rustTemplate("Some(#{SignableBody}::Bytes(&[]))", *codegenScope) - } else { - rust("None") - } - }, - ) + ${section.newLayerName}.store_put(#{SigV4OperationSigningConfig} { + signing_options, + ..#{Default}::default() + }); + """, + *codegenScope, + "payload_override" to + writable { + if (unsignedPayload) { + rustTemplate("Some(#{SignableBody}::UnsignedPayload)", *codegenScope) + } else if (section.operationShape.isInputEventStream(codegenContext.model)) { + // TODO(EventStream): Is this actually correct for all Event Stream operations? + rustTemplate("Some(#{SignableBody}::Bytes(&[]))", *codegenScope) + } else { + rust("None") + } + }, + ) + } } - } - else -> {} + else -> {} + } } - } } diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/UserAgentDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/UserAgentDecorator.kt index 058acab9cce..b8b69f275cf 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/UserAgentDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/UserAgentDecorator.kt @@ -42,8 +42,7 @@ class UserAgentDecorator : ClientCodegenDecorator { override fun serviceRuntimePluginCustomizations( codegenContext: ClientCodegenContext, baseCustomizations: List, - ): List = - baseCustomizations + AddApiMetadataIntoConfigBag(codegenContext) + ): List = baseCustomizations + AddApiMetadataIntoConfigBag(codegenContext) override fun extraSections(codegenContext: ClientCodegenContext): List { return listOf( @@ -56,7 +55,10 @@ class UserAgentDecorator : ClientCodegenDecorator { /** * Adds a static `API_METADATA` variable to the crate `config` containing the serviceId & the version of the crate for this individual service */ - override fun extras(codegenContext: ClientCodegenContext, rustCrate: RustCrate) { + override fun extras( + codegenContext: ClientCodegenContext, + rustCrate: RustCrate, + ) { val runtimeConfig = codegenContext.runtimeConfig // We are generating an AWS SDK, the service needs to have the AWS service trait @@ -88,85 +90,91 @@ class UserAgentDecorator : ClientCodegenDecorator { private val runtimeConfig = codegenContext.runtimeConfig private val awsRuntime = AwsRuntimeType.awsRuntime(runtimeConfig) - override fun section(section: ServiceRuntimePluginSection): Writable = writable { - when (section) { - is ServiceRuntimePluginSection.RegisterRuntimeComponents -> { - section.registerInterceptor(this) { - rust("#T::new()", awsRuntime.resolve("user_agent::UserAgentInterceptor")) + override fun section(section: ServiceRuntimePluginSection): Writable = + writable { + when (section) { + is ServiceRuntimePluginSection.RegisterRuntimeComponents -> { + section.registerInterceptor(this) { + rust("#T::new()", awsRuntime.resolve("user_agent::UserAgentInterceptor")) + } } + else -> emptySection } - else -> emptySection } - } } private class AppNameCustomization(codegenContext: ClientCodegenContext) : ConfigCustomization() { private val runtimeConfig = codegenContext.runtimeConfig - private val codegenScope = arrayOf( - *preludeScope, - "AppName" to AwsRuntimeType.awsTypes(runtimeConfig).resolve("app_name::AppName"), - "AwsUserAgent" to AwsRuntimeType.awsHttp(runtimeConfig).resolve("user_agent::AwsUserAgent"), - ) + private val codegenScope = + arrayOf( + *preludeScope, + "AppName" to AwsRuntimeType.awsTypes(runtimeConfig).resolve("app_name::AppName"), + "AwsUserAgent" to AwsRuntimeType.awsHttp(runtimeConfig).resolve("user_agent::AwsUserAgent"), + ) override fun section(section: ServiceConfig): Writable = when (section) { - is ServiceConfig.BuilderImpl -> writable { - rustTemplate( - """ - /// Sets the name of the app that is using the client. - /// - /// This _optional_ name is used to identify the application in the user agent that - /// gets sent along with requests. - pub fn app_name(mut self, app_name: #{AppName}) -> Self { - self.set_app_name(Some(app_name)); - self - } - """, - *codegenScope, - ) - - rustTemplate( - """ - /// Sets the name of the app that is using the client. - /// - /// This _optional_ name is used to identify the application in the user agent that - /// gets sent along with requests. - pub fn set_app_name(&mut self, app_name: #{Option}<#{AppName}>) -> &mut Self { - self.config.store_or_unset(app_name); - self - } - """, - *codegenScope, - ) - } + is ServiceConfig.BuilderImpl -> + writable { + rustTemplate( + """ + /// Sets the name of the app that is using the client. + /// + /// This _optional_ name is used to identify the application in the user agent that + /// gets sent along with requests. + pub fn app_name(mut self, app_name: #{AppName}) -> Self { + self.set_app_name(Some(app_name)); + self + } + """, + *codegenScope, + ) + + rustTemplate( + """ + /// Sets the name of the app that is using the client. + /// + /// This _optional_ name is used to identify the application in the user agent that + /// gets sent along with requests. + pub fn set_app_name(&mut self, app_name: #{Option}<#{AppName}>) -> &mut Self { + self.config.store_or_unset(app_name); + self + } + """, + *codegenScope, + ) + } - is ServiceConfig.BuilderBuild -> writable { - rust("layer.store_put(#T.clone());", ClientRustModule.Meta.toType().resolve("API_METADATA")) - } + is ServiceConfig.BuilderBuild -> + writable { + rust("layer.store_put(#T.clone());", ClientRustModule.Meta.toType().resolve("API_METADATA")) + } - is ServiceConfig.ConfigImpl -> writable { - rustTemplate( - """ - /// Returns the name of the app that is using the client, if it was provided. - /// - /// This _optional_ name is used to identify the application in the user agent that - /// gets sent along with requests. - pub fn app_name(&self) -> #{Option}<&#{AppName}> { - self.config.load::<#{AppName}>() - } - """, - *codegenScope, - ) - } + is ServiceConfig.ConfigImpl -> + writable { + rustTemplate( + """ + /// Returns the name of the app that is using the client, if it was provided. + /// + /// This _optional_ name is used to identify the application in the user agent that + /// gets sent along with requests. + pub fn app_name(&self) -> #{Option}<&#{AppName}> { + self.config.load::<#{AppName}>() + } + """, + *codegenScope, + ) + } - is ServiceConfig.DefaultForTests -> writable { - rustTemplate( - """ - self.config.store_put(#{AwsUserAgent}::for_tests()); - """, - *codegenScope, - ) - } + is ServiceConfig.DefaultForTests -> + writable { + rustTemplate( + """ + self.config.store_put(#{AwsUserAgent}::for_tests()); + """, + *codegenScope, + ) + } else -> emptySection } diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/DisabledAuthDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/DisabledAuthDecorator.kt index 4dfc75c30be..b9c319df128 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/DisabledAuthDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/DisabledAuthDecorator.kt @@ -28,10 +28,13 @@ class DisabledAuthDecorator : ClientCodegenDecorator { ), ) - private fun applies(service: ServiceShape) = - optionalAuth.containsKey(service.id) + private fun applies(service: ServiceShape) = optionalAuth.containsKey(service.id) - override fun transformModel(service: ServiceShape, model: Model, settings: ClientRustSettings): Model { + override fun transformModel( + service: ServiceShape, + model: Model, + settings: ClientRustSettings, + ): Model { if (!applies(service)) { return model } diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/RemoveDefaults.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/RemoveDefaults.kt index 45bf8bc1987..e12c1ba8aa9 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/RemoveDefaults.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/RemoveDefaults.kt @@ -24,15 +24,19 @@ import java.util.logging.Logger object RemoveDefaults { private val logger: Logger = Logger.getLogger(javaClass.name) - fun processModel(model: Model, removeDefaultsFrom: Set): Model { + fun processModel( + model: Model, + removeDefaultsFrom: Set, + ): Model { val removedRootDefaults: MutableSet = HashSet() - val removedRootDefaultsModel = ModelTransformer.create().mapShapes(model) { shape -> - shape.letIf(shouldRemoveRootDefault(shape, removeDefaultsFrom)) { - logger.info("Removing default trait from root $shape") - removedRootDefaults.add(shape.id) - removeDefault(shape) + val removedRootDefaultsModel = + ModelTransformer.create().mapShapes(model) { shape -> + shape.letIf(shouldRemoveRootDefault(shape, removeDefaultsFrom)) { + logger.info("Removing default trait from root $shape") + removedRootDefaults.add(shape.id) + removeDefault(shape) + } } - } return ModelTransformer.create().mapShapes(removedRootDefaultsModel) { shape -> shape.letIf(shouldRemoveMemberDefault(shape, removedRootDefaults, removeDefaultsFrom)) { @@ -42,7 +46,10 @@ object RemoveDefaults { } } - private fun shouldRemoveRootDefault(shape: Shape, removeDefaultsFrom: Set): Boolean { + private fun shouldRemoveRootDefault( + shape: Shape, + removeDefaultsFrom: Set, + ): Boolean { return shape !is MemberShape && removeDefaultsFrom.contains(shape.id) && shape.hasTrait() } diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/RemoveDefaultsDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/RemoveDefaultsDecorator.kt index 1e1e205f45e..88b1ada994c 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/RemoveDefaultsDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/RemoveDefaultsDecorator.kt @@ -25,56 +25,68 @@ class RemoveDefaultsDecorator : ClientCodegenDecorator { // Service shape id -> Shape id of each root shape to remove the default from. // TODO(https://github.com/smithy-lang/smithy-rs/issues/3220): Remove this customization after model updates. - private val removeDefaults: Map> = mapOf( - "com.amazonaws.amplifyuibuilder#AmplifyUIBuilder" to setOf( - "com.amazonaws.amplifyuibuilder#ListComponentsLimit", - "com.amazonaws.amplifyuibuilder#ListFormsLimit", - "com.amazonaws.amplifyuibuilder#ListThemesLimit", - ), - "com.amazonaws.drs#ElasticDisasterRecoveryService" to setOf( - "com.amazonaws.drs#Validity", - "com.amazonaws.drs#CostOptimizationConfiguration\$burstBalanceThreshold", - "com.amazonaws.drs#CostOptimizationConfiguration\$burstBalanceDeltaThreshold", - "com.amazonaws.drs.synthetic#ListStagingAccountsInput\$maxResults", - "com.amazonaws.drs#StrictlyPositiveInteger", - "com.amazonaws.drs#MaxResultsType", - "com.amazonaws.drs#MaxResultsReplicatingSourceServers", - "com.amazonaws.drs#LaunchActionOrder", - ), - "com.amazonaws.evidently#Evidently" to setOf( - "com.amazonaws.evidently#ResultsPeriod", - ), - "com.amazonaws.location#LocationService" to setOf( - "com.amazonaws.location.synthetic#ListPlaceIndexesInput\$MaxResults", - "com.amazonaws.location.synthetic#SearchPlaceIndexForSuggestionsInput\$MaxResults", - "com.amazonaws.location#PlaceIndexSearchResultLimit", - ), - "com.amazonaws.paymentcryptographydata#PaymentCryptographyDataPlane" to setOf( - "com.amazonaws.paymentcryptographydata#IntegerRangeBetween4And12", - ), - "com.amazonaws.emrserverless#AwsToledoWebService" to setOf( - "com.amazonaws.emrserverless#WorkerCounts", - ), - "com.amazonaws.s3control#AWSS3ControlServiceV20180820" to setOf( - "com.amazonaws.s3control#PublicAccessBlockConfiguration\$BlockPublicAcls", - "com.amazonaws.s3control#PublicAccessBlockConfiguration\$IgnorePublicAcls", - "com.amazonaws.s3control#PublicAccessBlockConfiguration\$BlockPublicPolicy", - "com.amazonaws.s3control#PublicAccessBlockConfiguration\$RestrictPublicBuckets", - ), - "com.amazonaws.iot#AWSIotService" to setOf( - "com.amazonaws.iot#ThingConnectivity\$connected", - "com.amazonaws.iot.synthetic#UpdateProvisioningTemplateInput\$enabled", - "com.amazonaws.iot.synthetic#CreateProvisioningTemplateInput\$enabled", - "com.amazonaws.iot.synthetic#DescribeProvisioningTemplateOutput\$enabled", - "com.amazonaws.iot.synthetic#DescribeProvisioningTemplateOutput\$enabled", - "com.amazonaws.iot#ProvisioningTemplateSummary\$enabled", - ), - ).map { (k, v) -> k.shapeId() to v.map { it.shapeId() }.toSet() }.toMap() + private val removeDefaults: Map> = + mapOf( + "com.amazonaws.amplifyuibuilder#AmplifyUIBuilder" to + setOf( + "com.amazonaws.amplifyuibuilder#ListComponentsLimit", + "com.amazonaws.amplifyuibuilder#ListFormsLimit", + "com.amazonaws.amplifyuibuilder#ListThemesLimit", + ), + "com.amazonaws.drs#ElasticDisasterRecoveryService" to + setOf( + "com.amazonaws.drs#Validity", + "com.amazonaws.drs#CostOptimizationConfiguration\$burstBalanceThreshold", + "com.amazonaws.drs#CostOptimizationConfiguration\$burstBalanceDeltaThreshold", + "com.amazonaws.drs.synthetic#ListStagingAccountsInput\$maxResults", + "com.amazonaws.drs#StrictlyPositiveInteger", + "com.amazonaws.drs#MaxResultsType", + "com.amazonaws.drs#MaxResultsReplicatingSourceServers", + "com.amazonaws.drs#LaunchActionOrder", + ), + "com.amazonaws.evidently#Evidently" to + setOf( + "com.amazonaws.evidently#ResultsPeriod", + ), + "com.amazonaws.location#LocationService" to + setOf( + "com.amazonaws.location.synthetic#ListPlaceIndexesInput\$MaxResults", + "com.amazonaws.location.synthetic#SearchPlaceIndexForSuggestionsInput\$MaxResults", + "com.amazonaws.location#PlaceIndexSearchResultLimit", + ), + "com.amazonaws.paymentcryptographydata#PaymentCryptographyDataPlane" to + setOf( + "com.amazonaws.paymentcryptographydata#IntegerRangeBetween4And12", + ), + "com.amazonaws.emrserverless#AwsToledoWebService" to + setOf( + "com.amazonaws.emrserverless#WorkerCounts", + ), + "com.amazonaws.s3control#AWSS3ControlServiceV20180820" to + setOf( + "com.amazonaws.s3control#PublicAccessBlockConfiguration\$BlockPublicAcls", + "com.amazonaws.s3control#PublicAccessBlockConfiguration\$IgnorePublicAcls", + "com.amazonaws.s3control#PublicAccessBlockConfiguration\$BlockPublicPolicy", + "com.amazonaws.s3control#PublicAccessBlockConfiguration\$RestrictPublicBuckets", + ), + "com.amazonaws.iot#AWSIotService" to + setOf( + "com.amazonaws.iot#ThingConnectivity\$connected", + "com.amazonaws.iot.synthetic#UpdateProvisioningTemplateInput\$enabled", + "com.amazonaws.iot.synthetic#CreateProvisioningTemplateInput\$enabled", + "com.amazonaws.iot.synthetic#DescribeProvisioningTemplateOutput\$enabled", + "com.amazonaws.iot.synthetic#DescribeProvisioningTemplateOutput\$enabled", + "com.amazonaws.iot#ProvisioningTemplateSummary\$enabled", + ), + ).map { (k, v) -> k.shapeId() to v.map { it.shapeId() }.toSet() }.toMap() - private fun applies(service: ServiceShape) = - removeDefaults.containsKey(service.id) + private fun applies(service: ServiceShape) = removeDefaults.containsKey(service.id) - override fun transformModel(service: ServiceShape, model: Model, settings: ClientRustSettings): Model { + override fun transformModel( + service: ServiceShape, + model: Model, + settings: ClientRustSettings, + ): Model { if (!applies(service)) { return model } diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/ServiceSpecificDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/ServiceSpecificDecorator.kt index 4850d299efb..7a493f702d1 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/ServiceSpecificDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/ServiceSpecificDecorator.kt @@ -50,7 +50,10 @@ class ServiceSpecificDecorator( /** Decorator order */ override val order: Byte = 0, ) : ClientCodegenDecorator { - private fun T.maybeApply(serviceId: ToShapeId, delegatedValue: () -> T): T = + private fun T.maybeApply( + serviceId: ToShapeId, + delegatedValue: () -> T, + ): T = if (appliesToServiceId == serviceId.toShapeId()) { delegatedValue() } else { @@ -64,23 +67,26 @@ class ServiceSpecificDecorator( codegenContext: ClientCodegenContext, operationShape: OperationShape, baseAuthSchemeOptions: List, - ): List = baseAuthSchemeOptions.maybeApply(codegenContext.serviceShape) { - delegateTo.authOptions(codegenContext, operationShape, baseAuthSchemeOptions) - } + ): List = + baseAuthSchemeOptions.maybeApply(codegenContext.serviceShape) { + delegateTo.authOptions(codegenContext, operationShape, baseAuthSchemeOptions) + } override fun builderCustomizations( codegenContext: ClientCodegenContext, baseCustomizations: List, - ): List = baseCustomizations.maybeApply(codegenContext.serviceShape) { - delegateTo.builderCustomizations(codegenContext, baseCustomizations) - } + ): List = + baseCustomizations.maybeApply(codegenContext.serviceShape) { + delegateTo.builderCustomizations(codegenContext, baseCustomizations) + } override fun configCustomizations( codegenContext: ClientCodegenContext, baseCustomizations: List, - ): List = baseCustomizations.maybeApply(codegenContext.serviceShape) { - delegateTo.configCustomizations(codegenContext, baseCustomizations) - } + ): List = + baseCustomizations.maybeApply(codegenContext.serviceShape) { + delegateTo.configCustomizations(codegenContext, baseCustomizations) + } override fun crateManifestCustomizations(codegenContext: ClientCodegenContext): ManifestCustomizations = emptyMap().maybeApply(codegenContext.serviceShape) { @@ -95,18 +101,23 @@ class ServiceSpecificDecorator( override fun errorCustomizations( codegenContext: ClientCodegenContext, baseCustomizations: List, - ): List = baseCustomizations.maybeApply(codegenContext.serviceShape) { - delegateTo.errorCustomizations(codegenContext, baseCustomizations) - } + ): List = + baseCustomizations.maybeApply(codegenContext.serviceShape) { + delegateTo.errorCustomizations(codegenContext, baseCustomizations) + } override fun errorImplCustomizations( codegenContext: ClientCodegenContext, baseCustomizations: List, - ): List = baseCustomizations.maybeApply(codegenContext.serviceShape) { - delegateTo.errorImplCustomizations(codegenContext, baseCustomizations) - } + ): List = + baseCustomizations.maybeApply(codegenContext.serviceShape) { + delegateTo.errorImplCustomizations(codegenContext, baseCustomizations) + } - override fun extras(codegenContext: ClientCodegenContext, rustCrate: RustCrate) { + override fun extras( + codegenContext: ClientCodegenContext, + rustCrate: RustCrate, + ) { maybeApply(codegenContext.serviceShape) { delegateTo.extras(codegenContext, rustCrate) } @@ -115,19 +126,24 @@ class ServiceSpecificDecorator( override fun libRsCustomizations( codegenContext: ClientCodegenContext, baseCustomizations: List, - ): List = baseCustomizations.maybeApply(codegenContext.serviceShape) { - delegateTo.libRsCustomizations(codegenContext, baseCustomizations) - } + ): List = + baseCustomizations.maybeApply(codegenContext.serviceShape) { + delegateTo.libRsCustomizations(codegenContext, baseCustomizations) + } override fun operationCustomizations( codegenContext: ClientCodegenContext, operation: OperationShape, baseCustomizations: List, - ): List = baseCustomizations.maybeApply(codegenContext.serviceShape) { - delegateTo.operationCustomizations(codegenContext, operation, baseCustomizations) - } + ): List = + baseCustomizations.maybeApply(codegenContext.serviceShape) { + delegateTo.operationCustomizations(codegenContext, operation, baseCustomizations) + } - override fun protocols(serviceId: ShapeId, currentProtocols: ClientProtocolMap): ClientProtocolMap = + override fun protocols( + serviceId: ShapeId, + currentProtocols: ClientProtocolMap, + ): ClientProtocolMap = currentProtocols.maybeApply(serviceId) { delegateTo.protocols(serviceId, currentProtocols) } @@ -135,11 +151,16 @@ class ServiceSpecificDecorator( override fun structureCustomizations( codegenContext: ClientCodegenContext, baseCustomizations: List, - ): List = baseCustomizations.maybeApply(codegenContext.serviceShape) { - delegateTo.structureCustomizations(codegenContext, baseCustomizations) - } + ): List = + baseCustomizations.maybeApply(codegenContext.serviceShape) { + delegateTo.structureCustomizations(codegenContext, baseCustomizations) + } - override fun transformModel(service: ServiceShape, model: Model, settings: ClientRustSettings): Model = + override fun transformModel( + service: ServiceShape, + model: Model, + settings: ClientRustSettings, + ): Model = model.maybeApply(service) { delegateTo.transformModel(service, model, settings) } @@ -147,16 +168,18 @@ class ServiceSpecificDecorator( override fun serviceRuntimePluginCustomizations( codegenContext: ClientCodegenContext, baseCustomizations: List, - ): List = baseCustomizations.maybeApply(codegenContext.serviceShape) { - delegateTo.serviceRuntimePluginCustomizations(codegenContext, baseCustomizations) - } + ): List = + baseCustomizations.maybeApply(codegenContext.serviceShape) { + delegateTo.serviceRuntimePluginCustomizations(codegenContext, baseCustomizations) + } override fun protocolTestGenerator( codegenContext: ClientCodegenContext, baseGenerator: ProtocolTestGenerator, - ): ProtocolTestGenerator = baseGenerator.maybeApply(codegenContext.serviceShape) { - delegateTo.protocolTestGenerator(codegenContext, baseGenerator) - } + ): ProtocolTestGenerator = + baseGenerator.maybeApply(codegenContext.serviceShape) { + delegateTo.protocolTestGenerator(codegenContext, baseGenerator) + } override fun extraSections(codegenContext: ClientCodegenContext): List = listOf().maybeApply(codegenContext.serviceShape) { diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/apigateway/ApiGatewayDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/apigateway/ApiGatewayDecorator.kt index 4b9cc3616aa..7d13d1bf759 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/apigateway/ApiGatewayDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/apigateway/ApiGatewayDecorator.kt @@ -29,21 +29,24 @@ class ApiGatewayDecorator : ClientCodegenDecorator { private class ApiGatewayAcceptHeaderInterceptorCustomization(private val codegenContext: ClientCodegenContext) : ServiceRuntimePluginCustomization() { - override fun section(section: ServiceRuntimePluginSection): Writable = writable { - if (section is ServiceRuntimePluginSection.RegisterRuntimeComponents) { - section.registerInterceptor(this) { - rustTemplate( - "#{Interceptor}::default()", - "Interceptor" to RuntimeType.forInlineDependency( - InlineAwsDependency.forRustFile( - "apigateway_interceptors", - additionalDependency = arrayOf( - CargoDependency.smithyRuntimeApiClient(codegenContext.runtimeConfig), - ), - ), - ).resolve("AcceptHeaderInterceptor"), - ) + override fun section(section: ServiceRuntimePluginSection): Writable = + writable { + if (section is ServiceRuntimePluginSection.RegisterRuntimeComponents) { + section.registerInterceptor(this) { + rustTemplate( + "#{Interceptor}::default()", + "Interceptor" to + RuntimeType.forInlineDependency( + InlineAwsDependency.forRustFile( + "apigateway_interceptors", + additionalDependency = + arrayOf( + CargoDependency.smithyRuntimeApiClient(codegenContext.runtimeConfig), + ), + ), + ).resolve("AcceptHeaderInterceptor"), + ) + } } } - } } diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/ec2/Ec2Decorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/ec2/Ec2Decorator.kt index 693b2b572d9..9d7e9e8b9bd 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/ec2/Ec2Decorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/ec2/Ec2Decorator.kt @@ -16,6 +16,9 @@ class Ec2Decorator : ClientCodegenDecorator { // EC2 incorrectly models primitive shapes as unboxed when they actually // need to be boxed for the API to work properly - override fun transformModel(service: ServiceShape, model: Model, settings: ClientRustSettings): Model = - EC2MakePrimitivesOptional.processModel(model) + override fun transformModel( + service: ServiceShape, + model: Model, + settings: ClientRustSettings, + ): Model = EC2MakePrimitivesOptional.processModel(model) } diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/glacier/GlacierDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/glacier/GlacierDecorator.kt index 364faa8c6a5..9b4a79172ef 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/glacier/GlacierDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/glacier/GlacierDecorator.kt @@ -64,37 +64,39 @@ class GlacierDecorator : ClientCodegenDecorator { /** Implements the `GlacierAccountId` trait for inputs that have an `account_id` field */ private class GlacierAccountIdCustomization(private val codegenContext: ClientCodegenContext) : StructureCustomization() { - override fun section(section: StructureSection): Writable = writable { - if (section is StructureSection.AdditionalTraitImpls && section.shape.inputWithAccountId()) { - val inlineModule = inlineModule(codegenContext.runtimeConfig) - rustTemplate( - """ - impl #{GlacierAccountId} for ${section.structName} { - fn account_id_mut(&mut self) -> &mut Option { - &mut self.account_id + override fun section(section: StructureSection): Writable = + writable { + if (section is StructureSection.AdditionalTraitImpls && section.shape.inputWithAccountId()) { + val inlineModule = inlineModule(codegenContext.runtimeConfig) + rustTemplate( + """ + impl #{GlacierAccountId} for ${section.structName} { + fn account_id_mut(&mut self) -> &mut Option { + &mut self.account_id + } } - } - """, - "GlacierAccountId" to inlineModule.resolve("GlacierAccountId"), - ) + """, + "GlacierAccountId" to inlineModule.resolve("GlacierAccountId"), + ) + } } - } } /** Adds the `x-amz-glacier-version` header to all requests */ private class GlacierApiVersionCustomization(private val codegenContext: ClientCodegenContext) : ServiceRuntimePluginCustomization() { - override fun section(section: ServiceRuntimePluginSection): Writable = writable { - if (section is ServiceRuntimePluginSection.RegisterRuntimeComponents) { - val apiVersion = codegenContext.serviceShape.version - section.registerInterceptor(this) { - rustTemplate( - "#{Interceptor}::new(${apiVersion.dq()})", - "Interceptor" to inlineModule(codegenContext.runtimeConfig).resolve("GlacierApiVersionInterceptor"), - ) + override fun section(section: ServiceRuntimePluginSection): Writable = + writable { + if (section is ServiceRuntimePluginSection.RegisterRuntimeComponents) { + val apiVersion = codegenContext.serviceShape.version + section.registerInterceptor(this) { + rustTemplate( + "#{Interceptor}::new(${apiVersion.dq()})", + "Interceptor" to inlineModule(codegenContext.runtimeConfig).resolve("GlacierApiVersionInterceptor"), + ) + } } } - } } /** @@ -105,29 +107,30 @@ private class GlacierApiVersionCustomization(private val codegenContext: ClientC */ private class GlacierOperationInterceptorsCustomization(private val codegenContext: ClientCodegenContext) : OperationCustomization() { - override fun section(section: OperationSection): Writable = writable { - if (section is OperationSection.AdditionalInterceptors) { - val inputShape = codegenContext.model.expectShape(section.operationShape.inputShape) as StructureShape - val inlineModule = inlineModule(codegenContext.runtimeConfig) - if (inputShape.inputWithAccountId()) { - section.registerInterceptor(codegenContext.runtimeConfig, this) { - rustTemplate( - "#{Interceptor}::<#{Input}>::new()", - "Interceptor" to inlineModule.resolve("GlacierAccountIdAutofillInterceptor"), - "Input" to codegenContext.symbolProvider.toSymbol(inputShape), - ) + override fun section(section: OperationSection): Writable = + writable { + if (section is OperationSection.AdditionalInterceptors) { + val inputShape = codegenContext.model.expectShape(section.operationShape.inputShape) as StructureShape + val inlineModule = inlineModule(codegenContext.runtimeConfig) + if (inputShape.inputWithAccountId()) { + section.registerInterceptor(codegenContext.runtimeConfig, this) { + rustTemplate( + "#{Interceptor}::<#{Input}>::new()", + "Interceptor" to inlineModule.resolve("GlacierAccountIdAutofillInterceptor"), + "Input" to codegenContext.symbolProvider.toSymbol(inputShape), + ) + } } - } - if (section.operationShape.requiresTreeHashHeader()) { - section.registerInterceptor(codegenContext.runtimeConfig, this) { - rustTemplate( - "#{Interceptor}::default()", - "Interceptor" to inlineModule.resolve("GlacierTreeHashHeaderInterceptor"), - ) + if (section.operationShape.requiresTreeHashHeader()) { + section.registerInterceptor(codegenContext.runtimeConfig, this) { + rustTemplate( + "#{Interceptor}::default()", + "Interceptor" to inlineModule.resolve("GlacierTreeHashHeaderInterceptor"), + ) + } } } } - } } /** True when the operation requires tree hash headers */ @@ -138,19 +141,21 @@ private fun OperationShape.requiresTreeHashHeader(): Boolean = private fun StructureShape.inputWithAccountId(): Boolean = hasTrait() && members().any { it.memberName.lowercase() == "accountid" } -private fun inlineModule(runtimeConfig: RuntimeConfig) = RuntimeType.forInlineDependency( - InlineAwsDependency.forRustFile( - "glacier_interceptors", - additionalDependency = glacierInterceptorDependencies(runtimeConfig).toTypedArray(), - ), -) +private fun inlineModule(runtimeConfig: RuntimeConfig) = + RuntimeType.forInlineDependency( + InlineAwsDependency.forRustFile( + "glacier_interceptors", + additionalDependency = glacierInterceptorDependencies(runtimeConfig).toTypedArray(), + ), + ) -private fun glacierInterceptorDependencies(runtimeConfig: RuntimeConfig) = listOf( - AwsCargoDependency.awsRuntime(runtimeConfig), - AwsCargoDependency.awsSigv4(runtimeConfig), - CargoDependency.Bytes, - CargoDependency.Hex, - CargoDependency.Ring, - CargoDependency.smithyHttp(runtimeConfig), - CargoDependency.smithyRuntimeApiClient(runtimeConfig), -) +private fun glacierInterceptorDependencies(runtimeConfig: RuntimeConfig) = + listOf( + AwsCargoDependency.awsRuntime(runtimeConfig), + AwsCargoDependency.awsSigv4(runtimeConfig), + CargoDependency.Bytes, + CargoDependency.Hex, + CargoDependency.Ring, + CargoDependency.smithyHttp(runtimeConfig), + CargoDependency.smithyRuntimeApiClient(runtimeConfig), + ) diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/route53/Route53Decorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/route53/Route53Decorator.kt index d4aa5bf78e2..19dd0e5e57e 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/route53/Route53Decorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/route53/Route53Decorator.kt @@ -32,9 +32,14 @@ class Route53Decorator : ClientCodegenDecorator { override val name: String = "Route53" override val order: Byte = 0 private val logger: Logger = Logger.getLogger(javaClass.name) - private val resourceShapes = setOf(ShapeId.from("com.amazonaws.route53#ResourceId"), ShapeId.from("com.amazonaws.route53#ChangeId")) + private val resourceShapes = + setOf(ShapeId.from("com.amazonaws.route53#ResourceId"), ShapeId.from("com.amazonaws.route53#ChangeId")) - override fun transformModel(service: ServiceShape, model: Model, settings: ClientRustSettings): Model = + override fun transformModel( + service: ServiceShape, + model: Model, + settings: ClientRustSettings, + ): Model = ModelTransformer.create().mapShapes(model) { shape -> shape.letIf(isResourceId(shape)) { logger.info("Adding TrimResourceId trait to $shape") @@ -50,11 +55,12 @@ class Route53Decorator : ClientCodegenDecorator { val inputShape = operation.inputShape(codegenContext.model) val hostedZoneMember = inputShape.members().find { it.hasTrait() } return if (hostedZoneMember != null) { - baseCustomizations + TrimResourceIdCustomization( - codegenContext, - inputShape, - codegenContext.symbolProvider.toMemberName(hostedZoneMember), - ) + baseCustomizations + + TrimResourceIdCustomization( + codegenContext, + inputShape, + codegenContext.symbolProvider.toMemberName(hostedZoneMember), + ) } else { baseCustomizations } @@ -70,29 +76,29 @@ class TrimResourceIdCustomization( private val inputShape: StructureShape, private val fieldName: String, ) : OperationCustomization() { - - override fun section(section: OperationSection): Writable = writable { - when (section) { - is OperationSection.AdditionalInterceptors -> { - section.registerInterceptor(codegenContext.runtimeConfig, this) { - val smithyRuntimeApi = RuntimeType.smithyRuntimeApiClient(codegenContext.runtimeConfig) - val interceptor = - RuntimeType.forInlineDependency( - InlineAwsDependency.forRustFile("route53_resource_id_preprocessor"), - ).resolve("Route53ResourceIdInterceptor") - rustTemplate( - """ - #{Route53ResourceIdInterceptor}::new(|input: &mut #{Input}| { - &mut input.$fieldName - }) - """, - "Input" to codegenContext.symbolProvider.toSymbol(inputShape), - "Route53ResourceIdInterceptor" to interceptor, - "SharedInterceptor" to smithyRuntimeApi.resolve("client::interceptors::SharedInterceptor"), - ) + override fun section(section: OperationSection): Writable = + writable { + when (section) { + is OperationSection.AdditionalInterceptors -> { + section.registerInterceptor(codegenContext.runtimeConfig, this) { + val smithyRuntimeApi = RuntimeType.smithyRuntimeApiClient(codegenContext.runtimeConfig) + val interceptor = + RuntimeType.forInlineDependency( + InlineAwsDependency.forRustFile("route53_resource_id_preprocessor"), + ).resolve("Route53ResourceIdInterceptor") + rustTemplate( + """ + #{Route53ResourceIdInterceptor}::new(|input: &mut #{Input}| { + &mut input.$fieldName + }) + """, + "Input" to codegenContext.symbolProvider.toSymbol(inputShape), + "Route53ResourceIdInterceptor" to interceptor, + "SharedInterceptor" to smithyRuntimeApi.resolve("client::interceptors::SharedInterceptor"), + ) + } } + else -> {} } - else -> {} } - } } diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/s3/S3Decorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/s3/S3Decorator.kt index 3a003398111..5c91ead09ae 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/s3/S3Decorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/s3/S3Decorator.kt @@ -50,21 +50,29 @@ class S3Decorator : ClientCodegenDecorator { override val name: String = "S3" override val order: Byte = 0 private val logger: Logger = Logger.getLogger(javaClass.name) - private val invalidXmlRootAllowList = setOf( - // API returns GetObjectAttributes_Response_ instead of Output - ShapeId.from("com.amazonaws.s3#GetObjectAttributesOutput"), - ) + private val invalidXmlRootAllowList = + setOf( + // API returns GetObjectAttributes_Response_ instead of Output + ShapeId.from("com.amazonaws.s3#GetObjectAttributesOutput"), + ) override fun protocols( serviceId: ShapeId, currentProtocols: ProtocolMap, - ): ProtocolMap = currentProtocols + mapOf( - RestXmlTrait.ID to ClientRestXmlFactory { protocolConfig -> - S3ProtocolOverride(protocolConfig) - }, - ) + ): ProtocolMap = + currentProtocols + + mapOf( + RestXmlTrait.ID to + ClientRestXmlFactory { protocolConfig -> + S3ProtocolOverride(protocolConfig) + }, + ) - override fun transformModel(service: ServiceShape, model: Model, settings: ClientRustSettings): Model = + override fun transformModel( + service: ServiceShape, + model: Model, + settings: ClientRustSettings, + ): Model = ModelTransformer.create().mapShapes(model) { shape -> shape.letIf(isInInvalidXmlRootAllowList(shape)) { logger.info("Adding AllowInvalidXmlRoot trait to $it") @@ -93,7 +101,11 @@ class S3Decorator : ClientCodegenDecorator { override fun endpointCustomizations(codegenContext: ClientCodegenContext): List { return listOf( object : EndpointCustomization { - override fun setBuiltInOnServiceConfig(name: String, value: Node, configBuilderRef: String): Writable? { + override fun setBuiltInOnServiceConfig( + name: String, + value: Node, + configBuilderRef: String, + ): Writable? { if (!name.startsWith("AWS::S3")) { return null } @@ -114,27 +126,28 @@ class S3Decorator : ClientCodegenDecorator { operation: OperationShape, baseCustomizations: List, ): List { - return baseCustomizations + object : OperationCustomization() { - override fun section(section: OperationSection): Writable { - return writable { - when (section) { - is OperationSection.BeforeParseResponse -> { - section.body?.also { body -> - rustTemplate( - """ - if matches!(#{errors}::body_is_error($body), Ok(true)) { - ${section.forceError} = true; - } - """, - "errors" to RuntimeType.unwrappedXmlErrors(codegenContext.runtimeConfig), - ) + return baseCustomizations + + object : OperationCustomization() { + override fun section(section: OperationSection): Writable { + return writable { + when (section) { + is OperationSection.BeforeParseResponse -> { + section.body?.also { body -> + rustTemplate( + """ + if matches!(#{errors}::body_is_error($body), Ok(true)) { + ${section.forceError} = true; + } + """, + "errors" to RuntimeType.unwrappedXmlErrors(codegenContext.runtimeConfig), + ) + } } + else -> {} } - else -> {} } } } - } } private fun isInInvalidXmlRootAllowList(shape: Shape): Boolean { @@ -154,41 +167,45 @@ class FilterEndpointTests( } } - fun transform(model: Model): Model = ModelTransformer.create().mapTraits(model) { _, trait -> - when (trait) { - is EndpointTestsTrait -> EndpointTestsTrait.builder().testCases(updateEndpointTests(trait.testCases)) - .version(trait.version).build() + fun transform(model: Model): Model = + ModelTransformer.create().mapTraits(model) { _, trait -> + when (trait) { + is EndpointTestsTrait -> + EndpointTestsTrait.builder().testCases(updateEndpointTests(trait.testCases)) + .version(trait.version).build() - else -> trait + else -> trait + } } - } } // TODO(P96049742): This model transform may need to change depending on if and how the S3 model is updated. private class AddOptionalAuth { - fun transform(model: Model): Model = ModelTransformer.create().mapShapes(model) { shape -> - // Add @optionalAuth to all S3 operations - if (shape is OperationShape && !shape.hasTrait()) { - shape.toBuilder() - .addTrait(OptionalAuthTrait()) - .build() - } else { - shape + fun transform(model: Model): Model = + ModelTransformer.create().mapShapes(model) { shape -> + // Add @optionalAuth to all S3 operations + if (shape is OperationShape && !shape.hasTrait()) { + shape.toBuilder() + .addTrait(OptionalAuthTrait()) + .build() + } else { + shape + } } - } } class S3ProtocolOverride(codegenContext: CodegenContext) : RestXml(codegenContext) { private val runtimeConfig = codegenContext.runtimeConfig - private val errorScope = arrayOf( - *RuntimeType.preludeScope, - "Bytes" to RuntimeType.Bytes, - "ErrorMetadata" to RuntimeType.errorMetadata(runtimeConfig), - "ErrorBuilder" to RuntimeType.errorMetadataBuilder(runtimeConfig), - "Headers" to RuntimeType.headers(runtimeConfig), - "XmlDecodeError" to RuntimeType.smithyXml(runtimeConfig).resolve("decode::XmlDecodeError"), - "base_errors" to restXmlErrors, - ) + private val errorScope = + arrayOf( + *RuntimeType.preludeScope, + "Bytes" to RuntimeType.Bytes, + "ErrorMetadata" to RuntimeType.errorMetadata(runtimeConfig), + "ErrorBuilder" to RuntimeType.errorMetadataBuilder(runtimeConfig), + "Headers" to RuntimeType.headers(runtimeConfig), + "XmlDecodeError" to RuntimeType.smithyXml(runtimeConfig).resolve("decode::XmlDecodeError"), + "base_errors" to restXmlErrors, + ) override fun parseHttpErrorMetadata(operationShape: OperationShape): RuntimeType { return ProtocolFunctions.crossOperationFn("parse_http_error_metadata") { fnName -> diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/s3/StripBucketFromPath.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/s3/StripBucketFromPath.kt index f2559b05c72..8b6de837b91 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/s3/StripBucketFromPath.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/s3/StripBucketFromPath.kt @@ -14,18 +14,20 @@ import software.amazon.smithy.rust.codegen.core.util.letIf class StripBucketFromHttpPath { private val transformer = ModelTransformer.create() + fun transform(model: Model): Model { // Remove `/{Bucket}` from the path (http trait) // The endpoints 2.0 rules handle either placing the bucket into the virtual host or adding it to the path return transformer.mapTraits(model) { shape, trait -> when (trait) { is HttpTrait -> { - val appliedToOperation = shape - .asOperationShape() - .map { operation -> - model.expectShape(operation.inputShape, StructureShape::class.java) - .getMember("Bucket").isPresent - }.orElse(false) + val appliedToOperation = + shape + .asOperationShape() + .map { operation -> + model.expectShape(operation.inputShape, StructureShape::class.java) + .getMember("Bucket").isPresent + }.orElse(false) trait.letIf(appliedToOperation) { it.toBuilder().uri(UriPattern.parse(transformUri(trait.uri.toString()))).build() } diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/s3control/S3ControlDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/s3control/S3ControlDecorator.kt index ce96a33b827..6246eb1e524 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/s3control/S3ControlDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/s3control/S3ControlDecorator.kt @@ -24,24 +24,32 @@ class S3ControlDecorator : ClientCodegenDecorator { override val name: String = "S3Control" override val order: Byte = 0 - override fun transformModel(service: ServiceShape, model: Model, settings: ClientRustSettings): Model = - stripEndpointTrait("AccountId")(model) + override fun transformModel( + service: ServiceShape, + model: Model, + settings: ClientRustSettings, + ): Model = stripEndpointTrait("AccountId")(model) override fun endpointCustomizations(codegenContext: ClientCodegenContext): List { - return listOf(object : EndpointCustomization { - override fun setBuiltInOnServiceConfig(name: String, value: Node, configBuilderRef: String): Writable? { - if (!name.startsWith("AWS::S3Control")) { - return null + return listOf( + object : EndpointCustomization { + override fun setBuiltInOnServiceConfig( + name: String, + value: Node, + configBuilderRef: String, + ): Writable? { + if (!name.startsWith("AWS::S3Control")) { + return null + } + val builtIn = codegenContext.getBuiltIn(name) ?: return null + return writable { + rustTemplate( + "let $configBuilderRef = $configBuilderRef.${builtIn.name.rustName()}(#{value});", + "value" to value.toWritable(), + ) + } } - val builtIn = codegenContext.getBuiltIn(name) ?: return null - return writable { - rustTemplate( - "let $configBuilderRef = $configBuilderRef.${builtIn.name.rustName()}(#{value});", - "value" to value.toWritable(), - ) - } - } - }, + }, ) } } diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/sso/SSODecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/sso/SSODecorator.kt index 8ead07d3bce..db13a1bd86b 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/sso/SSODecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/sso/SSODecorator.kt @@ -24,7 +24,11 @@ class SSODecorator : ClientCodegenDecorator { private fun isAwsCredentials(shape: Shape): Boolean = shape.id == ShapeId.from("com.amazonaws.sso#RoleCredentials") - override fun transformModel(service: ServiceShape, model: Model, settings: ClientRustSettings): Model = + override fun transformModel( + service: ServiceShape, + model: Model, + settings: ClientRustSettings, + ): Model = ModelTransformer.create().mapShapes(model) { shape -> shape.letIf(isAwsCredentials(shape)) { (shape as StructureShape).toBuilder().addTrait(SensitiveTrait()).build() diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/sts/STSDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/sts/STSDecorator.kt index ed23c99d8f9..c4c0f06ff88 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/sts/STSDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/sts/STSDecorator.kt @@ -30,7 +30,11 @@ class STSDecorator : ClientCodegenDecorator { private fun isAwsCredentials(shape: Shape): Boolean = shape.id == ShapeId.from("com.amazonaws.sts#Credentials") - override fun transformModel(service: ServiceShape, model: Model, settings: ClientRustSettings): Model = + override fun transformModel( + service: ServiceShape, + model: Model, + settings: ClientRustSettings, + ): Model = ModelTransformer.create().mapShapes(model) { shape -> shape.letIf(isIdpCommunicationError(shape)) { logger.info("Adding @retryable trait to $shape and setting its error type to 'server'") diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/timestream/TimestreamDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/timestream/TimestreamDecorator.kt index eca18011871..3e4705a743f 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/timestream/TimestreamDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/timestream/TimestreamDecorator.kt @@ -31,26 +31,31 @@ class TimestreamDecorator : ClientCodegenDecorator { override val name: String = "Timestream" override val order: Byte = -1 - override fun extraSections(codegenContext: ClientCodegenContext): List = listOf( - adhocCustomization { - addDependency(AwsCargoDependency.awsConfig(codegenContext.runtimeConfig).toDevDependency()) - rustTemplate( - """ - let config = aws_config::load_from_env().await; - // You MUST call `with_endpoint_discovery_enabled` to produce a working client for this service. - let ${it.clientName} = ${it.crateName}::Client::new(&config).with_endpoint_discovery_enabled().await; - """.replaceIndent(it.indent), - ) - }, - ) - - override fun extras(codegenContext: ClientCodegenContext, rustCrate: RustCrate) { - val endpointDiscovery = InlineAwsDependency.forRustFile( - "endpoint_discovery", - Visibility.PUBLIC, - CargoDependency.Tokio.copy(scope = DependencyScope.Compile, features = setOf("sync")), - CargoDependency.smithyAsync(codegenContext.runtimeConfig).toDevDependency().withFeature("test-util"), + override fun extraSections(codegenContext: ClientCodegenContext): List = + listOf( + adhocCustomization { + addDependency(AwsCargoDependency.awsConfig(codegenContext.runtimeConfig).toDevDependency()) + rustTemplate( + """ + let config = aws_config::load_from_env().await; + // You MUST call `with_endpoint_discovery_enabled` to produce a working client for this service. + let ${it.clientName} = ${it.crateName}::Client::new(&config).with_endpoint_discovery_enabled().await; + """.replaceIndent(it.indent), + ) + }, ) + + override fun extras( + codegenContext: ClientCodegenContext, + rustCrate: RustCrate, + ) { + val endpointDiscovery = + InlineAwsDependency.forRustFile( + "endpoint_discovery", + Visibility.PUBLIC, + CargoDependency.Tokio.copy(scope = DependencyScope.Compile, features = setOf("sync")), + CargoDependency.smithyAsync(codegenContext.runtimeConfig).toDevDependency().withFeature("test-util"), + ) rustCrate.withModule(ClientRustModule.client) { // helper function to resolve an endpoint given a base client rustTemplate( diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/endpoints/AwsEndpointsStdLib.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/endpoints/AwsEndpointsStdLib.kt index 5d3fb010f9e..184e757b1d4 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/endpoints/AwsEndpointsStdLib.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/endpoints/AwsEndpointsStdLib.kt @@ -29,14 +29,16 @@ class AwsEndpointsStdLib() : ClientCodegenDecorator { private fun partitionMetadata(sdkSettings: SdkSettings): ObjectNode { if (partitionsCache == null) { - val partitionsJson = when (val path = sdkSettings.partitionsConfigPath) { - null -> ( - javaClass.getResource("/default-partitions.json") - ?: throw IllegalStateException("Failed to find default-partitions.json in the JAR") - ).readText() + val partitionsJson = + when (val path = sdkSettings.partitionsConfigPath) { + null -> + ( + javaClass.getResource("/default-partitions.json") + ?: throw IllegalStateException("Failed to find default-partitions.json in the JAR") + ).readText() - else -> path.readText() - } + else -> path.readText() + } partitionsCache = Node.parse(partitionsJson).expectObjectNode() } return partitionsCache!! diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/endpoints/OperationInputTestGenerator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/endpoints/OperationInputTestGenerator.kt index b11236dfd14..5bcc870ae5b 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/endpoints/OperationInputTestGenerator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/endpoints/OperationInputTestGenerator.kt @@ -41,16 +41,20 @@ class OperationInputTestDecorator : ClientCodegenDecorator { override val name: String = "OperationInputTest" override val order: Byte = 0 - override fun extras(codegenContext: ClientCodegenContext, rustCrate: RustCrate) { + override fun extras( + codegenContext: ClientCodegenContext, + rustCrate: RustCrate, + ) { val endpointTests = EndpointTypesGenerator.fromContext(codegenContext).tests.orNullIfEmpty() ?: return rustCrate.integrationTest("endpoint_tests") { Attribute(Attribute.cfg(Attribute.feature("test-util"))).render(this, AttributeKind.Inner) - val tests = endpointTests.flatMap { test -> - val generator = OperationInputTestGenerator(codegenContext, test) - test.operationInputs.filterNot { usesDeprecatedBuiltIns(it) }.map { operationInput -> - generator.generateInput(operationInput) + val tests = + endpointTests.flatMap { test -> + val generator = OperationInputTestGenerator(codegenContext, test) + test.operationInputs.filterNot { usesDeprecatedBuiltIns(it) }.map { operationInput -> + generator.generateInput(operationInput) + } } - } tests.join("\n")(this) } } @@ -122,89 +126,94 @@ class OperationInputTestGenerator(_ctx: ClientCodegenContext, private val test: private val model = ctx.model private val instantiator = ClientInstantiator(ctx) - fun generateInput(testOperationInput: EndpointTestOperationInput) = writable { - val operationName = testOperationInput.operationName.toSnakeCase() - tokioTest(safeName("operation_input_test_$operationName")) { - rustTemplate( - """ - /* builtIns: ${escape(Node.prettyPrintJson(testOperationInput.builtInParams))} */ - /* clientParams: ${escape(Node.prettyPrintJson(testOperationInput.clientParams))} */ - let (http_client, rcvr) = #{capture_request}(None); - let conf = #{conf}; - let client = $moduleName::Client::from_conf(conf); - let _result = dbg!(#{invoke_operation}); - #{assertion} - """, - "capture_request" to RuntimeType.captureRequest(runtimeConfig), - "conf" to config(testOperationInput), - "invoke_operation" to operationInvocation(testOperationInput), - "assertion" to writable { - test.expect.endpoint.ifPresent { endpoint -> - val uri = escape(endpoint.url) - rustTemplate( - """ - let req = rcvr.expect_request(); - let uri = req.uri().to_string(); - assert!(uri.starts_with(${uri.dq()}), "expected URI to start with `$uri` but it was `{}`", uri); - """, - ) - } - test.expect.error.ifPresent { error -> - val expectedError = - escape("expected error: $error [${test.documentation.orNull() ?: "no docs"}]") - val escapedError = escape(error) - rustTemplate( - """ - rcvr.expect_no_request(); - let error = _result.expect_err(${expectedError.dq()}); - assert!( - format!("{:?}", error).contains(${escapedError.dq()}), - "expected error to contain `$escapedError` but it was {:?}", error - ); - """, - ) - } - }, - ) + fun generateInput(testOperationInput: EndpointTestOperationInput) = + writable { + val operationName = testOperationInput.operationName.toSnakeCase() + tokioTest(safeName("operation_input_test_$operationName")) { + rustTemplate( + """ + /* builtIns: ${escape(Node.prettyPrintJson(testOperationInput.builtInParams))} */ + /* clientParams: ${escape(Node.prettyPrintJson(testOperationInput.clientParams))} */ + let (http_client, rcvr) = #{capture_request}(None); + let conf = #{conf}; + let client = $moduleName::Client::from_conf(conf); + let _result = dbg!(#{invoke_operation}); + #{assertion} + """, + "capture_request" to RuntimeType.captureRequest(runtimeConfig), + "conf" to config(testOperationInput), + "invoke_operation" to operationInvocation(testOperationInput), + "assertion" to + writable { + test.expect.endpoint.ifPresent { endpoint -> + val uri = escape(endpoint.url) + rustTemplate( + """ + let req = rcvr.expect_request(); + let uri = req.uri().to_string(); + assert!(uri.starts_with(${uri.dq()}), "expected URI to start with `$uri` but it was `{}`", uri); + """, + ) + } + test.expect.error.ifPresent { error -> + val expectedError = + escape("expected error: $error [${test.documentation.orNull() ?: "no docs"}]") + val escapedError = escape(error) + rustTemplate( + """ + rcvr.expect_no_request(); + let error = _result.expect_err(${expectedError.dq()}); + assert!( + format!("{:?}", error).contains(${escapedError.dq()}), + "expected error to contain `$escapedError` but it was {:?}", error + ); + """, + ) + } + }, + ) + } } - } - private fun operationInvocation(testOperationInput: EndpointTestOperationInput) = writable { - rust("client.${testOperationInput.operationName.toSnakeCase()}()") - val operationInput = - model.expectShape(ctx.operationId(testOperationInput), OperationShape::class.java).inputShape(model) - testOperationInput.operationParams.members.forEach { (key, value) -> - val member = operationInput.expectMember(key.value) - rustTemplate( - ".${member.setterName()}(#{value})", - "value" to instantiator.generate(member, value), - ) + private fun operationInvocation(testOperationInput: EndpointTestOperationInput) = + writable { + rust("client.${testOperationInput.operationName.toSnakeCase()}()") + val operationInput = + model.expectShape(ctx.operationId(testOperationInput), OperationShape::class.java).inputShape(model) + testOperationInput.operationParams.members.forEach { (key, value) -> + val member = operationInput.expectMember(key.value) + rustTemplate( + ".${member.setterName()}(#{value})", + "value" to instantiator.generate(member, value), + ) + } + rust(".send().await") } - rust(".send().await") - } /** initialize service config for test */ - private fun config(operationInput: EndpointTestOperationInput) = writable { - rustBlock("") { - Attribute.AllowUnusedMut.render(this) - rust("let mut builder = $moduleName::Config::builder().with_test_defaults().http_client(http_client);") - operationInput.builtInParams.members.forEach { (builtIn, value) -> - val setter = endpointCustomizations.firstNotNullOfOrNull { - it.setBuiltInOnServiceConfig( - builtIn.value, - value, - "builder", - ) - } - if (setter != null) { - setter(this) - } else { - Logger.getLogger("OperationTestGenerator").warning("No provider for ${builtIn.value}") + private fun config(operationInput: EndpointTestOperationInput) = + writable { + rustBlock("") { + Attribute.AllowUnusedMut.render(this) + rust("let mut builder = $moduleName::Config::builder().with_test_defaults().http_client(http_client);") + operationInput.builtInParams.members.forEach { (builtIn, value) -> + val setter = + endpointCustomizations.firstNotNullOfOrNull { + it.setBuiltInOnServiceConfig( + builtIn.value, + value, + "builder", + ) + } + if (setter != null) { + setter(this) + } else { + Logger.getLogger("OperationTestGenerator").warning("No provider for ${builtIn.value}") + } } + rust("builder.build()") } - rust("builder.build()") } - } } fun ClientCodegenContext.operationId(testOperationInput: EndpointTestOperationInput): ShapeId = diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/endpoints/RequireEndpointRules.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/endpoints/RequireEndpointRules.kt index 3b3b81aeccf..f0aed0add0e 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/endpoints/RequireEndpointRules.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/endpoints/RequireEndpointRules.kt @@ -15,7 +15,11 @@ import software.amazon.smithy.rustsdk.sdkSettings class RequireEndpointRules : ClientCodegenDecorator { override val name: String = "RequireEndpointRules" override val order: Byte = 100 - override fun extras(codegenContext: ClientCodegenContext, rustCrate: RustCrate) { + + override fun extras( + codegenContext: ClientCodegenContext, + rustCrate: RustCrate, + ) { if (!codegenContext.sdkSettings().requireEndpointResolver) { return } diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/endpoints/StripEndpointTrait.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/endpoints/StripEndpointTrait.kt index 6dccf3f6b3f..9fd6f40822f 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/endpoints/StripEndpointTrait.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/endpoints/StripEndpointTrait.kt @@ -13,9 +13,10 @@ fun stripEndpointTrait(hostPrefix: String): (Model) -> Model { return { model: Model -> ModelTransformer.create() .removeTraitsIf(model) { _, trait -> - trait is EndpointTrait && trait.hostPrefix.labels.any { - it.isLabel && it.content == hostPrefix - } + trait is EndpointTrait && + trait.hostPrefix.labels.any { + it.isLabel && it.content == hostPrefix + } } } } diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/traits/PresignableTrait.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/traits/PresignableTrait.kt index ea3c2db0663..84cc95dbbd6 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/traits/PresignableTrait.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/traits/PresignableTrait.kt @@ -9,8 +9,9 @@ import software.amazon.smithy.model.node.Node import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.model.traits.AnnotationTrait -/** Synthetic trait that indicates an operation is presignable. */ // TODO(https://github.com/awslabs/smithy/pull/897) this can be replaced when the trait is added to Smithy. + +/** Synthetic trait that indicates an operation is presignable. */ class PresignableTrait(val syntheticOperationId: ShapeId) : AnnotationTrait(ID, Node.objectNode()) { companion object { val ID = ShapeId.from("smithy.api.aws.internal#presignable") diff --git a/aws/sdk-codegen/src/test/kotlin/AwsCrateDocsDecoratorTest.kt b/aws/sdk-codegen/src/test/kotlin/AwsCrateDocsDecoratorTest.kt index 7cd935e381d..bd8fc889a21 100644 --- a/aws/sdk-codegen/src/test/kotlin/AwsCrateDocsDecoratorTest.kt +++ b/aws/sdk-codegen/src/test/kotlin/AwsCrateDocsDecoratorTest.kt @@ -120,43 +120,47 @@ class AwsCrateDocsDecoratorTest { fun warningBanner() { val context = { version: String -> testClientCodegenContext( - model = """ + model = + """ namespace test service Foobaz { } - """.asSmithyModel(), - settings = testClientRustSettings( - moduleVersion = version, - service = ShapeId.from("test#Foobaz"), - runtimeConfig = AwsTestRuntimeConfig, - customizationConfig = - ObjectNode.parse( - """ - { "awsSdk": { - "awsConfigVersion": "dontcare" } } - """, - ) as ObjectNode, - ), + """.asSmithyModel(), + settings = + testClientRustSettings( + moduleVersion = version, + service = ShapeId.from("test#Foobaz"), + runtimeConfig = AwsTestRuntimeConfig, + customizationConfig = + ObjectNode.parse( + """ + { "awsSdk": { + "awsConfigVersion": "dontcare" } } + """, + ) as ObjectNode, + ), ) } // Test unstable versions first var codegenContext = context("0.36.0") - var result = AwsCrateDocGenerator(codegenContext).docText(includeHeader = false, includeLicense = false, asComments = true).let { writable -> - val writer = RustWriter.root() - writable(writer) - writer.toString() - } + var result = + AwsCrateDocGenerator(codegenContext).docText(includeHeader = false, includeLicense = false, asComments = true).let { writable -> + val writer = RustWriter.root() + writable(writer) + writer.toString() + } assertTrue(result.contains("The SDK is currently released as a developer preview")) // And now stable versions codegenContext = context("1.0.0") - result = AwsCrateDocGenerator(codegenContext).docText(includeHeader = false, includeLicense = false, asComments = true).let { writable -> - val writer = RustWriter.root() - writable(writer) - writer.toString() - } + result = + AwsCrateDocGenerator(codegenContext).docText(includeHeader = false, includeLicense = false, asComments = true).let { writable -> + val writer = RustWriter.root() + writable(writer) + writer.toString() + } assertFalse(result.contains("The SDK is currently released as a developer preview")) } } diff --git a/aws/sdk-codegen/src/test/kotlin/SdkCodegenIntegrationTest.kt b/aws/sdk-codegen/src/test/kotlin/SdkCodegenIntegrationTest.kt index d4bd5e9ab11..998674bf917 100644 --- a/aws/sdk-codegen/src/test/kotlin/SdkCodegenIntegrationTest.kt +++ b/aws/sdk-codegen/src/test/kotlin/SdkCodegenIntegrationTest.kt @@ -12,7 +12,8 @@ import software.amazon.smithy.rustsdk.awsSdkIntegrationTest class SdkCodegenIntegrationTest { companion object { - val model = """ + val model = + """ namespace test use aws.api#service @@ -46,12 +47,14 @@ class SdkCodegenIntegrationTest { operation SomeOperation { output: SomeOutput } - """.asSmithyModel() + """.asSmithyModel() } @Test fun smokeTestSdkCodegen() { - awsSdkIntegrationTest(model) { _, _ -> /* it should compile */ } + awsSdkIntegrationTest(model) { _, _ -> + // it should compile + } } // TODO(PostGA): Remove warning banner conditionals. diff --git a/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/AwsPresigningDecoratorTest.kt b/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/AwsPresigningDecoratorTest.kt index 0435e742594..c04a4ce5714 100644 --- a/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/AwsPresigningDecoratorTest.kt +++ b/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/AwsPresigningDecoratorTest.kt @@ -27,7 +27,11 @@ class AwsPresigningDecoratorTest { testTransform("com.amazonaws.s3", "GetObject", presignable = true) } - private fun testTransform(namespace: String, name: String, presignable: Boolean) { + private fun testTransform( + namespace: String, + name: String, + presignable: Boolean, + ) { val settings = testClientRustSettings() val decorator = AwsPresigningDecorator() val model = testOperation(namespace, name) @@ -35,7 +39,11 @@ class AwsPresigningDecoratorTest { hasPresignableTrait(transformed, namespace, name) shouldBe presignable } - private fun hasPresignableTrait(model: Model, namespace: String, name: String): Boolean = + private fun hasPresignableTrait( + model: Model, + namespace: String, + name: String, + ): Boolean = model.shapes().filter { shape -> shape is OperationShape && shape.id == ShapeId.fromParts(namespace, name) } .findFirst() .orNull()!! @@ -44,7 +52,10 @@ class AwsPresigningDecoratorTest { private fun serviceShape(model: Model): ServiceShape = model.shapes().filter { shape -> shape is ServiceShape }.findFirst().orNull()!! as ServiceShape - private fun testOperation(namespace: String, name: String): Model = + private fun testOperation( + namespace: String, + name: String, + ): Model = """ namespace $namespace use aws.protocols#restJson1 @@ -68,7 +79,8 @@ class OverrideHttpMethodTransformTest { @Test fun `it should override the HTTP method for the listed operations`() { val settings = testClientRustSettings() - val model = """ + val model = + """ namespace test use aws.protocols#restJson1 @@ -89,26 +101,28 @@ class OverrideHttpMethodTransformTest { @http(uri: "/three", method: "POST") operation Three { input: TestInput, output: TestOutput } - """.asSmithyModel() + """.asSmithyModel() val serviceShape = model.expectShape(ShapeId.from("test#TestService"), ServiceShape::class.java) - val presignableOp = PresignableOperation( - PayloadSigningType.EMPTY, - listOf( - OverrideHttpMethodTransform( - mapOf( - ShapeId.from("test#One") to "GET", - ShapeId.from("test#Two") to "POST", + val presignableOp = + PresignableOperation( + PayloadSigningType.EMPTY, + listOf( + OverrideHttpMethodTransform( + mapOf( + ShapeId.from("test#One") to "GET", + ShapeId.from("test#Two") to "POST", + ), ), ), - ), - ) - val transformed = AwsPresigningDecorator( - mapOf( - ShapeId.from("test#One") to presignableOp, - ShapeId.from("test#Two") to presignableOp, - ), - ).transformModel(serviceShape, model, settings) + ) + val transformed = + AwsPresigningDecorator( + mapOf( + ShapeId.from("test#One") to presignableOp, + ShapeId.from("test#Two") to presignableOp, + ), + ).transformModel(serviceShape, model, settings) val synthNamespace = "test.synthetic.aws.presigned" transformed.expectShape(ShapeId.from("$synthNamespace#One")).expectTrait().method shouldBe "GET" @@ -118,11 +132,11 @@ class OverrideHttpMethodTransformTest { } class MoveDocumentMembersToQueryParamsTransformTest { - @Test fun `it should move document members to query parameters for the listed operations`() { val settings = testClientRustSettings() - val model = """ + val model = + """ namespace test use aws.protocols#restJson1 @@ -156,39 +170,43 @@ class MoveDocumentMembersToQueryParamsTransformTest { @http(uri: "/two", method: "POST") operation Two { input: TwoInputOutput, output: TwoInputOutput } - """.asSmithyModel() + """.asSmithyModel() val serviceShape = model.expectShape(ShapeId.from("test#TestService"), ServiceShape::class.java) - val presignableOp = PresignableOperation( - PayloadSigningType.EMPTY, - listOf( - MoveDocumentMembersToQueryParamsTransform( - listOf(ShapeId.from("test#One")), + val presignableOp = + PresignableOperation( + PayloadSigningType.EMPTY, + listOf( + MoveDocumentMembersToQueryParamsTransform( + listOf(ShapeId.from("test#One")), + ), ), - ), - ) - val transformed = AwsPresigningDecorator( - mapOf(ShapeId.from("test#One") to presignableOp), - ).transformModel(serviceShape, model, settings) + ) + val transformed = + AwsPresigningDecorator( + mapOf(ShapeId.from("test#One") to presignableOp), + ).transformModel(serviceShape, model, settings) val index = HttpBindingIndex(transformed) index.getRequestBindings(ShapeId.from("test.synthetic.aws.presigned#One")).map { (key, value) -> key to value.location - }.toMap() shouldBe mapOf( - "one" to HttpBinding.Location.HEADER, - "two" to HttpBinding.Location.QUERY, - "three" to HttpBinding.Location.QUERY, - "four" to HttpBinding.Location.QUERY, - ) + }.toMap() shouldBe + mapOf( + "one" to HttpBinding.Location.HEADER, + "two" to HttpBinding.Location.QUERY, + "three" to HttpBinding.Location.QUERY, + "four" to HttpBinding.Location.QUERY, + ) transformed.getShape(ShapeId.from("test.synthetic.aws.presigned#Two")).orNull() shouldBe null index.getRequestBindings(ShapeId.from("test#Two")).map { (key, value) -> key to value.location - }.toMap() shouldBe mapOf( - "one" to HttpBinding.Location.HEADER, - "two" to HttpBinding.Location.QUERY, - "three" to HttpBinding.Location.DOCUMENT, - "four" to HttpBinding.Location.DOCUMENT, - ) + }.toMap() shouldBe + mapOf( + "one" to HttpBinding.Location.HEADER, + "two" to HttpBinding.Location.QUERY, + "three" to HttpBinding.Location.DOCUMENT, + "four" to HttpBinding.Location.DOCUMENT, + ) } } diff --git a/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/CredentialProviderConfigTest.kt b/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/CredentialProviderConfigTest.kt index aec0806d06f..eb759fe3af7 100644 --- a/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/CredentialProviderConfigTest.kt +++ b/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/CredentialProviderConfigTest.kt @@ -17,13 +17,15 @@ internal class CredentialProviderConfigTest { fun `configuring credentials provider at operation level should work`() { awsSdkIntegrationTest(SdkCodegenIntegrationTest.model) { ctx, rustCrate -> val rc = ctx.runtimeConfig - val codegenScope = arrayOf( - *RuntimeType.preludeScope, - "capture_request" to RuntimeType.captureRequest(rc), - "Credentials" to AwsRuntimeType.awsCredentialTypesTestUtil(rc) - .resolve("Credentials"), - "Region" to AwsRuntimeType.awsTypes(rc).resolve("region::Region"), - ) + val codegenScope = + arrayOf( + *RuntimeType.preludeScope, + "capture_request" to RuntimeType.captureRequest(rc), + "Credentials" to + AwsRuntimeType.awsCredentialTypesTestUtil(rc) + .resolve("Credentials"), + "Region" to AwsRuntimeType.awsTypes(rc).resolve("region::Region"), + ) rustCrate.integrationTest("credentials_provider") { // per https://github.com/awslabs/aws-sdk-rust/issues/901 tokioTest("configuring_credentials_provider_at_operation_level_should_work") { @@ -67,16 +69,19 @@ internal class CredentialProviderConfigTest { fun `configuring credentials provider on builder should replace what was previously set`() { awsSdkIntegrationTest(SdkCodegenIntegrationTest.model) { ctx, rustCrate -> val rc = ctx.runtimeConfig - val codegenScope = arrayOf( - *RuntimeType.preludeScope, - "capture_request" to RuntimeType.captureRequest(rc), - "Credentials" to AwsRuntimeType.awsCredentialTypesTestUtil(rc) - .resolve("Credentials"), - "Region" to AwsRuntimeType.awsTypes(rc).resolve("region::Region"), - "SdkConfig" to AwsRuntimeType.awsTypes(rc).resolve("sdk_config::SdkConfig"), - "SharedCredentialsProvider" to AwsRuntimeType.awsCredentialTypes(rc) - .resolve("provider::SharedCredentialsProvider"), - ) + val codegenScope = + arrayOf( + *RuntimeType.preludeScope, + "capture_request" to RuntimeType.captureRequest(rc), + "Credentials" to + AwsRuntimeType.awsCredentialTypesTestUtil(rc) + .resolve("Credentials"), + "Region" to AwsRuntimeType.awsTypes(rc).resolve("region::Region"), + "SdkConfig" to AwsRuntimeType.awsTypes(rc).resolve("sdk_config::SdkConfig"), + "SharedCredentialsProvider" to + AwsRuntimeType.awsCredentialTypes(rc) + .resolve("provider::SharedCredentialsProvider"), + ) rustCrate.integrationTest("credentials_provider") { // per https://github.com/awslabs/aws-sdk-rust/issues/973 tokioTest("configuring_credentials_provider_on_builder_should_replace_what_was_previously_set") { diff --git a/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/EndpointBuiltInsDecoratorTest.kt b/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/EndpointBuiltInsDecoratorTest.kt index 68bfab92fef..c5dd5581600 100644 --- a/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/EndpointBuiltInsDecoratorTest.kt +++ b/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/EndpointBuiltInsDecoratorTest.kt @@ -13,7 +13,8 @@ import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.testutil.integrationTest class EndpointBuiltInsDecoratorTest { - private val endpointUrlModel = """ + private val endpointUrlModel = + """ namespace test use aws.api#service @@ -74,7 +75,7 @@ class EndpointBuiltInsDecoratorTest { operation SomeOperation { output: SomeOutput } - """.asSmithyModel() + """.asSmithyModel() @Test fun endpointUrlBuiltInWorksEndToEnd() { @@ -107,10 +108,12 @@ class EndpointBuiltInsDecoratorTest { } """, "tokio" to CargoDependency.Tokio.toDevDependency().withFeature("rt").withFeature("macros").toType(), - "StaticReplayClient" to CargoDependency.smithyRuntimeTestUtil(codegenContext.runtimeConfig).toType() - .resolve("client::http::test_util::StaticReplayClient"), - "ReplayEvent" to CargoDependency.smithyRuntimeTestUtil(codegenContext.runtimeConfig).toType() - .resolve("client::http::test_util::ReplayEvent"), + "StaticReplayClient" to + CargoDependency.smithyRuntimeTestUtil(codegenContext.runtimeConfig).toType() + .resolve("client::http::test_util::StaticReplayClient"), + "ReplayEvent" to + CargoDependency.smithyRuntimeTestUtil(codegenContext.runtimeConfig).toType() + .resolve("client::http::test_util::ReplayEvent"), "SdkBody" to RuntimeType.sdkBody(codegenContext.runtimeConfig), ) } diff --git a/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/EndpointsCredentialsTest.kt b/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/EndpointsCredentialsTest.kt index 21494c8d58c..0beb50db35d 100644 --- a/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/EndpointsCredentialsTest.kt +++ b/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/EndpointsCredentialsTest.kt @@ -21,7 +21,8 @@ class EndpointsCredentialsTest { // 1. A rule that sets no authentication scheme—in this case, we should be using the default from the service // 2. A rule that sets a custom authentication scheme and that configures signing // The chosen path is controlled by static context parameters set on the operation - private val model = """ + private val model = + """ namespace aws.fooBaz use aws.api#service @@ -73,7 +74,7 @@ class EndpointsCredentialsTest { @http(uri: "/custom", method: "GET") @staticContextParams({ AuthMode: { value: "custom-auth" } }) operation CustomAuth { } - """.asSmithyModel() + """.asSmithyModel() @Test fun `endpoint rules configure auth in default and non-default case`() { @@ -97,8 +98,9 @@ class EndpointsCredentialsTest { assert!(auth_header.contains("/us-west-2/foobaz/aws4_request"), "{}", auth_header); """, "capture_request" to RuntimeType.captureRequest(context.runtimeConfig), - "Credentials" to AwsRuntimeType.awsCredentialTypesTestUtil(context.runtimeConfig) - .resolve("Credentials"), + "Credentials" to + AwsRuntimeType.awsCredentialTypesTestUtil(context.runtimeConfig) + .resolve("Credentials"), "Region" to AwsRuntimeType.awsTypes(context.runtimeConfig).resolve("region::Region"), ) } @@ -120,8 +122,9 @@ class EndpointsCredentialsTest { assert!(auth_header.contains("/region-custom-auth/name-custom-auth/aws4_request"), "{}", auth_header); """, "capture_request" to RuntimeType.captureRequest(context.runtimeConfig), - "Credentials" to AwsRuntimeType.awsCredentialTypesTestUtil(context.runtimeConfig) - .resolve("Credentials"), + "Credentials" to + AwsRuntimeType.awsCredentialTypesTestUtil(context.runtimeConfig) + .resolve("Credentials"), "Region" to AwsRuntimeType.awsTypes(context.runtimeConfig).resolve("region::Region"), ) } diff --git a/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/InvocationIdDecoratorTest.kt b/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/InvocationIdDecoratorTest.kt index be0861a045b..849db468525 100644 --- a/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/InvocationIdDecoratorTest.kt +++ b/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/InvocationIdDecoratorTest.kt @@ -48,10 +48,12 @@ class InvocationIdDecoratorTest { """, *preludeScope, "tokio" to CargoDependency.Tokio.toType(), - "InvocationIdGenerator" to AwsRuntimeType.awsRuntime(rc) - .resolve("invocation_id::InvocationIdGenerator"), - "InvocationId" to AwsRuntimeType.awsRuntime(rc) - .resolve("invocation_id::InvocationId"), + "InvocationIdGenerator" to + AwsRuntimeType.awsRuntime(rc) + .resolve("invocation_id::InvocationIdGenerator"), + "InvocationId" to + AwsRuntimeType.awsRuntime(rc) + .resolve("invocation_id::InvocationId"), "BoxError" to RuntimeType.boxError(rc), "capture_request" to RuntimeType.captureRequest(rc), ) diff --git a/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/OperationInputTestGeneratorTests.kt b/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/OperationInputTestGeneratorTests.kt index 91c7f6be91a..472d8159956 100644 --- a/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/OperationInputTestGeneratorTests.kt +++ b/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/OperationInputTestGeneratorTests.kt @@ -18,13 +18,15 @@ class OperationInputTestGeneratorTests { @Test fun `finds operation shape by name`() { val prefix = "\$version: \"2\"" - val operationModel = """ + val operationModel = + """ $prefix namespace operations operation Ping {} - """.trimIndent() - val serviceModel = """ + """.trimIndent() + val serviceModel = + """ $prefix namespace service @@ -33,19 +35,21 @@ class OperationInputTestGeneratorTests { service MyService { operations: [Ping] } - """.trimIndent() + """.trimIndent() - val model = Model.assembler() - .discoverModels() - .addUnparsedModel("operation.smithy", operationModel) - .addUnparsedModel("main.smithy", serviceModel) - .assemble() - .unwrap() + val model = + Model.assembler() + .discoverModels() + .addUnparsedModel("operation.smithy", operationModel) + .addUnparsedModel("main.smithy", serviceModel) + .assemble() + .unwrap() val context = testClientCodegenContext(model) - val testOperationInput = EndpointTestOperationInput.builder() - .operationName("Ping") - .build() + val testOperationInput = + EndpointTestOperationInput.builder() + .operationName("Ping") + .build() val operationId = context.operationId(testOperationInput) assertEquals("operations#Ping", operationId.toString()) @@ -53,18 +57,20 @@ class OperationInputTestGeneratorTests { @Test fun `fails for operation name not found`() { - val model = """ + val model = + """ namespace test operation Ping {} service MyService { operations: [Ping] } - """.trimIndent().asSmithyModel() + """.trimIndent().asSmithyModel() val context = testClientCodegenContext(model) - val testOperationInput = EndpointTestOperationInput.builder() - .operationName("Pong") - .build() + val testOperationInput = + EndpointTestOperationInput.builder() + .operationName("Pong") + .build() assertThrows { context.operationId(testOperationInput) } } diff --git a/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/RegionDecoratorTest.kt b/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/RegionDecoratorTest.kt index cc003fc0864..61362caea7b 100644 --- a/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/RegionDecoratorTest.kt +++ b/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/RegionDecoratorTest.kt @@ -11,7 +11,8 @@ import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import kotlin.io.path.readText class RegionDecoratorTest { - private val modelWithoutRegionParamOrSigV4AuthScheme = """ + private val modelWithoutRegionParamOrSigV4AuthScheme = + """ namespace test use aws.api#service @@ -28,9 +29,10 @@ class RegionDecoratorTest { service TestService { version: "2023-01-01", operations: [SomeOperation] } structure SomeOutput { something: String } operation SomeOperation { output: SomeOutput } - """.asSmithyModel() + """.asSmithyModel() - private val modelWithRegionParam = """ + private val modelWithRegionParam = + """ namespace test use aws.api#service @@ -49,9 +51,10 @@ class RegionDecoratorTest { service TestService { version: "2023-01-01", operations: [SomeOperation] } structure SomeOutput { something: String } operation SomeOperation { output: SomeOutput } - """.asSmithyModel() + """.asSmithyModel() - private val modelWithSigV4AuthScheme = """ + private val modelWithSigV4AuthScheme = + """ namespace test use aws.auth#sigv4 @@ -71,31 +74,34 @@ class RegionDecoratorTest { service TestService { version: "2023-01-01", operations: [SomeOperation] } structure SomeOutput { something: String } operation SomeOperation { output: SomeOutput } - """.asSmithyModel() + """.asSmithyModel() @Test fun `models without region built-in params or SigV4 should not have configurable regions`() { - val path = awsSdkIntegrationTest(modelWithoutRegionParamOrSigV4AuthScheme) { _, _ -> - // it should generate and compile successfully - } + val path = + awsSdkIntegrationTest(modelWithoutRegionParamOrSigV4AuthScheme) { _, _ -> + // it should generate and compile successfully + } val configContents = path.resolve("src/config.rs").readText() assertFalse(configContents.contains("fn set_region(")) } @Test fun `models with region built-in params should have configurable regions`() { - val path = awsSdkIntegrationTest(modelWithRegionParam) { _, _ -> - // it should generate and compile successfully - } + val path = + awsSdkIntegrationTest(modelWithRegionParam) { _, _ -> + // it should generate and compile successfully + } val configContents = path.resolve("src/config.rs").readText() assertTrue(configContents.contains("fn set_region(")) } @Test fun `models with SigV4 should have configurable regions`() { - val path = awsSdkIntegrationTest(modelWithSigV4AuthScheme) { _, _ -> - // it should generate and compile successfully - } + val path = + awsSdkIntegrationTest(modelWithSigV4AuthScheme) { _, _ -> + // it should generate and compile successfully + } val configContents = path.resolve("src/config.rs").readText() assertTrue(configContents.contains("fn set_region(")) } diff --git a/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/SigV4AuthDecoratorTest.kt b/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/SigV4AuthDecoratorTest.kt index 64f35d197fa..5183d4051a2 100644 --- a/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/SigV4AuthDecoratorTest.kt +++ b/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/SigV4AuthDecoratorTest.kt @@ -8,7 +8,8 @@ import org.junit.jupiter.api.Test import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel class SigV4AuthDecoratorTest { - private val modelWithSigV4AuthScheme = """ + private val modelWithSigV4AuthScheme = + """ namespace test use aws.auth#sigv4 @@ -55,7 +56,7 @@ class SigV4AuthDecoratorTest { @unsignedPayload @http(uri: "/", method: "POST") operation SomeOperation { input: SomeInput, output: SomeOutput } - """.asSmithyModel() + """.asSmithyModel() @Test fun unsignedPayloadSetsCorrectHeader() { diff --git a/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/TestUtil.kt b/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/TestUtil.kt index ad08e25aebf..aa2a18a2e44 100644 --- a/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/TestUtil.kt +++ b/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/TestUtil.kt @@ -21,50 +21,55 @@ import java.io.File // In aws-sdk-codegen, the working dir when gradle runs tests is actually `./aws`. So, to find the smithy runtime, we need // to go up one more level -val AwsTestRuntimeConfig = TestRuntimeConfig.copy( - runtimeCrateLocation = run { - val path = File("../../rust-runtime") - check(path.exists()) { "$path must exist to generate a working SDK" } - RuntimeCrateLocation.Path(path.absolutePath) - }, -) - -fun awsTestCodegenContext(model: Model? = null, settings: ClientRustSettings? = null) = - testClientCodegenContext( - model ?: "namespace test".asSmithyModel(), - settings = settings ?: testClientRustSettings(runtimeConfig = AwsTestRuntimeConfig), +val AwsTestRuntimeConfig = + TestRuntimeConfig.copy( + runtimeCrateLocation = + run { + val path = File("../../rust-runtime") + check(path.exists()) { "$path must exist to generate a working SDK" } + RuntimeCrateLocation.path(path.absolutePath) + }, ) +fun awsTestCodegenContext( + model: Model? = null, + settings: ClientRustSettings? = null, +) = testClientCodegenContext( + model ?: "namespace test".asSmithyModel(), + settings = settings ?: testClientRustSettings(runtimeConfig = AwsTestRuntimeConfig), +) + fun awsSdkIntegrationTest( model: Model, params: IntegrationTestParams = awsIntegrationTestParams(), test: (ClientCodegenContext, RustCrate) -> Unit = { _, _ -> }, -) = - clientIntegrationTest( - model, - awsIntegrationTestParams(), - test = test, - ) +) = clientIntegrationTest( + model, + awsIntegrationTestParams(), + test = test, +) -fun awsIntegrationTestParams() = IntegrationTestParams( - cargoCommand = "cargo test --features test-util behavior-version-latest", - runtimeConfig = AwsTestRuntimeConfig, - additionalSettings = ObjectNode.builder().withMember( - "customizationConfig", - ObjectNode.builder() - .withMember( - "awsSdk", +fun awsIntegrationTestParams() = + IntegrationTestParams( + cargoCommand = "cargo test --features test-util behavior-version-latest", + runtimeConfig = AwsTestRuntimeConfig, + additionalSettings = + ObjectNode.builder().withMember( + "customizationConfig", ObjectNode.builder() - .withMember("generateReadme", false) - .withMember("integrationTestPath", "../sdk/integration-tests") - .build(), - ).build(), + .withMember( + "awsSdk", + ObjectNode.builder() + .withMember("generateReadme", false) + .withMember("integrationTestPath", "../sdk/integration-tests") + .build(), + ).build(), + ) + .withMember( + "codegen", + ObjectNode.builder() + .withMember("includeFluentClient", false) + .withMember("includeEndpointUrlConfig", false) + .build(), + ).build(), ) - .withMember( - "codegen", - ObjectNode.builder() - .withMember("includeFluentClient", false) - .withMember("includeEndpointUrlConfig", false) - .build(), - ).build(), -) diff --git a/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/customize/RemoveDefaultsTest.kt b/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/customize/RemoveDefaultsTest.kt index 1ecc0e650e2..1443b5c3457 100644 --- a/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/customize/RemoveDefaultsTest.kt +++ b/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/customize/RemoveDefaultsTest.kt @@ -18,11 +18,13 @@ import software.amazon.smithy.rust.codegen.core.util.shapeId internal class RemoveDefaultsTest { @Test fun `defaults should be removed`() { - val removeDefaults = setOf( - "test#Bar".shapeId(), - "test#Foo\$baz".shapeId(), - ) - val baseModel = """ + val removeDefaults = + setOf( + "test#Bar".shapeId(), + "test#Foo\$baz".shapeId(), + ) + val baseModel = + """ namespace test structure Foo { @@ -33,7 +35,7 @@ internal class RemoveDefaultsTest { @default(0) integer Bar - """.asSmithyModel(smithyVersion = "2.0") + """.asSmithyModel(smithyVersion = "2.0") val model = RemoveDefaults.processModel(baseModel, removeDefaults) val barMember = model.lookup("test#Foo\$bar") barMember.hasTrait() shouldBe false diff --git a/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/customize/ec2/EC2MakePrimitivesOptionalTest.kt b/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/customize/ec2/EC2MakePrimitivesOptionalTest.kt index ae919497f0f..9a6510e6c01 100644 --- a/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/customize/ec2/EC2MakePrimitivesOptionalTest.kt +++ b/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/customize/ec2/EC2MakePrimitivesOptionalTest.kt @@ -22,7 +22,8 @@ internal class EC2MakePrimitivesOptionalTest { "CLIENT_ZERO_VALUE_V1_NO_INPUT", ) fun `primitive shapes are boxed`(nullabilityCheckMode: NullableIndex.CheckMode) { - val baseModel = """ + val baseModel = + """ namespace test structure Primitives { int: PrimitiveInteger, @@ -38,7 +39,7 @@ internal class EC2MakePrimitivesOptionalTest { structure Other {} - """.asSmithyModel() + """.asSmithyModel() val model = EC2MakePrimitivesOptional.processModel(baseModel) val nullableIndex = NullableIndex(model) val struct = model.lookup("test#Primitives") diff --git a/build.gradle.kts b/build.gradle.kts index 7439d73b146..d62cca61da4 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -18,7 +18,7 @@ plugins { } allprojects { repositories { - /* mavenLocal() */ + // mavenLocal() mavenCentral() google() } @@ -33,31 +33,30 @@ allprojects.forEach { } } -val ktlint by configurations.creating { - // https://github.com/pinterest/ktlint/issues/1114#issuecomment-805793163 - attributes { - attribute(Bundling.BUNDLING_ATTRIBUTE, objects.named(Bundling.EXTERNAL)) - } -} +val ktlint by configurations.creating val ktlintVersion: String by project dependencies { - ktlint("com.pinterest:ktlint:$ktlintVersion") - ktlint("com.pinterest.ktlint:ktlint-ruleset-standard:$ktlintVersion") + ktlint("com.pinterest.ktlint:ktlint-cli:$ktlintVersion") { + attributes { + attribute(Bundling.BUNDLING_ATTRIBUTE, objects.named(Bundling.EXTERNAL)) + } + } } -val lintPaths = listOf( - "**/*.kt", - // Exclude build output directories - "!**/build/**", - "!**/node_modules/**", - "!**/target/**", -) +val lintPaths = + listOf( + "**/*.kt", + // Exclude build output directories + "!**/build/**", + "!**/node_modules/**", + "!**/target/**", + ) tasks.register("ktlint") { description = "Check Kotlin code style." - group = "Verification" - classpath = configurations.getByName("ktlint") + group = LifecycleBasePlugin.VERIFICATION_GROUP + classpath = ktlint mainClass.set("com.pinterest.ktlint.Main") args = listOf("--log-level=info", "--relative", "--") + lintPaths // https://github.com/pinterest/ktlint/issues/1195#issuecomment-1009027802 @@ -66,10 +65,24 @@ tasks.register("ktlint") { tasks.register("ktlintFormat") { description = "Auto fix Kotlin code style violations" - group = "formatting" - classpath = configurations.getByName("ktlint") + group = LifecycleBasePlugin.VERIFICATION_GROUP + classpath = ktlint mainClass.set("com.pinterest.ktlint.Main") args = listOf("--log-level=info", "--relative", "--format", "--") + lintPaths // https://github.com/pinterest/ktlint/issues/1195#issuecomment-1009027802 jvmArgs("--add-opens", "java.base/java.lang=ALL-UNNAMED") } + +tasks.register("ktlintPreCommit") { + description = "Check Kotlin code style (for the pre-commit hooks)." + group = LifecycleBasePlugin.VERIFICATION_GROUP + classpath = ktlint + mainClass.set("com.pinterest.ktlint.Main") + args = listOf("--log-level=warn", "--color", "--relative", "--format", "--") + + System.getProperty("ktlintPreCommitArgs").let { args -> + check(args.isNotBlank()) { "need to pass in -DktlintPreCommitArgs=" } + args.split(" ") + } + // https://github.com/pinterest/ktlint/issues/1195#issuecomment-1009027802 + jvmArgs("--add-opens", "java.base/java.lang=ALL-UNNAMED") +} diff --git a/buildSrc/src/main/kotlin/CodegenTestCommon.kt b/buildSrc/src/main/kotlin/CodegenTestCommon.kt index 3a7008f8970..ac9d8db6c9b 100644 --- a/buildSrc/src/main/kotlin/CodegenTestCommon.kt +++ b/buildSrc/src/main/kotlin/CodegenTestCommon.kt @@ -22,44 +22,52 @@ data class CodegenTest( val imports: List = emptyList(), ) -fun generateImports(imports: List): String = if (imports.isEmpty()) { - "" -} else { - "\"imports\": [${imports.map { "\"$it\"" }.joinToString(", ")}]," -} +fun generateImports(imports: List): String = + if (imports.isEmpty()) { + "" + } else { + "\"imports\": [${imports.map { "\"$it\"" }.joinToString(", ")}]," + } -private fun generateSmithyBuild(projectDir: String, pluginName: String, tests: List): String { - val projections = tests.joinToString(",\n") { - """ - "${it.module}": { - ${generateImports(it.imports)} - "plugins": { - "$pluginName": { - "runtimeConfig": { - "relativePath": "$projectDir/rust-runtime" - }, - "codegen": { - ${it.extraCodegenConfig ?: ""} - }, - "service": "${it.service}", - "module": "${it.module}", - "moduleVersion": "0.0.1", - "moduleDescription": "test", - "moduleAuthors": ["protocoltest@example.com"] - ${it.extraConfig ?: ""} +private fun generateSmithyBuild( + projectDir: String, + pluginName: String, + tests: List, +): String { + val projections = + tests.joinToString(",\n") { + """ + "${it.module}": { + ${generateImports(it.imports)} + "plugins": { + "$pluginName": { + "runtimeConfig": { + "relativePath": "$projectDir/rust-runtime" + }, + "codegen": { + ${it.extraCodegenConfig ?: ""} + }, + "service": "${it.service}", + "module": "${it.module}", + "moduleVersion": "0.0.1", + "moduleDescription": "test", + "moduleAuthors": ["protocoltest@example.com"] + ${it.extraConfig ?: ""} + } } } + """.trimIndent() } - """.trimIndent() - } - return """ + return ( + """ { "version": "1.0", "projections": { $projections } } - """.trimIndent() + """.trimIndent() + ) } enum class Cargo(val toString: String) { @@ -69,26 +77,34 @@ enum class Cargo(val toString: String) { CLIPPY("cargoClippy"), } -private fun generateCargoWorkspace(pluginName: String, tests: List) = +private fun generateCargoWorkspace( + pluginName: String, + tests: List, +) = ( """ [workspace] members = [ ${tests.joinToString(",") { "\"${it.module}/$pluginName\"" }} ] """.trimIndent() +) /** * Filter the service integration tests for which to generate Rust crates in [allTests] using the given [properties]. */ -private fun codegenTests(properties: PropertyRetriever, allTests: List): List { +private fun codegenTests( + properties: PropertyRetriever, + allTests: List, +): List { val modulesOverride = properties.get("modules")?.split(",")?.map { it.trim() } - val ret = if (modulesOverride != null) { - println("modulesOverride: $modulesOverride") - allTests.filter { modulesOverride.contains(it.module) } - } else { - allTests - } + val ret = + if (modulesOverride != null) { + println("modulesOverride: $modulesOverride") + allTests.filter { modulesOverride.contains(it.module) } + } else { + allTests + } require(ret.isNotEmpty()) { "None of the provided module overrides (`$modulesOverride`) are valid test services (`${ allTests.map { @@ -106,22 +122,24 @@ val AllCargoCommands = listOf(Cargo.CHECK, Cargo.TEST, Cargo.CLIPPY, Cargo.DOCS) * The list of Cargo commands that is run by default is defined in [AllCargoCommands]. */ fun cargoCommands(properties: PropertyRetriever): List { - val cargoCommandsOverride = properties.get("cargoCommands")?.split(",")?.map { it.trim() }?.map { - when (it) { - "check" -> Cargo.CHECK - "test" -> Cargo.TEST - "doc" -> Cargo.DOCS - "clippy" -> Cargo.CLIPPY - else -> throw IllegalArgumentException("Unexpected Cargo command `$it` (valid commands are `check`, `test`, `doc`, `clippy`)") + val cargoCommandsOverride = + properties.get("cargoCommands")?.split(",")?.map { it.trim() }?.map { + when (it) { + "check" -> Cargo.CHECK + "test" -> Cargo.TEST + "doc" -> Cargo.DOCS + "clippy" -> Cargo.CLIPPY + else -> throw IllegalArgumentException("Unexpected Cargo command `$it` (valid commands are `check`, `test`, `doc`, `clippy`)") + } } - } - val ret = if (cargoCommandsOverride != null) { - println("cargoCommandsOverride: $cargoCommandsOverride") - AllCargoCommands.filter { cargoCommandsOverride.contains(it) } - } else { - AllCargoCommands - } + val ret = + if (cargoCommandsOverride != null) { + println("cargoCommandsOverride: $cargoCommandsOverride") + AllCargoCommands.filter { cargoCommandsOverride.contains(it) } + } else { + AllCargoCommands + } require(ret.isNotEmpty()) { "None of the provided cargo commands (`$cargoCommandsOverride`) are valid cargo commands (`${ AllCargoCommands.map { @@ -156,12 +174,13 @@ fun Project.registerGenerateSmithyBuildTask( // If this is a rebuild, cache all the hashes of the generated Rust files. These are later used by the // `modifyMtime` task. - project.extra[previousBuildHashesKey] = project.buildDir.walk() - .filter { it.isFile } - .map { - getChecksumForFile(it) to it.lastModified() - } - .toMap() + project.extra[previousBuildHashesKey] = + project.buildDir.walk() + .filter { it.isFile } + .map { + getChecksumForFile(it) to it.lastModified() + } + .toMap() } } } @@ -182,9 +201,7 @@ fun Project.registerGenerateCargoWorkspaceTask( } } -fun Project.registerGenerateCargoConfigTomlTask( - outputDir: File, -) { +fun Project.registerGenerateCargoConfigTomlTask(outputDir: File) { this.tasks.register("generateCargoConfigToml") { description = "generate `.cargo/config.toml`" doFirst { diff --git a/buildSrc/src/main/kotlin/CrateSet.kt b/buildSrc/src/main/kotlin/CrateSet.kt index 765503e5a16..d616d619f0d 100644 --- a/buildSrc/src/main/kotlin/CrateSet.kt +++ b/buildSrc/src/main/kotlin/CrateSet.kt @@ -16,21 +16,21 @@ object CrateSet { * stable = true */ - val StableCrates = setOf( - // AWS crates - "aws-config", - "aws-credential-types", - "aws-runtime", - "aws-runtime-api", - "aws-sigv4", - "aws-types", - - // smithy crates - "aws-smithy-async", - "aws-smithy-runtime-api", - "aws-smithy-runtime", - "aws-smithy-types", - ) + val StableCrates = + setOf( + // AWS crates + "aws-config", + "aws-credential-types", + "aws-runtime", + "aws-runtime-api", + "aws-sigv4", + "aws-types", + // smithy crates + "aws-smithy-async", + "aws-smithy-runtime-api", + "aws-smithy-runtime", + "aws-smithy-types", + ) val version = { name: String -> when { @@ -39,44 +39,48 @@ object CrateSet { } } - val AWS_SDK_RUNTIME = listOf( - "aws-config", - "aws-credential-types", - "aws-endpoint", - "aws-http", - "aws-hyper", - "aws-runtime", - "aws-runtime-api", - "aws-sig-auth", - "aws-sigv4", - "aws-types", - ).map { Crate(it, version(it)) } + val AWS_SDK_RUNTIME = + listOf( + "aws-config", + "aws-credential-types", + "aws-endpoint", + "aws-http", + "aws-hyper", + "aws-runtime", + "aws-runtime-api", + "aws-sig-auth", + "aws-sigv4", + "aws-types", + ).map { Crate(it, version(it)) } - val SMITHY_RUNTIME_COMMON = listOf( - "aws-smithy-async", - "aws-smithy-checksums", - "aws-smithy-client", - "aws-smithy-eventstream", - "aws-smithy-http", - "aws-smithy-http-auth", - "aws-smithy-http-tower", - "aws-smithy-json", - "aws-smithy-protocol-test", - "aws-smithy-query", - "aws-smithy-runtime", - "aws-smithy-runtime-api", - "aws-smithy-types", - "aws-smithy-types-convert", - "aws-smithy-xml", - ).map { Crate(it, version(it)) } + val SMITHY_RUNTIME_COMMON = + listOf( + "aws-smithy-async", + "aws-smithy-checksums", + "aws-smithy-client", + "aws-smithy-eventstream", + "aws-smithy-http", + "aws-smithy-http-auth", + "aws-smithy-http-tower", + "aws-smithy-json", + "aws-smithy-protocol-test", + "aws-smithy-query", + "aws-smithy-runtime", + "aws-smithy-runtime-api", + "aws-smithy-types", + "aws-smithy-types-convert", + "aws-smithy-xml", + ).map { Crate(it, version(it)) } val AWS_SDK_SMITHY_RUNTIME = SMITHY_RUNTIME_COMMON - val SERVER_SMITHY_RUNTIME = SMITHY_RUNTIME_COMMON + listOf( - Crate("aws-smithy-http-server", UNSTABLE_VERSION_PROP_NAME), - Crate("aws-smithy-http-server-python", UNSTABLE_VERSION_PROP_NAME), - Crate("aws-smithy-http-server-typescript", UNSTABLE_VERSION_PROP_NAME), - ) + val SERVER_SMITHY_RUNTIME = + SMITHY_RUNTIME_COMMON + + listOf( + Crate("aws-smithy-http-server", UNSTABLE_VERSION_PROP_NAME), + Crate("aws-smithy-http-server-python", UNSTABLE_VERSION_PROP_NAME), + Crate("aws-smithy-http-server-typescript", UNSTABLE_VERSION_PROP_NAME), + ) val ENTIRE_SMITHY_RUNTIME = (AWS_SDK_SMITHY_RUNTIME + SERVER_SMITHY_RUNTIME).toSortedSet(compareBy { it.name }) diff --git a/buildSrc/src/main/kotlin/HashUtils.kt b/buildSrc/src/main/kotlin/HashUtils.kt index fb093c15712..d92f328b9c1 100644 --- a/buildSrc/src/main/kotlin/HashUtils.kt +++ b/buildSrc/src/main/kotlin/HashUtils.kt @@ -8,5 +8,7 @@ import java.security.MessageDigest fun ByteArray.toHex() = joinToString(separator = "") { byte -> "%02x".format(byte) } -fun getChecksumForFile(file: File, digest: MessageDigest = MessageDigest.getInstance("SHA-256")): String = - digest.digest(file.readText().toByteArray()).toHex() +fun getChecksumForFile( + file: File, + digest: MessageDigest = MessageDigest.getInstance("SHA-256"), +): String = digest.digest(file.readText().toByteArray()).toHex() diff --git a/buildSrc/src/main/kotlin/ManifestPatcher.kt b/buildSrc/src/main/kotlin/ManifestPatcher.kt index aabe25a3db8..813c137527c 100644 --- a/buildSrc/src/main/kotlin/ManifestPatcher.kt +++ b/buildSrc/src/main/kotlin/ManifestPatcher.kt @@ -5,20 +5,29 @@ import java.io.File -fun rewriteCrateVersion(line: String, version: String): String = line.replace( - """^\s*version\s*=\s*"0.0.0-smithy-rs-head"$""".toRegex(), - "version = \"$version\"", -) +fun rewriteCrateVersion( + line: String, + version: String, +): String = + line.replace( + """^\s*version\s*=\s*"0.0.0-smithy-rs-head"$""".toRegex(), + "version = \"$version\"", + ) /** * Smithy runtime crate versions in smithy-rs are all `0.0.0-smithy-rs-head`. When copying over to the AWS SDK, * these should be changed to the smithy-rs version. */ -fun rewriteRuntimeCrateVersion(version: String, line: String): String = - rewriteCrateVersion(line, version) +fun rewriteRuntimeCrateVersion( + version: String, + line: String, +): String = rewriteCrateVersion(line, version) /** Patches a file with the result of the given `operation` being run on each line */ -fun patchFile(path: File, operation: (String) -> String) { +fun patchFile( + path: File, + operation: (String) -> String, +) { val patchedContents = path.readLines().joinToString("\n", transform = operation) path.writeText(patchedContents) } diff --git a/buildSrc/src/main/kotlin/aws/sdk/CrateVersioner.kt b/buildSrc/src/main/kotlin/aws/sdk/CrateVersioner.kt index 08145202032..e26f66a640b 100644 --- a/buildSrc/src/main/kotlin/aws/sdk/CrateVersioner.kt +++ b/buildSrc/src/main/kotlin/aws/sdk/CrateVersioner.kt @@ -14,13 +14,17 @@ import java.security.MessageDigest const val LOCAL_DEV_VERSION: String = "0.0.0-local" object CrateVersioner { - fun defaultFor(rootProject: Project, properties: PropertyRetriever): VersionCrate = + fun defaultFor( + rootProject: Project, + properties: PropertyRetriever, + ): VersionCrate = when (val versionsManifestPath = properties.get("aws.sdk.previous.release.versions.manifest")) { // In local dev, use special `0.0.0-local` version number for all SDK crates null -> SynchronizedCrateVersioner(properties, sdkVersion = LOCAL_DEV_VERSION) else -> { - val modelMetadataPath = properties.get("aws.sdk.model.metadata") - ?: throw IllegalArgumentException("Property `aws.sdk.model.metadata` required for independent crate version builds") + val modelMetadataPath = + properties.get("aws.sdk.model.metadata") + ?: throw IllegalArgumentException("Property `aws.sdk.model.metadata` required for independent crate version builds") IndependentCrateVersioner( VersionsManifest.fromFile(versionsManifestPath), ModelMetadata.fromFile(modelMetadataPath), @@ -32,21 +36,28 @@ object CrateVersioner { } interface VersionCrate { - fun decideCrateVersion(moduleName: String, service: AwsService): String + fun decideCrateVersion( + moduleName: String, + service: AwsService, + ): String fun independentVersioningEnabled(): Boolean } class SynchronizedCrateVersioner( properties: PropertyRetriever, - private val sdkVersion: String = properties.get(CrateSet.STABLE_VERSION_PROP_NAME) - ?: throw Exception("SDK runtime crate version missing"), + private val sdkVersion: String = + properties.get(CrateSet.STABLE_VERSION_PROP_NAME) + ?: throw Exception("SDK runtime crate version missing"), ) : VersionCrate { init { LoggerFactory.getLogger(javaClass).info("Using synchronized SDK crate versioning with version `$sdkVersion`") } - override fun decideCrateVersion(moduleName: String, service: AwsService): String = sdkVersion + override fun decideCrateVersion( + moduleName: String, + service: AwsService, + ): String = sdkVersion override fun independentVersioningEnabled(): Boolean = sdkVersion == LOCAL_DEV_VERSION } @@ -73,7 +84,9 @@ private data class SemVer( } fun bumpMajor(): SemVer = copy(major = major + 1, minor = 0, patch = 0) + fun bumpMinor(): SemVer = copy(minor = minor + 1, patch = 0) + fun bumpPatch(): SemVer = copy(patch = patch + 1) override fun toString(): String { @@ -119,32 +132,36 @@ class IndependentCrateVersioner( override fun independentVersioningEnabled(): Boolean = true - override fun decideCrateVersion(moduleName: String, service: AwsService): String { + override fun decideCrateVersion( + moduleName: String, + service: AwsService, + ): String { var previousVersion: SemVer? = null - val (reason, newVersion) = when (val existingCrate = versionsManifest.crates.get(moduleName)) { - // The crate didn't exist before, so create a new major version - null -> "new service" to newMajorVersion() - else -> { - previousVersion = SemVer.parse(existingCrate.version) - if (smithyRsChanged) { - "smithy-rs changed" to previousVersion.bumpCodegenChanged() - } else { - when (modelMetadata.changeType(moduleName)) { - ChangeType.FEATURE -> "its API changed" to previousVersion.bumpModelChanged() - ChangeType.DOCUMENTATION -> "it has new docs" to previousVersion.bumpDocsChanged() - ChangeType.UNCHANGED -> { - val currentModelsHash = hashModelsFn(service) - val previousModelsHash = existingCrate.modelHash - if (currentModelsHash != previousModelsHash) { - "its model(s) changed" to previousVersion.bumpModelChanged() - } else { - "no change" to previousVersion + val (reason, newVersion) = + when (val existingCrate = versionsManifest.crates.get(moduleName)) { + // The crate didn't exist before, so create a new major version + null -> "new service" to newMajorVersion() + else -> { + previousVersion = SemVer.parse(existingCrate.version) + if (smithyRsChanged) { + "smithy-rs changed" to previousVersion.bumpCodegenChanged() + } else { + when (modelMetadata.changeType(moduleName)) { + ChangeType.FEATURE -> "its API changed" to previousVersion.bumpModelChanged() + ChangeType.DOCUMENTATION -> "it has new docs" to previousVersion.bumpDocsChanged() + ChangeType.UNCHANGED -> { + val currentModelsHash = hashModelsFn(service) + val previousModelsHash = existingCrate.modelHash + if (currentModelsHash != previousModelsHash) { + "its model(s) changed" to previousVersion.bumpModelChanged() + } else { + "no change" to previousVersion + } } } } } } - } if (previousVersion == null) { logger.info("`$moduleName` is a new service. Starting it at `$newVersion`") } else if (previousVersion != newVersion) { @@ -155,28 +172,35 @@ class IndependentCrateVersioner( return newVersion.toString() } - private fun newMajorVersion(): SemVer = when (devPreview) { - true -> SemVer.parse("0.1.0") - else -> SemVer.parse("1.0.0") - } + private fun newMajorVersion(): SemVer = + when (devPreview) { + true -> SemVer.parse("0.1.0") + else -> SemVer.parse("1.0.0") + } private fun SemVer.bumpCodegenChanged(): SemVer = bumpMinor() - private fun SemVer.bumpModelChanged(): SemVer = when (devPreview) { - true -> bumpPatch() - else -> bumpMinor() - } + + private fun SemVer.bumpModelChanged(): SemVer = + when (devPreview) { + true -> bumpPatch() + else -> bumpMinor() + } private fun SemVer.bumpDocsChanged(): SemVer = bumpPatch() } private fun ByteArray.toLowerHex(): String = joinToString("") { byte -> "%02x".format(byte) } -fun hashModels(awsService: AwsService, loadFile: (File) -> ByteArray = File::readBytes): String { +fun hashModels( + awsService: AwsService, + loadFile: (File) -> ByteArray = File::readBytes, +): String { // Needs to match hashing done in the `generate-version-manifest` tool: val sha256 = MessageDigest.getInstance("SHA-256") - val hashes = awsService.modelFiles().fold("") { hashes, file -> - val fileHash = sha256.digest(loadFile(file)).toLowerHex() - hashes + fileHash + "\n" - } + val hashes = + awsService.modelFiles().fold("") { hashes, file -> + val fileHash = sha256.digest(loadFile(file)).toLowerHex() + hashes + fileHash + "\n" + } return sha256.digest(hashes.toByteArray(Charsets.UTF_8)).toLowerHex() } diff --git a/buildSrc/src/main/kotlin/aws/sdk/DocsLandingPage.kt b/buildSrc/src/main/kotlin/aws/sdk/DocsLandingPage.kt index 715917c4db2..097c47f8e07 100644 --- a/buildSrc/src/main/kotlin/aws/sdk/DocsLandingPage.kt +++ b/buildSrc/src/main/kotlin/aws/sdk/DocsLandingPage.kt @@ -15,7 +15,10 @@ import java.io.File * The generated docs will include links to crates.io, docs.rs and GitHub examples for all generated services. The generated docs will * be written to `docs.md` in the provided [outputDir]. */ -fun Project.docsLandingPage(awsServices: AwsServices, outputPath: File) { +fun Project.docsLandingPage( + awsServices: AwsServices, + outputPath: File, +) { val project = this val writer = SimpleCodeWriter() with(writer) { @@ -28,7 +31,7 @@ fun Project.docsLandingPage(awsServices: AwsServices, outputPath: File) { writer.write("## AWS Services") writer.write("") // empty line between header and table - /* generate a basic markdown table */ + // generate a basic markdown table writer.write("| Service | Package |") writer.write("| ------- | ------- |") awsServices.services.sortedBy { it.humanName }.forEach { @@ -44,7 +47,10 @@ fun Project.docsLandingPage(awsServices: AwsServices, outputPath: File) { /** * Generate a link to the examples for a given service */ -private fun examplesLink(service: AwsService, project: Project) = service.examplesUri(project)?.let { +private fun examplesLink( + service: AwsService, + project: Project, +) = service.examplesUri(project)?.let { "([examples]($it))" } @@ -52,6 +58,9 @@ private fun examplesLink(service: AwsService, project: Project) = service.exampl * Generate a link to the docs */ private fun docsRs(service: AwsService) = docsRs(service.crate()) + private fun docsRs(crate: String) = "([docs](https://docs.rs/$crate))" + private fun cratesIo(service: AwsService) = cratesIo(service.crate()) + private fun cratesIo(crate: String) = "[$crate](https://crates.io/crates/$crate)" diff --git a/buildSrc/src/main/kotlin/aws/sdk/ModelMetadata.kt b/buildSrc/src/main/kotlin/aws/sdk/ModelMetadata.kt index 95d65fd106d..90a6337bfcb 100644 --- a/buildSrc/src/main/kotlin/aws/sdk/ModelMetadata.kt +++ b/buildSrc/src/main/kotlin/aws/sdk/ModelMetadata.kt @@ -27,17 +27,20 @@ data class ModelMetadata( fun fromString(value: String): ModelMetadata { val toml = Toml().read(value) return ModelMetadata( - crates = toml.getTable("crates")?.entrySet()?.map { entry -> - entry.key to when (val kind = (entry.value as Toml).getString("kind")) { - "Feature" -> ChangeType.FEATURE - "Documentation" -> ChangeType.DOCUMENTATION - else -> throw IllegalArgumentException("Unrecognized change type: $kind") - } - }?.toMap() ?: emptyMap(), + crates = + toml.getTable("crates")?.entrySet()?.map { entry -> + entry.key to + when (val kind = (entry.value as Toml).getString("kind")) { + "Feature" -> ChangeType.FEATURE + "Documentation" -> ChangeType.DOCUMENTATION + else -> throw IllegalArgumentException("Unrecognized change type: $kind") + } + }?.toMap() ?: emptyMap(), ) } } fun hasCrates(): Boolean = crates.isNotEmpty() + fun changeType(moduleName: String): ChangeType = crates[moduleName] ?: ChangeType.UNCHANGED } diff --git a/buildSrc/src/main/kotlin/aws/sdk/ServiceLoader.kt b/buildSrc/src/main/kotlin/aws/sdk/ServiceLoader.kt index f31211f4db0..52113f6bc05 100644 --- a/buildSrc/src/main/kotlin/aws/sdk/ServiceLoader.kt +++ b/buildSrc/src/main/kotlin/aws/sdk/ServiceLoader.kt @@ -40,7 +40,7 @@ class AwsServices( // Root tests should not be included since they can't be part of the root Cargo workspace // in order to test differences in Cargo features. Examples should not be included either // because each example itself is a workspace. - ).toSortedSet() + ).toSortedSet() } val examples: List by lazy { @@ -68,11 +68,12 @@ class AwsServices( private fun manifestCompatibleWithGeneratedServices(path: File) = File(path, "Cargo.toml").let { cargoToml -> if (cargoToml.exists()) { - val usedModules = cargoToml.readLines() - .map { line -> line.substringBefore('=').trim() } - .filter { line -> line.startsWith("aws-sdk-") } - .map { line -> line.substringAfter("aws-sdk-") } - .toSet() + val usedModules = + cargoToml.readLines() + .map { line -> line.substringBefore('=').trim() } + .filter { line -> line.startsWith("aws-sdk-") } + .map { line -> line.substringAfter("aws-sdk-") } + .toSet() moduleNames.containsAll(usedModules) } else { false @@ -96,47 +97,53 @@ class AwsServices( * Since this function parses all models, it is relatively expensive to call. The result should be cached in a property * during build. */ -fun Project.discoverServices(awsModelsPath: String?, serviceMembership: Membership): AwsServices { +fun Project.discoverServices( + awsModelsPath: String?, + serviceMembership: Membership, +): AwsServices { val models = awsModelsPath?.let { File(it) } ?: project.file("aws-models") logger.info("Using model path: $models") - val baseServices = fileTree(models) - .sortedBy { file -> file.name } - .mapNotNull { file -> - val model = Model.assembler().addImport(file.absolutePath).assemble().result.get() - val services: List = model.shapes(ServiceShape::class.java).sorted().toList() - if (services.size > 1) { - throw Exception("There must be exactly one service in each aws model file") - } - if (services.isEmpty()) { - logger.info("${file.name} has no services") - null - } else { - val service = services[0] - val title = service.expectTrait(TitleTrait::class.java).value - val sdkId = service.expectTrait(ServiceTrait::class.java).sdkId - .toLowerCase() - .replace(" ", "") - // The smithy models should not include the suffix "service" but currently they do - .removeSuffix("service") - .removeSuffix("api") - val testFile = file.parentFile.resolve("$sdkId-tests.smithy") - val extras = if (testFile.exists()) { - logger.warn("Discovered protocol tests for ${file.name}") - listOf(testFile) + val baseServices = + fileTree(models) + .sortedBy { file -> file.name } + .mapNotNull { file -> + val model = Model.assembler().addImport(file.absolutePath).assemble().result.get() + val services: List = model.shapes(ServiceShape::class.java).sorted().toList() + if (services.size > 1) { + throw Exception("There must be exactly one service in each aws model file") + } + if (services.isEmpty()) { + logger.info("${file.name} has no services") + null } else { - listOf() + val service = services[0] + val title = service.expectTrait(TitleTrait::class.java).value + val sdkId = + service.expectTrait(ServiceTrait::class.java).sdkId + .toLowerCase() + .replace(" ", "") + // The smithy models should not include the suffix "service" but currently they do + .removeSuffix("service") + .removeSuffix("api") + val testFile = file.parentFile.resolve("$sdkId-tests.smithy") + val extras = + if (testFile.exists()) { + logger.warn("Discovered protocol tests for ${file.name}") + listOf(testFile) + } else { + listOf() + } + AwsService( + service = service.id.toString(), + module = sdkId, + moduleDescription = "AWS SDK for $title", + modelFile = file, + // Order is important for the versions.toml model hash calculation + extraFiles = extras.sorted(), + humanName = title, + ) } - AwsService( - service = service.id.toString(), - module = sdkId, - moduleDescription = "AWS SDK for $title", - modelFile = file, - // Order is important for the versions.toml model hash calculation - extraFiles = extras.sorted(), - humanName = title, - ) } - } val baseModules = baseServices.map { it.module }.toSet() logger.info("Discovered base service modules to generate: $baseModules") @@ -182,26 +189,29 @@ data class AwsService( val humanName: String, ) { fun modelFiles(): List = listOf(modelFile) + extraFiles + fun Project.examples(): File = projectDir.resolve("examples").resolve(module) /** * Generate a link to the examples for a given service */ - fun examplesUri(project: Project) = if (project.examples().exists()) { - "https://github.com/awslabs/aws-sdk-rust/tree/main/examples/$module" - } else { - null - } + fun examplesUri(project: Project) = + if (project.examples().exists()) { + "https://github.com/awslabs/aws-sdk-rust/tree/main/examples/$module" + } else { + null + } } fun AwsService.crate(): String = "aws-sdk-$module" -private fun Membership.isMember(member: String): Boolean = when { - exclusions.contains(member) -> false - inclusions.contains(member) -> true - inclusions.isEmpty() -> true - else -> false -} +private fun Membership.isMember(member: String): Boolean = + when { + exclusions.contains(member) -> false + inclusions.contains(member) -> true + inclusions.isEmpty() -> true + else -> false + } fun parseMembership(rawList: String): Membership { val inclusions = mutableSetOf() diff --git a/buildSrc/src/main/kotlin/aws/sdk/VersionsManifest.kt b/buildSrc/src/main/kotlin/aws/sdk/VersionsManifest.kt index 4ff726ab98d..01f97867c1d 100644 --- a/buildSrc/src/main/kotlin/aws/sdk/VersionsManifest.kt +++ b/buildSrc/src/main/kotlin/aws/sdk/VersionsManifest.kt @@ -32,15 +32,17 @@ data class VersionsManifest( return VersionsManifest( smithyRsRevision = toml.getString("smithy_rs_revision"), awsDocSdkExamplesRevision = toml.getString("aws_doc_sdk_examples_revision"), - crates = toml.getTable("crates").entrySet().map { entry -> - val crate = (entry.value as Toml) - entry.key to CrateVersion( - category = crate.getString("category"), - version = crate.getString("version"), - sourceHash = crate.getString("source_hash"), - modelHash = crate.getString("model_hash"), - ) - }.toMap(), + crates = + toml.getTable("crates").entrySet().map { entry -> + val crate = (entry.value as Toml) + entry.key to + CrateVersion( + category = crate.getString("category"), + version = crate.getString("version"), + sourceHash = crate.getString("source_hash"), + modelHash = crate.getString("model_hash"), + ) + }.toMap(), ) } } diff --git a/buildSrc/src/test/kotlin/CrateSetTest.kt b/buildSrc/src/test/kotlin/CrateSetTest.kt index dacc75cc329..53b1340db2f 100644 --- a/buildSrc/src/test/kotlin/CrateSetTest.kt +++ b/buildSrc/src/test/kotlin/CrateSetTest.kt @@ -22,7 +22,11 @@ class CrateSetTest { * matches what `package.metadata.smithy-rs-release-tooling` says in the `Cargo.toml` * for the corresponding crate. */ - private fun sutStabilityMatchesManifestStability(versionPropertyName: String, stabilityInManifest: Boolean, crate: String) { + private fun sutStabilityMatchesManifestStability( + versionPropertyName: String, + stabilityInManifest: Boolean, + crate: String, + ) { when (stabilityInManifest) { true -> assertEquals(STABLE_VERSION_PROP_NAME, versionPropertyName, "Crate: $crate") false -> assertEquals(UNSTABLE_VERSION_PROP_NAME, versionPropertyName, "Crate: $crate") @@ -37,16 +41,20 @@ class CrateSetTest { * If `package.metadata.smithy-rs-release-tooling` does not exist in a `Cargo.toml`, the implementation * will treat that crate as unstable. */ - private fun crateSetStabilitiesMatchManifestStabilities(crateSet: List, relativePathToRustRuntime: String) { + private fun crateSetStabilitiesMatchManifestStabilities( + crateSet: List, + relativePathToRustRuntime: String, + ) { crateSet.forEach { val path = "$relativePathToRustRuntime/${it.name}/Cargo.toml" val contents = File(path).readText() - val isStable = try { - Toml().read(contents).getTable("package.metadata.smithy-rs-release-tooling")?.getBoolean("stable") ?: false - } catch (e: java.lang.IllegalStateException) { - // sigv4 doesn't parse but it's stable now, hax hax hax - contents.trim().endsWith("[package.metadata.smithy-rs-release-tooling]\nstable = true") - } + val isStable = + try { + Toml().read(contents).getTable("package.metadata.smithy-rs-release-tooling")?.getBoolean("stable") ?: false + } catch (e: java.lang.IllegalStateException) { + // sigv4 doesn't parse but it's stable now, hax hax hax + contents.trim().endsWith("[package.metadata.smithy-rs-release-tooling]\nstable = true") + } sutStabilityMatchesManifestStability(it.versionPropertyName, isStable, it.name) } } diff --git a/buildSrc/src/test/kotlin/aws/sdk/IndependentCrateVersionerTest.kt b/buildSrc/src/test/kotlin/aws/sdk/IndependentCrateVersionerTest.kt index f4057c63144..0ec9f8051ae 100644 --- a/buildSrc/src/test/kotlin/aws/sdk/IndependentCrateVersionerTest.kt +++ b/buildSrc/src/test/kotlin/aws/sdk/IndependentCrateVersionerTest.kt @@ -9,15 +9,16 @@ import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.Test import java.io.File -private fun service(name: String): AwsService = AwsService( - name, - "test", - "test", - File("testmodel"), - null, - emptyList(), - name, -) +private fun service(name: String): AwsService = + AwsService( + name, + "test", + "test", + File("testmodel"), + null, + emptyList(), + name, + ) class IndependentCrateVersionerTest { @Test @@ -27,45 +28,51 @@ class IndependentCrateVersionerTest { val s3 = service("s3") val someNewService = service("somenewservice") - val versioner = IndependentCrateVersioner( - VersionsManifest( - smithyRsRevision = "smithy-rs-1", - awsDocSdkExamplesRevision = "dontcare", - crates = mapOf( - "aws-sdk-dynamodb" to CrateVersion( - category = "AwsSdk", - version = "0.11.3", - modelHash = "dynamodb-hash", - ), - "aws-sdk-ec2" to CrateVersion( - category = "AwsSdk", - version = "0.10.1", - modelHash = "ec2-hash", - ), - "aws-sdk-s3" to CrateVersion( - category = "AwsSdk", - version = "0.12.0", - modelHash = "s3-hash", - ), + val versioner = + IndependentCrateVersioner( + VersionsManifest( + smithyRsRevision = "smithy-rs-1", + awsDocSdkExamplesRevision = "dontcare", + crates = + mapOf( + "aws-sdk-dynamodb" to + CrateVersion( + category = "AwsSdk", + version = "0.11.3", + modelHash = "dynamodb-hash", + ), + "aws-sdk-ec2" to + CrateVersion( + category = "AwsSdk", + version = "0.10.1", + modelHash = "ec2-hash", + ), + "aws-sdk-s3" to + CrateVersion( + category = "AwsSdk", + version = "0.12.0", + modelHash = "s3-hash", + ), + ), ), - ), - ModelMetadata( - crates = mapOf( - "aws-sdk-dynamodb" to ChangeType.FEATURE, - "aws-sdk-ec2" to ChangeType.DOCUMENTATION, + ModelMetadata( + crates = + mapOf( + "aws-sdk-dynamodb" to ChangeType.FEATURE, + "aws-sdk-ec2" to ChangeType.DOCUMENTATION, + ), ), - ), - devPreview = true, - smithyRsVersion = "smithy-rs-2", - hashModelsFn = { service -> - when (service) { - dynamoDb -> "dynamodb-hash" - ec2 -> "ec2-hash" - s3 -> "s3-hash" - else -> throw IllegalStateException("unreachable") - } - }, - ) + devPreview = true, + smithyRsVersion = "smithy-rs-2", + hashModelsFn = { service -> + when (service) { + dynamoDb -> "dynamodb-hash" + ec2 -> "ec2-hash" + s3 -> "s3-hash" + else -> throw IllegalStateException("unreachable") + } + }, + ) // The code generator changed, so all minor versions should bump assertEquals("0.12.0", versioner.decideCrateVersion("aws-sdk-dynamodb", dynamoDb)) @@ -82,52 +89,59 @@ class IndependentCrateVersionerTest { val s3 = service("s3") val someNewService = service("somenewservice") - val versioner = IndependentCrateVersioner( - VersionsManifest( - smithyRsRevision = "smithy-rs-1", - awsDocSdkExamplesRevision = "dontcare", - crates = mapOf( - "aws-sdk-dynamodb" to CrateVersion( - category = "AwsSdk", - version = "0.11.3", - modelHash = "dynamodb-hash", - ), - "aws-sdk-ec2" to CrateVersion( - category = "AwsSdk", - version = "0.10.1", - modelHash = "ec2-hash", - ), - "aws-sdk-polly" to CrateVersion( - category = "AwsSdk", - version = "0.9.0", - modelHash = "old-polly-hash", - ), - "aws-sdk-s3" to CrateVersion( - category = "AwsSdk", - version = "0.12.0", - modelHash = "s3-hash", - ), + val versioner = + IndependentCrateVersioner( + VersionsManifest( + smithyRsRevision = "smithy-rs-1", + awsDocSdkExamplesRevision = "dontcare", + crates = + mapOf( + "aws-sdk-dynamodb" to + CrateVersion( + category = "AwsSdk", + version = "0.11.3", + modelHash = "dynamodb-hash", + ), + "aws-sdk-ec2" to + CrateVersion( + category = "AwsSdk", + version = "0.10.1", + modelHash = "ec2-hash", + ), + "aws-sdk-polly" to + CrateVersion( + category = "AwsSdk", + version = "0.9.0", + modelHash = "old-polly-hash", + ), + "aws-sdk-s3" to + CrateVersion( + category = "AwsSdk", + version = "0.12.0", + modelHash = "s3-hash", + ), + ), ), - ), - ModelMetadata( - crates = mapOf( - "aws-sdk-dynamodb" to ChangeType.FEATURE, - "aws-sdk-ec2" to ChangeType.DOCUMENTATION, - // polly has a model change, but is absent from the model metadata file + ModelMetadata( + crates = + mapOf( + "aws-sdk-dynamodb" to ChangeType.FEATURE, + "aws-sdk-ec2" to ChangeType.DOCUMENTATION, + // polly has a model change, but is absent from the model metadata file + ), ), - ), - devPreview = true, - smithyRsVersion = "smithy-rs-1", - hashModelsFn = { service -> - when (service) { - dynamoDb -> "dynamodb-hash" - ec2 -> "ec2-hash" - polly -> "NEW-polly-hash" - s3 -> "s3-hash" - else -> throw IllegalStateException("unreachable") - } - }, - ) + devPreview = true, + smithyRsVersion = "smithy-rs-1", + hashModelsFn = { service -> + when (service) { + dynamoDb -> "dynamodb-hash" + ec2 -> "ec2-hash" + polly -> "NEW-polly-hash" + s3 -> "s3-hash" + else -> throw IllegalStateException("unreachable") + } + }, + ) assertEquals("0.11.4", versioner.decideCrateVersion("aws-sdk-dynamodb", dynamoDb)) assertEquals("0.10.2", versioner.decideCrateVersion("aws-sdk-ec2", ec2)) @@ -143,34 +157,40 @@ class IndependentCrateVersionerTest { val s3 = service("s3") val someNewService = service("somenewservice") - val versioner = IndependentCrateVersioner( - VersionsManifest( - smithyRsRevision = "smithy-rs-1", - awsDocSdkExamplesRevision = "dontcare", - crates = mapOf( - "aws-sdk-dynamodb" to CrateVersion( - category = "AwsSdk", - version = "1.11.3", - ), - "aws-sdk-ec2" to CrateVersion( - category = "AwsSdk", - version = "1.10.1", - ), - "aws-sdk-s3" to CrateVersion( - category = "AwsSdk", - version = "1.12.0", - ), + val versioner = + IndependentCrateVersioner( + VersionsManifest( + smithyRsRevision = "smithy-rs-1", + awsDocSdkExamplesRevision = "dontcare", + crates = + mapOf( + "aws-sdk-dynamodb" to + CrateVersion( + category = "AwsSdk", + version = "1.11.3", + ), + "aws-sdk-ec2" to + CrateVersion( + category = "AwsSdk", + version = "1.10.1", + ), + "aws-sdk-s3" to + CrateVersion( + category = "AwsSdk", + version = "1.12.0", + ), + ), ), - ), - ModelMetadata( - crates = mapOf( - "aws-sdk-dynamodb" to ChangeType.FEATURE, - "aws-sdk-ec2" to ChangeType.DOCUMENTATION, + ModelMetadata( + crates = + mapOf( + "aws-sdk-dynamodb" to ChangeType.FEATURE, + "aws-sdk-ec2" to ChangeType.DOCUMENTATION, + ), ), - ), - devPreview = false, - smithyRsVersion = "smithy-rs-2", - ) + devPreview = false, + smithyRsVersion = "smithy-rs-2", + ) // The code generator changed, so all minor versions should bump assertEquals("1.12.0", versioner.decideCrateVersion("aws-sdk-dynamodb", dynamoDb)) @@ -187,52 +207,59 @@ class IndependentCrateVersionerTest { val s3 = service("s3") val someNewService = service("somenewservice") - val versioner = IndependentCrateVersioner( - VersionsManifest( - smithyRsRevision = "smithy-rs-1", - awsDocSdkExamplesRevision = "dontcare", - crates = mapOf( - "aws-sdk-dynamodb" to CrateVersion( - category = "AwsSdk", - version = "1.11.3", - modelHash = "dynamodb-hash", - ), - "aws-sdk-ec2" to CrateVersion( - category = "AwsSdk", - version = "1.10.1", - modelHash = "ec2-hash", - ), - "aws-sdk-polly" to CrateVersion( - category = "AwsSdk", - version = "1.9.0", - modelHash = "old-polly-hash", - ), - "aws-sdk-s3" to CrateVersion( - category = "AwsSdk", - version = "1.12.0", - modelHash = "s3-hash", - ), + val versioner = + IndependentCrateVersioner( + VersionsManifest( + smithyRsRevision = "smithy-rs-1", + awsDocSdkExamplesRevision = "dontcare", + crates = + mapOf( + "aws-sdk-dynamodb" to + CrateVersion( + category = "AwsSdk", + version = "1.11.3", + modelHash = "dynamodb-hash", + ), + "aws-sdk-ec2" to + CrateVersion( + category = "AwsSdk", + version = "1.10.1", + modelHash = "ec2-hash", + ), + "aws-sdk-polly" to + CrateVersion( + category = "AwsSdk", + version = "1.9.0", + modelHash = "old-polly-hash", + ), + "aws-sdk-s3" to + CrateVersion( + category = "AwsSdk", + version = "1.12.0", + modelHash = "s3-hash", + ), + ), ), - ), - ModelMetadata( - crates = mapOf( - "aws-sdk-dynamodb" to ChangeType.FEATURE, - "aws-sdk-ec2" to ChangeType.DOCUMENTATION, - // polly has a model change, but is absent from the model metadata file + ModelMetadata( + crates = + mapOf( + "aws-sdk-dynamodb" to ChangeType.FEATURE, + "aws-sdk-ec2" to ChangeType.DOCUMENTATION, + // polly has a model change, but is absent from the model metadata file + ), ), - ), - devPreview = false, - smithyRsVersion = "smithy-rs-1", - hashModelsFn = { service -> - when (service) { - dynamoDb -> "dynamodb-hash" - ec2 -> "ec2-hash" - polly -> "NEW-polly-hash" - s3 -> "s3-hash" - else -> throw IllegalStateException("unreachable") - } - }, - ) + devPreview = false, + smithyRsVersion = "smithy-rs-1", + hashModelsFn = { service -> + when (service) { + dynamoDb -> "dynamodb-hash" + ec2 -> "ec2-hash" + polly -> "NEW-polly-hash" + s3 -> "s3-hash" + else -> throw IllegalStateException("unreachable") + } + }, + ) assertEquals("1.12.0", versioner.decideCrateVersion("aws-sdk-dynamodb", dynamoDb)) assertEquals("1.10.2", versioner.decideCrateVersion("aws-sdk-ec2", ec2)) @@ -245,17 +272,19 @@ class IndependentCrateVersionerTest { class HashModelsTest { @Test fun testHashModels() { - val service = service("test").copy( - modelFile = File("model1a"), - extraFiles = listOf(File("model1b")), - ) - val hash = hashModels(service) { file -> - when (file.toString()) { - "model1a" -> "foo".toByteArray(Charsets.UTF_8) - "model1b" -> "bar".toByteArray(Charsets.UTF_8) - else -> throw IllegalStateException("unreachable") + val service = + service("test").copy( + modelFile = File("model1a"), + extraFiles = listOf(File("model1b")), + ) + val hash = + hashModels(service) { file -> + when (file.toString()) { + "model1a" -> "foo".toByteArray(Charsets.UTF_8) + "model1b" -> "bar".toByteArray(Charsets.UTF_8) + else -> throw IllegalStateException("unreachable") + } } - } assertEquals("964021077fb6c3d42ae162ab2e2255be64c6d96a6d77bca089569774d54ef69b", hash) } } diff --git a/buildSrc/src/test/kotlin/aws/sdk/ModelMetadataTest.kt b/buildSrc/src/test/kotlin/aws/sdk/ModelMetadataTest.kt index d157ed0adc2..dba74d5a90e 100644 --- a/buildSrc/src/test/kotlin/aws/sdk/ModelMetadataTest.kt +++ b/buildSrc/src/test/kotlin/aws/sdk/ModelMetadataTest.kt @@ -18,13 +18,14 @@ class ModelMetadataTest { @Test fun `it should parse`() { - val contents = """ + val contents = + """ [crates.aws-sdk-someservice] kind = "Feature" [crates.aws-sdk-s3] kind = "Documentation" - """.trimIndent() + """.trimIndent() val result = ModelMetadata.fromString(contents) assertEquals(ChangeType.FEATURE, result.changeType("aws-sdk-someservice")) diff --git a/buildSrc/src/test/kotlin/aws/sdk/VersionsManifestTest.kt b/buildSrc/src/test/kotlin/aws/sdk/VersionsManifestTest.kt index 9ca9c954366..1aad2e53554 100644 --- a/buildSrc/src/test/kotlin/aws/sdk/VersionsManifestTest.kt +++ b/buildSrc/src/test/kotlin/aws/sdk/VersionsManifestTest.kt @@ -11,39 +11,42 @@ import org.junit.jupiter.api.Test class VersionsManifestTest { @Test fun `it should parse versions toml`() { - val manifest = VersionsManifest.fromString( - """ - smithy_rs_revision = 'some-smithy-rs-revision' - aws_doc_sdk_examples_revision = 'some-doc-revision' + val manifest = + VersionsManifest.fromString( + """ + smithy_rs_revision = 'some-smithy-rs-revision' + aws_doc_sdk_examples_revision = 'some-doc-revision' - [crates.aws-config] - category = 'AwsRuntime' - version = '0.12.0' - source_hash = '12d172094a2576e6f4d00a8ba58276c0d4abc4e241bb75f0d3de8ac3412e8e47' + [crates.aws-config] + category = 'AwsRuntime' + version = '0.12.0' + source_hash = '12d172094a2576e6f4d00a8ba58276c0d4abc4e241bb75f0d3de8ac3412e8e47' - [crates.aws-sdk-account] - category = 'AwsSdk' - version = '0.12.0' - source_hash = 'a0dfc080638b1d803745f0bd66b610131783cf40ab88fd710dce906fc69b983e' - model_hash = '179bbfd915093dc3bec5444771da2b20d99a37d104ba25f0acac9aa0d5bb758a' - """.trimIndent(), - ) + [crates.aws-sdk-account] + category = 'AwsSdk' + version = '0.12.0' + source_hash = 'a0dfc080638b1d803745f0bd66b610131783cf40ab88fd710dce906fc69b983e' + model_hash = '179bbfd915093dc3bec5444771da2b20d99a37d104ba25f0acac9aa0d5bb758a' + """.trimIndent(), + ) assertEquals("some-smithy-rs-revision", manifest.smithyRsRevision) assertEquals("some-doc-revision", manifest.awsDocSdkExamplesRevision) assertEquals( mapOf( - "aws-config" to CrateVersion( - category = "AwsRuntime", - version = "0.12.0", - sourceHash = "12d172094a2576e6f4d00a8ba58276c0d4abc4e241bb75f0d3de8ac3412e8e47", - ), - "aws-sdk-account" to CrateVersion( - category = "AwsSdk", - version = "0.12.0", - sourceHash = "a0dfc080638b1d803745f0bd66b610131783cf40ab88fd710dce906fc69b983e", - modelHash = "179bbfd915093dc3bec5444771da2b20d99a37d104ba25f0acac9aa0d5bb758a", - ), + "aws-config" to + CrateVersion( + category = "AwsRuntime", + version = "0.12.0", + sourceHash = "12d172094a2576e6f4d00a8ba58276c0d4abc4e241bb75f0d3de8ac3412e8e47", + ), + "aws-sdk-account" to + CrateVersion( + category = "AwsSdk", + version = "0.12.0", + sourceHash = "a0dfc080638b1d803745f0bd66b610131783cf40ab88fd710dce906fc69b983e", + modelHash = "179bbfd915093dc3bec5444771da2b20d99a37d104ba25f0acac9aa0d5bb758a", + ), ), manifest.crates, ) diff --git a/codegen-client/build.gradle.kts b/codegen-client/build.gradle.kts index baed9a9b914..e7dfc0f0958 100644 --- a/codegen-client/build.gradle.kts +++ b/codegen-client/build.gradle.kts @@ -35,7 +35,7 @@ dependencies { } tasks.compileKotlin { - kotlinOptions.jvmTarget = "1.8" + kotlinOptions.jvmTarget = "11" } // Reusable license copySpec @@ -71,7 +71,7 @@ if (isTestingEnabled.toBoolean()) { } tasks.compileTestKotlin { - kotlinOptions.jvmTarget = "1.8" + kotlinOptions.jvmTarget = "11" } tasks.test { diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientCodegenContext.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientCodegenContext.kt index 64b3a2b2b36..60e6847be37 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientCodegenContext.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientCodegenContext.kt @@ -35,9 +35,10 @@ data class ClientCodegenContext( val rootDecorator: ClientCodegenDecorator, val protocolImpl: Protocol? = null, ) : CodegenContext( - model, symbolProvider, moduleDocProvider, serviceShape, protocol, settings, CodegenTarget.CLIENT, -) { + model, symbolProvider, moduleDocProvider, serviceShape, protocol, settings, CodegenTarget.CLIENT, + ) { val enableUserConfigurableRuntimePlugins: Boolean get() = settings.codegenConfig.enableUserConfigurableRuntimePlugins + override fun builderInstantiator(): BuilderInstantiator { return ClientBuilderInstantiator(this) } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientCodegenVisitor.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientCodegenVisitor.kt index ab26c7ae9ec..d057b7e6f56 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientCodegenVisitor.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientCodegenVisitor.kt @@ -73,49 +73,55 @@ class ClientCodegenVisitor( private val operationGenerator: OperationGenerator init { - val rustSymbolProviderConfig = RustSymbolProviderConfig( - runtimeConfig = settings.runtimeConfig, - renameExceptions = settings.codegenConfig.renameExceptions, - nullabilityCheckMode = settings.codegenConfig.nullabilityCheckMode, - moduleProvider = ClientModuleProvider, - nameBuilderFor = { symbol -> "${symbol.name}Builder" }, - ) + val rustSymbolProviderConfig = + RustSymbolProviderConfig( + runtimeConfig = settings.runtimeConfig, + renameExceptions = settings.codegenConfig.renameExceptions, + nullabilityCheckMode = settings.codegenConfig.nullabilityCheckMode, + moduleProvider = ClientModuleProvider, + nameBuilderFor = { symbol -> "${symbol.name}Builder" }, + ) val baseModel = baselineTransform(context.model) val untransformedService = settings.getService(baseModel) - val (protocol, generator) = ClientProtocolLoader( - codegenDecorator.protocols(untransformedService.id, ClientProtocolLoader.DefaultProtocols), - ).protocolFor(context.model, untransformedService) + val (protocol, generator) = + ClientProtocolLoader( + codegenDecorator.protocols(untransformedService.id, ClientProtocolLoader.DefaultProtocols), + ).protocolFor(context.model, untransformedService) protocolGeneratorFactory = generator model = codegenDecorator.transformModel(untransformedService, baseModel, settings) // the model transformer _might_ change the service shape val service = settings.getService(model) symbolProvider = RustClientCodegenPlugin.baseSymbolProvider(settings, model, service, rustSymbolProviderConfig, codegenDecorator) - codegenContext = ClientCodegenContext( - model, - symbolProvider, - null, - service, - protocol, - settings, - codegenDecorator, - ) + codegenContext = + ClientCodegenContext( + model, + symbolProvider, + null, + service, + protocol, + settings, + codegenDecorator, + ) - codegenContext = codegenContext.copy( - moduleDocProvider = codegenDecorator.moduleDocumentationCustomization( - codegenContext, - ClientModuleDocProvider(codegenContext, service.serviceNameOrDefault("the service")), - ), - protocolImpl = protocolGeneratorFactory.protocol(codegenContext), - ) + codegenContext = + codegenContext.copy( + moduleDocProvider = + codegenDecorator.moduleDocumentationCustomization( + codegenContext, + ClientModuleDocProvider(codegenContext, service.serviceNameOrDefault("the service")), + ), + protocolImpl = protocolGeneratorFactory.protocol(codegenContext), + ) - rustCrate = RustCrate( - context.fileManifest, - symbolProvider, - codegenContext.settings.codegenConfig, - codegenContext.expectModuleDocProvider(), - ) + rustCrate = + RustCrate( + context.fileManifest, + symbolProvider, + codegenContext.settings.codegenConfig, + codegenContext.expectModuleDocProvider(), + ) operationGenerator = protocolGeneratorFactory.buildProtocolGenerator(codegenContext) } @@ -214,44 +220,46 @@ class ClientCodegenVisitor( * This function _does not_ generate any serializers */ override fun structureShape(shape: StructureShape) { - val (renderStruct, renderBuilder) = when (val errorTrait = shape.getTrait()) { - null -> { - val struct: Writable = { - StructureGenerator( - model, - symbolProvider, - this, - shape, - codegenDecorator.structureCustomizations(codegenContext, emptyList()), - structSettings = codegenContext.structSettings(), - ).render() + val (renderStruct, renderBuilder) = + when (val errorTrait = shape.getTrait()) { + null -> { + val struct: Writable = { + StructureGenerator( + model, + symbolProvider, + this, + shape, + codegenDecorator.structureCustomizations(codegenContext, emptyList()), + structSettings = codegenContext.structSettings(), + ).render() - implBlock(symbolProvider.toSymbol(shape)) { - BuilderGenerator.renderConvenienceMethod(this, symbolProvider, shape) + implBlock(symbolProvider.toSymbol(shape)) { + BuilderGenerator.renderConvenienceMethod(this, symbolProvider, shape) + } } + val builder: Writable = { + BuilderGenerator( + codegenContext.model, + codegenContext.symbolProvider, + shape, + codegenDecorator.builderCustomizations(codegenContext, emptyList()), + ).render(this) + } + struct to builder } - val builder: Writable = { - BuilderGenerator( - codegenContext.model, - codegenContext.symbolProvider, - shape, - codegenDecorator.builderCustomizations(codegenContext, emptyList()), - ).render(this) + else -> { + val errorGenerator = + ErrorGenerator( + model, + symbolProvider, + shape, + errorTrait, + codegenDecorator.errorImplCustomizations(codegenContext, emptyList()), + codegenContext.structSettings(), + ) + errorGenerator::renderStruct to errorGenerator::renderBuilder } - struct to builder } - else -> { - val errorGenerator = ErrorGenerator( - model, - symbolProvider, - shape, - errorTrait, - codegenDecorator.errorImplCustomizations(codegenContext, emptyList()), - codegenContext.structSettings(), - ) - errorGenerator::renderStruct to errorGenerator::renderBuilder - } - } val privateModule = privateModule(shape) rustCrate.inPrivateModuleWithReexport(privateModule, symbolProvider.toSymbol(shape)) { diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientReExports.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientReExports.kt index 6aaf638680c..e9154df9ff1 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientReExports.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientReExports.kt @@ -12,6 +12,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType * Although it is not always possible to use this, this is the preferred method for using types in config customizations * and ensures that your type will be re-exported if it is used. */ -fun configReexport(type: RuntimeType): RuntimeType = RuntimeType.forInlineFun(type.name, module = ClientRustModule.config) { - rustTemplate("pub use #{type};", "type" to type) -} +fun configReexport(type: RuntimeType): RuntimeType = + RuntimeType.forInlineFun(type.name, module = ClientRustModule.config) { + rustTemplate("pub use #{type};", "type" to type) + } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientReservedWords.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientReservedWords.kt index add9f53f6df..e2ff13dd4f4 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientReservedWords.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientReservedWords.kt @@ -9,27 +9,31 @@ import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWordConfig import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator -val ClientReservedWords = RustReservedWordConfig( - structureMemberMap = StructureGenerator.structureMemberNameMap + - mapOf( - "send" to "send_value", - // To avoid conflicts with the `make_operation` and `presigned` functions on generated inputs - "make_operation" to "make_operation_value", - "presigned" to "presigned_value", - "customize" to "customize_value", - // To avoid conflicts with the error metadata `meta` field - "meta" to "meta_value", - ), - unionMemberMap = mapOf( - // Unions contain an `Unknown` variant. This exists to support parsing data returned from the server - // that represent union variants that have been added since this SDK was generated. - UnionGenerator.UnknownVariantName to "${UnionGenerator.UnknownVariantName}Value", - "${UnionGenerator.UnknownVariantName}Value" to "${UnionGenerator.UnknownVariantName}Value_", - ), - enumMemberMap = mapOf( - // Unknown is used as the name of the variant containing unexpected values - "Unknown" to "UnknownValue", - // Real models won't end in `_` so it's safe to stop here - "UnknownValue" to "UnknownValue_", - ), -) +val ClientReservedWords = + RustReservedWordConfig( + structureMemberMap = + StructureGenerator.structureMemberNameMap + + mapOf( + "send" to "send_value", + // To avoid conflicts with the `make_operation` and `presigned` functions on generated inputs + "make_operation" to "make_operation_value", + "presigned" to "presigned_value", + "customize" to "customize_value", + // To avoid conflicts with the error metadata `meta` field + "meta" to "meta_value", + ), + unionMemberMap = + mapOf( + // Unions contain an `Unknown` variant. This exists to support parsing data returned from the server + // that represent union variants that have been added since this SDK was generated. + UnionGenerator.UnknownVariantName to "${UnionGenerator.UnknownVariantName}Value", + "${UnionGenerator.UnknownVariantName}Value" to "${UnionGenerator.UnknownVariantName}Value_", + ), + enumMemberMap = + mapOf( + // Unknown is used as the name of the variant containing unexpected values + "Unknown" to "UnknownValue", + // Real models won't end in `_` so it's safe to stop here + "UnknownValue" to "UnknownValue_", + ), + ) diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientRustModule.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientRustModule.kt index 0216b2618cc..9d33c1b1c58 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientRustModule.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientRustModule.kt @@ -46,6 +46,7 @@ object ClientRustModule { /** crate::client */ val client = Client.self + object Client { /** crate::client */ val self = RustModule.public("client") @@ -56,6 +57,7 @@ object ClientRustModule { /** crate::config */ val config = Config.self + object Config { /** crate::client */ val self = RustModule.public("config") @@ -81,6 +83,7 @@ object ClientRustModule { /** crate::primitives */ val primitives = Primitives.self + object Primitives { /** crate::primitives */ val self = RustModule.public("primitives") @@ -91,6 +94,7 @@ object ClientRustModule { /** crate::types */ val types = Types.self + object Types { /** crate::types */ val self = RustModule.public("types") @@ -127,71 +131,80 @@ class ClientModuleDocProvider( } } - private fun clientModuleDoc(): Writable = writable { - val genericClientConstructionDocs = FluentClientDocs.clientConstructionDocs(codegenContext) - val writeClientConstructionDocs = codegenContext.rootDecorator - .clientConstructionDocs(codegenContext, genericClientConstructionDocs) + private fun clientModuleDoc(): Writable = + writable { + val genericClientConstructionDocs = FluentClientDocs.clientConstructionDocs(codegenContext) + val writeClientConstructionDocs = + codegenContext.rootDecorator + .clientConstructionDocs(codegenContext, genericClientConstructionDocs) - writeClientConstructionDocs(this) - FluentClientDocs.clientUsageDocs(codegenContext)(this) - } + writeClientConstructionDocs(this) + FluentClientDocs.clientUsageDocs(codegenContext)(this) + } - private fun customizeModuleDoc(): Writable = writable { - val model = codegenContext.model - docs("Operation customization and supporting types.\n") - if (codegenContext.serviceShape.operations.isNotEmpty()) { - val opFnName = FluentClientGenerator.clientOperationFnName( - codegenContext.serviceShape.operations.minOf { it } - .let { model.expectShape(it, OperationShape::class.java) }, - codegenContext.symbolProvider, - ) - val moduleUseName = codegenContext.moduleUseName() - docsTemplate( - """ - The underlying HTTP requests made during an operation can be customized - by calling the `customize()` method on the builder returned from a client - operation call. For example, this can be used to add an additional HTTP header: - - ```ignore - ## async fn wrapper() -> #{Result}<(), $moduleUseName::Error> { - ## let client: $moduleUseName::Client = unimplemented!(); - use #{http}::header::{HeaderName, HeaderValue}; - - let result = client.$opFnName() - .customize() - .mutate_request(|req| { - // Add `x-example-header` with value - req.headers_mut() - .insert( - HeaderName::from_static("x-example-header"), - HeaderValue::from_static("1"), - ); - }) - .send() - .await; - ## } - ``` - """.trimIndent(), - *RuntimeType.preludeScope, - "http" to CargoDependency.Http.toDevDependency().toType(), - ) + private fun customizeModuleDoc(): Writable = + writable { + val model = codegenContext.model + docs("Operation customization and supporting types.\n") + if (codegenContext.serviceShape.operations.isNotEmpty()) { + val opFnName = + FluentClientGenerator.clientOperationFnName( + codegenContext.serviceShape.operations.minOf { it } + .let { model.expectShape(it, OperationShape::class.java) }, + codegenContext.symbolProvider, + ) + val moduleUseName = codegenContext.moduleUseName() + docsTemplate( + """ + The underlying HTTP requests made during an operation can be customized + by calling the `customize()` method on the builder returned from a client + operation call. For example, this can be used to add an additional HTTP header: + + ```ignore + ## async fn wrapper() -> #{Result}<(), $moduleUseName::Error> { + ## let client: $moduleUseName::Client = unimplemented!(); + use #{http}::header::{HeaderName, HeaderValue}; + + let result = client.$opFnName() + .customize() + .mutate_request(|req| { + // Add `x-example-header` with value + req.headers_mut() + .insert( + HeaderName::from_static("x-example-header"), + HeaderValue::from_static("1"), + ); + }) + .send() + .await; + ## } + ``` + """.trimIndent(), + *RuntimeType.preludeScope, + "http" to CargoDependency.Http.toDevDependency().toType(), + ) + } } - } } object ClientModuleProvider : ModuleProvider { - override fun moduleForShape(context: ModuleProviderContext, shape: Shape): RustModule.LeafModule = when (shape) { - is OperationShape -> perOperationModule(context, shape) - is StructureShape -> when { - shape.hasTrait() -> ClientRustModule.Types.Error - shape.hasTrait() -> perOperationModule(context, shape) - shape.hasTrait() -> perOperationModule(context, shape) + override fun moduleForShape( + context: ModuleProviderContext, + shape: Shape, + ): RustModule.LeafModule = + when (shape) { + is OperationShape -> perOperationModule(context, shape) + is StructureShape -> + when { + shape.hasTrait() -> ClientRustModule.Types.Error + shape.hasTrait() -> perOperationModule(context, shape) + shape.hasTrait() -> perOperationModule(context, shape) + else -> ClientRustModule.types + } + else -> ClientRustModule.types } - else -> ClientRustModule.types - } - override fun moduleForOperationError( context: ModuleProviderContext, operation: OperationShape, @@ -202,7 +215,11 @@ object ClientModuleProvider : ModuleProvider { eventStream: UnionShape, ): RustModule.LeafModule = ClientRustModule.Types.Error - override fun moduleForBuilder(context: ModuleProviderContext, shape: Shape, symbol: Symbol): RustModule.LeafModule = + override fun moduleForBuilder( + context: ModuleProviderContext, + shape: Shape, + symbol: Symbol, + ): RustModule.LeafModule = RustModule.public("builders", parent = symbol.module(), documentationOverride = "Builders") private fun Shape.findOperation(model: Model): OperationShape { @@ -216,7 +233,10 @@ object ClientModuleProvider : ModuleProvider { } } - private fun perOperationModule(context: ModuleProviderContext, shape: Shape): RustModule.LeafModule { + private fun perOperationModule( + context: ModuleProviderContext, + shape: Shape, + ): RustModule.LeafModule { val operationShape = shape.findOperation(context.model) val contextName = operationShape.contextName(context.serviceShape) val operationModuleName = diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientRustSettings.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientRustSettings.kt index 67b76f0862b..ebb5db6a0a2 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientRustSettings.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientRustSettings.kt @@ -16,7 +16,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig import software.amazon.smithy.rust.codegen.core.util.orNull import java.util.Optional -/** +/* * [ClientRustSettings] and [ClientCodegenConfig] classes. * * These are specializations of [CoreRustSettings] and [CodegenConfig] for the `rust-client-codegen` @@ -39,20 +39,23 @@ data class ClientRustSettings( override val examplesUri: String?, override val customizationConfig: ObjectNode?, ) : CoreRustSettings( - service, - moduleName, - moduleVersion, - moduleAuthors, - moduleDescription, - moduleRepository, - runtimeConfig, - codegenConfig, - license, - examplesUri, - customizationConfig, -) { + service, + moduleName, + moduleVersion, + moduleAuthors, + moduleDescription, + moduleRepository, + runtimeConfig, + codegenConfig, + license, + examplesUri, + customizationConfig, + ) { companion object { - fun from(model: Model, config: ObjectNode): ClientRustSettings { + fun from( + model: Model, + config: ObjectNode, + ): ClientRustSettings { val coreRustSettings = CoreRustSettings.from(model, config) val codegenSettingsNode = config.getObjectMember(CODEGEN_SETTINGS) val coreCodegenConfig = CoreCodegenConfig.fromNode(codegenSettingsNode) @@ -93,8 +96,8 @@ data class ClientCodegenConfig( val includeEndpointUrlConfig: Boolean = defaultIncludeEndpointUrlConfig, val enableUserConfigurableRuntimePlugins: Boolean = defaultEnableUserConfigurableRuntimePlugins, ) : CoreCodegenConfig( - formatTimeoutSeconds, debugMode, defaultFlattenAccessors, -) { + formatTimeoutSeconds, debugMode, defaultFlattenAccessors, + ) { companion object { private const val defaultRenameExceptions = true private const val defaultIncludeFluentClient = true @@ -107,30 +110,33 @@ data class ClientCodegenConfig( // Note: only clients default to true, servers default to false private const val defaultFlattenAccessors = true - fun fromCodegenConfigAndNode(coreCodegenConfig: CoreCodegenConfig, node: Optional) = - if (node.isPresent) { - ClientCodegenConfig( - formatTimeoutSeconds = coreCodegenConfig.formatTimeoutSeconds, - flattenCollectionAccessors = node.get().getBooleanMemberOrDefault("flattenCollectionAccessors", defaultFlattenAccessors), - debugMode = coreCodegenConfig.debugMode, - eventStreamAllowList = node.get().getArrayMember("eventStreamAllowList").map { array -> + fun fromCodegenConfigAndNode( + coreCodegenConfig: CoreCodegenConfig, + node: Optional, + ) = if (node.isPresent) { + ClientCodegenConfig( + formatTimeoutSeconds = coreCodegenConfig.formatTimeoutSeconds, + flattenCollectionAccessors = node.get().getBooleanMemberOrDefault("flattenCollectionAccessors", defaultFlattenAccessors), + debugMode = coreCodegenConfig.debugMode, + eventStreamAllowList = + node.get().getArrayMember("eventStreamAllowList").map { array -> array.toList().mapNotNull { node -> node.asStringNode().orNull()?.value } }.orNull()?.toSet() ?: defaultEventStreamAllowList, - renameExceptions = node.get().getBooleanMemberOrDefault("renameErrors", defaultRenameExceptions), - includeFluentClient = node.get().getBooleanMemberOrDefault("includeFluentClient", defaultIncludeFluentClient), - addMessageToErrors = node.get().getBooleanMemberOrDefault("addMessageToErrors", defaultAddMessageToErrors), - includeEndpointUrlConfig = node.get().getBooleanMemberOrDefault("includeEndpointUrlConfig", defaultIncludeEndpointUrlConfig), - enableUserConfigurableRuntimePlugins = node.get().getBooleanMemberOrDefault("enableUserConfigurableRuntimePlugins", defaultEnableUserConfigurableRuntimePlugins), - nullabilityCheckMode = NullableIndex.CheckMode.valueOf(node.get().getStringMemberOrDefault("nullabilityCheckMode", defaultNullabilityCheckMode)), - ) - } else { - ClientCodegenConfig( - formatTimeoutSeconds = coreCodegenConfig.formatTimeoutSeconds, - debugMode = coreCodegenConfig.debugMode, - nullabilityCheckMode = NullableIndex.CheckMode.valueOf(defaultNullabilityCheckMode), - ) - } + renameExceptions = node.get().getBooleanMemberOrDefault("renameErrors", defaultRenameExceptions), + includeFluentClient = node.get().getBooleanMemberOrDefault("includeFluentClient", defaultIncludeFluentClient), + addMessageToErrors = node.get().getBooleanMemberOrDefault("addMessageToErrors", defaultAddMessageToErrors), + includeEndpointUrlConfig = node.get().getBooleanMemberOrDefault("includeEndpointUrlConfig", defaultIncludeEndpointUrlConfig), + enableUserConfigurableRuntimePlugins = node.get().getBooleanMemberOrDefault("enableUserConfigurableRuntimePlugins", defaultEnableUserConfigurableRuntimePlugins), + nullabilityCheckMode = NullableIndex.CheckMode.valueOf(node.get().getStringMemberOrDefault("nullabilityCheckMode", defaultNullabilityCheckMode)), + ) + } else { + ClientCodegenConfig( + formatTimeoutSeconds = coreCodegenConfig.formatTimeoutSeconds, + debugMode = coreCodegenConfig.debugMode, + nullabilityCheckMode = NullableIndex.CheckMode.valueOf(defaultNullabilityCheckMode), + ) + } } } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/RustClientCodegenPlugin.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/RustClientCodegenPlugin.kt index 9f75a10ddcd..7f9c07c2888 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/RustClientCodegenPlugin.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/RustClientCodegenPlugin.kt @@ -90,20 +90,19 @@ class RustClientCodegenPlugin : ClientDecoratableBuildPlugin() { serviceShape: ServiceShape, rustSymbolProviderConfig: RustSymbolProviderConfig, codegenDecorator: ClientCodegenDecorator, - ) = - SymbolVisitor(settings, model, serviceShape = serviceShape, config = rustSymbolProviderConfig) - // Generate different types for EventStream shapes (e.g. transcribe streaming) - .let { EventStreamSymbolProvider(rustSymbolProviderConfig.runtimeConfig, it, CodegenTarget.CLIENT) } - // Generate `ByteStream` instead of `Blob` for streaming binary shapes (e.g. S3 GetObject) - .let { StreamingShapeSymbolProvider(it) } - // Add Rust attributes (like `#[derive(PartialEq)]`) to generated shapes - .let { BaseSymbolMetadataProvider(it, additionalAttributes = listOf(NonExhaustive)) } - // Streaming shapes need different derives (e.g. they cannot derive `PartialEq`) - .let { StreamingShapeMetadataProvider(it) } - // Rename shapes that clash with Rust reserved words & and other SDK specific features e.g. `send()` cannot - // be the name of an operation input - .let { RustReservedWordSymbolProvider(it, ClientReservedWords) } - // Allows decorators to inject a custom symbol provider - .let { codegenDecorator.symbolProvider(it) } + ) = SymbolVisitor(settings, model, serviceShape = serviceShape, config = rustSymbolProviderConfig) + // Generate different types for EventStream shapes (e.g. transcribe streaming) + .let { EventStreamSymbolProvider(rustSymbolProviderConfig.runtimeConfig, it, CodegenTarget.CLIENT) } + // Generate `ByteStream` instead of `Blob` for streaming binary shapes (e.g. S3 GetObject) + .let { StreamingShapeSymbolProvider(it) } + // Add Rust attributes (like `#[derive(PartialEq)]`) to generated shapes + .let { BaseSymbolMetadataProvider(it, additionalAttributes = listOf(NonExhaustive)) } + // Streaming shapes need different derives (e.g. they cannot derive `PartialEq`) + .let { StreamingShapeMetadataProvider(it) } + // Rename shapes that clash with Rust reserved words & and other SDK specific features e.g. `send()` cannot + // be the name of an operation input + .let { RustReservedWordSymbolProvider(it, ClientReservedWords) } + // Allows decorators to inject a custom symbol provider + .let { codegenDecorator.symbolProvider(it) } } } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/ClientDocsGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/ClientDocsGenerator.kt index 6c55a4e682d..fb709ac61ef 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/ClientDocsGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/ClientDocsGenerator.kt @@ -18,11 +18,12 @@ import software.amazon.smithy.rust.codegen.core.util.getTrait class ClientDocsGenerator(private val codegenContext: ClientCodegenContext) : LibRsCustomization() { override fun section(section: LibRsSection): Writable { return when (section) { - is LibRsSection.ModuleDoc -> if (section.subsection is ModuleDocSection.CrateOrganization) { - crateLayout() - } else { - emptySection - } + is LibRsSection.ModuleDoc -> + if (section.subsection is ModuleDocSection.CrateOrganization) { + crateLayout() + } else { + emptySection + } else -> emptySection } } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/ConnectionPoisoningConfigCustomization.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/ConnectionPoisoningConfigCustomization.kt index 5f1688037e2..e611e4832f6 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/ConnectionPoisoningConfigCustomization.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/ConnectionPoisoningConfigCustomization.kt @@ -18,20 +18,21 @@ class ConnectionPoisoningRuntimePluginCustomization( ) : ServiceRuntimePluginCustomization() { private val runtimeConfig = codegenContext.runtimeConfig - override fun section(section: ServiceRuntimePluginSection): Writable = writable { - when (section) { - is ServiceRuntimePluginSection.RegisterRuntimeComponents -> { - // This interceptor assumes that a compatible Connector is set. Otherwise, connection poisoning - // won't work and an error message will be logged. - section.registerInterceptor(this) { - rust( - "#T::new()", - smithyRuntime(runtimeConfig).resolve("client::http::connection_poisoning::ConnectionPoisoningInterceptor"), - ) + override fun section(section: ServiceRuntimePluginSection): Writable = + writable { + when (section) { + is ServiceRuntimePluginSection.RegisterRuntimeComponents -> { + // This interceptor assumes that a compatible Connector is set. Otherwise, connection poisoning + // won't work and an error message will be logged. + section.registerInterceptor(this) { + rust( + "#T::new()", + smithyRuntime(runtimeConfig).resolve("client::http::connection_poisoning::ConnectionPoisoningInterceptor"), + ) + } } - } - else -> emptySection + else -> emptySection + } } - } } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/DocsRsMetadataDecorator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/DocsRsMetadataDecorator.kt index 442cedcf23d..b6f5517d758 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/DocsRsMetadataDecorator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/DocsRsMetadataDecorator.kt @@ -26,16 +26,17 @@ data class DocsRsMetadataSettings( ) fun DocsRsMetadataSettings.asMap(): Map { - val inner = listOfNotNull( - features?.let { "features" to it }, - allFeatures?.let { "all-features" to it }, - noDefaultFeatures?.let { "no-default-features" to it }, - defaultTarget?.let { "no-default-target" to it }, - targets?.let { "targets" to it }, - rustcArgs?.let { "rustc-args" to it }, - rustdocArgs?.let { "rustdoc-args" to it }, - cargoArgs?.let { "cargo-args" to it }, - ).toMap() + custom + val inner = + listOfNotNull( + features?.let { "features" to it }, + allFeatures?.let { "all-features" to it }, + noDefaultFeatures?.let { "no-default-features" to it }, + defaultTarget?.let { "no-default-target" to it }, + targets?.let { "targets" to it }, + rustcArgs?.let { "rustc-args" to it }, + rustdocArgs?.let { "rustdoc-args" to it }, + cargoArgs?.let { "cargo-args" to it }, + ).toMap() + custom return mapOf("package" to mapOf("metadata" to mapOf("docs" to mapOf("rs" to inner)))) } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/HttpAuthDecorator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/HttpAuthDecorator.kt index 1f11f991643..676700e6d68 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/HttpAuthDecorator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/HttpAuthDecorator.kt @@ -40,7 +40,6 @@ private fun codegenScope(runtimeConfig: RuntimeConfig): Array> "Token" to configReexport(smithyRuntimeApi.resolve("client::identity::http::Token")), "Login" to configReexport(smithyRuntimeApi.resolve("client::identity::http::Login")), "ResolveIdentity" to configReexport(smithyRuntimeApi.resolve("client::identity::ResolveIdentity")), - "AuthSchemeId" to smithyRuntimeApi.resolve("client::auth::AuthSchemeId"), "ApiKeyAuthScheme" to authHttp.resolve("ApiKeyAuthScheme"), "ApiKeyLocation" to authHttp.resolve("ApiKeyLocation"), @@ -75,7 +74,9 @@ private data class HttpAuthSchemes( } fun anyEnabled(): Boolean = isTokenBased() || isLoginBased() + fun isTokenBased(): Boolean = apiKey || bearer + fun isLoginBased(): Boolean = basic || digest } @@ -93,7 +94,10 @@ class HttpAuthDecorator : ClientCodegenDecorator { val codegenScope = codegenScope(codegenContext.runtimeConfig) val options = ArrayList() for (authScheme in authSchemes.keys) { - fun addOption(schemeShapeId: ShapeId, name: String) { + fun addOption( + schemeShapeId: ShapeId, + name: String, + ) { options.add( StaticAuthSchemeOption( schemeShapeId, @@ -144,61 +148,63 @@ private class HttpAuthServiceRuntimePluginCustomization( private val serviceShape = codegenContext.serviceShape private val codegenScope = codegenScope(codegenContext.runtimeConfig) - override fun section(section: ServiceRuntimePluginSection): Writable = writable { - when (section) { - is ServiceRuntimePluginSection.RegisterRuntimeComponents -> { - fun registerAuthScheme(scheme: Writable) { - section.registerAuthScheme(this) { - rustTemplate("#{SharedAuthScheme}::new(#{Scheme})", *codegenScope, "Scheme" to scheme) + override fun section(section: ServiceRuntimePluginSection): Writable = + writable { + when (section) { + is ServiceRuntimePluginSection.RegisterRuntimeComponents -> { + fun registerAuthScheme(scheme: Writable) { + section.registerAuthScheme(this) { + rustTemplate("#{SharedAuthScheme}::new(#{Scheme})", *codegenScope, "Scheme" to scheme) + } } - } - fun registerNamedAuthScheme(name: String) { - registerAuthScheme { - rustTemplate("#{$name}::new()", *codegenScope) + fun registerNamedAuthScheme(name: String) { + registerAuthScheme { + rustTemplate("#{$name}::new()", *codegenScope) + } } - } - if (authSchemes.apiKey) { - val trait = serviceShape.getTrait()!! - val location = when (trait.`in`!!) { - HttpApiKeyAuthTrait.Location.HEADER -> { - check(trait.scheme.isPresent) { - "A scheme is required for `@httpApiKey` when `in` is set to `header`" - } - "Header" - } + if (authSchemes.apiKey) { + val trait = serviceShape.getTrait()!! + val location = + when (trait.`in`!!) { + HttpApiKeyAuthTrait.Location.HEADER -> { + check(trait.scheme.isPresent) { + "A scheme is required for `@httpApiKey` when `in` is set to `header`" + } + "Header" + } - HttpApiKeyAuthTrait.Location.QUERY -> "Query" - } + HttpApiKeyAuthTrait.Location.QUERY -> "Query" + } - registerAuthScheme { - rustTemplate( - """ - #{ApiKeyAuthScheme}::new( - ${trait.scheme.orElse("").dq()}, - #{ApiKeyLocation}::$location, - ${trait.name.dq()}, + registerAuthScheme { + rustTemplate( + """ + #{ApiKeyAuthScheme}::new( + ${trait.scheme.orElse("").dq()}, + #{ApiKeyLocation}::$location, + ${trait.name.dq()}, + ) + """, + *codegenScope, ) - """, - *codegenScope, - ) + } + } + if (authSchemes.basic) { + registerNamedAuthScheme("BasicAuthScheme") + } + if (authSchemes.bearer) { + registerNamedAuthScheme("BearerAuthScheme") + } + if (authSchemes.digest) { + registerNamedAuthScheme("DigestAuthScheme") } } - if (authSchemes.basic) { - registerNamedAuthScheme("BasicAuthScheme") - } - if (authSchemes.bearer) { - registerNamedAuthScheme("BearerAuthScheme") - } - if (authSchemes.digest) { - registerNamedAuthScheme("DigestAuthScheme") - } - } - else -> emptySection + else -> emptySection + } } - } } private class HttpAuthConfigCustomization( @@ -207,92 +213,93 @@ private class HttpAuthConfigCustomization( ) : ConfigCustomization() { private val codegenScope = codegenScope(codegenContext.runtimeConfig) - override fun section(section: ServiceConfig): Writable = writable { - when (section) { - is ServiceConfig.BuilderImpl -> { - if (authSchemes.apiKey) { - rustTemplate( - """ - /// Sets the API key that will be used for authentication. - pub fn api_key(self, api_key: #{Token}) -> Self { - self.api_key_resolver(api_key) - } + override fun section(section: ServiceConfig): Writable = + writable { + when (section) { + is ServiceConfig.BuilderImpl -> { + if (authSchemes.apiKey) { + rustTemplate( + """ + /// Sets the API key that will be used for authentication. + pub fn api_key(self, api_key: #{Token}) -> Self { + self.api_key_resolver(api_key) + } - /// Sets an API key resolver will be used for authentication. - pub fn api_key_resolver(mut self, api_key_resolver: impl #{ResolveIdentity} + 'static) -> Self { - self.runtime_components.set_identity_resolver( - #{HTTP_API_KEY_AUTH_SCHEME_ID}, - #{SharedIdentityResolver}::new(api_key_resolver) - ); - self - } - """, - *codegenScope, - ) - } - if (authSchemes.bearer) { - rustTemplate( - """ - /// Sets the bearer token that will be used for HTTP bearer auth. - pub fn bearer_token(self, bearer_token: #{Token}) -> Self { - self.bearer_token_resolver(bearer_token) - } + /// Sets an API key resolver will be used for authentication. + pub fn api_key_resolver(mut self, api_key_resolver: impl #{ResolveIdentity} + 'static) -> Self { + self.runtime_components.set_identity_resolver( + #{HTTP_API_KEY_AUTH_SCHEME_ID}, + #{SharedIdentityResolver}::new(api_key_resolver) + ); + self + } + """, + *codegenScope, + ) + } + if (authSchemes.bearer) { + rustTemplate( + """ + /// Sets the bearer token that will be used for HTTP bearer auth. + pub fn bearer_token(self, bearer_token: #{Token}) -> Self { + self.bearer_token_resolver(bearer_token) + } - /// Sets a bearer token provider that will be used for HTTP bearer auth. - pub fn bearer_token_resolver(mut self, bearer_token_resolver: impl #{ResolveIdentity} + 'static) -> Self { - self.runtime_components.set_identity_resolver( - #{HTTP_BEARER_AUTH_SCHEME_ID}, - #{SharedIdentityResolver}::new(bearer_token_resolver) - ); - self - } - """, - *codegenScope, - ) - } - if (authSchemes.basic) { - rustTemplate( - """ - /// Sets the login that will be used for HTTP basic auth. - pub fn basic_auth_login(self, basic_auth_login: #{Login}) -> Self { - self.basic_auth_login_resolver(basic_auth_login) - } + /// Sets a bearer token provider that will be used for HTTP bearer auth. + pub fn bearer_token_resolver(mut self, bearer_token_resolver: impl #{ResolveIdentity} + 'static) -> Self { + self.runtime_components.set_identity_resolver( + #{HTTP_BEARER_AUTH_SCHEME_ID}, + #{SharedIdentityResolver}::new(bearer_token_resolver) + ); + self + } + """, + *codegenScope, + ) + } + if (authSchemes.basic) { + rustTemplate( + """ + /// Sets the login that will be used for HTTP basic auth. + pub fn basic_auth_login(self, basic_auth_login: #{Login}) -> Self { + self.basic_auth_login_resolver(basic_auth_login) + } - /// Sets a login resolver that will be used for HTTP basic auth. - pub fn basic_auth_login_resolver(mut self, basic_auth_resolver: impl #{ResolveIdentity} + 'static) -> Self { - self.runtime_components.set_identity_resolver( - #{HTTP_BASIC_AUTH_SCHEME_ID}, - #{SharedIdentityResolver}::new(basic_auth_resolver) - ); - self - } - """, - *codegenScope, - ) - } - if (authSchemes.digest) { - rustTemplate( - """ - /// Sets the login that will be used for HTTP digest auth. - pub fn digest_auth_login(self, digest_auth_login: #{Login}) -> Self { - self.digest_auth_login_resolver(digest_auth_login) - } + /// Sets a login resolver that will be used for HTTP basic auth. + pub fn basic_auth_login_resolver(mut self, basic_auth_resolver: impl #{ResolveIdentity} + 'static) -> Self { + self.runtime_components.set_identity_resolver( + #{HTTP_BASIC_AUTH_SCHEME_ID}, + #{SharedIdentityResolver}::new(basic_auth_resolver) + ); + self + } + """, + *codegenScope, + ) + } + if (authSchemes.digest) { + rustTemplate( + """ + /// Sets the login that will be used for HTTP digest auth. + pub fn digest_auth_login(self, digest_auth_login: #{Login}) -> Self { + self.digest_auth_login_resolver(digest_auth_login) + } - /// Sets a login resolver that will be used for HTTP digest auth. - pub fn digest_auth_login_resolver(mut self, digest_auth_resolver: impl #{ResolveIdentity} + 'static) -> Self { - self.runtime_components.set_identity_resolver( - #{HTTP_DIGEST_AUTH_SCHEME_ID}, - #{SharedIdentityResolver}::new(digest_auth_resolver) - ); - self - } - """, - *codegenScope, - ) + /// Sets a login resolver that will be used for HTTP digest auth. + pub fn digest_auth_login_resolver(mut self, digest_auth_resolver: impl #{ResolveIdentity} + 'static) -> Self { + self.runtime_components.set_identity_resolver( + #{HTTP_DIGEST_AUTH_SCHEME_ID}, + #{SharedIdentityResolver}::new(digest_auth_resolver) + ); + self + } + """, + *codegenScope, + ) + } } - } - else -> {} + else -> {} + } } - } } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/HttpChecksumRequiredGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/HttpChecksumRequiredGenerator.kt index aba470973da..c4826763e55 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/HttpChecksumRequiredGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/HttpChecksumRequiredGenerator.kt @@ -35,22 +35,23 @@ class HttpChecksumRequiredGenerator( throw CodegenException("HttpChecksum required cannot be applied to a streaming shape") } return when (section) { - is OperationSection.AdditionalRuntimePlugins -> writable { - section.addOperationRuntimePlugin(this) { - rustTemplate( - "#{HttpChecksumRequiredRuntimePlugin}::new()", - "HttpChecksumRequiredRuntimePlugin" to - InlineDependency.forRustFile( - RustModule.pubCrate("client_http_checksum_required", parent = ClientRustModule.root), - "/inlineable/src/client_http_checksum_required.rs", - CargoDependency.smithyRuntimeApiClient(codegenContext.runtimeConfig), - CargoDependency.smithyTypes(codegenContext.runtimeConfig), - CargoDependency.Http, - CargoDependency.Md5, - ).toType().resolve("HttpChecksumRequiredRuntimePlugin"), - ) + is OperationSection.AdditionalRuntimePlugins -> + writable { + section.addOperationRuntimePlugin(this) { + rustTemplate( + "#{HttpChecksumRequiredRuntimePlugin}::new()", + "HttpChecksumRequiredRuntimePlugin" to + InlineDependency.forRustFile( + RustModule.pubCrate("client_http_checksum_required", parent = ClientRustModule.root), + "/inlineable/src/client_http_checksum_required.rs", + CargoDependency.smithyRuntimeApiClient(codegenContext.runtimeConfig), + CargoDependency.smithyTypes(codegenContext.runtimeConfig), + CargoDependency.Http, + CargoDependency.Md5, + ).toType().resolve("HttpChecksumRequiredRuntimePlugin"), + ) + } } - } else -> emptySection } } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/HttpConnectorConfigDecorator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/HttpConnectorConfigDecorator.kt index 343292055f3..c8a7a7e211e 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/HttpConnectorConfigDecorator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/HttpConnectorConfigDecorator.kt @@ -31,105 +31,110 @@ private class HttpConnectorConfigCustomization( ) : ConfigCustomization() { private val runtimeConfig = codegenContext.runtimeConfig private val moduleUseName = codegenContext.moduleUseName() - private val codegenScope = arrayOf( - *preludeScope, - "HttpClient" to configReexport( - RuntimeType.smithyRuntimeApiClient(runtimeConfig).resolve("client::http::HttpClient"), - ), - "IntoShared" to configReexport(RuntimeType.smithyRuntimeApi(runtimeConfig).resolve("shared::IntoShared")), - "SharedHttpClient" to configReexport( - RuntimeType.smithyRuntimeApiClient(runtimeConfig).resolve("client::http::SharedHttpClient"), - ), - ) + private val codegenScope = + arrayOf( + *preludeScope, + "HttpClient" to + configReexport( + RuntimeType.smithyRuntimeApiClient(runtimeConfig).resolve("client::http::HttpClient"), + ), + "IntoShared" to configReexport(RuntimeType.smithyRuntimeApi(runtimeConfig).resolve("shared::IntoShared")), + "SharedHttpClient" to + configReexport( + RuntimeType.smithyRuntimeApiClient(runtimeConfig).resolve("client::http::SharedHttpClient"), + ), + ) override fun section(section: ServiceConfig): Writable { return when (section) { - is ServiceConfig.ConfigImpl -> writable { - rustTemplate( - """ - /// Return the [`SharedHttpClient`](#{SharedHttpClient}) to use when making requests, if any. - pub fn http_client(&self) -> Option<#{SharedHttpClient}> { - self.runtime_components.http_client() - } - """, - *codegenScope, - ) - } + is ServiceConfig.ConfigImpl -> + writable { + rustTemplate( + """ + /// Return the [`SharedHttpClient`](#{SharedHttpClient}) to use when making requests, if any. + pub fn http_client(&self) -> Option<#{SharedHttpClient}> { + self.runtime_components.http_client() + } + """, + *codegenScope, + ) + } - ServiceConfig.BuilderImpl -> writable { - rustTemplate( - """ - /// Sets the HTTP client to use when making requests. - /// - /// ## Examples - /// ```no_run - /// ## ##[cfg(test)] - /// ## mod tests { - /// ## ##[test] - /// ## fn example() { - /// use std::time::Duration; - /// use $moduleUseName::config::Config; - /// use aws_smithy_runtime::client::http::hyper_014::HyperClientBuilder; - /// - /// let https_connector = hyper_rustls::HttpsConnectorBuilder::new() - /// .with_webpki_roots() - /// .https_only() - /// .enable_http1() - /// .enable_http2() - /// .build(); - /// let hyper_client = HyperClientBuilder::new().build(https_connector); - /// - /// // This connector can then be given to a generated service Config - /// let config = my_service_client::Config::builder() - /// .endpoint_url("https://example.com") - /// .http_client(hyper_client) - /// .build(); - /// let client = my_service_client::Client::from_conf(config); - /// ## } - /// ## } - /// ``` - pub fn http_client(mut self, http_client: impl #{HttpClient} + 'static) -> Self { - self.set_http_client(#{Some}(#{IntoShared}::into_shared(http_client))); - self - } + ServiceConfig.BuilderImpl -> + writable { + rustTemplate( + """ + /// Sets the HTTP client to use when making requests. + /// + /// ## Examples + /// ```no_run + /// ## ##[cfg(test)] + /// ## mod tests { + /// ## ##[test] + /// ## fn example() { + /// use std::time::Duration; + /// use $moduleUseName::config::Config; + /// use aws_smithy_runtime::client::http::hyper_014::HyperClientBuilder; + /// + /// let https_connector = hyper_rustls::HttpsConnectorBuilder::new() + /// .with_webpki_roots() + /// .https_only() + /// .enable_http1() + /// .enable_http2() + /// .build(); + /// let hyper_client = HyperClientBuilder::new().build(https_connector); + /// + /// // This connector can then be given to a generated service Config + /// let config = my_service_client::Config::builder() + /// .endpoint_url("https://example.com") + /// .http_client(hyper_client) + /// .build(); + /// let client = my_service_client::Client::from_conf(config); + /// ## } + /// ## } + /// ``` + pub fn http_client(mut self, http_client: impl #{HttpClient} + 'static) -> Self { + self.set_http_client(#{Some}(#{IntoShared}::into_shared(http_client))); + self + } - /// Sets the HTTP client to use when making requests. - /// - /// ## Examples - /// ```no_run - /// ## ##[cfg(test)] - /// ## mod tests { - /// ## ##[test] - /// ## fn example() { - /// use std::time::Duration; - /// use $moduleUseName::config::{Builder, Config}; - /// use aws_smithy_runtime::client::http::hyper_014::HyperClientBuilder; - /// - /// fn override_http_client(builder: &mut Builder) { - /// let https_connector = hyper_rustls::HttpsConnectorBuilder::new() - /// .with_webpki_roots() - /// .https_only() - /// .enable_http1() - /// .enable_http2() - /// .build(); - /// let hyper_client = HyperClientBuilder::new().build(https_connector); - /// builder.set_http_client(Some(hyper_client)); - /// } - /// - /// let mut builder = $moduleUseName::Config::builder(); - /// override_http_client(&mut builder); - /// let config = builder.build(); - /// ## } - /// ## } - /// ``` - pub fn set_http_client(&mut self, http_client: Option<#{SharedHttpClient}>) -> &mut Self { - self.runtime_components.set_http_client(http_client); - self - } - """, - *codegenScope, - ) - } + /// Sets the HTTP client to use when making requests. + /// + /// ## Examples + /// ```no_run + /// ## ##[cfg(test)] + /// ## mod tests { + /// ## ##[test] + /// ## fn example() { + /// use std::time::Duration; + /// use $moduleUseName::config::{Builder, Config}; + /// use aws_smithy_runtime::client::http::hyper_014::HyperClientBuilder; + /// + /// fn override_http_client(builder: &mut Builder) { + /// let https_connector = hyper_rustls::HttpsConnectorBuilder::new() + /// .with_webpki_roots() + /// .https_only() + /// .enable_http1() + /// .enable_http2() + /// .build(); + /// let hyper_client = HyperClientBuilder::new().build(https_connector); + /// builder.set_http_client(Some(hyper_client)); + /// } + /// + /// let mut builder = $moduleUseName::Config::builder(); + /// override_http_client(&mut builder); + /// let config = builder.build(); + /// ## } + /// ## } + /// ``` + pub fn set_http_client(&mut self, http_client: Option<#{SharedHttpClient}>) -> &mut Self { + self.runtime_components.set_http_client(http_client); + self + } + """, + *codegenScope, + ) + } else -> emptySection } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/IdempotencyTokenDecorator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/IdempotencyTokenDecorator.kt index b33ee8fbbc2..cd40269b1c8 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/IdempotencyTokenDecorator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/IdempotencyTokenDecorator.kt @@ -25,12 +25,14 @@ class IdempotencyTokenDecorator : ClientCodegenDecorator { override val order: Byte = 0 private fun enabled(ctx: ClientCodegenContext) = ctx.serviceShape.needsIdempotencyToken(ctx.model) + override fun configCustomizations( codegenContext: ClientCodegenContext, baseCustomizations: List, - ): List = baseCustomizations.extendIf(enabled(codegenContext)) { - IdempotencyTokenProviderCustomization(codegenContext) - } + ): List = + baseCustomizations.extendIf(enabled(codegenContext)) { + IdempotencyTokenProviderCustomization(codegenContext) + } override fun operationCustomizations( codegenContext: ClientCodegenContext, @@ -46,11 +48,12 @@ class IdempotencyTokenDecorator : ClientCodegenDecorator { ): List { return baseCustomizations.extendIf(enabled(codegenContext)) { object : ServiceRuntimePluginCustomization() { - override fun section(section: ServiceRuntimePluginSection) = writable { - if (section is ServiceRuntimePluginSection.AdditionalConfig) { - section.putConfigValue(this, defaultTokenProvider((codegenContext.runtimeConfig))) + override fun section(section: ServiceRuntimePluginSection) = + writable { + if (section is ServiceRuntimePluginSection.AdditionalConfig) { + section.putConfigValue(this, defaultTokenProvider((codegenContext.runtimeConfig))) + } } - } } } } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/IdempotencyTokenGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/IdempotencyTokenGenerator.kt index af2e21294ca..dddaa45514c 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/IdempotencyTokenGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/IdempotencyTokenGenerator.kt @@ -39,40 +39,42 @@ class IdempotencyTokenGenerator( return emptySection } val memberName = symbolProvider.toMemberName(idempotencyTokenMember) - val codegenScope = arrayOf( - *preludeScope, - "Input" to symbolProvider.toSymbol(inputShape), - "IdempotencyTokenRuntimePlugin" to - InlineDependency.forRustFile( - RustModule.pubCrate("client_idempotency_token", parent = ClientRustModule.root), - "/inlineable/src/client_idempotency_token.rs", - CargoDependency.smithyRuntimeApiClient(runtimeConfig), - CargoDependency.smithyTypes(runtimeConfig), - InlineDependency.idempotencyToken(runtimeConfig), - ).toType().resolve("IdempotencyTokenRuntimePlugin"), - ) + val codegenScope = + arrayOf( + *preludeScope, + "Input" to symbolProvider.toSymbol(inputShape), + "IdempotencyTokenRuntimePlugin" to + InlineDependency.forRustFile( + RustModule.pubCrate("client_idempotency_token", parent = ClientRustModule.root), + "/inlineable/src/client_idempotency_token.rs", + CargoDependency.smithyRuntimeApiClient(runtimeConfig), + CargoDependency.smithyTypes(runtimeConfig), + InlineDependency.idempotencyToken(runtimeConfig), + ).toType().resolve("IdempotencyTokenRuntimePlugin"), + ) return when (section) { - is OperationSection.AdditionalRuntimePlugins -> writable { - section.addOperationRuntimePlugin(this) { - if (!symbolProvider.toSymbol(idempotencyTokenMember).isOptional()) { - UNREACHABLE("top level input members are always optional. $operationShape") + is OperationSection.AdditionalRuntimePlugins -> + writable { + section.addOperationRuntimePlugin(this) { + if (!symbolProvider.toSymbol(idempotencyTokenMember).isOptional()) { + UNREACHABLE("top level input members are always optional. $operationShape") + } + // An idempotency token is optional. If the user didn't specify a token + // then we'll generate one and set it. + rustTemplate( + """ + #{IdempotencyTokenRuntimePlugin}::new(|token_provider, input| { + let input: &mut #{Input} = input.downcast_mut().expect("correct type"); + if input.$memberName.is_none() { + input.$memberName = #{Some}(token_provider.make_idempotency_token()); + } + }) + """, + *codegenScope, + ) } - // An idempotency token is optional. If the user didn't specify a token - // then we'll generate one and set it. - rustTemplate( - """ - #{IdempotencyTokenRuntimePlugin}::new(|token_provider, input| { - let input: &mut #{Input} = input.downcast_mut().expect("correct type"); - if input.$memberName.is_none() { - input.$memberName = #{Some}(token_provider.make_idempotency_token()); - } - }) - """, - *codegenScope, - ) } - } else -> emptySection } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/IdentityCacheDecorator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/IdentityCacheDecorator.kt index bdd3e267f6a..b5a31e402c6 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/IdentityCacheDecorator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/IdentityCacheDecorator.kt @@ -18,91 +18,93 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.pre class IdentityCacheConfigCustomization(codegenContext: ClientCodegenContext) : ConfigCustomization() { private val moduleUseName = codegenContext.moduleUseName() - private val codegenScope = codegenContext.runtimeConfig.let { rc -> - val api = RuntimeType.smithyRuntimeApiClient(rc) - arrayOf( - *preludeScope, - "ResolveCachedIdentity" to configReexport(api.resolve("client::identity::ResolveCachedIdentity")), - "SharedIdentityCache" to configReexport(api.resolve("client::identity::SharedIdentityCache")), - ) - } + private val codegenScope = + codegenContext.runtimeConfig.let { rc -> + val api = RuntimeType.smithyRuntimeApiClient(rc) + arrayOf( + *preludeScope, + "ResolveCachedIdentity" to configReexport(api.resolve("client::identity::ResolveCachedIdentity")), + "SharedIdentityCache" to configReexport(api.resolve("client::identity::SharedIdentityCache")), + ) + } - override fun section(section: ServiceConfig): Writable = writable { - when (section) { - is ServiceConfig.BuilderImpl -> { - val docs = """ - /// Set the identity cache for auth. - /// - /// The identity cache defaults to a lazy caching implementation that will resolve - /// an identity when it is requested, and place it in the cache thereafter. Subsequent - /// requests will take the value from the cache while it is still valid. Once it expires, - /// the next request will result in refreshing the identity. - /// - /// This configuration allows you to disable or change the default caching mechanism. - /// To use a custom caching mechanism, implement the [`ResolveCachedIdentity`](#{ResolveCachedIdentity}) - /// trait and pass that implementation into this function. - /// - /// ## Examples - /// - /// Disabling identity caching: - /// ```no_run - /// use $moduleUseName::config::IdentityCache; - /// - /// let config = $moduleUseName::Config::builder() - /// .identity_cache(IdentityCache::no_cache()) - /// // ... - /// .build(); - /// let client = $moduleUseName::Client::from_conf(config); - /// ``` - /// - /// Customizing lazy caching: - /// ```no_run - /// use $moduleUseName::config::IdentityCache; - /// use std::time::Duration; - /// - /// let config = $moduleUseName::Config::builder() - /// .identity_cache( - /// IdentityCache::lazy() - /// // change the load timeout to 10 seconds - /// .load_timeout(Duration::from_secs(10)) - /// .build() - /// ) - /// // ... - /// .build(); - /// let client = $moduleUseName::Client::from_conf(config); - /// ``` - """ - rustTemplate( + override fun section(section: ServiceConfig): Writable = + writable { + when (section) { + is ServiceConfig.BuilderImpl -> { + val docs = """ + /// Set the identity cache for auth. + /// + /// The identity cache defaults to a lazy caching implementation that will resolve + /// an identity when it is requested, and place it in the cache thereafter. Subsequent + /// requests will take the value from the cache while it is still valid. Once it expires, + /// the next request will result in refreshing the identity. + /// + /// This configuration allows you to disable or change the default caching mechanism. + /// To use a custom caching mechanism, implement the [`ResolveCachedIdentity`](#{ResolveCachedIdentity}) + /// trait and pass that implementation into this function. + /// + /// ## Examples + /// + /// Disabling identity caching: + /// ```no_run + /// use $moduleUseName::config::IdentityCache; + /// + /// let config = $moduleUseName::Config::builder() + /// .identity_cache(IdentityCache::no_cache()) + /// // ... + /// .build(); + /// let client = $moduleUseName::Client::from_conf(config); + /// ``` + /// + /// Customizing lazy caching: + /// ```no_run + /// use $moduleUseName::config::IdentityCache; + /// use std::time::Duration; + /// + /// let config = $moduleUseName::Config::builder() + /// .identity_cache( + /// IdentityCache::lazy() + /// // change the load timeout to 10 seconds + /// .load_timeout(Duration::from_secs(10)) + /// .build() + /// ) + /// // ... + /// .build(); + /// let client = $moduleUseName::Client::from_conf(config); + /// ``` """ - $docs - pub fn identity_cache(mut self, identity_cache: impl #{ResolveCachedIdentity} + 'static) -> Self { - self.set_identity_cache(identity_cache); - self - } + rustTemplate( + """ + $docs + pub fn identity_cache(mut self, identity_cache: impl #{ResolveCachedIdentity} + 'static) -> Self { + self.set_identity_cache(identity_cache); + self + } - $docs - pub fn set_identity_cache(&mut self, identity_cache: impl #{ResolveCachedIdentity} + 'static) -> &mut Self { - self.runtime_components.set_identity_cache(#{Some}(identity_cache)); - self - } - """, - *codegenScope, - ) - } + $docs + pub fn set_identity_cache(&mut self, identity_cache: impl #{ResolveCachedIdentity} + 'static) -> &mut Self { + self.runtime_components.set_identity_cache(#{Some}(identity_cache)); + self + } + """, + *codegenScope, + ) + } - is ServiceConfig.ConfigImpl -> { - rustTemplate( - """ - /// Returns the configured identity cache for auth. - pub fn identity_cache(&self) -> #{Option}<#{SharedIdentityCache}> { - self.runtime_components.identity_cache() - } - """, - *codegenScope, - ) - } + is ServiceConfig.ConfigImpl -> { + rustTemplate( + """ + /// Returns the configured identity cache for auth. + pub fn identity_cache(&self) -> #{Option}<#{SharedIdentityCache}> { + self.runtime_components.identity_cache() + } + """, + *codegenScope, + ) + } - else -> {} + else -> {} + } } - } } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/InterceptorConfigCustomization.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/InterceptorConfigCustomization.kt index eecd7bb4a6b..0afdb500774 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/InterceptorConfigCustomization.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/InterceptorConfigCustomization.kt @@ -17,23 +17,25 @@ class InterceptorConfigCustomization(codegenContext: ClientCodegenContext) : Con private val moduleUseName = codegenContext.moduleUseName() private val runtimeConfig = codegenContext.runtimeConfig - private val codegenScope = arrayOf( - "Intercept" to configReexport(RuntimeType.intercept(runtimeConfig)), - "SharedInterceptor" to configReexport(RuntimeType.sharedInterceptor(runtimeConfig)), - ) + private val codegenScope = + arrayOf( + "Intercept" to configReexport(RuntimeType.intercept(runtimeConfig)), + "SharedInterceptor" to configReexport(RuntimeType.sharedInterceptor(runtimeConfig)), + ) override fun section(section: ServiceConfig) = writable { when (section) { - ServiceConfig.ConfigImpl -> rustTemplate( - """ - /// Returns interceptors currently registered by the user. - pub fn interceptors(&self) -> impl Iterator + '_ { - self.runtime_components.interceptors() - } - """, - *codegenScope, - ) + ServiceConfig.ConfigImpl -> + rustTemplate( + """ + /// Returns interceptors currently registered by the user. + pub fn interceptors(&self) -> impl Iterator + '_ { + self.runtime_components.interceptors() + } + """, + *codegenScope, + ) ServiceConfig.BuilderImpl -> rustTemplate( diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/MetadataCustomization.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/MetadataCustomization.kt index 03c62fb0493..513face076f 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/MetadataCustomization.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/MetadataCustomization.kt @@ -28,21 +28,22 @@ class MetadataCustomization( ) } - override fun section(section: OperationSection): Writable = writable { - when (section) { - is OperationSection.AdditionalRuntimePluginConfig -> { - rustTemplate( - """ - ${section.newLayerName}.store_put(#{Metadata}::new( - ${operationName.dq()}, - ${codegenContext.serviceShape.sdkId().dq()}, - )); - """, - *codegenScope, - ) - } + override fun section(section: OperationSection): Writable = + writable { + when (section) { + is OperationSection.AdditionalRuntimePluginConfig -> { + rustTemplate( + """ + ${section.newLayerName}.store_put(#{Metadata}::new( + ${operationName.dq()}, + ${codegenContext.serviceShape.sdkId().dq()}, + )); + """, + *codegenScope, + ) + } - else -> {} + else -> {} + } } - } } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/NoAuthDecorator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/NoAuthDecorator.kt index bc68c958c88..4390f762588 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/NoAuthDecorator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/NoAuthDecorator.kt @@ -30,16 +30,17 @@ class NoAuthDecorator : ClientCodegenDecorator { codegenContext: ClientCodegenContext, operationShape: OperationShape, baseAuthSchemeOptions: List, - ): List = baseAuthSchemeOptions + - AuthSchemeOption.StaticAuthSchemeOption( - noAuthSchemeShapeId, - listOf( - writable { - rustTemplate( - "#{NO_AUTH_SCHEME_ID}", - "NO_AUTH_SCHEME_ID" to noAuthModule(codegenContext).resolve("NO_AUTH_SCHEME_ID"), - ) - }, - ), - ) + ): List = + baseAuthSchemeOptions + + AuthSchemeOption.StaticAuthSchemeOption( + noAuthSchemeShapeId, + listOf( + writable { + rustTemplate( + "#{NO_AUTH_SCHEME_ID}", + "NO_AUTH_SCHEME_ID" to noAuthModule(codegenContext).resolve("NO_AUTH_SCHEME_ID"), + ) + }, + ), + ) } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/ResiliencyConfigCustomization.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/ResiliencyConfigCustomization.kt index 0f64fa0ccf7..8b6153b9a67 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/ResiliencyConfigCustomization.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/ResiliencyConfigCustomization.kt @@ -23,25 +23,26 @@ class ResiliencyConfigCustomization(codegenContext: ClientCodegenContext) : Conf private val timeoutModule = RuntimeType.smithyTypes(runtimeConfig).resolve("timeout") private val retries = RuntimeType.smithyRuntime(runtimeConfig).resolve("client::retries") private val moduleUseName = codegenContext.moduleUseName() - private val codegenScope = arrayOf( - *preludeScope, - "AsyncSleep" to configReexport(sleepModule.resolve("AsyncSleep")), - "SharedAsyncSleep" to configReexport(sleepModule.resolve("SharedAsyncSleep")), - "Sleep" to configReexport(sleepModule.resolve("Sleep")), - "ClientRateLimiter" to retries.resolve("ClientRateLimiter"), - "ClientRateLimiterPartition" to retries.resolve("ClientRateLimiterPartition"), - "debug" to RuntimeType.Tracing.resolve("debug"), - "IntoShared" to RuntimeType.smithyRuntimeApi(runtimeConfig).resolve("shared::IntoShared"), - "RetryConfig" to retryConfig.resolve("RetryConfig"), - "RetryMode" to RuntimeType.smithyTypes(runtimeConfig).resolve("retry::RetryMode"), - "RetryPartition" to retries.resolve("RetryPartition"), - "SharedAsyncSleep" to configReexport(sleepModule.resolve("SharedAsyncSleep")), - "SharedRetryStrategy" to configReexport(RuntimeType.smithyRuntimeApiClient(runtimeConfig).resolve("client::retries::SharedRetryStrategy")), - "SharedTimeSource" to configReexport(RuntimeType.smithyAsync(runtimeConfig).resolve("time::SharedTimeSource")), - "StandardRetryStrategy" to configReexport(retries.resolve("strategy::StandardRetryStrategy")), - "SystemTime" to RuntimeType.std.resolve("time::SystemTime"), - "TimeoutConfig" to timeoutModule.resolve("TimeoutConfig"), - ) + private val codegenScope = + arrayOf( + *preludeScope, + "AsyncSleep" to configReexport(sleepModule.resolve("AsyncSleep")), + "SharedAsyncSleep" to configReexport(sleepModule.resolve("SharedAsyncSleep")), + "Sleep" to configReexport(sleepModule.resolve("Sleep")), + "ClientRateLimiter" to retries.resolve("ClientRateLimiter"), + "ClientRateLimiterPartition" to retries.resolve("ClientRateLimiterPartition"), + "debug" to RuntimeType.Tracing.resolve("debug"), + "IntoShared" to RuntimeType.smithyRuntimeApi(runtimeConfig).resolve("shared::IntoShared"), + "RetryConfig" to retryConfig.resolve("RetryConfig"), + "RetryMode" to RuntimeType.smithyTypes(runtimeConfig).resolve("retry::RetryMode"), + "RetryPartition" to retries.resolve("RetryPartition"), + "SharedAsyncSleep" to configReexport(sleepModule.resolve("SharedAsyncSleep")), + "SharedRetryStrategy" to configReexport(RuntimeType.smithyRuntimeApiClient(runtimeConfig).resolve("client::retries::SharedRetryStrategy")), + "SharedTimeSource" to configReexport(RuntimeType.smithyAsync(runtimeConfig).resolve("time::SharedTimeSource")), + "StandardRetryStrategy" to configReexport(retries.resolve("strategy::StandardRetryStrategy")), + "SystemTime" to RuntimeType.std.resolve("time::SystemTime"), + "TimeoutConfig" to timeoutModule.resolve("TimeoutConfig"), + ) override fun section(section: ServiceConfig) = writable { diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/RetryClassifierConfigCustomization.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/RetryClassifierConfigCustomization.kt index 8b23e4efbac..34dede78cf7 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/RetryClassifierConfigCustomization.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/RetryClassifierConfigCustomization.kt @@ -24,25 +24,27 @@ class RetryClassifierConfigCustomization(codegenContext: ClientCodegenContext) : private val retries = RuntimeType.smithyRuntimeApi(runtimeConfig).resolve("client::retries") private val classifiers = retries.resolve("classifiers") - private val codegenScope = arrayOf( - "ClassifyRetry" to classifiers.resolve("ClassifyRetry"), - "RetryStrategy" to retries.resolve("RetryStrategy"), - "SharedRetryClassifier" to classifiers.resolve("SharedRetryClassifier"), - "RetryClassifierPriority" to classifiers.resolve("RetryClassifierPriority"), - ) + private val codegenScope = + arrayOf( + "ClassifyRetry" to classifiers.resolve("ClassifyRetry"), + "RetryStrategy" to retries.resolve("RetryStrategy"), + "SharedRetryClassifier" to classifiers.resolve("SharedRetryClassifier"), + "RetryClassifierPriority" to classifiers.resolve("RetryClassifierPriority"), + ) override fun section(section: ServiceConfig) = writable { when (section) { - ServiceConfig.ConfigImpl -> rustTemplate( - """ - /// Returns retry classifiers currently registered by the user. - pub fn retry_classifiers(&self) -> impl Iterator + '_ { - self.runtime_components.retry_classifiers() - } - """, - *codegenScope, - ) + ServiceConfig.ConfigImpl -> + rustTemplate( + """ + /// Returns retry classifiers currently registered by the user. + pub fn retry_classifiers(&self) -> impl Iterator + '_ { + self.runtime_components.retry_classifiers() + } + """, + *codegenScope, + ) ServiceConfig.BuilderImpl -> rustTemplate( @@ -243,20 +245,21 @@ class RetryClassifierServiceRuntimePluginCustomization(codegenContext: ClientCod private val runtimeConfig = codegenContext.runtimeConfig private val retries = RuntimeType.smithyRuntime(runtimeConfig).resolve("client::retries") - override fun section(section: ServiceRuntimePluginSection): Writable = writable { - when (section) { - is ServiceRuntimePluginSection.RegisterRuntimeComponents -> { - section.registerRetryClassifier(this) { - rustTemplate( - "#{HttpStatusCodeClassifier}::default()", - "HttpStatusCodeClassifier" to retries.resolve("classifiers::HttpStatusCodeClassifier"), - ) + override fun section(section: ServiceRuntimePluginSection): Writable = + writable { + when (section) { + is ServiceRuntimePluginSection.RegisterRuntimeComponents -> { + section.registerRetryClassifier(this) { + rustTemplate( + "#{HttpStatusCodeClassifier}::default()", + "HttpStatusCodeClassifier" to retries.resolve("classifiers::HttpStatusCodeClassifier"), + ) + } } - } - else -> emptySection + else -> emptySection + } } - } } class RetryClassifierOperationCustomization( @@ -266,32 +269,34 @@ class RetryClassifierOperationCustomization( private val runtimeConfig = codegenContext.runtimeConfig private val symbolProvider = codegenContext.symbolProvider - override fun section(section: OperationSection): Writable = writable { - val classifiers = RuntimeType.smithyRuntime(runtimeConfig).resolve("client::retries::classifiers") + override fun section(section: OperationSection): Writable = + writable { + val classifiers = RuntimeType.smithyRuntime(runtimeConfig).resolve("client::retries::classifiers") - val codegenScope = arrayOf( - *RuntimeType.preludeScope, - "TransientErrorClassifier" to classifiers.resolve("TransientErrorClassifier"), - "ModeledAsRetryableClassifier" to classifiers.resolve("ModeledAsRetryableClassifier"), - "OperationError" to symbolProvider.symbolForOperationError(operation), - ) + val codegenScope = + arrayOf( + *RuntimeType.preludeScope, + "TransientErrorClassifier" to classifiers.resolve("TransientErrorClassifier"), + "ModeledAsRetryableClassifier" to classifiers.resolve("ModeledAsRetryableClassifier"), + "OperationError" to symbolProvider.symbolForOperationError(operation), + ) - when (section) { - is OperationSection.RetryClassifiers -> { - section.registerRetryClassifier(this) { - rustTemplate( - "#{TransientErrorClassifier}::<#{OperationError}>::new()", - *codegenScope, - ) - } - section.registerRetryClassifier(this) { - rustTemplate( - "#{ModeledAsRetryableClassifier}::<#{OperationError}>::new()", - *codegenScope, - ) + when (section) { + is OperationSection.RetryClassifiers -> { + section.registerRetryClassifier(this) { + rustTemplate( + "#{TransientErrorClassifier}::<#{OperationError}>::new()", + *codegenScope, + ) + } + section.registerRetryClassifier(this) { + rustTemplate( + "#{ModeledAsRetryableClassifier}::<#{OperationError}>::new()", + *codegenScope, + ) + } } + else -> emptySection } - else -> emptySection } - } } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/SensitiveOutputDecorator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/SensitiveOutputDecorator.kt index 5e484d4c861..b7281f3e983 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/SensitiveOutputDecorator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/SensitiveOutputDecorator.kt @@ -33,15 +33,18 @@ private class SensitiveOutputCustomization( private val operation: OperationShape, ) : OperationCustomization() { private val sensitiveIndex = SensitiveIndex.of(codegenContext.model) - override fun section(section: OperationSection): Writable = writable { - if (section is OperationSection.AdditionalRuntimePluginConfig && sensitiveIndex.hasSensitiveOutput(operation)) { - rustTemplate( - """ - ${section.newLayerName}.store_put(#{SensitiveOutput}); - """, - "SensitiveOutput" to RuntimeType.smithyRuntimeApiClient(codegenContext.runtimeConfig) - .resolve("client::orchestrator::SensitiveOutput"), - ) + + override fun section(section: OperationSection): Writable = + writable { + if (section is OperationSection.AdditionalRuntimePluginConfig && sensitiveIndex.hasSensitiveOutput(operation)) { + rustTemplate( + """ + ${section.newLayerName}.store_put(#{SensitiveOutput}); + """, + "SensitiveOutput" to + RuntimeType.smithyRuntimeApiClient(codegenContext.runtimeConfig) + .resolve("client::orchestrator::SensitiveOutput"), + ) + } } - } } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/TimeSourceCustomization.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/TimeSourceCustomization.kt index 67b3664688c..498cbee4a36 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/TimeSourceCustomization.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/TimeSourceCustomization.kt @@ -16,15 +16,16 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope class TimeSourceCustomization(codegenContext: ClientCodegenContext) : ConfigCustomization() { - private val codegenScope = arrayOf( - *preludeScope, - "IntoShared" to RuntimeType.smithyRuntimeApi(codegenContext.runtimeConfig).resolve("shared::IntoShared"), - "SharedTimeSource" to RuntimeType.smithyAsync(codegenContext.runtimeConfig).resolve("time::SharedTimeSource"), - "StaticTimeSource" to RuntimeType.smithyAsync(codegenContext.runtimeConfig).resolve("time::StaticTimeSource"), - "TimeSource" to RuntimeType.smithyAsync(codegenContext.runtimeConfig).resolve("time::TimeSource"), - "UNIX_EPOCH" to RuntimeType.std.resolve("time::UNIX_EPOCH"), - "Duration" to RuntimeType.std.resolve("time::Duration"), - ) + private val codegenScope = + arrayOf( + *preludeScope, + "IntoShared" to RuntimeType.smithyRuntimeApi(codegenContext.runtimeConfig).resolve("shared::IntoShared"), + "SharedTimeSource" to RuntimeType.smithyAsync(codegenContext.runtimeConfig).resolve("time::SharedTimeSource"), + "StaticTimeSource" to RuntimeType.smithyAsync(codegenContext.runtimeConfig).resolve("time::StaticTimeSource"), + "TimeSource" to RuntimeType.smithyAsync(codegenContext.runtimeConfig).resolve("time::TimeSource"), + "UNIX_EPOCH" to RuntimeType.std.resolve("time::UNIX_EPOCH"), + "Duration" to RuntimeType.std.resolve("time::Duration"), + ) override fun section(section: ServiceConfig) = writable { diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customize/ClientCodegenDecorator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customize/ClientCodegenDecorator.kt index bfbbb056982..c676e5259a9 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customize/ClientCodegenDecorator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customize/ClientCodegenDecorator.kt @@ -33,7 +33,9 @@ sealed interface AuthSchemeOption { val constructor: List, ) : AuthSchemeOption - class CustomResolver(/* unimplemented */) : AuthSchemeOption + class CustomResolver( + // unimplemented + ) : AuthSchemeOption } /** @@ -69,14 +71,20 @@ interface ClientCodegenDecorator : CoreCodegenDecorator, ): List = baseCustomizations - fun protocols(serviceId: ShapeId, currentProtocols: ClientProtocolMap): ClientProtocolMap = currentProtocols + fun protocols( + serviceId: ShapeId, + currentProtocols: ClientProtocolMap, + ): ClientProtocolMap = currentProtocols fun endpointCustomizations(codegenContext: ClientCodegenContext): List = listOf() /** * Hook to customize client construction documentation. */ - fun clientConstructionDocs(codegenContext: ClientCodegenContext, baseDocs: Writable): Writable = baseDocs + fun clientConstructionDocs( + codegenContext: ClientCodegenContext, + baseDocs: Writable, + ): Writable = baseDocs /** * Hooks to register additional service-level runtime plugins at codegen time @@ -111,33 +119,40 @@ open class CombinedClientCodegenDecorator(decorators: List, - ): List = combineCustomizations(baseAuthSchemeOptions) { decorator, authOptions -> - decorator.authOptions(codegenContext, operationShape, authOptions) - } + ): List = + combineCustomizations(baseAuthSchemeOptions) { decorator, authOptions -> + decorator.authOptions(codegenContext, operationShape, authOptions) + } override fun configCustomizations( codegenContext: ClientCodegenContext, baseCustomizations: List, - ): List = combineCustomizations(baseCustomizations) { decorator, customizations -> - decorator.configCustomizations(codegenContext, customizations) - } + ): List = + combineCustomizations(baseCustomizations) { decorator, customizations -> + decorator.configCustomizations(codegenContext, customizations) + } override fun operationCustomizations( codegenContext: ClientCodegenContext, operation: OperationShape, baseCustomizations: List, - ): List = combineCustomizations(baseCustomizations) { decorator, customizations -> - decorator.operationCustomizations(codegenContext, operation, customizations) - } + ): List = + combineCustomizations(baseCustomizations) { decorator, customizations -> + decorator.operationCustomizations(codegenContext, operation, customizations) + } override fun errorCustomizations( codegenContext: ClientCodegenContext, baseCustomizations: List, - ): List = combineCustomizations(baseCustomizations) { decorator, customizations -> - decorator.errorCustomizations(codegenContext, customizations) - } + ): List = + combineCustomizations(baseCustomizations) { decorator, customizations -> + decorator.errorCustomizations(codegenContext, customizations) + } - override fun protocols(serviceId: ShapeId, currentProtocols: ClientProtocolMap): ClientProtocolMap = + override fun protocols( + serviceId: ShapeId, + currentProtocols: ClientProtocolMap, + ): ClientProtocolMap = combineCustomizations(currentProtocols) { decorator, protocolMap -> decorator.protocols(serviceId, protocolMap) } @@ -145,7 +160,10 @@ open class CombinedClientCodegenDecorator(decorators: List = addCustomizations { decorator -> decorator.endpointCustomizations(codegenContext) } - override fun clientConstructionDocs(codegenContext: ClientCodegenContext, baseDocs: Writable): Writable = + override fun clientConstructionDocs( + codegenContext: ClientCodegenContext, + baseDocs: Writable, + ): Writable = combineCustomizations(baseDocs) { decorator, customizations -> decorator.clientConstructionDocs(codegenContext, customizations) } @@ -161,9 +179,10 @@ open class CombinedClientCodegenDecorator(decorators: List - decorator.protocolTestGenerator(codegenContext, gen) - } + ): ProtocolTestGenerator = + combineCustomizations(baseGenerator) { decorator, gen -> + decorator.protocolTestGenerator(codegenContext, gen) + } companion object { fun fromClasspath( @@ -171,16 +190,18 @@ open class CombinedClientCodegenDecorator(decorators: List, - ): List = baseCustomizations + - ResiliencyConfigCustomization(codegenContext) + - IdentityCacheConfigCustomization(codegenContext) + - InterceptorConfigCustomization(codegenContext) + - TimeSourceCustomization(codegenContext) + - RetryClassifierConfigCustomization(codegenContext) + ): List = + baseCustomizations + + ResiliencyConfigCustomization(codegenContext) + + IdentityCacheConfigCustomization(codegenContext) + + InterceptorConfigCustomization(codegenContext) + + TimeSourceCustomization(codegenContext) + + RetryClassifierConfigCustomization(codegenContext) override fun libRsCustomizations( codegenContext: ClientCodegenContext, baseCustomizations: List, - ): List = - baseCustomizations + AllowLintsCustomization() + ): List = baseCustomizations + AllowLintsCustomization() - override fun extras(codegenContext: ClientCodegenContext, rustCrate: RustCrate) { + override fun extras( + codegenContext: ClientCodegenContext, + rustCrate: RustCrate, + ) { val rc = codegenContext.runtimeConfig // Add rt-tokio feature for `ByteStream::from_path` @@ -126,7 +129,8 @@ class RequiredCustomizations : ClientCodegenDecorator { override fun serviceRuntimePluginCustomizations( codegenContext: ClientCodegenContext, baseCustomizations: List, - ): List = baseCustomizations + - ConnectionPoisoningRuntimePluginCustomization(codegenContext) + - RetryClassifierServiceRuntimePluginCustomization(codegenContext) + ): List = + baseCustomizations + + ConnectionPoisoningRuntimePluginCustomization(codegenContext) + + RetryClassifierServiceRuntimePluginCustomization(codegenContext) } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/ClientContextConfigCustomization.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/ClientContextConfigCustomization.kt index db842def8d0..96ddb2083c7 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/ClientContextConfigCustomization.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/ClientContextConfigCustomization.kt @@ -37,12 +37,16 @@ import software.amazon.smithy.rust.codegen.core.util.toSnakeCase */ class ClientContextConfigCustomization(ctx: ClientCodegenContext) : ConfigCustomization() { private val runtimeConfig = ctx.runtimeConfig - private val configParams = ctx.serviceShape.getTrait()?.parameters.orEmpty().toList() - .map { (key, value) -> fromClientParam(key, value, ctx.symbolProvider, runtimeConfig) } + private val configParams = + ctx.serviceShape.getTrait()?.parameters.orEmpty().toList() + .map { (key, value) -> fromClientParam(key, value, ctx.symbolProvider, runtimeConfig) } private val decorators = configParams.map { standardConfigParam(it, ctx) } companion object { - fun toSymbol(shapeType: ShapeType, symbolProvider: RustSymbolProvider): Symbol = + fun toSymbol( + shapeType: ShapeType, + symbolProvider: RustSymbolProvider, + ): Symbol = symbolProvider.toSymbol( when (shapeType) { ShapeType.STRING -> StringShape.builder().id("smithy.api#String").build() diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/EndpointConfigCustomization.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/EndpointConfigCustomization.kt index 1131c201d3a..f11a4bd92a2 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/EndpointConfigCustomization.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/EndpointConfigCustomization.kt @@ -29,15 +29,16 @@ internal class EndpointConfigCustomization( private val epModule = RuntimeType.smithyRuntimeApiClient(runtimeConfig).resolve("client::endpoint") private val epRuntimeModule = RuntimeType.smithyRuntime(runtimeConfig).resolve("client::orchestrator::endpoints") - private val codegenScope = arrayOf( - *preludeScope, - "Params" to typesGenerator.paramsStruct(), - "Resolver" to RuntimeType.smithyRuntime(runtimeConfig).resolve("client::config_override::Resolver"), - "SharedEndpointResolver" to epModule.resolve("SharedEndpointResolver"), - "StaticUriEndpointResolver" to epRuntimeModule.resolve("StaticUriEndpointResolver"), - "ServiceSpecificResolver" to codegenContext.serviceSpecificEndpointResolver(), - "IntoShared" to RuntimeType.smithyRuntimeApi(runtimeConfig).resolve("shared::IntoShared"), - ) + private val codegenScope = + arrayOf( + *preludeScope, + "Params" to typesGenerator.paramsStruct(), + "Resolver" to RuntimeType.smithyRuntime(runtimeConfig).resolve("client::config_override::Resolver"), + "SharedEndpointResolver" to epModule.resolve("SharedEndpointResolver"), + "StaticUriEndpointResolver" to epRuntimeModule.resolve("StaticUriEndpointResolver"), + "ServiceSpecificResolver" to codegenContext.serviceSpecificEndpointResolver(), + "IntoShared" to RuntimeType.smithyRuntimeApi(runtimeConfig).resolve("shared::IntoShared"), + ) override fun section(section: ServiceConfig): Writable { return writable { @@ -55,17 +56,19 @@ internal class EndpointConfigCustomization( } ServiceConfig.BuilderImpl -> { - val endpointModule = ClientRustModule.Config.endpoint.fullyQualifiedPath() - .replace("crate::", "$moduleUseName::") + val endpointModule = + ClientRustModule.Config.endpoint.fullyQualifiedPath() + .replace("crate::", "$moduleUseName::") // if there are no rules, we don't generate a default resolver—we need to also suppress those docs. - val defaultResolverDocs = if (typesGenerator.defaultResolver() != null) { - """ - /// When unset, the client will used a generated endpoint resolver based on the endpoint resolution - /// rules for `$moduleUseName`. - """ - } else { - "/// This service does not define a default endpoint resolver." - } + val defaultResolverDocs = + if (typesGenerator.defaultResolver() != null) { + """ + /// When unset, the client will used a generated endpoint resolver based on the endpoint resolution + /// rules for `$moduleUseName`. + """ + } else { + "/// This service does not define a default endpoint resolver." + } if (codegenContext.settings.codegenConfig.includeEndpointUrlConfig) { rustTemplate( """ diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/EndpointParamsDecorator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/EndpointParamsDecorator.kt index d752efef95c..6631cd5db94 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/EndpointParamsDecorator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/EndpointParamsDecorator.kt @@ -37,13 +37,14 @@ private class EndpointParametersCustomization( private val codegenContext: ClientCodegenContext, private val operation: OperationShape, ) : OperationCustomization() { - override fun section(section: OperationSection): Writable = writable { - val symbolProvider = codegenContext.symbolProvider - val operationName = symbolProvider.toSymbol(operation).name - if (section is OperationSection.AdditionalInterceptors) { - section.registerInterceptor(codegenContext.runtimeConfig, this) { - rust("${operationName}EndpointParamsInterceptor") + override fun section(section: OperationSection): Writable = + writable { + val symbolProvider = codegenContext.symbolProvider + val operationName = symbolProvider.toSymbol(operation).name + if (section is OperationSection.AdditionalInterceptors) { + section.registerInterceptor(codegenContext.runtimeConfig, this) { + rust("${operationName}EndpointParamsInterceptor") + } } } - } } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/EndpointRulesetIndex.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/EndpointRulesetIndex.kt index a47aa2d8082..b6368624968 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/EndpointRulesetIndex.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/EndpointRulesetIndex.kt @@ -18,15 +18,15 @@ import java.util.concurrent.ConcurrentHashMap * Index to ensure that endpoint rulesets are parsed only once */ class EndpointRulesetIndex : KnowledgeIndex { - private val ruleSets: ConcurrentHashMap = ConcurrentHashMap() - fun endpointRulesForService(serviceShape: ServiceShape) = ruleSets.computeIfAbsent( - serviceShape, - ) { - serviceShape.getTrait()?.ruleSet?.let { EndpointRuleSet.fromNode(it) } - ?.also { it.typeCheck() } - } + fun endpointRulesForService(serviceShape: ServiceShape) = + ruleSets.computeIfAbsent( + serviceShape, + ) { + serviceShape.getTrait()?.ruleSet?.let { EndpointRuleSet.fromNode(it) } + ?.also { it.typeCheck() } + } fun endpointTests(serviceShape: ServiceShape) = serviceShape.getTrait()?.testCases ?: emptyList() diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/EndpointTypesGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/EndpointTypesGenerator.kt index d20e2d72b8f..081894b95e4 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/EndpointTypesGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/EndpointTypesGenerator.kt @@ -29,8 +29,9 @@ class EndpointTypesGenerator( val params: Parameters = rules?.parameters ?: Parameters.builder().build() private val runtimeConfig = codegenContext.runtimeConfig private val customizations = codegenContext.rootDecorator.endpointCustomizations(codegenContext) - private val stdlib = customizations - .flatMap { it.customRuntimeFunctions(codegenContext) } + private val stdlib = + customizations + .flatMap { it.customRuntimeFunctions(codegenContext) } companion object { fun fromContext(codegenContext: ClientCodegenContext): EndpointTypesGenerator { @@ -41,7 +42,9 @@ class EndpointTypesGenerator( } fun paramsStruct(): RuntimeType = EndpointParamsGenerator(codegenContext, params).paramsStruct() + fun paramsBuilder(): RuntimeType = EndpointParamsGenerator(codegenContext, params).paramsBuilder() + fun defaultResolver(): RuntimeType? = rules?.let { EndpointResolverGenerator(codegenContext, stdlib).defaultEndpointResolver(it) } @@ -63,9 +66,13 @@ class EndpointTypesGenerator( * * Exactly one endpoint customization must provide the value for this builtIn or null is returned. */ - fun builtInFor(parameter: Parameter, config: String): Writable? { - val defaultProviders = customizations - .mapNotNull { it.loadBuiltInFromServiceConfig(parameter, config) } + fun builtInFor( + parameter: Parameter, + config: String, + ): Writable? { + val defaultProviders = + customizations + .mapNotNull { it.loadBuiltInFromServiceConfig(parameter, config) } if (defaultProviders.size > 1) { error("Multiple providers provided a value for the builtin $parameter") } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/EndpointsDecorator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/EndpointsDecorator.kt index 2e53a29dccf..cb940ddfa5b 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/EndpointsDecorator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/EndpointsDecorator.kt @@ -44,7 +44,10 @@ interface EndpointCustomization { * } * ``` */ - fun loadBuiltInFromServiceConfig(parameter: Parameter, configRef: String): Writable? = null + fun loadBuiltInFromServiceConfig( + parameter: Parameter, + configRef: String, + ): Writable? = null /** * Set a given builtIn value on the service config builder. If this builtIn is not recognized, return null @@ -65,7 +68,11 @@ interface EndpointCustomization { * ``` */ - fun setBuiltInOnServiceConfig(name: String, value: Node, configBuilderRef: String): Writable? = null + fun setBuiltInOnServiceConfig( + name: String, + value: Node, + configBuilderRef: String, + ): Writable? = null /** * Provide a list of additional endpoints standard library functions that rules can use @@ -113,22 +120,27 @@ class EndpointsDecorator : ClientCodegenDecorator { codegenContext: ClientCodegenContext, baseCustomizations: List, ): List { - return baseCustomizations + object : ServiceRuntimePluginCustomization() { - override fun section(section: ServiceRuntimePluginSection): Writable { - return when (section) { - is ServiceRuntimePluginSection.RegisterRuntimeComponents -> writable { - codegenContext.defaultEndpointResolver()?.also { resolver -> - section.registerEndpointResolver(this, resolver) - } - } + return baseCustomizations + + object : ServiceRuntimePluginCustomization() { + override fun section(section: ServiceRuntimePluginSection): Writable { + return when (section) { + is ServiceRuntimePluginSection.RegisterRuntimeComponents -> + writable { + codegenContext.defaultEndpointResolver()?.also { resolver -> + section.registerEndpointResolver(this, resolver) + } + } - else -> emptySection + else -> emptySection + } } } - } } - override fun extras(codegenContext: ClientCodegenContext, rustCrate: RustCrate) { + override fun extras( + codegenContext: ClientCodegenContext, + rustCrate: RustCrate, + ) { val generator = EndpointTypesGenerator.fromContext(codegenContext) rustCrate.withModule(ClientRustModule.Config.endpoint) { withInlineModule(endpointTestsModule(), rustCrate.moduleDocProvider) { @@ -146,7 +158,8 @@ class EndpointsDecorator : ClientCodegenDecorator { private fun ClientCodegenContext.defaultEndpointResolver(): Writable? { val generator = EndpointTypesGenerator.fromContext(this) val defaultResolver = generator.defaultResolver() ?: return null - val ctx = arrayOf("DefaultResolver" to defaultResolver, "ServiceSpecificResolver" to serviceSpecificEndpointResolver()) + val ctx = + arrayOf("DefaultResolver" to defaultResolver, "ServiceSpecificResolver" to serviceSpecificEndpointResolver()) return writable { rustTemplate( """{ diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/Util.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/Util.kt index 7d88408b738..24831d00272 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/Util.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/Util.kt @@ -49,7 +49,8 @@ fun Identifier.rustName(): String { */ object EndpointsLib { val DiagnosticCollector = endpointsLib("diagnostic").toType().resolve("DiagnosticCollector") - fun PartitionResolver(runtimeConfig: RuntimeConfig) = + + fun partitionResolver(runtimeConfig: RuntimeConfig) = endpointsLib("partition", CargoDependency.smithyJson(runtimeConfig), CargoDependency.RegexLite).toType() .resolve("PartitionResolver") @@ -63,7 +64,10 @@ object EndpointsLib { endpointsLib("s3", endpointsLib("host"), CargoDependency.OnceCell, CargoDependency.RegexLite).toType() .resolve("is_virtual_hostable_s3_bucket") - private fun endpointsLib(name: String, vararg additionalDependency: RustDependency) = InlineDependency.forRustFile( + private fun endpointsLib( + name: String, + vararg additionalDependency: RustDependency, + ) = InlineDependency.forRustFile( RustModule.pubCrate( name, parent = EndpointStdLib, @@ -81,13 +85,14 @@ class Types(runtimeConfig: RuntimeConfig) { private val endpointRtApi = RuntimeType.smithyRuntimeApiClient(runtimeConfig).resolve("client::endpoint") val resolveEndpointError = smithyHttpEndpointModule.resolve("ResolveEndpointError") - fun toArray() = arrayOf( - "Endpoint" to smithyEndpoint, - "EndpointFuture" to endpointFuture, - "SharedEndpointResolver" to endpointRtApi.resolve("SharedEndpointResolver"), - "EndpointResolverParams" to endpointRtApi.resolve("EndpointResolverParams"), - "ResolveEndpoint" to endpointRtApi.resolve("ResolveEndpoint"), - ) + fun toArray() = + arrayOf( + "Endpoint" to smithyEndpoint, + "EndpointFuture" to endpointFuture, + "SharedEndpointResolver" to endpointRtApi.resolve("SharedEndpointResolver"), + "EndpointResolverParams" to endpointRtApi.resolve("EndpointResolverParams"), + "ResolveEndpoint" to endpointRtApi.resolve("ResolveEndpoint"), + ) } /** @@ -103,11 +108,12 @@ fun ContextParamTrait.memberName(): String = this.name.unsafeToRustName() * Returns the symbol for a given parameter. This enables [software.amazon.smithy.rust.codegen.core.rustlang.RustWriter] to generate the correct [RustType]. */ fun Parameter.symbol(): Symbol { - val rustType = when (this.type) { - ParameterType.STRING -> RustType.String - ParameterType.BOOLEAN -> RustType.Bool - else -> TODO("unexpected type: ${this.type}") - } + val rustType = + when (this.type) { + ParameterType.STRING -> RustType.String + ParameterType.BOOLEAN -> RustType.Bool + else -> TODO("unexpected type: ${this.type}") + } // Parameter return types are always optional return Symbol.builder().rustType(rustType).build().letIf(!this.isRequired) { it.makeOptional() } } @@ -126,10 +132,10 @@ class AuthSchemeLister : RuleValueVisitor> { return endpoint.properties.getOrDefault(Identifier.of("authSchemes"), Literal.tupleLiteral(listOf())) .asTupleLiteral() .orNull()?.let { - it.map { authScheme -> - authScheme.asRecordLiteral().get()[Identifier.of("name")]!!.asStringLiteral().get().expectLiteral() - } - }?.toHashSet() ?: hashSetOf() + it.map { authScheme -> + authScheme.asRecordLiteral().get()[Identifier.of("name")]!!.asStringLiteral().get().expectLiteral() + } + }?.toHashSet() ?: hashSetOf() } override fun visitTreeRule(rules: MutableList): Set { diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/generators/EndpointParamsGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/generators/EndpointParamsGenerator.kt index 57a8f31de85..caab3d1e14e 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/generators/EndpointParamsGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/generators/EndpointParamsGenerator.kt @@ -42,13 +42,14 @@ import software.amazon.smithy.rust.codegen.core.util.orNull // internals contains the actual resolver function fun endpointImplModule() = RustModule.private("internals", parent = ClientRustModule.Config.endpoint) -fun endpointTestsModule() = RustModule.new( - "test", - visibility = Visibility.PRIVATE, - parent = ClientRustModule.Config.endpoint, - inline = true, - documentationOverride = "", -).cfgTest() +fun endpointTestsModule() = + RustModule.new( + "test", + visibility = Visibility.PRIVATE, + parent = ClientRustModule.Config.endpoint, + inline = true, + documentationOverride = "", + ).cfgTest() // stdlib is isolated because it contains code generated names of stdlib functions–we want to ensure we avoid clashing val EndpointStdLib = RustModule.private("endpoint_lib") @@ -117,43 +118,47 @@ internal class EndpointParamsGenerator( ) { companion object { fun memberName(parameterName: String) = Identifier.of(parameterName).rustName() + fun setterName(parameterName: String) = "set_${memberName(parameterName)}" } - fun paramsStruct(): RuntimeType = RuntimeType.forInlineFun("Params", ClientRustModule.Config.endpoint) { - generateEndpointsStruct(this) - } + fun paramsStruct(): RuntimeType = + RuntimeType.forInlineFun("Params", ClientRustModule.Config.endpoint) { + generateEndpointsStruct(this) + } - internal fun paramsBuilder(): RuntimeType = RuntimeType.forInlineFun("ParamsBuilder", ClientRustModule.Config.endpoint) { - generateEndpointParamsBuilder(this) - } + internal fun paramsBuilder(): RuntimeType = + RuntimeType.forInlineFun("ParamsBuilder", ClientRustModule.Config.endpoint) { + generateEndpointParamsBuilder(this) + } - private fun paramsError(): RuntimeType = RuntimeType.forInlineFun("InvalidParams", ClientRustModule.Config.endpoint) { - rust( - """ - /// An error that occurred during endpoint resolution - ##[derive(Debug)] - pub struct InvalidParams { - field: std::borrow::Cow<'static, str> - } + private fun paramsError(): RuntimeType = + RuntimeType.forInlineFun("InvalidParams", ClientRustModule.Config.endpoint) { + rust( + """ + /// An error that occurred during endpoint resolution + ##[derive(Debug)] + pub struct InvalidParams { + field: std::borrow::Cow<'static, str> + } - impl InvalidParams { - ##[allow(dead_code)] - fn missing(field: &'static str) -> Self { - Self { field: field.into() } + impl InvalidParams { + ##[allow(dead_code)] + fn missing(field: &'static str) -> Self { + Self { field: field.into() } + } } - } - impl std::fmt::Display for InvalidParams { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "a required field was missing: `{}`", self.field) + impl std::fmt::Display for InvalidParams { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "a required field was missing: `{}`", self.field) + } } - } - impl std::error::Error for InvalidParams { } - """, - ) - } + impl std::error::Error for InvalidParams { } + """, + ) + } /** * Generates an endpoints struct based on the provided endpoint rules. The struct fields are `pub(crate)` @@ -165,11 +170,9 @@ internal class EndpointParamsGenerator( // Automatically implement standard Rust functionality Attribute(derive(RuntimeType.Debug, RuntimeType.PartialEq, RuntimeType.Clone)).render(writer) // Generate the struct block: - /* - pub struct Params { - ... members: pub(crate) field - } - */ + // pub struct Params { + // ... members: pub(crate) field + // } writer.docs("Configuration parameters for resolving the correct endpoint") writer.rustBlock("pub struct Params") { parameters.toList().forEach { parameter -> @@ -203,14 +206,15 @@ internal class EndpointParamsGenerator( """, "paramType" to type.makeOptional().mapRustType { t -> t.asDeref() }, - "param" to writable { - when { - type.isOptional() && type.rustType().isCopy() -> rust("self.$name") - type.isOptional() -> rust("self.$name.as_deref()") - type.rustType().isCopy() -> rust("Some(self.$name)") - else -> rust("Some(&self.$name)") - } - }, + "param" to + writable { + when { + type.isOptional() && type.rustType().isCopy() -> rust("self.$name") + type.isOptional() -> rust("self.$name.as_deref()") + type.rustType().isCopy() -> rust("Some(self.$name)") + else -> rust("Some(&self.$name)") + } + }, ) } } @@ -243,22 +247,23 @@ internal class EndpointParamsGenerator( "Params" to paramsStruct(), "ParamsError" to paramsError(), ) { - val params = writable { - Attribute.AllowClippyUnnecessaryLazyEvaluations.render(this) - rustBlockTemplate("#{Params}", "Params" to paramsStruct()) { - parameters.toList().forEach { parameter -> - rust("${parameter.memberName()}: self.${parameter.memberName()}") - parameter.default.orNull()?.also { default -> rust(".or_else(||Some(${value(default)}))") } - if (parameter.isRequired) { - rustTemplate( - ".ok_or_else(||#{Error}::missing(${parameter.memberName().dq()}))?", - "Error" to paramsError(), - ) + val params = + writable { + Attribute.AllowClippyUnnecessaryLazyEvaluations.render(this) + rustBlockTemplate("#{Params}", "Params" to paramsStruct()) { + parameters.toList().forEach { parameter -> + rust("${parameter.memberName()}: self.${parameter.memberName()}") + parameter.default.orNull()?.also { default -> rust(".or_else(||Some(${value(default)}))") } + if (parameter.isRequired) { + rustTemplate( + ".ok_or_else(||#{Error}::missing(${parameter.memberName().dq()}))?", + "Error" to paramsError(), + ) + } + rust(",") } - rust(",") } } - } rust("Ok(#W)", params) } parameters.toList().forEach { parameter -> @@ -282,15 +287,16 @@ internal class EndpointParamsGenerator( """, "nonOptionalType" to parameter.symbol().mapRustType { it.stripOuter() }, "type" to type, - "extraDocs" to writable { - if (parameter.default.isPresent || parameter.documentation.isPresent) { - docs("") - } - parameter.default.orNull()?.also { - docs("When unset, this parameter has a default value of `$it`.") - } - parameter.documentation.orNull()?.also { docs(it) } - }, + "extraDocs" to + writable { + if (parameter.default.isPresent || parameter.documentation.isPresent) { + docs("") + } + parameter.default.orNull()?.also { + docs("When unset, this parameter has a default value of `$it`.") + } + parameter.documentation.orNull()?.also { docs(it) } + }, ) } } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/generators/EndpointParamsInterceptorGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/generators/EndpointParamsInterceptorGenerator.kt index ec3b21417a1..458128842e3 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/generators/EndpointParamsInterceptorGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/generators/EndpointParamsInterceptorGenerator.kt @@ -41,31 +41,35 @@ class EndpointParamsInterceptorGenerator( private val model = codegenContext.model private val symbolProvider = codegenContext.symbolProvider private val endpointTypesGenerator = EndpointTypesGenerator.fromContext(codegenContext) - private val codegenScope = codegenContext.runtimeConfig.let { rc -> - val endpointTypesGenerator = EndpointTypesGenerator.fromContext(codegenContext) - val runtimeApi = CargoDependency.smithyRuntimeApiClient(rc).toType() - val interceptors = runtimeApi.resolve("client::interceptors") - val orchestrator = runtimeApi.resolve("client::orchestrator") - arrayOf( - *preludeScope, - "BoxError" to RuntimeType.boxError(rc), - "ConfigBag" to RuntimeType.configBag(rc), - "ContextAttachedError" to interceptors.resolve("error::ContextAttachedError"), - "EndpointResolverParams" to runtimeApi.resolve("client::endpoint::EndpointResolverParams"), - "HttpRequest" to orchestrator.resolve("HttpRequest"), - "HttpResponse" to orchestrator.resolve("HttpResponse"), - "Intercept" to RuntimeType.intercept(rc), - "InterceptorContext" to RuntimeType.interceptorContext(rc), - "BeforeSerializationInterceptorContextRef" to RuntimeType.beforeSerializationInterceptorContextRef(rc), - "Input" to interceptors.resolve("context::Input"), - "Output" to interceptors.resolve("context::Output"), - "Error" to interceptors.resolve("context::Error"), - "InterceptorError" to interceptors.resolve("error::InterceptorError"), - "Params" to endpointTypesGenerator.paramsStruct(), - ) - } + private val codegenScope = + codegenContext.runtimeConfig.let { rc -> + val endpointTypesGenerator = EndpointTypesGenerator.fromContext(codegenContext) + val runtimeApi = CargoDependency.smithyRuntimeApiClient(rc).toType() + val interceptors = runtimeApi.resolve("client::interceptors") + val orchestrator = runtimeApi.resolve("client::orchestrator") + arrayOf( + *preludeScope, + "BoxError" to RuntimeType.boxError(rc), + "ConfigBag" to RuntimeType.configBag(rc), + "ContextAttachedError" to interceptors.resolve("error::ContextAttachedError"), + "EndpointResolverParams" to runtimeApi.resolve("client::endpoint::EndpointResolverParams"), + "HttpRequest" to orchestrator.resolve("HttpRequest"), + "HttpResponse" to orchestrator.resolve("HttpResponse"), + "Intercept" to RuntimeType.intercept(rc), + "InterceptorContext" to RuntimeType.interceptorContext(rc), + "BeforeSerializationInterceptorContextRef" to RuntimeType.beforeSerializationInterceptorContextRef(rc), + "Input" to interceptors.resolve("context::Input"), + "Output" to interceptors.resolve("context::Output"), + "Error" to interceptors.resolve("context::Error"), + "InterceptorError" to interceptors.resolve("error::InterceptorError"), + "Params" to endpointTypesGenerator.paramsStruct(), + ) + } - fun render(writer: RustWriter, operationShape: OperationShape) { + fun render( + writer: RustWriter, + operationShape: OperationShape, + ) { val operationName = symbolProvider.toSymbol(operationShape).name val operationInput = symbolProvider.toSymbol(operationShape.inputShape(model)) val interceptorName = "${operationName}EndpointParamsInterceptor" @@ -105,7 +109,10 @@ class EndpointParamsInterceptorGenerator( ) } - private fun paramSetters(operationShape: OperationShape, params: Parameters) = writable { + private fun paramSetters( + operationShape: OperationShape, + params: Parameters, + ) = writable { val idx = ContextIndex.of(codegenContext.model) val memberParams = idx.getContextParams(operationShape).toList().sortedBy { it.first.memberName } val builtInParams = params.toList().filter { it.isBuiltIn } @@ -154,26 +161,28 @@ class EndpointParamsInterceptorGenerator( } } - private fun endpointPrefix(operationShape: OperationShape): Writable = writable { - operationShape.getTrait(EndpointTrait::class.java).map { epTrait -> - val endpointTraitBindings = EndpointTraitBindings( - codegenContext.model, - symbolProvider, - codegenContext.runtimeConfig, - operationShape, - epTrait, - ) - withBlockTemplate( - "let endpoint_prefix = ", - """.map_err(|err| #{ContextAttachedError}::new("endpoint prefix could not be built", err))?;""", - *codegenScope, - ) { - endpointTraitBindings.render( - this, - "_input", - ) + private fun endpointPrefix(operationShape: OperationShape): Writable = + writable { + operationShape.getTrait(EndpointTrait::class.java).map { epTrait -> + val endpointTraitBindings = + EndpointTraitBindings( + codegenContext.model, + symbolProvider, + codegenContext.runtimeConfig, + operationShape, + epTrait, + ) + withBlockTemplate( + "let endpoint_prefix = ", + """.map_err(|err| #{ContextAttachedError}::new("endpoint prefix could not be built", err))?;""", + *codegenScope, + ) { + endpointTraitBindings.render( + this, + "_input", + ) + } + rust("cfg.interceptor_state().store_put(endpoint_prefix);") } - rust("cfg.interceptor_state().store_put(endpoint_prefix);") } - } } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/generators/EndpointResolverGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/generators/EndpointResolverGenerator.kt index 174bf8f5861..9aa441a6316 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/generators/EndpointResolverGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/generators/EndpointResolverGenerator.kt @@ -80,6 +80,7 @@ abstract class CustomRuntimeFunction { class FunctionRegistry(private val functions: List) { private var usedFunctions = mutableSetOf() + fun fnFor(id: String): CustomRuntimeFunction? = functions.firstOrNull { it.id == id }?.also { usedFunctions.add(it) } @@ -129,29 +130,31 @@ internal class EndpointResolverGenerator( private val runtimeConfig = codegenContext.runtimeConfig private val registry: FunctionRegistry = FunctionRegistry(stdlib) private val types = Types(runtimeConfig) - private val codegenScope = arrayOf( - "BoxError" to RuntimeType.boxError(runtimeConfig), - "endpoint" to types.smithyHttpEndpointModule, - "SmithyEndpoint" to types.smithyEndpoint, - "EndpointFuture" to types.endpointFuture, - "ResolveEndpointError" to types.resolveEndpointError, - "EndpointError" to types.resolveEndpointError, - "ServiceSpecificEndpointResolver" to codegenContext.serviceSpecificEndpointResolver(), - "DiagnosticCollector" to EndpointsLib.DiagnosticCollector, - ) - - private val allowLintsForResolver = listOf( - // we generate if x { if y { if z { ... } } } - "clippy::collapsible_if", - // we generate `if (true) == expr { ... }` - "clippy::bool_comparison", - // we generate `if !(a == b)` - "clippy::nonminimal_bool", - // we generate `if x == "" { ... }` - "clippy::comparison_to_empty", - // we generate `if let Some(_) = ... { ... }` - "clippy::redundant_pattern_matching", - ) + private val codegenScope = + arrayOf( + "BoxError" to RuntimeType.boxError(runtimeConfig), + "endpoint" to types.smithyHttpEndpointModule, + "SmithyEndpoint" to types.smithyEndpoint, + "EndpointFuture" to types.endpointFuture, + "ResolveEndpointError" to types.resolveEndpointError, + "EndpointError" to types.resolveEndpointError, + "ServiceSpecificEndpointResolver" to codegenContext.serviceSpecificEndpointResolver(), + "DiagnosticCollector" to EndpointsLib.DiagnosticCollector, + ) + + private val allowLintsForResolver = + listOf( + // we generate if x { if y { if z { ... } } } + "clippy::collapsible_if", + // we generate `if (true) == expr { ... }` + "clippy::bool_comparison", + // we generate `if !(a == b)` + "clippy::nonminimal_bool", + // we generate `if x == "" { ... }` + "clippy::comparison_to_empty", + // we generate `if let Some(_) = ... { ... }` + "clippy::redundant_pattern_matching", + ) private val context = Context(registry, runtimeConfig) companion object { @@ -234,36 +237,40 @@ internal class EndpointResolverGenerator( } } - private fun resolverFnBody(endpointRuleSet: EndpointRuleSet) = writable { - endpointRuleSet.parameters.toList().forEach { - Attribute.AllowUnusedVariables.render(this) - rust("let ${it.memberName()} = &$ParamsName.${it.memberName()};") + private fun resolverFnBody(endpointRuleSet: EndpointRuleSet) = + writable { + endpointRuleSet.parameters.toList().forEach { + Attribute.AllowUnusedVariables.render(this) + rust("let ${it.memberName()} = &$ParamsName.${it.memberName()};") + } + generateRulesList(endpointRuleSet.rules)(this) } - generateRulesList(endpointRuleSet.rules)(this) - } - private fun generateRulesList(rules: List) = writable { - rules.forEach { rule -> - rule.documentation.orNull()?.also { comment(escape(it)) } - generateRule(rule)(this) - } - if (!isExhaustive(rules.last())) { - // it's hard to figure out if these are always needed or not - Attribute.AllowUnreachableCode.render(this) - rustTemplate( - """return Err(#{EndpointError}::message(format!("No rules matched these parameters. This is a bug. {:?}", $ParamsName)));""", - *codegenScope, - ) + private fun generateRulesList(rules: List) = + writable { + rules.forEach { rule -> + rule.documentation.orNull()?.also { comment(escape(it)) } + generateRule(rule)(this) + } + if (!isExhaustive(rules.last())) { + // it's hard to figure out if these are always needed or not + Attribute.AllowUnreachableCode.render(this) + rustTemplate( + """return Err(#{EndpointError}::message(format!("No rules matched these parameters. This is a bug. {:?}", $ParamsName)));""", + *codegenScope, + ) + } } - } - private fun isExhaustive(rule: Rule): Boolean = rule.conditions.isEmpty() || rule.conditions.all { - when (it.function.type()) { - is BooleanType -> false - is OptionalType -> false - else -> true - } - } + private fun isExhaustive(rule: Rule): Boolean = + rule.conditions.isEmpty() || + rule.conditions.all { + when (it.function.type()) { + is BooleanType -> false + is OptionalType -> false + else -> true + } + } private fun generateRule(rule: Rule): Writable { return generateRuleInternal(rule, rule.conditions) @@ -284,7 +291,10 @@ internal class EndpointResolverGenerator( * * The resulting generated code is a series of nested-if statements, nesting each condition inside the previous. */ - private fun generateRuleInternal(rule: Rule, conditions: List): Writable { + private fun generateRuleInternal( + rule: Rule, + conditions: List, + ): Writable { if (conditions.isEmpty()) { return rule.accept(RuleVisitor()) } else { @@ -322,11 +332,12 @@ internal class EndpointResolverGenerator( "target" to target, "next" to next, // handle the rare but possible case where we bound the name of a variable to a boolean condition - "binding" to writable { - if (resultName != "_") { - rust("let $resultName = true;") - } - }, + "binding" to + writable { + if (resultName != "_") { + rust("let $resultName = true;") + } + }, ) } @@ -349,17 +360,19 @@ internal class EndpointResolverGenerator( inner class RuleVisitor : RuleValueVisitor { override fun visitTreeRule(rules: List) = generateRulesList(rules) - override fun visitErrorRule(error: Expression) = writable { - rustTemplate( - "return Err(#{EndpointError}::message(#{message:W}));", - *codegenScope, - "message" to ExpressionGenerator(Ownership.Owned, context).generate(error), - ) - } + override fun visitErrorRule(error: Expression) = + writable { + rustTemplate( + "return Err(#{EndpointError}::message(#{message:W}));", + *codegenScope, + "message" to ExpressionGenerator(Ownership.Owned, context).generate(error), + ) + } - override fun visitEndpointRule(endpoint: Endpoint): Writable = writable { - rust("return Ok(#W);", generateEndpoint(endpoint)) - } + override fun visitEndpointRule(endpoint: Endpoint): Writable = + writable { + rust("return Ok(#W);", generateEndpoint(endpoint)) + } } /** @@ -382,7 +395,8 @@ internal class EndpointResolverGenerator( fun ClientCodegenContext.serviceSpecificEndpointResolver(): RuntimeType { val generator = EndpointTypesGenerator.fromContext(this) return RuntimeType.forInlineFun("ResolveEndpoint", ClientRustModule.Config.endpoint) { - val ctx = arrayOf(*preludeScope, "Params" to generator.paramsStruct(), *Types(runtimeConfig).toArray(), "Debug" to RuntimeType.Debug) + val ctx = + arrayOf(*preludeScope, "Params" to generator.paramsStruct(), *Types(runtimeConfig).toArray(), "Debug" to RuntimeType.Debug) rustTemplate( """ /// Endpoint resolver trait specific to ${serviceShape.serviceNameOrDefault("this service")} diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/generators/EndpointTestGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/generators/EndpointTestGenerator.kt index d14c0f3d224..3e5f0ef3352 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/generators/EndpointTestGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/generators/EndpointTestGenerator.kt @@ -41,13 +41,14 @@ internal class EndpointTestGenerator( ) { private val runtimeConfig = codegenContext.runtimeConfig private val types = Types(runtimeConfig) - private val codegenScope = arrayOf( - "Endpoint" to types.smithyEndpoint, - "Error" to types.resolveEndpointError, - "Document" to RuntimeType.document(runtimeConfig), - "HashMap" to RuntimeType.HashMap, - "capture_request" to RuntimeType.captureRequest(runtimeConfig), - ) + private val codegenScope = + arrayOf( + "Endpoint" to types.smithyEndpoint, + "Error" to types.resolveEndpointError, + "Document" to RuntimeType.document(runtimeConfig), + "HashMap" to RuntimeType.HashMap, + "capture_request" to RuntimeType.captureRequest(runtimeConfig), + ) private val instantiator = ClientInstantiator(codegenContext) @@ -56,64 +57,71 @@ internal class EndpointTestGenerator( return writable { docs(self.documentation.orElse("no docs")) } } - private fun generateBaseTest(testCase: EndpointTestCase, id: Int): Writable = writable { - rustTemplate( - """ - #{docs:W} - ##[test] - fn test_$id() { - let params = #{params:W}; - let resolver = #{resolver}::new(); - let endpoint = resolver.resolve_endpoint(¶ms); - #{assertion:W} - } - """, - *codegenScope, - "docs" to testCase.docs(), - "params" to params(testCase), - "resolver" to resolverType, - "assertion" to writable { - testCase.expect.endpoint.ifPresent { endpoint -> - rustTemplate( - """ - let endpoint = endpoint.expect("Expected valid endpoint: ${escape(endpoint.url)}"); - assert_eq!(endpoint, #{expected:W}); - """, - *codegenScope, "expected" to generateEndpoint(endpoint), - ) + private fun generateBaseTest( + testCase: EndpointTestCase, + id: Int, + ): Writable = + writable { + rustTemplate( + """ + #{docs:W} + ##[test] + fn test_$id() { + let params = #{params:W}; + let resolver = #{resolver}::new(); + let endpoint = resolver.resolve_endpoint(¶ms); + #{assertion:W} } - testCase.expect.error.ifPresent { error -> - val expectedError = - escape("expected error: $error [${testCase.documentation.orNull() ?: "no docs"}]") - rustTemplate( - """ - let error = endpoint.expect_err(${expectedError.dq()}); - assert_eq!(format!("{}", error), ${escape(error).dq()}) - """, - *codegenScope, - ) - } - }, - ) - } + """, + *codegenScope, + "docs" to testCase.docs(), + "params" to params(testCase), + "resolver" to resolverType, + "assertion" to + writable { + testCase.expect.endpoint.ifPresent { endpoint -> + rustTemplate( + """ + let endpoint = endpoint.expect("Expected valid endpoint: ${escape(endpoint.url)}"); + assert_eq!(endpoint, #{expected:W}); + """, + *codegenScope, "expected" to generateEndpoint(endpoint), + ) + } + testCase.expect.error.ifPresent { error -> + val expectedError = + escape("expected error: $error [${testCase.documentation.orNull() ?: "no docs"}]") + rustTemplate( + """ + let error = endpoint.expect_err(${expectedError.dq()}); + assert_eq!(format!("{}", error), ${escape(error).dq()}) + """, + *codegenScope, + ) + } + }, + ) + } - fun generate(): Writable = writable { - var id = 0 - testCases.forEach { testCase -> - id += 1 - generateBaseTest(testCase, id)(this) + fun generate(): Writable = + writable { + var id = 0 + testCases.forEach { testCase -> + id += 1 + generateBaseTest(testCase, id)(this) + } } - } - private fun params(testCase: EndpointTestCase) = writable { - rust("#T::builder()", paramsType) - testCase.params.members.forEach { (id, value) -> - if (params.get(Identifier.of(id)).isPresent) { - rust(".${Identifier.of(id).rustName()}(#W)", generateValue(Value.fromNode(value))) + private fun params(testCase: EndpointTestCase) = + writable { + rust("#T::builder()", paramsType) + testCase.params.members.forEach { (id, value) -> + if (params.get(Identifier.of(id)).isPresent) { + rust(".${Identifier.of(id).rustName()}(#W)", generateValue(Value.fromNode(value))) + } } + rust(""".build().expect("invalid params")""") } - rust(""".build().expect("invalid params")""") - } private fun generateValue(value: Value): Writable { return { @@ -161,19 +169,20 @@ internal class EndpointTestGenerator( } } - private fun generateEndpoint(value: ExpectedEndpoint) = writable { - rustTemplate("#{Endpoint}::builder().url(${escape(value.url).dq()})", *codegenScope) - value.headers.forEach { (headerName, values) -> - values.forEach { headerValue -> - rust(".header(${headerName.dq()}, ${headerValue.dq()})") + private fun generateEndpoint(value: ExpectedEndpoint) = + writable { + rustTemplate("#{Endpoint}::builder().url(${escape(value.url).dq()})", *codegenScope) + value.headers.forEach { (headerName, values) -> + values.forEach { headerValue -> + rust(".header(${headerName.dq()}, ${headerValue.dq()})") + } } + value.properties.forEach { (name, value) -> + rust( + ".property(${name.dq()}, #W)", + generateValue(Value.fromNode(value)), + ) + } + rust(".build()") } - value.properties.forEach { (name, value) -> - rust( - ".property(${name.dq()}, #W)", - generateValue(Value.fromNode(value)), - ) - } - rust(".build()") - } } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/rulesgen/ExpressionGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/rulesgen/ExpressionGenerator.kt index 562ba8e2018..6474d866b94 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/rulesgen/ExpressionGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/rulesgen/ExpressionGenerator.kt @@ -32,7 +32,6 @@ class ExpressionGenerator( private val ownership: Ownership, private val context: Context, ) { - @Contract(pure = true) fun generate(expr: Expression): Writable { return expr.accept(ExprGeneratorVisitor(ownership)) @@ -45,85 +44,99 @@ class ExpressionGenerator( private val ownership: Ownership, ) : ExpressionVisitor { - override fun visitLiteral(literal: Literal): Writable { return literal.accept(LiteralGenerator(ownership, context)) } - override fun visitRef(ref: Reference) = writable { - if (ownership == Ownership.Owned) { - when (ref.type()) { - is BooleanType -> rust("*${ref.name.rustName()}") - else -> rust("${ref.name.rustName()}.to_owned()") + override fun visitRef(ref: Reference) = + writable { + if (ownership == Ownership.Owned) { + when (ref.type()) { + is BooleanType -> rust("*${ref.name.rustName()}") + else -> rust("${ref.name.rustName()}.to_owned()") + } + } else { + rust(ref.name.rustName()) } - } else { - rust(ref.name.rustName()) } - } override fun visitGetAttr(getAttr: GetAttr): Writable { val target = ExpressionGenerator(Ownership.Borrowed, context).generate(getAttr.target) - val path = writable { - getAttr.path.toList().forEach { part -> - when (part) { - is GetAttr.Part.Key -> rust(".${part.key().rustName()}()") - is GetAttr.Part.Index -> { - if (part.index() == 0) { - // In this case, `.first()` is more idiomatic and `.get(0)` triggers lint warnings - rust(".first().cloned()") - } else { - rust(".get(${part.index()}).cloned()") // we end up with Option<&&T>, we need to get to Option<&T> + val path = + writable { + getAttr.path.toList().forEach { part -> + when (part) { + is GetAttr.Part.Key -> rust(".${part.key().rustName()}()") + is GetAttr.Part.Index -> { + if (part.index() == 0) { + // In this case, `.first()` is more idiomatic and `.get(0)` triggers lint warnings + rust(".first().cloned()") + } else { + rust(".get(${part.index()}).cloned()") // we end up with Option<&&T>, we need to get to Option<&T> + } } } } - } - if (ownership == Ownership.Owned && getAttr.type() != Type.booleanType()) { - if (getAttr.type() is OptionalType) { - rust(".map(|t|t.to_owned())") - } else { - rust(".to_owned()") + if (ownership == Ownership.Owned && getAttr.type() != Type.booleanType()) { + if (getAttr.type() is OptionalType) { + rust(".map(|t|t.to_owned())") + } else { + rust(".to_owned()") + } } } - } return writable { rust("#W#W", target, path) } } - override fun visitIsSet(fn: Expression) = writable { - val expressionGenerator = ExpressionGenerator(Ownership.Borrowed, context) - rust("#W.is_some()", expressionGenerator.generate(fn)) - } + override fun visitIsSet(fn: Expression) = + writable { + val expressionGenerator = ExpressionGenerator(Ownership.Borrowed, context) + rust("#W.is_some()", expressionGenerator.generate(fn)) + } - override fun visitNot(not: Expression) = writable { - rust("!(#W)", ExpressionGenerator(Ownership.Borrowed, context).generate(not)) - } + override fun visitNot(not: Expression) = + writable { + rust("!(#W)", ExpressionGenerator(Ownership.Borrowed, context).generate(not)) + } - override fun visitBoolEquals(left: Expression, right: Expression) = writable { + override fun visitBoolEquals( + left: Expression, + right: Expression, + ) = writable { val expressionGenerator = ExpressionGenerator(Ownership.Owned, context) rust("(#W) == (#W)", expressionGenerator.generate(left), expressionGenerator.generate(right)) } - override fun visitStringEquals(left: Expression, right: Expression) = writable { + override fun visitStringEquals( + left: Expression, + right: Expression, + ) = writable { val expressionGenerator = ExpressionGenerator(Ownership.Borrowed, context) rust("(#W) == (#W)", expressionGenerator.generate(left), expressionGenerator.generate(right)) } - override fun visitLibraryFunction(fn: FunctionDefinition, args: MutableList): Writable = writable { - val fnDefinition = context.functionRegistry.fnFor(fn.id) - ?: PANIC( - "no runtime function for ${fn.id} " + - "(hint: if this is a custom or aws-specific runtime function, ensure the relevant standard library has been loaded " + - "on the classpath)", + override fun visitLibraryFunction( + fn: FunctionDefinition, + args: MutableList, + ): Writable = + writable { + val fnDefinition = + context.functionRegistry.fnFor(fn.id) + ?: PANIC( + "no runtime function for ${fn.id} " + + "(hint: if this is a custom or aws-specific runtime function, ensure the relevant standard library has been loaded " + + "on the classpath)", + ) + val expressionGenerator = ExpressionGenerator(Ownership.Borrowed, context) + val argWritables = args.map { expressionGenerator.generate(it) } + rustTemplate( + "#{fn}(#{args}, ${EndpointResolverGenerator.DiagnosticCollector})", + "fn" to fnDefinition.usage(), + "args" to argWritables.join(","), ) - val expressionGenerator = ExpressionGenerator(Ownership.Borrowed, context) - val argWritables = args.map { expressionGenerator.generate(it) } - rustTemplate( - "#{fn}(#{args}, ${EndpointResolverGenerator.DiagnosticCollector})", - "fn" to fnDefinition.usage(), - "args" to argWritables.join(","), - ) - if (ownership == Ownership.Owned) { - rust(".to_owned()") + if (ownership == Ownership.Owned) { + rust(".to_owned()") + } } - } } } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/rulesgen/LiteralGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/rulesgen/LiteralGenerator.kt index a986c50e6da..29337465c59 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/rulesgen/LiteralGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/rulesgen/LiteralGenerator.kt @@ -27,59 +27,69 @@ import java.util.stream.Stream class LiteralGenerator(private val ownership: Ownership, private val context: Context) : LiteralVisitor { private val runtimeConfig = context.runtimeConfig - private val codegenScope = arrayOf( - "Document" to RuntimeType.document(runtimeConfig), - "HashMap" to RuntimeType.HashMap, - ) - override fun visitBoolean(b: Boolean) = writable { - rust(b.toString()) - } - - override fun visitString(value: Template) = writable { - val parts: Stream = value.accept( - TemplateGenerator(ownership) { expr, ownership -> - ExpressionGenerator(ownership, context).generate(expr) - }, + private val codegenScope = + arrayOf( + "Document" to RuntimeType.document(runtimeConfig), + "HashMap" to RuntimeType.HashMap, ) - parts.forEach { part -> part(this) } - } - override fun visitRecord(members: MutableMap) = writable { - rustBlock("") { - rustTemplate( - "let mut out = #{HashMap}::::new();", - *codegenScope, - ) - members.keys.sortedBy { it.toString() }.map { k -> k to members[k]!! }.forEach { (identifier, literal) -> - rust( - "out.insert(${identifier.toString().dq()}.to_string(), #W.into());", - // When writing into the hashmap, it always needs to be an owned type - ExpressionGenerator(Ownership.Owned, context).generate(literal), + override fun visitBoolean(b: Boolean) = + writable { + rust(b.toString()) + } + + override fun visitString(value: Template) = + writable { + val parts: Stream = + value.accept( + TemplateGenerator(ownership) { expr, ownership -> + ExpressionGenerator(ownership, context).generate(expr) + }, ) - } - rustTemplate("out") + parts.forEach { part -> part(this) } } - } - override fun visitTuple(members: MutableList) = writable { - rustTemplate( - "vec![#{inner:W}]", *codegenScope, - "inner" to writable { - members.forEach { literal -> - rustTemplate( - "#{Document}::from(#{literal:W}),", - *codegenScope, - "literal" to ExpressionGenerator( - Ownership.Owned, - context, - ).generate(literal), + override fun visitRecord(members: MutableMap) = + writable { + rustBlock("") { + rustTemplate( + "let mut out = #{HashMap}::::new();", + *codegenScope, + ) + members.keys.sortedBy { it.toString() }.map { k -> k to members[k]!! }.forEach { (identifier, literal) -> + rust( + "out.insert(${identifier.toString().dq()}.to_string(), #W.into());", + // When writing into the hashmap, it always needs to be an owned type + ExpressionGenerator(Ownership.Owned, context).generate(literal), ) } - }, - ) - } + rustTemplate("out") + } + } - override fun visitInteger(value: Int) = writable { - rust("$value") - } + override fun visitTuple(members: MutableList) = + writable { + rustTemplate( + "vec![#{inner:W}]", *codegenScope, + "inner" to + writable { + members.forEach { literal -> + rustTemplate( + "#{Document}::from(#{literal:W}),", + *codegenScope, + "literal" to + ExpressionGenerator( + Ownership.Owned, + context, + ).generate(literal), + ) + } + }, + ) + } + + override fun visitInteger(value: Int) = + writable { + rust("$value") + } } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/rulesgen/StdLib.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/rulesgen/StdLib.kt index 8bbb3cd9629..f9a55ae8479 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/rulesgen/StdLib.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/rulesgen/StdLib.kt @@ -21,19 +21,23 @@ import software.amazon.smithy.rust.codegen.core.util.dq /** * Standard library functions available to all generated crates (e.g. not `aws.` specific / prefixed) */ -internal val SmithyEndpointsStdLib: List = listOf( - SimpleRuntimeFunction("substring", EndpointsLib.substring), - SimpleRuntimeFunction("isValidHostLabel", EndpointsLib.isValidHostLabel), - SimpleRuntimeFunction("parseURL", EndpointsLib.parseUrl), - SimpleRuntimeFunction("uriEncode", EndpointsLib.uriEncode), -) +internal val SmithyEndpointsStdLib: List = + listOf( + SimpleRuntimeFunction("substring", EndpointsLib.substring), + SimpleRuntimeFunction("isValidHostLabel", EndpointsLib.isValidHostLabel), + SimpleRuntimeFunction("parseURL", EndpointsLib.parseUrl), + SimpleRuntimeFunction("uriEncode", EndpointsLib.uriEncode), + ) /** * AWS Standard library functions * * This is defined in client-codegen to support running tests—it is not used when generating smithy-native services. */ -fun awsStandardLib(runtimeConfig: RuntimeConfig, partitionsDotJson: Node) = listOf( +fun awsStandardLib( + runtimeConfig: RuntimeConfig, + partitionsDotJson: Node, +) = listOf( SimpleRuntimeFunction("aws.parseArn", EndpointsLib.awsParseArn), SimpleRuntimeFunction("aws.isVirtualHostableS3Bucket", EndpointsLib.awsIsVirtualHostableS3Bucket), AwsPartitionResolver(runtimeConfig, partitionsDotJson), @@ -47,45 +51,52 @@ fun awsStandardLib(runtimeConfig: RuntimeConfig, partitionsDotJson: Node) = list class AwsPartitionResolver(runtimeConfig: RuntimeConfig, private val partitionsDotJson: Node) : CustomRuntimeFunction() { override val id: String = "aws.partition" - private val codegenScope = arrayOf( - "PartitionResolver" to EndpointsLib.PartitionResolver(runtimeConfig), - "Lazy" to CargoDependency.OnceCell.toType().resolve("sync::Lazy"), - ) - - override fun structFieldInit() = writable { - val json = Node.printJson(partitionsDotJson).dq() - rustTemplate( - """partition_resolver: #{DEFAULT_PARTITION_RESOLVER}.clone()""", - *codegenScope, - "DEFAULT_PARTITION_RESOLVER" to RuntimeType.forInlineFun("DEFAULT_PARTITION_RESOLVER", EndpointStdLib) { - rustTemplate( - """ - // Loading the partition JSON is expensive since it involves many regex compilations, - // so cache the result so that it only need to be paid for the first constructed client. - pub(crate) static DEFAULT_PARTITION_RESOLVER: #{Lazy}<#{PartitionResolver}> = - #{Lazy}::new(|| #{PartitionResolver}::new_from_json(b$json).expect("valid JSON")); - """, - *codegenScope, - ) - }, + private val codegenScope = + arrayOf( + "PartitionResolver" to EndpointsLib.partitionResolver(runtimeConfig), + "Lazy" to CargoDependency.OnceCell.toType().resolve("sync::Lazy"), ) - } - override fun additionalArgsSignature(): Writable = writable { - rustTemplate("partition_resolver: &#{PartitionResolver}", *codegenScope) - } + override fun structFieldInit() = + writable { + val json = Node.printJson(partitionsDotJson).dq() + rustTemplate( + """partition_resolver: #{DEFAULT_PARTITION_RESOLVER}.clone()""", + *codegenScope, + "DEFAULT_PARTITION_RESOLVER" to + RuntimeType.forInlineFun("DEFAULT_PARTITION_RESOLVER", EndpointStdLib) { + rustTemplate( + """ + // Loading the partition JSON is expensive since it involves many regex compilations, + // so cache the result so that it only need to be paid for the first constructed client. + pub(crate) static DEFAULT_PARTITION_RESOLVER: #{Lazy}<#{PartitionResolver}> = + #{Lazy}::new(|| #{PartitionResolver}::new_from_json(b$json).expect("valid JSON")); + """, + *codegenScope, + ) + }, + ) + } + + override fun additionalArgsSignature(): Writable = + writable { + rustTemplate("partition_resolver: &#{PartitionResolver}", *codegenScope) + } - override fun additionalArgsInvocation(self: String) = writable { - rust("&$self.partition_resolver") - } + override fun additionalArgsInvocation(self: String) = + writable { + rust("&$self.partition_resolver") + } - override fun structField(): Writable = writable { - rustTemplate("partition_resolver: #{PartitionResolver}", *codegenScope) - } + override fun structField(): Writable = + writable { + rustTemplate("partition_resolver: #{PartitionResolver}", *codegenScope) + } - override fun usage() = writable { - rust("partition_resolver.resolve_partition") - } + override fun usage() = + writable { + rust("partition_resolver.resolve_partition") + } } /** diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/rulesgen/TemplateGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/rulesgen/TemplateGenerator.kt index 29b8a8d50f1..de8375444af 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/rulesgen/TemplateGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/rulesgen/TemplateGenerator.kt @@ -34,40 +34,45 @@ class TemplateGenerator( private val ownership: Ownership, private val exprGenerator: (Expression, Ownership) -> Writable, ) : TemplateVisitor { - override fun visitStaticTemplate(value: String) = writable { - // In the case of a static template, return the literal string, eg. `"foo"`. - rust(value.dq()) - if (ownership == Ownership.Owned) { - rust(".to_string()") + override fun visitStaticTemplate(value: String) = + writable { + // In the case of a static template, return the literal string, eg. `"foo"`. + rust(value.dq()) + if (ownership == Ownership.Owned) { + rust(".to_string()") + } } - } override fun visitSingleDynamicTemplate(expr: Expression): Writable { return exprGenerator(expr, ownership) } - override fun visitStaticElement(str: String) = writable { - when (str.length) { - 0 -> {} - 1 -> rust("out.push('$str');") - else -> rust("out.push_str(${str.dq()});") + override fun visitStaticElement(str: String) = + writable { + when (str.length) { + 0 -> {} + 1 -> rust("out.push('$str');") + else -> rust("out.push_str(${str.dq()});") + } } - } - override fun visitDynamicElement(expr: Expression) = writable { - // we don't need to own the argument to push_str - Attribute.AllowClippyNeedlessBorrow.render(this) - rust("out.push_str(&#W);", exprGenerator(expr, Ownership.Borrowed)) - } + override fun visitDynamicElement(expr: Expression) = + writable { + // we don't need to own the argument to push_str + Attribute.AllowClippyNeedlessBorrow.render(this) + rust("out.push_str(&#W);", exprGenerator(expr, Ownership.Borrowed)) + } - override fun startMultipartTemplate() = writable { - if (ownership == Ownership.Borrowed) { - rust("&") + override fun startMultipartTemplate() = + writable { + if (ownership == Ownership.Borrowed) { + rust("&") + } + rust("{ let mut out = String::new();") } - rust("{ let mut out = String::new();") - } - override fun finishMultipartTemplate() = writable { - rust(" out }") - } + override fun finishMultipartTemplate() = + writable { + rust(" out }") + } } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/AuthOptionsPluginGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/AuthOptionsPluginGenerator.kt index 2fd61b38062..e58352a48ab 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/AuthOptionsPluginGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/AuthOptionsPluginGenerator.kt @@ -25,7 +25,11 @@ class AuthOptionsPluginGenerator(private val codegenContext: ClientCodegenContex private val runtimeConfig = codegenContext.runtimeConfig private val logger: Logger = Logger.getLogger(javaClass.name) - fun authPlugin(operationShape: OperationShape, authSchemeOptions: List) = writable { + + fun authPlugin( + operationShape: OperationShape, + authSchemeOptions: List, + ) = writable { rustTemplate( """ #{DefaultAuthOptionsPlugin}::new(vec![#{options}]) @@ -36,22 +40,27 @@ class AuthOptionsPluginGenerator(private val codegenContext: ClientCodegenContex ) } - private fun actualAuthSchemes(operationShape: OperationShape, authSchemeOptions: List): List { + private fun actualAuthSchemes( + operationShape: OperationShape, + authSchemeOptions: List, + ): List { val out: MutableList = mutableListOf() var noSupportedAuthSchemes = true - val authSchemes = ServiceIndex.of(codegenContext.model) - .getEffectiveAuthSchemes(codegenContext.serviceShape, operationShape) + val authSchemes = + ServiceIndex.of(codegenContext.model) + .getEffectiveAuthSchemes(codegenContext.serviceShape, operationShape) for (schemeShapeId in authSchemes.keys) { - val optionsForScheme = authSchemeOptions.filter { - when (it) { - is AuthSchemeOption.CustomResolver -> false - is AuthSchemeOption.StaticAuthSchemeOption -> { - it.schemeShapeId == schemeShapeId + val optionsForScheme = + authSchemeOptions.filter { + when (it) { + is AuthSchemeOption.CustomResolver -> false + is AuthSchemeOption.StaticAuthSchemeOption -> { + it.schemeShapeId == schemeShapeId + } } } - } if (optionsForScheme.isNotEmpty()) { out.addAll(optionsForScheme.flatMap { (it as AuthSchemeOption.StaticAuthSchemeOption).constructor }) @@ -64,10 +73,11 @@ class AuthOptionsPluginGenerator(private val codegenContext: ClientCodegenContex } } if (operationShape.hasTrait() || noSupportedAuthSchemes) { - val authOption = authSchemeOptions.find { - it is AuthSchemeOption.StaticAuthSchemeOption && it.schemeShapeId == noAuthSchemeShapeId - } - ?: throw IllegalStateException("Missing 'no auth' implementation. This is a codegen bug.") + val authOption = + authSchemeOptions.find { + it is AuthSchemeOption.StaticAuthSchemeOption && it.schemeShapeId == noAuthSchemeShapeId + } + ?: throw IllegalStateException("Missing 'no auth' implementation. This is a codegen bug.") out += (authOption as AuthSchemeOption.StaticAuthSchemeOption).constructor } if (out.any { it.isEmpty() }) { diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientBuilderInstantiator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientBuilderInstantiator.kt index 453568c8c57..15911fbcd0b 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientBuilderInstantiator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientBuilderInstantiator.kt @@ -17,36 +17,46 @@ import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerat import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderInstantiator class ClientBuilderInstantiator(private val clientCodegenContext: ClientCodegenContext) : BuilderInstantiator { - override fun setField(builder: String, value: Writable, field: MemberShape): Writable { + override fun setField( + builder: String, + value: Writable, + field: MemberShape, + ): Writable { return setFieldWithSetter(builder, value, field) } /** * For the client, we finalize builders with error correction enabled */ - override fun finalizeBuilder(builder: String, shape: StructureShape, mapErr: Writable?): Writable = writable { - val correctErrors = clientCodegenContext.correctErrors(shape) - val builderW = writable { - when { - correctErrors != null -> rustTemplate("#{correctErrors}($builder)", "correctErrors" to correctErrors) - else -> rustTemplate(builder) - } - } - if (BuilderGenerator.hasFallibleBuilder(shape, clientCodegenContext.symbolProvider)) { - rustTemplate( - "#{builder}.build()#{mapErr}", - "builder" to builderW, - "mapErr" to ( - mapErr?.map { - rust(".map_err(#T)?", it) - } ?: writable { } + override fun finalizeBuilder( + builder: String, + shape: StructureShape, + mapErr: Writable?, + ): Writable = + writable { + val correctErrors = clientCodegenContext.correctErrors(shape) + val builderW = + writable { + when { + correctErrors != null -> rustTemplate("#{correctErrors}($builder)", "correctErrors" to correctErrors) + else -> rustTemplate(builder) + } + } + if (BuilderGenerator.hasFallibleBuilder(shape, clientCodegenContext.symbolProvider)) { + rustTemplate( + "#{builder}.build()#{mapErr}", + "builder" to builderW, + "mapErr" to ( + mapErr?.map { + rust(".map_err(#T)?", it) + } ?: writable { } ), - ) - } else { - rustTemplate( - "#{builder}.build()", - "builder" to builderW, - ) + ) + } else { + rustTemplate( + "#{builder}.build()", + "builder" to builderW, + ) + } } - } } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientEnumGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientEnumGenerator.kt index c665144c23a..e545e9102bc 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientEnumGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientEnumGenerator.kt @@ -37,84 +37,91 @@ data class InfallibleEnumType( const val UnknownVariantValue = "UnknownVariantValue" } - override fun implFromForStr(context: EnumGeneratorContext): Writable = writable { - rustTemplate( - """ - impl #{From}<&str> for ${context.enumName} { - fn from(s: &str) -> Self { - match s { - #{matchArms} + override fun implFromForStr(context: EnumGeneratorContext): Writable = + writable { + rustTemplate( + """ + impl #{From}<&str> for ${context.enumName} { + fn from(s: &str) -> Self { + match s { + #{matchArms} + } } } - } - """, - "From" to RuntimeType.From, - "matchArms" to writable { - context.sortedMembers.forEach { member -> - rust("${member.value.dq()} => ${context.enumName}::${member.derivedName()},") - } - rust( - "other => ${context.enumName}::$UnknownVariant(#T(other.to_owned()))", - unknownVariantValue(context), - ) - }, - ) - } - - override fun implFromStr(context: EnumGeneratorContext): Writable = writable { - rustTemplate( - """ - impl ::std::str::FromStr for ${context.enumName} { - type Err = ::std::convert::Infallible; - - fn from_str(s: &str) -> #{Result}::Err> { - #{Ok}(${context.enumName}::from(s)) - } - } - """, - *preludeScope, - ) - } + """, + "From" to RuntimeType.From, + "matchArms" to + writable { + context.sortedMembers.forEach { member -> + rust("${member.value.dq()} => ${context.enumName}::${member.derivedName()},") + } + rust( + "other => ${context.enumName}::$UnknownVariant(#T(other.to_owned()))", + unknownVariantValue(context), + ) + }, + ) + } - override fun additionalEnumImpls(context: EnumGeneratorContext): Writable = writable { - // `try_parse` isn't needed for unnamed enums - if (context.enumTrait.hasNames()) { + override fun implFromStr(context: EnumGeneratorContext): Writable = + writable { rustTemplate( """ - impl ${context.enumName} { - /// Parses the enum value while disallowing unknown variants. - /// - /// Unknown variants will result in an error. - pub fn try_parse(value: &str) -> #{Result} { - match Self::from(value) { - ##[allow(deprecated)] - Self::Unknown(_) => #{Err}(#{UnknownVariantError}::new(value)), - known => Ok(known), - } + impl ::std::str::FromStr for ${context.enumName} { + type Err = ::std::convert::Infallible; + + fn from_str(s: &str) -> #{Result}::Err> { + #{Ok}(${context.enumName}::from(s)) } } """, *preludeScope, - "UnknownVariantError" to unknownVariantError(), ) } - } - override fun additionalDocs(context: EnumGeneratorContext): Writable = writable { - renderForwardCompatibilityNote(context.enumName, context.sortedMembers, UnknownVariant, UnknownVariantValue) - } + override fun additionalEnumImpls(context: EnumGeneratorContext): Writable = + writable { + // `try_parse` isn't needed for unnamed enums + if (context.enumTrait.hasNames()) { + rustTemplate( + """ + impl ${context.enumName} { + /// Parses the enum value while disallowing unknown variants. + /// + /// Unknown variants will result in an error. + pub fn try_parse(value: &str) -> #{Result} { + match Self::from(value) { + ##[allow(deprecated)] + Self::Unknown(_) => #{Err}(#{UnknownVariantError}::new(value)), + known => Ok(known), + } + } + } + """, + *preludeScope, + "UnknownVariantError" to unknownVariantError(), + ) + } + } - override fun additionalEnumMembers(context: EnumGeneratorContext): Writable = writable { - docs("`$UnknownVariant` contains new variants that have been added since this code was generated.") - rust( - """##[deprecated(note = "Don't directly match on `$UnknownVariant`. See the docs on this enum for the correct way to handle unknown variants.")]""", - ) - rust("$UnknownVariant(#T)", unknownVariantValue(context)) - } + override fun additionalDocs(context: EnumGeneratorContext): Writable = + writable { + renderForwardCompatibilityNote(context.enumName, context.sortedMembers, UnknownVariant, UnknownVariantValue) + } - override fun additionalAsStrMatchArms(context: EnumGeneratorContext): Writable = writable { - rust("${context.enumName}::$UnknownVariant(value) => value.as_str()") - } + override fun additionalEnumMembers(context: EnumGeneratorContext): Writable = + writable { + docs("`$UnknownVariant` contains new variants that have been added since this code was generated.") + rust( + """##[deprecated(note = "Don't directly match on `$UnknownVariant`. See the docs on this enum for the correct way to handle unknown variants.")]""", + ) + rust("$UnknownVariant(#T)", unknownVariantValue(context)) + } + + override fun additionalAsStrMatchArms(context: EnumGeneratorContext): Writable = + writable { + rust("${context.enumName}::$UnknownVariant(value) => value.as_str()") + } private fun unknownVariantValue(context: EnumGeneratorContext): RuntimeType { return RuntimeType.forInlineFun(UnknownVariantValue, unknownVariantModule) { @@ -142,8 +149,10 @@ data class InfallibleEnumType( * forward-compatible way. */ private fun RustWriter.renderForwardCompatibilityNote( - enumName: String, sortedMembers: List, - unknownVariant: String, unknownVariantValue: String, + enumName: String, + sortedMembers: List, + unknownVariant: String, + unknownVariantValue: String, ) { docs( """ @@ -210,26 +219,27 @@ class ClientEnumGenerator(codegenContext: ClientCodegenContext, shape: StringSha ), ) -private fun unknownVariantError(): RuntimeType = RuntimeType.forInlineFun("UnknownVariantError", ClientRustModule.Error) { - rustTemplate( - """ - /// The given enum value failed to parse since it is not a known value. - ##[derive(Debug)] - pub struct UnknownVariantError { - value: #{String}, - } - impl UnknownVariantError { - pub(crate) fn new(value: impl #{Into}<#{String}>) -> Self { - Self { value: value.into() } +private fun unknownVariantError(): RuntimeType = + RuntimeType.forInlineFun("UnknownVariantError", ClientRustModule.Error) { + rustTemplate( + """ + /// The given enum value failed to parse since it is not a known value. + ##[derive(Debug)] + pub struct UnknownVariantError { + value: #{String}, } - } - impl ::std::fmt::Display for UnknownVariantError { - fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> #{Result}<(), ::std::fmt::Error> { - write!(f, "unknown enum variant: '{}'", self.value) + impl UnknownVariantError { + pub(crate) fn new(value: impl #{Into}<#{String}>) -> Self { + Self { value: value.into() } + } } - } - impl ::std::error::Error for UnknownVariantError {} - """, - *preludeScope, - ) -} + impl ::std::fmt::Display for UnknownVariantError { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> #{Result}<(), ::std::fmt::Error> { + write!(f, "unknown enum variant: '{}'", self.value) + } + } + impl ::std::error::Error for UnknownVariantError {} + """, + *preludeScope, + ) + } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ConfigOverrideRuntimePluginGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ConfigOverrideRuntimePluginGenerator.kt index 1c92575ea42..c5fabe49cef 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ConfigOverrideRuntimePluginGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ConfigOverrideRuntimePluginGenerator.kt @@ -18,23 +18,27 @@ class ConfigOverrideRuntimePluginGenerator( codegenContext: ClientCodegenContext, ) { private val moduleUseName = codegenContext.moduleUseName() - private val codegenScope = codegenContext.runtimeConfig.let { rc -> - val runtimeApi = RuntimeType.smithyRuntimeApiClient(rc) - val smithyTypes = RuntimeType.smithyTypes(rc) - arrayOf( - *RuntimeType.preludeScope, - "Cow" to RuntimeType.Cow, - "CloneableLayer" to smithyTypes.resolve("config_bag::CloneableLayer"), - "FrozenLayer" to smithyTypes.resolve("config_bag::FrozenLayer"), - "InterceptorRegistrar" to runtimeApi.resolve("client::interceptors::InterceptorRegistrar"), - "Layer" to smithyTypes.resolve("config_bag::Layer"), - "Resolver" to RuntimeType.smithyRuntime(rc).resolve("client::config_override::Resolver"), - "RuntimeComponentsBuilder" to RuntimeType.runtimeComponentsBuilder(rc), - "RuntimePlugin" to RuntimeType.runtimePlugin(rc), - ) - } + private val codegenScope = + codegenContext.runtimeConfig.let { rc -> + val runtimeApi = RuntimeType.smithyRuntimeApiClient(rc) + val smithyTypes = RuntimeType.smithyTypes(rc) + arrayOf( + *RuntimeType.preludeScope, + "Cow" to RuntimeType.Cow, + "CloneableLayer" to smithyTypes.resolve("config_bag::CloneableLayer"), + "FrozenLayer" to smithyTypes.resolve("config_bag::FrozenLayer"), + "InterceptorRegistrar" to runtimeApi.resolve("client::interceptors::InterceptorRegistrar"), + "Layer" to smithyTypes.resolve("config_bag::Layer"), + "Resolver" to RuntimeType.smithyRuntime(rc).resolve("client::config_override::Resolver"), + "RuntimeComponentsBuilder" to RuntimeType.runtimeComponentsBuilder(rc), + "RuntimePlugin" to RuntimeType.runtimePlugin(rc), + ) + } - fun render(writer: RustWriter, customizations: List) { + fun render( + writer: RustWriter, + customizations: List, + ) { writer.rustTemplate( """ /// A plugin that enables configuration for a single operation invocation @@ -81,12 +85,13 @@ class ConfigOverrideRuntimePluginGenerator( } """, *codegenScope, - "config" to writable { - writeCustomizations( - customizations, - ServiceConfig.OperationConfigOverride("layer"), - ) - }, + "config" to + writable { + writeCustomizations( + customizations, + ServiceConfig.OperationConfigOverride("layer"), + ) + }, ) } } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/EndpointTraitBindingGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/EndpointTraitBindingGenerator.kt index fe373a13bb4..a97f57f7845 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/EndpointTraitBindingGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/EndpointTraitBindingGenerator.kt @@ -61,31 +61,33 @@ class EndpointTraitBindings( // build a list of args: `labelname = "field"` // these eventually end up in the format! macro invocation: // ```format!("some.{endpoint}", endpoint = endpoint);``` - val args = endpointTrait.hostPrefix.labels.map { label -> - val memberShape = inputShape.getMember(label.content).get() - val field = symbolProvider.toMemberName(memberShape) - if (symbolProvider.toSymbol(memberShape).isOptional()) { - rust("let $field = $input.$field.as_deref().unwrap_or_default();") - } else { - // NOTE: this is dead code until we start respecting @required - rust("let $field = &$input.$field;") + val args = + endpointTrait.hostPrefix.labels.map { label -> + val memberShape = inputShape.getMember(label.content).get() + val field = symbolProvider.toMemberName(memberShape) + if (symbolProvider.toSymbol(memberShape).isOptional()) { + rust("let $field = $input.$field.as_deref().unwrap_or_default();") + } else { + // NOTE: this is dead code until we start respecting @required + rust("let $field = &$input.$field;") + } + if (generateValidation) { + val errorString = "$field was unset or empty but must be set as part of the endpoint prefix" + val contents = + """ + if $field.is_empty() { + return Err(#{InvalidEndpointError}::failed_to_construct_uri("$errorString").into()); + } + """ + rustTemplate( + contents, + "InvalidEndpointError" to + RuntimeType.smithyHttp(runtimeConfig) + .resolve("endpoint::error::InvalidEndpointError"), + ) + } + "${label.content} = $field" } - if (generateValidation) { - val errorString = "$field was unset or empty but must be set as part of the endpoint prefix" - val contents = - """ - if $field.is_empty() { - return Err(#{InvalidEndpointError}::failed_to_construct_uri("$errorString").into()); - } - """ - rustTemplate( - contents, - "InvalidEndpointError" to RuntimeType.smithyHttp(runtimeConfig) - .resolve("endpoint::error::InvalidEndpointError"), - ) - } - "${label.content} = $field" - } rustTemplate( "#{EndpointPrefix}::new(format!($formatLiteral, ${args.joinToString()}))", "EndpointPrefix" to endpointPrefix, diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ErrorCorrection.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ErrorCorrection.kt index b05c0143836..07617d9301f 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ErrorCorrection.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ErrorCorrection.kt @@ -58,28 +58,32 @@ private fun ClientCodegenContext.errorCorrectedDefault(member: MemberShape): Wri val instantiator = PrimitiveInstantiator(runtimeConfig, symbolProvider) return writable { when { - target is EnumShape || target.hasTrait() -> rustTemplate( - """"no value was set".parse::<#{Shape}>().ok()""", - "Shape" to targetSymbol, - ) + target is EnumShape || target.hasTrait() -> + rustTemplate( + """"no value was set".parse::<#{Shape}>().ok()""", + "Shape" to targetSymbol, + ) - target is BooleanShape || target is NumberShape || target is StringShape || target is DocumentShape || target is ListShape || target is MapShape -> rust( - "Some(Default::default())", - ) + target is BooleanShape || target is NumberShape || target is StringShape || target is DocumentShape || target is ListShape || target is MapShape -> + rust( + "Some(Default::default())", + ) - target is StructureShape -> rustTemplate( - "{ let builder = #{Builder}::default(); #{instantiate} }", - "Builder" to symbolProvider.symbolForBuilder(target), - "instantiate" to builderInstantiator().finalizeBuilder("builder", target).map { - if (BuilderGenerator.hasFallibleBuilder(target, symbolProvider)) { - rust("#T.ok()", it) - } else { - it.some()(this) - } - }.letIf(memberSymbol.isRustBoxed()) { - it.plus { rustTemplate(".map(#{Box}::new)", *preludeScope) } - }, - ) + target is StructureShape -> + rustTemplate( + "{ let builder = #{Builder}::default(); #{instantiate} }", + "Builder" to symbolProvider.symbolForBuilder(target), + "instantiate" to + builderInstantiator().finalizeBuilder("builder", target).map { + if (BuilderGenerator.hasFallibleBuilder(target, symbolProvider)) { + rust("#T.ok()", it) + } else { + it.some()(this) + } + }.letIf(memberSymbol.isRustBoxed()) { + it.plus { rustTemplate(".map(#{Box}::new)", *preludeScope) } + }, + ) target is TimestampShape -> instantiator.instantiate(target, Node.from(0)).some()(this) target is BlobShape -> instantiator.instantiate(target, Node.from("")).some()(this) @@ -90,17 +94,18 @@ private fun ClientCodegenContext.errorCorrectedDefault(member: MemberShape): Wri fun ClientCodegenContext.correctErrors(shape: StructureShape): RuntimeType? { val name = symbolProvider.shapeFunctionName(serviceShape, shape) + "_correct_errors" - val corrections = writable { - shape.members().forEach { member -> - val memberName = symbolProvider.toMemberName(member) - errorCorrectedDefault(member)?.also { default -> - rustTemplate( - """if builder.$memberName.is_none() { builder.$memberName = #{default} }""", - "default" to default, - ) + val corrections = + writable { + shape.members().forEach { member -> + val memberName = symbolProvider.toMemberName(member) + errorCorrectedDefault(member)?.also { default -> + rustTemplate( + """if builder.$memberName.is_none() { builder.$memberName = #{default} }""", + "default" to default, + ) + } } } - } if (corrections.isEmpty()) { return null diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/NestedAccessorGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/NestedAccessorGenerator.kt index 138c7eb1231..49937301aad 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/NestedAccessorGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/NestedAccessorGenerator.kt @@ -29,7 +29,10 @@ class NestedAccessorGenerator(private val codegenContext: CodegenContext) { /** * Generate an accessor on [root] that consumes [root] and returns an `Option` for the nested item */ - fun generateOwnedAccessor(root: StructureShape, path: List): RuntimeType { + fun generateOwnedAccessor( + root: StructureShape, + path: List, + ): RuntimeType { check(path.isNotEmpty()) { "must not be called on an empty path" } val baseType = symbolProvider.toSymbol(path.last()) val fnName = symbolProvider.nestedAccessorName(codegenContext.serviceShape, "", root, path) @@ -48,7 +51,10 @@ class NestedAccessorGenerator(private val codegenContext: CodegenContext) { /** * Generate an accessor on [root] that takes a reference and returns an `Option<&T>` for the nested item */ - fun generateBorrowingAccessor(root: StructureShape, path: List): RuntimeType { + fun generateBorrowingAccessor( + root: StructureShape, + path: List, + ): RuntimeType { check(path.isNotEmpty()) { "must not be called on an empty path" } val baseType = symbolProvider.toSymbol(path.last()).makeOptional() val fnName = symbolProvider.nestedAccessorName(codegenContext.serviceShape, "ref", root, path) @@ -65,13 +71,17 @@ class NestedAccessorGenerator(private val codegenContext: CodegenContext) { } } - private fun generateBody(path: List, reference: Boolean): Writable = + private fun generateBody( + path: List, + reference: Boolean, + ): Writable = writable { - val ref = if (reference) { - "&" - } else { - "" - } + val ref = + if (reference) { + "&" + } else { + "" + } if (path.isEmpty()) { rustTemplate("#{Some}(input)", *preludeScope) } else { diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/OperationCustomization.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/OperationCustomization.kt index 176b85b1c8f..07ed4e51cba 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/OperationCustomization.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/OperationCustomization.kt @@ -70,7 +70,11 @@ sealed class OperationSection(name: String) : Section(name) { override val customizations: List, val operationShape: OperationShape, ) : OperationSection("AdditionalInterceptors") { - fun registerInterceptor(runtimeConfig: RuntimeConfig, writer: RustWriter, interceptor: Writable) { + fun registerInterceptor( + runtimeConfig: RuntimeConfig, + writer: RustWriter, + interceptor: Writable, + ) { writer.rustTemplate( ".with_interceptor(#{interceptor})", "interceptor" to interceptor, @@ -95,11 +99,17 @@ sealed class OperationSection(name: String) : Section(name) { override val customizations: List, val operationShape: OperationShape, ) : OperationSection("AdditionalRuntimePlugins") { - fun addClientPlugin(writer: RustWriter, plugin: Writable) { + fun addClientPlugin( + writer: RustWriter, + plugin: Writable, + ) { writer.rustTemplate(".with_client_plugin(#{plugin})", "plugin" to plugin) } - fun addOperationRuntimePlugin(writer: RustWriter, plugin: Writable) { + fun addOperationRuntimePlugin( + writer: RustWriter, + plugin: Writable, + ) { writer.rustTemplate(".with_operation_plugin(#{plugin})", "plugin" to plugin) } } @@ -108,7 +118,10 @@ sealed class OperationSection(name: String) : Section(name) { override val customizations: List, val operationShape: OperationShape, ) : OperationSection("RetryClassifiers") { - fun registerRetryClassifier(writer: RustWriter, classifier: Writable) { + fun registerRetryClassifier( + writer: RustWriter, + classifier: Writable, + ) { writer.rustTemplate(".with_retry_classifier(#{classifier})", "classifier" to classifier) } } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/OperationGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/OperationGenerator.kt index f96534ed9db..c2200003db0 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/OperationGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/OperationGenerator.kt @@ -36,9 +36,10 @@ import software.amazon.smithy.rust.codegen.core.util.sdkId open class OperationGenerator( private val codegenContext: ClientCodegenContext, private val protocol: Protocol, - private val bodyGenerator: ProtocolPayloadGenerator = ClientHttpBoundProtocolPayloadGenerator( - codegenContext, protocol, - ), + private val bodyGenerator: ProtocolPayloadGenerator = + ClientHttpBoundProtocolPayloadGenerator( + codegenContext, protocol, + ), ) { private val model = codegenContext.model private val runtimeConfig = codegenContext.runtimeConfig @@ -87,31 +88,35 @@ open class OperationGenerator( val outputType = symbolProvider.toSymbol(operationShape.outputShape(model)) val errorType = symbolProvider.symbolForOperationError(operationShape) - val codegenScope = arrayOf( - *preludeScope, - "Arc" to RuntimeType.Arc, - "ConcreteInput" to symbolProvider.toSymbol(operationShape.inputShape(model)), - "Input" to RuntimeType.smithyRuntimeApiClient(runtimeConfig).resolve("client::interceptors::context::Input"), - "Operation" to symbolProvider.toSymbol(operationShape), - "OperationError" to errorType, - "OperationOutput" to outputType, - "HttpResponse" to RuntimeType.smithyRuntimeApiClient(runtimeConfig) - .resolve("client::orchestrator::HttpResponse"), - "SdkError" to RuntimeType.sdkError(runtimeConfig), - ) - val additionalPlugins = writable { - writeCustomizations( - operationCustomizations, - OperationSection.AdditionalRuntimePlugins(operationCustomizations, operationShape), - ) - rustTemplate( - ".with_client_plugin(#{auth_plugin})", - "auth_plugin" to AuthOptionsPluginGenerator(codegenContext).authPlugin( - operationShape, - authSchemeOptions, - ), + val codegenScope = + arrayOf( + *preludeScope, + "Arc" to RuntimeType.Arc, + "ConcreteInput" to symbolProvider.toSymbol(operationShape.inputShape(model)), + "Input" to RuntimeType.smithyRuntimeApiClient(runtimeConfig).resolve("client::interceptors::context::Input"), + "Operation" to symbolProvider.toSymbol(operationShape), + "OperationError" to errorType, + "OperationOutput" to outputType, + "HttpResponse" to + RuntimeType.smithyRuntimeApiClient(runtimeConfig) + .resolve("client::orchestrator::HttpResponse"), + "SdkError" to RuntimeType.sdkError(runtimeConfig), ) - } + val additionalPlugins = + writable { + writeCustomizations( + operationCustomizations, + OperationSection.AdditionalRuntimePlugins(operationCustomizations, operationShape), + ) + rustTemplate( + ".with_client_plugin(#{auth_plugin})", + "auth_plugin" to + AuthOptionsPluginGenerator(codegenContext).authPlugin( + operationShape, + authSchemeOptions, + ), + ) + } rustTemplate( """ pub(crate) async fn orchestrate( @@ -166,24 +171,27 @@ open class OperationGenerator( *codegenScope, "Error" to RuntimeType.smithyRuntimeApiClient(runtimeConfig).resolve("client::interceptors::context::Error"), "InterceptorContext" to RuntimeType.interceptorContext(runtimeConfig), - "OrchestratorError" to RuntimeType.smithyRuntimeApiClient(runtimeConfig) - .resolve("client::orchestrator::error::OrchestratorError"), + "OrchestratorError" to + RuntimeType.smithyRuntimeApiClient(runtimeConfig) + .resolve("client::orchestrator::error::OrchestratorError"), "RuntimePlugin" to RuntimeType.runtimePlugin(runtimeConfig), "RuntimePlugins" to RuntimeType.runtimePlugins(runtimeConfig), "StopPoint" to RuntimeType.smithyRuntime(runtimeConfig).resolve("client::orchestrator::StopPoint"), - "invoke_with_stop_point" to RuntimeType.smithyRuntime(runtimeConfig) - .resolve("client::orchestrator::invoke_with_stop_point"), - "additional_runtime_plugins" to writable { - if (additionalPlugins.isNotEmpty()) { - rustTemplate( - """ - runtime_plugins = runtime_plugins - #{additional_runtime_plugins}; - """, - "additional_runtime_plugins" to additionalPlugins, - ) - } - }, + "invoke_with_stop_point" to + RuntimeType.smithyRuntime(runtimeConfig) + .resolve("client::orchestrator::invoke_with_stop_point"), + "additional_runtime_plugins" to + writable { + if (additionalPlugins.isNotEmpty()) { + rustTemplate( + """ + runtime_plugins = runtime_plugins + #{additional_runtime_plugins}; + """, + "additional_runtime_plugins" to additionalPlugins, + ) + } + }, ) writeCustomizations(operationCustomizations, OperationSection.OperationImplBlock(operationCustomizations)) diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/OperationRuntimePluginGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/OperationRuntimePluginGenerator.kt index 4d4d32a8dd2..e17fdc71706 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/OperationRuntimePluginGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/OperationRuntimePluginGenerator.kt @@ -21,28 +21,29 @@ import software.amazon.smithy.rust.codegen.core.util.dq class OperationRuntimePluginGenerator( private val codegenContext: ClientCodegenContext, ) { - private val codegenScope = codegenContext.runtimeConfig.let { rc -> - val runtimeApi = RuntimeType.smithyRuntimeApiClient(rc) - val smithyTypes = RuntimeType.smithyTypes(rc) - arrayOf( - *preludeScope, - "AuthSchemeOptionResolverParams" to runtimeApi.resolve("client::auth::AuthSchemeOptionResolverParams"), - "BoxError" to RuntimeType.boxError(codegenContext.runtimeConfig), - "ConfigBag" to RuntimeType.configBag(codegenContext.runtimeConfig), - "Cow" to RuntimeType.Cow, - "FrozenLayer" to smithyTypes.resolve("config_bag::FrozenLayer"), - "IntoShared" to runtimeApi.resolve("shared::IntoShared"), - "Layer" to smithyTypes.resolve("config_bag::Layer"), - "RetryClassifiers" to runtimeApi.resolve("client::retries::RetryClassifiers"), - "RuntimeComponentsBuilder" to RuntimeType.runtimeComponentsBuilder(codegenContext.runtimeConfig), - "RuntimePlugin" to RuntimeType.runtimePlugin(codegenContext.runtimeConfig), - "SharedAuthSchemeOptionResolver" to runtimeApi.resolve("client::auth::SharedAuthSchemeOptionResolver"), - "SharedRequestSerializer" to runtimeApi.resolve("client::ser_de::SharedRequestSerializer"), - "SharedResponseDeserializer" to runtimeApi.resolve("client::ser_de::SharedResponseDeserializer"), - "StaticAuthSchemeOptionResolver" to runtimeApi.resolve("client::auth::static_resolver::StaticAuthSchemeOptionResolver"), - "StaticAuthSchemeOptionResolverParams" to runtimeApi.resolve("client::auth::static_resolver::StaticAuthSchemeOptionResolverParams"), - ) - } + private val codegenScope = + codegenContext.runtimeConfig.let { rc -> + val runtimeApi = RuntimeType.smithyRuntimeApiClient(rc) + val smithyTypes = RuntimeType.smithyTypes(rc) + arrayOf( + *preludeScope, + "AuthSchemeOptionResolverParams" to runtimeApi.resolve("client::auth::AuthSchemeOptionResolverParams"), + "BoxError" to RuntimeType.boxError(codegenContext.runtimeConfig), + "ConfigBag" to RuntimeType.configBag(codegenContext.runtimeConfig), + "Cow" to RuntimeType.Cow, + "FrozenLayer" to smithyTypes.resolve("config_bag::FrozenLayer"), + "IntoShared" to runtimeApi.resolve("shared::IntoShared"), + "Layer" to smithyTypes.resolve("config_bag::Layer"), + "RetryClassifiers" to runtimeApi.resolve("client::retries::RetryClassifiers"), + "RuntimeComponentsBuilder" to RuntimeType.runtimeComponentsBuilder(codegenContext.runtimeConfig), + "RuntimePlugin" to RuntimeType.runtimePlugin(codegenContext.runtimeConfig), + "SharedAuthSchemeOptionResolver" to runtimeApi.resolve("client::auth::SharedAuthSchemeOptionResolver"), + "SharedRequestSerializer" to runtimeApi.resolve("client::ser_de::SharedRequestSerializer"), + "SharedResponseDeserializer" to runtimeApi.resolve("client::ser_de::SharedResponseDeserializer"), + "StaticAuthSchemeOptionResolver" to runtimeApi.resolve("client::auth::static_resolver::StaticAuthSchemeOptionResolver"), + "StaticAuthSchemeOptionResolverParams" to runtimeApi.resolve("client::auth::static_resolver::StaticAuthSchemeOptionResolverParams"), + ) + } fun render( writer: RustWriter, @@ -82,34 +83,38 @@ class OperationRuntimePluginGenerator( """, *codegenScope, *preludeScope, - "additional_config" to writable { - writeCustomizations( - customizations, - OperationSection.AdditionalRuntimePluginConfig( + "additional_config" to + writable { + writeCustomizations( + customizations, + OperationSection.AdditionalRuntimePluginConfig( + customizations, + newLayerName = "cfg", + operationShape, + ), + ) + }, + "runtime_plugin_supporting_types" to + writable { + writeCustomizations( + customizations, + OperationSection.RuntimePluginSupportingTypes(customizations, "cfg", operationShape), + ) + }, + "interceptors" to + writable { + writeCustomizations( + customizations, + OperationSection.AdditionalInterceptors(customizations, operationShape), + ) + }, + "retry_classifiers" to + writable { + writeCustomizations( customizations, - newLayerName = "cfg", - operationShape, - ), - ) - }, - "runtime_plugin_supporting_types" to writable { - writeCustomizations( - customizations, - OperationSection.RuntimePluginSupportingTypes(customizations, "cfg", operationShape), - ) - }, - "interceptors" to writable { - writeCustomizations( - customizations, - OperationSection.AdditionalInterceptors(customizations, operationShape), - ) - }, - "retry_classifiers" to writable { - writeCustomizations( - customizations, - OperationSection.RetryClassifiers(customizations, operationShape), - ) - }, + OperationSection.RetryClassifiers(customizations, operationShape), + ) + }, ) } } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/PaginatorGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/PaginatorGenerator.kt index c9c342a2a31..3d9700993f4 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/PaginatorGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/PaginatorGenerator.kt @@ -56,158 +56,163 @@ class PaginatorGenerator private constructor( private val runtimeConfig = codegenContext.runtimeConfig private val paginatorName = "${operation.id.name.toPascalCase()}Paginator" private val idx = PaginatedIndex.of(model) - private val paginationInfo = idx.getPaginationInfo(codegenContext.serviceShape, operation).orNull() - ?: PANIC("failed to load pagination info") - private val module = RustModule.public( - "paginator", - parent = symbolProvider.moduleForShape(operation), - documentationOverride = "Paginator for this operation", - ) + private val paginationInfo = + idx.getPaginationInfo(codegenContext.serviceShape, operation).orNull() + ?: PANIC("failed to load pagination info") + private val module = + RustModule.public( + "paginator", + parent = symbolProvider.moduleForShape(operation), + documentationOverride = "Paginator for this operation", + ) private val inputType = symbolProvider.toSymbol(operation.inputShape(model)) private val outputShape = operation.outputShape(model) private val outputType = symbolProvider.toSymbol(outputShape) private val errorType = symbolProvider.symbolForOperationError(operation) - private fun paginatorType(): RuntimeType = RuntimeType.forInlineFun( - paginatorName, - module, - generate(), - ) - - private val codegenScope = arrayOf( - *preludeScope, - "page_size_setter" to pageSizeSetter(), - - // Operation Types - "operation" to symbolProvider.toSymbol(operation), - "Input" to inputType, - "Output" to outputType, - "Error" to errorType, - "Builder" to symbolProvider.symbolForBuilder(operation.inputShape(model)), - - // SDK Types - "HttpResponse" to RuntimeType.smithyRuntimeApiClient(runtimeConfig).resolve("client::orchestrator::HttpResponse"), - "SdkError" to RuntimeType.sdkError(runtimeConfig), - "pagination_stream" to RuntimeType.smithyAsync(runtimeConfig).resolve("future::pagination_stream"), - - // External Types - "Stream" to RuntimeType.TokioStream.resolve("Stream"), + private fun paginatorType(): RuntimeType = + RuntimeType.forInlineFun( + paginatorName, + module, + generate(), + ) - ) + private val codegenScope = + arrayOf( + *preludeScope, + "page_size_setter" to pageSizeSetter(), + // Operation Types + "operation" to symbolProvider.toSymbol(operation), + "Input" to inputType, + "Output" to outputType, + "Error" to errorType, + "Builder" to symbolProvider.symbolForBuilder(operation.inputShape(model)), + // SDK Types + "HttpResponse" to RuntimeType.smithyRuntimeApiClient(runtimeConfig).resolve("client::orchestrator::HttpResponse"), + "SdkError" to RuntimeType.sdkError(runtimeConfig), + "pagination_stream" to RuntimeType.smithyAsync(runtimeConfig).resolve("future::pagination_stream"), + // External Types + "Stream" to RuntimeType.TokioStream.resolve("Stream"), + ) /** Generate the paginator struct & impl **/ - private fun generate() = writable { - val outputTokenLens = NestedAccessorGenerator(codegenContext).generateBorrowingAccessor( - outputShape, - paginationInfo.outputTokenMemberPath, - ) - val inputTokenMember = symbolProvider.toMemberName(paginationInfo.inputTokenMember) - rustTemplate( - """ - /// Paginator for #{operation:D} - pub struct $paginatorName { - handle: std::sync::Arc, - builder: #{Builder}, - stop_on_duplicate_token: bool, - } + private fun generate() = + writable { + val outputTokenLens = + NestedAccessorGenerator(codegenContext).generateBorrowingAccessor( + outputShape, + paginationInfo.outputTokenMemberPath, + ) + val inputTokenMember = symbolProvider.toMemberName(paginationInfo.inputTokenMember) + rustTemplate( + """ + /// Paginator for #{operation:D} + pub struct $paginatorName { + handle: std::sync::Arc, + builder: #{Builder}, + stop_on_duplicate_token: bool, + } - impl $paginatorName { - /// Create a new paginator-wrapper - pub(crate) fn new(handle: std::sync::Arc, builder: #{Builder}) -> Self { - Self { - handle, - builder, - stop_on_duplicate_token: true, + impl $paginatorName { + /// Create a new paginator-wrapper + pub(crate) fn new(handle: std::sync::Arc, builder: #{Builder}) -> Self { + Self { + handle, + builder, + stop_on_duplicate_token: true, + } } - } - #{page_size_setter:W} + #{page_size_setter:W} - #{items_fn:W} + #{items_fn:W} - /// Stop paginating when the service returns the same pagination token twice in a row. - /// - /// Defaults to true. - /// - /// For certain operations, it may be useful to continue on duplicate token. For example, - /// if an operation is for tailing a log file in real-time, then continuing may be desired. - /// This option can be set to `false` to accommodate these use cases. - pub fn stop_on_duplicate_token(mut self, stop_on_duplicate_token: bool) -> Self { - self.stop_on_duplicate_token = stop_on_duplicate_token; - self - } + /// Stop paginating when the service returns the same pagination token twice in a row. + /// + /// Defaults to true. + /// + /// For certain operations, it may be useful to continue on duplicate token. For example, + /// if an operation is for tailing a log file in real-time, then continuing may be desired. + /// This option can be set to `false` to accommodate these use cases. + pub fn stop_on_duplicate_token(mut self, stop_on_duplicate_token: bool) -> Self { + self.stop_on_duplicate_token = stop_on_duplicate_token; + self + } - /// Create the pagination stream - /// - /// _Note:_ No requests will be dispatched until the stream is used - /// (e.g. with the [`.next().await`](aws_smithy_async::future::pagination_stream::PaginationStream::next) method). - pub fn send(self) -> #{pagination_stream}::PaginationStream<#{item_type}> { - // Move individual fields out of self for the borrow checker - let builder = self.builder; - let handle = self.handle; - #{runtime_plugin_init} - #{pagination_stream}::PaginationStream::new(#{pagination_stream}::fn_stream::FnStream::new(move |tx| #{Box}::pin(async move { - // Build the input for the first time. If required fields are missing, this is where we'll produce an early error. - let mut input = match builder.build().map_err(#{SdkError}::construction_failure) { - #{Ok}(input) => input, - #{Err}(e) => { let _ = tx.send(#{Err}(e)).await; return; } - }; - loop { - let resp = #{orchestrate}; - // If the input member is None or it was an error - let done = match resp { - #{Ok}(ref resp) => { - let new_token = #{output_token}(resp); - let is_empty = new_token.map(|token| token.is_empty()).unwrap_or(true); - if !is_empty && new_token == input.$inputTokenMember.as_ref() && self.stop_on_duplicate_token { - true - } else { - input.$inputTokenMember = new_token.cloned(); - is_empty - } - }, - #{Err}(_) => true, + /// Create the pagination stream + /// + /// _Note:_ No requests will be dispatched until the stream is used + /// (e.g. with the [`.next().await`](aws_smithy_async::future::pagination_stream::PaginationStream::next) method). + pub fn send(self) -> #{pagination_stream}::PaginationStream<#{item_type}> { + // Move individual fields out of self for the borrow checker + let builder = self.builder; + let handle = self.handle; + #{runtime_plugin_init} + #{pagination_stream}::PaginationStream::new(#{pagination_stream}::fn_stream::FnStream::new(move |tx| #{Box}::pin(async move { + // Build the input for the first time. If required fields are missing, this is where we'll produce an early error. + let mut input = match builder.build().map_err(#{SdkError}::construction_failure) { + #{Ok}(input) => input, + #{Err}(e) => { let _ = tx.send(#{Err}(e)).await; return; } }; - if tx.send(resp).await.is_err() { - // receiving end was dropped - return + loop { + let resp = #{orchestrate}; + // If the input member is None or it was an error + let done = match resp { + #{Ok}(ref resp) => { + let new_token = #{output_token}(resp); + let is_empty = new_token.map(|token| token.is_empty()).unwrap_or(true); + if !is_empty && new_token == input.$inputTokenMember.as_ref() && self.stop_on_duplicate_token { + true + } else { + input.$inputTokenMember = new_token.cloned(); + is_empty + } + }, + #{Err}(_) => true, + }; + if tx.send(resp).await.is_err() { + // receiving end was dropped + return + } + if done { + return + } } - if done { - return - } - } - }))) + }))) + } } - } - """, - *codegenScope, - "items_fn" to itemsFn(), - "output_token" to outputTokenLens, - "item_type" to writable { - rustTemplate("#{Result}<#{Output}, #{SdkError}<#{Error}, #{HttpResponse}>>", *codegenScope) - }, - "orchestrate" to writable { - rustTemplate( - "#{operation}::orchestrate(&runtime_plugins, input.clone()).await", - *codegenScope, - ) - }, - "runtime_plugin_init" to writable { - rustTemplate( - """ - let runtime_plugins = #{operation}::operation_runtime_plugins( - handle.runtime_plugins.clone(), - &handle.conf, - #{None}, - ); - """, - *codegenScope, - "RuntimePlugins" to RuntimeType.runtimePlugins(runtimeConfig), - ) - }, - ) - } + """, + *codegenScope, + "items_fn" to itemsFn(), + "output_token" to outputTokenLens, + "item_type" to + writable { + rustTemplate("#{Result}<#{Output}, #{SdkError}<#{Error}, #{HttpResponse}>>", *codegenScope) + }, + "orchestrate" to + writable { + rustTemplate( + "#{operation}::orchestrate(&runtime_plugins, input.clone()).await", + *codegenScope, + ) + }, + "runtime_plugin_init" to + writable { + rustTemplate( + """ + let runtime_plugins = #{operation}::operation_runtime_plugins( + handle.runtime_plugins.clone(), + &handle.conf, + #{None}, + ); + """, + *codegenScope, + "RuntimePlugins" to RuntimeType.runtimePlugins(runtimeConfig), + ) + }, + ) + } /** Type of the inner item of the paginator */ private fun itemType(): String { @@ -243,59 +248,63 @@ class PaginatorGenerator private constructor( } /** Generate a struct with a `items()` method that flattens the paginator **/ - private fun itemsPaginator(): RuntimeType? = if (paginationInfo.itemsMemberPath.isEmpty()) { - null - } else { - RuntimeType.forInlineFun("${paginatorName}Items", module) { - rustTemplate( - """ - /// Flattened paginator for `$paginatorName` - /// - /// This is created with [`.items()`]($paginatorName::items) - pub struct ${paginatorName}Items($paginatorName); - - impl ${paginatorName}Items { - /// Create the pagination stream - /// - /// _Note_: No requests will be dispatched until the stream is used - /// (e.g. with the [`.next().await`](aws_smithy_async::future::pagination_stream::PaginationStream::next) method). + private fun itemsPaginator(): RuntimeType? = + if (paginationInfo.itemsMemberPath.isEmpty()) { + null + } else { + RuntimeType.forInlineFun("${paginatorName}Items", module) { + rustTemplate( + """ + /// Flattened paginator for `$paginatorName` /// - /// To read the entirety of the paginator, use [`.collect::, _>()`](aws_smithy_async::future::pagination_stream::PaginationStream::collect). - pub fn send(self) -> #{pagination_stream}::PaginationStream<#{item_type}> { - #{pagination_stream}::TryFlatMap::new(self.0.send()).flat_map(|page| #{extract_items}(page).unwrap_or_default().into_iter()) + /// This is created with [`.items()`]($paginatorName::items) + pub struct ${paginatorName}Items($paginatorName); + + impl ${paginatorName}Items { + /// Create the pagination stream + /// + /// _Note_: No requests will be dispatched until the stream is used + /// (e.g. with the [`.next().await`](aws_smithy_async::future::pagination_stream::PaginationStream::next) method). + /// + /// To read the entirety of the paginator, use [`.collect::, _>()`](aws_smithy_async::future::pagination_stream::PaginationStream::collect). + pub fn send(self) -> #{pagination_stream}::PaginationStream<#{item_type}> { + #{pagination_stream}::TryFlatMap::new(self.0.send()).flat_map(|page| #{extract_items}(page).unwrap_or_default().into_iter()) + } } - } - """, - "extract_items" to NestedAccessorGenerator(codegenContext).generateOwnedAccessor( - outputShape, - paginationInfo.itemsMemberPath, - ), - "item_type" to writable { - rustTemplate("#{Result}<${itemType()}, #{SdkError}<#{Error}, #{HttpResponse}>>", *codegenScope) - }, - *codegenScope, - ) + """, + "extract_items" to + NestedAccessorGenerator(codegenContext).generateOwnedAccessor( + outputShape, + paginationInfo.itemsMemberPath, + ), + "item_type" to + writable { + rustTemplate("#{Result}<${itemType()}, #{SdkError}<#{Error}, #{HttpResponse}>>", *codegenScope) + }, + *codegenScope, + ) + } } - } - private fun pageSizeSetter() = writable { - paginationInfo.pageSizeMember.orNull()?.also { - val memberName = symbolProvider.toMemberName(it) - val pageSizeT = - symbolProvider.toSymbol(it).rustType().stripOuter().render(true) - rustTemplate( - """ - /// Set the page size - /// - /// _Note: this method will override any previously set value for `$memberName`_ - pub fn page_size(mut self, limit: $pageSizeT) -> Self { - self.builder.$memberName = #{Some}(limit); - self - } - """, - *preludeScope, - ) + private fun pageSizeSetter() = + writable { + paginationInfo.pageSizeMember.orNull()?.also { + val memberName = symbolProvider.toMemberName(it) + val pageSizeT = + symbolProvider.toSymbol(it).rustType().stripOuter().render(true) + rustTemplate( + """ + /// Set the page size + /// + /// _Note: this method will override any previously set value for `$memberName`_ + pub fn page_size(mut self, limit: $pageSizeT) -> Self { + self.builder.$memberName = #{Some}(limit); + self + } + """, + *preludeScope, + ) + } } - } } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/SensitiveIndex.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/SensitiveIndex.kt index 15e8d5361ce..8198b43aaed 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/SensitiveIndex.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/SensitiveIndex.kt @@ -17,6 +17,7 @@ class SensitiveIndex(model: Model) : KnowledgeIndex { private val sensitiveOutputs = sensitiveOutputSelector.select(model).map { it.id }.toSet() fun hasSensitiveInput(operationShape: OperationShape): Boolean = sensitiveInputs.contains(operationShape.id) + fun hasSensitiveOutput(operationShape: OperationShape): Boolean = sensitiveOutputs.contains(operationShape.id) companion object { diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ServiceGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ServiceGenerator.kt index dfdd302a31e..88591ef30c3 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ServiceGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ServiceGenerator.kt @@ -41,10 +41,11 @@ class ServiceGenerator( ).render(rustCrate) rustCrate.withModule(ClientRustModule.config) { - val serviceConfigGenerator = ServiceConfigGenerator.withBaseBehavior( - codegenContext, - extraCustomizations = decorator.configCustomizations(codegenContext, listOf()), - ) + val serviceConfigGenerator = + ServiceConfigGenerator.withBaseBehavior( + codegenContext, + extraCustomizations = decorator.configCustomizations(codegenContext, listOf()), + ) serviceConfigGenerator.render(this) // Enable users to opt in to the test-utils in the runtime crate diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ServiceRuntimePluginGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ServiceRuntimePluginGenerator.kt index b58b8e97641..66095c1555d 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ServiceRuntimePluginGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ServiceRuntimePluginGenerator.kt @@ -32,26 +32,41 @@ sealed class ServiceRuntimePluginSection(name: String) : Section(name) { */ data class AdditionalConfig(val newLayerName: String, val serviceConfigName: String) : ServiceRuntimePluginSection("AdditionalConfig") { /** Adds a value to the config bag */ - fun putConfigValue(writer: RustWriter, value: Writable) { + fun putConfigValue( + writer: RustWriter, + value: Writable, + ) { writer.rust("$newLayerName.store_put(#T);", value) } } data class RegisterRuntimeComponents(val serviceConfigName: String) : ServiceRuntimePluginSection("RegisterRuntimeComponents") { /** Generates the code to register an interceptor */ - fun registerInterceptor(writer: RustWriter, interceptor: Writable) { + fun registerInterceptor( + writer: RustWriter, + interceptor: Writable, + ) { writer.rust("runtime_components.push_interceptor(#T);", interceptor) } - fun registerAuthScheme(writer: RustWriter, authScheme: Writable) { + fun registerAuthScheme( + writer: RustWriter, + authScheme: Writable, + ) { writer.rust("runtime_components.push_auth_scheme(#T);", authScheme) } - fun registerEndpointResolver(writer: RustWriter, resolver: Writable) { + fun registerEndpointResolver( + writer: RustWriter, + resolver: Writable, + ) { writer.rust("runtime_components.set_endpoint_resolver(Some(#T));", resolver) } - fun registerRetryClassifier(writer: RustWriter, classifier: Writable) { + fun registerRetryClassifier( + writer: RustWriter, + classifier: Writable, + ) { writer.rust("runtime_components.push_retry_classifier(#T);", classifier) } } @@ -64,30 +79,32 @@ typealias ServiceRuntimePluginCustomization = NamedCustomization - val runtimeApi = RuntimeType.smithyRuntimeApiClient(rc) - val smithyTypes = RuntimeType.smithyTypes(rc) - arrayOf( - *preludeScope, - "Arc" to RuntimeType.Arc, - "BoxError" to RuntimeType.boxError(codegenContext.runtimeConfig), - "Cow" to RuntimeType.Cow, - "FrozenLayer" to smithyTypes.resolve("config_bag::FrozenLayer"), - "IntoShared" to runtimeApi.resolve("shared::IntoShared"), - "Layer" to smithyTypes.resolve("config_bag::Layer"), - "RuntimeComponentsBuilder" to RuntimeType.runtimeComponentsBuilder(rc), - "RuntimePlugin" to RuntimeType.runtimePlugin(rc), - "Order" to runtimeApi.resolve("client::runtime_plugin::Order"), - ) - } + private val codegenScope = + codegenContext.runtimeConfig.let { rc -> + val runtimeApi = RuntimeType.smithyRuntimeApiClient(rc) + val smithyTypes = RuntimeType.smithyTypes(rc) + arrayOf( + *preludeScope, + "Arc" to RuntimeType.Arc, + "BoxError" to RuntimeType.boxError(codegenContext.runtimeConfig), + "Cow" to RuntimeType.Cow, + "FrozenLayer" to smithyTypes.resolve("config_bag::FrozenLayer"), + "IntoShared" to runtimeApi.resolve("shared::IntoShared"), + "Layer" to smithyTypes.resolve("config_bag::Layer"), + "RuntimeComponentsBuilder" to RuntimeType.runtimeComponentsBuilder(rc), + "RuntimePlugin" to RuntimeType.runtimePlugin(rc), + "Order" to runtimeApi.resolve("client::runtime_plugin::Order"), + ) + } fun render( writer: RustWriter, customizations: List, ) { - val additionalConfig = writable { - writeCustomizations(customizations, ServiceRuntimePluginSection.AdditionalConfig("cfg", "_service_config")) - } + val additionalConfig = + writable { + writeCustomizations(customizations, ServiceRuntimePluginSection.AdditionalConfig("cfg", "_service_config")) + } writer.rustTemplate( """ ##[derive(::std::fmt::Debug)] @@ -123,27 +140,30 @@ class ServiceRuntimePluginGenerator( #{declare_singletons} """, *codegenScope, - "config" to writable { - if (additionalConfig.isNotEmpty()) { - rustTemplate( - """ - let mut cfg = #{Layer}::new(${codegenContext.serviceShape.id.name.dq()}); - #{additional_config} - #{Some}(cfg.freeze()) - """, - *codegenScope, - "additional_config" to additionalConfig, - ) - } else { - rust("None") - } - }, - "runtime_components" to writable { - writeCustomizations(customizations, ServiceRuntimePluginSection.RegisterRuntimeComponents("_service_config")) - }, - "declare_singletons" to writable { - writeCustomizations(customizations, ServiceRuntimePluginSection.DeclareSingletons()) - }, + "config" to + writable { + if (additionalConfig.isNotEmpty()) { + rustTemplate( + """ + let mut cfg = #{Layer}::new(${codegenContext.serviceShape.id.name.dq()}); + #{additional_config} + #{Some}(cfg.freeze()) + """, + *codegenScope, + "additional_config" to additionalConfig, + ) + } else { + rust("None") + } + }, + "runtime_components" to + writable { + writeCustomizations(customizations, ServiceRuntimePluginSection.RegisterRuntimeComponents("_service_config")) + }, + "declare_singletons" to + writable { + writeCustomizations(customizations, ServiceRuntimePluginSection.DeclareSingletons()) + }, ) } } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/CustomizableOperationGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/CustomizableOperationGenerator.kt index 68ed17a95b7..403c8b1d3b7 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/CustomizableOperationGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/CustomizableOperationGenerator.kt @@ -28,31 +28,40 @@ class CustomizableOperationGenerator( private val runtimeConfig = codegenContext.runtimeConfig fun render(crate: RustCrate) { - val codegenScope = arrayOf( - *preludeScope, - "CustomizableOperation" to ClientRustModule.Client.customize.toType() - .resolve("CustomizableOperation"), - "CustomizableSend" to ClientRustModule.Client.customize.toType() - .resolve("internal::CustomizableSend"), - "HttpRequest" to RuntimeType.smithyRuntimeApiClient(runtimeConfig) - .resolve("client::orchestrator::HttpRequest"), - "HttpResponse" to RuntimeType.smithyRuntimeApiClient(runtimeConfig) - .resolve("client::orchestrator::HttpResponse"), - "Intercept" to RuntimeType.intercept(runtimeConfig), - "MapRequestInterceptor" to RuntimeType.smithyRuntime(runtimeConfig) - .resolve("client::interceptors::MapRequestInterceptor"), - "MutateRequestInterceptor" to RuntimeType.smithyRuntime(runtimeConfig) - .resolve("client::interceptors::MutateRequestInterceptor"), - "PhantomData" to RuntimeType.Phantom, - "RuntimePlugin" to RuntimeType.runtimePlugin(runtimeConfig), - "SharedRuntimePlugin" to RuntimeType.sharedRuntimePlugin(runtimeConfig), - "SendResult" to ClientRustModule.Client.customize.toType() - .resolve("internal::SendResult"), - "SdkBody" to RuntimeType.sdkBody(runtimeConfig), - "SdkError" to RuntimeType.sdkError(runtimeConfig), - "SharedInterceptor" to RuntimeType.smithyRuntimeApiClient(runtimeConfig) - .resolve("client::interceptors::SharedInterceptor"), - ) + val codegenScope = + arrayOf( + *preludeScope, + "CustomizableOperation" to + ClientRustModule.Client.customize.toType() + .resolve("CustomizableOperation"), + "CustomizableSend" to + ClientRustModule.Client.customize.toType() + .resolve("internal::CustomizableSend"), + "HttpRequest" to + RuntimeType.smithyRuntimeApiClient(runtimeConfig) + .resolve("client::orchestrator::HttpRequest"), + "HttpResponse" to + RuntimeType.smithyRuntimeApiClient(runtimeConfig) + .resolve("client::orchestrator::HttpResponse"), + "Intercept" to RuntimeType.intercept(runtimeConfig), + "MapRequestInterceptor" to + RuntimeType.smithyRuntime(runtimeConfig) + .resolve("client::interceptors::MapRequestInterceptor"), + "MutateRequestInterceptor" to + RuntimeType.smithyRuntime(runtimeConfig) + .resolve("client::interceptors::MutateRequestInterceptor"), + "PhantomData" to RuntimeType.Phantom, + "RuntimePlugin" to RuntimeType.runtimePlugin(runtimeConfig), + "SharedRuntimePlugin" to RuntimeType.sharedRuntimePlugin(runtimeConfig), + "SendResult" to + ClientRustModule.Client.customize.toType() + .resolve("internal::SendResult"), + "SdkBody" to RuntimeType.sdkBody(runtimeConfig), + "SdkError" to RuntimeType.sdkError(runtimeConfig), + "SharedInterceptor" to + RuntimeType.smithyRuntimeApiClient(runtimeConfig) + .resolve("client::interceptors::SharedInterceptor"), + ) val customizeModule = ClientRustModule.Client.customize crate.withModule(customizeModule) { @@ -174,17 +183,21 @@ class CustomizableOperationGenerator( } """, *codegenScope, - "additional_methods" to writable { - writeCustomizations( - customizations, - CustomizableOperationSection.CustomizableOperationImpl, - ) - }, + "additional_methods" to + writable { + writeCustomizations( + customizations, + CustomizableOperationSection.CustomizableOperationImpl, + ) + }, ) } } - private fun renderConvenienceAliases(parentModule: RustModule, writer: RustWriter) { + private fun renderConvenienceAliases( + parentModule: RustModule, + writer: RustWriter, + ) { writer.withInlineModule(RustModule.new("internal", Visibility.PUBCRATE, true, parentModule), null) { rustTemplate( """ @@ -206,8 +219,9 @@ class CustomizableOperationGenerator( } """, *preludeScope, - "HttpResponse" to RuntimeType.smithyRuntimeApiClient(runtimeConfig) - .resolve("client::orchestrator::HttpResponse"), + "HttpResponse" to + RuntimeType.smithyRuntimeApiClient(runtimeConfig) + .resolve("client::orchestrator::HttpResponse"), "SdkError" to RuntimeType.sdkError(runtimeConfig), ) } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/FluentClientCore.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/FluentClientCore.kt index 3073413efe6..b997142a990 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/FluentClientCore.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/FluentClientCore.kt @@ -20,7 +20,11 @@ import software.amazon.smithy.rust.codegen.core.smithy.generators.setterName class FluentClientCore(private val model: Model) { /** Generate and write Rust code for a builder method that sets a Vec */ - fun RustWriter.renderVecHelper(member: MemberShape, memberName: String, coreType: RustType.Vec) { + fun RustWriter.renderVecHelper( + member: MemberShape, + memberName: String, + coreType: RustType.Vec, + ) { docs("Appends an item to `${member.memberName}`.") rust("///") docs("To override the contents of this collection use [`${member.setterName()}`](Self::${member.setterName()}).") @@ -36,7 +40,11 @@ class FluentClientCore(private val model: Model) { } /** Generate and write Rust code for a builder method that sets a HashMap */ - fun RustWriter.renderMapHelper(member: MemberShape, memberName: String, coreType: RustType.HashMap) { + fun RustWriter.renderMapHelper( + member: MemberShape, + memberName: String, + coreType: RustType.HashMap, + ) { docs("Adds a key-value pair to `${member.memberName}`.") rust("///") docs("To override the contents of this collection use [`${member.setterName()}`](Self::${member.setterName()}).") @@ -58,7 +66,11 @@ class FluentClientCore(private val model: Model) { * `renderInputHelper(memberShape, "foo", RustType.String)` -> `pub fn foo(mut self, input: impl Into) -> Self { ... }` * `renderInputHelper(memberShape, "set_bar", RustType.Option)` -> `pub fn set_bar(mut self, input: Option) -> Self { ... }` */ - fun RustWriter.renderInputHelper(member: MemberShape, memberName: String, coreType: RustType) { + fun RustWriter.renderInputHelper( + member: MemberShape, + memberName: String, + coreType: RustType, + ) { val functionInput = coreType.asArgument("input") documentShape(member, model) @@ -72,7 +84,11 @@ class FluentClientCore(private val model: Model) { /** * Generate and write Rust code for a getter method that returns a reference to the inner data. */ - fun RustWriter.renderGetterHelper(member: MemberShape, memberName: String, coreType: RustType) { + fun RustWriter.renderGetterHelper( + member: MemberShape, + memberName: String, + coreType: RustType, + ) { documentShape(member, model) deprecatedShape(member) withBlockTemplate("pub fn $memberName(&self) -> &#{CoreType} {", "}", "CoreType" to coreType) { diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/FluentClientDecorator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/FluentClientDecorator.kt index 73296eb3378..76ed95ed1ce 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/FluentClientDecorator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/FluentClientDecorator.kt @@ -29,7 +29,10 @@ class FluentClientDecorator : ClientCodegenDecorator { private fun applies(codegenContext: ClientCodegenContext): Boolean = codegenContext.settings.codegenConfig.includeFluentClient - override fun extras(codegenContext: ClientCodegenContext, rustCrate: RustCrate) { + override fun extras( + codegenContext: ClientCodegenContext, + rustCrate: RustCrate, + ) { if (!applies(codegenContext)) { return } @@ -50,14 +53,17 @@ class FluentClientDecorator : ClientCodegenDecorator { return baseCustomizations } - return baseCustomizations + object : LibRsCustomization() { - override fun section(section: LibRsSection) = when (section) { - is LibRsSection.Body -> writable { - rust("pub use client::Client;") - } - else -> emptySection + return baseCustomizations + + object : LibRsCustomization() { + override fun section(section: LibRsSection) = + when (section) { + is LibRsSection.Body -> + writable { + rust("pub use client::Client;") + } + else -> emptySection + } } - } } } @@ -77,20 +83,21 @@ abstract class FluentClientCustomization : NamedCustomization writable { - val serviceName = codegenContext.serviceShape.serviceNameOrDefault("the service") - docs( - """ - An ergonomic client for $serviceName. + is FluentClientSection.FluentClientDocs -> + writable { + val serviceName = codegenContext.serviceShape.serviceNameOrDefault("the service") + docs( + """ + An ergonomic client for $serviceName. - This client allows ergonomic access to $serviceName. - Each method corresponds to an API defined in the service's Smithy model, - and the request and response shapes are auto-generated from that same model. - """, - ) - FluentClientDocs.clientConstructionDocs(codegenContext)(this) - FluentClientDocs.clientUsageDocs(codegenContext)(this) - } + This client allows ergonomic access to $serviceName. + Each method corresponds to an API defined in the service's Smithy model, + and the request and response shapes are auto-generated from that same model. + """, + ) + FluentClientDocs.clientConstructionDocs(codegenContext)(this) + FluentClientDocs.clientUsageDocs(codegenContext)(this) + } else -> emptySection } } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/FluentClientDocs.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/FluentClientDocs.kt index 448b576a9d3..be499c6ca6d 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/FluentClientDocs.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/FluentClientDocs.kt @@ -14,85 +14,89 @@ import software.amazon.smithy.rust.codegen.core.util.inputShape import software.amazon.smithy.rust.codegen.core.util.serviceNameOrDefault object FluentClientDocs { - fun clientConstructionDocs(codegenContext: ClientCodegenContext) = writable { - val serviceName = codegenContext.serviceShape.serviceNameOrDefault("the service") - val moduleUseName = codegenContext.moduleUseName() - docsTemplate( - """ - Client for calling $serviceName. + fun clientConstructionDocs(codegenContext: ClientCodegenContext) = + writable { + val serviceName = codegenContext.serviceShape.serviceNameOrDefault("the service") + val moduleUseName = codegenContext.moduleUseName() + docsTemplate( + """ + Client for calling $serviceName. - #### Constructing a `Client` + #### Constructing a `Client` - A `Client` requires a config in order to be constructed. With the default set of Cargo features, - this config will only require an endpoint to produce a functioning client. However, some Smithy - features will require additional configuration. For example, `@auth` requires some kind of identity - or identity resolver to be configured. The config is used to customize various aspects of the client, - such as: + A `Client` requires a config in order to be constructed. With the default set of Cargo features, + this config will only require an endpoint to produce a functioning client. However, some Smithy + features will require additional configuration. For example, `@auth` requires some kind of identity + or identity resolver to be configured. The config is used to customize various aspects of the client, + such as: - - [HTTP Connector](crate::config::Builder::http_connector) - - [Retry](crate::config::Builder::retry_config) - - [Timeouts](crate::config::Builder::timeout_config) - - [... and more](crate::config::Builder) + - [HTTP Connector](crate::config::Builder::http_connector) + - [Retry](crate::config::Builder::retry_config) + - [Timeouts](crate::config::Builder::timeout_config) + - [... and more](crate::config::Builder) - Below is a minimal example of how to create a client: + Below is a minimal example of how to create a client: - ```rust,no_run - let config = $moduleUseName::Config::builder() - .endpoint_url("http://localhost:1234") - .build(); - let client = $moduleUseName::Client::from_conf(config); - ``` + ```rust,no_run + let config = $moduleUseName::Config::builder() + .endpoint_url("http://localhost:1234") + .build(); + let client = $moduleUseName::Client::from_conf(config); + ``` - _Note:_ Client construction is expensive due to connection thread pool initialization, and should be done - once at application start-up. Cloning a client is cheap (it's just an [`Arc`](std::sync::Arc) under the hood), - so creating it once at start-up and cloning it around the application as needed is recommended. - """.trimIndent(), - ) - } + _Note:_ Client construction is expensive due to connection thread pool initialization, and should be done + once at application start-up. Cloning a client is cheap (it's just an [`Arc`](std::sync::Arc) under the hood), + so creating it once at start-up and cloning it around the application as needed is recommended. + """.trimIndent(), + ) + } - fun clientUsageDocs(codegenContext: ClientCodegenContext) = writable { - val model = codegenContext.model - val symbolProvider = codegenContext.symbolProvider - if (model.operationShapes.isNotEmpty()) { - // Find an operation with a simple string member shape - val (operation, member) = codegenContext.serviceShape.operations - .map { id -> - val operationShape = model.expectShape(id, OperationShape::class.java) - val member = operationShape.inputShape(model) - .members() - .firstOrNull { model.expectShape(it.target) is StringShape } - operationShape to member - } - .sortedBy { it.first.id } - .firstOrNull { (_, member) -> member != null } ?: (null to null) - if (operation != null && member != null) { - val operationSymbol = symbolProvider.toSymbol(operation) - val memberSymbol = symbolProvider.toSymbol(member) - val operationFnName = FluentClientGenerator.clientOperationFnName(operation, symbolProvider) - docsTemplate( - """ - ## Using the `Client` + fun clientUsageDocs(codegenContext: ClientCodegenContext) = + writable { + val model = codegenContext.model + val symbolProvider = codegenContext.symbolProvider + if (model.operationShapes.isNotEmpty()) { + // Find an operation with a simple string member shape + val (operation, member) = + codegenContext.serviceShape.operations + .map { id -> + val operationShape = model.expectShape(id, OperationShape::class.java) + val member = + operationShape.inputShape(model) + .members() + .firstOrNull { model.expectShape(it.target) is StringShape } + operationShape to member + } + .sortedBy { it.first.id } + .firstOrNull { (_, member) -> member != null } ?: (null to null) + if (operation != null && member != null) { + val operationSymbol = symbolProvider.toSymbol(operation) + val memberSymbol = symbolProvider.toSymbol(member) + val operationFnName = FluentClientGenerator.clientOperationFnName(operation, symbolProvider) + docsTemplate( + """ + ## Using the `Client` - A client has a function for every operation that can be performed by the service. - For example, the [`${operationSymbol.name}`](${operationSymbol.namespace}) operation has - a [`Client::$operationFnName`], function which returns a builder for that operation. - The fluent builder ultimately has a `send()` function that returns an async future that - returns a result, as illustrated below: + A client has a function for every operation that can be performed by the service. + For example, the [`${operationSymbol.name}`](${operationSymbol.namespace}) operation has + a [`Client::$operationFnName`], function which returns a builder for that operation. + The fluent builder ultimately has a `send()` function that returns an async future that + returns a result, as illustrated below: - ```rust,ignore - let result = client.$operationFnName() - .${memberSymbol.name}("example") - .send() - .await; - ``` + ```rust,ignore + let result = client.$operationFnName() + .${memberSymbol.name}("example") + .send() + .await; + ``` - The underlying HTTP requests that get made by this can be modified with the `customize_operation` - function on the fluent builder. See the [`customize`](crate::client::customize) module for more - information. - """.trimIndent(), - "operation" to operationSymbol, - ) + The underlying HTTP requests that get made by this can be modified with the `customize_operation` + function on the fluent builder. See the [`customize`](crate::client::customize) module for more + information. + """.trimIndent(), + "operation" to operationSymbol, + ) + } } } - } } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/FluentClientGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/FluentClientGenerator.kt index 3865d5a0ab2..660f8e982b5 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/FluentClientGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/FluentClientGenerator.kt @@ -60,16 +60,21 @@ import software.amazon.smithy.rust.codegen.core.util.sdkId import software.amazon.smithy.rust.codegen.core.util.toSnakeCase private val BehaviorVersionLatest = Feature("behavior-version-latest", false, listOf()) + class FluentClientGenerator( private val codegenContext: ClientCodegenContext, private val customizations: List = emptyList(), ) { - companion object { - fun clientOperationFnName(operationShape: OperationShape, symbolProvider: RustSymbolProvider): String = - RustReservedWords.escapeIfNeeded(symbolProvider.toSymbol(operationShape).name.toSnakeCase()) - - fun clientOperationModuleName(operationShape: OperationShape, symbolProvider: RustSymbolProvider): String = + fun clientOperationFnName( + operationShape: OperationShape, + symbolProvider: RustSymbolProvider, + ): String = RustReservedWords.escapeIfNeeded(symbolProvider.toSymbol(operationShape).name.toSnakeCase()) + + fun clientOperationModuleName( + operationShape: OperationShape, + symbolProvider: RustSymbolProvider, + ): String = RustReservedWords.escapeIfNeeded( symbolProvider.toSymbol(operationShape).name.toSnakeCase(), EscapeFor.ModuleName, @@ -84,10 +89,14 @@ class FluentClientGenerator( private val runtimeConfig = codegenContext.runtimeConfig private val core = FluentClientCore(model) - fun render(crate: RustCrate, customizableOperationCustomizations: List = emptyList()) { + fun render( + crate: RustCrate, + customizableOperationCustomizations: List = emptyList(), + ) { renderFluentClient(crate) - val customizableOperationGenerator = CustomizableOperationGenerator(codegenContext, customizableOperationCustomizations) + val customizableOperationGenerator = + CustomizableOperationGenerator(codegenContext, customizableOperationCustomizations) operations.forEach { operation -> crate.withModule(symbolProvider.moduleForBuilder(operation)) { renderFluentBuilder(operation) @@ -159,15 +168,16 @@ class FluentClientGenerator( "Arc" to RuntimeType.Arc, "base_client_runtime_plugins" to baseClientRuntimePluginsFn(codegenContext), "BoxError" to RuntimeType.boxError(runtimeConfig), - "client_docs" to writable { - customizations.forEach { - it.section( - FluentClientSection.FluentClientDocs( - serviceShape, - ), - )(this) - } - }, + "client_docs" to + writable { + customizations.forEach { + it.section( + FluentClientSection.FluentClientDocs( + serviceShape, + ), + )(this) + } + }, "ConfigBag" to RuntimeType.configBag(runtimeConfig), "RuntimePlugins" to RuntimeType.runtimePlugins(runtimeConfig), "tracing" to CargoDependency.Tracing.toType(), @@ -183,24 +193,27 @@ class FluentClientGenerator( crate.withModule(privateModule) { rustBlock("impl super::Client") { val fullPath = operation.fullyQualifiedFluentBuilder(symbolProvider) - val maybePaginated = if (operation.isPaginated(model)) { - "\n/// This operation supports pagination; See [`into_paginator()`]($fullPath::into_paginator)." - } else { - "" - } + val maybePaginated = + if (operation.isPaginated(model)) { + "\n/// This operation supports pagination; See [`into_paginator()`]($fullPath::into_paginator)." + } else { + "" + } val output = operation.outputShape(model) val operationOk = symbolProvider.toSymbol(output) val operationErr = symbolProvider.symbolForOperationError(operation) - val inputFieldsBody = generateOperationShapeDocs(this, symbolProvider, operation, model) - .joinToString("\n") { "/// - $it" } + val inputFieldsBody = + generateOperationShapeDocs(this, symbolProvider, operation, model) + .joinToString("\n") { "/// - $it" } - val inputFieldsHead = if (inputFieldsBody.isNotEmpty()) { - "The fluent builder is configurable:\n" - } else { - "The fluent builder takes no input, just [`send`]($fullPath::send) it." - } + val inputFieldsHead = + if (inputFieldsBody.isNotEmpty()) { + "The fluent builder is configurable:\n" + } else { + "The fluent builder takes no input, just [`send`]($fullPath::send) it." + } val outputFieldsBody = generateShapeMemberDocs(this, symbolProvider, output, model).joinToString("\n") { @@ -347,20 +360,24 @@ class FluentClientGenerator( write("&self.inner") } - val orchestratorScope = arrayOf( - *preludeScope, - "CustomizableOperation" to ClientRustModule.Client.customize.toType() - .resolve("CustomizableOperation"), - "HttpResponse" to RuntimeType.smithyRuntimeApiClient(runtimeConfig) - .resolve("client::orchestrator::HttpResponse"), - "Operation" to operationSymbol, - "OperationError" to errorType, - "OperationOutput" to outputType, - "RuntimePlugins" to RuntimeType.runtimePlugins(runtimeConfig), - "SendResult" to ClientRustModule.Client.customize.toType() - .resolve("internal::SendResult"), - "SdkError" to RuntimeType.sdkError(runtimeConfig), - ) + val orchestratorScope = + arrayOf( + *preludeScope, + "CustomizableOperation" to + ClientRustModule.Client.customize.toType() + .resolve("CustomizableOperation"), + "HttpResponse" to + RuntimeType.smithyRuntimeApiClient(runtimeConfig) + .resolve("client::orchestrator::HttpResponse"), + "Operation" to operationSymbol, + "OperationError" to errorType, + "OperationOutput" to outputType, + "RuntimePlugins" to RuntimeType.runtimePlugins(runtimeConfig), + "SendResult" to + ClientRustModule.Client.customize.toType() + .resolve("internal::SendResult"), + "SdkError" to RuntimeType.sdkError(runtimeConfig), + ) rustTemplate( """ /// Sends the request and returns the response. @@ -454,67 +471,70 @@ class FluentClientGenerator( } } -private fun baseClientRuntimePluginsFn(codegenContext: ClientCodegenContext): RuntimeType = codegenContext.runtimeConfig.let { rc -> - RuntimeType.forInlineFun("base_client_runtime_plugins", ClientRustModule.config) { - val api = RuntimeType.smithyRuntimeApiClient(rc) - val rt = RuntimeType.smithyRuntime(rc) - val behaviorVersionError = "Invalid client configuration: A behavior major version must be set when sending a " + - "request or constructing a client. You must set it during client construction or by enabling the " + - "`${BehaviorVersionLatest.name}` cargo feature." - rustTemplate( - """ - pub(crate) fn base_client_runtime_plugins( - mut config: crate::Config, - ) -> #{RuntimePlugins} { - let mut configured_plugins = #{Vec}::new(); - ::std::mem::swap(&mut config.runtime_plugins, &mut configured_plugins); - ##[allow(unused_mut)] - let mut behavior_version = config.behavior_version.clone(); - #{update_bmv} - - let mut plugins = #{RuntimePlugins}::new() - // defaults - .with_client_plugins(#{default_plugins}( - #{DefaultPluginParams}::new() - .with_retry_partition_name(${codegenContext.serviceShape.sdkId().dq()}) - .with_behavior_version(behavior_version.expect(${behaviorVersionError.dq()})) - )) - // user config - .with_client_plugin( - #{StaticRuntimePlugin}::new() - .with_config(config.config.clone()) - .with_runtime_components(config.runtime_components.clone()) - ) - // codegen config - .with_client_plugin(crate::config::ServiceRuntimePlugin::new(config)) - .with_client_plugin(#{NoAuthRuntimePlugin}::new()); - - for plugin in configured_plugins { - plugins = plugins.with_client_plugin(plugin); - } - plugins - } - """, - *preludeScope, - "DefaultPluginParams" to rt.resolve("client::defaults::DefaultPluginParams"), - "default_plugins" to rt.resolve("client::defaults::default_plugins"), - "NoAuthRuntimePlugin" to rt.resolve("client::auth::no_auth::NoAuthRuntimePlugin"), - "RuntimePlugins" to RuntimeType.runtimePlugins(rc), - "StaticRuntimePlugin" to api.resolve("client::runtime_plugin::StaticRuntimePlugin"), - "update_bmv" to featureGatedBlock(BehaviorVersionLatest) { - rustTemplate( - """ - if behavior_version.is_none() { - behavior_version = Some(#{BehaviorVersion}::latest()); +private fun baseClientRuntimePluginsFn(codegenContext: ClientCodegenContext): RuntimeType = + codegenContext.runtimeConfig.let { rc -> + RuntimeType.forInlineFun("base_client_runtime_plugins", ClientRustModule.config) { + val api = RuntimeType.smithyRuntimeApiClient(rc) + val rt = RuntimeType.smithyRuntime(rc) + val behaviorVersionError = + "Invalid client configuration: A behavior major version must be set when sending a " + + "request or constructing a client. You must set it during client construction or by enabling the " + + "`${BehaviorVersionLatest.name}` cargo feature." + rustTemplate( + """ + pub(crate) fn base_client_runtime_plugins( + mut config: crate::Config, + ) -> #{RuntimePlugins} { + let mut configured_plugins = #{Vec}::new(); + ::std::mem::swap(&mut config.runtime_plugins, &mut configured_plugins); + ##[allow(unused_mut)] + let mut behavior_version = config.behavior_version.clone(); + #{update_bmv} + + let mut plugins = #{RuntimePlugins}::new() + // defaults + .with_client_plugins(#{default_plugins}( + #{DefaultPluginParams}::new() + .with_retry_partition_name(${codegenContext.serviceShape.sdkId().dq()}) + .with_behavior_version(behavior_version.expect(${behaviorVersionError.dq()})) + )) + // user config + .with_client_plugin( + #{StaticRuntimePlugin}::new() + .with_config(config.config.clone()) + .with_runtime_components(config.runtime_components.clone()) + ) + // codegen config + .with_client_plugin(crate::config::ServiceRuntimePlugin::new(config)) + .with_client_plugin(#{NoAuthRuntimePlugin}::new()); + + for plugin in configured_plugins { + plugins = plugins.with_client_plugin(plugin); } - - """, - "BehaviorVersion" to api.resolve("client::behavior_version::BehaviorVersion"), - ) - }, - ) + plugins + } + """, + *preludeScope, + "DefaultPluginParams" to rt.resolve("client::defaults::DefaultPluginParams"), + "default_plugins" to rt.resolve("client::defaults::default_plugins"), + "NoAuthRuntimePlugin" to rt.resolve("client::auth::no_auth::NoAuthRuntimePlugin"), + "RuntimePlugins" to RuntimeType.runtimePlugins(rc), + "StaticRuntimePlugin" to api.resolve("client::runtime_plugin::StaticRuntimePlugin"), + "update_bmv" to + featureGatedBlock(BehaviorVersionLatest) { + rustTemplate( + """ + if behavior_version.is_none() { + behavior_version = Some(#{BehaviorVersion}::latest()); + } + + """, + "BehaviorVersion" to api.resolve("client::behavior_version::BehaviorVersion"), + ) + }, + ) + } } -} /** * For a given `operation` shape, return a list of strings where each string describes the name and input type of one of @@ -537,10 +557,11 @@ private fun generateOperationShapeDocs( val builderSetterLink = docLink("$fluentBuilderFullyQualifiedName::${memberShape.setterName()}") val docTrait = memberShape.getMemberTrait(model, DocumentationTrait::class.java).orNull() - val docs = when (docTrait?.value?.isNotBlank()) { - true -> normalizeHtml(writer.escape(docTrait.value)).replace("\n", " ") - else -> "(undocumented)" - } + val docs = + when (docTrait?.value?.isNotBlank()) { + true -> normalizeHtml(writer.escape(docTrait.value)).replace("\n", " ") + else -> "(undocumented)" + } "[`$builderInputDoc`]($builderInputLink) / [`$builderSetterDoc`]($builderSetterLink):
required: **${memberShape.isRequired}**
$docs
" } @@ -563,10 +584,11 @@ private fun generateShapeMemberDocs( val name = symbolProvider.toMemberName(memberShape) val member = symbolProvider.toSymbol(memberShape).rustType().render(fullyQualified = false) val docTrait = memberShape.getMemberTrait(model, DocumentationTrait::class.java).orNull() - val docs = when (docTrait?.value?.isNotBlank()) { - true -> normalizeHtml(writer.escape(docTrait.value)).replace("\n", " ") - else -> "(undocumented)" - } + val docs = + when (docTrait?.value?.isNotBlank()) { + true -> normalizeHtml(writer.escape(docTrait.value)).replace("\n", " ") + else -> "(undocumented)" + } "[`$name($member)`](${docLink("$structName::$name")}): $docs" } @@ -582,9 +604,8 @@ internal fun OperationShape.fluentBuilderType(symbolProvider: RustSymbolProvider * * * _NOTE: This function generates the links that appear under **"The fluent builder is configurable:"**_ */ -private fun OperationShape.fullyQualifiedFluentBuilder( - symbolProvider: RustSymbolProvider, -): String = fluentBuilderType(symbolProvider).fullyQualifiedName() +private fun OperationShape.fullyQualifiedFluentBuilder(symbolProvider: RustSymbolProvider): String = + fluentBuilderType(symbolProvider).fullyQualifiedName() /** * Generate a string that looks like a Rust function pointer for documenting a fluent builder method e.g. @@ -596,11 +617,12 @@ internal fun MemberShape.asFluentBuilderInputDoc(symbolProvider: SymbolProvider) val memberName = symbolProvider.toMemberName(this) val outerType = symbolProvider.toSymbol(this).rustType().stripOuter() // We generate Vec/HashMap helpers - val renderedType = when (outerType) { - is RustType.Vec -> listOf(outerType.member) - is RustType.HashMap -> listOf(outerType.key, outerType.member) - else -> listOf(outerType) - } + val renderedType = + when (outerType) { + is RustType.Vec -> listOf(outerType.member) + is RustType.HashMap -> listOf(outerType.key, outerType.member) + else -> listOf(outerType) + } val args = renderedType.joinToString { it.asArgumentType(fullyQualified = false) } return "$memberName($args)" diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/config/IdempotencyTokenProviderCustomization.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/config/IdempotencyTokenProviderCustomization.kt index fd04a062248..c41764dfa81 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/config/IdempotencyTokenProviderCustomization.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/config/IdempotencyTokenProviderCustomization.kt @@ -19,40 +19,43 @@ import software.amazon.smithy.rust.codegen.core.smithy.customize.NamedCustomizat */ class IdempotencyTokenProviderCustomization(codegenContext: ClientCodegenContext) : NamedCustomization() { private val runtimeConfig = codegenContext.runtimeConfig - private val codegenScope = arrayOf( - *preludeScope, - "IdempotencyTokenProvider" to RuntimeType.idempotencyToken(runtimeConfig).resolve("IdempotencyTokenProvider"), - ) + private val codegenScope = + arrayOf( + *preludeScope, + "IdempotencyTokenProvider" to RuntimeType.idempotencyToken(runtimeConfig).resolve("IdempotencyTokenProvider"), + ) override fun section(section: ServiceConfig): Writable { return when (section) { - ServiceConfig.BuilderImpl -> writable { - rustTemplate( - """ - /// Sets the idempotency token provider to use for service calls that require tokens. - pub fn idempotency_token_provider(mut self, idempotency_token_provider: impl #{Into}<#{IdempotencyTokenProvider}>) -> Self { - self.set_idempotency_token_provider(#{Some}(idempotency_token_provider.into())); - self - } - """, - *codegenScope, - ) + ServiceConfig.BuilderImpl -> + writable { + rustTemplate( + """ + /// Sets the idempotency token provider to use for service calls that require tokens. + pub fn idempotency_token_provider(mut self, idempotency_token_provider: impl #{Into}<#{IdempotencyTokenProvider}>) -> Self { + self.set_idempotency_token_provider(#{Some}(idempotency_token_provider.into())); + self + } + """, + *codegenScope, + ) - rustTemplate( - """ - /// Sets the idempotency token provider to use for service calls that require tokens. - pub fn set_idempotency_token_provider(&mut self, idempotency_token_provider: #{Option}<#{IdempotencyTokenProvider}>) -> &mut Self { - self.config.store_or_unset(idempotency_token_provider); - self - } - """, - *codegenScope, - ) - } + rustTemplate( + """ + /// Sets the idempotency token provider to use for service calls that require tokens. + pub fn set_idempotency_token_provider(&mut self, idempotency_token_provider: #{Option}<#{IdempotencyTokenProvider}>) -> &mut Self { + self.config.store_or_unset(idempotency_token_provider); + self + } + """, + *codegenScope, + ) + } - is ServiceConfig.DefaultForTests -> writable { - rust("""${section.configBuilderRef}.set_idempotency_token_provider(Some("00000000-0000-4000-8000-000000000000".into()));""") - } + is ServiceConfig.DefaultForTests -> + writable { + rust("""${section.configBuilderRef}.set_idempotency_token_provider(Some("00000000-0000-4000-8000-000000000000".into()));""") + } else -> writable { } } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/config/ServiceConfigGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/config/ServiceConfigGenerator.kt index 26f4cd7a634..38fecb7e486 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/config/ServiceConfigGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/config/ServiceConfigGenerator.kt @@ -118,7 +118,6 @@ data class ConfigParam( val getterDocs: Writable? = null, val optional: Boolean = true, ) { - data class Builder( var name: String? = null, var type: Symbol? = null, @@ -128,11 +127,17 @@ data class ConfigParam( var optional: Boolean = true, ) { fun name(name: String) = apply { this.name = name } + fun type(type: Symbol) = apply { this.type = type } + fun newtype(newtype: RuntimeType) = apply { this.newtype = newtype } + fun setterDocs(setterDocs: Writable?) = apply { this.setterDocs = setterDocs } + fun getterDocs(getterDocs: Writable?) = apply { this.getterDocs = getterDocs } + fun optional(optional: Boolean) = apply { this.optional = optional } + fun build() = ConfigParam(name!!, type!!, newtype, setterDocs, getterDocs, optional) } } @@ -143,23 +148,27 @@ data class ConfigParam( * When config parameters are stored in a config map in Rust, stored parameters are keyed by type. * Therefore, primitive types, such as bool and String, need to be wrapped in newtypes to make them distinct. */ -fun configParamNewtype(newtypeName: String, inner: Symbol, runtimeConfig: RuntimeConfig) = - RuntimeType.forInlineFun(newtypeName, ClientRustModule.config) { - val codegenScope = arrayOf( +fun configParamNewtype( + newtypeName: String, + inner: Symbol, + runtimeConfig: RuntimeConfig, +) = RuntimeType.forInlineFun(newtypeName, ClientRustModule.config) { + val codegenScope = + arrayOf( "Storable" to RuntimeType.smithyTypes(runtimeConfig).resolve("config_bag::Storable"), "StoreReplace" to RuntimeType.smithyTypes(runtimeConfig).resolve("config_bag::StoreReplace"), ) - rustTemplate( - """ - ##[derive(Debug, Clone)] - pub(crate) struct $newtypeName(pub(crate) $inner); - impl #{Storable} for $newtypeName { - type Storer = #{StoreReplace}; - } - """, - *codegenScope, - ) - } + rustTemplate( + """ + ##[derive(Debug, Clone)] + pub(crate) struct $newtypeName(pub(crate) $inner); + impl #{Storable} for $newtypeName { + type Storer = #{StoreReplace}; + } + """, + *codegenScope, + ) +} /** * Render an expression that loads a value from a config bag. @@ -167,21 +176,26 @@ fun configParamNewtype(newtypeName: String, inner: Symbol, runtimeConfig: Runtim * The expression to be rendered handles a case where a newtype is stored in the config bag, but the user expects * the underlying raw type after the newtype has been loaded from the bag. */ -fun loadFromConfigBag(innerTypeName: String, newtype: RuntimeType): Writable = writable { - rustTemplate( - """ - load::<#{newtype}>().map(#{f}) - """, - "newtype" to newtype, - "f" to writable { - if (innerTypeName == "bool") { - rust("|ty| ty.0") - } else { - rust("|ty| ty.0.clone()") - } - }, - ) -} +fun loadFromConfigBag( + innerTypeName: String, + newtype: RuntimeType, +): Writable = + writable { + rustTemplate( + """ + load::<#{newtype}>().map(#{f}) + """, + "newtype" to newtype, + "f" to + writable { + if (innerTypeName == "bool") { + rust("|ty| ty.0") + } else { + rust("|ty| ty.0.clone()") + } + }, + ) + } /** * Config customization for a config param with no special behavior: @@ -189,33 +203,37 @@ fun loadFromConfigBag(innerTypeName: String, newtype: RuntimeType): Writable = w * 2. convenience setter (non-optional) * 3. standard setter (&mut self) */ -fun standardConfigParam(param: ConfigParam, codegenContext: ClientCodegenContext): ConfigCustomization = +fun standardConfigParam( + param: ConfigParam, + codegenContext: ClientCodegenContext, +): ConfigCustomization = object : ConfigCustomization() { override fun section(section: ServiceConfig): Writable { return when (section) { - ServiceConfig.BuilderImpl -> writable { - docsOrFallback(param.setterDocs) - rust( - """ - pub fn ${param.name}(mut self, ${param.name}: impl Into<#T>) -> Self { - self.set_${param.name}(Some(${param.name}.into())); - self + ServiceConfig.BuilderImpl -> + writable { + docsOrFallback(param.setterDocs) + rust( + """ + pub fn ${param.name}(mut self, ${param.name}: impl Into<#T>) -> Self { + self.set_${param.name}(Some(${param.name}.into())); + self }""", - param.type, - ) - - docsOrFallback(param.setterDocs) - rustTemplate( - """ - pub fn set_${param.name}(&mut self, ${param.name}: Option<#{T}>) -> &mut Self { - self.config.store_or_unset(${param.name}.map(#{newtype})); - self - } - """, - "T" to param.type, - "newtype" to param.newtype!!, - ) - } + param.type, + ) + + docsOrFallback(param.setterDocs) + rustTemplate( + """ + pub fn set_${param.name}(&mut self, ${param.name}: Option<#{T}>) -> &mut Self { + self.config.store_or_unset(${param.name}.map(#{newtype})); + self + } + """, + "T" to param.type, + "newtype" to param.newtype!!, + ) + } else -> emptySection } @@ -267,81 +285,84 @@ class ServiceConfigGenerator( private val runtimeConfig = codegenContext.runtimeConfig private val enableUserConfigurableRuntimePlugins = codegenContext.enableUserConfigurableRuntimePlugins private val smithyTypes = RuntimeType.smithyTypes(runtimeConfig) - val codegenScope = arrayOf( - *preludeScope, - "BoxError" to RuntimeType.boxError(runtimeConfig), - "CloneableLayer" to smithyTypes.resolve("config_bag::CloneableLayer"), - "ConfigBag" to RuntimeType.configBag(codegenContext.runtimeConfig), - "Cow" to RuntimeType.Cow, - "FrozenLayer" to configReexport(smithyTypes.resolve("config_bag::FrozenLayer")), - "Layer" to configReexport(smithyTypes.resolve("config_bag::Layer")), - "Resolver" to RuntimeType.smithyRuntime(runtimeConfig).resolve("client::config_override::Resolver"), - "RuntimeComponentsBuilder" to configReexport(RuntimeType.runtimeComponentsBuilder(runtimeConfig)), - "RuntimePlugin" to configReexport(RuntimeType.runtimePlugin(runtimeConfig)), - "SharedRuntimePlugin" to configReexport(RuntimeType.sharedRuntimePlugin(runtimeConfig)), - "runtime_plugin" to RuntimeType.smithyRuntimeApiClient(runtimeConfig).resolve("client::runtime_plugin"), - "BehaviorVersion" to configReexport( - RuntimeType.smithyRuntimeApi(runtimeConfig).resolve("client::behavior_version::BehaviorVersion"), - ), - ) + val codegenScope = + arrayOf( + *preludeScope, + "BoxError" to RuntimeType.boxError(runtimeConfig), + "CloneableLayer" to smithyTypes.resolve("config_bag::CloneableLayer"), + "ConfigBag" to RuntimeType.configBag(codegenContext.runtimeConfig), + "Cow" to RuntimeType.Cow, + "FrozenLayer" to configReexport(smithyTypes.resolve("config_bag::FrozenLayer")), + "Layer" to configReexport(smithyTypes.resolve("config_bag::Layer")), + "Resolver" to RuntimeType.smithyRuntime(runtimeConfig).resolve("client::config_override::Resolver"), + "RuntimeComponentsBuilder" to configReexport(RuntimeType.runtimeComponentsBuilder(runtimeConfig)), + "RuntimePlugin" to configReexport(RuntimeType.runtimePlugin(runtimeConfig)), + "SharedRuntimePlugin" to configReexport(RuntimeType.sharedRuntimePlugin(runtimeConfig)), + "runtime_plugin" to RuntimeType.smithyRuntimeApiClient(runtimeConfig).resolve("client::runtime_plugin"), + "BehaviorVersion" to + configReexport( + RuntimeType.smithyRuntimeApi(runtimeConfig).resolve("client::behavior_version::BehaviorVersion"), + ), + ) - private fun behaviorMv() = writable { - val docs = """ - /// Sets the [`behavior major version`](crate::config::BehaviorVersion). - /// - /// Over time, new best-practice behaviors are introduced. However, these behaviors might not be backwards - /// compatible. For example, a change which introduces new default timeouts or a new retry-mode for - /// all operations might be the ideal behavior but could break existing applications. - /// - /// ## Examples - /// - /// Set the behavior major version to `latest`. This is equivalent to enabling the `behavior-version-latest` cargo feature. - /// ```no_run - /// use $moduleUseName::config::BehaviorVersion; - /// - /// let config = $moduleUseName::Config::builder() - /// .behavior_version(BehaviorVersion::latest()) - /// // ... - /// .build(); - /// let client = $moduleUseName::Client::from_conf(config); - /// ``` - /// - /// Customizing behavior major version: - /// ```no_run - /// use $moduleUseName::config::BehaviorVersion; - /// - /// let config = $moduleUseName::Config::builder() - /// .behavior_version(BehaviorVersion::v2023_11_09()) - /// // ... - /// .build(); - /// let client = $moduleUseName::Client::from_conf(config); - /// ``` - """ - rustTemplate( + private fun behaviorMv() = + writable { + val docs = """ + /// Sets the [`behavior major version`](crate::config::BehaviorVersion). + /// + /// Over time, new best-practice behaviors are introduced. However, these behaviors might not be backwards + /// compatible. For example, a change which introduces new default timeouts or a new retry-mode for + /// all operations might be the ideal behavior but could break existing applications. + /// + /// ## Examples + /// + /// Set the behavior major version to `latest`. This is equivalent to enabling the `behavior-version-latest` cargo feature. + /// ```no_run + /// use $moduleUseName::config::BehaviorVersion; + /// + /// let config = $moduleUseName::Config::builder() + /// .behavior_version(BehaviorVersion::latest()) + /// // ... + /// .build(); + /// let client = $moduleUseName::Client::from_conf(config); + /// ``` + /// + /// Customizing behavior major version: + /// ```no_run + /// use $moduleUseName::config::BehaviorVersion; + /// + /// let config = $moduleUseName::Config::builder() + /// .behavior_version(BehaviorVersion::v2023_11_09()) + /// // ... + /// .build(); + /// let client = $moduleUseName::Client::from_conf(config); + /// ``` """ - $docs - pub fn behavior_version(mut self, behavior_version: crate::config::BehaviorVersion) -> Self { - self.set_behavior_version(Some(behavior_version)); - self - } + rustTemplate( + """ + $docs + pub fn behavior_version(mut self, behavior_version: crate::config::BehaviorVersion) -> Self { + self.set_behavior_version(Some(behavior_version)); + self + } - $docs - pub fn set_behavior_version(&mut self, behavior_version: Option) -> &mut Self { - self.behavior_version = behavior_version; - self - } + $docs + pub fn set_behavior_version(&mut self, behavior_version: Option) -> &mut Self { + self.behavior_version = behavior_version; + self + } - /// Convenience method to set the latest behavior major version - /// - /// This is equivalent to enabling the `behavior-version-latest` Cargo feature - pub fn behavior_version_latest(mut self) -> Self { - self.set_behavior_version(Some(crate::config::BehaviorVersion::latest())); - self - } - """, - *codegenScope, - ) - } + /// Convenience method to set the latest behavior major version + /// + /// This is equivalent to enabling the `behavior-version-latest` Cargo feature + pub fn behavior_version_latest(mut self) -> Self { + self.set_behavior_version(Some(crate::config::BehaviorVersion::latest())); + self + } + """, + *codegenScope, + ) + } fun render(writer: RustWriter) { writer.docs("Configuration for a $moduleUseName service client.\n") @@ -435,11 +456,12 @@ class ServiceConfigGenerator( } behaviorMv()(this) - val visibility = if (enableUserConfigurableRuntimePlugins) { - "pub" - } else { - "pub(crate)" - } + val visibility = + if (enableUserConfigurableRuntimePlugins) { + "pub" + } else { + "pub(crate)" + } docs("Adds a runtime plugin to the config.") if (!enableUserConfigurableRuntimePlugins) { diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/config/StalledStreamProtectionConfigCustomization.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/config/StalledStreamProtectionConfigCustomization.kt index 03164cbfab1..0c67acbf43c 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/config/StalledStreamProtectionConfigCustomization.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/config/StalledStreamProtectionConfigCustomization.kt @@ -43,55 +43,58 @@ class StalledStreamProtectionDecorator : ClientCodegenDecorator { */ class StalledStreamProtectionConfigCustomization(codegenContext: ClientCodegenContext) : NamedCustomization() { private val rc = codegenContext.runtimeConfig - private val codegenScope = arrayOf( - *preludeScope, - "StalledStreamProtectionConfig" to configReexport(RuntimeType.smithyRuntimeApi(rc).resolve("client::stalled_stream_protection::StalledStreamProtectionConfig")), - ) + private val codegenScope = + arrayOf( + *preludeScope, + "StalledStreamProtectionConfig" to configReexport(RuntimeType.smithyRuntimeApi(rc).resolve("client::stalled_stream_protection::StalledStreamProtectionConfig")), + ) override fun section(section: ServiceConfig): Writable { return when (section) { - ServiceConfig.ConfigImpl -> writable { - rustTemplate( - """ - /// Return a reference to the stalled stream protection configuration contained in this config, if any. - pub fn stalled_stream_protection(&self) -> #{Option}<&#{StalledStreamProtectionConfig}> { - self.config.load::<#{StalledStreamProtectionConfig}>() - } - """, - *codegenScope, - ) - } - ServiceConfig.BuilderImpl -> writable { - rustTemplate( - """ - /// Set the [`StalledStreamProtectionConfig`](#{StalledStreamProtectionConfig}) - /// to configure protection for stalled streams. - pub fn stalled_stream_protection( - mut self, - stalled_stream_protection_config: #{StalledStreamProtectionConfig} - ) -> Self { - self.set_stalled_stream_protection(#{Some}(stalled_stream_protection_config)); - self - } - """, - *codegenScope, - ) + ServiceConfig.ConfigImpl -> + writable { + rustTemplate( + """ + /// Return a reference to the stalled stream protection configuration contained in this config, if any. + pub fn stalled_stream_protection(&self) -> #{Option}<&#{StalledStreamProtectionConfig}> { + self.config.load::<#{StalledStreamProtectionConfig}>() + } + """, + *codegenScope, + ) + } + ServiceConfig.BuilderImpl -> + writable { + rustTemplate( + """ + /// Set the [`StalledStreamProtectionConfig`](#{StalledStreamProtectionConfig}) + /// to configure protection for stalled streams. + pub fn stalled_stream_protection( + mut self, + stalled_stream_protection_config: #{StalledStreamProtectionConfig} + ) -> Self { + self.set_stalled_stream_protection(#{Some}(stalled_stream_protection_config)); + self + } + """, + *codegenScope, + ) - rustTemplate( - """ - /// Set the [`StalledStreamProtectionConfig`](#{StalledStreamProtectionConfig}) - /// to configure protection for stalled streams. - pub fn set_stalled_stream_protection( - &mut self, - stalled_stream_protection_config: #{Option}<#{StalledStreamProtectionConfig}> - ) -> &mut Self { - self.config.store_or_unset(stalled_stream_protection_config); - self - } - """, - *codegenScope, - ) - } + rustTemplate( + """ + /// Set the [`StalledStreamProtectionConfig`](#{StalledStreamProtectionConfig}) + /// to configure protection for stalled streams. + pub fn set_stalled_stream_protection( + &mut self, + stalled_stream_protection_config: #{Option}<#{StalledStreamProtectionConfig}> + ) -> &mut Self { + self.config.store_or_unset(stalled_stream_protection_config); + self + } + """, + *codegenScope, + ) + } else -> emptySection } @@ -103,24 +106,25 @@ class StalledStreamProtectionOperationCustomization( ) : OperationCustomization() { private val rc = codegenContext.runtimeConfig - override fun section(section: OperationSection): Writable = writable { - when (section) { - is OperationSection.AdditionalInterceptors -> { - val stalledStreamProtectionModule = RuntimeType.smithyRuntime(rc).resolve("client::stalled_stream_protection") - section.registerInterceptor(rc, this) { - // Currently, only response bodies are protected/supported because - // we can't count on hyper to poll a request body on wake. - rustTemplate( - """ - #{StalledStreamProtectionInterceptor}::new(#{Kind}::ResponseBody) - """, - *preludeScope, - "StalledStreamProtectionInterceptor" to stalledStreamProtectionModule.resolve("StalledStreamProtectionInterceptor"), - "Kind" to stalledStreamProtectionModule.resolve("StalledStreamProtectionInterceptorKind"), - ) + override fun section(section: OperationSection): Writable = + writable { + when (section) { + is OperationSection.AdditionalInterceptors -> { + val stalledStreamProtectionModule = RuntimeType.smithyRuntime(rc).resolve("client::stalled_stream_protection") + section.registerInterceptor(rc, this) { + // Currently, only response bodies are protected/supported because + // we can't count on hyper to poll a request body on wake. + rustTemplate( + """ + #{StalledStreamProtectionInterceptor}::new(#{Kind}::ResponseBody) + """, + *preludeScope, + "StalledStreamProtectionInterceptor" to stalledStreamProtectionModule.resolve("StalledStreamProtectionInterceptor"), + "Kind" to stalledStreamProtectionModule.resolve("StalledStreamProtectionInterceptorKind"), + ) + } } + else -> { } } - else -> { } } - } } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/error/ErrorGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/error/ErrorGenerator.kt index d90fd441529..92d716d329f 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/error/ErrorGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/error/ErrorGenerator.kt @@ -46,19 +46,20 @@ class ErrorGenerator( model, symbolProvider, this, shape, listOf( object : StructureCustomization() { - override fun section(section: StructureSection): Writable = writable { - when (section) { - is StructureSection.AdditionalFields -> { - rust("pub(crate) meta: #T,", errorMetadata(runtimeConfig)) - } + override fun section(section: StructureSection): Writable = + writable { + when (section) { + is StructureSection.AdditionalFields -> { + rust("pub(crate) meta: #T,", errorMetadata(runtimeConfig)) + } - is StructureSection.AdditionalDebugFields -> { - rust("""${section.formatterName}.field("meta", &self.meta);""") - } + is StructureSection.AdditionalDebugFields -> { + rust("""${section.formatterName}.field("meta", &self.meta);""") + } - else -> {} + else -> {} + } } - } }, ), structSettings, @@ -89,38 +90,39 @@ class ErrorGenerator( model, symbolProvider, shape, listOf( object : BuilderCustomization() { - override fun section(section: BuilderSection): Writable = writable { - when (section) { - is BuilderSection.AdditionalFields -> { - rust("meta: std::option::Option<#T>,", errorMetadata(runtimeConfig)) - } + override fun section(section: BuilderSection): Writable = + writable { + when (section) { + is BuilderSection.AdditionalFields -> { + rust("meta: std::option::Option<#T>,", errorMetadata(runtimeConfig)) + } - is BuilderSection.AdditionalMethods -> { - rustTemplate( - """ - /// Sets error metadata - pub fn meta(mut self, meta: #{error_metadata}) -> Self { - self.meta = Some(meta); - self - } + is BuilderSection.AdditionalMethods -> { + rustTemplate( + """ + /// Sets error metadata + pub fn meta(mut self, meta: #{error_metadata}) -> Self { + self.meta = Some(meta); + self + } - /// Sets error metadata - pub fn set_meta(&mut self, meta: std::option::Option<#{error_metadata}>) -> &mut Self { - self.meta = meta; - self - } - """, - "error_metadata" to errorMetadata(runtimeConfig), - ) - } + /// Sets error metadata + pub fn set_meta(&mut self, meta: std::option::Option<#{error_metadata}>) -> &mut Self { + self.meta = meta; + self + } + """, + "error_metadata" to errorMetadata(runtimeConfig), + ) + } - is BuilderSection.AdditionalFieldsInBuild -> { - rust("meta: self.meta.unwrap_or_default(),") - } + is BuilderSection.AdditionalFieldsInBuild -> { + rust("meta: self.meta.unwrap_or_default(),") + } - else -> {} + else -> {} + } } - } }, ), ).render(this) diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/error/OperationErrorGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/error/OperationErrorGenerator.kt index 280b06c2cac..03b97fffaaa 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/error/OperationErrorGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/error/OperationErrorGenerator.kt @@ -56,22 +56,25 @@ class OperationErrorGenerator( private fun operationErrors(): List = (operationOrEventStream as OperationShape).operationErrors(model).map { it.asStructureShape().get() } + private fun eventStreamErrors(): List = (operationOrEventStream as UnionShape).eventStreamErrors() .map { model.expectShape(it.asMemberShape().get().target, StructureShape::class.java) } fun render(writer: RustWriter) { - val (errorSymbol, errors) = when (operationOrEventStream) { - is OperationShape -> symbolProvider.symbolForOperationError(operationOrEventStream) to operationErrors() - is UnionShape -> symbolProvider.symbolForEventStreamError(operationOrEventStream) to eventStreamErrors() - else -> UNREACHABLE("OperationErrorGenerator only supports operation or event stream shapes") - } + val (errorSymbol, errors) = + when (operationOrEventStream) { + is OperationShape -> symbolProvider.symbolForOperationError(operationOrEventStream) to operationErrors() + is UnionShape -> symbolProvider.symbolForEventStreamError(operationOrEventStream) to eventStreamErrors() + else -> UNREACHABLE("OperationErrorGenerator only supports operation or event stream shapes") + } - val meta = RustMetadata( - derives = setOf(RuntimeType.Debug), - additionalAttributes = listOf(Attribute.NonExhaustive), - visibility = Visibility.PUBLIC, - ) + val meta = + RustMetadata( + derives = setOf(RuntimeType.Debug), + additionalAttributes = listOf(Attribute.NonExhaustive), + visibility = Visibility.PUBLIC, + ) writer.rust("/// Error type for the `${errorSymbol.name}` operation.") meta.render(writer) @@ -123,24 +126,28 @@ class OperationErrorGenerator( } } - private fun RustWriter.renderImplDisplay(errorSymbol: Symbol, errors: List) { + private fun RustWriter.renderImplDisplay( + errorSymbol: Symbol, + errors: List, + ) { rustBlock("impl #T for ${errorSymbol.name}", RuntimeType.Display) { rustBlock("fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result") { delegateToVariants(errors) { variantMatch -> when (variantMatch) { - is VariantMatch.Unhandled -> writable { - rustTemplate( - """ - if let #{Some}(code) = #{ProvideErrorMetadata}::code(self) { - write!(f, "unhandled error ({code})") - } else { - f.write_str("unhandled error") - } - """, - *preludeScope, - "ProvideErrorMetadata" to RuntimeType.provideErrorMetadataTrait(runtimeConfig), - ) - } + is VariantMatch.Unhandled -> + writable { + rustTemplate( + """ + if let #{Some}(code) = #{ProvideErrorMetadata}::code(self) { + write!(f, "unhandled error ({code})") + } else { + f.write_str("unhandled error") + } + """, + *preludeScope, + "ProvideErrorMetadata" to RuntimeType.provideErrorMetadataTrait(runtimeConfig), + ) + } is VariantMatch.Modeled -> writable { rust("_inner.fmt(f)") } } } @@ -148,7 +155,10 @@ class OperationErrorGenerator( } } - private fun RustWriter.renderImplProvideErrorMetadata(errorSymbol: Symbol, errors: List) { + private fun RustWriter.renderImplProvideErrorMetadata( + errorSymbol: Symbol, + errors: List, + ) { val errorMetadataTrait = RuntimeType.provideErrorMetadataTrait(runtimeConfig) rustBlock("impl #T for ${errorSymbol.name}", errorMetadataTrait) { rustBlock("fn meta(&self) -> &#T", errorMetadata(runtimeConfig)) { @@ -164,7 +174,10 @@ class OperationErrorGenerator( } } - private fun RustWriter.renderImplProvideErrorKind(errorSymbol: Symbol, errors: List) { + private fun RustWriter.renderImplProvideErrorKind( + errorSymbol: Symbol, + errors: List, + ) { val retryErrorKindT = RuntimeType.retryErrorKind(symbolProvider.config.runtimeConfig) rustBlock( "impl #T for ${errorSymbol.name}", @@ -198,7 +211,10 @@ class OperationErrorGenerator( } } - private fun RustWriter.renderImpl(errorSymbol: Symbol, errors: List) { + private fun RustWriter.renderImpl( + errorSymbol: Symbol, + errors: List, + ) { rustBlock("impl ${errorSymbol.name}") { rustTemplate( """ @@ -246,7 +262,10 @@ class OperationErrorGenerator( } } - private fun RustWriter.renderImplStdError(errorSymbol: Symbol, errors: List) { + private fun RustWriter.renderImplStdError( + errorSymbol: Symbol, + errors: List, + ) { rustBlock("impl #T for ${errorSymbol.name}", RuntimeType.StdError) { rustBlockTemplate( "fn source(&self) -> #{Option}<&(dyn #{StdError} + 'static)>", @@ -255,12 +274,14 @@ class OperationErrorGenerator( ) { delegateToVariants(errors) { variantMatch -> when (variantMatch) { - is VariantMatch.Unhandled -> writable { - rustTemplate("#{Some}(&*_inner.source)", *preludeScope) - } - is VariantMatch.Modeled -> writable { - rustTemplate("#{Some}(_inner)", *preludeScope) - } + is VariantMatch.Unhandled -> + writable { + rustTemplate("#{Some}(&*_inner.source)", *preludeScope) + } + is VariantMatch.Modeled -> + writable { + rustTemplate("#{Some}(_inner)", *preludeScope) + } } } } @@ -269,6 +290,7 @@ class OperationErrorGenerator( sealed class VariantMatch(name: String) : Section(name) { object Unhandled : VariantMatch("Unhandled") + data class Modeled(val symbol: Symbol, val shape: Shape) : VariantMatch("Modeled") } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/error/ServiceErrorGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/error/ServiceErrorGenerator.kt index 5a33941845b..0f8789d4e6f 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/error/ServiceErrorGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/error/ServiceErrorGenerator.kt @@ -57,11 +57,12 @@ class ServiceErrorGenerator( private val symbolProvider = codegenContext.symbolProvider private val model = codegenContext.model - private val allErrors = operations.flatMap { - it.allErrors(model) - }.map { it.id }.distinctBy { it.getName(codegenContext.serviceShape) } - .map { codegenContext.model.expectShape(it, StructureShape::class.java) } - .sortedBy { it.id.getName(codegenContext.serviceShape) } + private val allErrors = + operations.flatMap { + it.allErrors(model) + }.map { it.id }.distinctBy { it.getName(codegenContext.serviceShape) } + .map { codegenContext.model.expectShape(it, StructureShape::class.java) } + .sortedBy { it.id.getName(codegenContext.serviceShape) } private val sdkError = RuntimeType.sdkError(codegenContext.runtimeConfig) @@ -140,7 +141,10 @@ class ServiceErrorGenerator( ) } - private fun RustWriter.renderImplFrom(errorSymbol: Symbol, errors: List) { + private fun RustWriter.renderImplFrom( + errorSymbol: Symbol, + errors: List, + ) { if (errors.isNotEmpty() || CodegenTarget.CLIENT == codegenContext.target) { val operationErrors = errors.map { model.expectShape(it) } rustBlock( @@ -205,16 +209,19 @@ class ServiceErrorGenerator( } """, *preludeScope, - "ErrorMetadata" to RuntimeType.smithyTypes(codegenContext.runtimeConfig) - .resolve("error::metadata::ErrorMetadata"), - "ProvideErrorMetadata" to RuntimeType.smithyTypes(codegenContext.runtimeConfig) - .resolve("error::metadata::ProvideErrorMetadata"), - "matchers" to writable { - allErrors.forEach { errorShape -> - val errSymbol = symbolProvider.toSymbol(errorShape) - rust("Self::${errSymbol.name}(inner) => inner.meta(),") - } - }, + "ErrorMetadata" to + RuntimeType.smithyTypes(codegenContext.runtimeConfig) + .resolve("error::metadata::ErrorMetadata"), + "ProvideErrorMetadata" to + RuntimeType.smithyTypes(codegenContext.runtimeConfig) + .resolve("error::metadata::ProvideErrorMetadata"), + "matchers" to + writable { + allErrors.forEach { errorShape -> + val errSymbol = symbolProvider.toSymbol(errorShape) + rust("Self::${errSymbol.name}(inner) => inner.meta(),") + } + }, ) } @@ -238,47 +245,53 @@ class ServiceErrorGenerator( } } -fun unhandledError(rc: RuntimeConfig): RuntimeType = RuntimeType.forInlineFun( - "Unhandled", - // Place in a sealed module so that it can't be referenced at all - RustModule.pubCrate("sealed_unhandled", ClientRustModule.Error), +fun unhandledError(rc: RuntimeConfig): RuntimeType = + RuntimeType.forInlineFun( + "Unhandled", + // Place in a sealed module so that it can't be referenced at all + RustModule.pubCrate("sealed_unhandled", ClientRustModule.Error), + ) { + rustTemplate( + """ + /// This struct is not intended to be used. + /// + /// This struct holds information about an unhandled error, + /// but that information should be obtained by using the + /// [`ProvideErrorMetadata`](#{ProvideErrorMetadata}) trait + /// on the error type. + /// + /// This struct intentionally doesn't yield any useful information itself. + #{deprecation} + ##[derive(Debug)] + pub struct Unhandled { + pub(crate) source: #{BoxError}, + pub(crate) meta: #{ErrorMetadata}, + } + """, + "BoxError" to RuntimeType.smithyRuntimeApi(rc).resolve("box_error::BoxError"), + "deprecation" to writable { renderUnhandledErrorDeprecation(rc) }, + "ErrorMetadata" to RuntimeType.smithyTypes(rc).resolve("error::metadata::ErrorMetadata"), + "ProvideErrorMetadata" to RuntimeType.smithyTypes(rc).resolve("error::metadata::ProvideErrorMetadata"), + ) + } + +fun RustWriter.renderUnhandledErrorDeprecation( + rc: RuntimeConfig, + errorName: String? = null, ) { - rustTemplate( - """ - /// This struct is not intended to be used. - /// - /// This struct holds information about an unhandled error, - /// but that information should be obtained by using the - /// [`ProvideErrorMetadata`](#{ProvideErrorMetadata}) trait - /// on the error type. - /// - /// This struct intentionally doesn't yield any useful information itself. - #{deprecation} - ##[derive(Debug)] - pub struct Unhandled { - pub(crate) source: #{BoxError}, - pub(crate) meta: #{ErrorMetadata}, + val link = + if (errorName != null) { + "##impl-ProvideErrorMetadata-for-$errorName" + } else { + "#{ProvideErrorMetadata}" } - """, - "BoxError" to RuntimeType.smithyRuntimeApi(rc).resolve("box_error::BoxError"), - "deprecation" to writable { renderUnhandledErrorDeprecation(rc) }, - "ErrorMetadata" to RuntimeType.smithyTypes(rc).resolve("error::metadata::ErrorMetadata"), - "ProvideErrorMetadata" to RuntimeType.smithyTypes(rc).resolve("error::metadata::ProvideErrorMetadata"), - ) -} - -fun RustWriter.renderUnhandledErrorDeprecation(rc: RuntimeConfig, errorName: String? = null) { - val link = if (errorName != null) { - "##impl-ProvideErrorMetadata-for-$errorName" - } else { - "#{ProvideErrorMetadata}" - } - val message = """ + val message = + """ Matching `Unhandled` directly is not forwards compatible. Instead, match using a variable wildcard pattern and check `.code()`:
   `err if err.code() == Some("SpecificExceptionCode") => { /* handle the error */ }`
See [`ProvideErrorMetadata`]($link) for what information is available for the error. - """.trimIndent() + """.trimIndent() // `.dq()` doesn't quite do what we want here since we actually want a Rust multi-line string val messageEscaped = message.replace("\"", "\\\"").replace("\n", " \\\n").replace("
", "\n") rustTemplate( diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/http/RequestBindingGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/http/RequestBindingGenerator.kt index 0a4e5fb09b8..24aa79d6cd7 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/http/RequestBindingGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/http/RequestBindingGenerator.kt @@ -39,13 +39,17 @@ fun HttpTrait.uriFormatString(): String { return uri.rustFormatString("/", "/") } -fun SmithyPattern.rustFormatString(prefix: String, separator: String): String { - val base = segments.joinToString(separator = separator, prefix = prefix) { - when { - it.isLabel -> "{${it.content}}" - else -> it.content +fun SmithyPattern.rustFormatString( + prefix: String, + separator: String, +): String { + val base = + segments.joinToString(separator = separator, prefix = prefix) { + when { + it.isLabel -> "{${it.content}}" + else -> it.content + } } - } return base.dq() } @@ -71,12 +75,13 @@ class RequestBindingGenerator( private val index = HttpBindingIndex.of(model) private val encoder = RuntimeType.smithyTypes(runtimeConfig).resolve("primitive::Encoder") - private val codegenScope = arrayOf( - *preludeScope, - "BuildError" to runtimeConfig.operationBuildError(), - "HttpRequestBuilder" to RuntimeType.HttpRequestBuilder, - "Input" to symbolProvider.toSymbol(inputShape), - ) + private val codegenScope = + arrayOf( + *preludeScope, + "BuildError" to runtimeConfig.operationBuildError(), + "HttpRequestBuilder" to RuntimeType.HttpRequestBuilder, + "Input" to symbolProvider.toSymbol(inputShape), + ) /** * Generates `update_http_builder` and all necessary dependency functions into the impl block provided by @@ -113,7 +118,7 @@ class RequestBindingGenerator( } } - /** URI Generation **/ + // URI Generation ** /** * Generate a function to build the request URI @@ -122,10 +127,11 @@ class RequestBindingGenerator( val formatString = httpTrait.uriFormatString() // name of a local variable containing this member's component of the URI val local = { member: MemberShape -> symbolProvider.toMemberName(member) } - val args = httpTrait.uri.labels.map { label -> - val member = inputShape.expectMember(label.content) - "${label.content} = ${local(member)}" - } + val args = + httpTrait.uri.labels.map { label -> + val member = inputShape.expectMember(label.content) + "${label.content} = ${local(member)}" + } val combinedArgs = listOf(formatString, *args.toTypedArray()) writer.rustBlockTemplate( "fn uri_base(_input: &#{Input}, output: &mut #{String}) -> #{Result}<(), #{BuildError}>", @@ -213,13 +219,15 @@ class RequestBindingGenerator( val target = model.expectShape(memberShape.target) if (memberShape.isRequired) { - val codegenScope = arrayOf( - *preludeScope, - "BuildError" to OperationBuildError(runtimeConfig).missingField( - memberName, - "cannot be empty or unset", - ), - ) + val codegenScope = + arrayOf( + *preludeScope, + "BuildError" to + OperationBuildError(runtimeConfig).missingField( + memberName, + "cannot be empty or unset", + ), + ) val derefName = safeName("inner") rust("let $derefName = &_input.$memberName;") if (memberSymbol.isOptional()) { @@ -266,7 +274,12 @@ class RequestBindingGenerator( /** * Format [member] when used as a queryParam */ - private fun paramFmtFun(writer: RustWriter, target: Shape, member: MemberShape, targetName: String): String { + private fun paramFmtFun( + writer: RustWriter, + target: Shape, + member: MemberShape, + targetName: String, + ): String { return when { target.isStringShape -> { val func = writer.format(RuntimeType.queryFormat(runtimeConfig, "fmt_string")) @@ -291,13 +304,18 @@ class RequestBindingGenerator( } } - private fun RustWriter.serializeLabel(member: MemberShape, label: SmithyPattern.Segment, outputVar: String) { + private fun RustWriter.serializeLabel( + member: MemberShape, + label: SmithyPattern.Segment, + outputVar: String, + ) { val target = model.expectShape(member.target) val symbol = symbolProvider.toSymbol(member) - val buildError = OperationBuildError(runtimeConfig).missingField( - symbolProvider.toMemberName(member), - "cannot be empty or unset", - ) + val buildError = + OperationBuildError(runtimeConfig).missingField( + symbolProvider.toMemberName(member), + "cannot be empty or unset", + ) val input = safeName("input") rust("let $input = &_input.${symbolProvider.toMemberName(member)};") if (symbol.isOptional()) { @@ -306,11 +324,12 @@ class RequestBindingGenerator( when { target.isStringShape -> { val func = format(RuntimeType.labelFormat(runtimeConfig, "fmt_string")) - val encodingStrategy = if (label.isGreedyLabel) { - RuntimeType.labelFormat(runtimeConfig, "EncodingStrategy::Greedy") - } else { - RuntimeType.labelFormat(runtimeConfig, "EncodingStrategy::Default") - } + val encodingStrategy = + if (label.isGreedyLabel) { + RuntimeType.labelFormat(runtimeConfig, "EncodingStrategy::Greedy") + } else { + RuntimeType.labelFormat(runtimeConfig, "EncodingStrategy::Default") + } rust("let $outputVar = $func($input, #T);", encodingStrategy) } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolParserGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolParserGenerator.kt index 8ce713b457d..3b34aaed056 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolParserGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolParserGenerator.kt @@ -50,14 +50,15 @@ class ProtocolParserGenerator( private val protocolFunctions = ProtocolFunctions(codegenContext) private val symbolProvider: RustSymbolProvider = codegenContext.symbolProvider - private val codegenScope = arrayOf( - "Bytes" to RuntimeType.Bytes, - "Headers" to RuntimeType.headers(codegenContext.runtimeConfig), - "Response" to RuntimeType.smithyRuntimeApi(codegenContext.runtimeConfig).resolve("http::Response"), - "http" to RuntimeType.Http, - "operation" to RuntimeType.operationModule(codegenContext.runtimeConfig), - "SdkBody" to RuntimeType.sdkBody(codegenContext.runtimeConfig), - ) + private val codegenScope = + arrayOf( + "Bytes" to RuntimeType.Bytes, + "Headers" to RuntimeType.headers(codegenContext.runtimeConfig), + "Response" to RuntimeType.smithyRuntimeApi(codegenContext.runtimeConfig).resolve("http::Response"), + "http" to RuntimeType.Http, + "operation" to RuntimeType.operationModule(codegenContext.runtimeConfig), + "SdkBody" to RuntimeType.sdkBody(codegenContext.runtimeConfig), + ) fun parseResponseFn( operationShape: OperationShape, @@ -151,11 +152,12 @@ class ProtocolParserGenerator( errorSymbol, listOf( object : OperationCustomization() { - override fun section(section: OperationSection): Writable = { - if (section is OperationSection.MutateOutput) { - rust("let output = output.meta(generic);") + override fun section(section: OperationSection): Writable = + { + if (section is OperationSection.MutateOutput) { + rust("let output = output.meta(generic);") + } } - } }, ), ) @@ -264,9 +266,10 @@ class ProtocolParserGenerator( } } - val mapErr = writable { - rust("#T::unhandled", errorSymbol) - } + val mapErr = + writable { + rust("#T::unhandled", errorSymbol) + } writeCustomizations( customizations, @@ -289,16 +292,17 @@ class ProtocolParserGenerator( val errorSymbol = symbolProvider.symbolForOperationError(operationShape) val member = binding.member return when (binding.location) { - HttpLocation.HEADER -> writable { - val fnName = httpBindingGenerator.generateDeserializeHeaderFn(binding) - rust( - """ - #T(_response_headers) - .map_err(|_|#T::unhandled("Failed to parse ${member.memberName} from header `${binding.locationName}"))? - """, - fnName, errorSymbol, - ) - } + HttpLocation.HEADER -> + writable { + val fnName = httpBindingGenerator.generateDeserializeHeaderFn(binding) + rust( + """ + #T(_response_headers) + .map_err(|_|#T::unhandled("Failed to parse ${member.memberName} from header `${binding.locationName}"))? + """, + fnName, errorSymbol, + ) + } HttpLocation.DOCUMENT -> { // document is handled separately null @@ -307,20 +311,22 @@ class ProtocolParserGenerator( val payloadParser: RustWriter.(String) -> Unit = { body -> rust("#T($body).map_err(#T::unhandled)", structuredDataParser.payloadParser(member), errorSymbol) } - val deserializer = httpBindingGenerator.generateDeserializePayloadFn( - binding, - errorSymbol, - payloadParser = payloadParser, - ) + val deserializer = + httpBindingGenerator.generateDeserializePayloadFn( + binding, + errorSymbol, + payloadParser = payloadParser, + ) return if (binding.member.isStreaming(model)) { writable { rust("Some(#T(_response_body)?)", deserializer) } } else { writable { rust("#T(_response_body)?", deserializer) } } } - HttpLocation.RESPONSE_CODE -> writable { - rust("Some(_response_status as _)") - } + HttpLocation.RESPONSE_CODE -> + writable { + rust("Some(_response_status as _)") + } HttpLocation.PREFIX_HEADERS -> { val sym = httpBindingGenerator.generateDeserializePrefixHeaderFn(binding) writable { diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolTestGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolTestGenerator.kt index 133b0dddf72..5f0fa8464ed 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolTestGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolTestGenerator.kt @@ -67,7 +67,6 @@ class DefaultProtocolTestGenerator( override val codegenContext: ClientCodegenContext, override val protocolSupport: ProtocolSupport, override val operationShape: OperationShape, - private val renderClientCreation: RustWriter.(ClientCreationParams) -> Unit = { params -> rustTemplate( """ @@ -91,37 +90,44 @@ class DefaultProtocolTestGenerator( private val instantiator = ClientInstantiator(codegenContext) - private val codegenScope = arrayOf( - "SmithyHttp" to RT.smithyHttp(rc), - "AssertEq" to RT.PrettyAssertions.resolve("assert_eq!"), - "Uri" to RT.Http.resolve("Uri"), - ) + private val codegenScope = + arrayOf( + "SmithyHttp" to RT.smithyHttp(rc), + "AssertEq" to RT.PrettyAssertions.resolve("assert_eq!"), + "Uri" to RT.Http.resolve("Uri"), + ) sealed class TestCase { abstract val testCase: HttpMessageTestCase data class RequestTest(override val testCase: HttpRequestTestCase) : TestCase() + data class ResponseTest(override val testCase: HttpResponseTestCase, val targetShape: StructureShape) : TestCase() } override fun render(writer: RustWriter) { - val requestTests = operationShape.getTrait() - ?.getTestCasesFor(AppliesTo.CLIENT).orEmpty().map { TestCase.RequestTest(it) } - val responseTests = operationShape.getTrait() - ?.getTestCasesFor(AppliesTo.CLIENT).orEmpty().map { TestCase.ResponseTest(it, outputShape) } - val errorTests = operationIndex.getErrors(operationShape).flatMap { error -> - val testCases = error.getTrait() - ?.getTestCasesFor(AppliesTo.CLIENT).orEmpty() - testCases.map { TestCase.ResponseTest(it, error) } - } + val requestTests = + operationShape.getTrait() + ?.getTestCasesFor(AppliesTo.CLIENT).orEmpty().map { TestCase.RequestTest(it) } + val responseTests = + operationShape.getTrait() + ?.getTestCasesFor(AppliesTo.CLIENT).orEmpty().map { TestCase.ResponseTest(it, outputShape) } + val errorTests = + operationIndex.getErrors(operationShape).flatMap { error -> + val testCases = + error.getTrait() + ?.getTestCasesFor(AppliesTo.CLIENT).orEmpty() + testCases.map { TestCase.ResponseTest(it, error) } + } val allTests: List = (requestTests + responseTests + errorTests).filterMatching() if (allTests.isNotEmpty()) { val operationName = operationSymbol.name val testModuleName = "${operationName.toSnakeCase()}_request_test" - val additionalAttributes = listOf( - Attribute(allow("unreachable_code", "unused_variables")), - ) + val additionalAttributes = + listOf( + Attribute(allow("unreachable_code", "unused_variables")), + ) writer.withInlineModule( RustModule.inlineTests(testModuleName, additionalAttributes = additionalAttributes), null, @@ -168,41 +174,42 @@ class DefaultProtocolTestGenerator( testModuleWriter.write("Test ID: ${testCase.id}") testModuleWriter.newlinePrefix = "" Attribute.TokioTest.render(testModuleWriter) - val action = when (testCase) { - is HttpResponseTestCase -> Action.Response - is HttpRequestTestCase -> Action.Request - else -> throw CodegenException("unknown test case type") - } + val action = + when (testCase) { + is HttpResponseTestCase -> Action.Response + is HttpRequestTestCase -> Action.Request + else -> throw CodegenException("unknown test case type") + } if (expectFail(testCase)) { testModuleWriter.writeWithNoFormatting("#[should_panic]") } - val fnName = when (action) { - is Action.Response -> "_response" - is Action.Request -> "_request" - } + val fnName = + when (action) { + is Action.Response -> "_response" + is Action.Request -> "_request" + } Attribute.AllowUnusedMut.render(testModuleWriter) testModuleWriter.rustBlock("async fn ${testCase.id.toSnakeCase()}$fnName()") { block(this) } } - private fun RustWriter.renderHttpRequestTestCase( - httpRequestTestCase: HttpRequestTestCase, - ) { + private fun RustWriter.renderHttpRequestTestCase(httpRequestTestCase: HttpRequestTestCase) { if (!protocolSupport.requestSerialization) { rust("/* test case disabled for this protocol (not yet supported) */") return } - val customParams = httpRequestTestCase.vendorParams.getObjectMember("endpointParams").orNull()?.let { params -> - writable { - val customizations = codegenContext.rootDecorator.endpointCustomizations(codegenContext) - params.getObjectMember("builtInParams").orNull()?.members?.forEach { (name, value) -> - customizations.firstNotNullOf { - it.setBuiltInOnServiceConfig(name.value, value, "config_builder") - }(this) + val customParams = + httpRequestTestCase.vendorParams.getObjectMember("endpointParams").orNull()?.let { params -> + writable { + val customizations = codegenContext.rootDecorator.endpointCustomizations(codegenContext) + params.getObjectMember("builtInParams").orNull()?.members?.forEach { (name, value) -> + customizations.firstNotNullOf { + it.setBuiltInOnServiceConfig(name.value, value, "config_builder") + }(this) + } } - } - } ?: writable { } + } ?: writable { } // support test cases that set the host value, e.g: https://github.com/smithy-lang/smithy/blob/be68f3bbdfe5bf50a104b387094d40c8069f16b1/smithy-aws-protocol-tests/model/restJson1/endpoint-paths.smithy#L19 val host = "https://${httpRequestTestCase.host.orNull() ?: "example.com"}".dq() rustTemplate( @@ -212,8 +219,9 @@ class DefaultProtocolTestGenerator( #{customParams} """, - "capture_request" to CargoDependency.smithyRuntimeTestUtil(rc).toType() - .resolve("client::http::test_util::capture_request"), + "capture_request" to + CargoDependency.smithyRuntimeTestUtil(rc).toType() + .resolve("client::http::test_util::capture_request"), "config" to ClientRustModule.config, "customParams" to customParams, ) @@ -268,25 +276,28 @@ class DefaultProtocolTestGenerator( } } - private fun HttpMessageTestCase.action(): Action = when (this) { - is HttpRequestTestCase -> Action.Request - is HttpResponseTestCase -> Action.Response - else -> throw CodegenException("Unknown test case type") - } + private fun HttpMessageTestCase.action(): Action = + when (this) { + is HttpRequestTestCase -> Action.Request + is HttpResponseTestCase -> Action.Response + else -> throw CodegenException("Unknown test case type") + } - private fun expectFail(testCase: HttpMessageTestCase): Boolean = ExpectFail.find { - it.id == testCase.id && it.action == testCase.action() && it.service == codegenContext.serviceShape.id.toString() - } != null + private fun expectFail(testCase: HttpMessageTestCase): Boolean = + ExpectFail.find { + it.id == testCase.id && it.action == testCase.action() && it.service == codegenContext.serviceShape.id.toString() + } != null private fun RustWriter.renderHttpResponseTestCase( testCase: HttpResponseTestCase, expectedShape: StructureShape, ) { if (!protocolSupport.responseDeserialization || ( - !protocolSupport.errorDeserialization && expectedShape.hasTrait( - ErrorTrait::class.java, - ) - ) + !protocolSupport.errorDeserialization && + expectedShape.hasTrait( + ErrorTrait::class.java, + ) + ) ) { rust("/* test case disabled for this protocol (not yet supported) */") return @@ -329,8 +340,9 @@ class DefaultProtocolTestGenerator( }); """, "copy_from_slice" to RT.Bytes.resolve("copy_from_slice"), - "SharedResponseDeserializer" to RT.smithyRuntimeApiClient(rc) - .resolve("client::ser_de::SharedResponseDeserializer"), + "SharedResponseDeserializer" to + RT.smithyRuntimeApiClient(rc) + .resolve("client::ser_de::SharedResponseDeserializer"), "Operation" to codegenContext.symbolProvider.toSymbol(operationShape), "DeserializeResponse" to RT.smithyRuntimeApiClient(rc).resolve("client::ser_de::DeserializeResponse"), "RuntimePlugin" to RT.runtimePlugin(rc), @@ -394,7 +406,11 @@ class DefaultProtocolTestGenerator( } } - private fun checkBody(rustWriter: RustWriter, body: String, mediaType: String?) { + private fun checkBody( + rustWriter: RustWriter, + body: String, + mediaType: String?, + ) { rustWriter.write("""let body = http_request.body().bytes().expect("body should be strict");""") if (body == "") { rustWriter.rustTemplate( @@ -418,7 +434,11 @@ class DefaultProtocolTestGenerator( } } - private fun checkRequiredHeaders(rustWriter: RustWriter, actualExpression: String, requireHeaders: List) { + private fun checkRequiredHeaders( + rustWriter: RustWriter, + actualExpression: String, + requireHeaders: List, + ) { basicCheck( requireHeaders, rustWriter, @@ -428,7 +448,11 @@ class DefaultProtocolTestGenerator( ) } - private fun checkForbidHeaders(rustWriter: RustWriter, actualExpression: String, forbidHeaders: List) { + private fun checkForbidHeaders( + rustWriter: RustWriter, + actualExpression: String, + forbidHeaders: List, + ) { basicCheck( forbidHeaders, rustWriter, @@ -438,7 +462,11 @@ class DefaultProtocolTestGenerator( ) } - private fun checkHeaders(rustWriter: RustWriter, actualExpression: String, headers: Map) { + private fun checkHeaders( + rustWriter: RustWriter, + actualExpression: String, + headers: Map, + ) { if (headers.isEmpty()) { return } @@ -516,13 +544,19 @@ class DefaultProtocolTestGenerator( * wraps `inner` in a call to `aws_smithy_protocol_test::assert_ok`, a convenience wrapper * for pretty printing protocol test helper results */ - private fun assertOk(rustWriter: RustWriter, inner: Writable) { + private fun assertOk( + rustWriter: RustWriter, + inner: Writable, + ) { rustWriter.write("#T(", RT.protocolTest(rc, "assert_ok")) inner(rustWriter) rustWriter.write(");") } - private fun strSlice(writer: RustWriter, args: List) { + private fun strSlice( + writer: RustWriter, + args: List, + ) { writer.withBlock("&[", "]") { write(args.joinToString(",") { it.dq() }) } @@ -531,6 +565,7 @@ class DefaultProtocolTestGenerator( companion object { sealed class Action { object Request : Action() + object Response : Action() } @@ -551,20 +586,21 @@ class DefaultProtocolTestGenerator( // These tests are not even attempted to be generated, either because they will not compile // or because they are flaky - private val DisableTests = setOf( - // TODO(https://github.com/smithy-lang/smithy-rs/issues/2891): Implement support for `@requestCompression` - "SDKAppendedGzipAfterProvidedEncoding_restJson1", - "SDKAppendedGzipAfterProvidedEncoding_restXml", - "SDKAppendsGzipAndIgnoresHttpProvidedEncoding_awsJson1_0", - "SDKAppendsGzipAndIgnoresHttpProvidedEncoding_awsJson1_1", - "SDKAppendsGzipAndIgnoresHttpProvidedEncoding_awsQuery", - "SDKAppendsGzipAndIgnoresHttpProvidedEncoding_ec2Query", - "SDKAppliedContentEncoding_awsJson1_0", - "SDKAppliedContentEncoding_awsJson1_1", - "SDKAppliedContentEncoding_awsQuery", - "SDKAppliedContentEncoding_ec2Query", - "SDKAppliedContentEncoding_restJson1", - "SDKAppliedContentEncoding_restXml", - ) + private val DisableTests = + setOf( + // TODO(https://github.com/smithy-lang/smithy-rs/issues/2891): Implement support for `@requestCompression` + "SDKAppendedGzipAfterProvidedEncoding_restJson1", + "SDKAppendedGzipAfterProvidedEncoding_restXml", + "SDKAppendsGzipAndIgnoresHttpProvidedEncoding_awsJson1_0", + "SDKAppendsGzipAndIgnoresHttpProvidedEncoding_awsJson1_1", + "SDKAppendsGzipAndIgnoresHttpProvidedEncoding_awsQuery", + "SDKAppendsGzipAndIgnoresHttpProvidedEncoding_ec2Query", + "SDKAppliedContentEncoding_awsJson1_0", + "SDKAppliedContentEncoding_awsJson1_1", + "SDKAppliedContentEncoding_awsQuery", + "SDKAppliedContentEncoding_ec2Query", + "SDKAppliedContentEncoding_restJson1", + "SDKAppliedContentEncoding_restXml", + ) } } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/RequestSerializerGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/RequestSerializerGenerator.kt index f079cf1026a..43cc46e8dbf 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/RequestSerializerGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/RequestSerializerGenerator.kt @@ -49,15 +49,19 @@ class RequestSerializerGenerator( "operation" to RuntimeType.operationModule(codegenContext.runtimeConfig), "SerializeRequest" to runtimeApi.resolve("client::ser_de::SerializeRequest"), "SdkBody" to RuntimeType.sdkBody(codegenContext.runtimeConfig), - "HeaderSerializationSettings" to RuntimeType.forInlineDependency( - InlineDependency.serializationSettings( - codegenContext.runtimeConfig, - ), - ).resolve("HeaderSerializationSettings"), + "HeaderSerializationSettings" to + RuntimeType.forInlineDependency( + InlineDependency.serializationSettings( + codegenContext.runtimeConfig, + ), + ).resolve("HeaderSerializationSettings"), ) } - fun render(writer: RustWriter, operationShape: OperationShape) { + fun render( + writer: RustWriter, + operationShape: OperationShape, + ) { val inputShape = operationShape.inputShape(codegenContext.model) val operationName = symbolProvider.toSymbol(operationShape).name val inputSymbol = symbolProvider.toSymbol(inputShape) @@ -83,39 +87,42 @@ class RequestSerializerGenerator( *codegenScope, "ConcreteInput" to inputSymbol, "create_http_request" to createHttpRequest(operationShape), - "generate_body" to writable { - if (bodyGenerator != null) { - val body = writable { - bodyGenerator.generatePayload(this, "input", operationShape) - } - val streamingMember = inputShape.findStreamingMember(codegenContext.model) - val isBlobStreaming = - streamingMember != null && codegenContext.model.expectShape(streamingMember.target) is BlobShape - if (isBlobStreaming) { - // Consume the `ByteStream` into its inner `SdkBody`. - rust("#T.into_inner()", body) + "generate_body" to + writable { + if (bodyGenerator != null) { + val body = + writable { + bodyGenerator.generatePayload(this, "input", operationShape) + } + val streamingMember = inputShape.findStreamingMember(codegenContext.model) + val isBlobStreaming = + streamingMember != null && codegenContext.model.expectShape(streamingMember.target) is BlobShape + if (isBlobStreaming) { + // Consume the `ByteStream` into its inner `SdkBody`. + rust("#T.into_inner()", body) + } else { + rustTemplate("#{SdkBody}::from(#{body})", *codegenScope, "body" to body) + } } else { - rustTemplate("#{SdkBody}::from(#{body})", *codegenScope, "body" to body) + rustTemplate("#{SdkBody}::empty()", *codegenScope) + } + }, + "add_content_length" to + if (needsContentLength(operationShape)) { + writable { + rustTemplate( + """ + if let Some(content_length) = body.content_length() { + let content_length = content_length.to_string(); + request_builder = _header_serialization_settings.set_default_header(request_builder, #{http}::header::CONTENT_LENGTH, &content_length); + } + """, + *codegenScope, + ) } } else { - rustTemplate("#{SdkBody}::empty()", *codegenScope) - } - }, - "add_content_length" to if (needsContentLength(operationShape)) { - writable { - rustTemplate( - """ - if let Some(content_length) = body.content_length() { - let content_length = content_length.to_string(); - request_builder = _header_serialization_settings.set_default_header(request_builder, #{http}::header::CONTENT_LENGTH, &content_length); - } - """, - *codegenScope, - ) - } - } else { - writable { } - }, + writable { } + }, ) } @@ -124,34 +131,36 @@ class RequestSerializerGenerator( .any { it.location == HttpLocation.DOCUMENT || it.location == HttpLocation.PAYLOAD } } - private fun createHttpRequest(operationShape: OperationShape): Writable = writable { - val httpBindingGenerator = RequestBindingGenerator( - codegenContext, - protocol, - operationShape, - ) - httpBindingGenerator.renderUpdateHttpBuilder(this) - val contentType = httpBindingResolver.requestContentType(operationShape) + private fun createHttpRequest(operationShape: OperationShape): Writable = + writable { + val httpBindingGenerator = + RequestBindingGenerator( + codegenContext, + protocol, + operationShape, + ) + httpBindingGenerator.renderUpdateHttpBuilder(this) + val contentType = httpBindingResolver.requestContentType(operationShape) - rustTemplate("let mut builder = update_http_builder(&input, #{HttpRequestBuilder}::new())?;", *codegenScope) - if (contentType != null) { - rustTemplate( - "builder = _header_serialization_settings.set_default_header(builder, #{http}::header::CONTENT_TYPE, ${contentType.dq()});", - *codegenScope, - ) - } - for (header in protocol.additionalRequestHeaders(operationShape)) { - rustTemplate( - """ - builder = _header_serialization_settings.set_default_header( - builder, - #{http}::header::HeaderName::from_static(${header.first.dq()}), - ${header.second.dq()} - ); - """, - *codegenScope, - ) + rustTemplate("let mut builder = update_http_builder(&input, #{HttpRequestBuilder}::new())?;", *codegenScope) + if (contentType != null) { + rustTemplate( + "builder = _header_serialization_settings.set_default_header(builder, #{http}::header::CONTENT_TYPE, ${contentType.dq()});", + *codegenScope, + ) + } + for (header in protocol.additionalRequestHeaders(operationShape)) { + rustTemplate( + """ + builder = _header_serialization_settings.set_default_header( + builder, + #{http}::header::HeaderName::from_static(${header.first.dq()}), + ${header.second.dq()} + ); + """, + *codegenScope, + ) + } + rust("builder") } - rust("builder") - } } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ResponseDeserializerGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ResponseDeserializerGenerator.kt index 2f183c3444e..ac797e1409e 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ResponseDeserializerGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ResponseDeserializerGenerator.kt @@ -52,7 +52,11 @@ class ResponseDeserializerGenerator( ) } - fun render(writer: RustWriter, operationShape: OperationShape, customizations: List) { + fun render( + writer: RustWriter, + operationShape: OperationShape, + customizations: List, + ) { val outputSymbol = symbolProvider.toSymbol(operationShape.outputShape(model)) val operationName = symbolProvider.toSymbol(operationShape).name val streaming = operationShape.outputShape(model).hasStreamingMember(model) @@ -72,17 +76,19 @@ class ResponseDeserializerGenerator( *codegenScope, "O" to outputSymbol, "E" to symbolProvider.symbolForOperationError(operationShape), - "deserialize_streaming" to writable { - if (streaming) { - deserializeStreaming(operationShape, customizations) - } - }, - "deserialize_nonstreaming" to writable { - when (streaming) { - true -> deserializeStreamingError(operationShape, customizations) - else -> deserializeNonStreaming(operationShape, customizations) - } - }, + "deserialize_streaming" to + writable { + if (streaming) { + deserializeStreaming(operationShape, customizations) + } + }, + "deserialize_nonstreaming" to + writable { + when (streaming) { + true -> deserializeStreamingError(operationShape, customizations) + else -> deserializeNonStreaming(operationShape, customizations) + } + }, ) } @@ -107,9 +113,10 @@ class ResponseDeserializerGenerator( """, *codegenScope, "parse_streaming_response" to parserGenerator.parseStreamingResponseFn(operationShape, customizations), - "BeforeParseResponse" to writable { - writeCustomizations(customizations, OperationSection.BeforeParseResponse(customizations, "response", "force_error", body = null)) - }, + "BeforeParseResponse" to + writable { + writeCustomizations(customizations, OperationSection.BeforeParseResponse(customizations, "response", "force_error", body = null)) + }, ) } @@ -151,26 +158,28 @@ class ResponseDeserializerGenerator( *codegenScope, "parse_error" to parserGenerator.parseErrorFn(operationShape, customizations), "parse_response" to parserGenerator.parseResponseFn(operationShape, customizations), - "BeforeParseResponse" to writable { - writeCustomizations(customizations, OperationSection.BeforeParseResponse(customizations, "response", "force_error", "body")) - }, + "BeforeParseResponse" to + writable { + writeCustomizations(customizations, OperationSection.BeforeParseResponse(customizations, "response", "force_error", "body")) + }, ) } - private fun typeEraseResult(): RuntimeType = ProtocolFunctions.crossOperationFn("type_erase_result") { fnName -> - rustTemplate( - """ - pub(crate) fn $fnName(result: #{Result}) -> #{Result}<#{Output}, #{OrchestratorError}<#{Error}>> - where - O: ::std::fmt::Debug + #{Send} + #{Sync} + 'static, - E: ::std::error::Error + std::fmt::Debug + #{Send} + #{Sync} + 'static, - { - result.map(|output| #{Output}::erase(output)) - .map_err(|error| #{Error}::erase(error)) - .map_err(#{Into}::into) - } - """, - *codegenScope, - ) - } + private fun typeEraseResult(): RuntimeType = + ProtocolFunctions.crossOperationFn("type_erase_result") { fnName -> + rustTemplate( + """ + pub(crate) fn $fnName(result: #{Result}) -> #{Result}<#{Output}, #{OrchestratorError}<#{Error}>> + where + O: ::std::fmt::Debug + #{Send} + #{Sync} + 'static, + E: ::std::error::Error + std::fmt::Debug + #{Send} + #{Sync} + 'static, + { + result.map(|output| #{Output}::erase(output)) + .map_err(|error| #{Error}::erase(error)) + .map_err(#{Into}::into) + } + """, + *codegenScope, + ) + } } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/ClientProtocolLoader.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/ClientProtocolLoader.kt index cb9c7db7e5a..5c0f7e6e1a6 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/ClientProtocolLoader.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/ClientProtocolLoader.kt @@ -32,32 +32,33 @@ import software.amazon.smithy.rust.codegen.core.util.hasTrait class ClientProtocolLoader(supportedProtocols: ProtocolMap) : ProtocolLoader(supportedProtocols) { - companion object { - val DefaultProtocols = mapOf( - AwsJson1_0Trait.ID to ClientAwsJsonFactory(AwsJsonVersion.Json10), - AwsJson1_1Trait.ID to ClientAwsJsonFactory(AwsJsonVersion.Json11), - AwsQueryTrait.ID to ClientAwsQueryFactory(), - Ec2QueryTrait.ID to ClientEc2QueryFactory(), - RestJson1Trait.ID to ClientRestJsonFactory(), - RestXmlTrait.ID to ClientRestXmlFactory(), - ) + val DefaultProtocols = + mapOf( + AwsJson1_0Trait.ID to ClientAwsJsonFactory(AwsJsonVersion.Json10), + AwsJson1_1Trait.ID to ClientAwsJsonFactory(AwsJsonVersion.Json11), + AwsQueryTrait.ID to ClientAwsQueryFactory(), + Ec2QueryTrait.ID to ClientEc2QueryFactory(), + RestJson1Trait.ID to ClientRestJsonFactory(), + RestXmlTrait.ID to ClientRestXmlFactory(), + ) val Default = ClientProtocolLoader(DefaultProtocols) } } -private val CLIENT_PROTOCOL_SUPPORT = ProtocolSupport( - /* Client protocol codegen enabled */ - requestSerialization = true, - requestBodySerialization = true, - responseDeserialization = true, - errorDeserialization = true, - /* Server protocol codegen disabled */ - requestDeserialization = false, - requestBodyDeserialization = false, - responseSerialization = false, - errorSerialization = false, -) +private val CLIENT_PROTOCOL_SUPPORT = + ProtocolSupport( + // Client protocol codegen enabled + requestSerialization = true, + requestBodySerialization = true, + responseDeserialization = true, + errorDeserialization = true, + // Server protocol codegen disabled + requestDeserialization = false, + requestBodyDeserialization = false, + responseSerialization = false, + errorSerialization = false, + ) private class ClientAwsJsonFactory(private val version: AwsJsonVersion) : ProtocolGeneratorFactory { @@ -73,8 +74,10 @@ private class ClientAwsJsonFactory(private val version: AwsJsonVersion) : override fun support(): ProtocolSupport = CLIENT_PROTOCOL_SUPPORT - private fun compatibleWithAwsQuery(serviceShape: ServiceShape, version: AwsJsonVersion) = - serviceShape.hasTrait() && version == AwsJsonVersion.Json10 + private fun compatibleWithAwsQuery( + serviceShape: ServiceShape, + version: AwsJsonVersion, + ) = serviceShape.hasTrait() && version == AwsJsonVersion.Json10 } private class ClientAwsQueryFactory : ProtocolGeneratorFactory { diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/HttpBoundProtocolGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/HttpBoundProtocolGenerator.kt index f3cc0abad08..c2cbc7c8572 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/HttpBoundProtocolGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/HttpBoundProtocolGenerator.kt @@ -18,27 +18,28 @@ class ClientHttpBoundProtocolPayloadGenerator( codegenContext: ClientCodegenContext, protocol: Protocol, ) : ProtocolPayloadGenerator by HttpBoundProtocolPayloadGenerator( - codegenContext, protocol, HttpMessageType.REQUEST, - renderEventStreamBody = { writer, params -> - writer.rustTemplate( - """ - { - let error_marshaller = #{errorMarshallerConstructorFn}(); - let marshaller = #{marshallerConstructorFn}(); - let (signer, signer_sender) = #{DeferredSigner}::new(); - _cfg.interceptor_state().store_put(signer_sender); - let adapter: #{aws_smithy_http}::event_stream::MessageStreamAdapter<_, _> = - ${params.outerName}.${params.memberName}.into_body_stream(marshaller, error_marshaller, signer); - #{SdkBody}::from_body_0_4(#{hyper}::Body::wrap_stream(adapter)) - } - """, - "hyper" to CargoDependency.HyperWithStream.toType(), - "SdkBody" to CargoDependency.smithyTypes(codegenContext.runtimeConfig).withFeature("http-body-0-4-x") - .toType().resolve("body::SdkBody"), - "aws_smithy_http" to RuntimeType.smithyHttp(codegenContext.runtimeConfig), - "DeferredSigner" to RuntimeType.smithyEventStream(codegenContext.runtimeConfig).resolve("frame::DeferredSigner"), - "marshallerConstructorFn" to params.marshallerConstructorFn, - "errorMarshallerConstructorFn" to params.errorMarshallerConstructorFn, - ) - }, -) + codegenContext, protocol, HttpMessageType.REQUEST, + renderEventStreamBody = { writer, params -> + writer.rustTemplate( + """ + { + let error_marshaller = #{errorMarshallerConstructorFn}(); + let marshaller = #{marshallerConstructorFn}(); + let (signer, signer_sender) = #{DeferredSigner}::new(); + _cfg.interceptor_state().store_put(signer_sender); + let adapter: #{aws_smithy_http}::event_stream::MessageStreamAdapter<_, _> = + ${params.outerName}.${params.memberName}.into_body_stream(marshaller, error_marshaller, signer); + #{SdkBody}::from_body_0_4(#{hyper}::Body::wrap_stream(adapter)) + } + """, + "hyper" to CargoDependency.HyperWithStream.toType(), + "SdkBody" to + CargoDependency.smithyTypes(codegenContext.runtimeConfig).withFeature("http-body-0-4-x") + .toType().resolve("body::SdkBody"), + "aws_smithy_http" to RuntimeType.smithyHttp(codegenContext.runtimeConfig), + "DeferredSigner" to RuntimeType.smithyEventStream(codegenContext.runtimeConfig).resolve("frame::DeferredSigner"), + "marshallerConstructorFn" to params.marshallerConstructorFn, + "errorMarshallerConstructorFn" to params.errorMarshallerConstructorFn, + ) + }, + ) diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/transformers/RemoveEventStreamOperations.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/transformers/RemoveEventStreamOperations.kt index 3d757e4bdb1..863f9daa88e 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/transformers/RemoveEventStreamOperations.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/transformers/RemoveEventStreamOperations.kt @@ -15,11 +15,15 @@ import software.amazon.smithy.rust.codegen.core.util.orNull import java.util.logging.Logger // TODO(EventStream): [CLEANUP] Remove this class once the Event Stream implementation is stable + /** Transformer to REMOVE operations that use EventStreaming until event streaming is supported */ object RemoveEventStreamOperations { private val logger = Logger.getLogger(javaClass.name) - fun transform(model: Model, settings: ClientRustSettings): Model { + fun transform( + model: Model, + settings: ClientRustSettings, + ): Model { // If Event Stream is allowed in build config, then don't remove the operations val allowList = settings.codegenConfig.eventStreamAllowList if (allowList.isEmpty() || allowList.contains(settings.moduleName)) { @@ -30,16 +34,18 @@ object RemoveEventStreamOperations { if (parentShape !is OperationShape) { true } else { - val ioShapes = listOfNotNull(parentShape.output.orNull(), parentShape.input.orNull()).map { - model.expectShape( - it, - StructureShape::class.java, - ) - } - val hasEventStream = ioShapes.any { ioShape -> - val streamingMember = ioShape.findStreamingMember(model)?.let { model.expectShape(it.target) } - streamingMember?.isUnionShape ?: false - } + val ioShapes = + listOfNotNull(parentShape.output.orNull(), parentShape.input.orNull()).map { + model.expectShape( + it, + StructureShape::class.java, + ) + } + val hasEventStream = + ioShapes.any { ioShape -> + val streamingMember = ioShape.findStreamingMember(model)?.let { model.expectShape(it.target) } + streamingMember?.isUnionShape ?: false + } // If a streaming member has a union trait, it is an event stream. Event Streams are not currently supported // by the SDK, so if we generate this API it won't work. (!hasEventStream).also { diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/testutil/ClientCodegenIntegrationTest.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/testutil/ClientCodegenIntegrationTest.kt index 5d6de735b99..5846a44221d 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/testutil/ClientCodegenIntegrationTest.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/testutil/ClientCodegenIntegrationTest.kt @@ -18,21 +18,26 @@ import java.nio.file.Path fun clientIntegrationTest( model: Model, - params: IntegrationTestParams = IntegrationTestParams(cargoCommand = "cargo test --features behavior-version-latest"), + params: IntegrationTestParams = + IntegrationTestParams(cargoCommand = "cargo test --features behavior-version-latest"), additionalDecorators: List = listOf(), test: (ClientCodegenContext, RustCrate) -> Unit = { _, _ -> }, ): Path { fun invokeRustCodegenPlugin(ctx: PluginContext) { - val codegenDecorator = object : ClientCodegenDecorator { - override val name: String = "Add tests" - override val order: Byte = 0 + val codegenDecorator = + object : ClientCodegenDecorator { + override val name: String = "Add tests" + override val order: Byte = 0 - override fun classpathDiscoverable(): Boolean = false + override fun classpathDiscoverable(): Boolean = false - override fun extras(codegenContext: ClientCodegenContext, rustCrate: RustCrate) { - test(codegenContext, rustCrate) + override fun extras( + codegenContext: ClientCodegenContext, + rustCrate: RustCrate, + ) { + test(codegenContext, rustCrate) + } } - } RustClientCodegenPlugin().executeWithDecorator(ctx, codegenDecorator, *additionalDecorators.toTypedArray()) } return codegenIntegrationTest(model, params, invokePlugin = ::invokeRustCodegenPlugin) diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/testutil/TestHelpers.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/testutil/TestHelpers.kt index a65d734085d..e4b9e11ac5e 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/testutil/TestHelpers.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/testutil/TestHelpers.kt @@ -51,19 +51,23 @@ fun testClientRustSettings( customizationConfig, ) -val TestClientRustSymbolProviderConfig = RustSymbolProviderConfig( - runtimeConfig = TestRuntimeConfig, - renameExceptions = true, - nullabilityCheckMode = NullableIndex.CheckMode.CLIENT_ZERO_VALUE_V1, - moduleProvider = ClientModuleProvider, -) +val TestClientRustSymbolProviderConfig = + RustSymbolProviderConfig( + runtimeConfig = TestRuntimeConfig, + renameExceptions = true, + nullabilityCheckMode = NullableIndex.CheckMode.CLIENT_ZERO_VALUE_V1, + moduleProvider = ClientModuleProvider, + ) private class ClientTestCodegenDecorator : ClientCodegenDecorator { override val name = "test" override val order: Byte = 0 } -fun testSymbolProvider(model: Model, serviceShape: ServiceShape? = null): RustSymbolProvider = +fun testSymbolProvider( + model: Model, + serviceShape: ServiceShape? = null, +): RustSymbolProvider = RustClientCodegenPlugin.baseSymbolProvider( testClientRustSettings(), model, @@ -78,19 +82,22 @@ fun testClientCodegenContext( serviceShape: ServiceShape? = null, settings: ClientRustSettings = testClientRustSettings(), rootDecorator: ClientCodegenDecorator? = null, -): ClientCodegenContext = ClientCodegenContext( - model, - symbolProvider ?: testSymbolProvider(model), - TestModuleDocProvider, - serviceShape - ?: model.serviceShapes.firstOrNull() - ?: ServiceShape.builder().version("test").id("test#Service").build(), - ShapeId.from("test#Protocol"), - settings, - rootDecorator ?: CombinedClientCodegenDecorator(emptyList()), -) +): ClientCodegenContext = + ClientCodegenContext( + model, + symbolProvider ?: testSymbolProvider(model), + TestModuleDocProvider, + serviceShape + ?: model.serviceShapes.firstOrNull() + ?: ServiceShape.builder().version("test").id("test#Service").build(), + ShapeId.from("test#Protocol"), + settings, + rootDecorator ?: CombinedClientCodegenDecorator(emptyList()), + ) -fun ClientCodegenContext.withEnableUserConfigurableRuntimePlugins(enableUserConfigurableRuntimePlugins: Boolean): ClientCodegenContext = +fun ClientCodegenContext.withEnableUserConfigurableRuntimePlugins( + enableUserConfigurableRuntimePlugins: Boolean, +): ClientCodegenContext = copy(settings = settings.copy(codegenConfig = settings.codegenConfig.copy(enableUserConfigurableRuntimePlugins = enableUserConfigurableRuntimePlugins))) fun TestWriterDelegator.clientRustSettings() = @@ -100,4 +107,5 @@ fun TestWriterDelegator.clientRustSettings() = codegenConfig = codegenConfig as ClientCodegenConfig, ) -fun TestWriterDelegator.clientCodegenContext(model: Model) = testClientCodegenContext(model, settings = clientRustSettings()) +fun TestWriterDelegator.clientCodegenContext(model: Model) = + testClientCodegenContext(model, settings = clientRustSettings()) diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientCodegenVisitorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientCodegenVisitorTest.kt index f9cf52375b6..16e89f9feaa 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientCodegenVisitorTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientCodegenVisitorTest.kt @@ -18,7 +18,8 @@ import software.amazon.smithy.rust.codegen.core.testutil.generatePluginContext class ClientCodegenVisitorTest { @Test fun `baseline transform verify mixins removed`() { - val model = """ + val model = + """ namespace com.example use aws.protocols#awsJson1_0 @@ -43,7 +44,7 @@ class ClientCodegenVisitorTest { ] { greeting: String } - """.asSmithyModel(smithyVersion = "2.0") + """.asSmithyModel(smithyVersion = "2.0") val (ctx, _) = generatePluginContext(model) val codegenDecorator = CombinedClientCodegenDecorator.fromClasspath( diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/EventStreamSymbolProviderTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/EventStreamSymbolProviderTest.kt index 6b8c28826cc..27ebff53cab 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/EventStreamSymbolProviderTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/EventStreamSymbolProviderTest.kt @@ -26,33 +26,35 @@ class EventStreamSymbolProviderTest { @Test fun `it should adjust types for operations with event streams`() { // Transform the model so that it has synthetic inputs/outputs - val model = OperationNormalizer.transform( - """ - namespace test - - structure Something { stuff: Blob } - - @streaming - union SomeStream { - Something: Something, - } - - structure TestInput { inputStream: SomeStream } - structure TestOutput { outputStream: SomeStream } - operation TestOperation { - input: TestInput, - output: TestOutput, - } - service TestService { version: "123", operations: [TestOperation] } - """.asSmithyModel(), - ) + val model = + OperationNormalizer.transform( + """ + namespace test + + structure Something { stuff: Blob } + + @streaming + union SomeStream { + Something: Something, + } + + structure TestInput { inputStream: SomeStream } + structure TestOutput { outputStream: SomeStream } + operation TestOperation { + input: TestInput, + output: TestOutput, + } + service TestService { version: "123", operations: [TestOperation] } + """.asSmithyModel(), + ) val service = model.expectShape(ShapeId.from("test#TestService")) as ServiceShape - val provider = EventStreamSymbolProvider( - TestRuntimeConfig, - SymbolVisitor(testClientRustSettings(), model, service, TestClientRustSymbolProviderConfig), - CodegenTarget.CLIENT, - ) + val provider = + EventStreamSymbolProvider( + TestRuntimeConfig, + SymbolVisitor(testClientRustSettings(), model, service, TestClientRustSymbolProviderConfig), + CodegenTarget.CLIENT, + ) // Look up the synthetic input/output rather than the original input/output val inputStream = model.expectShape(ShapeId.from("test.synthetic#TestOperationInput\$inputStream")) as MemberShape @@ -64,44 +66,48 @@ class EventStreamSymbolProviderTest { val someStream = RustType.Opaque("SomeStream", "crate::types") val someStreamError = RustType.Opaque("SomeStreamError", "crate::types::error") - inputType shouldBe RustType.Application( - RuntimeType.eventStreamSender(TestRuntimeConfig).toSymbol().rustType(), - listOf(someStream, someStreamError), - ) - outputType shouldBe RustType.Application( - RuntimeType.eventReceiver(TestRuntimeConfig).toSymbol().rustType(), - listOf(someStream, someStreamError), - ) + inputType shouldBe + RustType.Application( + RuntimeType.eventStreamSender(TestRuntimeConfig).toSymbol().rustType(), + listOf(someStream, someStreamError), + ) + outputType shouldBe + RustType.Application( + RuntimeType.eventReceiver(TestRuntimeConfig).toSymbol().rustType(), + listOf(someStream, someStreamError), + ) } @Test fun `it should leave alone types for operations without event streams`() { - val model = OperationNormalizer.transform( - """ - namespace test - - structure Something { stuff: Blob } - - union NotStreaming { - Something: Something, - } - - structure TestInput { inputStream: NotStreaming } - structure TestOutput { outputStream: NotStreaming } - operation TestOperation { - input: TestInput, - output: TestOutput, - } - service TestService { version: "123", operations: [TestOperation] } - """.asSmithyModel(), - ) + val model = + OperationNormalizer.transform( + """ + namespace test + + structure Something { stuff: Blob } + + union NotStreaming { + Something: Something, + } + + structure TestInput { inputStream: NotStreaming } + structure TestOutput { outputStream: NotStreaming } + operation TestOperation { + input: TestInput, + output: TestOutput, + } + service TestService { version: "123", operations: [TestOperation] } + """.asSmithyModel(), + ) val service = model.expectShape(ShapeId.from("test#TestService")) as ServiceShape - val provider = EventStreamSymbolProvider( - TestRuntimeConfig, - SymbolVisitor(testClientRustSettings(), model, service, TestClientRustSymbolProviderConfig), - CodegenTarget.CLIENT, - ) + val provider = + EventStreamSymbolProvider( + TestRuntimeConfig, + SymbolVisitor(testClientRustSettings(), model, service, TestClientRustSymbolProviderConfig), + CodegenTarget.CLIENT, + ) // Look up the synthetic input/output rather than the original input/output val inputStream = model.expectShape(ShapeId.from("test.synthetic#TestOperationInput\$inputStream")) as MemberShape diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/StreamingShapeSymbolProviderTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/StreamingShapeSymbolProviderTest.kt index 9846d9b0c00..61bf6562b1c 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/StreamingShapeSymbolProviderTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/StreamingShapeSymbolProviderTest.kt @@ -18,7 +18,8 @@ import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.util.lookup internal class StreamingShapeSymbolProviderTest { - val model = """ + val model = + """ namespace test operation GenerateSpeech { output: GenerateSpeechOutput, @@ -32,7 +33,7 @@ internal class StreamingShapeSymbolProviderTest { @streaming blob BlobStream - """.asSmithyModel() + """.asSmithyModel() @Test fun `generates a byte stream on streaming output`() { diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/HttpAuthDecoratorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/HttpAuthDecoratorTest.kt index e235144f2d7..e81a2d2ee83 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/HttpAuthDecoratorTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/HttpAuthDecoratorTest.kt @@ -17,15 +17,18 @@ import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.testutil.integrationTest class HttpAuthDecoratorTest { - private fun codegenScope(runtimeConfig: RuntimeConfig): Array> = arrayOf( - "ReplayEvent" to CargoDependency.smithyRuntime(runtimeConfig) - .toDevDependency().withFeature("test-util").toType() - .resolve("client::http::test_util::ReplayEvent"), - "StaticReplayClient" to CargoDependency.smithyRuntime(runtimeConfig) - .toDevDependency().withFeature("test-util").toType() - .resolve("client::http::test_util::StaticReplayClient"), - "SdkBody" to RuntimeType.sdkBody(runtimeConfig), - ) + private fun codegenScope(runtimeConfig: RuntimeConfig): Array> = + arrayOf( + "ReplayEvent" to + CargoDependency.smithyRuntime(runtimeConfig) + .toDevDependency().withFeature("test-util").toType() + .resolve("client::http::test_util::ReplayEvent"), + "StaticReplayClient" to + CargoDependency.smithyRuntime(runtimeConfig) + .toDevDependency().withFeature("test-util").toType() + .resolve("client::http::test_util::StaticReplayClient"), + "SdkBody" to RuntimeType.sdkBody(runtimeConfig), + ) @Test fun multipleAuthSchemesSchemeSelection() { @@ -225,7 +228,6 @@ class HttpAuthDecoratorTest { fn compile() {} """, - ) Attribute.TokioTest.render(this) rustTemplate( @@ -255,8 +257,9 @@ class HttpAuthDecoratorTest { http_client.assert_requests_match(&[]); } """, - "capture_test_logs" to CargoDependency.smithyRuntimeTestUtil(ctx.runtimeConfig).toType() - .resolve("test_util::capture_test_logs::capture_test_logs"), + "capture_test_logs" to + CargoDependency.smithyRuntimeTestUtil(ctx.runtimeConfig).toType() + .resolve("test_util::capture_test_logs::capture_test_logs"), *codegenScope(ctx.runtimeConfig), ) } @@ -465,7 +468,8 @@ class HttpAuthDecoratorTest { } private object TestModels { - val allSchemes = """ + val allSchemes = + """ namespace test use aws.api#service @@ -492,9 +496,10 @@ private object TestModels { operation SomeOperation { output: SomeOutput } - """.asSmithyModel() + """.asSmithyModel() - val noSchemes = """ + val noSchemes = + """ namespace test use aws.api#service @@ -517,7 +522,8 @@ private object TestModels { output: SomeOutput }""".asSmithyModel() - val apiKeyInQueryString = """ + val apiKeyInQueryString = + """ namespace test use aws.api#service @@ -541,9 +547,10 @@ private object TestModels { operation SomeOperation { output: SomeOutput } - """.asSmithyModel() + """.asSmithyModel() - val apiKeyInHeaders = """ + val apiKeyInHeaders = + """ namespace test use aws.api#service @@ -567,9 +574,10 @@ private object TestModels { operation SomeOperation { output: SomeOutput } - """.asSmithyModel() + """.asSmithyModel() - val basicAuth = """ + val basicAuth = + """ namespace test use aws.api#service @@ -593,9 +601,10 @@ private object TestModels { operation SomeOperation { output: SomeOutput } - """.asSmithyModel() + """.asSmithyModel() - val bearerAuth = """ + val bearerAuth = + """ namespace test use aws.api#service @@ -619,9 +628,10 @@ private object TestModels { operation SomeOperation { output: SomeOutput } - """.asSmithyModel() + """.asSmithyModel() - val optionalAuth = """ + val optionalAuth = + """ namespace test use aws.api#service @@ -646,5 +656,5 @@ private object TestModels { operation SomeOperation { output: SomeOutput } - """.asSmithyModel() + """.asSmithyModel() } diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/MetadataCustomizationTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/MetadataCustomizationTest.kt index cac24f5b962..d96c25dca09 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/MetadataCustomizationTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/MetadataCustomizationTest.kt @@ -16,22 +16,23 @@ import software.amazon.smithy.rust.codegen.core.testutil.testModule import software.amazon.smithy.rust.codegen.core.testutil.tokioTest class MetadataCustomizationTest { - @Test fun `extract metadata via customizable operation`() { clientIntegrationTest(BasicTestModels.AwsJson10TestModel) { clientCodegenContext, rustCrate -> val runtimeConfig = clientCodegenContext.runtimeConfig - val codegenScope = arrayOf( - *preludeScope, - "BeforeTransmitInterceptorContextMut" to RuntimeType.beforeTransmitInterceptorContextMut(runtimeConfig), - "BoxError" to RuntimeType.boxError(runtimeConfig), - "ConfigBag" to RuntimeType.configBag(runtimeConfig), - "Intercept" to RuntimeType.intercept(runtimeConfig), - "Metadata" to RuntimeType.operationModule(runtimeConfig).resolve("Metadata"), - "capture_request" to RuntimeType.captureRequest(runtimeConfig), - "RuntimeComponents" to RuntimeType.smithyRuntimeApiClient(runtimeConfig) - .resolve("client::runtime_components::RuntimeComponents"), - ) + val codegenScope = + arrayOf( + *preludeScope, + "BeforeTransmitInterceptorContextMut" to RuntimeType.beforeTransmitInterceptorContextMut(runtimeConfig), + "BoxError" to RuntimeType.boxError(runtimeConfig), + "ConfigBag" to RuntimeType.configBag(runtimeConfig), + "Intercept" to RuntimeType.intercept(runtimeConfig), + "Metadata" to RuntimeType.operationModule(runtimeConfig).resolve("Metadata"), + "capture_request" to RuntimeType.captureRequest(runtimeConfig), + "RuntimeComponents" to + RuntimeType.smithyRuntimeApiClient(runtimeConfig) + .resolve("client::runtime_components::RuntimeComponents"), + ) rustCrate.testModule { addDependency(CargoDependency.Tokio.toDevDependency().withFeature("test-util")) tokioTest("test_extract_metadata_via_customizable_operation") { diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/ResiliencyConfigCustomizationTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/ResiliencyConfigCustomizationTest.kt index 30d046ff635..ddc636afbfd 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/ResiliencyConfigCustomizationTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/ResiliencyConfigCustomizationTest.kt @@ -12,7 +12,6 @@ import software.amazon.smithy.rust.codegen.core.testutil.BasicTestModels import software.amazon.smithy.rust.codegen.core.testutil.unitTest internal class ResiliencyConfigCustomizationTest { - @Test fun `generates a valid config`() { clientIntegrationTest(BasicTestModels.AwsJson10TestModel) { _, crate -> diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/SensitiveOutputDecoratorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/SensitiveOutputDecoratorTest.kt index e97ad2921b9..e78e84e4846 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/SensitiveOutputDecoratorTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/SensitiveOutputDecoratorTest.kt @@ -16,14 +16,17 @@ import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.testutil.integrationTest class SensitiveOutputDecoratorTest { - private fun codegenScope(runtimeConfig: RuntimeConfig): Array> = arrayOf( - "capture_test_logs" to CargoDependency.smithyRuntimeTestUtil(runtimeConfig).toType() - .resolve("test_util::capture_test_logs::capture_test_logs"), - "capture_request" to RuntimeType.captureRequest(runtimeConfig), - "SdkBody" to RuntimeType.sdkBody(runtimeConfig), - ) + private fun codegenScope(runtimeConfig: RuntimeConfig): Array> = + arrayOf( + "capture_test_logs" to + CargoDependency.smithyRuntimeTestUtil(runtimeConfig).toType() + .resolve("test_util::capture_test_logs::capture_test_logs"), + "capture_request" to RuntimeType.captureRequest(runtimeConfig), + "SdkBody" to RuntimeType.sdkBody(runtimeConfig), + ) - private val model = """ + private val model = + """ namespace com.example use aws.protocols#awsJson1_0 @awsJson1_0 @@ -43,7 +46,7 @@ class SensitiveOutputDecoratorTest { structure TestOutput { credentials: Credentials, } - """.asSmithyModel() + """.asSmithyModel() @Test fun `sensitive output in model should redact response body`() { diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/ClientContextConfigCustomizationTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/ClientContextConfigCustomizationTest.kt index 15b6f0e5953..8a7c26ad4be 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/ClientContextConfigCustomizationTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/ClientContextConfigCustomizationTest.kt @@ -12,7 +12,8 @@ import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.testutil.unitTest class ClientContextConfigCustomizationTest { - val model = """ + val model = + """ namespace test use smithy.rules#clientContextParams use aws.protocols#awsJson1_0 @@ -27,7 +28,7 @@ class ClientContextConfigCustomizationTest { }) @awsJson1_0 service TestService { operations: [] } - """.asSmithyModel() + """.asSmithyModel() @Test fun `client params generate a valid customization`() { diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/EndpointResolverGeneratorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/EndpointResolverGeneratorTest.kt index aaa8cfa284f..110baeba6c2 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/EndpointResolverGeneratorTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/EndpointResolverGeneratorTest.kt @@ -14,19 +14,20 @@ import software.amazon.smithy.rust.codegen.client.testutil.clientIntegrationTest class EndpointResolverGeneratorTest { companion object { - val testCases = listOf( - "default-values.smithy", - "deprecated-param.smithy", - "duplicate-param.smithy", - "get-attr-type-inference.smithy", - "headers.smithy", - "minimal-ruleset.smithy", - "parse-url.smithy", - "substring.smithy", - "uri-encode.smithy", - "valid-hostlabel.smithy", - "valid-model.smithy", - ) + val testCases = + listOf( + "default-values.smithy", + "deprecated-param.smithy", + "duplicate-param.smithy", + "get-attr-type-inference.smithy", + "headers.smithy", + "minimal-ruleset.smithy", + "parse-url.smithy", + "substring.smithy", + "uri-encode.smithy", + "valid-hostlabel.smithy", + "valid-model.smithy", + ) @JvmStatic fun testSuites(): List { @@ -40,11 +41,12 @@ class EndpointResolverGeneratorTest { } // for tests, load partitions.json from smithy—for real usage, this file will be inserted at codegen time - /*private val partitionsJson = - Node.parse( - this::class.java.getResource("/software/amazon/smithy/rulesengine/language/partitions.json")?.readText() - ?: throw CodegenException("partitions.json was not present in smithy bundle"), - )*/ + // + // private val partitionsJson = + // Node.parse( + // this::class.java.getResource("/software/amazon/smithy/rulesengine/language/partitions.json")?.readText() + // ?: throw CodegenException("partitions.json was not present in smithy bundle"), + // ) @ParameterizedTest(name = "{0}") @MethodSource("testSuites") diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/EndpointsDecoratorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/EndpointsDecoratorTest.kt index 789315e2189..0e5ab3ff780 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/EndpointsDecoratorTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/EndpointsDecoratorTest.kt @@ -23,7 +23,8 @@ import software.amazon.smithy.rust.codegen.core.util.runCommand * End-to-end test of endpoint resolvers, attaching a real resolver to a fully generated service */ class EndpointsDecoratorTest { - val model = """ + val model = + """ namespace test use smithy.rules#endpointRuleSet @@ -122,111 +123,114 @@ class EndpointsDecoratorTest { structure NestedStructure { field: String } - """.asSmithyModel(disableValidation = true) + """.asSmithyModel(disableValidation = true) @Test fun `resolve endpoint`() { - val testDir = clientIntegrationTest( - model, - // Just run integration tests. - IntegrationTestParams(command = { "cargo test --all-features --test *".runCommand(it) }), - ) { clientCodegenContext, rustCrate -> - rustCrate.integrationTest("endpoint_params_test") { - val moduleName = clientCodegenContext.moduleUseName() - Attribute.TokioTest.render(this) - rustTemplate( - """ - async fn endpoint_params_are_set() { - use #{NeverClient}; - use #{TokioSleep}; - use aws_smithy_runtime_api::box_error::BoxError; - use aws_smithy_runtime_api::client::endpoint::EndpointResolverParams; - use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents; - use aws_smithy_types::config_bag::ConfigBag; - use aws_smithy_types::endpoint::Endpoint; - use aws_smithy_types::timeout::TimeoutConfig; - use std::sync::atomic::AtomicBool; - use std::sync::atomic::Ordering; - use std::sync::Arc; - use std::time::Duration; - use $moduleName::{ - config::endpoint::Params, config::interceptors::BeforeTransmitInterceptorContextRef, - config::Intercept, config::SharedAsyncSleep, Client, Config, - }; - - ##[derive(Clone, Debug, Default)] - struct TestInterceptor { - called: Arc, - } - impl Intercept for TestInterceptor { - fn name(&self) -> &'static str { - "TestInterceptor" + val testDir = + clientIntegrationTest( + model, + // Just run integration tests. + IntegrationTestParams(command = { "cargo test --all-features --test *".runCommand(it) }), + ) { clientCodegenContext, rustCrate -> + rustCrate.integrationTest("endpoint_params_test") { + val moduleName = clientCodegenContext.moduleUseName() + Attribute.TokioTest.render(this) + rustTemplate( + """ + async fn endpoint_params_are_set() { + use #{NeverClient}; + use #{TokioSleep}; + use aws_smithy_runtime_api::box_error::BoxError; + use aws_smithy_runtime_api::client::endpoint::EndpointResolverParams; + use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents; + use aws_smithy_types::config_bag::ConfigBag; + use aws_smithy_types::endpoint::Endpoint; + use aws_smithy_types::timeout::TimeoutConfig; + use std::sync::atomic::AtomicBool; + use std::sync::atomic::Ordering; + use std::sync::Arc; + use std::time::Duration; + use $moduleName::{ + config::endpoint::Params, config::interceptors::BeforeTransmitInterceptorContextRef, + config::Intercept, config::SharedAsyncSleep, Client, Config, + }; + + ##[derive(Clone, Debug, Default)] + struct TestInterceptor { + called: Arc, } - - fn read_before_transmit( - &self, - _context: &BeforeTransmitInterceptorContextRef<'_>, - _runtime_components: &RuntimeComponents, - cfg: &mut ConfigBag, - ) -> Result<(), BoxError> { - let params = cfg - .load::() - .expect("params set in config"); - let params: &Params = params.get().expect("correct type"); - assert_eq!( - params, - &Params::builder() - .bucket("bucket-name".to_string()) - .built_in_with_default("some-default") - .bool_built_in_with_default(true) - .a_bool_param(false) - .a_string_param("hello".to_string()) - .region("us-east-2".to_string()) - .build() - .unwrap() - ); - - let endpoint = cfg.load::().expect("endpoint set in config"); - assert_eq!(endpoint.url(), "https://www.us-east-2.example.com"); - - self.called.store(true, Ordering::Relaxed); - Ok(()) + impl Intercept for TestInterceptor { + fn name(&self) -> &'static str { + "TestInterceptor" + } + + fn read_before_transmit( + &self, + _context: &BeforeTransmitInterceptorContextRef<'_>, + _runtime_components: &RuntimeComponents, + cfg: &mut ConfigBag, + ) -> Result<(), BoxError> { + let params = cfg + .load::() + .expect("params set in config"); + let params: &Params = params.get().expect("correct type"); + assert_eq!( + params, + &Params::builder() + .bucket("bucket-name".to_string()) + .built_in_with_default("some-default") + .bool_built_in_with_default(true) + .a_bool_param(false) + .a_string_param("hello".to_string()) + .region("us-east-2".to_string()) + .build() + .unwrap() + ); + + let endpoint = cfg.load::().expect("endpoint set in config"); + assert_eq!(endpoint.url(), "https://www.us-east-2.example.com"); + + self.called.store(true, Ordering::Relaxed); + Ok(()) + } } - } - let interceptor = TestInterceptor::default(); - let config = Config::builder() - .http_client(NeverClient::new()) - .interceptor(interceptor.clone()) - .timeout_config( - TimeoutConfig::builder() - .operation_timeout(Duration::from_millis(30)) - .build(), - ) - .sleep_impl(SharedAsyncSleep::new(TokioSleep::new())) - .a_string_param("hello") - .a_bool_param(false) - .build(); - let client = Client::from_conf(config); - - let _ = dbg!(client.test_operation().bucket("bucket-name").send().await); - assert!( - interceptor.called.load(Ordering::Relaxed), - "the interceptor should have been called" - ); - - // bucket_name is unset and marked as required on the model, so we'll refuse to construct this request - let err = client.test_operation().send().await.expect_err("param missing"); - assert_eq!(format!("{}", err), "failed to construct request"); - } - """, - "NeverClient" to CargoDependency.smithyRuntimeTestUtil(clientCodegenContext.runtimeConfig) - .toType().resolve("client::http::test_util::NeverClient"), - "TokioSleep" to CargoDependency.smithyAsync(clientCodegenContext.runtimeConfig) - .withFeature("rt-tokio").toType().resolve("rt::sleep::TokioSleep"), - ) + let interceptor = TestInterceptor::default(); + let config = Config::builder() + .http_client(NeverClient::new()) + .interceptor(interceptor.clone()) + .timeout_config( + TimeoutConfig::builder() + .operation_timeout(Duration::from_millis(30)) + .build(), + ) + .sleep_impl(SharedAsyncSleep::new(TokioSleep::new())) + .a_string_param("hello") + .a_bool_param(false) + .build(); + let client = Client::from_conf(config); + + let _ = dbg!(client.test_operation().bucket("bucket-name").send().await); + assert!( + interceptor.called.load(Ordering::Relaxed), + "the interceptor should have been called" + ); + + // bucket_name is unset and marked as required on the model, so we'll refuse to construct this request + let err = client.test_operation().send().await.expect_err("param missing"); + assert_eq!(format!("{}", err), "failed to construct request"); + } + """, + "NeverClient" to + CargoDependency.smithyRuntimeTestUtil(clientCodegenContext.runtimeConfig) + .toType().resolve("client::http::test_util::NeverClient"), + "TokioSleep" to + CargoDependency.smithyAsync(clientCodegenContext.runtimeConfig) + .withFeature("rt-tokio").toType().resolve("rt::sleep::TokioSleep"), + ) + } } - } // the model has an intentionally failing test—ensure it fails val failure = shouldThrow { "cargo test".runWithWarnings(testDir) } failure.output shouldContain "endpoint::test::test_1" diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/rulesgen/ExpressionGeneratorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/rulesgen/ExpressionGeneratorTest.kt index 58e23cb28b0..89e75fc36df 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/rulesgen/ExpressionGeneratorTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/rulesgen/ExpressionGeneratorTest.kt @@ -43,14 +43,16 @@ internal class ExprGeneratorTest { @Test fun generateLiterals1() { - val literal = Literal.recordLiteral( - mutableMapOf( - Identifier.of("this") to Literal.integerLiteral(5), - Identifier.of("that") to Literal.stringLiteral( - Template.fromString("static"), + val literal = + Literal.recordLiteral( + mutableMapOf( + Identifier.of("this") to Literal.integerLiteral(5), + Identifier.of("that") to + Literal.stringLiteral( + Template.fromString("static"), + ), ), - ), - ) + ) TestWorkspace.testProject().unitTest { val generator = ExpressionGenerator(Ownership.Borrowed, testContext) @@ -61,13 +63,14 @@ internal class ExprGeneratorTest { @Test fun generateLiterals2() { val project = TestWorkspace.testProject() - val gen = ExpressionGenerator( - Ownership.Borrowed, - Context( - FunctionRegistry(listOf()), - TestRuntimeConfig, - ), - ) + val gen = + ExpressionGenerator( + Ownership.Borrowed, + Context( + FunctionRegistry(listOf()), + TestRuntimeConfig, + ), + ) project.unitTest { rust("""let extra = "helloworld";""") rust("assert_eq!(true, #W);", gen.generate(Expression.of(true))) @@ -83,12 +86,13 @@ internal class ExprGeneratorTest { assert_eq!(expected, #{actual:W}); """, "Document" to RuntimeType.document(TestRuntimeConfig), - "actual" to gen.generate( - Literal.fromNode( - Node.objectNode().withMember("a", true).withMember("b", "hello") - .withMember("c", ArrayNode.arrayNode(BooleanNode.from(true))), + "actual" to + gen.generate( + Literal.fromNode( + Node.objectNode().withMember("a", true).withMember("b", "hello") + .withMember("c", ArrayNode.arrayNode(BooleanNode.from(true))), + ), ), - ), ) } } diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/rulesgen/TemplateGeneratorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/rulesgen/TemplateGeneratorTest.kt index 95188a9b5dd..282643d6813 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/rulesgen/TemplateGeneratorTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/rulesgen/TemplateGeneratorTest.kt @@ -19,7 +19,10 @@ internal class TemplateGeneratorTest { /** * helper to assert that a template string is templated to the expected result */ - private fun assertTemplateEquals(template: String, result: String) { + private fun assertTemplateEquals( + template: String, + result: String, + ) { val literalTemplate = Template.fromString(template) // For testing, val exprFn = { expr: Expression, ownership: Ownership -> diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientEnumGeneratorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientEnumGeneratorTest.kt index 9c0817e5c55..27b59311bbe 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientEnumGeneratorTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientEnumGeneratorTest.kt @@ -19,7 +19,11 @@ import software.amazon.smithy.rust.codegen.core.util.lookup class ClientEnumGeneratorTest { @Test fun `matching on enum should be forward-compatible`() { - fun expectMatchExpressionCompiles(model: Model, shapeId: String, enumToMatchOn: String) { + fun expectMatchExpressionCompiles( + model: Model, + shapeId: String, + enumToMatchOn: String, + ) { val shape = model.lookup(shapeId) val context = testClientCodegenContext(model) val project = TestWorkspace.testProject(context.symbolProvider) @@ -40,7 +44,8 @@ class ClientEnumGeneratorTest { project.compileAndTest() } - val modelV1 = """ + val modelV1 = + """ namespace test @enum([ @@ -48,11 +53,12 @@ class ClientEnumGeneratorTest { { name: "Variant2", value: "Variant2" }, ]) string SomeEnum - """.asSmithyModel() + """.asSmithyModel() val variant3AsUnknown = """SomeEnum::from("Variant3")""" expectMatchExpressionCompiles(modelV1, "test#SomeEnum", variant3AsUnknown) - val modelV2 = """ + val modelV2 = + """ namespace test @enum([ @@ -61,21 +67,22 @@ class ClientEnumGeneratorTest { { name: "Variant3", value: "Variant3" }, ]) string SomeEnum - """.asSmithyModel() + """.asSmithyModel() val variant3AsVariant3 = "SomeEnum::Variant3" expectMatchExpressionCompiles(modelV2, "test#SomeEnum", variant3AsVariant3) } @Test fun `impl debug for non-sensitive enum should implement the derived debug trait`() { - val model = """ + val model = + """ namespace test @enum([ { name: "Foo", value: "Foo" }, { name: "Bar", value: "Bar" }, ]) string SomeEnum - """.asSmithyModel() + """.asSmithyModel() val shape = model.lookup("test#SomeEnum") val context = testClientCodegenContext(model) @@ -99,7 +106,8 @@ class ClientEnumGeneratorTest { @Test fun `it escapes the Unknown variant if the enum has an unknown value in the model`() { - val model = """ + val model = + """ namespace test @enum([ { name: "Known", value: "Known" }, @@ -107,7 +115,7 @@ class ClientEnumGeneratorTest { { name: "UnknownValue", value: "UnknownValue" }, ]) string SomeEnum - """.asSmithyModel() + """.asSmithyModel() val shape = model.lookup("test#SomeEnum") val context = testClientCodegenContext(model) @@ -131,14 +139,15 @@ class ClientEnumGeneratorTest { @Test fun `generated named enums can roundtrip between string and enum value on the unknown variant`() { - val model = """ + val model = + """ namespace test @enum([ { value: "t2.nano", name: "T2_NANO" }, { value: "t2.micro", name: "T2_MICRO" }, ]) string InstanceType - """.asSmithyModel() + """.asSmithyModel() val shape = model.lookup("test#InstanceType") val context = testClientCodegenContext(model) diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientInstantiatorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientInstantiatorTest.kt index 49905fdeb9c..00fa942961d 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientInstantiatorTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientInstantiatorTest.kt @@ -19,7 +19,8 @@ import software.amazon.smithy.rust.codegen.core.util.dq import software.amazon.smithy.rust.codegen.core.util.lookup internal class ClientInstantiatorTest { - private val model = """ + private val model = + """ namespace com.test @enum([ @@ -39,7 +40,7 @@ internal class ClientInstantiatorTest { }, ]) string NamedEnum - """.asSmithyModel() + """.asSmithyModel() private val codegenContext = testClientCodegenContext(model) private val symbolProvider = codegenContext.symbolProvider diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ConfigOverrideRuntimePluginGeneratorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ConfigOverrideRuntimePluginGeneratorTest.kt index b60f99a67de..fb1130784d3 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ConfigOverrideRuntimePluginGeneratorTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ConfigOverrideRuntimePluginGeneratorTest.kt @@ -18,7 +18,8 @@ import software.amazon.smithy.rust.codegen.core.testutil.tokioTest import software.amazon.smithy.rust.codegen.core.testutil.unitTest internal class ConfigOverrideRuntimePluginGeneratorTest { - private val model = """ + private val model = + """ namespace com.example use aws.protocols#awsJson1_0 @@ -33,20 +34,22 @@ internal class ConfigOverrideRuntimePluginGeneratorTest { structure TestInput { foo: String, } - """.asSmithyModel() + """.asSmithyModel() @Test fun `operation overrides endpoint resolver`() { clientIntegrationTest(model) { clientCodegenContext, rustCrate -> val runtimeConfig = clientCodegenContext.runtimeConfig - val codegenScope = arrayOf( - *preludeScope, - "EndpointResolverParams" to RuntimeType.smithyRuntimeApi(runtimeConfig) - .resolve("client::endpoint::EndpointResolverParams"), - "RuntimePlugin" to RuntimeType.runtimePlugin(runtimeConfig), - "RuntimeComponentsBuilder" to RuntimeType.runtimeComponentsBuilder(runtimeConfig), - "capture_request" to RuntimeType.captureRequest(runtimeConfig), - ) + val codegenScope = + arrayOf( + *preludeScope, + "EndpointResolverParams" to + RuntimeType.smithyRuntimeApi(runtimeConfig) + .resolve("client::endpoint::EndpointResolverParams"), + "RuntimePlugin" to RuntimeType.runtimePlugin(runtimeConfig), + "RuntimeComponentsBuilder" to RuntimeType.runtimeComponentsBuilder(runtimeConfig), + "capture_request" to RuntimeType.captureRequest(runtimeConfig), + ) rustCrate.testModule { addDependency(CargoDependency.Tokio.toDevDependency().withFeature("test-util")) tokioTest("test_operation_overrides_endpoint_resolver") { @@ -72,10 +75,11 @@ internal class ConfigOverrideRuntimePluginGeneratorTest { fun `operation overrides http connector`() { clientIntegrationTest(model) { clientCodegenContext, rustCrate -> val runtimeConfig = clientCodegenContext.runtimeConfig - val codegenScope = arrayOf( - *preludeScope, - "RuntimePlugin" to RuntimeType.runtimePlugin(runtimeConfig), - ) + val codegenScope = + arrayOf( + *preludeScope, + "RuntimePlugin" to RuntimeType.runtimePlugin(runtimeConfig), + ) rustCrate.testModule { addDependency(CargoDependency.Tokio.toDevDependency().withFeature("test-util")) tokioTest("test_operation_overrides_http_client") { @@ -130,11 +134,13 @@ internal class ConfigOverrideRuntimePluginGeneratorTest { *codegenScope, "AsyncSleep" to RuntimeType.smithyAsync(runtimeConfig).resolve("rt::sleep::AsyncSleep"), "capture_request" to RuntimeType.captureRequest(runtimeConfig), - "NeverClient" to CargoDependency.smithyRuntimeTestUtil(runtimeConfig).toType() - .resolve("client::http::test_util::NeverClient"), + "NeverClient" to + CargoDependency.smithyRuntimeTestUtil(runtimeConfig).toType() + .resolve("client::http::test_util::NeverClient"), "Timeout" to RuntimeType.smithyAsync(runtimeConfig).resolve("future::timeout::Timeout"), - "TokioSleep" to CargoDependency.smithyAsync(runtimeConfig).withFeature("rt-tokio") - .toType().resolve("rt::sleep::TokioSleep"), + "TokioSleep" to + CargoDependency.smithyAsync(runtimeConfig).withFeature("rt-tokio") + .toType().resolve("rt::sleep::TokioSleep"), ) } } @@ -145,30 +151,38 @@ internal class ConfigOverrideRuntimePluginGeneratorTest { fun `operation overrides retry config`() { clientIntegrationTest(model) { clientCodegenContext, rustCrate -> val runtimeConfig = clientCodegenContext.runtimeConfig - val codegenScope = arrayOf( - *preludeScope, - "AlwaysRetry" to RuntimeType.smithyRuntimeApi(runtimeConfig) - .resolve("client::retries::AlwaysRetry"), - "ConfigBag" to RuntimeType.smithyTypes(runtimeConfig).resolve("config_bag::ConfigBag"), - "ErrorKind" to RuntimeType.smithyTypes(runtimeConfig).resolve("retry::ErrorKind"), - "Input" to RuntimeType.smithyRuntimeApi(runtimeConfig).resolve("client::interceptors::context::Input"), - "InterceptorContext" to RuntimeType.interceptorContext(runtimeConfig), - "Layer" to RuntimeType.smithyTypes(runtimeConfig).resolve("config_bag::Layer"), - "OrchestratorError" to RuntimeType.smithyRuntimeApi(runtimeConfig) - .resolve("client::orchestrator::OrchestratorError"), - "RetryConfig" to RuntimeType.smithyTypes(clientCodegenContext.runtimeConfig) - .resolve("retry::RetryConfig"), - "RequestAttempts" to smithyRuntimeApiTestUtil(runtimeConfig).toType() - .resolve("client::retries::RequestAttempts"), - "RetryClassifiers" to RuntimeType.smithyRuntimeApi(runtimeConfig) - .resolve("client::retries::RetryClassifiers"), - "RuntimeComponentsBuilder" to RuntimeType.runtimeComponentsBuilder(runtimeConfig), - "RuntimePlugin" to RuntimeType.runtimePlugin(runtimeConfig), - "StandardRetryStrategy" to RuntimeType.smithyRuntime(runtimeConfig) - .resolve("client::retries::strategy::StandardRetryStrategy"), - "ShouldAttempt" to RuntimeType.smithyRuntimeApi(runtimeConfig) - .resolve("client::retries::ShouldAttempt"), - ) + val codegenScope = + arrayOf( + *preludeScope, + "AlwaysRetry" to + RuntimeType.smithyRuntimeApi(runtimeConfig) + .resolve("client::retries::AlwaysRetry"), + "ConfigBag" to RuntimeType.smithyTypes(runtimeConfig).resolve("config_bag::ConfigBag"), + "ErrorKind" to RuntimeType.smithyTypes(runtimeConfig).resolve("retry::ErrorKind"), + "Input" to RuntimeType.smithyRuntimeApi(runtimeConfig).resolve("client::interceptors::context::Input"), + "InterceptorContext" to RuntimeType.interceptorContext(runtimeConfig), + "Layer" to RuntimeType.smithyTypes(runtimeConfig).resolve("config_bag::Layer"), + "OrchestratorError" to + RuntimeType.smithyRuntimeApi(runtimeConfig) + .resolve("client::orchestrator::OrchestratorError"), + "RetryConfig" to + RuntimeType.smithyTypes(clientCodegenContext.runtimeConfig) + .resolve("retry::RetryConfig"), + "RequestAttempts" to + smithyRuntimeApiTestUtil(runtimeConfig).toType() + .resolve("client::retries::RequestAttempts"), + "RetryClassifiers" to + RuntimeType.smithyRuntimeApi(runtimeConfig) + .resolve("client::retries::RetryClassifiers"), + "RuntimeComponentsBuilder" to RuntimeType.runtimeComponentsBuilder(runtimeConfig), + "RuntimePlugin" to RuntimeType.runtimePlugin(runtimeConfig), + "StandardRetryStrategy" to + RuntimeType.smithyRuntime(runtimeConfig) + .resolve("client::retries::strategy::StandardRetryStrategy"), + "ShouldAttempt" to + RuntimeType.smithyRuntimeApi(runtimeConfig) + .resolve("client::retries::ShouldAttempt"), + ) rustCrate.testModule { unitTest("test_operation_overrides_retry_config") { rustTemplate( diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/EndpointTraitBindingsTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/EndpointTraitBindingsTest.kt index c645759040f..8a10b735a3e 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/EndpointTraitBindingsTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/EndpointTraitBindingsTest.kt @@ -37,7 +37,8 @@ internal class EndpointTraitBindingsTest { @Test fun `generate endpoint prefixes`() { - val model = """ + val model = + """ namespace test @readonly @endpoint(hostPrefix: "{foo}a.data.") @@ -49,16 +50,17 @@ internal class EndpointTraitBindingsTest { @hostLabel foo: String } - """.asSmithyModel() + """.asSmithyModel() val operationShape: OperationShape = model.lookup("test#GetStatus") val symbolProvider = testSymbolProvider(model) - val endpointBindingGenerator = EndpointTraitBindings( - model, - symbolProvider, - TestRuntimeConfig, - operationShape, - operationShape.expectTrait(EndpointTrait::class.java), - ) + val endpointBindingGenerator = + EndpointTraitBindings( + model, + symbolProvider, + TestRuntimeConfig, + operationShape, + operationShape.expectTrait(EndpointTrait::class.java), + ) val project = TestWorkspace.testProject() project.withModule(RustModule.private("test")) { rust( @@ -118,7 +120,8 @@ internal class EndpointTraitBindingsTest { @ExperimentalPathApi @Test fun `endpoint integration test`() { - val model = """ + val model = + """ namespace com.example use aws.protocols#awsJson1_0 use smithy.rules#endpointRuleSet @@ -150,7 +153,7 @@ internal class EndpointTraitBindingsTest { @hostLabel greeting: String } - """.asSmithyModel() + """.asSmithyModel() clientIntegrationTest(model) { clientCodegenContext, rustCrate -> val moduleName = clientCodegenContext.moduleUseName() rustCrate.integrationTest("test_endpoint_prefix") { @@ -244,8 +247,9 @@ internal class EndpointTraitBindingsTest { ); } """, - "capture_request" to CargoDependency.smithyRuntimeTestUtil(clientCodegenContext.runtimeConfig) - .toType().resolve("client::http::test_util::capture_request"), + "capture_request" to + CargoDependency.smithyRuntimeTestUtil(clientCodegenContext.runtimeConfig) + .toType().resolve("client::http::test_util::capture_request"), ) } } diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ErrorCorrectionTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ErrorCorrectionTest.kt index 69691b476e5..5f0baf7816f 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ErrorCorrectionTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ErrorCorrectionTest.kt @@ -13,7 +13,8 @@ import software.amazon.smithy.rust.codegen.core.testutil.unitTest import software.amazon.smithy.rust.codegen.core.util.lookup class ErrorCorrectionTest { - private val model = """ + private val model = + """ namespace com.example use aws.protocols#awsJson1_0 @@ -78,7 +79,7 @@ class ErrorCorrectionTest { key: String, value: StringList } - """.asSmithyModel(smithyVersion = "2.0") + """.asSmithyModel(smithyVersion = "2.0") @Test fun correctMissingFields() { diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/PaginatorGeneratorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/PaginatorGeneratorTest.kt index 4b3d7f3a5f7..8c8a4784e68 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/PaginatorGeneratorTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/PaginatorGeneratorTest.kt @@ -13,7 +13,8 @@ import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.testutil.integrationTest internal class PaginatorGeneratorTest { - private val model = """ + private val model = + """ namespace test use aws.protocols#awsJson1_1 @@ -67,7 +68,7 @@ internal class PaginatorGeneratorTest { key: String, value: Integer } - """.asSmithyModel() + """.asSmithyModel() @Test fun `generate paginators that compile`() { diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/SensitiveIndexTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/SensitiveIndexTest.kt index cecf4235885..7cc3f01016e 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/SensitiveIndexTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/SensitiveIndexTest.kt @@ -11,7 +11,8 @@ import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.util.lookup class SensitiveIndexTest { - val model = """ + val model = + """ namespace com.example service TestService { operations: [ @@ -58,7 +59,7 @@ class SensitiveIndexTest { structure Inner { credentials: Credentials } - """.asSmithyModel(smithyVersion = "2.0") + """.asSmithyModel(smithyVersion = "2.0") @Test fun `correctly identify operations`() { @@ -66,13 +67,14 @@ class SensitiveIndexTest { data class TestCase(val shape: String, val sensitiveInput: Boolean, val sensitiveOutput: Boolean) - val testCases = listOf( - TestCase("NotSensitive", sensitiveInput = false, sensitiveOutput = false), - TestCase("SensitiveInput", sensitiveInput = true, sensitiveOutput = false), - TestCase("SensitiveOutput", sensitiveInput = false, sensitiveOutput = true), - TestCase("NestedSensitiveInput", sensitiveInput = true, sensitiveOutput = false), - TestCase("NestedSensitiveOutput", sensitiveInput = false, sensitiveOutput = true), - ) + val testCases = + listOf( + TestCase("NotSensitive", sensitiveInput = false, sensitiveOutput = false), + TestCase("SensitiveInput", sensitiveInput = true, sensitiveOutput = false), + TestCase("SensitiveOutput", sensitiveInput = false, sensitiveOutput = true), + TestCase("NestedSensitiveInput", sensitiveInput = true, sensitiveOutput = false), + TestCase("NestedSensitiveOutput", sensitiveInput = false, sensitiveOutput = true), + ) testCases.forEach { tc -> assertEquals(tc.sensitiveInput, index.hasSensitiveInput(model.lookup("com.example#${tc.shape}")), "input: $tc") assertEquals(tc.sensitiveOutput, index.hasSensitiveOutput(model.lookup("com.example#${tc.shape}")), "output: $tc ") diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/CustomizableOperationGeneratorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/CustomizableOperationGeneratorTest.kt index c14d6b21fbe..30090104f9e 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/CustomizableOperationGeneratorTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/CustomizableOperationGeneratorTest.kt @@ -15,7 +15,8 @@ import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.testutil.integrationTest class CustomizableOperationGeneratorTest { - val model = """ + val model = + """ namespace com.example use aws.protocols#awsJson1_0 @@ -30,7 +31,7 @@ class CustomizableOperationGeneratorTest { structure TestInput { foo: String, } - """.asSmithyModel() + """.asSmithyModel() @Test fun `CustomizableOperation is send and sync`() { @@ -51,8 +52,9 @@ class CustomizableOperationGeneratorTest { check_send_and_sync(client.say_hello().customize()); } """, - "NeverClient" to CargoDependency.smithyRuntimeTestUtil(codegenContext.runtimeConfig).toType() - .resolve("client::http::test_util::NeverClient"), + "NeverClient" to + CargoDependency.smithyRuntimeTestUtil(codegenContext.runtimeConfig).toType() + .resolve("client::http::test_util::NeverClient"), ) } } diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/FluentClientGeneratorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/FluentClientGeneratorTest.kt index f36c0dbfdcc..b300dfa8c2f 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/FluentClientGeneratorTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/FluentClientGeneratorTest.kt @@ -17,7 +17,8 @@ import software.amazon.smithy.rust.codegen.core.testutil.integrationTest import software.amazon.smithy.rust.codegen.core.util.lookup class FluentClientGeneratorTest { - val model = """ + val model = + """ namespace com.example use aws.protocols#awsJson1_0 @@ -49,16 +50,17 @@ class FluentClientGeneratorTest { key: String, value: StringList } - """.asSmithyModel() + """.asSmithyModel() @Test fun `generate correct input docs`() { - val expectations = mapOf( - "listValue" to "list_value(impl Into)", - "doubleListValue" to "double_list_value(Vec::)", - "mapValue" to "map_value(impl Into, Vec::)", - "byteValue" to "byte_value(i8)", - ) + val expectations = + mapOf( + "listValue" to "list_value(impl Into)", + "doubleListValue" to "double_list_value(Vec::)", + "mapValue" to "map_value(impl Into, Vec::)", + "byteValue" to "byte_value(i8)", + ) expectations.forEach { (name, expect) -> val member = model.lookup("com.example#TestInput\$$name") member.asFluentBuilderInputDoc(testSymbolProvider(model)) shouldBe expect @@ -84,8 +86,9 @@ class FluentClientGeneratorTest { check_send(client.say_hello().send()); } """, - "NeverClient" to CargoDependency.smithyRuntimeTestUtil(codegenContext.runtimeConfig).toType() - .resolve("client::http::test_util::NeverClient"), + "NeverClient" to + CargoDependency.smithyRuntimeTestUtil(codegenContext.runtimeConfig).toType() + .resolve("client::http::test_util::NeverClient"), ) } } @@ -112,8 +115,9 @@ class FluentClientGeneratorTest { assert_eq!(*input.get_byte_value(), Some(4)); } """, - "NeverClient" to CargoDependency.smithyRuntimeTestUtil(codegenContext.runtimeConfig).toType() - .resolve("client::http::test_util::NeverClient"), + "NeverClient" to + CargoDependency.smithyRuntimeTestUtil(codegenContext.runtimeConfig).toType() + .resolve("client::http::test_util::NeverClient"), ) } } @@ -121,7 +125,8 @@ class FluentClientGeneratorTest { @Test fun `dead-code warning should not be issued when a service has no operations`() { - val model = """ + val model = + """ namespace com.example use aws.protocols#awsJson1_0 @@ -129,7 +134,7 @@ class FluentClientGeneratorTest { service HelloService { version: "1" } - """.asSmithyModel() + """.asSmithyModel() clientIntegrationTest(model) } diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/config/ServiceConfigGeneratorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/config/ServiceConfigGeneratorTest.kt index 7cdd0f7b890..4e9e88f414e 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/config/ServiceConfigGeneratorTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/config/ServiceConfigGeneratorTest.kt @@ -26,7 +26,8 @@ import software.amazon.smithy.rust.codegen.core.util.toPascalCase internal class ServiceConfigGeneratorTest { @Test fun `idempotency token when used`() { - fun model(trait: String) = """ + fun model(trait: String) = + """ namespace com.example use aws.protocols#restJson1 @@ -47,7 +48,7 @@ internal class ServiceConfigGeneratorTest { $trait tok: String } - """.asSmithyModel() + """.asSmithyModel() val withToken = model("@idempotencyToken") val withoutToken = model("") @@ -57,7 +58,8 @@ internal class ServiceConfigGeneratorTest { @Test fun `find idempotency token via resources`() { - val model = """ + val model = + """ namespace com.example service ResourceService { resources: [Resource], @@ -75,7 +77,7 @@ internal class ServiceConfigGeneratorTest { @idempotencyToken tok: String } - """.asSmithyModel() + """.asSmithyModel() model.lookup("com.example#ResourceService").needsIdempotencyToken(model) shouldBe true } @@ -83,57 +85,62 @@ internal class ServiceConfigGeneratorTest { fun `generate customizations as specified`() { class ServiceCustomizer(private val codegenContext: ClientCodegenContext) : NamedCustomization() { - override fun section(section: ServiceConfig): Writable { return when (section) { ServiceConfig.ConfigStructAdditionalDocs -> emptySection - ServiceConfig.ConfigImpl -> writable { - rustTemplate( - """ - ##[allow(missing_docs)] - pub fn config_field(&self) -> u64 { - self.config.load::<#{T}>().map(|u| u.0).unwrap() - } - """, - "T" to configParamNewtype( - "config_field".toPascalCase(), RuntimeType.U64.toSymbol(), - codegenContext.runtimeConfig, - ), - ) - } + ServiceConfig.ConfigImpl -> + writable { + rustTemplate( + """ + ##[allow(missing_docs)] + pub fn config_field(&self) -> u64 { + self.config.load::<#{T}>().map(|u| u.0).unwrap() + } + """, + "T" to + configParamNewtype( + "config_field".toPascalCase(), RuntimeType.U64.toSymbol(), + codegenContext.runtimeConfig, + ), + ) + } - ServiceConfig.BuilderImpl -> writable { - rustTemplate( - """ - ##[allow(missing_docs)] - pub fn config_field(mut self, config_field: u64) -> Self { - self.config.store_put(#{T}(config_field)); - self - } - """, - "T" to configParamNewtype( - "config_field".toPascalCase(), RuntimeType.U64.toSymbol(), - codegenContext.runtimeConfig, - ), - ) - } + ServiceConfig.BuilderImpl -> + writable { + rustTemplate( + """ + ##[allow(missing_docs)] + pub fn config_field(mut self, config_field: u64) -> Self { + self.config.store_put(#{T}(config_field)); + self + } + """, + "T" to + configParamNewtype( + "config_field".toPascalCase(), RuntimeType.U64.toSymbol(), + codegenContext.runtimeConfig, + ), + ) + } else -> emptySection } } } - val serviceDecorator = object : ClientCodegenDecorator { - override val name: String = "Add service plugin" - override val order: Byte = 0 - override fun configCustomizations( - codegenContext: ClientCodegenContext, - baseCustomizations: List, - ): List { - return baseCustomizations + ServiceCustomizer(codegenContext) + val serviceDecorator = + object : ClientCodegenDecorator { + override val name: String = "Add service plugin" + override val order: Byte = 0 + + override fun configCustomizations( + codegenContext: ClientCodegenContext, + baseCustomizations: List, + ): List { + return baseCustomizations + ServiceCustomizer(codegenContext) + } } - } clientIntegrationTest(BasicTestModels.AwsJson10TestModel, additionalDecorators = listOf(serviceDecorator)) { ctx, rustCrate -> rustCrate.withModule(ClientRustModule.config) { diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/error/OperationErrorGeneratorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/error/OperationErrorGeneratorTest.kt index 1f1b7032411..656eb80acfa 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/error/OperationErrorGeneratorTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/error/OperationErrorGeneratorTest.kt @@ -13,7 +13,8 @@ import software.amazon.smithy.rust.codegen.core.testutil.unitTest import software.amazon.smithy.rust.codegen.core.util.lookup class OperationErrorGeneratorTest { - private val model = """ + private val model = + """ namespace error @aws.protocols#awsJson1_0 @@ -43,7 +44,7 @@ class OperationErrorGeneratorTest { @error("server") @deprecated structure Deprecated { } - """.asSmithyModel() + """.asSmithyModel() @Test fun `generates combined error enums`() { diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/error/ServiceErrorGeneratorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/error/ServiceErrorGeneratorTest.kt index a8ce5b832a0..a88079eee68 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/error/ServiceErrorGeneratorTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/error/ServiceErrorGeneratorTest.kt @@ -15,7 +15,8 @@ import software.amazon.smithy.rust.codegen.core.testutil.unitTest import software.amazon.smithy.rust.codegen.core.util.lookup internal class ServiceErrorGeneratorTest { - private val model = """ + private val model = + """ namespace com.example use aws.protocols#restJson1 @@ -44,7 +45,7 @@ internal class ServiceErrorGeneratorTest { @error("client") @deprecated structure MeDeprecated { } - """.asSmithyModel() + """.asSmithyModel() @Test fun `top level errors are send + sync`() { diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/http/RequestBindingGeneratorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/http/RequestBindingGeneratorTest.kt index a769504f134..c05b65b0cba 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/http/RequestBindingGeneratorTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/http/RequestBindingGeneratorTest.kt @@ -30,7 +30,8 @@ import software.amazon.smithy.rust.codegen.core.util.dq import software.amazon.smithy.rust.codegen.core.util.expectTrait class RequestBindingGeneratorTest { - private val baseModel = """ + private val baseModel = + """ namespace smithy.example @idempotent @@ -120,7 +121,7 @@ class RequestBindingGeneratorTest { @sensitive string SensitiveStringHeader - """.asSmithyModel() + """.asSmithyModel() private val model = OperationNormalizer.transform(baseModel) private val symbolProvider = testSymbolProvider(model) private val operationShape = model.expectShape(ShapeId.from("smithy.example#PutObject"), OperationShape::class.java) @@ -131,12 +132,13 @@ class RequestBindingGeneratorTest { inputShape.renderWithModelBuilder(model, symbolProvider, rustCrate) rustCrate.withModule(operationModule) { val codegenContext = testClientCodegenContext(model) - val bindingGen = RequestBindingGenerator( - codegenContext, - // Any protocol is fine for this test. - RestJson(codegenContext), - operationShape, - ) + val bindingGen = + RequestBindingGenerator( + codegenContext, + // Any protocol is fine for this test. + RestJson(codegenContext), + operationShape, + ) rustBlock("impl PutObjectInput") { // RequestBindingGenerator's functions expect to be rendered inside a function, // but the unit test needs to call some of these functions individually. This generates diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/http/ResponseBindingGeneratorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/http/ResponseBindingGeneratorTest.kt index b02e6879b6d..e0a9f063470 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/http/ResponseBindingGeneratorTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/http/ResponseBindingGeneratorTest.kt @@ -27,7 +27,8 @@ import software.amazon.smithy.rust.codegen.core.util.lookup import software.amazon.smithy.rust.codegen.core.util.outputShape class ResponseBindingGeneratorTest { - private val baseModel = """ + private val baseModel = + """ namespace smithy.example @idempotent @@ -64,7 +65,7 @@ class ResponseBindingGeneratorTest { // Sent in the body additional: String, } - """.asSmithyModel() + """.asSmithyModel() private val model = OperationNormalizer.transform(baseModel) private val operationShape: OperationShape = model.lookup("smithy.example#PutObject") private val outputShape: StructureShape = operationShape.outputShape(model) @@ -75,15 +76,17 @@ class ResponseBindingGeneratorTest { operationShape.outputShape(model).renderWithModelBuilder(model, symbolProvider, this) withModule(symbolProvider.moduleForShape(outputShape)) { rustBlock("impl PutObjectOutput") { - val bindings = HttpTraitHttpBindingResolver(model, ProtocolContentTypes.consistent("dont-care")) - .responseBindings(operationShape) - .filter { it.location == HttpLocation.HEADER } + val bindings = + HttpTraitHttpBindingResolver(model, ProtocolContentTypes.consistent("dont-care")) + .responseBindings(operationShape) + .filter { it.location == HttpLocation.HEADER } bindings.forEach { binding -> - val runtimeType = ResponseBindingGenerator( - RestJson(codegenContext), - codegenContext, - operationShape, - ).generateDeserializeHeaderFn(binding) + val runtimeType = + ResponseBindingGenerator( + RestJson(codegenContext), + codegenContext, + operationShape, + ).generateDeserializeHeaderFn(binding) // little hack to force these functions to be generated rust("// use #T;", runtimeType) } diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolParserGeneratorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolParserGeneratorTest.kt index 760eced3874..86b2cacfc06 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolParserGeneratorTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolParserGeneratorTest.kt @@ -10,19 +10,20 @@ import software.amazon.smithy.rust.codegen.client.testutil.clientIntegrationTest import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel class ProtocolParserGeneratorTest { - private val model = """ + private val model = + """ ${'$'}version: "2.0" namespace test - + use aws.protocols#restJson1 - + @restJson1 service TestService { version: "2019-12-16", operations: [SomeOperation] errors: [SomeTopLevelError] } - + @http(uri: "/SomeOperation", method: "POST") operation SomeOperation { input: SomeOperationInputOutput, @@ -35,7 +36,7 @@ class ProtocolParserGeneratorTest { a: String, b: Integer } - + @error("server") structure SomeTopLevelError { @required @@ -48,7 +49,7 @@ class ProtocolParserGeneratorTest { context: String } - + @error("client") structure SomeOperationError { @required @@ -61,8 +62,8 @@ class ProtocolParserGeneratorTest { context: String } - """ - .asSmithyModel() + """ + .asSmithyModel() @Test fun `generate an complex error structure that compiles`() { diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolTestGeneratorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolTestGeneratorTest.kt index 401e7b8007f..83a6f52e9ae 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolTestGeneratorTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolTestGeneratorTest.kt @@ -31,97 +31,100 @@ private class TestServiceRuntimePluginCustomization( private val fakeRequestBuilder: String, private val fakeRequestBody: String, ) : ServiceRuntimePluginCustomization() { - override fun section(section: ServiceRuntimePluginSection): Writable = writable { - if (section is ServiceRuntimePluginSection.RegisterRuntimeComponents) { - val rc = context.runtimeConfig - section.registerInterceptor(this) { - rustTemplate( - """ - { - ##[derive(::std::fmt::Debug)] - struct TestInterceptor; - impl #{Intercept} for TestInterceptor { - fn name(&self) -> &'static str { - "TestInterceptor" - } + override fun section(section: ServiceRuntimePluginSection): Writable = + writable { + if (section is ServiceRuntimePluginSection.RegisterRuntimeComponents) { + val rc = context.runtimeConfig + section.registerInterceptor(this) { + rustTemplate( + """ + { + ##[derive(::std::fmt::Debug)] + struct TestInterceptor; + impl #{Intercept} for TestInterceptor { + fn name(&self) -> &'static str { + "TestInterceptor" + } - fn modify_before_retry_loop( - &self, - context: &mut #{BeforeTransmitInterceptorContextMut}<'_>, - _rc: &#{RuntimeComponents}, - _cfg: &mut #{ConfigBag}, - ) -> #{Result}<(), #{BoxError}> { - // Replace the serialized request - let mut fake_req = ::http::Request::builder() - $fakeRequestBuilder - .body(#{SdkBody}::from($fakeRequestBody)) - .expect("valid request").try_into().unwrap(); - ::std::mem::swap( - context.request_mut(), - &mut fake_req, - ); - Ok(()) + fn modify_before_retry_loop( + &self, + context: &mut #{BeforeTransmitInterceptorContextMut}<'_>, + _rc: &#{RuntimeComponents}, + _cfg: &mut #{ConfigBag}, + ) -> #{Result}<(), #{BoxError}> { + // Replace the serialized request + let mut fake_req = ::http::Request::builder() + $fakeRequestBuilder + .body(#{SdkBody}::from($fakeRequestBody)) + .expect("valid request").try_into().unwrap(); + ::std::mem::swap( + context.request_mut(), + &mut fake_req, + ); + Ok(()) + } } - } - TestInterceptor - } - """, - *preludeScope, - "BeforeTransmitInterceptorContextMut" to RT.beforeTransmitInterceptorContextMut(rc), - "BoxError" to RT.boxError(rc), - "ConfigBag" to RT.configBag(rc), - "Intercept" to RT.intercept(rc), - "RuntimeComponents" to RT.runtimeComponents(rc), - "SdkBody" to RT.sdkBody(rc), - ) + TestInterceptor + } + """, + *preludeScope, + "BeforeTransmitInterceptorContextMut" to RT.beforeTransmitInterceptorContextMut(rc), + "BoxError" to RT.boxError(rc), + "ConfigBag" to RT.configBag(rc), + "Intercept" to RT.intercept(rc), + "RuntimeComponents" to RT.runtimeComponents(rc), + "SdkBody" to RT.sdkBody(rc), + ) + } } } - } } private class TestOperationCustomization( private val context: ClientCodegenContext, private val fakeOutput: String, ) : OperationCustomization() { - override fun section(section: OperationSection): Writable = writable { - val rc = context.runtimeConfig - if (section is OperationSection.AdditionalRuntimePluginConfig) { - rustTemplate( - """ - // Override the default response deserializer with our fake output - ##[derive(::std::fmt::Debug)] - struct TestDeser; - impl #{DeserializeResponse} for TestDeser { - fn deserialize_nonstreaming( - &self, - _response: &#{HttpResponse}, - ) -> #{Result}<#{Output}, #{OrchestratorError}<#{Error}>> { - let fake_out: #{Result}< - crate::operation::say_hello::SayHelloOutput, - crate::operation::say_hello::SayHelloError, - > = $fakeOutput; - fake_out - .map(|o| #{Output}::erase(o)) - .map_err(|e| #{OrchestratorError}::operation(#{Error}::erase(e))) + override fun section(section: OperationSection): Writable = + writable { + val rc = context.runtimeConfig + if (section is OperationSection.AdditionalRuntimePluginConfig) { + rustTemplate( + """ + // Override the default response deserializer with our fake output + ##[derive(::std::fmt::Debug)] + struct TestDeser; + impl #{DeserializeResponse} for TestDeser { + fn deserialize_nonstreaming( + &self, + _response: &#{HttpResponse}, + ) -> #{Result}<#{Output}, #{OrchestratorError}<#{Error}>> { + let fake_out: #{Result}< + crate::operation::say_hello::SayHelloOutput, + crate::operation::say_hello::SayHelloError, + > = $fakeOutput; + fake_out + .map(|o| #{Output}::erase(o)) + .map_err(|e| #{OrchestratorError}::operation(#{Error}::erase(e))) + } } - } - cfg.store_put(#{SharedResponseDeserializer}::new(TestDeser)); - """, - *preludeScope, - "SharedResponseDeserializer" to RT.smithyRuntimeApi(rc).resolve("client::ser_de::SharedResponseDeserializer"), - "Error" to RT.smithyRuntimeApi(rc).resolve("client::interceptors::context::Error"), - "HttpResponse" to RT.smithyRuntimeApi(rc).resolve("client::orchestrator::HttpResponse"), - "OrchestratorError" to RT.smithyRuntimeApi(rc).resolve("client::orchestrator::OrchestratorError"), - "Output" to RT.smithyRuntimeApi(rc).resolve("client::interceptors::context::Output"), - "DeserializeResponse" to RT.smithyRuntimeApi(rc).resolve("client::ser_de::DeserializeResponse"), - ) + cfg.store_put(#{SharedResponseDeserializer}::new(TestDeser)); + """, + *preludeScope, + "SharedResponseDeserializer" to RT.smithyRuntimeApi(rc).resolve("client::ser_de::SharedResponseDeserializer"), + "Error" to RT.smithyRuntimeApi(rc).resolve("client::interceptors::context::Error"), + "HttpResponse" to RT.smithyRuntimeApi(rc).resolve("client::orchestrator::HttpResponse"), + "OrchestratorError" to RT.smithyRuntimeApi(rc).resolve("client::orchestrator::OrchestratorError"), + "Output" to RT.smithyRuntimeApi(rc).resolve("client::interceptors::context::Output"), + "DeserializeResponse" to RT.smithyRuntimeApi(rc).resolve("client::ser_de::DeserializeResponse"), + ) + } } - } } class ProtocolTestGeneratorTest { - private val model = """ + private val model = + """ namespace com.example use aws.protocols#restJson1 @@ -194,7 +197,7 @@ class ProtocolTestGeneratorTest { name: String } - """.asSmithyModel() + """.asSmithyModel() private val correctBody = """{"name": "Teddy"}""" /** @@ -207,26 +210,31 @@ class ProtocolTestGeneratorTest { fakeRequestBody: String = "${correctBody.dq()}.to_string()", fakeOutput: String = """Ok(crate::operation::say_hello::SayHelloOutput::builder().value("hey there!").build())""", ): Path { - val codegenDecorator = object : ClientCodegenDecorator { - override val name: String = "mock" - override val order: Byte = 0 - override fun classpathDiscoverable(): Boolean = false - - override fun serviceRuntimePluginCustomizations( - codegenContext: ClientCodegenContext, - baseCustomizations: List, - ): List = baseCustomizations + TestServiceRuntimePluginCustomization( - codegenContext, - fakeRequestBuilder, - fakeRequestBody, - ) - - override fun operationCustomizations( - codegenContext: ClientCodegenContext, - operation: OperationShape, - baseCustomizations: List, - ): List = baseCustomizations + TestOperationCustomization(codegenContext, fakeOutput) - } + val codegenDecorator = + object : ClientCodegenDecorator { + override val name: String = "mock" + override val order: Byte = 0 + + override fun classpathDiscoverable(): Boolean = false + + override fun serviceRuntimePluginCustomizations( + codegenContext: ClientCodegenContext, + baseCustomizations: List, + ): List = + baseCustomizations + + TestServiceRuntimePluginCustomization( + codegenContext, + fakeRequestBuilder, + fakeRequestBody, + ) + + override fun operationCustomizations( + codegenContext: ClientCodegenContext, + operation: OperationShape, + baseCustomizations: List, + ): List = + baseCustomizations + TestOperationCustomization(codegenContext, fakeOutput) + } return clientIntegrationTest( model, additionalDecorators = listOf(codegenDecorator), @@ -246,32 +254,34 @@ class ProtocolTestGeneratorTest { @Test fun `test incorrect response parsing`() { - val err = assertThrows { - testService( - """ - .uri("/?Hi=Hello%20there&required") - .header("X-Greeting", "Hi") - .method("POST") - """, - fakeOutput = "Ok(crate::operation::say_hello::SayHelloOutput::builder().build())", - ) - } + val err = + assertThrows { + testService( + """ + .uri("/?Hi=Hello%20there&required") + .header("X-Greeting", "Hi") + .method("POST") + """, + fakeOutput = "Ok(crate::operation::say_hello::SayHelloOutput::builder().build())", + ) + } err.message shouldContain "basic_response_test_response ... FAILED" } @Test fun `test invalid body`() { - val err = assertThrows { - testService( - """ - .uri("/?Hi=Hello%20there&required") - .header("X-Greeting", "Hi") - .method("POST") - """, - """"{}".to_string()""", - ) - } + val err = + assertThrows { + testService( + """ + .uri("/?Hi=Hello%20there&required") + .header("X-Greeting", "Hi") + .method("POST") + """, + """"{}".to_string()""", + ) + } err.message shouldContain "say_hello_request ... FAILED" err.message shouldContain "body did not match" @@ -279,15 +289,16 @@ class ProtocolTestGeneratorTest { @Test fun `test invalid url parameter`() { - val err = assertThrows { - testService( - """ - .uri("/?Hi=INCORRECT&required") - .header("X-Greeting", "Hi") - .method("POST") - """, - ) - } + val err = + assertThrows { + testService( + """ + .uri("/?Hi=INCORRECT&required") + .header("X-Greeting", "Hi") + .method("POST") + """, + ) + } // Verify the test actually ran err.message shouldContain "say_hello_request ... FAILED" err.message shouldContain "missing query param" @@ -295,15 +306,16 @@ class ProtocolTestGeneratorTest { @Test fun `test forbidden url parameter`() { - val err = assertThrows { - testService( - """ - .uri("/?goodbye&Hi=Hello%20there&required") - .header("X-Greeting", "Hi") - .method("POST") - """, - ) - } + val err = + assertThrows { + testService( + """ + .uri("/?goodbye&Hi=Hello%20there&required") + .header("X-Greeting", "Hi") + .method("POST") + """, + ) + } // Verify the test actually ran err.message shouldContain "say_hello_request ... FAILED" err.message shouldContain "forbidden query param" @@ -312,15 +324,16 @@ class ProtocolTestGeneratorTest { @Test fun `test required url parameter`() { // Hard coded implementation for this 1 test - val err = assertThrows { - testService( - """ - .uri("/?Hi=Hello%20there") - .header("X-Greeting", "Hi") - .method("POST") - """, - ) - } + val err = + assertThrows { + testService( + """ + .uri("/?Hi=Hello%20there") + .header("X-Greeting", "Hi") + .method("POST") + """, + ) + } // Verify the test actually ran err.message shouldContain "say_hello_request ... FAILED" @@ -329,15 +342,16 @@ class ProtocolTestGeneratorTest { @Test fun `test invalid path`() { - val err = assertThrows { - testService( - """ - .uri("/incorrect-path?required&Hi=Hello%20there") - .header("X-Greeting", "Hi") - .method("POST") - """, - ) - } + val err = + assertThrows { + testService( + """ + .uri("/incorrect-path?required&Hi=Hello%20there") + .header("X-Greeting", "Hi") + .method("POST") + """, + ) + } // Verify the test actually ran err.message shouldContain "say_hello_request ... FAILED" @@ -346,16 +360,17 @@ class ProtocolTestGeneratorTest { @Test fun `invalid header`() { - val err = assertThrows { - testService( - """ - .uri("/?Hi=Hello%20there&required") - // should be "Hi" - .header("X-Greeting", "Hey") - .method("POST") - """, - ) - } + val err = + assertThrows { + testService( + """ + .uri("/?Hi=Hello%20there&required") + // should be "Hi" + .header("X-Greeting", "Hey") + .method("POST") + """, + ) + } err.message shouldContain "say_hello_request ... FAILED" err.message shouldContain "invalid header value" diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/AwsQueryCompatibleTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/AwsQueryCompatibleTest.kt index 706e7b81c01..dcb322b4472 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/AwsQueryCompatibleTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/AwsQueryCompatibleTest.kt @@ -16,7 +16,8 @@ import software.amazon.smithy.rust.codegen.core.util.lookup class AwsQueryCompatibleTest { @Test fun `aws-query-compatible json with aws query error should allow for retrieving error code and type from custom header`() { - val model = """ + val model = + """ namespace test use aws.protocols#awsJson1_0 use aws.protocols#awsQueryCompatible @@ -48,7 +49,7 @@ class AwsQueryCompatibleTest { structure InvalidThingException { message: String } - """.asSmithyModel() + """.asSmithyModel() clientIntegrationTest(model) { context, rustCrate -> val operation: OperationShape = context.model.lookup("test#SomeOperation") @@ -91,8 +92,9 @@ class AwsQueryCompatibleTest { assert_eq!(Some("Sender"), error.meta().extra("type")); } """, - "infallible_client_fn" to CargoDependency.smithyRuntimeTestUtil(context.runtimeConfig) - .toType().resolve("client::http::test_util::infallible_client_fn"), + "infallible_client_fn" to + CargoDependency.smithyRuntimeTestUtil(context.runtimeConfig) + .toType().resolve("client::http::test_util::infallible_client_fn"), "tokio" to CargoDependency.Tokio.toType(), ) } @@ -101,7 +103,8 @@ class AwsQueryCompatibleTest { @Test fun `aws-query-compatible json without aws query error should allow for retrieving error code from payload`() { - val model = """ + val model = + """ namespace test use aws.protocols#awsJson1_0 use aws.protocols#awsQueryCompatible @@ -128,7 +131,7 @@ class AwsQueryCompatibleTest { structure InvalidThingException { message: String } - """.asSmithyModel() + """.asSmithyModel() clientIntegrationTest(model) { context, rustCrate -> val operation: OperationShape = context.model.lookup("test#SomeOperation") @@ -164,8 +167,9 @@ class AwsQueryCompatibleTest { assert_eq!(None, error.meta().extra("type")); } """, - "infallible_client_fn" to CargoDependency.smithyRuntimeTestUtil(context.runtimeConfig) - .toType().resolve("client::http::test_util::infallible_client_fn"), + "infallible_client_fn" to + CargoDependency.smithyRuntimeTestUtil(context.runtimeConfig) + .toType().resolve("client::http::test_util::infallible_client_fn"), "tokio" to CargoDependency.Tokio.toType(), ) } diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/AwsQueryTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/AwsQueryTest.kt index 0e5f22ec70e..9f811158951 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/AwsQueryTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/AwsQueryTest.kt @@ -10,7 +10,8 @@ import software.amazon.smithy.rust.codegen.client.testutil.clientIntegrationTest import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel class AwsQueryTest { - private val model = """ + private val model = + """ namespace test use aws.protocols#awsQuery @@ -31,7 +32,7 @@ class AwsQueryTest { a: String, b: Integer } - """.asSmithyModel() + """.asSmithyModel() @Test fun `generate an aws query service that compiles`() { diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/Ec2QueryTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/Ec2QueryTest.kt index 13779ca09f5..9a258254ab9 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/Ec2QueryTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/Ec2QueryTest.kt @@ -10,7 +10,8 @@ import software.amazon.smithy.rust.codegen.client.testutil.clientIntegrationTest import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel class Ec2QueryTest { - private val model = """ + private val model = + """ namespace test use aws.protocols#ec2Query @@ -31,7 +32,7 @@ class Ec2QueryTest { a: String, b: Integer } - """.asSmithyModel() + """.asSmithyModel() @Test fun `generate an aws query service that compiles`() { diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/RestJsonTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/RestJsonTest.kt index 99a6b3a01c1..200019adb12 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/RestJsonTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/RestJsonTest.kt @@ -10,7 +10,8 @@ import software.amazon.smithy.rust.codegen.client.testutil.clientIntegrationTest import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel internal class RestJsonTest { - val model = """ + val model = + """ namespace test use aws.protocols#restJson1 use aws.api#service @@ -36,7 +37,7 @@ internal class RestJsonTest { a: String, b: Integer } - """.asSmithyModel() + """.asSmithyModel() @Test fun `generate a rest json service that compiles`() { diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/RestXmlTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/RestXmlTest.kt index aad9e1d17f4..279bf964edd 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/RestXmlTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/RestXmlTest.kt @@ -10,8 +10,8 @@ import software.amazon.smithy.rust.codegen.client.testutil.clientIntegrationTest import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel internal class RestXmlTest { - - private val model = """ + private val model = + """ namespace test use aws.protocols#restXml use aws.api#service @@ -81,7 +81,7 @@ internal class RestXmlTest { renamedWithPrefix: String } - """.asSmithyModel() + """.asSmithyModel() @Test fun `generate a rest xml service that compiles`() { diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/transformers/RemoveEventStreamOperationsTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/transformers/RemoveEventStreamOperationsTest.kt index 4cc9c594aaa..2396689b3af 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/transformers/RemoveEventStreamOperationsTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/transformers/RemoveEventStreamOperationsTest.kt @@ -16,7 +16,8 @@ import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import java.util.Optional internal class RemoveEventStreamOperationsTest { - private val model = """ + private val model = + """ namespace test operation EventStream { input: StreamingInput, @@ -43,28 +44,30 @@ internal class RemoveEventStreamOperationsTest { } structure Foo {} - """.asSmithyModel() + """.asSmithyModel() @Test fun `remove event stream ops from services that are not in the allow list`() { - val transformed = RemoveEventStreamOperations.transform( - model, - testClientRustSettings( - codegenConfig = ClientCodegenConfig(eventStreamAllowList = setOf("not-test-module")), - ), - ) + val transformed = + RemoveEventStreamOperations.transform( + model, + testClientRustSettings( + codegenConfig = ClientCodegenConfig(eventStreamAllowList = setOf("not-test-module")), + ), + ) transformed.expectShape(ShapeId.from("test#BlobStream")) transformed.getShape(ShapeId.from("test#EventStream")) shouldBe Optional.empty() } @Test fun `keep event stream ops from services that are in the allow list`() { - val transformed = RemoveEventStreamOperations.transform( - model, - testClientRustSettings( - codegenConfig = ClientCodegenConfig(eventStreamAllowList = setOf("test-module")), - ), - ) + val transformed = + RemoveEventStreamOperations.transform( + model, + testClientRustSettings( + codegenConfig = ClientCodegenConfig(eventStreamAllowList = setOf("test-module")), + ), + ) transformed.expectShape(ShapeId.from("test#BlobStream")) transformed.getShape(ShapeId.from("test#EventStream")) shouldNotBe Optional.empty() } diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/testutil/Matchers.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/testutil/Matchers.kt index 53b9a20b64c..1d3c7eae5b5 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/testutil/Matchers.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/testutil/Matchers.kt @@ -7,7 +7,10 @@ package software.amazon.smithy.rust.codegen.client.testutil import io.kotest.matchers.shouldBe -fun String.shouldMatchResource(clazz: Class, resourceName: String) { +fun String.shouldMatchResource( + clazz: Class, + resourceName: String, +) { val resource = clazz.getResource(resourceName).readText() this.trim().shouldBe(resource.trim()) } diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/tool/TimeTestSuiteGenerator.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/tool/TimeTestSuiteGenerator.kt index 681bcbb406a..44835f11843 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/tool/TimeTestSuiteGenerator.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/tool/TimeTestSuiteGenerator.kt @@ -21,21 +21,29 @@ import java.util.Locale import kotlin.math.absoluteValue private val UTC = ZoneId.of("UTC") -private val YEARS = listOf(-9999, -100, -1, /* year 0 doesn't exist */ 1, 100, 1969, 1970, 2037, 2038, 9999) -private val DAYS_IN_MONTH = mapOf( - 1 to 31, - 2 to 28, - 3 to 31, - 4 to 30, - 5 to 31, - 6 to 30, - 7 to 31, - 8 to 31, - 9 to 30, - 10 to 31, - 11 to 30, - 12 to 31, -) +private val YEARS = + listOf( + -9999, -100, -1, + // year 0 doesn't exist + 1, + 100, 1969, 1970, 2037, 2038, 9999, + ) + +private val DAYS_IN_MONTH = + mapOf( + 1 to 31, + 2 to 28, + 3 to 31, + 4 to 30, + 5 to 31, + 6 to 30, + 7 to 31, + 8 to 31, + 9 to 30, + 10 to 31, + 11 to 30, + 12 to 31, + ) private val MILLI_FRACTIONS = listOf(0, 1_000_000, 10_000_000, 100_000_000, 234_000_000) private val MICRO_FRACTIONS = listOf(0, 1_000, 10_000, 100_000, 234_000) private val NANO_FRACTIONS = @@ -47,13 +55,14 @@ private data class TestCase( ) { fun toNode(): Node = time.toInstant().let { instant -> - val map = mutableMapOf( - "iso8601" to Node.from(DateTimeFormatter.ISO_OFFSET_DATE_TIME.format(time)), - // JSON numbers have 52 bits of precision, and canonical seconds needs 64 bits - "canonical_seconds" to Node.from(instant.epochSecond.toString()), - "canonical_nanos" to NumberNode(instant.nano, SourceLocation.NONE), - "error" to BooleanNode(formatted == null, SourceLocation.NONE), - ) + val map = + mutableMapOf( + "iso8601" to Node.from(DateTimeFormatter.ISO_OFFSET_DATE_TIME.format(time)), + // JSON numbers have 52 bits of precision, and canonical seconds needs 64 bits + "canonical_seconds" to Node.from(instant.epochSecond.toString()), + "canonical_nanos" to NumberNode(instant.nano, SourceLocation.NONE), + "error" to BooleanNode(formatted == null, SourceLocation.NONE), + ) if (formatted != null) { map["smithy_format_value"] = Node.from(formatted) } @@ -76,11 +85,12 @@ private fun generateTestTimes(allowed: AllowedSubseconds): List { val hour = i % 24 val minute = i % 60 val second = (i * 233).absoluteValue % 60 - val nanoOfSecond = when (allowed) { - AllowedSubseconds.NANOS -> NANO_FRACTIONS[i % NANO_FRACTIONS.size] - AllowedSubseconds.MICROS -> MICRO_FRACTIONS[i % MICRO_FRACTIONS.size] - AllowedSubseconds.MILLIS -> MILLI_FRACTIONS[i % MILLI_FRACTIONS.size] - } + val nanoOfSecond = + when (allowed) { + AllowedSubseconds.NANOS -> NANO_FRACTIONS[i % NANO_FRACTIONS.size] + AllowedSubseconds.MICROS -> MICRO_FRACTIONS[i % MICRO_FRACTIONS.size] + AllowedSubseconds.MILLIS -> MILLI_FRACTIONS[i % MILLI_FRACTIONS.size] + } result.add(ZonedDateTime.of(year, month, dayOfMonth, hour, minute, second, nanoOfSecond, UTC)) i += 1 } @@ -95,25 +105,27 @@ private fun generateTestTimes(allowed: AllowedSubseconds): List { } private fun generateEpochSecondsTests(): List { - val formatter = DateTimeFormatterBuilder() - .appendValue(ChronoField.INSTANT_SECONDS, 1, 19, SignStyle.NORMAL) - .optionalStart() - .appendFraction(ChronoField.MICRO_OF_SECOND, 0, 6, true) - .optionalEnd() - .toFormatter() + val formatter = + DateTimeFormatterBuilder() + .appendValue(ChronoField.INSTANT_SECONDS, 1, 19, SignStyle.NORMAL) + .optionalStart() + .appendFraction(ChronoField.MICRO_OF_SECOND, 0, 6, true) + .optionalEnd() + .toFormatter() return generateTestTimes(AllowedSubseconds.MICROS).map { time -> TestCase(time, formatter.format(time)) } } private fun generateHttpDateTests(parsing: Boolean): List { - val formatter = DateTimeFormatterBuilder() - .appendPattern("EEE, dd MMM yyyy HH:mm:ss") - .optionalStart() - .appendFraction(ChronoField.MILLI_OF_SECOND, 0, 3, true) - .optionalEnd() - .appendLiteral(" GMT") - .toFormatter(Locale.ENGLISH) + val formatter = + DateTimeFormatterBuilder() + .appendPattern("EEE, dd MMM yyyy HH:mm:ss") + .optionalStart() + .appendFraction(ChronoField.MILLI_OF_SECOND, 0, 3, true) + .optionalEnd() + .appendLiteral(" GMT") + .toFormatter(Locale.ENGLISH) return generateTestTimes(if (parsing) AllowedSubseconds.MILLIS else AllowedSubseconds.NANOS).map { time -> TestCase( time, @@ -126,13 +138,14 @@ private fun generateHttpDateTests(parsing: Boolean): List { } private fun generateDateTimeTests(parsing: Boolean): List { - val formatter = DateTimeFormatterBuilder() - .appendPattern("yyyy-MM-dd'T'HH:mm:ss") - .optionalStart() - .appendFraction(ChronoField.MICRO_OF_SECOND, 0, 6, true) - .optionalEnd() - .appendLiteral("Z") - .toFormatter(Locale.ENGLISH) + val formatter = + DateTimeFormatterBuilder() + .appendPattern("yyyy-MM-dd'T'HH:mm:ss") + .optionalStart() + .appendFraction(ChronoField.MICRO_OF_SECOND, 0, 6, true) + .optionalEnd() + .appendLiteral("Z") + .toFormatter(Locale.ENGLISH) return generateTestTimes(if (parsing) AllowedSubseconds.MICROS else AllowedSubseconds.NANOS).map { time -> TestCase( time, @@ -146,78 +159,85 @@ private fun generateDateTimeTests(parsing: Boolean): List { fun main() { val none = SourceLocation.NONE - val topLevels = mapOf( - "description" to ArrayNode( - """ - This file holds format and parse test cases for Smithy's built-in `epoch-seconds`, - `http-date`, and `date-time` timestamp formats. - - There are six top-level sections: - - `format_epoch_seconds`: Test cases for formatting timestamps into `epoch-seconds` - - `format_http_date`: Test cases for formatting timestamps into `http-date` - - `format_date_time`: Test cases for formatting timestamps into `date-time` - - `parse_epoch_seconds`: Test cases for parsing timestamps from `epoch-seconds` - - `parse_http_date`: Test cases for parsing timestamps from `http-date` - - `parse_date_time`: Test cases for parsing timestamps from `date-time` - - Each top-level section is an array of the same test case data structure: - ```typescript - type TestCase = { - // Human-readable ISO-8601 representation of the canonical date-time. This should not - // be used by tests, and is only present to make test failures more human readable. - iso8601: string, - - // The canonical number of seconds since the Unix epoch in UTC. - canonical_seconds: string, - - // The canonical nanosecond adjustment to the canonical number of seconds. - // If conversion from (canonical_seconds, canonical_nanos) into a 128-bit integer is required, - // DO NOT just add the two together as this will yield an incorrect value when - // canonical_seconds is negative. - canonical_nanos: number, - - // Will be true if this test case is expected to result in an error or exception - error: boolean, - - // String value of the timestamp in the Smithy format. For the `format_epoch_seconds` top-level, - // this will be in the `epoch-seconds` format, and for `parse_http_date`, it will be in the - // `http-date` format (and so on). - // - // For parsing tests, parse this value and compare the result against canonical_seconds - // and canonical_nanos. - // - // For formatting tests, form the canonical_seconds and canonical_nanos, and then compare - // the result against this value. - // - // This value will not be set for formatting tests if `error` is set to `true`. - smithy_format_value: string, - } - ``` - """.trimIndent().split("\n").map { Node.from(it) }, - none, - ), - "format_epoch_seconds" to ArrayNode(generateEpochSecondsTests().map(TestCase::toNode), none), - "format_http_date" to ArrayNode(generateHttpDateTests(parsing = false).map(TestCase::toNode), none), - "format_date_time" to ArrayNode(generateDateTimeTests(parsing = false).map(TestCase::toNode), none), - "parse_epoch_seconds" to ArrayNode( - generateEpochSecondsTests() - .filter { it.formatted != null } - .map(TestCase::toNode), - none, - ), - "parse_http_date" to ArrayNode( - generateHttpDateTests(parsing = true) - .filter { it.formatted != null } - .map(TestCase::toNode), - none, - ), - "parse_date_time" to ArrayNode( - generateDateTimeTests(parsing = true) - .filter { it.formatted != null } - .map(TestCase::toNode), - none, - ), - ).mapKeys { Node.from(it.key) } + val topLevels = + mapOf( + "description" to + ArrayNode( + """ + This file holds format and parse test cases for Smithy's built-in `epoch-seconds`, + `http-date`, and `date-time` timestamp formats. + + There are six top-level sections: + - `format_epoch_seconds`: Test cases for formatting timestamps into `epoch-seconds` + - `format_http_date`: Test cases for formatting timestamps into `http-date` + - `format_date_time`: Test cases for formatting timestamps into `date-time` + - `parse_epoch_seconds`: Test cases for parsing timestamps from `epoch-seconds` + - `parse_http_date`: Test cases for parsing timestamps from `http-date` + - `parse_date_time`: Test cases for parsing timestamps from `date-time` + + Each top-level section is an array of the same test case data structure: + ```typescript + type TestCase = { + // Human-readable ISO-8601 representation of the canonical date-time. This should not + // be used by tests, and is only present to make test failures more human readable. + iso8601: string, + + // The canonical number of seconds since the Unix epoch in UTC. + canonical_seconds: string, + + // The canonical nanosecond adjustment to the canonical number of seconds. + // If conversion from (canonical_seconds, canonical_nanos) into a 128-bit integer is required, + // DO NOT just add the two together as this will yield an incorrect value when + // canonical_seconds is negative. + canonical_nanos: number, + + // Will be true if this test case is expected to result in an error or exception + error: boolean, + + // String value of the timestamp in the Smithy format. For the `format_epoch_seconds` top-level, + // this will be in the `epoch-seconds` format, and for `parse_http_date`, it will be in the + // `http-date` format (and so on). + // + // For parsing tests, parse this value and compare the result against canonical_seconds + // and canonical_nanos. + // + // For formatting tests, form the canonical_seconds and canonical_nanos, and then compare + // the result against this value. + // + // This value will not be set for formatting tests if `error` is set to `true`. + smithy_format_value: string, + } + ``` + """.trimIndent().split("\n").map { + Node.from(it) + }, + none, + ), + "format_epoch_seconds" to ArrayNode(generateEpochSecondsTests().map(TestCase::toNode), none), + "format_http_date" to ArrayNode(generateHttpDateTests(parsing = false).map(TestCase::toNode), none), + "format_date_time" to ArrayNode(generateDateTimeTests(parsing = false).map(TestCase::toNode), none), + "parse_epoch_seconds" to + ArrayNode( + generateEpochSecondsTests() + .filter { it.formatted != null } + .map(TestCase::toNode), + none, + ), + "parse_http_date" to + ArrayNode( + generateHttpDateTests(parsing = true) + .filter { it.formatted != null } + .map(TestCase::toNode), + none, + ), + "parse_date_time" to + ArrayNode( + generateDateTimeTests(parsing = true) + .filter { it.formatted != null } + .map(TestCase::toNode), + none, + ), + ).mapKeys { Node.from(it.key) } println(Node.prettyPrintJson(ObjectNode(topLevels, none))) } diff --git a/codegen-core/build.gradle.kts b/codegen-core/build.gradle.kts index e72667f8984..0358b430a23 100644 --- a/codegen-core/build.gradle.kts +++ b/codegen-core/build.gradle.kts @@ -98,7 +98,7 @@ val generateSmithyRuntimeCrateVersion by tasks.registering { } tasks.compileKotlin { - kotlinOptions.jvmTarget = "1.8" + kotlinOptions.jvmTarget = "11" dependsOn(generateSmithyRuntimeCrateVersion) } @@ -135,7 +135,7 @@ if (isTestingEnabled.toBoolean()) { } tasks.compileTestKotlin { - kotlinOptions.jvmTarget = "1.8" + kotlinOptions.jvmTarget = "11" } tasks.test { diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/Version.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/Version.kt index 071f0a2089b..5e3bf35db5e 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/Version.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/Version.kt @@ -35,22 +35,21 @@ data class Version( } // Returns full version in the "{smithy_rs_version}-{git_commit_hash}" format - fun fullVersion(): String = - fromDefaultResource().fullVersion + fun fullVersion(): String = fromDefaultResource().fullVersion - fun stableCrateVersion(): String = - fromDefaultResource().stableCrateVersion + fun stableCrateVersion(): String = fromDefaultResource().stableCrateVersion - fun unstableCrateVersion(): String = - fromDefaultResource().unstableCrateVersion + fun unstableCrateVersion(): String = fromDefaultResource().unstableCrateVersion fun crateVersion(crate: String): String { val version = fromDefaultResource() return version.crates[crate] ?: version.unstableCrateVersion } - fun fromDefaultResource(): Version = parse( - Version::class.java.getResource(VERSION_FILENAME)?.readText() - ?: throw CodegenException("$VERSION_FILENAME does not exist"), - ) + + fun fromDefaultResource(): Version = + parse( + Version::class.java.getResource(VERSION_FILENAME)?.readText() + ?: throw CodegenException("$VERSION_FILENAME does not exist"), + ) } } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/CargoDependency.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/CargoDependency.kt index 88416e45c45..23ce4428afb 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/CargoDependency.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/CargoDependency.kt @@ -21,12 +21,16 @@ enum class DependencyScope { } sealed class DependencyLocation + data class CratesIo(val version: String) : DependencyLocation() + data class Local(val basePath: String, val version: String? = null) : DependencyLocation() sealed class RustDependency(open val name: String) : SymbolDependencyContainer { abstract fun version(): String + open fun dependencies(): List = listOf() + override fun getDependencies(): List { return listOf( SymbolDependency @@ -39,6 +43,7 @@ sealed class RustDependency(open val name: String) : SymbolDependencyContainer { companion object { private const val PropertyKey = "rustdep" + fun fromSymbolDependency(symbolDependency: SymbolDependency) = symbolDependency.getProperty(PropertyKey, RustDependency::class.java).get() } @@ -86,8 +91,10 @@ class InlineDependency( } } - private fun forInlineableRustFile(name: String, vararg additionalDependencies: RustDependency) = - forRustFile(RustModule.private(name), "/inlineable/src/$name.rs", *additionalDependencies) + private fun forInlineableRustFile( + name: String, + vararg additionalDependencies: RustDependency, + ) = forRustFile(RustModule.private(name), "/inlineable/src/$name.rs", *additionalDependencies) fun eventReceiver(runtimeConfig: RuntimeConfig) = forInlineableRustFile( @@ -97,7 +104,8 @@ class InlineDependency( CargoDependency.smithyTypes(runtimeConfig), ) - fun defaultAuthPlugin(runtimeConfig: RuntimeConfig) = forInlineableRustFile("auth_plugin", CargoDependency.smithyRuntimeApiClient(runtimeConfig)) + fun defaultAuthPlugin(runtimeConfig: RuntimeConfig) = + forInlineableRustFile("auth_plugin", CargoDependency.smithyRuntimeApiClient(runtimeConfig)) fun jsonErrors(runtimeConfig: RuntimeConfig) = forInlineableRustFile( @@ -130,12 +138,13 @@ class InlineDependency( fun unwrappedXmlErrors(runtimeConfig: RuntimeConfig): InlineDependency = forInlineableRustFile("rest_xml_unwrapped_errors", CargoDependency.smithyXml(runtimeConfig)) - fun serializationSettings(runtimeConfig: RuntimeConfig): InlineDependency = forInlineableRustFile( - "serialization_settings", - CargoDependency.Http, - CargoDependency.smithyHttp(runtimeConfig), - CargoDependency.smithyTypes(runtimeConfig), - ) + fun serializationSettings(runtimeConfig: RuntimeConfig): InlineDependency = + forInlineableRustFile( + "serialization_settings", + CargoDependency.Http, + CargoDependency.smithyHttp(runtimeConfig), + CargoDependency.smithyTypes(runtimeConfig), + ) fun constrained(): InlineDependency = InlineDependency.forRustFile(ConstrainedModule, "/inlineable/src/constrained.rs") @@ -174,10 +183,11 @@ data class CargoDependency( fun toDevDependency() = copy(scope = DependencyScope.Dev) - override fun version(): String = when (location) { - is CratesIo -> location.version - is Local -> "local" - } + override fun version(): String = + when (location) { + is CratesIo -> location.version + is Local -> "local" + } fun toMap(): Map { val attribs = mutableMapOf() @@ -274,45 +284,63 @@ data class CargoDependency( DependencyScope.Dev, features = setOf("macros", "test-util", "rt-multi-thread"), ) - val TracingAppender: CargoDependency = CargoDependency( - "tracing-appender", - CratesIo("0.2.2"), - DependencyScope.Dev, - ) - val TracingSubscriber: CargoDependency = CargoDependency( - "tracing-subscriber", - CratesIo("0.3.16"), - DependencyScope.Dev, - features = setOf("env-filter", "json"), - ) - val TracingTest: CargoDependency = CargoDependency( - "tracing-test", - CratesIo("0.2.4"), - DependencyScope.Dev, - features = setOf("no-env-filter"), - ) + val TracingAppender: CargoDependency = + CargoDependency( + "tracing-appender", + CratesIo("0.2.2"), + DependencyScope.Dev, + ) + val TracingSubscriber: CargoDependency = + CargoDependency( + "tracing-subscriber", + CratesIo("0.3.16"), + DependencyScope.Dev, + features = setOf("env-filter", "json"), + ) + val TracingTest: CargoDependency = + CargoDependency( + "tracing-test", + CratesIo("0.2.4"), + DependencyScope.Dev, + features = setOf("no-env-filter"), + ) fun smithyAsync(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-async") + fun smithyChecksums(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-checksums") fun smithyEventStream(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-eventstream") + fun smithyHttp(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-http") + fun smithyJson(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-json") + fun smithyProtocolTestHelpers(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-protocol-test", scope = DependencyScope.Dev) fun smithyQuery(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-query") - fun smithyRuntime(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-runtime") - .withFeature("client") - fun smithyRuntimeTestUtil(runtimeConfig: RuntimeConfig) = smithyRuntime(runtimeConfig).toDevDependency().withFeature("test-util") + + fun smithyRuntime(runtimeConfig: RuntimeConfig) = + runtimeConfig.smithyRuntimeCrate("smithy-runtime") + .withFeature("client") + + fun smithyRuntimeTestUtil(runtimeConfig: RuntimeConfig) = + smithyRuntime(runtimeConfig).toDevDependency().withFeature("test-util") + fun smithyRuntimeApi(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-runtime-api") - fun smithyRuntimeApiClient(runtimeConfig: RuntimeConfig) = smithyRuntimeApi(runtimeConfig).withFeature("client").withFeature("http-02x") + + fun smithyRuntimeApiClient(runtimeConfig: RuntimeConfig) = + smithyRuntimeApi(runtimeConfig).withFeature("client").withFeature("http-02x") + fun smithyRuntimeApiTestUtil(runtimeConfig: RuntimeConfig) = smithyRuntimeApi(runtimeConfig).toDevDependency().withFeature("test-util") + fun smithyTypes(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-types") + fun smithyXml(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-xml") // behind feature-gate - val Serde = CargoDependency("serde", CratesIo("1.0"), features = setOf("derive"), scope = DependencyScope.CfgUnstable) + val Serde = + CargoDependency("serde", CratesIo("1.0"), features = setOf("derive"), scope = DependencyScope.CfgUnstable) } } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustGenerics.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustGenerics.kt index b5de87b76ca..80b33e260db 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustGenerics.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustGenerics.kt @@ -118,14 +118,15 @@ class RustGenerics(vararg genericTypeArgs: GenericTypeArg) { * // } * ``` */ - fun bounds() = writable { - // Only write bounds for generic type params with a bound - for ((typeArg, bound) in typeArgs) { - if (bound != null) { - rustTemplate("$typeArg: #{bound},\n", "bound" to bound) + fun bounds() = + writable { + // Only write bounds for generic type params with a bound + for ((typeArg, bound) in typeArgs) { + if (bound != null) { + rustTemplate("$typeArg: #{bound},\n", "bound" to bound) + } } } - } /** * Combine two `GenericsGenerator`s into one. Type args for the first `GenericsGenerator` will appear before diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustModule.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustModule.kt index 78dee92daef..6b0a725925b 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustModule.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustModule.kt @@ -16,7 +16,6 @@ import software.amazon.smithy.rust.codegen.core.util.PANIC * - There is no guarantee _which_ module will be rendered. */ sealed class RustModule { - /** lib.rs */ object LibRs : RustModule() @@ -33,11 +32,10 @@ sealed class RustModule { val rustMetadata: RustMetadata, val parent: RustModule = LibRs, val inline: Boolean = false, - /* module is a cfg(test) module */ + // module is a cfg(test) module val tests: Boolean = false, val documentationOverride: String? = null, ) : RustModule() { - init { check(!name.contains("::")) { "Module names CANNOT contain `::`—modules must be nested with parent (name was: `$name`)" @@ -52,14 +50,14 @@ sealed class RustModule { } /** Convert a module into a module gated with `#[cfg(test)]` */ - fun cfgTest(): LeafModule = this.copy( - rustMetadata = rustMetadata.copy(additionalAttributes = rustMetadata.additionalAttributes + Attribute.CfgTest), - tests = true, - ) + fun cfgTest(): LeafModule = + this.copy( + rustMetadata = rustMetadata.copy(additionalAttributes = rustMetadata.additionalAttributes + Attribute.CfgTest), + tests = true, + ) } companion object { - /** Creates a new module with the specified visibility */ fun new( name: String, @@ -84,29 +82,33 @@ sealed class RustModule { parent: RustModule = LibRs, documentationOverride: String? = null, additionalAttributes: List = emptyList(), - ): LeafModule = new( - name, - visibility = Visibility.PUBLIC, - inline = false, - parent = parent, - documentationOverride = documentationOverride, - additionalAttributes = additionalAttributes, - ) + ): LeafModule = + new( + name, + visibility = Visibility.PUBLIC, + inline = false, + parent = parent, + documentationOverride = documentationOverride, + additionalAttributes = additionalAttributes, + ) /** Creates a new private module */ - fun private(name: String, parent: RustModule = LibRs): LeafModule = - new(name, visibility = Visibility.PRIVATE, inline = false, parent = parent) + fun private( + name: String, + parent: RustModule = LibRs, + ): LeafModule = new(name, visibility = Visibility.PRIVATE, inline = false, parent = parent) fun pubCrate( name: String, parent: RustModule = LibRs, additionalAttributes: List = emptyList(), - ): LeafModule = new( - name, visibility = Visibility.PUBCRATE, - inline = false, - parent = parent, - additionalAttributes = additionalAttributes, - ) + ): LeafModule = + new( + name, visibility = Visibility.PUBCRATE, + inline = false, + parent = parent, + additionalAttributes = additionalAttributes, + ) fun inlineTests( name: String = "test", @@ -121,29 +123,32 @@ sealed class RustModule { ).cfgTest() } - fun isInline(): Boolean = when (this) { - is LibRs -> false - is LeafModule -> this.inline - } + fun isInline(): Boolean = + when (this) { + is LibRs -> false + is LeafModule -> this.inline + } /** * Fully qualified path to this module, e.g. `crate::grandparent::parent::child` */ - fun fullyQualifiedPath(): String = when (this) { - is LibRs -> "crate" - is LeafModule -> parent.fullyQualifiedPath() + "::" + name - } + fun fullyQualifiedPath(): String = + when (this) { + is LibRs -> "crate" + is LeafModule -> parent.fullyQualifiedPath() + "::" + name + } /** * The file this module is homed in, e.g. `src/grandparent/parent/child.rs` */ - fun definitionFile(): String = when (this) { - is LibRs -> "src/lib.rs" - is LeafModule -> { - val path = fullyQualifiedPath().split("::").drop(1).joinToString("/") - "src/$path.rs" + fun definitionFile(): String = + when (this) { + is LibRs -> "src/lib.rs" + is LeafModule -> { + val path = fullyQualifiedPath().split("::").drop(1).joinToString("/") + "src/$path.rs" + } } - } /** * Renders the usage statement, approximately: @@ -152,7 +157,10 @@ sealed class RustModule { * pub mod my_module_name * ``` */ - fun renderModStatement(writer: RustWriter, moduleDocProvider: ModuleDocProvider) { + fun renderModStatement( + writer: RustWriter, + moduleDocProvider: ModuleDocProvider, + ) { when (this) { is LeafModule -> { if (name.startsWith("r#")) { diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustReservedWords.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustReservedWords.kt index 260cb62295a..cc37f891eda 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustReservedWords.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustReservedWords.kt @@ -97,85 +97,92 @@ enum class EscapeFor { } object RustReservedWords : ReservedWords { - private val RustKeywords = setOf( - "as", - "break", - "const", - "continue", - "crate", - "else", - "enum", - "extern", - "false", - "fn", - "for", - "if", - "impl", - "in", - "let", - "loop", - "match", - "mod", - "move", - "mut", - "pub", - "ref", - "return", - "self", - "Self", - "static", - "struct", - "super", - "trait", - "true", - "type", - "unsafe", - "use", - "where", - "while", - - "async", - "await", - "dyn", - - "abstract", - "become", - "box", - "do", - "final", - "macro", - "override", - "priv", - "typeof", - "unsized", - "virtual", - "yield", - "try", - ) + private val RustKeywords = + setOf( + "as", + "break", + "const", + "continue", + "crate", + "else", + "enum", + "extern", + "false", + "fn", + "for", + "if", + "impl", + "in", + "let", + "loop", + "match", + "mod", + "move", + "mut", + "pub", + "ref", + "return", + "self", + "Self", + "static", + "struct", + "super", + "trait", + "true", + "type", + "unsafe", + "use", + "where", + "while", + "async", + "await", + "dyn", + "abstract", + "become", + "box", + "do", + "final", + "macro", + "override", + "priv", + "typeof", + "unsized", + "virtual", + "yield", + "try", + ) // Some things can't be used as a raw identifier, so we can't use the normal escaping strategy // https://internals.rust-lang.org/t/raw-identifiers-dont-work-for-all-identifiers/9094/4 - private val keywordEscapingMap = mapOf( - "crate" to "crate_", - "super" to "super_", - "self" to "self_", - "Self" to "SelfValue", - // Real models won't end in `_` so it's safe to stop here - "SelfValue" to "SelfValue_", - ) + private val keywordEscapingMap = + mapOf( + "crate" to "crate_", + "super" to "super_", + "self" to "self_", + "Self" to "SelfValue", + // Real models won't end in `_` so it's safe to stop here + "SelfValue" to "SelfValue_", + ) override fun escape(word: String): String = doEscape(word, EscapeFor.TypeName) - private fun doEscape(word: String, escapeFor: EscapeFor = EscapeFor.TypeName): String = + private fun doEscape( + word: String, + escapeFor: EscapeFor = EscapeFor.TypeName, + ): String = when (val mapped = keywordEscapingMap[word]) { - null -> when (escapeFor) { - EscapeFor.TypeName -> "r##$word" - EscapeFor.ModuleName -> "${word}_" - } + null -> + when (escapeFor) { + EscapeFor.TypeName -> "r##$word" + EscapeFor.ModuleName -> "${word}_" + } else -> mapped } - fun escapeIfNeeded(word: String, escapeFor: EscapeFor = EscapeFor.TypeName): String = + fun escapeIfNeeded( + word: String, + escapeFor: EscapeFor = EscapeFor.TypeName, + ): String = when (isReserved(word)) { true -> doEscape(word, escapeFor) else -> word diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustType.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustType.kt index 75d97a2721b..8ac936d6793 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustType.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustType.kt @@ -15,11 +15,12 @@ import software.amazon.smithy.rust.codegen.core.util.dq * * Clippy is upset about `*&`, so if [input] is already referenced, simply strip the leading '&' */ -fun autoDeref(input: String) = if (input.startsWith("&")) { - input.removePrefix("&") -} else { - "*$input" -} +fun autoDeref(input: String) = + if (input.startsWith("&")) { + input.removePrefix("&") + } else { + "*$input" + } /** * A hierarchy of types handled by Smithy codegen @@ -212,20 +213,22 @@ fun RustType.asArgumentType(fullyQualified: Boolean = true): String { } /** Format this Rust type so that it may be used as an argument type in a function definition */ -fun RustType.asArgumentValue(name: String) = when (this) { - is RustType.String, is RustType.Box -> "$name.into()" - else -> name -} +fun RustType.asArgumentValue(name: String) = + when (this) { + is RustType.String, is RustType.Box -> "$name.into()" + else -> name + } /** * For a given name, generate an `Argument` data class containing pre-formatted strings for using this type when * writing a Rust function. */ -fun RustType.asArgument(name: String) = Argument( - "$name: ${this.asArgumentType()}", - this.asArgumentValue(name), - this.render(), -) +fun RustType.asArgument(name: String) = + Argument( + "$name: ${this.asArgumentType()}", + this.asArgumentValue(name), + this.render(), + ) /** * Render this type, including references and generic parameters. @@ -233,38 +236,40 @@ fun RustType.asArgument(name: String) = Argument( * - To generate something like `std::collections::HashMap`, use [qualifiedName] */ fun RustType.render(fullyQualified: Boolean = true): String { - val namespace = if (fullyQualified) { - this.namespace?.let { "$it::" } ?: "" - } else { - "" - } - val base = when (this) { - is RustType.Unit -> this.name - is RustType.Bool -> this.name - is RustType.Float -> this.name - is RustType.Integer -> this.name - is RustType.String -> this.name - is RustType.Vec -> "${this.name}::<${this.member.render(fullyQualified)}>" - is RustType.Slice -> "[${this.member.render(fullyQualified)}]" - is RustType.HashMap -> "${this.name}::<${this.key.render(fullyQualified)}, ${this.member.render(fullyQualified)}>" - is RustType.HashSet -> "${this.name}::<${this.member.render(fullyQualified)}>" - is RustType.Reference -> { - if (this.lifetime == "&") { - "&${this.member.render(fullyQualified)}" - } else { - "&${this.lifetime?.let { "'$it" } ?: ""} ${this.member.render(fullyQualified)}" - } + val namespace = + if (fullyQualified) { + this.namespace?.let { "$it::" } ?: "" + } else { + "" } - is RustType.Application -> { - val args = this.args.joinToString(", ") { it.render(fullyQualified) } - "${this.name}<$args>" + val base = + when (this) { + is RustType.Unit -> this.name + is RustType.Bool -> this.name + is RustType.Float -> this.name + is RustType.Integer -> this.name + is RustType.String -> this.name + is RustType.Vec -> "${this.name}::<${this.member.render(fullyQualified)}>" + is RustType.Slice -> "[${this.member.render(fullyQualified)}]" + is RustType.HashMap -> "${this.name}::<${this.key.render(fullyQualified)}, ${this.member.render(fullyQualified)}>" + is RustType.HashSet -> "${this.name}::<${this.member.render(fullyQualified)}>" + is RustType.Reference -> { + if (this.lifetime == "&") { + "&${this.member.render(fullyQualified)}" + } else { + "&${this.lifetime?.let { "'$it" } ?: ""} ${this.member.render(fullyQualified)}" + } + } + is RustType.Application -> { + val args = this.args.joinToString(", ") { it.render(fullyQualified) } + "${this.name}<$args>" + } + is RustType.Option -> "${this.name}<${this.member.render(fullyQualified)}>" + is RustType.Box -> "${this.name}<${this.member.render(fullyQualified)}>" + is RustType.Dyn -> "${this.name} ${this.member.render(fullyQualified)}" + is RustType.Opaque -> this.name + is RustType.MaybeConstrained -> "${this.name}<${this.member.render(fullyQualified)}>" } - is RustType.Option -> "${this.name}<${this.member.render(fullyQualified)}>" - is RustType.Box -> "${this.name}<${this.member.render(fullyQualified)}>" - is RustType.Dyn -> "${this.name} ${this.member.render(fullyQualified)}" - is RustType.Opaque -> this.name - is RustType.MaybeConstrained -> "${this.name}<${this.member.render(fullyQualified)}>" - } return "$namespace$base" } @@ -273,29 +278,33 @@ fun RustType.render(fullyQualified: Boolean = true): String { * Option.contains(DateTime) would return true. * Option.contains(Blob) would return false. */ -fun RustType.contains(t: T): Boolean = when (this) { - t -> true - is RustType.Container -> this.member.contains(t) - else -> false -} +fun RustType.contains(t: T): Boolean = + when (this) { + t -> true + is RustType.Container -> this.member.contains(t) + else -> false + } -inline fun RustType.stripOuter(): RustType = when (this) { - is T -> this.member - else -> this -} +inline fun RustType.stripOuter(): RustType = + when (this) { + is T -> this.member + else -> this + } /** Extracts the inner Reference type */ -fun RustType.innerReference(): RustType? = when (this) { - is RustType.Reference -> this - is RustType.Container -> this.member.innerReference() - else -> null -} +fun RustType.innerReference(): RustType? = + when (this) { + is RustType.Reference -> this + is RustType.Container -> this.member.innerReference() + else -> null + } /** Wraps a type in Option if it isn't already */ -fun RustType.asOptional(): RustType = when (this) { - is RustType.Option -> this - else -> RustType.Option(this) -} +fun RustType.asOptional(): RustType = + when (this) { + is RustType.Option -> this + else -> RustType.Option(this) + } /** * Converts type to a reference @@ -304,11 +313,12 @@ fun RustType.asOptional(): RustType = when (this) { * - `String` -> `&String` * - `Option` -> `Option<&T>` */ -fun RustType.asRef(): RustType = when (this) { - is RustType.Reference -> this - is RustType.Option -> RustType.Option(member.asRef()) - else -> RustType.Reference(null, this) -} +fun RustType.asRef(): RustType = + when (this) { + is RustType.Reference -> this + is RustType.Option -> RustType.Option(member.asRef()) + else -> RustType.Reference(null, this) + } /** * Converts type to its Deref target @@ -318,64 +328,77 @@ fun RustType.asRef(): RustType = when (this) { * - `Option` -> `Option<&str>` * - `Box` -> `&Something` */ -fun RustType.asDeref(): RustType = when (this) { - is RustType.Option -> if (member.isDeref()) { - RustType.Option(member.asDeref().asRef()) - } else { - this - } +fun RustType.asDeref(): RustType = + when (this) { + is RustType.Option -> + if (member.isDeref()) { + RustType.Option(member.asDeref().asRef()) + } else { + this + } - is RustType.Box -> RustType.Reference(null, member) - is RustType.String -> RustType.Opaque("str") - is RustType.Vec -> RustType.Slice(member) - else -> this -} + is RustType.Box -> RustType.Reference(null, member) + is RustType.String -> RustType.Opaque("str") + is RustType.Vec -> RustType.Slice(member) + else -> this + } /** Returns true if the type implements Deref */ -fun RustType.isDeref(): Boolean = when (this) { - is RustType.Box -> true - is RustType.String -> true - is RustType.Vec -> true - else -> false -} +fun RustType.isDeref(): Boolean = + when (this) { + is RustType.Box -> true + is RustType.String -> true + is RustType.Vec -> true + else -> false + } /** Returns true if the type implements Copy */ -fun RustType.isCopy(): Boolean = when (this) { - is RustType.Float -> true - is RustType.Integer -> true - is RustType.Reference -> true - is RustType.Bool -> true - is RustType.Slice -> true - is RustType.Option -> this.member.isCopy() - else -> false -} +fun RustType.isCopy(): Boolean = + when (this) { + is RustType.Float -> true + is RustType.Integer -> true + is RustType.Reference -> true + is RustType.Bool -> true + is RustType.Slice -> true + is RustType.Option -> this.member.isCopy() + else -> false + } /** Returns true if the type implements Eq */ -fun RustType.isEq(): Boolean = when (this) { - is RustType.Integer -> true - is RustType.Bool -> true - is RustType.String -> true - is RustType.Unit -> true - is RustType.Container -> this.member.isEq() - else -> false -} +fun RustType.isEq(): Boolean = + when (this) { + is RustType.Integer -> true + is RustType.Bool -> true + is RustType.String -> true + is RustType.Unit -> true + is RustType.Container -> this.member.isEq() + else -> false + } enum class Visibility { - PRIVATE, PUBCRATE, PUBLIC; + PRIVATE, + PUBCRATE, + PUBLIC, + ; companion object { - fun publicIf(condition: Boolean, ifNot: Visibility): Visibility = if (condition) { - PUBLIC - } else { - ifNot - } + fun publicIf( + condition: Boolean, + ifNot: Visibility, + ): Visibility = + if (condition) { + PUBLIC + } else { + ifNot + } } - fun toRustQualifier(): String = when (this) { - PRIVATE -> "" - PUBCRATE -> "pub(crate)" - PUBLIC -> "pub" - } + fun toRustQualifier(): String = + when (this) { + PRIVATE -> "" + PUBCRATE -> "pub(crate)" + PUBLIC -> "pub" + } } /** @@ -386,11 +409,9 @@ data class RustMetadata( val additionalAttributes: List = listOf(), val visibility: Visibility = Visibility.PRIVATE, ) { - fun withDerives(vararg newDerives: RuntimeType): RustMetadata = - this.copy(derives = derives + newDerives) + fun withDerives(vararg newDerives: RuntimeType): RustMetadata = this.copy(derives = derives + newDerives) - fun withoutDerives(vararg withoutDerives: RuntimeType) = - this.copy(derives = derives - withoutDerives.toSet()) + fun withoutDerives(vararg withoutDerives: RuntimeType) = this.copy(derives = derives - withoutDerives.toSet()) fun renderAttributes(writer: RustWriter): RustMetadata { val (deriveHelperAttrs, otherAttrs) = additionalAttributes.partition { it.isDeriveHelper } @@ -473,7 +494,10 @@ class Attribute(val inner: Writable, val isDeriveHelper: Boolean = false) { constructor(str: String, isDeriveHelper: Boolean) : this(writable(str), isDeriveHelper) constructor(runtimeType: RuntimeType) : this(runtimeType.writable) - fun render(writer: RustWriter, attributeKind: AttributeKind = AttributeKind.Outer) { + fun render( + writer: RustWriter, + attributeKind: AttributeKind = AttributeKind.Outer, + ) { // Writing "#[]" with nothing inside it is meaningless if (inner.isNotEmpty()) { when (attributeKind) { @@ -486,17 +510,19 @@ class Attribute(val inner: Writable, val isDeriveHelper: Boolean = false) { // These were supposed to be a part of companion object but we decided to move it out to here to avoid NPE // You can find the discussion here. // https://github.com/smithy-lang/smithy-rs/discussions/2248 - public fun SerdeSerialize(): Attribute { + fun serdeSerialize(): Attribute { return Attribute(cfgAttr(all(writable("aws_sdk_unstable"), feature("serde-serialize")), derive(RuntimeType.SerdeSerialize))) } - public fun SerdeDeserialize(): Attribute { + + fun serdeDeserialize(): Attribute { return Attribute(cfgAttr(all(writable("aws_sdk_unstable"), feature("serde-deserialize")), derive(RuntimeType.SerdeDeserialize))) } - public fun SerdeSkip(): Attribute { + + fun serdeSkip(): Attribute { return Attribute(cfgAttr(all(writable("aws_sdk_unstable"), any(feature("serde-serialize"), feature("serde-deserialize"))), serde("skip"))) } - public fun SerdeSerializeOrDeserialize(): Attribute { + fun serdeSerializeOrDeserialize(): Attribute { return Attribute(cfg(all(writable("aws_sdk_unstable"), any(feature("serde-serialize"), feature("serde-deserialize"))))) } @@ -526,6 +552,7 @@ class Attribute(val inner: Writable, val isDeriveHelper: Boolean = false) { val DocHidden = Attribute(doc("hidden")) val DocInline = Attribute(doc("inline")) val NoImplicitPrelude = Attribute("no_implicit_prelude") + fun shouldPanic(expectedMessage: String) = Attribute(macroWithArgs("should_panic", "expected = ${expectedMessage.dq()}")) @@ -545,40 +572,62 @@ class Attribute(val inner: Writable, val isDeriveHelper: Boolean = false) { */ val Deprecated = Attribute("deprecated") - private fun macroWithArgs(name: String, vararg args: RustWriter.() -> Unit): Writable = { - // Macros that require args can't be empty - if (args.isNotEmpty()) { - rustInline("$name(#W)", args.toList().join(", ")) + private fun macroWithArgs( + name: String, + vararg args: RustWriter.() -> Unit, + ): Writable = + { + // Macros that require args can't be empty + if (args.isNotEmpty()) { + rustInline("$name(#W)", args.toList().join(", ")) + } } - } - private fun macroWithArgs(name: String, vararg args: String): Writable = { - // Macros that require args can't be empty - if (args.isNotEmpty()) { - rustInline("$name(${args.joinToString(", ")})") + private fun macroWithArgs( + name: String, + vararg args: String, + ): Writable = + { + // Macros that require args can't be empty + if (args.isNotEmpty()) { + rustInline("$name(${args.joinToString(", ")})") + } } - } fun all(vararg attrMacros: Writable): Writable = macroWithArgs("all", *attrMacros) + fun cfgAttr(vararg attrMacros: Writable): Writable = macroWithArgs("cfg_attr", *attrMacros) fun allow(lints: Collection): Writable = macroWithArgs("allow", *lints.toTypedArray()) + fun allow(vararg lints: String): Writable = macroWithArgs("allow", *lints) + fun deny(vararg lints: String): Writable = macroWithArgs("deny", *lints) + fun serde(vararg lints: String): Writable = macroWithArgs("serde", *lints) + fun any(vararg attrMacros: Writable): Writable = macroWithArgs("any", *attrMacros) + fun cfg(vararg attrMacros: Writable): Writable = macroWithArgs("cfg", *attrMacros) + fun cfg(vararg attrMacros: String): Writable = macroWithArgs("cfg", *attrMacros) + fun doc(vararg attrMacros: Writable): Writable = macroWithArgs("doc", *attrMacros) + fun doc(str: String): Writable = macroWithArgs("doc", writable(str)) + fun not(vararg attrMacros: Writable): Writable = macroWithArgs("not", *attrMacros) fun feature(feature: String) = writable("feature = ${feature.dq()}") + fun featureGate(featureName: String): Attribute { return Attribute(cfg(feature(featureName))) } - fun deprecated(since: String? = null, note: String? = null): Writable { + fun deprecated( + since: String? = null, + note: String? = null, + ): Writable { val optionalFields = mutableListOf() if (!note.isNullOrEmpty()) { optionalFields.add(pair("note" to note.dq())) @@ -596,21 +645,23 @@ class Attribute(val inner: Writable, val isDeriveHelper: Boolean = false) { } } - fun derive(vararg runtimeTypes: RuntimeType): Writable = { - // Empty derives are meaningless - if (runtimeTypes.isNotEmpty()) { - // Sorted derives look nicer than unsorted, and it makes test output easier to predict - val writables = runtimeTypes.sortedBy { it.path }.map { it.writable }.join(", ") - rustInline("derive(#W)", writables) + fun derive(vararg runtimeTypes: RuntimeType): Writable = + { + // Empty derives are meaningless + if (runtimeTypes.isNotEmpty()) { + // Sorted derives look nicer than unsorted, and it makes test output easier to predict + val writables = runtimeTypes.sortedBy { it.path }.map { it.writable }.join(", ") + rustInline("derive(#W)", writables) + } } - } fun derive(runtimeTypes: Collection): Writable = derive(*runtimeTypes.toTypedArray()) - fun pair(pair: Pair): Writable = { - val (key, value) = pair - rustInline("$key = $value") - } + fun pair(pair: Pair): Writable = + { + val (key, value) = pair + rustInline("$key = $value") + } } } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustWriter.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustWriter.kt index e8bf22ff7c7..05504e54c73 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustWriter.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustWriter.kt @@ -76,7 +76,11 @@ fun > T.withBlock( return conditionalBlock(textBeforeNewLine, textAfterNewLine, conditional = true, block = block, args = args) } -fun > T.assignment(variableName: String, vararg ctx: Pair, block: T.() -> Unit) { +fun > T.assignment( + variableName: String, + vararg ctx: Pair, + block: T.() -> Unit, +) { withBlockTemplate("let $variableName =", ";", *ctx) { block() } @@ -185,26 +189,31 @@ fun RustWriter.rustInline( this.writeInline(contents, *args) } -/* rewrite #{foo} to #{foo:T} (the smithy template format) */ -private fun transformTemplate(template: String, scope: Array>, trim: Boolean = true): String { +// rewrite #{foo} to #{foo:T} (the smithy template format) +private fun transformTemplate( + template: String, + scope: Array>, + trim: Boolean = true, +): String { check( scope.distinctBy { it.first.lowercase() }.size == scope.distinctBy { it.first }.size, ) { "Duplicate cased keys not supported" } - val output = template.replace(Regex("""#\{([a-zA-Z_0-9]+)(:\w)?\}""")) { matchResult -> - val keyName = matchResult.groupValues[1] - val templateType = matchResult.groupValues[2].ifEmpty { ":T" } - if (!scope.toMap().keys.contains(keyName)) { - throw CodegenException( - """ - Rust block template expected `$keyName` but was not present in template. - Hint: Template contains: ${scope.map { "`${it.first}`" }} - """.trimIndent(), - ) + val output = + template.replace(Regex("""#\{([a-zA-Z_0-9]+)(:\w)?\}""")) { matchResult -> + val keyName = matchResult.groupValues[1] + val templateType = matchResult.groupValues[2].ifEmpty { ":T" } + if (!scope.toMap().keys.contains(keyName)) { + throw CodegenException( + """ + Rust block template expected `$keyName` but was not present in template. + Hint: Template contains: ${scope.map { "`${it.first}`" }} + """.trimIndent(), + ) + } + "#{${keyName.lowercase()}$templateType}" } - "#{${keyName.lowercase()}$templateType}" - } return output.letIf(trim) { output.trim() } } @@ -298,13 +307,14 @@ fun > T.docsOrFallback( autoSuppressMissingDocs: Boolean = true, note: String? = null, ): T { - val htmlDocs: (T.() -> Unit)? = when (docString?.isNotBlank()) { - true -> { - { docs(normalizeHtml(escape(docString))) } - } + val htmlDocs: (T.() -> Unit)? = + when (docString?.isNotBlank()) { + true -> { + { docs(normalizeHtml(escape(docString))) } + } - else -> null - } + else -> null + } return docsOrFallback(htmlDocs, autoSuppressMissingDocs, note) } @@ -334,7 +344,11 @@ fun > T.docsOrFallback( * Document the containing entity (e.g. module, crate, etc.) * Instead of prefixing lines with `///` lines are prefixed with `//!` */ -fun RustWriter.containerDocs(text: String, vararg args: Any, trimStart: Boolean = true): RustWriter { +fun RustWriter.containerDocs( + text: String, + vararg args: Any, + trimStart: Boolean = true, +): RustWriter { return docs(text, newlinePrefix = "//! ", args = args, trimStart = trimStart) } @@ -366,13 +380,14 @@ fun > T.docs( this.ensureNewline() pushState() setNewlinePrefix(newlinePrefix) - val cleaned = text.lines() - .joinToString("\n") { - when (trimStart) { - true -> it.trimStart() - else -> it - }.replace("\t", " ") // Rustdoc warns on tabs in documentation - } + val cleaned = + text.lines() + .joinToString("\n") { + when (trimStart) { + true -> it.trimStart() + else -> it + }.replace("\t", " ") // Rustdoc warns on tabs in documentation + } write(cleaned, *args) popState() return this @@ -386,16 +401,20 @@ fun > T.docsTemplate( vararg args: Pair, newlinePrefix: String = "/// ", trimStart: Boolean = false, -): T = withTemplate(text, args, trim = false) { template -> - docs(template, newlinePrefix = newlinePrefix, trimStart = trimStart) -} +): T = + withTemplate(text, args, trim = false) { template -> + docs(template, newlinePrefix = newlinePrefix, trimStart = trimStart) + } /** * Writes a comment into the code * * Equivalent to [docs] but lines are preceded with `// ` instead of `///` */ -fun > T.comment(text: String, vararg args: Any): T { +fun > T.comment( + text: String, + vararg args: Any, +): T { return docs(text, *args, newlinePrefix = "// ") } @@ -441,19 +460,28 @@ private fun Element.changeInto(tagName: String) { } /** Write an `impl` block for the given symbol */ -fun RustWriter.implBlock(symbol: Symbol, block: Writable) { +fun RustWriter.implBlock( + symbol: Symbol, + block: Writable, +) { rustBlock("impl ${symbol.name}") { block() } } /** Write a `#[cfg(feature = "...")]` block for the given feature */ -fun RustWriter.featureGateBlock(feature: String, block: Writable) { +fun RustWriter.featureGateBlock( + feature: String, + block: Writable, +) { featureGatedBlock(feature, block)(this) } /** Write a `#[cfg(feature = "...")]` block for the given feature */ -fun featureGatedBlock(feature: String, block: Writable): Writable { +fun featureGatedBlock( + feature: String, + block: Writable, +): Writable { return writable { rustBlock("##[cfg(feature = ${feature.dq()})]") { block() @@ -461,7 +489,10 @@ fun featureGatedBlock(feature: String, block: Writable): Writable { } } -fun featureGatedBlock(feature: Feature, block: Writable): Writable { +fun featureGatedBlock( + feature: Feature, + block: Writable, +): Writable { return writable { rustBlock("##[cfg(feature = ${feature.name.dq()})]") { block() @@ -477,10 +508,12 @@ fun RustWriter.raw(text: String) = writeInline(escape(text)) /** * [rustTemplate] equivalent for `raw()`. Note: This function won't automatically escape formatter symbols. */ -fun RustWriter.rawTemplate(text: String, vararg args: Pair) = - withTemplate(text, args, trim = false) { templated -> - writeInline(templated) - } +fun RustWriter.rawTemplate( + text: String, + vararg args: Pair, +) = withTemplate(text, args, trim = false) { templated -> + writeInline(templated) +} /** * Rustdoc doesn't support `r#` for raw identifiers. @@ -499,353 +532,391 @@ class RustWriter private constructor( val devDependenciesOnly: Boolean = false, ) : SymbolWriter(UseDeclarations(namespace)) { + companion object { + fun root() = forModule(null) + + fun forModule(module: String?): RustWriter = + if (module == null) { + RustWriter("lib.rs", "crate") + } else { + RustWriter("$module.rs", "crate::$module") + } - companion object { - fun root() = forModule(null) - fun forModule(module: String?): RustWriter = if (module == null) { - RustWriter("lib.rs", "crate") - } else { - RustWriter("$module.rs", "crate::$module") - } + fun factory(debugMode: Boolean): Factory = + Factory { fileName: String, namespace: String -> + when { + fileName.endsWith(".toml") -> RustWriter(fileName, namespace, "#", debugMode = debugMode) + fileName.endsWith(".py") -> RustWriter(fileName, namespace, "#", debugMode = debugMode) + fileName.endsWith(".md") -> rawWriter(fileName, debugMode = debugMode) + fileName == "LICENSE" -> rawWriter(fileName, debugMode = debugMode) + fileName.startsWith("tests/") -> + RustWriter( + fileName, + namespace, + debugMode = debugMode, + devDependenciesOnly = true, + ) + + fileName == "package.json" -> rawWriter(fileName, debugMode = debugMode) + fileName == "stubgen.sh" -> rawWriter(fileName, debugMode = debugMode) + else -> RustWriter(fileName, namespace, debugMode = debugMode) + } + } - fun factory(debugMode: Boolean): Factory = Factory { fileName: String, namespace: String -> - when { - fileName.endsWith(".toml") -> RustWriter(fileName, namespace, "#", debugMode = debugMode) - fileName.endsWith(".py") -> RustWriter(fileName, namespace, "#", debugMode = debugMode) - fileName.endsWith(".md") -> rawWriter(fileName, debugMode = debugMode) - fileName == "LICENSE" -> rawWriter(fileName, debugMode = debugMode) - fileName.startsWith("tests/") -> RustWriter( + fun toml( + fileName: String, + debugMode: Boolean = false, + ): RustWriter = + RustWriter( fileName, - namespace, + namespace = "ignore", + commentCharacter = "#", + printWarning = false, debugMode = debugMode, - devDependenciesOnly = true, ) - fileName == "package.json" -> rawWriter(fileName, debugMode = debugMode) - fileName == "stubgen.sh" -> rawWriter(fileName, debugMode = debugMode) - else -> RustWriter(fileName, namespace, debugMode = debugMode) - } + private fun rawWriter( + fileName: String, + debugMode: Boolean, + ): RustWriter = + RustWriter( + fileName, + namespace = "ignore", + commentCharacter = "ignore", + printWarning = false, + debugMode = debugMode, + ) } - fun toml(fileName: String, debugMode: Boolean = false): RustWriter = - RustWriter( - fileName, - namespace = "ignore", - commentCharacter = "#", - printWarning = false, - debugMode = debugMode, - ) - - private fun rawWriter(fileName: String, debugMode: Boolean): RustWriter = - RustWriter( - fileName, - namespace = "ignore", - commentCharacter = "ignore", - printWarning = false, - debugMode = debugMode, - ) - } + override fun write( + content: Any?, + vararg args: Any?, + ): RustWriter { + // TODO(https://github.com/rust-lang/rustfmt/issues/5425): The second condition introduced here is to prevent + // this rustfmt bug + val contentIsNotJustAComma = (content as? String?)?.let { it.trim() != "," } ?: false + if (debugMode && contentIsNotJustAComma) { + val location = Thread.currentThread().stackTrace + location.first { it.isRelevant() }?.let { "/* ${it.fileName}:${it.lineNumber} */" } + ?.also { super.writeInline(it) } + } - override fun write(content: Any?, vararg args: Any?): RustWriter { - // TODO(https://github.com/rust-lang/rustfmt/issues/5425): The second condition introduced here is to prevent - // this rustfmt bug - val contentIsNotJustAComma = (content as? String?)?.let { it.trim() != "," } ?: false - if (debugMode && contentIsNotJustAComma) { - val location = Thread.currentThread().stackTrace - location.first { it.isRelevant() }?.let { "/* ${it.fileName}:${it.lineNumber} */" } - ?.also { super.writeInline(it) } + return super.write(content, *args) } - return super.write(content, *args) - } - - fun dirty() = super.toString().isNotBlank() || preamble.isNotEmpty() + fun dirty() = super.toString().isNotBlank() || preamble.isNotEmpty() - /** Helper function to determine if a stack frame is relevant for debug purposes */ - private fun StackTraceElement.isRelevant(): Boolean { - if (this.className.contains("AbstractCodeWriter") || this.className.startsWith("java.lang")) { - return false + /** Helper function to determine if a stack frame is relevant for debug purposes */ + private fun StackTraceElement.isRelevant(): Boolean { + if (this.className.contains("AbstractCodeWriter") || this.className.startsWith("java.lang")) { + return false + } + return this.fileName != "RustWriter.kt" } - return this.fileName != "RustWriter.kt" - } - private val preamble = mutableListOf() - private val formatter = RustSymbolFormatter() - private var n = 0 + private val preamble = mutableListOf() + private val formatter = RustSymbolFormatter() + private var n = 0 - init { - expressionStart = '#' - if (filename.endsWith(".rs")) { - require(namespace.startsWith("crate") || filename.startsWith("tests/") || filename == "build.rs") { - "We can only write into files in the crate (got $namespace)" + init { + expressionStart = '#' + if (filename.endsWith(".rs")) { + require(namespace.startsWith("crate") || filename.startsWith("tests/") || filename == "build.rs") { + "We can only write into files in the crate (got $namespace)" + } } + putFormatter('T', formatter) + putFormatter('D', RustDocLinker()) + putFormatter('W', RustWriteableInjector()) } - putFormatter('T', formatter) - putFormatter('D', RustDocLinker()) - putFormatter('W', RustWriteableInjector()) - } - fun module(): String? = if (filename.startsWith("src") && filename.endsWith(".rs")) { - filename.removeSuffix(".rs").substringAfterLast(File.separatorChar) - } else { - null - } - - fun safeName(prefix: String = "var"): String { - n += 1 - return "${prefix}_$n" - } - - fun first(preWriter: Writable) { - preamble.add(preWriter) - } - - private fun addDependencyTestAware(dependencyContainer: SymbolDependencyContainer): RustWriter { - if (!devDependenciesOnly) { - super.addDependency(dependencyContainer) - } else { - dependencyContainer.dependencies.forEach { dependency -> - super.addDependency( - when (val dep = RustDependency.fromSymbolDependency(dependency)) { - is CargoDependency -> dep.toDevDependency() - else -> dependencyContainer - }, - ) + fun module(): String? = + if (filename.startsWith("src") && filename.endsWith(".rs")) { + filename.removeSuffix(".rs").substringAfterLast(File.separatorChar) + } else { + null } - } - return this - } - /** - * Create an inline module. Instead of being in a new file, inline modules are written as a `mod { ... }` block - * directly into the parent. - * - * Callers must take care to use [this] when writing to ensure code is written to the right place: - * ```kotlin - * val writer = RustWriter.forModule("model") - * writer.withInlineModule(RustModule.public("nested")) { - * Generator(...).render(this) // GOOD - * Generator(...).render(writer) // WRONG! - * } - * ``` - * - * The returned writer will inject any local imports into the module as needed. - */ - fun withInlineModule( - module: RustModule.LeafModule, - moduleDocProvider: ModuleDocProvider?, - moduleWriter: Writable, - ): RustWriter { - check(module.isInline()) { - "Only inline modules may be used with `withInlineModule`: $module" + fun safeName(prefix: String = "var"): String { + n += 1 + return "${prefix}_$n" } - // In Rust, modules must specify their own imports—they don't have access to the parent scope. - // To easily handle this, create a new inner writer to collect imports, then dump it - // into an inline module. - val innerWriter = RustWriter( - this.filename, - "${this.namespace}::${module.name}", - printWarning = false, - devDependenciesOnly = devDependenciesOnly || module.tests, - ) - moduleWriter(innerWriter) - ModuleDocProvider.writeDocs(moduleDocProvider, module, this) - module.rustMetadata.render(this) - rustBlock("mod ${module.name}") { - writeWithNoFormatting(innerWriter.toString()) + fun first(preWriter: Writable) { + preamble.add(preWriter) } - innerWriter.dependencies.forEach { addDependencyTestAware(it) } - return this - } - /** - * Generate a wrapping if statement around a nullable value. - * The provided code block will only be called if the value is not `None`. - */ - fun ifSome(member: Symbol, value: ValueExpression, block: RustWriter.(value: ValueExpression) -> Unit) { - when { - member.isOptional() -> { - val innerValue = ValueExpression.Reference(safeName("inner")) - rustBlockTemplate("if let #{Some}(${innerValue.name}) = ${value.asRef()}", *preludeScope) { - block(innerValue) + private fun addDependencyTestAware(dependencyContainer: SymbolDependencyContainer): RustWriter { + if (!devDependenciesOnly) { + super.addDependency(dependencyContainer) + } else { + dependencyContainer.dependencies.forEach { dependency -> + super.addDependency( + when (val dep = RustDependency.fromSymbolDependency(dependency)) { + is CargoDependency -> dep.toDevDependency() + else -> dependencyContainer + }, + ) } } - - else -> this.block(value) + return this } - } - - /** - * Generate a wrapping if statement around a primitive field. - * The specified block will only be called if the field is not set to its default value - `0` for - * numbers, `false` for booleans. - */ - fun ifNotDefault(shape: Shape, variable: ValueExpression, block: RustWriter.(field: ValueExpression) -> Unit) { - when (shape) { - is FloatShape, is DoubleShape -> rustBlock("if ${variable.asValue()} != 0.0") { - block(variable) - } - - is NumberShape -> rustBlock("if ${variable.asValue()} != 0") { - block(variable) - } - is BooleanShape -> rustBlock("if ${variable.asValue()}") { - block(variable) + /** + * Create an inline module. Instead of being in a new file, inline modules are written as a `mod { ... }` block + * directly into the parent. + * + * Callers must take care to use [this] when writing to ensure code is written to the right place: + * ```kotlin + * val writer = RustWriter.forModule("model") + * writer.withInlineModule(RustModule.public("nested")) { + * Generator(...).render(this) // GOOD + * Generator(...).render(writer) // WRONG! + * } + * ``` + * + * The returned writer will inject any local imports into the module as needed. + */ + fun withInlineModule( + module: RustModule.LeafModule, + moduleDocProvider: ModuleDocProvider?, + moduleWriter: Writable, + ): RustWriter { + check(module.isInline()) { + "Only inline modules may be used with `withInlineModule`: $module" } - else -> rustBlock("") { - this.block(variable) + // In Rust, modules must specify their own imports—they don't have access to the parent scope. + // To easily handle this, create a new inner writer to collect imports, then dump it + // into an inline module. + val innerWriter = + RustWriter( + this.filename, + "${this.namespace}::${module.name}", + printWarning = false, + devDependenciesOnly = devDependenciesOnly || module.tests, + ) + moduleWriter(innerWriter) + ModuleDocProvider.writeDocs(moduleDocProvider, module, this) + module.rustMetadata.render(this) + rustBlock("mod ${module.name}") { + writeWithNoFormatting(innerWriter.toString()) } + innerWriter.dependencies.forEach { addDependencyTestAware(it) } + return this } - } - /** - * Generate a wrapping if statement around a field. - * - * - If the field is optional, it will only be called if the field is present - * - If the field is an unboxed primitive, it will only be called if the field is non-zero - * - * # Example - * - * For a nullable structure shape (e.g. `Option`), the following code will be generated: - * - * ``` - * if let Some(v) = my_nullable_struct { - * /* {block(variable)} */ - * } - * ``` - * - * # Example - * - * For a non-nullable integer shape, the following code will be generated: - * - * ``` - * if my_int != 0 { - * /* {block(variable)} */ - * } - * ``` - */ - fun ifSet( - shape: Shape, - member: Symbol, - variable: ValueExpression, - block: RustWriter.(field: ValueExpression) -> Unit, - ) { - ifSome(member, variable) { inner -> ifNotDefault(shape, inner, block) } - } + /** + * Generate a wrapping if statement around a nullable value. + * The provided code block will only be called if the value is not `None`. + */ + fun ifSome( + member: Symbol, + value: ValueExpression, + block: RustWriter.(value: ValueExpression) -> Unit, + ) { + when { + member.isOptional() -> { + val innerValue = ValueExpression.Reference(safeName("inner")) + rustBlockTemplate("if let #{Some}(${innerValue.name}) = ${value.asRef()}", *preludeScope) { + block(innerValue) + } + } - fun listForEach( - target: Shape, - outerField: String, - block: RustWriter.(field: String, target: ShapeId) -> Unit, - ) { - if (target is CollectionShape) { - val derefName = safeName("inner") - rustBlock("for $derefName in $outerField") { - block(derefName, target.member.target) + else -> this.block(value) } - } else { - this.block(outerField, target.toShapeId()) } - } - override fun toString(): String { - val contents = super.toString() - val preheader = if (preamble.isNotEmpty()) { - val prewriter = - RustWriter(filename, namespace, printWarning = false, devDependenciesOnly = devDependenciesOnly) - preamble.forEach { it(prewriter) } - prewriter.toString() - } else { - null + /** + * Generate a wrapping if statement around a primitive field. + * The specified block will only be called if the field is not set to its default value - `0` for + * numbers, `false` for booleans. + */ + fun ifNotDefault( + shape: Shape, + variable: ValueExpression, + block: RustWriter.(field: ValueExpression) -> Unit, + ) { + when (shape) { + is FloatShape, is DoubleShape -> + rustBlock("if ${variable.asValue()} != 0.0") { + block(variable) + } + + is NumberShape -> + rustBlock("if ${variable.asValue()} != 0") { + block(variable) + } + + is BooleanShape -> + rustBlock("if ${variable.asValue()}") { + block(variable) + } + + else -> + rustBlock("") { + this.block(variable) + } + } } - // Hack to support TOML: the [commentCharacter] is overridden to support writing TOML. - val header = if (printWarning) { - "$commentCharacter Code generated by software.amazon.smithy.rust.codegen.smithy-rs. DO NOT EDIT." - } else { - null - } - val useDecls = importContainer.toString().ifEmpty { - null + /** + * Generate a wrapping if statement around a field. + * + * - If the field is optional, it will only be called if the field is present + * - If the field is an unboxed primitive, it will only be called if the field is non-zero + * + * # Example + * + * For a nullable structure shape (e.g. `Option`), the following code will be generated: + * + * ``` + * if let Some(v) = my_nullable_struct { + * /* {block(variable)} */ + * } + * ``` + * + * # Example + * + * For a non-nullable integer shape, the following code will be generated: + * + * ``` + * if my_int != 0 { + * /* {block(variable)} */ + * } + * ``` + */ + fun ifSet( + shape: Shape, + member: Symbol, + variable: ValueExpression, + block: RustWriter.(field: ValueExpression) -> Unit, + ) { + ifSome(member, variable) { inner -> ifNotDefault(shape, inner, block) } } - return listOfNotNull(preheader, header, useDecls, contents).joinToString(separator = "\n", postfix = "\n") - } - fun format(r: Any) = formatter.apply(r, "") - - fun addDepsRecursively(symbol: Symbol) { - addDependencyTestAware(symbol) - symbol.references.forEach { addDepsRecursively(it.symbol) } - } - - /** - * Generate RustDoc links, e.g. [`Abc`](crate::module::Abc) - */ - inner class RustDocLinker : BiFunction { - override fun apply(t: Any, u: String): String { - return when (t) { - is Symbol -> "[`${t.name}`](${docLink(t.rustType().qualifiedName())})" - else -> throw CodegenException("Invalid type provided to RustDocLinker ($t) expected Symbol") + fun listForEach( + target: Shape, + outerField: String, + block: RustWriter.(field: String, target: ShapeId) -> Unit, + ) { + if (target is CollectionShape) { + val derefName = safeName("inner") + rustBlock("for $derefName in $outerField") { + block(derefName, target.member.target) + } + } else { + this.block(outerField, target.toShapeId()) } } - } - - /** - * Formatter to enable formatting any [writable] with the #W formatter. - */ - inner class RustWriteableInjector : BiFunction { - override fun apply(t: Any, u: String): String { - @Suppress("UNCHECKED_CAST") - val func = - t as? Writable ?: throw CodegenException("RustWriteableInjector.apply choked on non-function t ($t)") - val innerWriter = - RustWriter(filename, namespace, printWarning = false, devDependenciesOnly = devDependenciesOnly) - func(innerWriter) - innerWriter.dependencies.forEach { addDependencyTestAware(it) } - return innerWriter.toString().trimEnd() - } - } - inner class RustSymbolFormatter : BiFunction { - override fun apply(t: Any, u: String): String { - return when (t) { - is RuntimeType -> { - t.dependency?.also { addDependencyTestAware(it) } - // for now, use the fully qualified type name - t.fullyQualifiedName() + override fun toString(): String { + val contents = super.toString() + val preheader = + if (preamble.isNotEmpty()) { + val prewriter = + RustWriter(filename, namespace, printWarning = false, devDependenciesOnly = devDependenciesOnly) + preamble.forEach { it(prewriter) } + prewriter.toString() + } else { + null } - is RustModule -> { - t.fullyQualifiedPath() + // Hack to support TOML: the [commentCharacter] is overridden to support writing TOML. + val header = + if (printWarning) { + "$commentCharacter Code generated by software.amazon.smithy.rust.codegen.smithy-rs. DO NOT EDIT." + } else { + null } - - is Symbol -> { - addDepsRecursively(t) - t.rustType().render(fullyQualified = true) + val useDecls = + importContainer.toString().ifEmpty { + null } + return listOfNotNull(preheader, header, useDecls, contents).joinToString(separator = "\n", postfix = "\n") + } - is RustType -> { - t.render(fullyQualified = true) - } + fun format(r: Any) = formatter.apply(r, "") - is Function<*> -> { - @Suppress("UNCHECKED_CAST") - val func = - t as? Writable ?: throw CodegenException("Invalid function type (expected writable) ($t)") - val innerWriter = - RustWriter(filename, namespace, printWarning = false, devDependenciesOnly = devDependenciesOnly) - func(innerWriter) - innerWriter.dependencies.forEach { addDependencyTestAware(it) } - return innerWriter.toString().trimEnd() + fun addDepsRecursively(symbol: Symbol) { + addDependencyTestAware(symbol) + symbol.references.forEach { addDepsRecursively(it.symbol) } + } + + /** + * Generate RustDoc links, e.g. [`Abc`](crate::module::Abc) + */ + inner class RustDocLinker : BiFunction { + override fun apply( + t: Any, + u: String, + ): String { + return when (t) { + is Symbol -> "[`${t.name}`](${docLink(t.rustType().qualifiedName())})" + else -> throw CodegenException("Invalid type provided to RustDocLinker ($t) expected Symbol") } + } + } + + /** + * Formatter to enable formatting any [writable] with the #W formatter. + */ + inner class RustWriteableInjector : BiFunction { + override fun apply( + t: Any, + u: String, + ): String { + @Suppress("UNCHECKED_CAST") + val func = + t as? Writable ?: throw CodegenException("RustWriteableInjector.apply choked on non-function t ($t)") + val innerWriter = + RustWriter(filename, namespace, printWarning = false, devDependenciesOnly = devDependenciesOnly) + func(innerWriter) + innerWriter.dependencies.forEach { addDependencyTestAware(it) } + return innerWriter.toString().trimEnd() + } + } - else -> throw CodegenException("Invalid type provided to RustSymbolFormatter: $t") - // escaping generates `##` sequences for all the common cases where - // it will be run through templating, but in this context, we won't be escaped - }.replace("##", "#") + inner class RustSymbolFormatter : BiFunction { + override fun apply( + t: Any, + u: String, + ): String { + return when (t) { + is RuntimeType -> { + t.dependency?.also { addDependencyTestAware(it) } + // for now, use the fully qualified type name + t.fullyQualifiedName() + } + + is RustModule -> { + t.fullyQualifiedPath() + } + + is Symbol -> { + addDepsRecursively(t) + t.rustType().render(fullyQualified = true) + } + + is RustType -> { + t.render(fullyQualified = true) + } + + is Function<*> -> { + @Suppress("UNCHECKED_CAST") + val func = + t as? Writable ?: throw CodegenException("Invalid function type (expected writable) ($t)") + val innerWriter = + RustWriter(filename, namespace, printWarning = false, devDependenciesOnly = devDependenciesOnly) + func(innerWriter) + innerWriter.dependencies.forEach { addDependencyTestAware(it) } + return innerWriter.toString().trimEnd() + } + + else -> throw CodegenException("Invalid type provided to RustSymbolFormatter: $t") + // escaping generates `##` sequences for all the common cases where + // it will be run through templating, but in this context, we won't be escaped + }.replace("##", "#") + } } } -} diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/UseDeclarations.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/UseDeclarations.kt index 4e409bf126e..bbfc8696ba3 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/UseDeclarations.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/UseDeclarations.kt @@ -10,7 +10,12 @@ import software.amazon.smithy.codegen.core.Symbol class UseDeclarations(private val namespace: String) : ImportContainer { private val imports: MutableSet = mutableSetOf() - fun addImport(moduleName: String, symbolName: String, alias: String = symbolName) { + + fun addImport( + moduleName: String, + symbolName: String, + alias: String = symbolName, + ) { imports.add(UseStatement(moduleName, symbolName, alias)) } @@ -18,7 +23,10 @@ class UseDeclarations(private val namespace: String) : ImportContainer { return imports.map { it.toString() }.sorted().joinToString(separator = "\n") } - override fun importSymbol(symbol: Symbol, alias: String?) { + override fun importSymbol( + symbol: Symbol, + alias: String?, + ) { if (symbol.namespace.isNotEmpty() && symbol.namespace != namespace) { addImport(symbol.namespace, symbol.name, alias ?: symbol.name) } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/Writable.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/Writable.kt index 8ef68e9ef25..ad4b07c2308 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/Writable.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/Writable.kt @@ -105,31 +105,33 @@ fun Array.join(separator: Writable) = asIterable().join(separator) * some_fn::(); * ``` */ -fun rustTypeParameters( - vararg typeParameters: Any, -): Writable = writable { - if (typeParameters.isNotEmpty()) { - val items = typeParameters.map { typeParameter -> - writable { - when (typeParameter) { - is Symbol, is RuntimeType, is RustType -> rustInlineTemplate("#{it}", "it" to typeParameter) - is String -> rustInlineTemplate(typeParameter) - is RustGenerics -> rustInlineTemplate( - "#{gg:W}", - "gg" to typeParameter.declaration(withAngleBrackets = false), - ) - - else -> { - // Check if it's a writer. If it is, invoke it; Else, throw a codegen error. - @Suppress("UNCHECKED_CAST") - val func = typeParameter as? Writable - ?: throw CodegenException("Unhandled type '$typeParameter' encountered by rustTypeParameters writer") - func.invoke(this) +fun rustTypeParameters(vararg typeParameters: Any): Writable = + writable { + if (typeParameters.isNotEmpty()) { + val items = + typeParameters.map { typeParameter -> + writable { + when (typeParameter) { + is Symbol, is RuntimeType, is RustType -> rustInlineTemplate("#{it}", "it" to typeParameter) + is String -> rustInlineTemplate(typeParameter) + is RustGenerics -> + rustInlineTemplate( + "#{gg:W}", + "gg" to typeParameter.declaration(withAngleBrackets = false), + ) + + else -> { + // Check if it's a writer. If it is, invoke it; Else, throw a codegen error. + @Suppress("UNCHECKED_CAST") + val func = + typeParameter as? Writable + ?: throw CodegenException("Unhandled type '$typeParameter' encountered by rustTypeParameters writer") + func.invoke(this) + } + } } } - } - } - rustInlineTemplate("<#{Items:W}>", "Items" to items.join(", ")) + rustInlineTemplate("<#{Items:W}>", "Items" to items.join(", ")) + } } -} diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/CodegenContext.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/CodegenContext.kt index a08eddb5ed5..eba03b6f71d 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/CodegenContext.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/CodegenContext.kt @@ -27,32 +27,26 @@ abstract class CodegenContext( * an entry point. */ open val model: Model, - /** * The "canonical" symbol provider to convert Smithy [Shape]s into [Symbol]s, which have an associated [RustType]. */ open val symbolProvider: RustSymbolProvider, - /** * Provider of documentation for generated Rust modules. */ open val moduleDocProvider: ModuleDocProvider?, - /** * Entrypoint service shape for code generation. */ open val serviceShape: ServiceShape, - /** * Shape indicating the protocol to generate, e.g. RestJson1. */ open val protocol: ShapeId, - /** * Settings loaded from `smithy-build.json`. */ open val settings: CoreRustSettings, - /** * Are we generating code for a smithy-rs client or server? * @@ -63,22 +57,23 @@ abstract class CodegenContext( */ open val target: CodegenTarget, ) { - /** * Configuration of the runtime package: * - Where are the runtime crates (smithy-*) located on the file system? Or are they versioned? * - What are they called? + * + * This is just a convenience. To avoid typing `context.settings.runtimeConfig`, you can simply write + * `context.runtimeConfig`. */ - // This is just a convenience. To avoid typing `context.settings.runtimeConfig`, you can simply write - // `context.runtimeConfig`. val runtimeConfig: RuntimeConfig by lazy { settings.runtimeConfig } /** * The name of the cargo crate to generate e.g. `aws-sdk-s3` * This is loaded from the smithy-build.json during codegen. + * + * This is just a convenience. To avoid typing `context.settings.moduleName`, you can simply write + * `context.moduleName`. */ - // This is just a convenience. To avoid typing `context.settings.moduleName`, you can simply write - // `context.moduleName`. val moduleName: String by lazy { settings.moduleName } /** @@ -88,9 +83,10 @@ abstract class CodegenContext( fun moduleUseName() = moduleName.replace("-", "_") /** Return a ModuleDocProvider or panic if one wasn't configured */ - fun expectModuleDocProvider(): ModuleDocProvider = checkNotNull(moduleDocProvider) { - "A ModuleDocProvider must be set on the CodegenContext" - } + fun expectModuleDocProvider(): ModuleDocProvider = + checkNotNull(moduleDocProvider) { + "A ModuleDocProvider must be set on the CodegenContext" + } fun structSettings() = StructSettings(settings.codegenConfig.flattenCollectionAccessors) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/CodegenDelegator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/CodegenDelegator.kt index d52359d6c40..d494ea0caff 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/CodegenDelegator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/CodegenDelegator.kt @@ -31,7 +31,11 @@ import software.amazon.smithy.rust.codegen.core.smithy.generators.ManifestCustom /** Provider of documentation for generated Rust modules */ interface ModuleDocProvider { companion object { - fun writeDocs(provider: ModuleDocProvider?, module: RustModule.LeafModule, writer: RustWriter) { + fun writeDocs( + provider: ModuleDocProvider?, + module: RustModule.LeafModule, + writer: RustWriter, + ) { check( provider != null || module.documentationOverride != null || @@ -90,7 +94,10 @@ open class RustCrate( /** * Write into the module that this shape is [locatedIn] */ - fun useShapeWriter(shape: Shape, f: Writable) { + fun useShapeWriter( + shape: Shape, + f: Writable, + ) { val module = symbolProvider.toSymbol(shape).module() check(!module.isInline()) { "Cannot use useShapeWriter with inline modules—use [RustWriter.withInlineModule] instead" @@ -191,9 +198,10 @@ open class RustCrate( } else { // Create a dependency which adds the mod statement for this module. This will be added to the writer // so that _usage_ of this module will generate _exactly one_ `mod ` with the correct modifiers. - val modStatement = RuntimeType.forInlineFun("mod_" + module.fullyQualifiedPath(), module.parent) { - module.renderModStatement(this, moduleDocProvider) - } + val modStatement = + RuntimeType.forInlineFun("mod_" + module.fullyQualifiedPath(), module.parent) { + module.renderModStatement(this, moduleDocProvider) + } val path = module.fullyQualifiedPath().split("::").drop(1).joinToString("/") inner.useFileWriter("src/$path.rs", module.fullyQualifiedPath()) { writer -> moduleWriter(writer) @@ -208,13 +216,18 @@ open class RustCrate( /** * Returns the module for a given Shape. */ - fun moduleFor(shape: Shape, moduleWriter: Writable): RustCrate = - withModule((symbolProvider as RustSymbolProvider).moduleForShape(shape), moduleWriter) + fun moduleFor( + shape: Shape, + moduleWriter: Writable, + ): RustCrate = withModule((symbolProvider as RustSymbolProvider).moduleForShape(shape), moduleWriter) /** * Create a new file directly */ - fun withFile(filename: String, fileWriter: Writable) { + fun withFile( + filename: String, + fileWriter: Writable, + ) { inner.useFileWriter(filename) { fileWriter(it) } @@ -227,7 +240,11 @@ open class RustCrate( * @param symbol: The symbol of the thing being rendered, which will be re-exported. This symbol * should be the public-facing symbol rather than the private symbol. */ - fun inPrivateModuleWithReexport(privateModule: RustModule.LeafModule, symbol: Symbol, writer: Writable) { + fun inPrivateModuleWithReexport( + privateModule: RustModule.LeafModule, + symbol: Symbol, + writer: Writable, + ) { withModule(privateModule, writer) privateModule.toType().resolve(symbol.name).toSymbol().also { privateSymbol -> withModule(symbol.module()) { @@ -264,13 +281,14 @@ fun WriterDelegator.finalize( .mergeDependencyFeatures() .mergeIdenticalTestDependencies() this.useFileWriter("Cargo.toml") { - val cargoToml = CargoTomlGenerator( - settings, - it, - manifestCustomizations, - cargoDependencies, - features, - ) + val cargoToml = + CargoTomlGenerator( + settings, + it, + manifestCustomizations, + cargoDependencies, + features, + ) cargoToml.render() } flushWriters() diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/CodegenTarget.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/CodegenTarget.kt index b1dba3e33f9..fff8f264a27 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/CodegenTarget.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/CodegenTarget.kt @@ -9,23 +9,27 @@ package software.amazon.smithy.rust.codegen.core.smithy * Code generation mode: In some situations, codegen has different behavior for client vs. server (eg. required fields) */ enum class CodegenTarget { - CLIENT, SERVER; + CLIENT, + SERVER, + ; /** * Convenience method to execute thunk if the target is for CLIENT */ - fun ifClient(thunk: () -> B): B? = if (this == CLIENT) { - thunk() - } else { - null - } + fun ifClient(thunk: () -> B): B? = + if (this == CLIENT) { + thunk() + } else { + null + } /** * Convenience method to execute thunk if the target is for SERVER */ - fun ifServer(thunk: () -> B): B? = if (this == SERVER) { - thunk() - } else { - null - } + fun ifServer(thunk: () -> B): B? = + if (this == SERVER) { + thunk() + } else { + null + } } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/CoreRustSettings.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/CoreRustSettings.kt index 5e103f14625..18c9fa5d7a6 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/CoreRustSettings.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/CoreRustSettings.kt @@ -79,7 +79,6 @@ open class CoreRustSettings( open val moduleAuthors: List, open val moduleDescription: String?, open val moduleRepository: String?, - /** * Configuration of the runtime package: * - Where are the runtime crates (smithy-*) located on the file system? Or are they versioned? @@ -91,7 +90,6 @@ open class CoreRustSettings( open val examplesUri: String? = null, open val customizationConfig: ObjectNode? = null, ) { - /** * Get the corresponding [ServiceShape] from a model. * @return Returns the found `Service` @@ -111,10 +109,11 @@ open class CoreRustSettings( // Infer the service to generate from a model. @JvmStatic protected fun inferService(model: Model): ShapeId { - val services = model.shapes(ServiceShape::class.java) - .map(Shape::getId) - .sorted() - .toList() + val services = + model.shapes(ServiceShape::class.java) + .map(Shape::getId) + .sorted() + .toList() when { services.isEmpty() -> { @@ -144,7 +143,10 @@ open class CoreRustSettings( * @param config Config object to load * @return Returns the extracted settings */ - fun from(model: Model, config: ObjectNode): CoreRustSettings { + fun from( + model: Model, + config: ObjectNode, + ): CoreRustSettings { val codegenSettings = config.getObjectMember(CODEGEN_SETTINGS) val coreCodegenConfig = CoreCodegenConfig.fromNode(codegenSettings) return fromCodegenConfig(model, config, coreCodegenConfig) @@ -158,7 +160,11 @@ open class CoreRustSettings( * @param coreCodegenConfig CodegenConfig object to use * @return Returns the extracted settings */ - private fun fromCodegenConfig(model: Model, config: ObjectNode, coreCodegenConfig: CoreCodegenConfig): CoreRustSettings { + private fun fromCodegenConfig( + model: Model, + config: ObjectNode, + coreCodegenConfig: CoreCodegenConfig, + ): CoreRustSettings { config.warnIfAdditionalProperties( arrayListOf( SERVICE, @@ -175,9 +181,10 @@ open class CoreRustSettings( ), ) - val service = config.getStringMember(SERVICE) - .map(StringNode::expectShapeId) - .orElseGet { inferService(model) } + val service = + config.getStringMember(SERVICE) + .map(StringNode::expectShapeId) + .orElseGet { inferService(model) } val runtimeConfig = config.getObjectMember(RUNTIME_CONFIG) return CoreRustSettings( diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/DirectedWalker.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/DirectedWalker.kt index f48b996045e..5ce1a3bed7a 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/DirectedWalker.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/DirectedWalker.kt @@ -21,6 +21,9 @@ class DirectedWalker(model: Model) { fun walkShapes(shape: Shape): Set = walkShapes(shape) { true } - fun walkShapes(shape: Shape, predicate: Predicate): Set = + fun walkShapes( + shape: Shape, + predicate: Predicate, + ): Set = inner.walkShapes(shape) { rel -> predicate.test(rel) && rel.direction == RelationshipDirection.DIRECTED } } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/EventStreamSymbolProvider.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/EventStreamSymbolProvider.kt index 97ef843fe7a..ee19295897f 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/EventStreamSymbolProvider.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/EventStreamSymbolProvider.kt @@ -34,33 +34,38 @@ class EventStreamSymbolProvider( // We only want to wrap with Event Stream types when dealing with member shapes if (shape is MemberShape && shape.isEventStream(model)) { // Determine if the member has a container that is a synthetic input or output - val operationShape = model.expectShape(shape.container).let { maybeInputOutput -> - val operationId = maybeInputOutput.getTrait()?.operation - ?: maybeInputOutput.getTrait()?.operation - operationId?.let { model.expectShape(it, OperationShape::class.java) } - } + val operationShape = + model.expectShape(shape.container).let { maybeInputOutput -> + val operationId = + maybeInputOutput.getTrait()?.operation + ?: maybeInputOutput.getTrait()?.operation + operationId?.let { model.expectShape(it, OperationShape::class.java) } + } // If we find an operation shape, then we can wrap the type if (operationShape != null) { val unionShape = model.expectShape(shape.target).asUnionShape().get() - val error = if (target == CodegenTarget.SERVER && unionShape.eventStreamErrors().isEmpty()) { - RuntimeType.smithyHttp(runtimeConfig).resolve("event_stream::MessageStreamError").toSymbol() - } else { - symbolForEventStreamError(unionShape) - } + val error = + if (target == CodegenTarget.SERVER && unionShape.eventStreamErrors().isEmpty()) { + RuntimeType.smithyHttp(runtimeConfig).resolve("event_stream::MessageStreamError").toSymbol() + } else { + symbolForEventStreamError(unionShape) + } val errorT = error.rustType() val innerT = initial.rustType().stripOuter() - val isSender = (shape.isInputEventStream(model) && target == CodegenTarget.CLIENT) || - (shape.isOutputEventStream(model) && target == CodegenTarget.SERVER) - val outer = when (isSender) { - true -> RuntimeType.eventStreamSender(runtimeConfig).toSymbol().rustType() - else -> { - if (target == CodegenTarget.SERVER) { - RuntimeType.eventStreamReceiver(runtimeConfig).toSymbol().rustType() - } else { - RuntimeType.eventReceiver(runtimeConfig).toSymbol().rustType() + val isSender = + (shape.isInputEventStream(model) && target == CodegenTarget.CLIENT) || + (shape.isOutputEventStream(model) && target == CodegenTarget.SERVER) + val outer = + when (isSender) { + true -> RuntimeType.eventStreamSender(runtimeConfig).toSymbol().rustType() + else -> { + if (target == CodegenTarget.SERVER) { + RuntimeType.eventStreamReceiver(runtimeConfig).toSymbol().rustType() + } else { + RuntimeType.eventReceiver(runtimeConfig).toSymbol().rustType() + } } } - } val rustType = RustType.Application(outer, listOf(innerT, errorT)) return initial.toBuilder() .name(rustType.name) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt index 6235af28447..760941498b8 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt @@ -36,14 +36,15 @@ private const val DEFAULT_KEY = "DEFAULT" */ data class RuntimeCrateLocation(val path: String?, val versions: CrateVersionMap) { companion object { - fun Path(path: String) = RuntimeCrateLocation(path, CrateVersionMap(emptyMap())) + fun path(path: String) = RuntimeCrateLocation(path, CrateVersionMap(emptyMap())) } } fun RuntimeCrateLocation.crateLocation(crateName: String): DependencyLocation { - val version = crateName.let { - versions.map[crateName] - } ?: Version.crateVersion(crateName) + val version = + crateName.let { + versions.map[crateName] + } ?: Version.crateVersion(crateName) return when (this.path) { // CratesIo needs an exact version. However, for local runtime crates we do not // provide a detected version unless the user explicitly sets one via the `versions` map. @@ -73,10 +74,9 @@ value class CrateVersionMap( */ data class RuntimeConfig( val cratePrefix: String = "aws", - val runtimeCrateLocation: RuntimeCrateLocation = RuntimeCrateLocation.Path("../"), + val runtimeCrateLocation: RuntimeCrateLocation = RuntimeCrateLocation.path("../"), ) { companion object { - /** * Load a `RuntimeConfig` from an [ObjectNode] (JSON) */ @@ -144,12 +144,13 @@ data class RuntimeType(val path: String, val dependency: RustDependency? = null) /** * Get a writable for this `RuntimeType` */ - val writable = writable { - rustInlineTemplate( - "#{this:T}", - "this" to this@RuntimeType, - ) - } + val writable = + writable { + rustInlineTemplate( + "#{this:T}", + "this" to this@RuntimeType, + ) + } /** * Convert this [RuntimeType] into a [Symbol]. @@ -158,11 +159,12 @@ data class RuntimeType(val path: String, val dependency: RustDependency? = null) * (e.g. when bringing a trait into scope). See [CodegenWriter.addUseImports]. */ fun toSymbol(): Symbol { - val builder = Symbol - .builder() - .name(name) - .namespace(namespace, "::") - .rustType(RustType.Opaque(name, namespace)) + val builder = + Symbol + .builder() + .name(name) + .namespace(namespace, "::") + .rustType(RustType.Opaque(name, namespace)) dependency?.run { builder.addDependency(this) } return builder.build() @@ -240,7 +242,6 @@ data class RuntimeType(val path: String, val dependency: RustDependency? = null) "String" to String, "ToString" to std.resolve("string::ToString"), "Vec" to Vec, - // 2021 Edition "TryFrom" to std.resolve("convert::TryFrom"), "TryInto" to std.resolve("convert::TryInto"), @@ -315,19 +316,28 @@ data class RuntimeType(val path: String, val dependency: RustDependency? = null) // smithy runtime types fun smithyAsync(runtimeConfig: RuntimeConfig) = CargoDependency.smithyAsync(runtimeConfig).toType() + fun smithyChecksums(runtimeConfig: RuntimeConfig) = CargoDependency.smithyChecksums(runtimeConfig).toType() fun smithyEventStream(runtimeConfig: RuntimeConfig) = CargoDependency.smithyEventStream(runtimeConfig).toType() + fun smithyHttp(runtimeConfig: RuntimeConfig) = CargoDependency.smithyHttp(runtimeConfig).toType() + fun smithyJson(runtimeConfig: RuntimeConfig) = CargoDependency.smithyJson(runtimeConfig).toType() + fun smithyQuery(runtimeConfig: RuntimeConfig) = CargoDependency.smithyQuery(runtimeConfig).toType() + fun smithyRuntime(runtimeConfig: RuntimeConfig) = CargoDependency.smithyRuntime(runtimeConfig).toType() + fun smithyRuntimeApi(runtimeConfig: RuntimeConfig) = CargoDependency.smithyRuntimeApi(runtimeConfig).toType() + fun smithyRuntimeApiClient(runtimeConfig: RuntimeConfig) = CargoDependency.smithyRuntimeApiClient(runtimeConfig).toType() fun smithyTypes(runtimeConfig: RuntimeConfig) = CargoDependency.smithyTypes(runtimeConfig).toType() + fun smithyXml(runtimeConfig: RuntimeConfig) = CargoDependency.smithyXml(runtimeConfig).toType() + private fun smithyProtocolTest(runtimeConfig: RuntimeConfig) = CargoDependency.smithyProtocolTestHelpers(runtimeConfig).toType() @@ -402,11 +412,17 @@ data class RuntimeType(val path: String, val dependency: RustDependency? = null) smithyRuntimeApi(runtimeConfig).resolve("http::Headers") fun blob(runtimeConfig: RuntimeConfig) = smithyTypes(runtimeConfig).resolve("Blob") + fun byteStream(runtimeConfig: RuntimeConfig) = smithyTypes(runtimeConfig).resolve("byte_stream::ByteStream") + fun dateTime(runtimeConfig: RuntimeConfig) = smithyTypes(runtimeConfig).resolve("DateTime") + fun document(runtimeConfig: RuntimeConfig): RuntimeType = smithyTypes(runtimeConfig).resolve("Document") + fun format(runtimeConfig: RuntimeConfig) = smithyTypes(runtimeConfig).resolve("date_time::Format") + fun retryErrorKind(runtimeConfig: RuntimeConfig) = smithyTypes(runtimeConfig).resolve("retry::ErrorKind") + fun eventStreamReceiver(runtimeConfig: RuntimeConfig): RuntimeType = smithyHttp(runtimeConfig).resolve("event_stream::Receiver") @@ -420,6 +436,7 @@ data class RuntimeType(val path: String, val dependency: RustDependency? = null) smithyHttp(runtimeConfig).resolve("futures_stream_adapter::FuturesStreamCompatByteStream") fun errorMetadata(runtimeConfig: RuntimeConfig) = smithyTypes(runtimeConfig).resolve("error::ErrorMetadata") + fun errorMetadataBuilder(runtimeConfig: RuntimeConfig) = smithyTypes(runtimeConfig).resolve("error::metadata::Builder") @@ -427,6 +444,7 @@ data class RuntimeType(val path: String, val dependency: RustDependency? = null) smithyTypes(runtimeConfig).resolve("error::metadata::ProvideErrorMetadata") fun jsonErrors(runtimeConfig: RuntimeConfig) = forInlineDependency(InlineDependency.jsonErrors(runtimeConfig)) + fun awsQueryCompatibleErrors(runtimeConfig: RuntimeConfig) = forInlineDependency(InlineDependency.awsQueryCompatibleErrors(runtimeConfig)) @@ -434,17 +452,28 @@ data class RuntimeType(val path: String, val dependency: RustDependency? = null) RuntimeType.forInlineDependency(InlineDependency.defaultAuthPlugin(runtimeConfig)) .resolve("DefaultAuthOptionsPlugin") - fun labelFormat(runtimeConfig: RuntimeConfig, func: String) = smithyHttp(runtimeConfig).resolve("label::$func") + fun labelFormat( + runtimeConfig: RuntimeConfig, + func: String, + ) = smithyHttp(runtimeConfig).resolve("label::$func") + fun operation(runtimeConfig: RuntimeConfig) = smithyHttp(runtimeConfig).resolve("operation::Operation") + fun operationModule(runtimeConfig: RuntimeConfig) = smithyHttp(runtimeConfig).resolve("operation") - fun protocolTest(runtimeConfig: RuntimeConfig, func: String): RuntimeType = - smithyProtocolTest(runtimeConfig).resolve(func) + fun protocolTest( + runtimeConfig: RuntimeConfig, + func: String, + ): RuntimeType = smithyProtocolTest(runtimeConfig).resolve(func) fun provideErrorKind(runtimeConfig: RuntimeConfig) = smithyTypes(runtimeConfig).resolve("retry::ProvideErrorKind") - fun queryFormat(runtimeConfig: RuntimeConfig, func: String) = smithyHttp(runtimeConfig).resolve("query::$func") + fun queryFormat( + runtimeConfig: RuntimeConfig, + func: String, + ) = smithyHttp(runtimeConfig).resolve("query::$func") + fun sdkBody(runtimeConfig: RuntimeConfig): RuntimeType = smithyTypes(runtimeConfig).resolve("body::SdkBody") fun parseTimestampFormat( @@ -452,13 +481,14 @@ data class RuntimeType(val path: String, val dependency: RustDependency? = null) runtimeConfig: RuntimeConfig, format: TimestampFormatTrait.Format, ): RuntimeType { - val timestampFormat = when (format) { - TimestampFormatTrait.Format.EPOCH_SECONDS -> "EpochSeconds" - // clients allow offsets, servers do nt - TimestampFormatTrait.Format.DATE_TIME -> codegenTarget.ifClient { "DateTimeWithOffset" } ?: "DateTime" - TimestampFormatTrait.Format.HTTP_DATE -> "HttpDate" - TimestampFormatTrait.Format.UNKNOWN -> TODO() - } + val timestampFormat = + when (format) { + TimestampFormatTrait.Format.EPOCH_SECONDS -> "EpochSeconds" + // clients allow offsets, servers do nt + TimestampFormatTrait.Format.DATE_TIME -> codegenTarget.ifClient { "DateTimeWithOffset" } ?: "DateTime" + TimestampFormatTrait.Format.HTTP_DATE -> "HttpDate" + TimestampFormatTrait.Format.UNKNOWN -> TODO() + } return smithyTypes(runtimeConfig).resolve("date_time::Format::$timestampFormat") } @@ -467,24 +497,30 @@ data class RuntimeType(val path: String, val dependency: RustDependency? = null) runtimeConfig: RuntimeConfig, format: TimestampFormatTrait.Format, ): RuntimeType { - val timestampFormat = when (format) { - TimestampFormatTrait.Format.EPOCH_SECONDS -> "EpochSeconds" - // clients allow offsets, servers do not - TimestampFormatTrait.Format.DATE_TIME -> "DateTime" - TimestampFormatTrait.Format.HTTP_DATE -> "HttpDate" - TimestampFormatTrait.Format.UNKNOWN -> TODO() - } + val timestampFormat = + when (format) { + TimestampFormatTrait.Format.EPOCH_SECONDS -> "EpochSeconds" + // clients allow offsets, servers do not + TimestampFormatTrait.Format.DATE_TIME -> "DateTime" + TimestampFormatTrait.Format.HTTP_DATE -> "HttpDate" + TimestampFormatTrait.Format.UNKNOWN -> TODO() + } return smithyTypes(runtimeConfig).resolve("date_time::Format::$timestampFormat") } - fun captureRequest(runtimeConfig: RuntimeConfig) = CargoDependency.smithyRuntimeTestUtil(runtimeConfig).toType() - .resolve("client::http::test_util::capture_request") + fun captureRequest(runtimeConfig: RuntimeConfig) = + CargoDependency.smithyRuntimeTestUtil(runtimeConfig).toType() + .resolve("client::http::test_util::capture_request") fun forInlineDependency(inlineDependency: InlineDependency) = RuntimeType("crate::${inlineDependency.name}", inlineDependency) - fun forInlineFun(name: String, module: RustModule, func: Writable) = RuntimeType( + fun forInlineFun( + name: String, + module: RustModule, + func: Writable, + ) = RuntimeType( "${module.fullyQualifiedPath()}::$name", dependency = InlineDependency(name, module, listOf(), func), ) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RustSymbolProvider.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RustSymbolProvider.kt index 2314007aa93..78ab4bfebcf 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RustSymbolProvider.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RustSymbolProvider.kt @@ -26,10 +26,13 @@ interface RustSymbolProvider : SymbolProvider { fun moduleForShape(shape: Shape): RustModule.LeafModule = config.moduleProvider.moduleForShape(moduleProviderContext, shape) + fun moduleForOperationError(operation: OperationShape): RustModule.LeafModule = config.moduleProvider.moduleForOperationError(moduleProviderContext, operation) + fun moduleForEventStreamError(eventStream: UnionShape): RustModule.LeafModule = config.moduleProvider.moduleForEventStreamError(moduleProviderContext, eventStream) + fun moduleForBuilder(shape: Shape): RustModule.LeafModule = config.moduleProvider.moduleForBuilder(moduleProviderContext, shape, toSymbol(shape)) @@ -61,16 +64,29 @@ fun CodegenContext.toModuleProviderContext(): ModuleProviderContext = */ interface ModuleProvider { /** Returns the module for a shape */ - fun moduleForShape(context: ModuleProviderContext, shape: Shape): RustModule.LeafModule + fun moduleForShape( + context: ModuleProviderContext, + shape: Shape, + ): RustModule.LeafModule /** Returns the module for an operation error */ - fun moduleForOperationError(context: ModuleProviderContext, operation: OperationShape): RustModule.LeafModule + fun moduleForOperationError( + context: ModuleProviderContext, + operation: OperationShape, + ): RustModule.LeafModule /** Returns the module for an event stream error */ - fun moduleForEventStreamError(context: ModuleProviderContext, eventStream: UnionShape): RustModule.LeafModule + fun moduleForEventStreamError( + context: ModuleProviderContext, + eventStream: UnionShape, + ): RustModule.LeafModule /** Returns the module for a builder */ - fun moduleForBuilder(context: ModuleProviderContext, shape: Shape, symbol: Symbol): RustModule.LeafModule + fun moduleForBuilder( + context: ModuleProviderContext, + shape: Shape, + symbol: Symbol, + ): RustModule.LeafModule } /** @@ -93,9 +109,13 @@ open class WrappingSymbolProvider(private val base: RustSymbolProvider) : RustSy override val config: RustSymbolProviderConfig get() = base.config override fun toSymbol(shape: Shape): Symbol = base.toSymbol(shape) + override fun toMemberName(shape: MemberShape): String = base.toMemberName(shape) + override fun symbolForOperationError(operation: OperationShape): Symbol = base.symbolForOperationError(operation) + override fun symbolForEventStreamError(eventStream: UnionShape): Symbol = base.symbolForEventStreamError(eventStream) + override fun symbolForBuilder(shape: Shape): Symbol = base.symbolForBuilder(shape) } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/StreamingTraitSymbolProvider.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/StreamingTraitSymbolProvider.kt index 051f3c3d1ee..646fb3d5ca9 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/StreamingTraitSymbolProvider.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/StreamingTraitSymbolProvider.kt @@ -77,11 +77,16 @@ class StreamingShapeMetadataProvider(private val base: RustSymbolProvider) : Sym } override fun memberMeta(memberShape: MemberShape) = base.toSymbol(memberShape).expectRustMetadata() + override fun enumMeta(stringShape: StringShape) = base.toSymbol(stringShape).expectRustMetadata() override fun listMeta(listShape: ListShape) = base.toSymbol(listShape).expectRustMetadata() + override fun mapMeta(mapShape: MapShape) = base.toSymbol(mapShape).expectRustMetadata() + override fun stringMeta(stringShape: StringShape) = base.toSymbol(stringShape).expectRustMetadata() + override fun numberMeta(numberShape: NumberShape) = base.toSymbol(numberShape).expectRustMetadata() + override fun blobMeta(blobShape: BlobShape) = base.toSymbol(blobShape).expectRustMetadata() } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolExt.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolExt.kt index 35a278a3182..78b348e6c2e 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolExt.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolExt.kt @@ -118,10 +118,11 @@ fun Symbol.canUseDefault(): Boolean = this.defaultValue() != Default.NoDefault /** * True when [this] is will be represented by Option in Rust */ -fun Symbol.isOptional(): Boolean = when (this.rustType()) { - is RustType.Option -> true - else -> false -} +fun Symbol.isOptional(): Boolean = + when (this.rustType()) { + is RustType.Option -> true + else -> false + } fun Symbol.isRustBoxed(): Boolean = rustType().stripOuter() is RustType.Box @@ -133,12 +134,21 @@ private const val SYMBOL_DEFAULT = "symboldefault" // Symbols should _always_ be created with a Rust type & shape attached fun Symbol.rustType(): RustType = this.expectProperty(RUST_TYPE_KEY, RustType::class.java) + fun Symbol.Builder.rustType(rustType: RustType): Symbol.Builder = this.putProperty(RUST_TYPE_KEY, rustType) + fun Symbol.shape(): Shape = this.expectProperty(SHAPE_KEY, Shape::class.java) + fun Symbol.Builder.shape(shape: Shape?): Symbol.Builder = this.putProperty(SHAPE_KEY, shape) + fun Symbol.module(): RustModule.LeafModule = this.expectProperty(RUST_MODULE_KEY, RustModule.LeafModule::class.java) + fun Symbol.Builder.module(module: RustModule.LeafModule): Symbol.Builder = this.putProperty(RUST_MODULE_KEY, module) + fun Symbol.renamedFrom(): String? = this.getProperty(RENAMED_FROM_KEY, String::class.java).orNull() + fun Symbol.Builder.renamedFrom(name: String): Symbol.Builder = this.putProperty(RENAMED_FROM_KEY, name) + fun Symbol.defaultValue(): Default = this.getProperty(SYMBOL_DEFAULT, Default::class.java).orElse(Default.NoDefault) + fun Symbol.Builder.setDefault(default: Default): Symbol.Builder = this.putProperty(SYMBOL_DEFAULT, default) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolMetadataProvider.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolMetadataProvider.kt index 6476dd538e3..b69f56b37b6 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolMetadataProvider.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolMetadataProvider.kt @@ -33,34 +33,43 @@ import software.amazon.smithy.rust.codegen.core.util.hasTrait abstract class SymbolMetadataProvider(private val base: RustSymbolProvider) : WrappingSymbolProvider(base) { override fun toSymbol(shape: Shape): Symbol { val baseSymbol = base.toSymbol(shape) - val meta = when (shape) { - is MemberShape -> memberMeta(shape) - is StructureShape -> structureMeta(shape) - is UnionShape -> unionMeta(shape) - is ListShape -> listMeta(shape) - is MapShape -> mapMeta(shape) - is NumberShape -> numberMeta(shape) - is BlobShape -> blobMeta(shape) - is StringShape -> if (shape.hasTrait()) { - enumMeta(shape) - } else { - stringMeta(shape) + val meta = + when (shape) { + is MemberShape -> memberMeta(shape) + is StructureShape -> structureMeta(shape) + is UnionShape -> unionMeta(shape) + is ListShape -> listMeta(shape) + is MapShape -> mapMeta(shape) + is NumberShape -> numberMeta(shape) + is BlobShape -> blobMeta(shape) + is StringShape -> + if (shape.hasTrait()) { + enumMeta(shape) + } else { + stringMeta(shape) + } + + else -> null } - - else -> null - } return baseSymbol.toBuilder().meta(meta).build() } abstract fun memberMeta(memberShape: MemberShape): RustMetadata + abstract fun structureMeta(structureShape: StructureShape): RustMetadata + abstract fun unionMeta(unionShape: UnionShape): RustMetadata + abstract fun enumMeta(stringShape: StringShape): RustMetadata abstract fun listMeta(listShape: ListShape): RustMetadata + abstract fun mapMeta(mapShape: MapShape): RustMetadata + abstract fun stringMeta(stringShape: StringShape): RustMetadata + abstract fun numberMeta(numberShape: NumberShape): RustMetadata + abstract fun blobMeta(blobShape: BlobShape): RustMetadata } @@ -71,12 +80,13 @@ fun containerDefaultMetadata( ): RustMetadata { val derives = mutableSetOf(RuntimeType.Debug, RuntimeType.PartialEq, RuntimeType.Clone) - val isSensitive = shape.hasTrait() || - // Checking the shape's direct members for the sensitive trait should suffice. - // Whether their descendants, i.e. a member's member, is sensitive does not - // affect the inclusion/exclusion of the derived `Debug` trait of _this_ container - // shape; any sensitive descendant should still be printed as redacted. - shape.members().any { it.getMemberTrait(model, SensitiveTrait::class.java).isPresent } + val isSensitive = + shape.hasTrait() || + // Checking the shape's direct members for the sensitive trait should suffice. + // Whether their descendants, i.e. a member's member, is sensitive does not + // affect the inclusion/exclusion of the derived `Debug` trait of _this_ container + // shape; any sensitive descendant should still be printed as redacted. + shape.members().any { it.getMemberTrait(model, SensitiveTrait::class.java).isPresent } if (isSensitive) { derives.remove(RuntimeType.Debug) @@ -95,7 +105,6 @@ class BaseSymbolMetadataProvider( base: RustSymbolProvider, private val additionalAttributes: List, ) : SymbolMetadataProvider(base) { - override fun memberMeta(memberShape: MemberShape): RustMetadata = when (val container = model.expectShape(memberShape.container)) { is StructureShape -> RustMetadata(visibility = Visibility.PUBLIC) @@ -109,7 +118,9 @@ class BaseSymbolMetadataProvider( else -> TODO("Unrecognized container type: $container") } - override fun structureMeta(structureShape: StructureShape) = containerDefaultMetadata(structureShape, model, additionalAttributes) + override fun structureMeta(structureShape: StructureShape) = + containerDefaultMetadata(structureShape, model, additionalAttributes) + override fun unionMeta(unionShape: UnionShape) = containerDefaultMetadata(unionShape, model, additionalAttributes) override fun enumMeta(stringShape: StringShape): RustMetadata = @@ -127,17 +138,23 @@ class BaseSymbolMetadataProvider( private fun defaultRustMetadata() = RustMetadata(visibility = Visibility.PRIVATE) override fun listMeta(listShape: ListShape) = defaultRustMetadata() + override fun mapMeta(mapShape: MapShape) = defaultRustMetadata() + override fun stringMeta(stringShape: StringShape) = defaultRustMetadata() + override fun numberMeta(numberShape: NumberShape) = defaultRustMetadata() + override fun blobMeta(blobShape: BlobShape) = defaultRustMetadata() } private const val META_KEY = "meta" + fun Symbol.Builder.meta(rustMetadata: RustMetadata?): Symbol.Builder = this.putProperty(META_KEY, rustMetadata) -fun Symbol.expectRustMetadata(): RustMetadata = this.getProperty(META_KEY, RustMetadata::class.java).orElseThrow { - CodegenException( - "Expected `$this` to have metadata attached but it did not.", - ) -} +fun Symbol.expectRustMetadata(): RustMetadata = + this.getProperty(META_KEY, RustMetadata::class.java).orElseThrow { + CodegenException( + "Expected `$this` to have metadata attached but it did not.", + ) + } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolVisitor.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolVisitor.kt index af74c80d9f6..89d50554486 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolVisitor.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolVisitor.kt @@ -56,17 +56,18 @@ import software.amazon.smithy.rust.codegen.core.util.toSnakeCase import kotlin.reflect.KClass /** Map from Smithy Shapes to Rust Types */ -val SimpleShapes: Map, RustType> = mapOf( - BooleanShape::class to RustType.Bool, - FloatShape::class to RustType.Float(32), - DoubleShape::class to RustType.Float(64), - ByteShape::class to RustType.Integer(8), - ShortShape::class to RustType.Integer(16), - IntegerShape::class to RustType.Integer(32), - IntEnumShape::class to RustType.Integer(32), - LongShape::class to RustType.Integer(64), - StringShape::class to RustType.String, -) +val SimpleShapes: Map, RustType> = + mapOf( + BooleanShape::class to RustType.Bool, + FloatShape::class to RustType.Float(32), + DoubleShape::class to RustType.Float(64), + ByteShape::class to RustType.Integer(8), + ShortShape::class to RustType.Integer(16), + IntegerShape::class to RustType.Integer(32), + IntEnumShape::class to RustType.Integer(32), + LongShape::class to RustType.Integer(64), + StringShape::class to RustType.String, + ) /** * Track both the past and current name of a symbol @@ -82,7 +83,10 @@ data class MaybeRenamed(val name: String, val renamedFrom: String?) /** * Make the return [value] optional if the [member] symbol is as well optional. */ -fun SymbolProvider.wrapOptional(member: MemberShape, value: String): String = +fun SymbolProvider.wrapOptional( + member: MemberShape, + value: String, +): String = value.letIf(toSymbol(member).isOptional()) { "Some($value)" } @@ -90,7 +94,10 @@ fun SymbolProvider.wrapOptional(member: MemberShape, value: String): String = /** * Make the return [value] optional if the [member] symbol is not optional. */ -fun SymbolProvider.toOptional(member: MemberShape, value: String): String = +fun SymbolProvider.toOptional( + member: MemberShape, + value: String, +): String = value.letIf(!toSymbol(member).isOptional()) { "Some($value)" } @@ -139,10 +146,11 @@ open class SymbolVisitor( module.toType().resolve("${symbol.name}Error").toSymbol().toBuilder().locatedIn(module).build() } - override fun symbolForBuilder(shape: Shape): Symbol = toSymbol(shape).let { symbol -> - val module = moduleForBuilder(shape) - module.toType().resolve(config.nameBuilderFor(symbol)).toSymbol().toBuilder().locatedIn(module).build() - } + override fun symbolForBuilder(shape: Shape): Symbol = + toSymbol(shape).let { symbol -> + val module = moduleForBuilder(shape) + module.toType().resolve(config.nameBuilderFor(symbol)).toSymbol().toBuilder().locatedIn(module).build() + } override fun toMemberName(shape: MemberShape): String { val container = model.expectShape(shape.container) @@ -160,7 +168,10 @@ open class SymbolVisitor( /** * Produce `Box` when the shape has the `RustBoxTrait` */ - private fun handleRustBoxing(symbol: Symbol, shape: Shape): Symbol { + private fun handleRustBoxing( + symbol: Symbol, + shape: Shape, + ): Symbol { return if (shape.hasTrait()) { val rustType = RustType.Box(symbol.rustType()) with(Symbol.builder()) { @@ -179,14 +190,21 @@ open class SymbolVisitor( } override fun booleanShape(shape: BooleanShape): Symbol = simpleShape(shape) + override fun byteShape(shape: ByteShape): Symbol = simpleShape(shape) + override fun shortShape(shape: ShortShape): Symbol = simpleShape(shape) + override fun integerShape(shape: IntegerShape): Symbol = simpleShape(shape) + override fun longShape(shape: LongShape): Symbol = simpleShape(shape) + override fun floatShape(shape: FloatShape): Symbol = simpleShape(shape) + override fun doubleShape(shape: DoubleShape): Symbol = simpleShape(shape) override fun intEnumShape(shape: IntEnumShape): Symbol = simpleShape(shape) + override fun stringShape(shape: StringShape): Symbol { return if (shape.hasTrait()) { val rustType = RustType.Opaque(shape.contextName(serviceShape).toPascalCase()) @@ -203,12 +221,13 @@ open class SymbolVisitor( override fun setShape(shape: SetShape): Symbol { val inner = this.toSymbol(shape.member) - val builder = if (model.expectShape(shape.member.target).isStringShape) { - symbolBuilder(shape, RustType.HashSet(inner.rustType())) - } else { - // only strings get put into actual sets because floats are unhashable - symbolBuilder(shape, RustType.Vec(inner.rustType())) - } + val builder = + if (model.expectShape(shape.member.target).isStringShape) { + symbolBuilder(shape, RustType.HashSet(inner.rustType())) + } else { + // only strings get put into actual sets because floats are unhashable + symbolBuilder(shape, RustType.Vec(inner.rustType())) + } return builder.addReference(inner).build() } @@ -255,9 +274,10 @@ open class SymbolVisitor( override fun structureShape(shape: StructureShape): Symbol { val isError = shape.hasTrait() - val name = shape.contextName(serviceShape).toPascalCase().letIf(isError && config.renameExceptions) { - it.replace("Exception", "Error") - } + val name = + shape.contextName(serviceShape).toPascalCase().letIf(isError && config.renameExceptions) { + it.replace("Exception", "Error") + } return symbolBuilder(shape, RustType.Opaque(name)).locatedIn(moduleForShape(shape)).build() } @@ -268,17 +288,18 @@ open class SymbolVisitor( override fun memberShape(shape: MemberShape): Symbol { val target = model.expectShape(shape.target) - val defaultValue = shape.getMemberTrait(model, DefaultTrait::class.java).orNull()?.let { trait -> - if (target.isDocumentShape || target.isTimestampShape) { - Default.NonZeroDefault(trait.toNode()) - } else { - when (val value = trait.toNode()) { - Node.from(""), Node.from(0), Node.from(false), Node.arrayNode(), Node.objectNode() -> Default.RustDefault - Node.nullNode() -> Default.NoDefault - else -> Default.NonZeroDefault(value) + val defaultValue = + shape.getMemberTrait(model, DefaultTrait::class.java).orNull()?.let { trait -> + if (target.isDocumentShape || target.isTimestampShape) { + Default.NonZeroDefault(trait.toNode()) + } else { + when (val value = trait.toNode()) { + Node.from(""), Node.from(0), Node.from(false), Node.arrayNode(), Node.objectNode() -> Default.RustDefault + Node.nullNode() -> Default.NoDefault + else -> Default.NonZeroDefault(value) + } } - } - } ?: Default.NoDefault + } ?: Default.NoDefault // Handle boxing first, so we end up with Option>, not Box>. return handleOptionality( handleRustBoxing(toSymbol(target), shape), @@ -299,14 +320,20 @@ open class SymbolVisitor( * * See `RecursiveShapeBoxer.kt` for the model transformation pass that annotates model shapes with [RustBoxTrait]. */ -fun handleRustBoxing(symbol: Symbol, shape: MemberShape): Symbol = +fun handleRustBoxing( + symbol: Symbol, + shape: MemberShape, +): Symbol = if (shape.hasTrait()) { symbol.makeRustBoxed() } else { symbol } -fun symbolBuilder(shape: Shape?, rustType: RustType): Symbol.Builder = +fun symbolBuilder( + shape: Shape?, + rustType: RustType, +): Symbol.Builder = Symbol.builder().shape(shape).rustType(rustType) .name(rustType.name) // Every symbol that actually gets defined somewhere should set a definition file @@ -318,8 +345,7 @@ fun handleOptionality( member: MemberShape, nullableIndex: NullableIndex, nullabilityCheckMode: CheckMode, -): Symbol = - symbol.letIf(nullableIndex.isMemberNullable(member, nullabilityCheckMode)) { symbol.makeOptional() } +): Symbol = symbol.letIf(nullableIndex.isMemberNullable(member, nullabilityCheckMode)) { symbol.makeOptional() } /** * Creates a test module for this symbol. diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customizations/AllowLintsCustomization.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customizations/AllowLintsCustomization.kt index 4f19e270311..d7ee63fd273 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customizations/AllowLintsCustomization.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customizations/AllowLintsCustomization.kt @@ -12,78 +12,72 @@ import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsCustomization import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsSection -private val allowedRustcLints = listOf( - // Deprecated items should be safe to compile, so don't block the compilation. - "deprecated", - - // Unknown lints need to be allowed since we use both nightly and our MSRV, and sometimes we need - // to disable lints that are in nightly but don't exist in the MSRV. - "unknown_lints", -) - -private val allowedClippyLints = listOf( - // Sometimes operations are named the same as our module e.g. output leading to `output::output`. - "module_inception", - - // Currently, we don't re-case acronyms in models, e.g. `SSEVersion`. - "upper_case_acronyms", - - // Large errors trigger this warning, we are unlikely to optimize this case currently. - "large_enum_variant", - - // Some models have members with `is` in the name, which leads to builder functions with the wrong self convention. - "wrong_self_convention", - - // Models like ecs use method names like `add()` which confuses Clippy. - "should_implement_trait", - - // Protocol tests use silly names like `baz`, don't flag that. - "disallowed_names", - - // Forcing use of `vec![]` can make codegen harder in some cases. - "vec_init_then_push", - - // Some models have shapes that generate complex Rust types (e.g. nested collection and map shapes). - "type_complexity", - - // Determining if the expression is the last one (to remove return) can make codegen harder in some cases. - "needless_return", - - // For backwards compatibility, we often don't derive Eq - "derive_partial_eq_without_eq", - - // Keeping errors small in a backwards compatible way is challenging - "result_large_err", -) - -private val allowedRustdocLints = listOf( - // Rust >=1.53.0 requires links to be wrapped in ``. This is extremely hard to enforce for - // docs that come from the modeled documentation, so we need to disable this lint - "bare_urls", - // Rustdoc warns about redundant explicit links in doc comments. This is fine for handwritten - // crates, but is impractical to manage for code generated crates. Thus, allow it. - "redundant_explicit_links", -) +private val allowedRustcLints = + listOf( + // Deprecated items should be safe to compile, so don't block the compilation. + "deprecated", + // Unknown lints need to be allowed since we use both nightly and our MSRV, and sometimes we need + // to disable lints that are in nightly but don't exist in the MSRV. + "unknown_lints", + ) + +private val allowedClippyLints = + listOf( + // Sometimes operations are named the same as our module e.g. output leading to `output::output`. + "module_inception", + // Currently, we don't re-case acronyms in models, e.g. `SSEVersion`. + "upper_case_acronyms", + // Large errors trigger this warning, we are unlikely to optimize this case currently. + "large_enum_variant", + // Some models have members with `is` in the name, which leads to builder functions with the wrong self convention. + "wrong_self_convention", + // Models like ecs use method names like `add()` which confuses Clippy. + "should_implement_trait", + // Protocol tests use silly names like `baz`, don't flag that. + "disallowed_names", + // Forcing use of `vec![]` can make codegen harder in some cases. + "vec_init_then_push", + // Some models have shapes that generate complex Rust types (e.g. nested collection and map shapes). + "type_complexity", + // Determining if the expression is the last one (to remove return) can make codegen harder in some cases. + "needless_return", + // For backwards compatibility, we often don't derive Eq + "derive_partial_eq_without_eq", + // Keeping errors small in a backwards compatible way is challenging + "result_large_err", + ) + +private val allowedRustdocLints = + listOf( + // Rust >=1.53.0 requires links to be wrapped in ``. This is extremely hard to enforce for + // docs that come from the modeled documentation, so we need to disable this lint + "bare_urls", + // Rustdoc warns about redundant explicit links in doc comments. This is fine for handwritten + // crates, but is impractical to manage for code generated crates. Thus, allow it. + "redundant_explicit_links", + ) class AllowLintsCustomization( private val rustcLints: List = allowedRustcLints, private val clippyLints: List = allowedClippyLints, private val rustdocLints: List = allowedRustdocLints, ) : LibRsCustomization() { - override fun section(section: LibRsSection) = when (section) { - is LibRsSection.Attributes -> writable { - rustcLints.forEach { - Attribute(allow(it)).render(this, AttributeKind.Inner) - } - clippyLints.forEach { - Attribute(allow("clippy::$it")).render(this, AttributeKind.Inner) - } - rustdocLints.forEach { - Attribute(allow("rustdoc::$it")).render(this, AttributeKind.Inner) - } - // add a newline at the end - this.write("") + override fun section(section: LibRsSection) = + when (section) { + is LibRsSection.Attributes -> + writable { + rustcLints.forEach { + Attribute(allow(it)).render(this, AttributeKind.Inner) + } + clippyLints.forEach { + Attribute(allow("clippy::$it")).render(this, AttributeKind.Inner) + } + rustdocLints.forEach { + Attribute(allow("rustdoc::$it")).render(this, AttributeKind.Inner) + } + // add a newline at the end + this.write("") + } + else -> emptySection } - else -> emptySection - } } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customizations/CrateVersionCustomization.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customizations/CrateVersionCustomization.kt index 93db223c20e..310e125bd0d 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customizations/CrateVersionCustomization.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customizations/CrateVersionCustomization.kt @@ -16,13 +16,15 @@ import software.amazon.smithy.rust.codegen.core.smithy.RustCrate object CrateVersionCustomization { fun pkgVersion(module: RustModule): RuntimeType = RuntimeType(module.fullyQualifiedPath() + "::PKG_VERSION") - fun extras(rustCrate: RustCrate, module: RustModule) = - rustCrate.withModule(module) { - rust( - """ - /// Crate version number. - pub static PKG_VERSION: &str = env!("CARGO_PKG_VERSION"); - """, - ) - } + fun extras( + rustCrate: RustCrate, + module: RustModule, + ) = rustCrate.withModule(module) { + rust( + """ + /// Crate version number. + pub static PKG_VERSION: &str = env!("CARGO_PKG_VERSION"); + """, + ) + } } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customizations/SmithyTypesPubUseExtra.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customizations/SmithyTypesPubUseExtra.kt index 2a285ca13f3..2c8176e1e95 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customizations/SmithyTypesPubUseExtra.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customizations/SmithyTypesPubUseExtra.kt @@ -30,12 +30,16 @@ private fun hasStreamingOperations(model: Model): Boolean { } // TODO(https://github.com/smithy-lang/smithy-rs/issues/2111): Fix this logic to consider collection/map shapes -private fun structUnionMembersMatchPredicate(model: Model, predicate: (Shape) -> Boolean): Boolean = +private fun structUnionMembersMatchPredicate( + model: Model, + predicate: (Shape) -> Boolean, +): Boolean = model.structureShapes.any { structure -> structure.members().any { member -> predicate(model.expectShape(member.target)) } - } || model.unionShapes.any { union -> - union.members().any { member -> predicate(model.expectShape(member.target)) } - } + } || + model.unionShapes.any { union -> + union.members().any { member -> predicate(model.expectShape(member.target)) } + } /** Returns true if the model uses any blob shapes */ private fun hasBlobs(model: Model): Boolean = structUnionMembersMatchPredicate(model, Shape::isBlobShape) @@ -44,59 +48,68 @@ private fun hasBlobs(model: Model): Boolean = structUnionMembersMatchPredicate(m private fun hasDateTimes(model: Model): Boolean = structUnionMembersMatchPredicate(model, Shape::isTimestampShape) /** Adds re-export statements for Smithy primitives */ -fun pubUseSmithyPrimitives(codegenContext: CodegenContext, model: Model, rustCrate: RustCrate): Writable = writable { - val rc = codegenContext.runtimeConfig - if (hasBlobs(model)) { - rustTemplate("pub use #{Blob};", "Blob" to RuntimeType.blob(rc)) - } - if (hasDateTimes(model)) { - rustTemplate( - """ - pub use #{DateTime}; - pub use #{Format} as DateTimeFormat; - """, - "DateTime" to RuntimeType.dateTime(rc), - "Format" to RuntimeType.format(rc), - ) - } - if (hasStreamingOperations(model)) { - rustCrate.mergeFeature( - Feature( - "rt-tokio", - true, - listOf("aws-smithy-types/rt-tokio"), - ), - ) - rustTemplate( - """ - pub use #{ByteStream}; - pub use #{AggregatedBytes}; - pub use #{Error} as ByteStreamError; - pub use #{SdkBody}; - """, - "ByteStream" to RuntimeType.smithyTypes(rc).resolve("byte_stream::ByteStream"), - "AggregatedBytes" to RuntimeType.smithyTypes(rc).resolve("byte_stream::AggregatedBytes"), - "Error" to RuntimeType.smithyTypes(rc).resolve("byte_stream::error::Error"), - "SdkBody" to RuntimeType.smithyTypes(rc).resolve("body::SdkBody"), - ) +fun pubUseSmithyPrimitives( + codegenContext: CodegenContext, + model: Model, + rustCrate: RustCrate, +): Writable = + writable { + val rc = codegenContext.runtimeConfig + if (hasBlobs(model)) { + rustTemplate("pub use #{Blob};", "Blob" to RuntimeType.blob(rc)) + } + if (hasDateTimes(model)) { + rustTemplate( + """ + pub use #{DateTime}; + pub use #{Format} as DateTimeFormat; + """, + "DateTime" to RuntimeType.dateTime(rc), + "Format" to RuntimeType.format(rc), + ) + } + if (hasStreamingOperations(model)) { + rustCrate.mergeFeature( + Feature( + "rt-tokio", + true, + listOf("aws-smithy-types/rt-tokio"), + ), + ) + rustTemplate( + """ + pub use #{ByteStream}; + pub use #{AggregatedBytes}; + pub use #{Error} as ByteStreamError; + pub use #{SdkBody}; + """, + "ByteStream" to RuntimeType.smithyTypes(rc).resolve("byte_stream::ByteStream"), + "AggregatedBytes" to RuntimeType.smithyTypes(rc).resolve("byte_stream::AggregatedBytes"), + "Error" to RuntimeType.smithyTypes(rc).resolve("byte_stream::error::Error"), + "SdkBody" to RuntimeType.smithyTypes(rc).resolve("body::SdkBody"), + ) + } } -} /** Adds re-export statements for event-stream-related Smithy primitives */ -fun pubUseSmithyPrimitivesEventStream(codegenContext: CodegenContext, model: Model): Writable = writable { - val rc = codegenContext.runtimeConfig - if (codegenContext.serviceShape.hasEventStreamOperations(model)) { - rustTemplate( - """ - pub use #{Header}; - pub use #{HeaderValue}; - pub use #{Message}; - pub use #{StrBytes}; - """, - "Header" to RuntimeType.smithyTypes(rc).resolve("event_stream::Header"), - "HeaderValue" to RuntimeType.smithyTypes(rc).resolve("event_stream::HeaderValue"), - "Message" to RuntimeType.smithyTypes(rc).resolve("event_stream::Message"), - "StrBytes" to RuntimeType.smithyTypes(rc).resolve("str_bytes::StrBytes"), - ) +fun pubUseSmithyPrimitivesEventStream( + codegenContext: CodegenContext, + model: Model, +): Writable = + writable { + val rc = codegenContext.runtimeConfig + if (codegenContext.serviceShape.hasEventStreamOperations(model)) { + rustTemplate( + """ + pub use #{Header}; + pub use #{HeaderValue}; + pub use #{Message}; + pub use #{StrBytes}; + """, + "Header" to RuntimeType.smithyTypes(rc).resolve("event_stream::Header"), + "HeaderValue" to RuntimeType.smithyTypes(rc).resolve("event_stream::HeaderValue"), + "Message" to RuntimeType.smithyTypes(rc).resolve("event_stream::Message"), + "StrBytes" to RuntimeType.smithyTypes(rc).resolve("str_bytes::StrBytes"), + ) + } } -} diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customize/CoreCodegenDecorator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customize/CoreCodegenDecorator.kt index 923f4fce554..f5fd23f4fc5 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customize/CoreCodegenDecorator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customize/CoreCodegenDecorator.kt @@ -43,12 +43,19 @@ interface CoreCodegenDecorator { /** * Hook to transform the Smithy model before codegen takes place. */ - fun transformModel(service: ServiceShape, model: Model, settings: CodegenSettings): Model = model + fun transformModel( + service: ServiceShape, + model: Model, + settings: CodegenSettings, + ): Model = model /** * Hook to add additional modules to the generated crate. */ - fun extras(codegenContext: CodegenContext, rustCrate: RustCrate) {} + fun extras( + codegenContext: CodegenContext, + rustCrate: RustCrate, + ) {} /** * Customize the documentation provider for module documentation. @@ -84,6 +91,7 @@ interface CoreCodegenDecorator { ): List = baseCustomizations // TODO(https://github.com/smithy-lang/smithy-rs/issues/1401): Move builder customizations into `ClientCodegenDecorator` + /** * Hook to customize generated builders. */ @@ -124,11 +132,18 @@ abstract class CombinedCoreCodegenDecorator decorator.transformModel(otherModel.expectShape(service.id, ServiceShape::class.java), otherModel, settings) } @@ -136,9 +151,10 @@ abstract class CombinedCoreCodegenDecorator - decorator.moduleDocumentationCustomization(codegenContext, base) - } + ): ModuleDocProvider = + combineCustomizations(baseModuleDocProvider) { decorator, base -> + decorator.moduleDocumentationCustomization(codegenContext, base) + } final override fun libRsCustomizations( codegenContext: CodegenContext, @@ -151,23 +167,26 @@ abstract class CombinedCoreCodegenDecorator, - ): List = combineCustomizations(baseCustomizations) { decorator, customizations -> - decorator.structureCustomizations(codegenContext, customizations) - } + ): List = + combineCustomizations(baseCustomizations) { decorator, customizations -> + decorator.structureCustomizations(codegenContext, customizations) + } override fun builderCustomizations( codegenContext: CodegenContext, baseCustomizations: List, - ): List = combineCustomizations(baseCustomizations) { decorator, customizations -> - decorator.builderCustomizations(codegenContext, customizations) - } + ): List = + combineCustomizations(baseCustomizations) { decorator, customizations -> + decorator.builderCustomizations(codegenContext, customizations) + } override fun errorImplCustomizations( codegenContext: CodegenContext, baseCustomizations: List, - ): List = combineCustomizations(baseCustomizations) { decorator, customizations -> - decorator.errorImplCustomizations(codegenContext, customizations) - } + ): List = + combineCustomizations(baseCustomizations) { decorator, customizations -> + decorator.errorImplCustomizations(codegenContext, customizations) + } final override fun extraSections(codegenContext: CodegenContext): List = addCustomizations { decorator -> decorator.extraSections(codegenContext) } @@ -215,16 +234,18 @@ abstract class CombinedCoreCodegenDecorator { - val decorators = ServiceLoader.load( - decoratorClass, - context.pluginClassLoader.orElse(decoratorClass.classLoader), - ) - - val filteredDecorators = decorators.asSequence() - .onEach { logger.info("Discovered Codegen Decorator: ${it!!::class.java.name}") } - .filter { it!!.classpathDiscoverable() } - .onEach { logger.info("Adding Codegen Decorator: ${it!!::class.java.name}") } - .toList() + val decorators = + ServiceLoader.load( + decoratorClass, + context.pluginClassLoader.orElse(decoratorClass.classLoader), + ) + + val filteredDecorators = + decorators.asSequence() + .onEach { logger.info("Discovered Codegen Decorator: ${it!!::class.java.name}") } + .filter { it!!.classpathDiscoverable() } + .onEach { logger.info("Adding Codegen Decorator: ${it!!::class.java.name}") } + .toList() return filteredDecorators + extras } } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customize/Customization.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customize/Customization.kt index c174c81d954..79ff1bdb6f4 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customize/Customization.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customize/Customization.kt @@ -37,11 +37,12 @@ inline fun adhocCustomization( crossinline customization: RustWriter.(T) -> Unit, ): AdHocCustomization = object : AdHocCustomization() { - override fun section(section: AdHocSection): Writable = writable { - if (section is T) { - customization(section) + override fun section(section: AdHocSection): Writable = + writable { + if (section is T) { + customization(section) + } } - } } /** @@ -51,11 +52,15 @@ inline fun adhocCustomization( */ abstract class NamedCustomization { abstract fun section(section: T): Writable + protected val emptySection = writable { } } /** Convenience for rendering a list of customizations for a given section */ -fun RustWriter.writeCustomizations(customizations: List>, section: T) { +fun RustWriter.writeCustomizations( + customizations: List>, + section: T, +) { for (customization in customizations) { customization.section(section)(this) } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/BuilderGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/BuilderGenerator.kt index 76c8caf4efa..e273188cd5e 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/BuilderGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/BuilderGenerator.kt @@ -76,6 +76,7 @@ sealed class BuilderSection(name: String) : Section(name) { abstract class BuilderCustomization : NamedCustomization() fun RuntimeConfig.operationBuildError() = RuntimeType.smithyTypes(this).resolve("error::operation::BuildError") + fun RuntimeConfig.serializationError() = RuntimeType.smithyTypes(this).resolve("error::operation::SerializationError") fun MemberShape.enforceRequired( @@ -89,30 +90,42 @@ fun MemberShape.enforceRequired( val shape = this val isOptional = codegenContext.symbolProvider.toSymbol(shape).isOptional() val field = field.letIf(!isOptional) { field.map { rust("Some(#T)", it) } } - val error = OperationBuildError(codegenContext.runtimeConfig).missingField( - codegenContext.symbolProvider.toMemberName(shape), "A required field was not set", - ) - val unwrapped = when (codegenContext.model.expectShape(this.target)) { - is StringShape -> writable { - rustTemplate( - "#{field}.filter(|f|!AsRef::::as_ref(f).trim().is_empty())", - "field" to field, - ) - } + val error = + OperationBuildError(codegenContext.runtimeConfig).missingField( + codegenContext.symbolProvider.toMemberName(shape), "A required field was not set", + ) + val unwrapped = + when (codegenContext.model.expectShape(this.target)) { + is StringShape -> + writable { + rustTemplate( + "#{field}.filter(|f|!AsRef::::as_ref(f).trim().is_empty())", + "field" to field, + ) + } - else -> field - }.map { base -> rustTemplate("#{base}.ok_or_else(||#{error})?", "base" to base, "error" to error) } + else -> field + }.map { base -> rustTemplate("#{base}.ok_or_else(||#{error})?", "base" to base, "error" to error) } return unwrapped.letIf(produceOption) { w -> w.map { rust("Some(#T)", it) } } } class OperationBuildError(private val runtimeConfig: RuntimeConfig) { - - fun missingField(field: String, details: String) = writable { + fun missingField( + field: String, + details: String, + ) = writable { rust("#T::missing_field(${field.dq()}, ${details.dq()})", runtimeConfig.operationBuildError()) } - fun invalidField(field: String, details: String) = invalidField(field) { rust(details.dq()) } - fun invalidField(field: String, details: Writable) = writable { + fun invalidField( + field: String, + details: String, + ) = invalidField(field) { rust(details.dq()) } + + fun invalidField( + field: String, + details: Writable, + ) = writable { rustTemplate( "#{error}::invalid_field(${field.dq()}, #{details:W})", "error" to runtimeConfig.operationBuildError(), @@ -138,7 +151,10 @@ class BuilderGenerator( * Returns whether a structure shape, whose builder has been generated with [BuilderGenerator], requires a * fallible builder to be constructed. */ - fun hasFallibleBuilder(structureShape: StructureShape, symbolProvider: SymbolProvider): Boolean = + fun hasFallibleBuilder( + structureShape: StructureShape, + symbolProvider: SymbolProvider, + ): Boolean = // All operation inputs should have fallible builders in case a new required field is added in the future. structureShape.hasTrait() || structureShape @@ -149,7 +165,11 @@ class BuilderGenerator( !it.isOptional() && !it.canUseDefault() } - fun renderConvenienceMethod(implBlock: RustWriter, symbolProvider: RustSymbolProvider, shape: StructureShape) { + fun renderConvenienceMethod( + implBlock: RustWriter, + symbolProvider: RustSymbolProvider, + shape: StructureShape, + ) { implBlock.docs("Creates a new builder-style object to manufacture #D.", symbolProvider.toSymbol(shape)) symbolProvider.symbolForBuilder(shape).also { builderSymbol -> implBlock.rustBlock("pub fn builder() -> #T", builderSymbol) { @@ -165,9 +185,10 @@ class BuilderGenerator( private val metadata = structureSymbol.expectRustMetadata() // Filter out any derive that isn't Debug, PartialEq, or Clone. Then add a Default derive - private val builderDerives = metadata.derives.filter { - it == RuntimeType.Debug || it == RuntimeType.PartialEq || it == RuntimeType.Clone - } + RuntimeType.Default + private val builderDerives = + metadata.derives.filter { + it == RuntimeType.Debug || it == RuntimeType.PartialEq || it == RuntimeType.Clone + } + RuntimeType.Default private val builderName = symbolProvider.symbolForBuilder(shape).name fun render(writer: RustWriter) { @@ -181,10 +202,11 @@ class BuilderGenerator( private fun renderBuildFn(implBlockWriter: RustWriter) { val fallibleBuilder = hasFallibleBuilder(shape, symbolProvider) val outputSymbol = symbolProvider.toSymbol(shape) - val returnType = when (fallibleBuilder) { - true -> "#{Result}<${implBlockWriter.format(outputSymbol)}, ${implBlockWriter.format(runtimeConfig.operationBuildError())}>" - false -> implBlockWriter.format(outputSymbol) - } + val returnType = + when (fallibleBuilder) { + true -> "#{Result}<${implBlockWriter.format(outputSymbol)}, ${implBlockWriter.format(runtimeConfig.operationBuildError())}>" + false -> implBlockWriter.format(outputSymbol) + } implBlockWriter.docs("Consumes the builder and constructs a #D.", outputSymbol) val trulyRequiredMembers = members.filter { trulyRequired(it) } if (trulyRequiredMembers.isNotEmpty()) { @@ -209,7 +231,11 @@ class BuilderGenerator( } // TODO(EventStream): [DX] Consider updating builders to take EventInputStream as Into - private fun renderBuilderMember(writer: RustWriter, memberName: String, memberSymbol: Symbol) { + private fun renderBuilderMember( + writer: RustWriter, + memberName: String, + memberSymbol: Symbol, + ) { // Builder members are crate-public to enable using them directly in serializers/deserializers. // During XML deserialization, `builder..take` is used to append to lists and maps. writer.write("pub(crate) $memberName: #T,", memberSymbol) @@ -333,7 +359,11 @@ class BuilderGenerator( } } - private fun RustWriter.renderVecHelper(member: MemberShape, memberName: String, coreType: RustType.Vec) { + private fun RustWriter.renderVecHelper( + member: MemberShape, + memberName: String, + coreType: RustType.Vec, + ) { docs("Appends an item to `$memberName`.") rust("///") docs("To override the contents of this collection use [`${member.setterName()}`](Self::${member.setterName()}).") @@ -355,7 +385,11 @@ class BuilderGenerator( } } - private fun RustWriter.renderMapHelper(member: MemberShape, memberName: String, coreType: RustType.HashMap) { + private fun RustWriter.renderMapHelper( + member: MemberShape, + memberName: String, + coreType: RustType.HashMap, + ) { docs("Adds a key-value pair to `$memberName`.") rust("///") docs("To override the contents of this collection use [`${member.setterName()}`](Self::${member.setterName()}).") @@ -380,9 +414,10 @@ class BuilderGenerator( } } - private fun trulyRequired(member: MemberShape) = symbolProvider.toSymbol(member).let { - !it.isOptional() && !it.canUseDefault() - } + private fun trulyRequired(member: MemberShape) = + symbolProvider.toSymbol(member).let { + !it.isOptional() && !it.canUseDefault() + } /** * The core builder of the inner type. If the structure requires a fallible builder, this may use `?` to return diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/BuilderInstantiator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/BuilderInstantiator.kt index fd62ced0668..f73e8d5a055 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/BuilderInstantiator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/BuilderInstantiator.kt @@ -17,16 +17,28 @@ import software.amazon.smithy.rust.codegen.core.rustlang.writable * */ interface BuilderInstantiator { /** Set a field on a builder. */ - fun setField(builder: String, value: Writable, field: MemberShape): Writable + fun setField( + builder: String, + value: Writable, + field: MemberShape, + ): Writable /** Finalize a builder, turning it into a built object * - In the case of builders-of-builders, the value should be returned directly * - If an error is returned, you MUST use `mapErr` to convert the error type */ - fun finalizeBuilder(builder: String, shape: StructureShape, mapErr: Writable? = null): Writable + fun finalizeBuilder( + builder: String, + shape: StructureShape, + mapErr: Writable? = null, + ): Writable /** Set a field on a builder using the `$setterName` method. $value will be passed directly. */ - fun setFieldWithSetter(builder: String, value: Writable, field: MemberShape) = writable { + fun setFieldWithSetter( + builder: String, + value: Writable, + field: MemberShape, + ) = writable { rustTemplate("$builder = $builder.${field.setterName()}(#{value})", "value" to value) } } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/CargoTomlGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/CargoTomlGenerator.kt index b5eca73acd8..5c2f8d9a938 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/CargoTomlGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/CargoTomlGenerator.kt @@ -79,29 +79,36 @@ class CargoTomlGenerator( cargoFeatures.add("default" to features.filter { it.default }.map { it.name }) } - val cargoToml = mapOf( - "package" to listOfNotNull( - "name" to moduleName, - "version" to moduleVersion, - "authors" to moduleAuthors, - moduleDescription?.let { "description" to it }, - "edition" to "2021", - "license" to moduleLicense, - "repository" to moduleRepository, - "metadata" to listOfNotNull( - "smithy" to listOfNotNull( - "codegen-version" to Version.fullVersion(), + val cargoToml = + mapOf( + "package" to + listOfNotNull( + "name" to moduleName, + "version" to moduleVersion, + "authors" to moduleAuthors, + moduleDescription?.let { "description" to it }, + "edition" to "2021", + "license" to moduleLicense, + "repository" to moduleRepository, + "metadata" to + listOfNotNull( + "smithy" to + listOfNotNull( + "codegen-version" to Version.fullVersion(), + ).toMap(), + ).toMap(), ).toMap(), - ).toMap(), - ).toMap(), - "dependencies" to dependencies.filter { it.scope == DependencyScope.Compile } - .associate { it.name to it.toMap() }, - "build-dependencies" to dependencies.filter { it.scope == DependencyScope.Build } - .associate { it.name to it.toMap() }, - "dev-dependencies" to dependencies.filter { it.scope == DependencyScope.Dev } - .associate { it.name to it.toMap() }, - "features" to cargoFeatures.toMap(), - ).deepMergeWith(manifestCustomizations) + "dependencies" to + dependencies.filter { it.scope == DependencyScope.Compile } + .associate { it.name to it.toMap() }, + "build-dependencies" to + dependencies.filter { it.scope == DependencyScope.Build } + .associate { it.name to it.toMap() }, + "dev-dependencies" to + dependencies.filter { it.scope == DependencyScope.Dev } + .associate { it.name to it.toMap() }, + "features" to cargoFeatures.toMap(), + ).deepMergeWith(manifestCustomizations) writer.writeWithNoFormatting(TomlWriter().write(cargoToml)) } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGenerator.kt index 9c614c7c6e4..d28896f18f6 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGenerator.kt @@ -138,7 +138,10 @@ class EnumMemberModel( } } -private fun RustWriter.docWithNote(doc: String?, note: String?) { +private fun RustWriter.docWithNote( + doc: String?, + note: String?, +) { if (doc.isNullOrBlank() && note.isNullOrBlank()) { // If the model doesn't have any documentation for the shape, then suppress the missing docs lint // since the lack of documentation is a modeling issue rather than a codegen issue. @@ -166,12 +169,13 @@ open class EnumGenerator( private val enumTrait: EnumTrait = shape.expectTrait() private val symbol: Symbol = symbolProvider.toSymbol(shape) - private val context = EnumGeneratorContext( - enumName = symbol.name, - enumMeta = symbol.expectRustMetadata(), - enumTrait = enumTrait, - sortedMembers = enumTrait.values.sortedBy { it.value }.map { EnumMemberModel(shape, it, symbolProvider) }, - ) + private val context = + EnumGeneratorContext( + enumName = symbol.name, + enumMeta = symbol.expectRustMetadata(), + enumTrait = enumTrait, + sortedMembers = enumTrait.values.sortedBy { it.value }.map { EnumMemberModel(shape, it, symbolProvider) }, + ) fun render(writer: RustWriter) { enumType.additionalEnumAttributes(context).forEach { attribute -> @@ -200,14 +204,15 @@ open class EnumGenerator( insertTrailingNewline() // impl Blah { pub fn as_str(&self) -> &str implBlock( - asStrImpl = writable { - rustBlock("match self") { - context.sortedMembers.forEach { member -> - rust("""${context.enumName}::${member.derivedName()} => ${member.value.dq()},""") + asStrImpl = + writable { + rustBlock("match self") { + context.sortedMembers.forEach { member -> + rust("""${context.enumName}::${member.derivedName()} => ${member.value.dq()},""") + } + enumType.additionalAsStrMatchArms(context)(this) } - enumType.additionalAsStrMatchArms(context)(this) - } - }, + }, ) rustTemplate( """ @@ -227,9 +232,10 @@ open class EnumGenerator( context.enumMeta.render(this) rust("struct ${context.enumName}(String);") implBlock( - asStrImpl = writable { - rust("&self.0") - }, + asStrImpl = + writable { + rust("&self.0") + }, ) // Add an infallible FromStr implementation for uniformity @@ -295,9 +301,10 @@ open class EnumGenerator( } """, "asStrImpl" to asStrImpl, - "Values" to writable { - rust(context.sortedMembers.joinToString(", ") { it.value.dq() }) - }, + "Values" to + writable { + rust(context.sortedMembers.joinToString(", ") { it.value.dq() }) + }, ) } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt index 795a6b5a1dc..d0df70047c6 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt @@ -120,6 +120,7 @@ open class Instantiator( // in the structure field's type. The latter's method name is the field's name, whereas the former is prefixed // with `set_`. Client instantiators call the `set_*` builder setters. fun setterName(memberShape: MemberShape): String + fun doesSetterTakeInOption(memberShape: MemberShape): Boolean } @@ -135,7 +136,12 @@ open class Instantiator( override fun generate(shape: Shape): Writable? = null } - fun generate(shape: Shape, data: Node, headers: Map = mapOf(), ctx: Ctx = Ctx()) = writable { + fun generate( + shape: Shape, + data: Node, + headers: Map = mapOf(), + ctx: Ctx = Ctx(), + ) = writable { render(this, shape, data, headers, ctx) } @@ -162,11 +168,12 @@ open class Instantiator( // Members, supporting potentially optional members is MemberShape -> renderMember(writer, shape, data, ctx) - is SimpleShape -> PrimitiveInstantiator(runtimeConfig, symbolProvider).instantiate( - shape, - data, - customWritable, - )(writer) + is SimpleShape -> + PrimitiveInstantiator(runtimeConfig, symbolProvider).instantiate( + shape, + data, + customWritable, + )(writer) else -> writer.writeWithNoFormatting("todo!() /* $shape $data */") } @@ -177,7 +184,12 @@ open class Instantiator( * If the shape is optional: `Some(inner)` or `None`. * Otherwise: `inner`. */ - private fun renderMember(writer: RustWriter, memberShape: MemberShape, data: Node, ctx: Ctx) { + private fun renderMember( + writer: RustWriter, + memberShape: MemberShape, + data: Node, + ctx: Ctx, + ) { val targetShape = model.expectShape(memberShape.target) val symbol = symbolProvider.toSymbol(memberShape) customWritable.generate(memberShape) @@ -195,11 +207,13 @@ open class Instantiator( "#{Some}(", ")", // The conditions are not commutative: note client builders always take in `Option`. - conditional = symbol.isOptional() || - ( - model.expectShape(memberShape.container) is StructureShape && builderKindBehavior.doesSetterTakeInOption( - memberShape, - ) + conditional = + symbol.isOptional() || + ( + model.expectShape(memberShape.container) is StructureShape && + builderKindBehavior.doesSetterTakeInOption( + memberShape, + ) ), *preludeScope, ) { @@ -225,8 +239,12 @@ open class Instantiator( } } - private fun renderSet(writer: RustWriter, shape: SetShape, data: ArrayNode, ctx: Ctx) = - renderList(writer, shape, data, ctx) + private fun renderSet( + writer: RustWriter, + shape: SetShape, + data: ArrayNode, + ctx: Ctx, + ) = renderList(writer, shape, data, ctx) /** * ```rust @@ -238,7 +256,12 @@ open class Instantiator( * } * ``` */ - private fun renderMap(writer: RustWriter, shape: MapShape, data: ObjectNode, ctx: Ctx) { + private fun renderMap( + writer: RustWriter, + shape: MapShape, + data: ObjectNode, + ctx: Ctx, + ) { if (data.members.isEmpty()) { writer.rust("#T::new()", RuntimeType.HashMap) } else { @@ -264,18 +287,24 @@ open class Instantiator( * MyUnion::Variant(...) * ``` */ - private fun renderUnion(writer: RustWriter, shape: UnionShape, data: ObjectNode, ctx: Ctx) { + private fun renderUnion( + writer: RustWriter, + shape: UnionShape, + data: ObjectNode, + ctx: Ctx, + ) { val unionSymbol = symbolProvider.toSymbol(shape) - val variant = if (defaultsForRequiredFields && data.members.isEmpty()) { - val (name, memberShape) = shape.allMembers.entries.first() - val targetShape = model.expectShape(memberShape.target) - Node.from(name) to fillDefaultValue(targetShape) - } else { - check(data.members.size == 1) - val entry = data.members.iterator().next() - entry.key to entry.value - } + val variant = + if (defaultsForRequiredFields && data.members.isEmpty()) { + val (name, memberShape) = shape.allMembers.entries.first() + val targetShape = model.expectShape(memberShape.target) + Node.from(name) to fillDefaultValue(targetShape) + } else { + check(data.members.size == 1) + val entry = data.members.iterator().next() + entry.key to entry.value + } val memberName = variant.first.value val member = shape.expectMember(memberName) @@ -293,7 +322,12 @@ open class Instantiator( * vec![..., ..., ...] * ``` */ - private fun renderList(writer: RustWriter, shape: CollectionShape, data: ArrayNode, ctx: Ctx) { + private fun renderList( + writer: RustWriter, + shape: CollectionShape, + data: ArrayNode, + ctx: Ctx, + ) { writer.withBlock("vec![", "]") { data.elements.forEach { v -> renderMember(this, shape.member, v, ctx) @@ -344,7 +378,11 @@ open class Instantiator( ctx: Ctx, ) { val renderedMembers = mutableSetOf() - fun renderMemberHelper(memberShape: MemberShape, value: Node) { + + fun renderMemberHelper( + memberShape: MemberShape, + value: Node, + ) { renderedMembers.add(memberShape) when (constructPattern) { InstantiatorConstructPattern.DIRECT -> { @@ -410,24 +448,25 @@ open class Instantiator( * * Warning: this method does not take into account any constraint traits attached to the shape. */ - private fun fillDefaultValue(shape: Shape): Node = when (shape) { - is MemberShape -> fillDefaultValue(model.expectShape(shape.target)) - - // Aggregate shapes. - is StructureShape -> Node.objectNode() - is UnionShape -> Node.objectNode() - is CollectionShape -> Node.arrayNode() - is MapShape -> Node.objectNode() - - // Simple Shapes - is TimestampShape -> Node.from(0) // Number node for timestamp - is BlobShape -> Node.from("") // String node for bytes - is StringShape -> Node.from("") - is NumberShape -> Node.from(0) - is BooleanShape -> Node.from(false) - is DocumentShape -> Node.objectNode() - else -> throw CodegenException("Unrecognized shape `$shape`") - } + private fun fillDefaultValue(shape: Shape): Node = + when (shape) { + is MemberShape -> fillDefaultValue(model.expectShape(shape.target)) + + // Aggregate shapes. + is StructureShape -> Node.objectNode() + is UnionShape -> Node.objectNode() + is CollectionShape -> Node.arrayNode() + is MapShape -> Node.objectNode() + + // Simple Shapes + is TimestampShape -> Node.from(0) // Number node for timestamp + is BlobShape -> Node.from("") // String node for bytes + is StringShape -> Node.from("") + is NumberShape -> Node.from(0) + is BooleanShape -> Node.from(false) + is DocumentShape -> Node.objectNode() + else -> throw CodegenException("Unrecognized shape `$shape`") + } } class PrimitiveInstantiator(private val runtimeConfig: RuntimeConfig, private val symbolProvider: SymbolProvider) { @@ -455,56 +494,63 @@ class PrimitiveInstantiator(private val runtimeConfig: RuntimeConfig, private va * Blob::new("arg") * ``` */ - is BlobShape -> if (shape.hasTrait()) { - rust( - "#T::from_static(b${(data as StringNode).value.dq()})", - RuntimeType.byteStream(runtimeConfig), - ) - } else { - rust( - "#T::new(${(data as StringNode).value.dq()})", - RuntimeType.blob(runtimeConfig), - ) - } - - is StringShape -> renderString(shape, data as StringNode)(this) - is NumberShape -> when (data) { - is StringNode -> { - val numberSymbol = symbolProvider.toSymbol(shape) - // support Smithy custom values, such as Infinity + is BlobShape -> + if (shape.hasTrait()) { rust( - """<#T as #T>::parse_smithy_primitive(${data.value.dq()}).expect("invalid string for number")""", - numberSymbol, - RuntimeType.smithyTypes(runtimeConfig).resolve("primitive::Parse"), + "#T::from_static(b${(data as StringNode).value.dq()})", + RuntimeType.byteStream(runtimeConfig), + ) + } else { + rust( + "#T::new(${(data as StringNode).value.dq()})", + RuntimeType.blob(runtimeConfig), ) } - is NumberNode -> write(data.value) - } + is StringShape -> renderString(shape, data as StringNode)(this) + is NumberShape -> + when (data) { + is StringNode -> { + val numberSymbol = symbolProvider.toSymbol(shape) + // support Smithy custom values, such as Infinity + rust( + """<#T as #T>::parse_smithy_primitive(${data.value.dq()}).expect("invalid string for number")""", + numberSymbol, + RuntimeType.smithyTypes(runtimeConfig).resolve("primitive::Parse"), + ) + } + + is NumberNode -> write(data.value) + } is BooleanShape -> rust(data.asBooleanNode().get().toString()) - is DocumentShape -> rustBlock("") { - val smithyJson = CargoDependency.smithyJson(runtimeConfig).toType() - rustTemplate( - """ - let json_bytes = br##"${Node.prettyPrintJson(data)}"##; - let mut tokens = #{json_token_iter}(json_bytes).peekable(); - #{expect_document}(&mut tokens).expect("well formed json") - """, - "expect_document" to smithyJson.resolve("deserialize::token::expect_document"), - "json_token_iter" to smithyJson.resolve("deserialize::json_token_iter"), - ) - } + is DocumentShape -> + rustBlock("") { + val smithyJson = CargoDependency.smithyJson(runtimeConfig).toType() + rustTemplate( + """ + let json_bytes = br##"${Node.prettyPrintJson(data)}"##; + let mut tokens = #{json_token_iter}(json_bytes).peekable(); + #{expect_document}(&mut tokens).expect("well formed json") + """, + "expect_document" to smithyJson.resolve("deserialize::token::expect_document"), + "json_token_iter" to smithyJson.resolve("deserialize::json_token_iter"), + ) + } } } - private fun renderString(shape: StringShape, arg: StringNode): Writable = { - val data = escape(arg.value).dq() - if (shape.hasTrait() || shape is EnumShape) { - val enumSymbol = symbolProvider.toSymbol(shape) - rust("""$data.parse::<#T>().expect("static value validated to member")""", enumSymbol) - } else { - rust("$data.to_owned()") + private fun renderString( + shape: StringShape, + arg: StringNode, + ): Writable = + { + val data = escape(arg.value).dq() + if (shape.hasTrait() || shape is EnumShape) { + val enumSymbol = symbolProvider.toSymbol(shape) + rust("""$data.parse::<#T>().expect("static value validated to member")""", enumSymbol) + } else { + rust("$data.to_owned()") + } } - } } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/LibRsGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/LibRsGenerator.kt index ed05e0c1f07..c221d7d59c1 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/LibRsGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/LibRsGenerator.kt @@ -20,7 +20,9 @@ import software.amazon.smithy.rust.codegen.core.util.getTrait sealed class ModuleDocSection { data class ServiceDocs(val documentationTraitValue: String?) : ModuleDocSection() + object CrateOrganization : ModuleDocSection() + object Examples : ModuleDocSection() } @@ -40,9 +42,10 @@ class LibRsGenerator( private val customizations: List, private val requireDocs: Boolean, ) { - private fun docSection(section: ModuleDocSection): List = customizations - .map { customization -> customization.section(LibRsSection.ModuleDoc(section)) } - .filter { it.isNotEmpty() } + private fun docSection(section: ModuleDocSection): List = + customizations + .map { customization -> customization.section(LibRsSection.ModuleDoc(section)) } + .filter { it.isNotEmpty() } fun render(writer: RustWriter) { writer.first { diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/StructureGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/StructureGenerator.kt index 64b7f55702d..61ea75bb35e 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/StructureGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/StructureGenerator.kt @@ -68,20 +68,22 @@ open class StructureGenerator( ) { companion object { /** Reserved struct member names */ - val structureMemberNameMap: Map = mapOf( - "build" to "build_value", - "builder" to "builder_value", - "default" to "default_value", - ) + val structureMemberNameMap: Map = + mapOf( + "build" to "build_value", + "builder" to "builder_value", + "default" to "default_value", + ) } private val errorTrait = shape.getTrait() protected val members: List = shape.allMembers.values.toList() - private val accessorMembers: List = when (errorTrait) { - null -> members - // Let the ErrorGenerator render the error message accessor if this is an error struct - else -> members.filter { "message" != symbolProvider.toMemberName(it) } - } + private val accessorMembers: List = + when (errorTrait) { + null -> members + // Let the ErrorGenerator render the error message accessor if this is an error struct + else -> members.filter { "message" != symbolProvider.toMemberName(it) } + } protected val name: String = symbolProvider.toSymbol(shape).name fun render() { @@ -129,18 +131,19 @@ open class StructureGenerator( forEachMember(accessorMembers) { member, memberName, memberSymbol -> val memberType = memberSymbol.rustType() var unwrapOrDefault = false - val returnType = when { - // Automatically flatten vecs - structSettings.flattenVecAccessors && memberType is RustType.Option && memberType.stripOuter() is RustType.Vec -> { - unwrapOrDefault = true - memberType.stripOuter().asDeref().asRef() + val returnType = + when { + // Automatically flatten vecs + structSettings.flattenVecAccessors && memberType is RustType.Option && memberType.stripOuter() is RustType.Vec -> { + unwrapOrDefault = true + memberType.stripOuter().asDeref().asRef() + } + + memberType.isCopy() -> memberType + memberType is RustType.Option && memberType.member.isDeref() -> memberType.asDeref() + memberType.isDeref() -> memberType.asDeref().asRef() + else -> memberType.asRef() } - - memberType.isCopy() -> memberType - memberType is RustType.Option && memberType.member.isDeref() -> memberType.asDeref() - memberType.isDeref() -> memberType.asDeref().asRef() - else -> memberType.asRef() - } writer.renderMemberDoc(member, memberSymbol) if (unwrapOrDefault) { // Add a newline @@ -164,7 +167,12 @@ open class StructureGenerator( } } - open fun renderStructureMember(writer: RustWriter, member: MemberShape, memberName: String, memberSymbol: Symbol) { + open fun renderStructureMember( + writer: RustWriter, + member: MemberShape, + memberName: String, + memberSymbol: Symbol, + ) { writer.renderMemberDoc(member, memberSymbol) writer.deprecatedShape(member) memberSymbol.expectRustMetadata().render(writer) @@ -204,12 +212,16 @@ open class StructureGenerator( } } - private fun RustWriter.renderMemberDoc(member: MemberShape, memberSymbol: Symbol) { + private fun RustWriter.renderMemberDoc( + member: MemberShape, + memberSymbol: Symbol, + ) { documentShape( member, model, - note = memberSymbol.renamedFrom() - ?.let { oldName -> "This member has been renamed from `$oldName`." }, + note = + memberSymbol.renamedFrom() + ?.let { oldName -> "This member has been renamed from `$oldName`." }, ) } } @@ -219,10 +231,11 @@ open class StructureGenerator( * e.g. `<'a, 'b>` */ fun StructureShape.lifetimeDeclaration(symbolProvider: RustSymbolProvider): String { - val lifetimes = this.members() - .mapNotNull { symbolProvider.toSymbol(it).rustType().innerReference()?.let { it as RustType.Reference } } - .mapNotNull { it.lifetime } - .toSet().sorted() + val lifetimes = + this.members() + .mapNotNull { symbolProvider.toSymbol(it).rustType().innerReference()?.let { it as RustType.Reference } } + .mapNotNull { it.lifetime } + .toSet().sorted() return if (lifetimes.isNotEmpty()) { "<${lifetimes.joinToString { "'$it" }}>" } else { diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/UnionGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/UnionGenerator.kt index 93bea4c0101..0b35ed99488 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/UnionGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/UnionGenerator.kt @@ -34,10 +34,11 @@ import software.amazon.smithy.rust.codegen.core.util.isTargetUnit import software.amazon.smithy.rust.codegen.core.util.shouldRedact import software.amazon.smithy.rust.codegen.core.util.toSnakeCase -fun CodegenTarget.renderUnknownVariant() = when (this) { - CodegenTarget.SERVER -> false - CodegenTarget.CLIENT -> true -} +fun CodegenTarget.renderUnknownVariant() = + when (this) { + CodegenTarget.SERVER -> false + CodegenTarget.CLIENT -> true + } /** * Generate an `enum` for a Smithy Union Shape @@ -177,7 +178,11 @@ fun unknownVariantError(union: String) = "The `Unknown` variant is intended for responses only. " + "It occurs when an outdated client is used after a new enum variant was added on the server side." -private fun RustWriter.renderVariant(symbolProvider: SymbolProvider, member: MemberShape, memberSymbol: Symbol) { +private fun RustWriter.renderVariant( + symbolProvider: SymbolProvider, + member: MemberShape, + memberSymbol: Symbol, +) { if (member.isTargetUnit()) { write("${symbolProvider.toMemberName(member)},") } else { diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/error/ErrorImplGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/error/ErrorImplGenerator.kt index 049933bc459..bc6881521bf 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/error/ErrorImplGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/error/ErrorImplGenerator.kt @@ -108,21 +108,22 @@ class ErrorImplGenerator( val messageSymbol = symbolProvider.toSymbol(messageShape).mapRustType { t -> t.asDeref() } val messageType = messageSymbol.rustType() val memberName = symbolProvider.toMemberName(messageShape) - val (returnType, message) = if (messageType.stripOuter() is RustType.Opaque) { - // The string shape has a constraint trait that makes its symbol be a wrapper tuple struct. - if (messageSymbol.isOptional()) { - "Option<&${messageType.stripOuter().render()}>" to - "self.$memberName.as_ref()" - } else { - "&${messageType.render()}" to "&self.$memberName" - } - } else { - if (messageSymbol.isOptional()) { - messageType.render() to "self.$memberName.as_deref()" + val (returnType, message) = + if (messageType.stripOuter() is RustType.Opaque) { + // The string shape has a constraint trait that makes its symbol be a wrapper tuple struct. + if (messageSymbol.isOptional()) { + "Option<&${messageType.stripOuter().render()}>" to + "self.$memberName.as_ref()" + } else { + "&${messageType.render()}" to "&self.$memberName" + } } else { - messageType.render() to "self.$memberName.as_ref()" + if (messageSymbol.isOptional()) { + messageType.render() to "self.$memberName.as_deref()" + } else { + messageType.render() to "self.$memberName.as_ref()" + } } - } rust( """ @@ -153,9 +154,10 @@ class ErrorImplGenerator( rustBlock("fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result") { // If the error id and the Rust name don't match, print the actual error id for easy debugging // Note: Exceptions cannot be renamed so it is OK to not call `getName(service)` here - val errorDesc = symbol.name.letIf(symbol.name != shape.id.name) { symbolName -> - "$symbolName [${shape.id.name}]" - } + val errorDesc = + symbol.name.letIf(symbol.name != shape.id.name) { symbolName -> + "$symbolName [${shape.id.name}]" + } write("::std::write!(f, ${errorDesc.dq()})?;") messageShape?.let { if (it.shouldRedact(model)) { diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/http/HttpBindingGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/http/HttpBindingGenerator.kt index 8ebf6082c86..68cdaadd1ff 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/http/HttpBindingGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/http/HttpBindingGenerator.kt @@ -75,7 +75,8 @@ import software.amazon.smithy.rust.codegen.core.util.redactIfNecessary * - serializing data to an HTTP response (we are a server), */ enum class HttpMessageType { - REQUEST, RESPONSE + REQUEST, + RESPONSE, } /** @@ -164,16 +165,17 @@ class HttpBindingGenerator( val outputSymbol = symbolProvider.toSymbol(binding.member) val target = model.expectShape(binding.member.target) check(target is MapShape) - val inner = protocolFunctions.deserializeFn(binding.member, fnNameSuffix = "inner") { fnName -> - rustBlockTemplate( - "pub fn $fnName<'a>(headers: impl #{Iterator}) -> std::result::Result, #{header_util}::ParseError>", - *preludeScope, - "Value" to symbolProvider.toSymbol(model.expectShape(target.value.target)), - "header_util" to headerUtil, - ) { - deserializeFromHeader(model.expectShape(target.value.target), binding.member) + val inner = + protocolFunctions.deserializeFn(binding.member, fnNameSuffix = "inner") { fnName -> + rustBlockTemplate( + "pub fn $fnName<'a>(headers: impl #{Iterator}) -> std::result::Result, #{header_util}::ParseError>", + *preludeScope, + "Value" to symbolProvider.toSymbol(model.expectShape(target.value.target)), + "header_util" to headerUtil, + ) { + deserializeFromHeader(model.expectShape(target.value.target), binding.member) + } } - } val returnTypeSymbol = outputSymbol.mapRustType { it.asOptional() } return protocolFunctions.deserializeFn(binding.member, fnNameSuffix = "prefix_header") { fnName -> rustBlockTemplate( @@ -252,13 +254,18 @@ class HttpBindingGenerator( } } - private fun RustWriter.bindEventStreamOutput(operationShape: OperationShape, outputT: Symbol, targetShape: UnionShape) { - val unmarshallerConstructorFn = EventStreamUnmarshallerGenerator( - protocol, - codegenContext, - operationShape, - targetShape, - ).render() + private fun RustWriter.bindEventStreamOutput( + operationShape: OperationShape, + outputT: Symbol, + targetShape: UnionShape, + ) { + val unmarshallerConstructorFn = + EventStreamUnmarshallerGenerator( + protocol, + codegenContext, + operationShape, + targetShape, + ).render() rustTemplate( """ let unmarshaller = #{unmarshallerConstructorFn}(); @@ -267,17 +274,18 @@ class HttpBindingGenerator( """, "SdkBody" to RuntimeType.sdkBody(runtimeConfig), "unmarshallerConstructorFn" to unmarshallerConstructorFn, - "receiver" to writable { - if (codegenTarget == CodegenTarget.SERVER) { - rust("${outputT.rustType().qualifiedName()}::new(unmarshaller, body)") - } else { - rustTemplate( - "#{EventReceiver}::new(#{Receiver}::new(unmarshaller, body))", - "EventReceiver" to RuntimeType.eventReceiver(runtimeConfig), - "Receiver" to RuntimeType.eventStreamReceiver(runtimeConfig), - ) - } - }, + "receiver" to + writable { + if (codegenTarget == CodegenTarget.SERVER) { + rust("${outputT.rustType().qualifiedName()}::new(unmarshaller, body)") + } else { + rustTemplate( + "#{EventReceiver}::new(#{Receiver}::new(unmarshaller, body))", + "EventReceiver" to RuntimeType.eventReceiver(runtimeConfig), + "Receiver" to RuntimeType.eventStreamReceiver(runtimeConfig), + ) + } + }, ) } @@ -338,10 +346,11 @@ class HttpBindingGenerator( } } - is BlobShape -> rust( - "Ok(#T::new(body))", - symbolProvider.toSymbol(targetShape), - ) + is BlobShape -> + rust( + "Ok(#T::new(body))", + symbolProvider.toSymbol(targetShape), + ) // `httpPayload` can be applied to set/map/list shapes. // However, none of the AWS protocols support it. // Smithy CLI will refuse to build the model if you apply the trait to these shapes, so this branch @@ -355,7 +364,10 @@ class HttpBindingGenerator( * Parse a value from a header. * This function produces an expression which produces the precise type required by the target shape. */ - private fun RustWriter.deserializeFromHeader(targetShape: Shape, memberShape: MemberShape) { + private fun RustWriter.deserializeFromHeader( + targetShape: Shape, + memberShape: MemberShape, + ) { val rustType = symbolProvider.toSymbol(targetShape).rustType().stripOuter() // Normally, we go through a flow that looks for `,`s but that's wrong if the output // is just a single string (which might include `,`s.). @@ -364,12 +376,13 @@ class HttpBindingGenerator( rust("#T::one_or_none(headers)", headerUtil) return } - val (coreType, coreShape) = if (targetShape is CollectionShape) { - val coreShape = model.expectShape(targetShape.member.target) - symbolProvider.toSymbol(coreShape).rustType() to coreShape - } else { - rustType to targetShape - } + val (coreType, coreShape) = + if (targetShape is CollectionShape) { + val coreShape = model.expectShape(targetShape.member.target) + symbolProvider.toSymbol(coreShape).rustType() to coreShape + } else { + rustType to targetShape + } val parsedValue = safeName() if (coreShape.isTimestampShape()) { val timestampFormat = @@ -481,14 +494,17 @@ class HttpBindingGenerator( shape: Shape, httpMessageType: HttpMessageType = HttpMessageType.REQUEST, ): RuntimeType? { - val (headerBindings, prefixHeaderBinding) = when (httpMessageType) { - // Only a single structure member can be bound by `httpPrefixHeaders`, hence the `getOrNull(0)`. - HttpMessageType.REQUEST -> index.getRequestBindings(shape, HttpLocation.HEADER) to - index.getRequestBindings(shape, HttpLocation.PREFIX_HEADERS).getOrNull(0) - - HttpMessageType.RESPONSE -> index.getResponseBindings(shape, HttpLocation.HEADER) to - index.getResponseBindings(shape, HttpLocation.PREFIX_HEADERS).getOrNull(0) - } + val (headerBindings, prefixHeaderBinding) = + when (httpMessageType) { + // Only a single structure member can be bound by `httpPrefixHeaders`, hence the `getOrNull(0)`. + HttpMessageType.REQUEST -> + index.getRequestBindings(shape, HttpLocation.HEADER) to + index.getRequestBindings(shape, HttpLocation.PREFIX_HEADERS).getOrNull(0) + + HttpMessageType.RESPONSE -> + index.getResponseBindings(shape, HttpLocation.HEADER) to + index.getResponseBindings(shape, HttpLocation.PREFIX_HEADERS).getOrNull(0) + } if (headerBindings.isEmpty() && prefixHeaderBinding == null) { return null @@ -497,22 +513,24 @@ class HttpBindingGenerator( return protocolFunctions.serializeFn(shape, fnNameSuffix = "headers") { fnName -> // If the shape is an operation shape, the input symbol of the generated function is the input or output // shape, which is the shape holding the header-bound data. - val shapeSymbol = symbolProvider.toSymbol( - if (shape is OperationShape) { - when (httpMessageType) { - HttpMessageType.REQUEST -> shape.inputShape(model) - HttpMessageType.RESPONSE -> shape.outputShape(model) - } - } else { - shape - }, - ) - val codegenScope = arrayOf( - "BuildError" to runtimeConfig.operationBuildError(), - HttpMessageType.REQUEST.name to RuntimeType.HttpRequestBuilder, - HttpMessageType.RESPONSE.name to RuntimeType.HttpResponseBuilder, - "Shape" to shapeSymbol, - ) + val shapeSymbol = + symbolProvider.toSymbol( + if (shape is OperationShape) { + when (httpMessageType) { + HttpMessageType.REQUEST -> shape.inputShape(model) + HttpMessageType.RESPONSE -> shape.outputShape(model) + } + } else { + shape + }, + ) + val codegenScope = + arrayOf( + "BuildError" to runtimeConfig.operationBuildError(), + HttpMessageType.REQUEST.name to RuntimeType.HttpRequestBuilder, + HttpMessageType.RESPONSE.name to RuntimeType.HttpResponseBuilder, + "Shape" to shapeSymbol, + ) rustBlockTemplate( """ pub fn $fnName( @@ -642,13 +660,14 @@ class HttpBindingGenerator( val encoder = RuntimeType.smithyTypes(runtimeConfig).resolve("primitive::Encoder") rust("let mut encoder = #T::from(${variableName.asValue()});", encoder) } - val formatted = headerFmtFun( - this, - shape, - timestampFormat, - context.valueExpression.name, - isMultiValuedHeader = isMultiValuedHeader, - ) + val formatted = + headerFmtFun( + this, + shape, + timestampFormat, + context.valueExpression.name, + isMultiValuedHeader = isMultiValuedHeader, + ) val safeName = safeName("formatted") rustTemplate( """ @@ -712,20 +731,22 @@ class HttpBindingGenerator( """, "HeaderValue" to RuntimeType.Http.resolve("HeaderValue"), - "invalid_header_name" to OperationBuildError(runtimeConfig).invalidField(memberName) { - rust("""format!("`{k}` cannot be used as a header name: {err}")""") - }, - "invalid_header_value" to OperationBuildError(runtimeConfig).invalidField(memberName) { - rust( - """ - format!( - "`{}` cannot be used as a header value: {}", - ${memberShape.redactIfNecessary(model, "v")}, - err + "invalid_header_name" to + OperationBuildError(runtimeConfig).invalidField(memberName) { + rust("""format!("`{k}` cannot be used as a header name: {err}")""") + }, + "invalid_header_value" to + OperationBuildError(runtimeConfig).invalidField(memberName) { + rust( + """ + format!( + "`{}` cannot be used as a header value: {}", + ${memberShape.redactIfNecessary(model, "v")}, + err + ) + """, ) - """, - ) - }, + }, ) } } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/ProtocolSupport.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/ProtocolSupport.kt index c66dae4ce5b..b32b9c893c9 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/ProtocolSupport.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/ProtocolSupport.kt @@ -6,12 +6,12 @@ package software.amazon.smithy.rust.codegen.core.smithy.generators.protocol data class ProtocolSupport( - /* Client support */ + // Client support val requestSerialization: Boolean, val requestBodySerialization: Boolean, val responseDeserialization: Boolean, val errorDeserialization: Boolean, - /* Server support */ + // Server support val requestDeserialization: Boolean, val requestBodyDeserialization: Boolean, val responseSerialization: Boolean, diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsJson.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsJson.kt index 06b479506c1..1b54f4289f9 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsJson.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsJson.kt @@ -42,11 +42,12 @@ class AwsJsonHttpBindingResolver( private val model: Model, private val awsJsonVersion: AwsJsonVersion, ) : HttpBindingResolver { - private val httpTrait = HttpTrait.builder() - .code(200) - .method("POST") - .uri(UriPattern.parse("/")) - .build() + private val httpTrait = + HttpTrait.builder() + .code(200) + .method("POST") + .uri(UriPattern.parse("/")) + .build() private fun bindings(shape: ToShapeId): List { val members = shape.let { model.expectShape(it.toShapeId()) }.members() @@ -76,8 +77,7 @@ class AwsJsonHttpBindingResolver( override fun responseBindings(operationShape: OperationShape): List = bindings(operationShape.outputShape) - override fun errorResponseBindings(errorShape: ToShapeId): List = - bindings(errorShape) + override fun errorResponseBindings(errorShape: ToShapeId): List = bindings(errorShape) override fun requestContentType(operationShape: OperationShape): String = "application/x-amz-json-${awsJsonVersion.value}" @@ -96,24 +96,26 @@ class AwsJsonSerializerGenerator( JsonSerializerGenerator(codegenContext, httpBindingResolver, ::awsJsonFieldName), ) : StructuredDataSerializerGenerator by jsonSerializerGenerator { private val runtimeConfig = codegenContext.runtimeConfig - private val codegenScope = arrayOf( - "Error" to runtimeConfig.serializationError(), - "SdkBody" to RuntimeType.sdkBody(runtimeConfig), - ) + private val codegenScope = + arrayOf( + "Error" to runtimeConfig.serializationError(), + "SdkBody" to RuntimeType.sdkBody(runtimeConfig), + ) private val protocolFunctions = ProtocolFunctions(codegenContext) override fun operationInputSerializer(operationShape: OperationShape): RuntimeType { var serializer = jsonSerializerGenerator.operationInputSerializer(operationShape) if (serializer == null) { val inputShape = operationShape.inputShape(codegenContext.model) - serializer = protocolFunctions.serializeFn(operationShape, fnNameSuffix = "input") { fnName -> - rustBlockTemplate( - "pub fn $fnName(_input: &#{target}) -> Result<#{SdkBody}, #{Error}>", - *codegenScope, "target" to codegenContext.symbolProvider.toSymbol(inputShape), - ) { - rustTemplate("""Ok(#{SdkBody}::from("{}"))""", *codegenScope) + serializer = + protocolFunctions.serializeFn(operationShape, fnNameSuffix = "input") { fnName -> + rustBlockTemplate( + "pub fn $fnName(_input: &#{target}) -> Result<#{SdkBody}, #{Error}>", + *codegenScope, "target" to codegenContext.symbolProvider.toSymbol(inputShape), + ) { + rustTemplate("""Ok(#{SdkBody}::from("{}"))""", *codegenScope) + } } - } } return serializer } @@ -124,14 +126,16 @@ open class AwsJson( val awsJsonVersion: AwsJsonVersion, ) : Protocol { private val runtimeConfig = codegenContext.runtimeConfig - private val errorScope = arrayOf( - "Bytes" to RuntimeType.Bytes, - "ErrorMetadataBuilder" to RuntimeType.errorMetadataBuilder(runtimeConfig), - "Headers" to RuntimeType.headers(runtimeConfig), - "JsonError" to CargoDependency.smithyJson(runtimeConfig).toType() - .resolve("deserialize::error::DeserializeError"), - "json_errors" to RuntimeType.jsonErrors(runtimeConfig), - ) + private val errorScope = + arrayOf( + "Bytes" to RuntimeType.Bytes, + "ErrorMetadataBuilder" to RuntimeType.errorMetadataBuilder(runtimeConfig), + "Headers" to RuntimeType.headers(runtimeConfig), + "JsonError" to + CargoDependency.smithyJson(runtimeConfig).toType() + .resolve("deserialize::error::DeserializeError"), + "json_errors" to RuntimeType.jsonErrors(runtimeConfig), + ) val version: AwsJsonVersion get() = awsJsonVersion diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsQuery.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsQuery.kt index da7ccd1ad12..7f62cd9f8af 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsQuery.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsQuery.kt @@ -22,11 +22,12 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.AwsQu import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.StructuredDataSerializerGenerator import software.amazon.smithy.rust.codegen.core.util.getTrait -private val awsQueryHttpTrait = HttpTrait.builder() - .code(200) - .method("POST") - .uri(UriPattern.parse("/")) - .build() +private val awsQueryHttpTrait = + HttpTrait.builder() + .code(200) + .method("POST") + .uri(UriPattern.parse("/")) + .build() class AwsQueryBindingResolver(private val model: Model) : StaticHttpBindingResolver(model, awsQueryHttpTrait, "application/x-www-form-urlencoded", "text/xml") { @@ -39,12 +40,13 @@ class AwsQueryBindingResolver(private val model: Model) : class AwsQueryProtocol(private val codegenContext: CodegenContext) : Protocol { private val runtimeConfig = codegenContext.runtimeConfig private val awsQueryErrors: RuntimeType = RuntimeType.wrappedXmlErrors(runtimeConfig) - private val errorScope = arrayOf( - "Bytes" to RuntimeType.Bytes, - "ErrorMetadataBuilder" to RuntimeType.errorMetadataBuilder(runtimeConfig), - "Headers" to RuntimeType.headers(runtimeConfig), - "XmlDecodeError" to RuntimeType.smithyXml(runtimeConfig).resolve("decode::XmlDecodeError"), - ) + private val errorScope = + arrayOf( + "Bytes" to RuntimeType.Bytes, + "ErrorMetadataBuilder" to RuntimeType.errorMetadataBuilder(runtimeConfig), + "Headers" to RuntimeType.headers(runtimeConfig), + "XmlDecodeError" to RuntimeType.smithyXml(runtimeConfig).resolve("decode::XmlDecodeError"), + ) override val httpBindingResolver: HttpBindingResolver = AwsQueryBindingResolver(codegenContext.model) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsQueryCompatible.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsQueryCompatible.kt index 61861a369a7..7b15e81051d 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsQueryCompatible.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsQueryCompatible.kt @@ -31,8 +31,7 @@ class AwsQueryCompatibleHttpBindingResolver( override fun errorResponseBindings(errorShape: ToShapeId): List = awsJsonHttpBindingResolver.errorResponseBindings(errorShape) - override fun errorCode(errorShape: ToShapeId): String = - awsQueryBindingResolver.errorCode(errorShape) + override fun errorCode(errorShape: ToShapeId): String = awsQueryBindingResolver.errorCode(errorShape) override fun requestContentType(operationShape: OperationShape): String = awsJsonHttpBindingResolver.requestContentType(operationShape) @@ -46,15 +45,17 @@ class AwsQueryCompatible( private val awsJson: AwsJson, ) : Protocol { private val runtimeConfig = codegenContext.runtimeConfig - private val errorScope = arrayOf( - "Bytes" to RuntimeType.Bytes, - "ErrorMetadataBuilder" to RuntimeType.errorMetadataBuilder(runtimeConfig), - "Headers" to RuntimeType.headers(runtimeConfig), - "JsonError" to CargoDependency.smithyJson(runtimeConfig).toType() - .resolve("deserialize::error::DeserializeError"), - "aws_query_compatible_errors" to RuntimeType.awsQueryCompatibleErrors(runtimeConfig), - "json_errors" to RuntimeType.jsonErrors(runtimeConfig), - ) + private val errorScope = + arrayOf( + "Bytes" to RuntimeType.Bytes, + "ErrorMetadataBuilder" to RuntimeType.errorMetadataBuilder(runtimeConfig), + "Headers" to RuntimeType.headers(runtimeConfig), + "JsonError" to + CargoDependency.smithyJson(runtimeConfig).toType() + .resolve("deserialize::error::DeserializeError"), + "aws_query_compatible_errors" to RuntimeType.awsQueryCompatibleErrors(runtimeConfig), + "json_errors" to RuntimeType.jsonErrors(runtimeConfig), + ) override val httpBindingResolver: HttpBindingResolver = AwsQueryCompatibleHttpBindingResolver( @@ -64,11 +65,9 @@ class AwsQueryCompatible( override val defaultTimestampFormat = awsJson.defaultTimestampFormat - override fun structuredDataParser(): StructuredDataParserGenerator = - awsJson.structuredDataParser() + override fun structuredDataParser(): StructuredDataParserGenerator = awsJson.structuredDataParser() - override fun structuredDataSerializer(): StructuredDataSerializerGenerator = - awsJson.structuredDataSerializer() + override fun structuredDataSerializer(): StructuredDataSerializerGenerator = awsJson.structuredDataSerializer() override fun parseHttpErrorMetadata(operationShape: OperationShape): RuntimeType = ProtocolFunctions.crossOperationFn("parse_http_error_metadata") { fnName -> diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/Ec2Query.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/Ec2Query.kt index 691736780b9..439cf35f031 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/Ec2Query.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/Ec2Query.kt @@ -21,23 +21,25 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.Struc class Ec2QueryProtocol(private val codegenContext: CodegenContext) : Protocol { private val runtimeConfig = codegenContext.runtimeConfig private val ec2QueryErrors: RuntimeType = RuntimeType.ec2QueryErrors(runtimeConfig) - private val errorScope = arrayOf( - "Bytes" to RuntimeType.Bytes, - "ErrorMetadataBuilder" to RuntimeType.errorMetadataBuilder(runtimeConfig), - "Headers" to RuntimeType.headers(runtimeConfig), - "XmlDecodeError" to RuntimeType.smithyXml(runtimeConfig).resolve("decode::XmlDecodeError"), - ) + private val errorScope = + arrayOf( + "Bytes" to RuntimeType.Bytes, + "ErrorMetadataBuilder" to RuntimeType.errorMetadataBuilder(runtimeConfig), + "Headers" to RuntimeType.headers(runtimeConfig), + "XmlDecodeError" to RuntimeType.smithyXml(runtimeConfig).resolve("decode::XmlDecodeError"), + ) - override val httpBindingResolver: HttpBindingResolver = StaticHttpBindingResolver( - codegenContext.model, - HttpTrait.builder() - .code(200) - .method("POST") - .uri(UriPattern.parse("/")) - .build(), - "application/x-www-form-urlencoded", - "text/xml", - ) + override val httpBindingResolver: HttpBindingResolver = + StaticHttpBindingResolver( + codegenContext.model, + HttpTrait.builder() + .code(200) + .method("POST") + .uri(UriPattern.parse("/")) + .build(), + "application/x-www-form-urlencoded", + "text/xml", + ) override val defaultTimestampFormat: TimestampFormatTrait.Format = TimestampFormatTrait.Format.DATE_TIME diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBindingResolver.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBindingResolver.kt index dbfe3de6103..8eb693c41fb 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBindingResolver.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBindingResolver.kt @@ -62,14 +62,18 @@ interface HttpBindingResolver { /** * Returns a list of member shapes bound to a given request [location] for a given [operationShape] */ - fun requestMembers(operationShape: OperationShape, location: HttpLocation): List = - requestBindings(operationShape).filter { it.location == location }.map { it.member } + fun requestMembers( + operationShape: OperationShape, + location: HttpLocation, + ): List = requestBindings(operationShape).filter { it.location == location }.map { it.member } /** * Returns a list of member shapes bound to a given response [location] for a given [operationShape] */ - fun responseMembers(operationShape: OperationShape, location: HttpLocation): List = - responseBindings(operationShape).filter { it.location == location }.map { it.member } + fun responseMembers( + operationShape: OperationShape, + location: HttpLocation, + ): List = responseBindings(operationShape).filter { it.location == location }.map { it.member } /** * Determine the timestamp format based on the input parameters. @@ -138,8 +142,7 @@ open class HttpTraitHttpBindingResolver( location: HttpLocation, defaultTimestampFormat: TimestampFormatTrait.Format, model: Model, - ): TimestampFormatTrait.Format = - httpIndex.determineTimestampFormat(memberShape, location, defaultTimestampFormat) + ): TimestampFormatTrait.Format = httpIndex.determineTimestampFormat(memberShape, location, defaultTimestampFormat) override fun requestContentType(operationShape: OperationShape): String? = httpIndex.determineRequestContentType( @@ -184,8 +187,7 @@ open class StaticHttpBindingResolver( override fun responseBindings(operationShape: OperationShape): List = bindings(operationShape.output.orNull()) - override fun errorResponseBindings(errorShape: ToShapeId): List = - bindings(errorShape) + override fun errorResponseBindings(errorShape: ToShapeId): List = bindings(errorShape) override fun requestContentType(operationShape: OperationShape): String = requestContentType diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBoundProtocolPayloadGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBoundProtocolPayloadGenerator.kt index 5b1d9df988a..a8fa69a6c27 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBoundProtocolPayloadGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBoundProtocolPayloadGenerator.kt @@ -62,26 +62,30 @@ class HttpBoundProtocolPayloadGenerator( private val target = codegenContext.target private val httpBindingResolver = protocol.httpBindingResolver private val smithyEventStream = RuntimeType.smithyEventStream(runtimeConfig) - private val codegenScope = arrayOf( - "hyper" to CargoDependency.HyperWithStream.toType(), - "SdkBody" to RuntimeType.sdkBody(runtimeConfig), - "BuildError" to runtimeConfig.operationBuildError(), - "SmithyHttp" to RuntimeType.smithyHttp(runtimeConfig), - "NoOpSigner" to smithyEventStream.resolve("frame::NoOpSigner"), - ) + private val codegenScope = + arrayOf( + "hyper" to CargoDependency.HyperWithStream.toType(), + "SdkBody" to RuntimeType.sdkBody(runtimeConfig), + "BuildError" to runtimeConfig.operationBuildError(), + "SmithyHttp" to RuntimeType.smithyHttp(runtimeConfig), + "NoOpSigner" to smithyEventStream.resolve("frame::NoOpSigner"), + ) private val protocolFunctions = ProtocolFunctions(codegenContext) override fun payloadMetadata( operationShape: OperationShape, additionalPayloadContext: AdditionalPayloadContext, ): ProtocolPayloadGenerator.PayloadMetadata { - val (shape, payloadMemberName) = when (httpMessageType) { - HttpMessageType.RESPONSE -> operationShape.outputShape(model) to - httpBindingResolver.responseMembers(operationShape, HttpLocation.PAYLOAD).firstOrNull()?.memberName + val (shape, payloadMemberName) = + when (httpMessageType) { + HttpMessageType.RESPONSE -> + operationShape.outputShape(model) to + httpBindingResolver.responseMembers(operationShape, HttpLocation.PAYLOAD).firstOrNull()?.memberName - HttpMessageType.REQUEST -> operationShape.inputShape(model) to - httpBindingResolver.requestMembers(operationShape, HttpLocation.PAYLOAD).firstOrNull()?.memberName - } + HttpMessageType.REQUEST -> + operationShape.inputShape(model) to + httpBindingResolver.requestMembers(operationShape, HttpLocation.PAYLOAD).firstOrNull()?.memberName + } // Only: // - streaming operations (blob streaming and event streams), @@ -95,9 +99,10 @@ class HttpBoundProtocolPayloadGenerator( } else { val member = shape.expectMember(payloadMemberName) when (val type = model.expectShape(member.target)) { - is DocumentShape, is StructureShape, is UnionShape -> ProtocolPayloadGenerator.PayloadMetadata( - takesOwnership = false, - ) + is DocumentShape, is StructureShape, is UnionShape -> + ProtocolPayloadGenerator.PayloadMetadata( + takesOwnership = false, + ) is StringShape, is BlobShape -> ProtocolPayloadGenerator.PayloadMetadata(takesOwnership = true) else -> UNREACHABLE("Unexpected payload target type: $type") @@ -112,24 +117,28 @@ class HttpBoundProtocolPayloadGenerator( additionalPayloadContext: AdditionalPayloadContext, ) { when (httpMessageType) { - HttpMessageType.RESPONSE -> generateResponsePayload( - writer, - shapeName, - operationShape, - additionalPayloadContext, - ) + HttpMessageType.RESPONSE -> + generateResponsePayload( + writer, + shapeName, + operationShape, + additionalPayloadContext, + ) - HttpMessageType.REQUEST -> generateRequestPayload( - writer, - shapeName, - operationShape, - additionalPayloadContext, - ) + HttpMessageType.REQUEST -> + generateRequestPayload( + writer, + shapeName, + operationShape, + additionalPayloadContext, + ) } } private fun generateRequestPayload( - writer: RustWriter, shapeName: String, operationShape: OperationShape, + writer: RustWriter, + shapeName: String, + operationShape: OperationShape, additionalPayloadContext: AdditionalPayloadContext, ) { val payloadMemberName = @@ -150,7 +159,9 @@ class HttpBoundProtocolPayloadGenerator( } private fun generateResponsePayload( - writer: RustWriter, shapeName: String, operationShape: OperationShape, + writer: RustWriter, + shapeName: String, + operationShape: OperationShape, additionalPayloadContext: AdditionalPayloadContext, ) { val payloadMemberName = @@ -203,15 +214,20 @@ class HttpBoundProtocolPayloadGenerator( ) } else { val bodyMetadata = payloadMetadata(operationShape) - val payloadMember = when (httpMessageType) { - HttpMessageType.RESPONSE -> operationShape.outputShape(model).expectMember(payloadMemberName) - HttpMessageType.REQUEST -> operationShape.inputShape(model).expectMember(payloadMemberName) - } + val payloadMember = + when (httpMessageType) { + HttpMessageType.RESPONSE -> operationShape.outputShape(model).expectMember(payloadMemberName) + HttpMessageType.REQUEST -> operationShape.inputShape(model).expectMember(payloadMemberName) + } writer.serializeViaPayload(bodyMetadata, shapeName, payloadMember, serializerGenerator) } } - private fun generateStructureSerializer(writer: RustWriter, shapeName: String, serializer: RuntimeType?) { + private fun generateStructureSerializer( + writer: RustWriter, + shapeName: String, + serializer: RuntimeType?, + ) { if (serializer == null) { writer.rust("\"\"") } else { @@ -232,28 +248,31 @@ class HttpBoundProtocolPayloadGenerator( val memberName = symbolProvider.toMemberName(memberShape) val unionShape = model.expectShape(memberShape.target, UnionShape::class.java) - val contentType = when (target) { - CodegenTarget.CLIENT -> httpBindingResolver.requestContentType(operationShape) - CodegenTarget.SERVER -> httpBindingResolver.responseContentType(operationShape) - } - val errorMarshallerConstructorFn = EventStreamErrorMarshallerGenerator( - model, - target, - runtimeConfig, - symbolProvider, - unionShape, - serializerGenerator, - contentType ?: throw CodegenException("event streams must set a content type"), - ).render() - val marshallerConstructorFn = EventStreamMarshallerGenerator( - model, - target, - runtimeConfig, - symbolProvider, - unionShape, - serializerGenerator, - contentType, - ).render() + val contentType = + when (target) { + CodegenTarget.CLIENT -> httpBindingResolver.requestContentType(operationShape) + CodegenTarget.SERVER -> httpBindingResolver.responseContentType(operationShape) + } + val errorMarshallerConstructorFn = + EventStreamErrorMarshallerGenerator( + model, + target, + runtimeConfig, + symbolProvider, + unionShape, + serializerGenerator, + contentType ?: throw CodegenException("event streams must set a content type"), + ).render() + val marshallerConstructorFn = + EventStreamMarshallerGenerator( + model, + target, + runtimeConfig, + symbolProvider, + unionShape, + serializerGenerator, + contentType, + ).render() // TODO(EventStream): [RPC] RPC protocols need to send an initial message with the // parameters that are not `@eventHeader` or `@eventPayload`. @@ -276,50 +295,53 @@ class HttpBoundProtocolPayloadGenerator( serializerGenerator: StructuredDataSerializerGenerator, ) { val ref = if (payloadMetadata.takesOwnership) "" else "&" - val serializer = protocolFunctions.serializeFn(member, fnNameSuffix = "http_payload") { fnName -> - val outputT = if (member.isStreaming(model)) { - symbolProvider.toSymbol(member) - } else { - RuntimeType.ByteSlab.toSymbol() - } - rustBlockTemplate( - "pub fn $fnName(payload: $ref#{Member}) -> Result<#{outputT}, #{BuildError}>", - "Member" to symbolProvider.toSymbol(member), - "outputT" to outputT, - *codegenScope, - ) { - val asRef = if (payloadMetadata.takesOwnership) "" else ".as_ref()" + val serializer = + protocolFunctions.serializeFn(member, fnNameSuffix = "http_payload") { fnName -> + val outputT = + if (member.isStreaming(model)) { + symbolProvider.toSymbol(member) + } else { + RuntimeType.ByteSlab.toSymbol() + } + rustBlockTemplate( + "pub fn $fnName(payload: $ref#{Member}) -> Result<#{outputT}, #{BuildError}>", + "Member" to symbolProvider.toSymbol(member), + "outputT" to outputT, + *codegenScope, + ) { + val asRef = if (payloadMetadata.takesOwnership) "" else ".as_ref()" - if (symbolProvider.toSymbol(member).isOptional()) { - withBlockTemplate( - """ - let payload = match payload$asRef { - Some(t) => t, - None => return Ok( - """, - ")};", - *codegenScope, - ) { - when (val targetShape = model.expectShape(member.target)) { - // Return an empty `Vec`. - is StringShape, is BlobShape, is DocumentShape -> rust( - """ - Vec::new() - """, - ) + if (symbolProvider.toSymbol(member).isOptional()) { + withBlockTemplate( + """ + let payload = match payload$asRef { + Some(t) => t, + None => return Ok( + """, + ")};", + *codegenScope, + ) { + when (val targetShape = model.expectShape(member.target)) { + // Return an empty `Vec`. + is StringShape, is BlobShape, is DocumentShape -> + rust( + """ + Vec::new() + """, + ) - is StructureShape -> rust("#T()", serializerGenerator.unsetStructure(targetShape)) - is UnionShape -> rust("#T()", serializerGenerator.unsetUnion(targetShape)) - else -> throw CodegenException("`httpPayload` on member shapes targeting shapes of type ${targetShape.type} is unsupported") + is StructureShape -> rust("#T()", serializerGenerator.unsetStructure(targetShape)) + is UnionShape -> rust("#T()", serializerGenerator.unsetUnion(targetShape)) + else -> throw CodegenException("`httpPayload` on member shapes targeting shapes of type ${targetShape.type} is unsupported") + } } } - } - withBlock("Ok(", ")") { - renderPayload(member, "payload", serializerGenerator) + withBlock("Ok(", ")") { + renderPayload(member, "payload", serializerGenerator) + } } } - } rust("#T($ref $shapeName.${symbolProvider.toMemberName(member)})?", serializer) } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/Protocol.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/Protocol.kt index 4a1339ca9aa..236c297db92 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/Protocol.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/Protocol.kt @@ -67,6 +67,8 @@ typealias ProtocolMap = Map> interface ProtocolGeneratorFactory { fun protocol(codegenContext: C): Protocol + fun buildProtocolGenerator(codegenContext: C): T + fun support(): ProtocolSupport } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/ProtocolFunctions.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/ProtocolFunctions.kt index c468fdca0bc..e40046f1d84 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/ProtocolFunctions.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/ProtocolFunctions.kt @@ -41,7 +41,10 @@ class ProtocolFunctions( companion object { private val serDeModule = RustModule.pubCrate("protocol_serde") - fun crossOperationFn(fnName: String, block: ProtocolFnWritable): RuntimeType = + fun crossOperationFn( + fnName: String, + block: ProtocolFnWritable, + ): RuntimeType = RuntimeType.forInlineFun(fnName, serDeModule) { block(fnName) } @@ -100,12 +103,13 @@ class ProtocolFunctions( val moduleName = codegenContext.symbolProvider.shapeModuleName(codegenContext.serviceShape, shape) val fnBaseName = codegenContext.symbolProvider.shapeFunctionName(codegenContext.serviceShape, shape) val suffix = fnNameSuffix?.let { "_$it" } ?: "" - val fnName = RustReservedWords.escapeIfNeeded( - when (fnType) { - FnType.Deserialize -> "de_$fnBaseName$suffix" - FnType.Serialize -> "ser_$fnBaseName$suffix" - }, - ) + val fnName = + RustReservedWords.escapeIfNeeded( + when (fnType) { + FnType.Deserialize -> "de_$fnBaseName$suffix" + FnType.Serialize -> "ser_$fnBaseName$suffix" + }, + ) return serDeFn(moduleName, fnName, parentModule, block) } @@ -115,12 +119,13 @@ class ProtocolFunctions( parentModule: RustModule.LeafModule, block: ProtocolFnWritable, ): RuntimeType { - val additionalAttributes = when { - // Some SDK models have maps with names prefixed with `__mapOf__`, which become `__map_of__`, - // and the Rust compiler warning doesn't like multiple adjacent underscores. - moduleName.contains("__") || fnName.contains("__") -> listOf(Attribute.AllowNonSnakeCase) - else -> emptyList() - } + val additionalAttributes = + when { + // Some SDK models have maps with names prefixed with `__mapOf__`, which become `__map_of__`, + // and the Rust compiler warning doesn't like multiple adjacent underscores. + moduleName.contains("__") || fnName.contains("__") -> listOf(Attribute.AllowNonSnakeCase) + else -> emptyList() + } return RuntimeType.forInlineFun( fnName, RustModule.pubCrate(moduleName, parent = parentModule, additionalAttributes = additionalAttributes), @@ -131,7 +136,10 @@ class ProtocolFunctions( } /** Creates a module name for a ser/de function. */ -internal fun RustSymbolProvider.shapeModuleName(serviceShape: ServiceShape?, shape: Shape): String = +internal fun RustSymbolProvider.shapeModuleName( + serviceShape: ServiceShape?, + shape: Shape, +): String = RustReservedWords.escapeIfNeeded( "shape_" + when (shape) { @@ -142,14 +150,19 @@ internal fun RustSymbolProvider.shapeModuleName(serviceShape: ServiceShape?, sha ) /** Creates a unique name for a ser/de function. */ -fun RustSymbolProvider.shapeFunctionName(serviceShape: ServiceShape?, shape: Shape): String { - val extras = "".letIf(shape.hasTrait()) { - it + "_output" - }.letIf(shape.hasTrait()) { it + "_input" } - val containerName = when (shape) { - is MemberShape -> model.expectShape(shape.container).contextName(serviceShape).toSnakeCase() - else -> shape.contextName(serviceShape).toSnakeCase() - } + extras +fun RustSymbolProvider.shapeFunctionName( + serviceShape: ServiceShape?, + shape: Shape, +): String { + val extras = + "".letIf(shape.hasTrait()) { + it + "_output" + }.letIf(shape.hasTrait()) { it + "_input" } + val containerName = + when (shape) { + is MemberShape -> model.expectShape(shape.container).contextName(serviceShape).toSnakeCase() + else -> shape.contextName(serviceShape).toSnakeCase() + } + extras return when (shape) { is MemberShape -> shape.memberName.toSnakeCase() is DocumentShape -> "document" diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestJson.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestJson.kt index 3e08ceca8e2..d839bac3c8d 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestJson.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestJson.kt @@ -42,9 +42,10 @@ class RestJsonHttpBindingResolver( * overridden by a specific mechanism e.g. an output shape member is targeted with `httpPayload` or `mediaType` traits. */ override fun responseContentType(operationShape: OperationShape): String? { - val members = operationShape - .outputShape(model) - .members() + val members = + operationShape + .outputShape(model) + .members() // TODO(https://github.com/awslabs/smithy/issues/1259) // Temporary fix for https://github.com/awslabs/smithy/blob/df456a514f72f4e35f0fb07c7e26006ff03b2071/smithy-model/src/main/java/software/amazon/smithy/model/knowledge/HttpBindingIndex.java#L352 for (member in members) { @@ -61,14 +62,16 @@ class RestJsonHttpBindingResolver( open class RestJson(val codegenContext: CodegenContext) : Protocol { private val runtimeConfig = codegenContext.runtimeConfig - private val errorScope = arrayOf( - "Bytes" to RuntimeType.Bytes, - "ErrorMetadataBuilder" to RuntimeType.errorMetadataBuilder(runtimeConfig), - "Headers" to RuntimeType.headers(runtimeConfig), - "JsonError" to CargoDependency.smithyJson(runtimeConfig).toType() - .resolve("deserialize::error::DeserializeError"), - "json_errors" to RuntimeType.jsonErrors(runtimeConfig), - ) + private val errorScope = + arrayOf( + "Bytes" to RuntimeType.Bytes, + "ErrorMetadataBuilder" to RuntimeType.errorMetadataBuilder(runtimeConfig), + "Headers" to RuntimeType.headers(runtimeConfig), + "JsonError" to + CargoDependency.smithyJson(runtimeConfig).toType() + .resolve("deserialize::error::DeserializeError"), + "json_errors" to RuntimeType.jsonErrors(runtimeConfig), + ) override val httpBindingResolver: HttpBindingResolver = RestJsonHttpBindingResolver(codegenContext.model, ProtocolContentTypes("application/json", "application/json", "application/vnd.amazon.eventstream")) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestXml.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestXml.kt index 15d294c0356..700fe60775c 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestXml.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestXml.kt @@ -21,17 +21,19 @@ import software.amazon.smithy.rust.codegen.core.util.expectTrait open class RestXml(val codegenContext: CodegenContext) : Protocol { private val restXml = codegenContext.serviceShape.expectTrait() private val runtimeConfig = codegenContext.runtimeConfig - private val errorScope = arrayOf( - "Bytes" to RuntimeType.Bytes, - "ErrorMetadataBuilder" to RuntimeType.errorMetadataBuilder(runtimeConfig), - "Headers" to RuntimeType.headers(runtimeConfig), - "XmlDecodeError" to RuntimeType.smithyXml(runtimeConfig).resolve("decode::XmlDecodeError"), - ) + private val errorScope = + arrayOf( + "Bytes" to RuntimeType.Bytes, + "ErrorMetadataBuilder" to RuntimeType.errorMetadataBuilder(runtimeConfig), + "Headers" to RuntimeType.headers(runtimeConfig), + "XmlDecodeError" to RuntimeType.smithyXml(runtimeConfig).resolve("decode::XmlDecodeError"), + ) - protected val restXmlErrors: RuntimeType = when (restXml.isNoErrorWrapping) { - true -> RuntimeType.unwrappedXmlErrors(runtimeConfig) - false -> RuntimeType.wrappedXmlErrors(runtimeConfig) - } + protected val restXmlErrors: RuntimeType = + when (restXml.isNoErrorWrapping) { + true -> RuntimeType.unwrappedXmlErrors(runtimeConfig) + false -> RuntimeType.wrappedXmlErrors(runtimeConfig) + } override val httpBindingResolver: HttpBindingResolver = HttpTraitHttpBindingResolver(codegenContext.model, ProtocolContentTypes("application/xml", "application/xml", "application/vnd.amazon.eventstream")) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/XmlNameIndex.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/XmlNameIndex.kt index e3f1be1ff7e..4495f6f92de 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/XmlNameIndex.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/XmlNameIndex.kt @@ -67,5 +67,6 @@ data class XmlMemberIndex(val dataMembers: List, val attributeMembe } fun isEmpty() = dataMembers.isEmpty() && attributeMembers.isEmpty() + fun isNotEmpty() = !isEmpty() } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/EventStreamUnmarshallerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/EventStreamUnmarshallerGenerator.kt index 47aadd368d3..b114911d926 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/EventStreamUnmarshallerGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/EventStreamUnmarshallerGenerator.kt @@ -43,8 +43,7 @@ import software.amazon.smithy.rust.codegen.core.util.expectTrait import software.amazon.smithy.rust.codegen.core.util.hasTrait import software.amazon.smithy.rust.codegen.core.util.toPascalCase -fun RustModule.Companion.eventStreamSerdeModule(): RustModule.LeafModule = - private("event_stream_serde") +fun RustModule.Companion.eventStreamSerdeModule(): RustModule.LeafModule = private("event_stream_serde") class EventStreamUnmarshallerGenerator( private val protocol: Protocol, @@ -58,28 +57,30 @@ class EventStreamUnmarshallerGenerator( private val codegenTarget = codegenContext.target private val runtimeConfig = codegenContext.runtimeConfig private val unionSymbol = symbolProvider.toSymbol(unionShape) - private val errorSymbol = if (codegenTarget == CodegenTarget.SERVER && unionShape.eventStreamErrors().isEmpty()) { - RuntimeType.smithyHttp(runtimeConfig).resolve("event_stream::MessageStreamError").toSymbol() - } else { - symbolProvider.symbolForEventStreamError(unionShape) - } + private val errorSymbol = + if (codegenTarget == CodegenTarget.SERVER && unionShape.eventStreamErrors().isEmpty()) { + RuntimeType.smithyHttp(runtimeConfig).resolve("event_stream::MessageStreamError").toSymbol() + } else { + symbolProvider.symbolForEventStreamError(unionShape) + } private val smithyEventStream = RuntimeType.smithyEventStream(runtimeConfig) private val smithyTypes = RuntimeType.smithyTypes(runtimeConfig) private val eventStreamSerdeModule = RustModule.eventStreamSerdeModule() - private val codegenScope = arrayOf( - "Blob" to RuntimeType.blob(runtimeConfig), - "expect_fns" to smithyEventStream.resolve("smithy"), - "MarshallMessage" to smithyEventStream.resolve("frame::MarshallMessage"), - "Message" to smithyTypes.resolve("event_stream::Message"), - "Header" to smithyTypes.resolve("event_stream::Header"), - "HeaderValue" to smithyTypes.resolve("event_stream::HeaderValue"), - "Error" to smithyEventStream.resolve("error::Error"), - "OpError" to errorSymbol, - "SmithyError" to RuntimeType.smithyTypes(runtimeConfig).resolve("Error"), - "tracing" to RuntimeType.Tracing, - "UnmarshalledMessage" to smithyEventStream.resolve("frame::UnmarshalledMessage"), - "UnmarshallMessage" to smithyEventStream.resolve("frame::UnmarshallMessage"), - ) + private val codegenScope = + arrayOf( + "Blob" to RuntimeType.blob(runtimeConfig), + "expect_fns" to smithyEventStream.resolve("smithy"), + "MarshallMessage" to smithyEventStream.resolve("frame::MarshallMessage"), + "Message" to smithyTypes.resolve("event_stream::Message"), + "Header" to smithyTypes.resolve("event_stream::Header"), + "HeaderValue" to smithyTypes.resolve("event_stream::HeaderValue"), + "Error" to smithyEventStream.resolve("error::Error"), + "OpError" to errorSymbol, + "SmithyError" to RuntimeType.smithyTypes(runtimeConfig).resolve("Error"), + "tracing" to RuntimeType.Tracing, + "UnmarshalledMessage" to smithyEventStream.resolve("frame::UnmarshalledMessage"), + "UnmarshallMessage" to smithyEventStream.resolve("frame::UnmarshallMessage"), + ) fun render(): RuntimeType { val unmarshallerType = unionShape.eventStreamUnmarshallerType() @@ -88,7 +89,10 @@ class EventStreamUnmarshallerGenerator( } } - private fun RustWriter.renderUnmarshaller(unmarshallerType: RuntimeType, unionSymbol: Symbol) { + private fun RustWriter.renderUnmarshaller( + unmarshallerType: RuntimeType, + unionSymbol: Symbol, + ) { val unmarshallerTypeName = unmarshallerType.name rust( """ @@ -139,11 +143,12 @@ class EventStreamUnmarshallerGenerator( } } - private fun expectedContentType(payloadTarget: Shape): String? = when (payloadTarget) { - is BlobShape -> "application/octet-stream" - is StringShape -> "text/plain" - else -> null - } + private fun expectedContentType(payloadTarget: Shape): String? = + when (payloadTarget) { + is BlobShape -> "application/octet-stream" + is StringShape -> "text/plain" + else -> null + } private fun RustWriter.renderUnmarshallEvent() { rustBlock("match response_headers.smithy_type.as_str()") { @@ -155,22 +160,27 @@ class EventStreamUnmarshallerGenerator( } rustBlock("_unknown_variant => ") { when (codegenTarget.renderUnknownVariant()) { - true -> rustTemplate( - "Ok(#{UnmarshalledMessage}::Event(#{Output}::${UnionGenerator.UnknownVariantName}))", - "Output" to unionSymbol, - *codegenScope, - ) + true -> + rustTemplate( + "Ok(#{UnmarshalledMessage}::Event(#{Output}::${UnionGenerator.UnknownVariantName}))", + "Output" to unionSymbol, + *codegenScope, + ) - false -> rustTemplate( - "return Err(#{Error}::unmarshalling(format!(\"unrecognized :event-type: {}\", _unknown_variant)));", - *codegenScope, - ) + false -> + rustTemplate( + "return Err(#{Error}::unmarshalling(format!(\"unrecognized :event-type: {}\", _unknown_variant)));", + *codegenScope, + ) } } } } - private fun RustWriter.renderUnmarshallUnionMember(unionMember: MemberShape, unionStruct: StructureShape) { + private fun RustWriter.renderUnmarshallUnionMember( + unionMember: MemberShape, + unionStruct: StructureShape, + ) { val unionMemberName = symbolProvider.toMemberName(unionMember) val empty = unionStruct.members().isEmpty() val payloadOnly = @@ -330,11 +340,12 @@ class EventStreamUnmarshallerGenerator( val syntheticUnion = unionShape.expectTrait() if (syntheticUnion.errorMembers.isNotEmpty()) { // clippy::single-match implied, using if when there's only one error - val (header, matchOperator) = if (syntheticUnion.errorMembers.size > 1) { - listOf("match response_headers.smithy_type.as_str() {", "=>") - } else { - listOf("if response_headers.smithy_type.as_str() == ", "") - } + val (header, matchOperator) = + if (syntheticUnion.errorMembers.size > 1) { + listOf("match response_headers.smithy_type.as_str() {", "=>") + } else { + listOf("if response_headers.smithy_type.as_str() == ", "") + } rust(header) for (member in syntheticUnion.errorMembers) { rustBlock("${member.memberName.dq()} $matchOperator ") { @@ -360,14 +371,15 @@ class EventStreamUnmarshallerGenerator( ) )) """, - "build" to builderInstantiator.finalizeBuilder( - "builder", target, - mapErr = { - rustTemplate( - """|err|#{Error}::unmarshalling(format!("{}", err))""", *codegenScope, - ) - }, - ), + "build" to + builderInstantiator.finalizeBuilder( + "builder", target, + mapErr = { + rustTemplate( + """|err|#{Error}::unmarshalling(format!("{}", err))""", *codegenScope, + ) + }, + ), "parser" to parser, *codegenScope, ) @@ -377,7 +389,12 @@ class EventStreamUnmarshallerGenerator( CodegenTarget.SERVER -> { val target = model.expectShape(member.target, StructureShape::class.java) val parser = protocol.structuredDataParser().errorParser(target) - val mut = if (parser != null) { " mut" } else { "" } + val mut = + if (parser != null) { + " mut" + } else { + "" + } rust("let$mut builder = #T::default();", symbolProvider.symbolForBuilder(target)) if (parser != null) { rustTemplate( diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt index 6b0a7b507bd..31a1b0c7b54 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt @@ -104,23 +104,24 @@ class JsonParserGenerator( private val smithyJson = CargoDependency.smithyJson(runtimeConfig).toType() private val protocolFunctions = ProtocolFunctions(codegenContext) private val builderInstantiator = codegenContext.builderInstantiator() - private val codegenScope = arrayOf( - "Error" to smithyJson.resolve("deserialize::error::DeserializeError"), - "expect_blob_or_null" to smithyJson.resolve("deserialize::token::expect_blob_or_null"), - "expect_bool_or_null" to smithyJson.resolve("deserialize::token::expect_bool_or_null"), - "expect_document" to smithyJson.resolve("deserialize::token::expect_document"), - "expect_number_or_null" to smithyJson.resolve("deserialize::token::expect_number_or_null"), - "expect_start_array" to smithyJson.resolve("deserialize::token::expect_start_array"), - "expect_start_object" to smithyJson.resolve("deserialize::token::expect_start_object"), - "expect_string_or_null" to smithyJson.resolve("deserialize::token::expect_string_or_null"), - "expect_timestamp_or_null" to smithyJson.resolve("deserialize::token::expect_timestamp_or_null"), - "json_token_iter" to smithyJson.resolve("deserialize::json_token_iter"), - "Peekable" to RuntimeType.std.resolve("iter::Peekable"), - "skip_value" to smithyJson.resolve("deserialize::token::skip_value"), - "skip_to_end" to smithyJson.resolve("deserialize::token::skip_to_end"), - "Token" to smithyJson.resolve("deserialize::Token"), - "or_empty" to orEmptyJson(), - ) + private val codegenScope = + arrayOf( + "Error" to smithyJson.resolve("deserialize::error::DeserializeError"), + "expect_blob_or_null" to smithyJson.resolve("deserialize::token::expect_blob_or_null"), + "expect_bool_or_null" to smithyJson.resolve("deserialize::token::expect_bool_or_null"), + "expect_document" to smithyJson.resolve("deserialize::token::expect_document"), + "expect_number_or_null" to smithyJson.resolve("deserialize::token::expect_number_or_null"), + "expect_start_array" to smithyJson.resolve("deserialize::token::expect_start_array"), + "expect_start_object" to smithyJson.resolve("deserialize::token::expect_start_object"), + "expect_string_or_null" to smithyJson.resolve("deserialize::token::expect_string_or_null"), + "expect_timestamp_or_null" to smithyJson.resolve("deserialize::token::expect_timestamp_or_null"), + "json_token_iter" to smithyJson.resolve("deserialize::json_token_iter"), + "Peekable" to RuntimeType.std.resolve("iter::Peekable"), + "skip_value" to smithyJson.resolve("deserialize::token::skip_value"), + "skip_to_end" to smithyJson.resolve("deserialize::token::skip_to_end"), + "Token" to smithyJson.resolve("deserialize::Token"), + "or_empty" to orEmptyJson(), + ) /** * Reusable structure parser implementation that can be used to generate parsing code for @@ -168,11 +169,12 @@ class JsonParserGenerator( *codegenScope, "ReturnType" to returnSymbolToParse.symbol, ) { - val input = if (shape is DocumentShape) { - "input" - } else { - "#{or_empty}(input)" - } + val input = + if (shape is DocumentShape) { + "input" + } else { + "#{or_empty}(input)" + } rustTemplate( """ @@ -212,19 +214,20 @@ class JsonParserGenerator( ) } - private fun orEmptyJson(): RuntimeType = ProtocolFunctions.crossOperationFn("or_empty_doc") { - rust( - """ - pub(crate) fn or_empty_doc(data: &[u8]) -> &[u8] { - if data.is_empty() { - b"{}" - } else { - data + private fun orEmptyJson(): RuntimeType = + ProtocolFunctions.crossOperationFn("or_empty_doc") { + rust( + """ + pub(crate) fn or_empty_doc(data: &[u8]) -> &[u8] { + if data.is_empty() { + b"{}" + } else { + data + } } - } - """, - ) - } + """, + ) + } override fun serverInputParser(operationShape: OperationShape): RuntimeType? { val includedMembers = httpBindingResolver.requestMembers(operationShape, HttpLocation.DOCUMENT) @@ -321,7 +324,10 @@ class JsonParserGenerator( } } - private fun RustWriter.deserializeStringInner(target: StringShape, escapedStrName: String) { + private fun RustWriter.deserializeStringInner( + target: StringShape, + escapedStrName: String, + ) { withBlock("$escapedStrName.to_unescaped().map(|u|", ")") { when (target.hasTrait()) { true -> { @@ -380,60 +386,61 @@ class JsonParserGenerator( private fun RustWriter.deserializeCollection(shape: CollectionShape) { val isSparse = shape.hasTrait() val (returnSymbol, returnUnconstrainedType) = returnSymbolToParse(shape) - val parser = protocolFunctions.deserializeFn(shape) { fnName -> - rustBlockTemplate( - """ - pub(crate) fn $fnName<'a, I>(tokens: &mut #{Peekable}) -> Result, #{Error}> - where I: Iterator, #{Error}>> - """, - "ReturnType" to returnSymbol, - *codegenScope, - ) { - startArrayOrNull { - rust("let mut items = Vec::new();") - rustBlock("loop") { - rustBlock("match tokens.peek()") { - rustBlockTemplate("Some(Ok(#{Token}::EndArray { .. })) =>", *codegenScope) { - rust("tokens.next().transpose().unwrap(); break;") - } - rustBlock("_ => ") { - if (isSparse) { - withBlock("items.push(", ");") { - deserializeMember(shape.member) - } - } else { - withBlock("let value =", ";") { - deserializeMember(shape.member) - } - rust( - """ - if let Some(value) = value { - items.push(value); + val parser = + protocolFunctions.deserializeFn(shape) { fnName -> + rustBlockTemplate( + """ + pub(crate) fn $fnName<'a, I>(tokens: &mut #{Peekable}) -> Result, #{Error}> + where I: Iterator, #{Error}>> + """, + "ReturnType" to returnSymbol, + *codegenScope, + ) { + startArrayOrNull { + rust("let mut items = Vec::new();") + rustBlock("loop") { + rustBlock("match tokens.peek()") { + rustBlockTemplate("Some(Ok(#{Token}::EndArray { .. })) =>", *codegenScope) { + rust("tokens.next().transpose().unwrap(); break;") + } + rustBlock("_ => ") { + if (isSparse) { + withBlock("items.push(", ");") { + deserializeMember(shape.member) } - """, - ) - codegenTarget.ifServer { - rustTemplate( + } else { + withBlock("let value =", ";") { + deserializeMember(shape.member) + } + rust( """ - else { - return Err(#{Error}::custom("dense list cannot contain null values")); + if let Some(value) = value { + items.push(value); } """, - *codegenScope, ) + codegenTarget.ifServer { + rustTemplate( + """ + else { + return Err(#{Error}::custom("dense list cannot contain null values")); + } + """, + *codegenScope, + ) + } } } } } - } - if (returnUnconstrainedType) { - rust("Ok(Some(#{T}(items)))", returnSymbol) - } else { - rust("Ok(Some(items))") + if (returnUnconstrainedType) { + rust("Ok(Some(#{T}(items)))", returnSymbol) + } else { + rust("Ok(Some(items))") + } } } } - } rust("#T(tokens)?", parser) } @@ -441,181 +448,190 @@ class JsonParserGenerator( val keyTarget = model.expectShape(shape.key.target) as StringShape val isSparse = shape.hasTrait() val returnSymbolToParse = returnSymbolToParse(shape) - val parser = protocolFunctions.deserializeFn(shape) { fnName -> - rustBlockTemplate( - """ - pub(crate) fn $fnName<'a, I>(tokens: &mut #{Peekable}) -> Result, #{Error}> - where I: Iterator, #{Error}>> - """, - "ReturnType" to returnSymbolToParse.symbol, - *codegenScope, - ) { - startObjectOrNull { - rust("let mut map = #T::new();", RuntimeType.HashMap) - objectKeyLoop(hasMembers = true) { - withBlock("let key =", "?;") { - deserializeStringInner(keyTarget, "key") - } - withBlock("let value =", ";") { - deserializeMember(shape.value) - } - if (isSparse) { - rust("map.insert(key, value);") - } else { - codegenTarget.ifServer { - rustTemplate( - """ - match value { - Some(value) => { map.insert(key, value); } - None => return Err(#{Error}::custom("dense map cannot contain null values")) - }""", - *codegenScope, - ) + val parser = + protocolFunctions.deserializeFn(shape) { fnName -> + rustBlockTemplate( + """ + pub(crate) fn $fnName<'a, I>(tokens: &mut #{Peekable}) -> Result, #{Error}> + where I: Iterator, #{Error}>> + """, + "ReturnType" to returnSymbolToParse.symbol, + *codegenScope, + ) { + startObjectOrNull { + rust("let mut map = #T::new();", RuntimeType.HashMap) + objectKeyLoop(hasMembers = true) { + withBlock("let key =", "?;") { + deserializeStringInner(keyTarget, "key") } - codegenTarget.ifClient { - rustTemplate( - """ - if let Some(value) = value { - map.insert(key, value); - } - """, - ) + withBlock("let value =", ";") { + deserializeMember(shape.value) + } + if (isSparse) { + rust("map.insert(key, value);") + } else { + codegenTarget.ifServer { + rustTemplate( + """ + match value { + Some(value) => { map.insert(key, value); } + None => return Err(#{Error}::custom("dense map cannot contain null values")) + }""", + *codegenScope, + ) + } + codegenTarget.ifClient { + rustTemplate( + """ + if let Some(value) = value { + map.insert(key, value); + } + """, + ) + } } } - } - if (returnSymbolToParse.isUnconstrained) { - rust("Ok(Some(#{T}(map)))", returnSymbolToParse.symbol) - } else { - rust("Ok(Some(map))") + if (returnSymbolToParse.isUnconstrained) { + rust("Ok(Some(#{T}(map)))", returnSymbolToParse.symbol) + } else { + rust("Ok(Some(map))") + } } } } - } rust("#T(tokens)?", parser) } private fun RustWriter.deserializeStruct(shape: StructureShape) { val returnSymbolToParse = returnSymbolToParse(shape) - val nestedParser = protocolFunctions.deserializeFn(shape) { fnName -> - rustBlockTemplate( - """ - pub(crate) fn $fnName<'a, I>(tokens: &mut #{Peekable}) -> Result, #{Error}> - where I: Iterator, #{Error}>> - """, - "ReturnType" to returnSymbolToParse.symbol, - *codegenScope, - ) { - startObjectOrNull { - Attribute.AllowUnusedMut.render(this) - rustTemplate( - "let mut builder = #{Builder}::default();", - *codegenScope, - "Builder" to symbolProvider.symbolForBuilder(shape), - ) - deserializeStructInner(shape.members()) - val builder = builderInstantiator.finalizeBuilder( - "builder", shape, - ) { + val nestedParser = + protocolFunctions.deserializeFn(shape) { fnName -> + rustBlockTemplate( + """ + pub(crate) fn $fnName<'a, I>(tokens: &mut #{Peekable}) -> Result, #{Error}> + where I: Iterator, #{Error}>> + """, + "ReturnType" to returnSymbolToParse.symbol, + *codegenScope, + ) { + startObjectOrNull { + Attribute.AllowUnusedMut.render(this) rustTemplate( - """|err|#{Error}::custom_source("Response was invalid", err)""", *codegenScope, + "let mut builder = #{Builder}::default();", + *codegenScope, + "Builder" to symbolProvider.symbolForBuilder(shape), ) + deserializeStructInner(shape.members()) + val builder = + builderInstantiator.finalizeBuilder( + "builder", shape, + ) { + rustTemplate( + """|err|#{Error}::custom_source("Response was invalid", err)""", *codegenScope, + ) + } + rust("Ok(Some(#T))", builder) } - rust("Ok(Some(#T))", builder) } } - } rust("#T(tokens)?", nestedParser) } private fun RustWriter.deserializeUnion(shape: UnionShape) { val returnSymbolToParse = returnSymbolToParse(shape) - val nestedParser = protocolFunctions.deserializeFn(shape) { fnName -> - rustBlockTemplate( - """ - pub(crate) fn $fnName<'a, I>(tokens: &mut #{Peekable}) -> Result, #{Error}> - where I: Iterator, #{Error}>> - """, - *codegenScope, - "Shape" to returnSymbolToParse.symbol, - ) { - rust("let mut variant = None;") - val checkValueSet = !shape.members().all { it.isTargetUnit() } && !codegenTarget.renderUnknownVariant() - rustBlock("match tokens.next().transpose()?") { - rustBlockTemplate( - """ - Some(#{Token}::ValueNull { .. }) => return Ok(None), - Some(#{Token}::StartObject { .. }) => - """, - *codegenScope, - ) { - objectKeyLoop(hasMembers = shape.members().isNotEmpty()) { - rustTemplate( - """ - let key = key.to_unescaped()?; - if key == "__type" { - #{skip_value}(tokens)?; - continue - } - if variant.is_some() { - return Err(#{Error}::custom("encountered mixed variants in union")); - } - """, - *codegenScope, - ) - withBlock("variant = match key.as_ref() {", "};") { - for (member in shape.members()) { - val variantName = symbolProvider.toMemberName(member) - rustBlock("${jsonName(member).dq()} =>") { - if (member.isTargetUnit()) { + val nestedParser = + protocolFunctions.deserializeFn(shape) { fnName -> + rustBlockTemplate( + """ + pub(crate) fn $fnName<'a, I>(tokens: &mut #{Peekable}) -> Result, #{Error}> + where I: Iterator, #{Error}>> + """, + *codegenScope, + "Shape" to returnSymbolToParse.symbol, + ) { + rust("let mut variant = None;") + val checkValueSet = !shape.members().all { it.isTargetUnit() } && !codegenTarget.renderUnknownVariant() + rustBlock("match tokens.next().transpose()?") { + rustBlockTemplate( + """ + Some(#{Token}::ValueNull { .. }) => return Ok(None), + Some(#{Token}::StartObject { .. }) => + """, + *codegenScope, + ) { + objectKeyLoop(hasMembers = shape.members().isNotEmpty()) { + rustTemplate( + """ + let key = key.to_unescaped()?; + if key == "__type" { + #{skip_value}(tokens)?; + continue + } + if variant.is_some() { + return Err(#{Error}::custom("encountered mixed variants in union")); + } + """, + *codegenScope, + ) + withBlock("variant = match key.as_ref() {", "};") { + for (member in shape.members()) { + val variantName = symbolProvider.toMemberName(member) + rustBlock("${jsonName(member).dq()} =>") { + if (member.isTargetUnit()) { + rustTemplate( + """ + #{skip_value}(tokens)?; + Some(#{Union}::$variantName) + """, + "Union" to returnSymbolToParse.symbol, *codegenScope, + ) + } else { + withBlock("Some(#T::$variantName(", "))", returnSymbolToParse.symbol) { + deserializeMember(member) + unwrapOrDefaultOrError(member, checkValueSet) + } + } + } + } + when (codegenTarget.renderUnknownVariant()) { + // In client mode, resolve an unknown union variant to the unknown variant. + true -> rustTemplate( """ - #{skip_value}(tokens)?; - Some(#{Union}::$variantName) + _ => { + #{skip_value}(tokens)?; + Some(#{Union}::${UnionGenerator.UnknownVariantName}) + } """, - "Union" to returnSymbolToParse.symbol, *codegenScope, + "Union" to returnSymbolToParse.symbol, + *codegenScope, + ) + // In server mode, use strict parsing. + // Consultation: https://github.com/awslabs/smithy/issues/1222 + false -> + rustTemplate( + """variant => return Err(#{Error}::custom(format!("unexpected union variant: {}", variant)))""", + *codegenScope, ) - } else { - withBlock("Some(#T::$variantName(", "))", returnSymbolToParse.symbol) { - deserializeMember(member) - unwrapOrDefaultOrError(member, checkValueSet) - } - } } } - when (codegenTarget.renderUnknownVariant()) { - // In client mode, resolve an unknown union variant to the unknown variant. - true -> rustTemplate( - """ - _ => { - #{skip_value}(tokens)?; - Some(#{Union}::${UnionGenerator.UnknownVariantName}) - } - """, - "Union" to returnSymbolToParse.symbol, - *codegenScope, - ) - // In server mode, use strict parsing. - // Consultation: https://github.com/awslabs/smithy/issues/1222 - false -> rustTemplate( - """variant => return Err(#{Error}::custom(format!("unexpected union variant: {}", variant)))""", - *codegenScope, - ) - } } } + rustTemplate( + """_ => return Err(#{Error}::custom("expected start object or null"))""", + *codegenScope, + ) } - rustTemplate( - """_ => return Err(#{Error}::custom("expected start object or null"))""", - *codegenScope, - ) + rust("Ok(variant)") } - rust("Ok(variant)") } - } rust("#T(tokens)?", nestedParser) } - private fun RustWriter.unwrapOrDefaultOrError(member: MemberShape, checkValueSet: Boolean) { + private fun RustWriter.unwrapOrDefaultOrError( + member: MemberShape, + checkValueSet: Boolean, + ) { if (symbolProvider.toSymbol(member).canUseDefault() && !checkValueSet) { rust(".unwrap_or_default()") } else { @@ -626,7 +642,10 @@ class JsonParserGenerator( } } - private fun RustWriter.objectKeyLoop(hasMembers: Boolean, inner: Writable) { + private fun RustWriter.objectKeyLoop( + hasMembers: Boolean, + inner: Writable, + ) { if (!hasMembers) { rustTemplate("#{skip_to_end}(tokens)?;", *codegenScope) } else { @@ -651,8 +670,13 @@ class JsonParserGenerator( } private fun RustWriter.startArrayOrNull(inner: Writable) = startOrNull("array", inner) + private fun RustWriter.startObjectOrNull(inner: Writable) = startOrNull("object", inner) - private fun RustWriter.startOrNull(objectOrArray: String, inner: Writable) { + + private fun RustWriter.startOrNull( + objectOrArray: String, + inner: Writable, + ) { rustBlockTemplate("match tokens.next().transpose()?", *codegenScope) { rustBlockTemplate( """ diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/RestXmlParserGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/RestXmlParserGenerator.kt index d37413f29f2..cea67a46c2b 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/RestXmlParserGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/RestXmlParserGenerator.kt @@ -24,11 +24,12 @@ class RestXmlParserGenerator( ) { context, inner -> val shapeName = context.outputShapeName // Get the non-synthetic version of the outputShape and check to see if it has the `AllowInvalidXmlRoot` trait - val allowInvalidRoot = context.model.getShape(context.shape.outputShape).orNull().let { shape -> - shape?.getTrait()?.originalId.let { shapeId -> - context.model.getShape(shapeId).orNull()?.hasTrait() ?: false + val allowInvalidRoot = + context.model.getShape(context.shape.outputShape).orNull().let { shape -> + shape?.getTrait()?.originalId.let { shapeId -> + context.model.getShape(shapeId).orNull()?.hasTrait() ?: false + } } - } // If we DON'T allow the XML root to be invalid, insert code to check for and report a mismatch if (!allowInvalidRoot) { diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/XmlBindingTraitParserGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/XmlBindingTraitParserGenerator.kt index 03caf527a89..a93981df211 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/XmlBindingTraitParserGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/XmlBindingTraitParserGenerator.kt @@ -70,7 +70,6 @@ class XmlBindingTraitParserGenerator( private val xmlErrors: RuntimeType, private val writeOperationWrapper: RustWriter.(OperationWrapperContext, OperationInnerWriteable) -> Unit, ) : StructuredDataParserGenerator { - /** Abstraction to represent an XML element name */ data class XmlName(val name: String) { /** Generates an expression to match a given element against this XML tag name */ @@ -103,15 +102,16 @@ class XmlBindingTraitParserGenerator( private val builderInstantiator = codegenContext.builderInstantiator() // The symbols we want all the time - private val codegenScope = arrayOf( - "Blob" to RuntimeType.blob(runtimeConfig), - "Document" to smithyXml.resolve("decode::Document"), - "XmlDecodeError" to xmlDecodeError, - "next_start_element" to smithyXml.resolve("decode::next_start_element"), - "try_data" to smithyXml.resolve("decode::try_data"), - "ScopedDecoder" to scopedDecoder, - "aws_smithy_types" to CargoDependency.smithyTypes(runtimeConfig).toType(), - ) + private val codegenScope = + arrayOf( + "Blob" to RuntimeType.blob(runtimeConfig), + "Document" to smithyXml.resolve("decode::Document"), + "XmlDecodeError" to xmlDecodeError, + "next_start_element" to smithyXml.resolve("decode::next_start_element"), + "try_data" to smithyXml.resolve("decode::try_data"), + "ScopedDecoder" to scopedDecoder, + "aws_smithy_types" to CargoDependency.smithyTypes(runtimeConfig).toType(), + ) private val model = codegenContext.model private val index = HttpBindingIndex.of(model) private val xmlIndex = XmlNameIndex.of(model) @@ -305,7 +305,11 @@ class XmlBindingTraitParserGenerator( /** * Update a structure builder based on the [members], specifying where to find each member (document vs. attributes) */ - private fun RustWriter.parseStructureInner(members: XmlMemberIndex, builder: String, outerCtx: Ctx) { + private fun RustWriter.parseStructureInner( + members: XmlMemberIndex, + builder: String, + outerCtx: Ctx, + ) { members.attributeMembers.forEach { member -> val temp = safeName("attrib") withBlock("let $temp =", ";") { @@ -339,7 +343,11 @@ class XmlBindingTraitParserGenerator( * generates a match expression * When [ignoreUnexpected] is true, unexpected tags are ignored */ - private fun RustWriter.parseLoop(ctx: Ctx, ignoreUnexpected: Boolean = true, inner: RustWriter.(Ctx) -> Unit) { + private fun RustWriter.parseLoop( + ctx: Ctx, + ignoreUnexpected: Boolean = true, + inner: RustWriter.(Ctx) -> Unit, + ) { rustBlock("while let Some(mut tag) = ${ctx.tag}.next_tag()") { rustBlock("match tag.start_el()") { inner(ctx.copy(tag = "tag")) @@ -353,7 +361,11 @@ class XmlBindingTraitParserGenerator( /** * Generate an XML parser for a given member */ - private fun RustWriter.parseMember(memberShape: MemberShape, ctx: Ctx, forceOptional: Boolean = false) { + private fun RustWriter.parseMember( + memberShape: MemberShape, + ctx: Ctx, + forceOptional: Boolean = false, + ) { val target = model.expectShape(memberShape.target) val symbol = symbolProvider.toSymbol(memberShape) conditionalBlock("Some(", ")", forceOptional || symbol.isOptional()) { @@ -364,17 +376,19 @@ class XmlBindingTraitParserGenerator( rustTemplate("#{try_data}(&mut ${ctx.tag})?.as_ref()", *codegenScope) } - is MapShape -> if (memberShape.isFlattened()) { - parseFlatMap(target, ctx) - } else { - parseMap(target, ctx) - } + is MapShape -> + if (memberShape.isFlattened()) { + parseFlatMap(target, ctx) + } else { + parseMap(target, ctx) + } - is CollectionShape -> if (memberShape.isFlattened()) { - parseFlatList(target, ctx) - } else { - parseList(target, ctx) - } + is CollectionShape -> + if (memberShape.isFlattened()) { + parseFlatList(target, ctx) + } else { + parseList(target, ctx) + } is StructureShape -> { parseStructure(target, ctx) @@ -389,7 +403,10 @@ class XmlBindingTraitParserGenerator( } } - private fun RustWriter.parseAttributeMember(memberShape: MemberShape, ctx: Ctx) { + private fun RustWriter.parseAttributeMember( + memberShape: MemberShape, + ctx: Ctx, + ) { rustBlock("") { rustTemplate( """ @@ -411,58 +428,66 @@ class XmlBindingTraitParserGenerator( } } - private fun RustWriter.parseUnion(shape: UnionShape, ctx: Ctx) { + private fun RustWriter.parseUnion( + shape: UnionShape, + ctx: Ctx, + ) { val symbol = symbolProvider.toSymbol(shape) - val nestedParser = protocolFunctions.deserializeFn(shape) { fnName -> - rustBlockTemplate( - "pub fn $fnName(decoder: &mut #{ScopedDecoder}) -> Result<#{Shape}, #{XmlDecodeError}>", - *codegenScope, "Shape" to symbol, - ) { - val members = shape.members() - rustTemplate("let mut base: Option<#{Shape}> = None;", *codegenScope, "Shape" to symbol) - parseLoop(Ctx(tag = "decoder", accum = null), ignoreUnexpected = false) { ctx -> - members.forEach { member -> - val variantName = symbolProvider.toMemberName(member) - case(member) { - if (member.isTargetUnit()) { - rust("base = Some(#T::$variantName);", symbol) - } else { - val current = - """ - (match base.take() { - None => None, - Some(${format(symbol)}::$variantName(inner)) => Some(inner), - Some(_) => return Err(#{XmlDecodeError}::custom("mixed variants")) - }) - """ - withBlock("let tmp =", ";") { - parseMember(member, ctx.copy(accum = current.trim())) + val nestedParser = + protocolFunctions.deserializeFn(shape) { fnName -> + rustBlockTemplate( + "pub fn $fnName(decoder: &mut #{ScopedDecoder}) -> Result<#{Shape}, #{XmlDecodeError}>", + *codegenScope, "Shape" to symbol, + ) { + val members = shape.members() + rustTemplate("let mut base: Option<#{Shape}> = None;", *codegenScope, "Shape" to symbol) + parseLoop(Ctx(tag = "decoder", accum = null), ignoreUnexpected = false) { ctx -> + members.forEach { member -> + val variantName = symbolProvider.toMemberName(member) + case(member) { + if (member.isTargetUnit()) { + rust("base = Some(#T::$variantName);", symbol) + } else { + val current = + """ + (match base.take() { + None => None, + Some(${format(symbol)}::$variantName(inner)) => Some(inner), + Some(_) => return Err(#{XmlDecodeError}::custom("mixed variants")) + }) + """ + withBlock("let tmp =", ";") { + parseMember(member, ctx.copy(accum = current.trim())) + } + rust("base = Some(#T::$variantName(tmp));", symbol) } - rust("base = Some(#T::$variantName(tmp));", symbol) } } + when (target.renderUnknownVariant()) { + true -> rust("_unknown => base = Some(#T::${UnionGenerator.UnknownVariantName}),", symbol) + false -> + rustTemplate( + """variant => return Err(#{XmlDecodeError}::custom(format!("unexpected union variant: {:?}", variant)))""", + *codegenScope, + ) + } } - when (target.renderUnknownVariant()) { - true -> rust("_unknown => base = Some(#T::${UnionGenerator.UnknownVariantName}),", symbol) - false -> rustTemplate( - """variant => return Err(#{XmlDecodeError}::custom(format!("unexpected union variant: {:?}", variant)))""", - *codegenScope, - ) - } + rustTemplate( + """base.ok_or_else(||#{XmlDecodeError}::custom("expected union, got nothing"))""", + *codegenScope, + ) } - rustTemplate( - """base.ok_or_else(||#{XmlDecodeError}::custom("expected union, got nothing"))""", - *codegenScope, - ) } - } rust("#T(&mut ${ctx.tag})", nestedParser) } /** * The match clause to check if the tag matches a given member */ - private fun RustWriter.case(member: MemberShape, inner: Writable) { + private fun RustWriter.case( + member: MemberShape, + inner: Writable, + ) { rustBlock( "s if ${ member.xmlName().matchExpression("s") @@ -473,61 +498,73 @@ class XmlBindingTraitParserGenerator( rust(",") } - private fun RustWriter.parseStructure(shape: StructureShape, ctx: Ctx) { + private fun RustWriter.parseStructure( + shape: StructureShape, + ctx: Ctx, + ) { val symbol = symbolProvider.toSymbol(shape) - val nestedParser = protocolFunctions.deserializeFn(shape) { fnName -> - Attribute.AllowNeedlessQuestionMark.render(this) - rustBlockTemplate( - "pub fn $fnName(decoder: &mut #{ScopedDecoder}) -> Result<#{Shape}, #{XmlDecodeError}>", - *codegenScope, "Shape" to symbol, - ) { - Attribute.AllowUnusedMut.render(this) - rustTemplate("let mut builder = #{Shape}::builder();", *codegenScope, "Shape" to symbol) - val members = shape.xmlMembers() - if (members.isNotEmpty()) { - parseStructureInner(members, "builder", Ctx(tag = "decoder", accum = null)) - } else { - rust("let _ = decoder;") - } - val builder = builderInstantiator.finalizeBuilder( - "builder", - shape, - mapErr = { - rustTemplate( - """|_|#{XmlDecodeError}::custom("missing field")""", - *codegenScope, + val nestedParser = + protocolFunctions.deserializeFn(shape) { fnName -> + Attribute.AllowNeedlessQuestionMark.render(this) + rustBlockTemplate( + "pub fn $fnName(decoder: &mut #{ScopedDecoder}) -> Result<#{Shape}, #{XmlDecodeError}>", + *codegenScope, "Shape" to symbol, + ) { + Attribute.AllowUnusedMut.render(this) + rustTemplate("let mut builder = #{Shape}::builder();", *codegenScope, "Shape" to symbol) + val members = shape.xmlMembers() + if (members.isNotEmpty()) { + parseStructureInner(members, "builder", Ctx(tag = "decoder", accum = null)) + } else { + rust("let _ = decoder;") + } + val builder = + builderInstantiator.finalizeBuilder( + "builder", + shape, + mapErr = { + rustTemplate( + """|_|#{XmlDecodeError}::custom("missing field")""", + *codegenScope, + ) + }, ) - }, - ) - rust("Ok(#T)", builder) + rust("Ok(#T)", builder) + } } - } rust("#T(&mut ${ctx.tag})", nestedParser) } - private fun RustWriter.parseList(target: CollectionShape, ctx: Ctx) { + private fun RustWriter.parseList( + target: CollectionShape, + ctx: Ctx, + ) { val member = target.member - val listParser = protocolFunctions.deserializeFn(target) { fnName -> - rustBlockTemplate( - "pub fn $fnName(decoder: &mut #{ScopedDecoder}) -> Result<#{List}, #{XmlDecodeError}>", - *codegenScope, - "List" to symbolProvider.toSymbol(target), - ) { - rust("let mut out = std::vec::Vec::new();") - parseLoop(Ctx(tag = "decoder", accum = null)) { ctx -> - case(member) { - withBlock("out.push(", ");") { - parseMember(member, ctx) + val listParser = + protocolFunctions.deserializeFn(target) { fnName -> + rustBlockTemplate( + "pub fn $fnName(decoder: &mut #{ScopedDecoder}) -> Result<#{List}, #{XmlDecodeError}>", + *codegenScope, + "List" to symbolProvider.toSymbol(target), + ) { + rust("let mut out = std::vec::Vec::new();") + parseLoop(Ctx(tag = "decoder", accum = null)) { ctx -> + case(member) { + withBlock("out.push(", ");") { + parseMember(member, ctx) + } } } + rust("Ok(out)") } - rust("Ok(out)") } - } rust("#T(&mut ${ctx.tag})", listParser) } - private fun RustWriter.parseFlatList(target: CollectionShape, ctx: Ctx) { + private fun RustWriter.parseFlatList( + target: CollectionShape, + ctx: Ctx, + ) { val list = safeName("list") withBlock("Result::<#T, #T>::Ok({", "})", symbolProvider.toSymbol(target), xmlDecodeError) { val accum = ctx.accum ?: throw CodegenException("Need accum to parse flat list") @@ -539,26 +576,33 @@ class XmlBindingTraitParserGenerator( } } - private fun RustWriter.parseMap(target: MapShape, ctx: Ctx) { - val mapParser = protocolFunctions.deserializeFn(target) { fnName -> - rustBlockTemplate( - "pub fn $fnName(decoder: &mut #{ScopedDecoder}) -> Result<#{Map}, #{XmlDecodeError}>", - *codegenScope, - "Map" to symbolProvider.toSymbol(target), - ) { - rust("let mut out = #T::new();", RuntimeType.HashMap) - parseLoop(Ctx(tag = "decoder", accum = null)) { ctx -> - rustBlock("s if ${XmlName("entry").matchExpression("s")} => ") { - rust("#T(&mut ${ctx.tag}, &mut out)?;", mapEntryParser(target, ctx)) + private fun RustWriter.parseMap( + target: MapShape, + ctx: Ctx, + ) { + val mapParser = + protocolFunctions.deserializeFn(target) { fnName -> + rustBlockTemplate( + "pub fn $fnName(decoder: &mut #{ScopedDecoder}) -> Result<#{Map}, #{XmlDecodeError}>", + *codegenScope, + "Map" to symbolProvider.toSymbol(target), + ) { + rust("let mut out = #T::new();", RuntimeType.HashMap) + parseLoop(Ctx(tag = "decoder", accum = null)) { ctx -> + rustBlock("s if ${XmlName("entry").matchExpression("s")} => ") { + rust("#T(&mut ${ctx.tag}, &mut out)?;", mapEntryParser(target, ctx)) + } } + rust("Ok(out)") } - rust("Ok(out)") } - } rust("#T(&mut ${ctx.tag})", mapParser) } - private fun RustWriter.parseFlatMap(target: MapShape, ctx: Ctx) { + private fun RustWriter.parseFlatMap( + target: MapShape, + ctx: Ctx, + ) { val map = safeName("map") val entryDecoder = mapEntryParser(target, ctx) withBlock("Result::<#T, #T>::Ok({", "})", symbolProvider.toSymbol(target), xmlDecodeError) { @@ -575,7 +619,10 @@ class XmlBindingTraitParserGenerator( } } - private fun mapEntryParser(target: MapShape, ctx: Ctx): RuntimeType { + private fun mapEntryParser( + target: MapShape, + ctx: Ctx, + ): RuntimeType { return protocolFunctions.deserializeFn(target, "entry") { fnName -> rustBlockTemplate( "pub fn $fnName(decoder: &mut #{ScopedDecoder}, out: &mut #{Map}) -> Result<(), #{XmlDecodeError}>", @@ -618,7 +665,10 @@ class XmlBindingTraitParserGenerator( * Parse a simple member from a data field * [provider] generates code for the inner data field */ - private fun RustWriter.parsePrimitiveInner(member: MemberShape, provider: Writable) { + private fun RustWriter.parsePrimitiveInner( + member: MemberShape, + provider: Writable, + ) { when (val shape = model.expectShape(member.target)) { is StringShape -> parseStringInner(shape, provider) is NumberShape, is BooleanShape -> { @@ -671,7 +721,10 @@ class XmlBindingTraitParserGenerator( } } - private fun RustWriter.parseStringInner(shape: StringShape, provider: Writable) { + private fun RustWriter.parseStringInner( + shape: StringShape, + provider: Writable, + ) { withBlock("Result::<#T, #T>::Ok(", ")", symbolProvider.toSymbol(shape), xmlDecodeError) { if (shape.hasTrait()) { val enumSymbol = symbolProvider.toSymbol(shape) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/EventStreamErrorMarshallerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/EventStreamErrorMarshallerGenerator.kt index b9491f29e8c..97c8adb48a3 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/EventStreamErrorMarshallerGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/EventStreamErrorMarshallerGenerator.kt @@ -46,20 +46,22 @@ class EventStreamErrorMarshallerGenerator( private val smithyEventStream = RuntimeType.smithyEventStream(runtimeConfig) private val smithyTypes = RuntimeType.smithyTypes(runtimeConfig) - private val operationErrorSymbol = if (target == CodegenTarget.SERVER && unionShape.eventStreamErrors().isEmpty()) { - RuntimeType.smithyHttp(runtimeConfig).resolve("event_stream::MessageStreamError").toSymbol() - } else { - symbolProvider.symbolForEventStreamError(unionShape) - } + private val operationErrorSymbol = + if (target == CodegenTarget.SERVER && unionShape.eventStreamErrors().isEmpty()) { + RuntimeType.smithyHttp(runtimeConfig).resolve("event_stream::MessageStreamError").toSymbol() + } else { + symbolProvider.symbolForEventStreamError(unionShape) + } private val eventStreamSerdeModule = RustModule.eventStreamSerdeModule() private val errorsShape = unionShape.expectTrait() - private val codegenScope = arrayOf( - "MarshallMessage" to smithyEventStream.resolve("frame::MarshallMessage"), - "Message" to smithyTypes.resolve("event_stream::Message"), - "Header" to smithyTypes.resolve("event_stream::Header"), - "HeaderValue" to smithyTypes.resolve("event_stream::HeaderValue"), - "Error" to smithyEventStream.resolve("error::Error"), - ) + private val codegenScope = + arrayOf( + "MarshallMessage" to smithyEventStream.resolve("frame::MarshallMessage"), + "Message" to smithyTypes.resolve("event_stream::Message"), + "Header" to smithyTypes.resolve("event_stream::Header"), + "HeaderValue" to smithyTypes.resolve("event_stream::HeaderValue"), + "Error" to smithyEventStream.resolve("error::Error"), + ) override fun render(): RuntimeType { val marshallerType = unionShape.eventStreamMarshallerType() @@ -70,7 +72,10 @@ class EventStreamErrorMarshallerGenerator( } } - private fun RustWriter.renderMarshaller(marshallerType: RuntimeType, unionSymbol: Symbol) { + private fun RustWriter.renderMarshaller( + marshallerType: RuntimeType, + unionSymbol: Symbol, + ) { rust( """ ##[non_exhaustive] @@ -128,7 +133,10 @@ class EventStreamErrorMarshallerGenerator( } } - private fun RustWriter.renderMarshallEvent(unionMember: MemberShape, eventStruct: StructureShape) { + private fun RustWriter.renderMarshallEvent( + unionMember: MemberShape, + eventStruct: StructureShape, + ) { val headerMembers = eventStruct.members().filter { it.hasTrait() } val payloadMember = eventStruct.members().firstOrNull { it.hasTrait() } for (member in headerMembers) { diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/EventStreamMarshallerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/EventStreamMarshallerGenerator.kt index ac6cf88ccc2..84e31a6173c 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/EventStreamMarshallerGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/EventStreamMarshallerGenerator.kt @@ -56,13 +56,14 @@ open class EventStreamMarshallerGenerator( private val smithyEventStream = RuntimeType.smithyEventStream(runtimeConfig) private val smithyTypes = RuntimeType.smithyTypes(runtimeConfig) private val eventStreamSerdeModule = RustModule.eventStreamSerdeModule() - private val codegenScope = arrayOf( - "MarshallMessage" to smithyEventStream.resolve("frame::MarshallMessage"), - "Message" to smithyTypes.resolve("event_stream::Message"), - "Header" to smithyTypes.resolve("event_stream::Header"), - "HeaderValue" to smithyTypes.resolve("event_stream::HeaderValue"), - "Error" to smithyEventStream.resolve("error::Error"), - ) + private val codegenScope = + arrayOf( + "MarshallMessage" to smithyEventStream.resolve("frame::MarshallMessage"), + "Message" to smithyTypes.resolve("event_stream::Message"), + "Header" to smithyTypes.resolve("event_stream::Header"), + "HeaderValue" to smithyTypes.resolve("event_stream::HeaderValue"), + "Error" to smithyEventStream.resolve("error::Error"), + ) open fun render(): RuntimeType { val marshallerType = unionShape.eventStreamMarshallerType() @@ -73,7 +74,10 @@ open class EventStreamMarshallerGenerator( } } - private fun RustWriter.renderMarshaller(marshallerType: RuntimeType, unionSymbol: Symbol) { + private fun RustWriter.renderMarshaller( + marshallerType: RuntimeType, + unionSymbol: Symbol, + ) { rust( """ ##[non_exhaustive] @@ -125,7 +129,10 @@ open class EventStreamMarshallerGenerator( } } - private fun RustWriter.renderMarshallEvent(unionMember: MemberShape, eventStruct: StructureShape) { + private fun RustWriter.renderMarshallEvent( + unionMember: MemberShape, + eventStruct: StructureShape, + ) { val headerMembers = eventStruct.members().filter { it.hasTrait() } val payloadMember = eventStruct.members().firstOrNull { it.hasTrait() } for (member in headerMembers) { @@ -146,7 +153,11 @@ open class EventStreamMarshallerGenerator( } } - protected fun RustWriter.renderMarshallEventHeader(memberName: String, member: MemberShape, target: Shape) { + protected fun RustWriter.renderMarshallEventHeader( + memberName: String, + member: MemberShape, + target: Shape, + ) { val headerName = member.memberName handleOptional( symbolProvider.toSymbol(member).isOptional(), @@ -156,7 +167,11 @@ open class EventStreamMarshallerGenerator( ) } - private fun RustWriter.renderAddHeader(headerName: String, inputName: String, target: Shape) { + private fun RustWriter.renderAddHeader( + headerName: String, + inputName: String, + target: Shape, + ) { withBlock("headers.push(", ");") { rustTemplate( "#{Header}::new(${headerName.dq()}, #{HeaderValue}::${headerValue(inputName, target)})", @@ -167,17 +182,21 @@ open class EventStreamMarshallerGenerator( // Event stream header types: https://awslabs.github.io/smithy/1.0/spec/core/stream-traits.html#eventheader-trait // Note: there are no floating point header types for Event Stream. - private fun headerValue(inputName: String, target: Shape): String = when (target) { - is BooleanShape -> "Bool($inputName)" - is ByteShape -> "Byte($inputName)" - is ShortShape -> "Int16($inputName)" - is IntegerShape -> "Int32($inputName)" - is LongShape -> "Int64($inputName)" - is BlobShape -> "ByteArray($inputName.into_inner().into())" - is StringShape -> "String($inputName.into())" - is TimestampShape -> "Timestamp($inputName)" - else -> throw IllegalStateException("unsupported event stream header shape type: $target") - } + private fun headerValue( + inputName: String, + target: Shape, + ): String = + when (target) { + is BooleanShape -> "Bool($inputName)" + is ByteShape -> "Byte($inputName)" + is ShortShape -> "Int16($inputName)" + is IntegerShape -> "Int32($inputName)" + is LongShape -> "Int64($inputName)" + is BlobShape -> "ByteArray($inputName.into_inner().into())" + is StringShape -> "String($inputName.into())" + is TimestampShape -> "Timestamp($inputName)" + else -> throw IllegalStateException("unsupported event stream header shape type: $target") + } protected fun RustWriter.renderMarshallEventPayload( inputExpr: String, @@ -189,11 +208,12 @@ open class EventStreamMarshallerGenerator( if (target is BlobShape || target is StringShape) { data class PayloadContext(val conversionFn: String, val contentType: String) - val ctx = when (target) { - is BlobShape -> PayloadContext("into_inner", "application/octet-stream") - is StringShape -> PayloadContext("into_bytes", "text/plain") - else -> throw IllegalStateException("unreachable") - } + val ctx = + when (target) { + is BlobShape -> PayloadContext("into_inner", "application/octet-stream") + is StringShape -> PayloadContext("into_bytes", "text/plain") + else -> throw IllegalStateException("unreachable") + } addStringHeader(":content-type", "${ctx.contentType.dq()}.into()") handleOptional( optional, @@ -245,7 +265,10 @@ open class EventStreamMarshallerGenerator( } } - protected fun RustWriter.addStringHeader(name: String, valueExpr: String) { + protected fun RustWriter.addStringHeader( + name: String, + valueExpr: String, + ) { rustTemplate("headers.push(#{Header}::new(${name.dq()}, #{HeaderValue}::String($valueExpr)));", *codegenScope) } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGenerator.kt index 981706b5665..a351b25b65c 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGenerator.kt @@ -110,7 +110,10 @@ class JsonSerializerGenerator( val writeNulls: Boolean = false, ) { companion object { - fun collectionMember(context: Context, itemName: String): MemberContext = + fun collectionMember( + context: Context, + itemName: String, + ): MemberContext = MemberContext( "${context.writerExpression}.value()", ValueExpression.Reference(itemName), @@ -118,7 +121,11 @@ class JsonSerializerGenerator( writeNulls = true, ) - fun mapMember(context: Context, key: String, value: String): MemberContext = + fun mapMember( + context: Context, + key: String, + value: String, + ): MemberContext = MemberContext( "${context.writerExpression}.key($key)", ValueExpression.Reference(value), @@ -151,8 +158,10 @@ class JsonSerializerGenerator( ) /** Returns an expression to get a JsonValueWriter from a JsonObjectWriter */ - private fun objectValueWriterExpression(objectWriterName: String, jsonName: String): String = - "$objectWriterName.key(${jsonName.dq()})" + private fun objectValueWriterExpression( + objectWriterName: String, + jsonName: String, + ): String = "$objectWriterName.key(${jsonName.dq()})" } } @@ -170,14 +179,15 @@ class JsonSerializerGenerator( private val codegenTarget = codegenContext.target private val runtimeConfig = codegenContext.runtimeConfig private val protocolFunctions = ProtocolFunctions(codegenContext) - private val codegenScope = arrayOf( - *preludeScope, - "Error" to runtimeConfig.serializationError(), - "SdkBody" to RuntimeType.sdkBody(runtimeConfig), - "JsonObjectWriter" to RuntimeType.smithyJson(runtimeConfig).resolve("serialize::JsonObjectWriter"), - "JsonValueWriter" to RuntimeType.smithyJson(runtimeConfig).resolve("serialize::JsonValueWriter"), - "ByteSlab" to RuntimeType.ByteSlab, - ) + private val codegenScope = + arrayOf( + *preludeScope, + "Error" to runtimeConfig.serializationError(), + "SdkBody" to RuntimeType.sdkBody(runtimeConfig), + "JsonObjectWriter" to RuntimeType.smithyJson(runtimeConfig).resolve("serialize::JsonObjectWriter"), + "JsonValueWriter" to RuntimeType.smithyJson(runtimeConfig).resolve("serialize::JsonValueWriter"), + "ByteSlab" to RuntimeType.ByteSlab, + ) private val serializerUtil = SerializerUtil(model) /** @@ -191,10 +201,11 @@ class JsonSerializerGenerator( makeSection: (StructureShape, String) -> JsonSerializerSection, error: Boolean, ): RuntimeType { - val suffix = when (error) { - true -> "error" - else -> "output" - } + val suffix = + when (error) { + true -> "error" + else -> "output" + } return protocolFunctions.serializeFn(structureShape, fnNameSuffix = suffix) { fnName -> rustBlockTemplate( "pub fn $fnName(value: &#{target}) -> Result", @@ -323,29 +334,33 @@ class JsonSerializerGenerator( context: StructContext, includedMembers: List? = null, ) { - val structureSerializer = protocolFunctions.serializeFn(context.shape) { fnName -> - val inner = context.copy(objectName = "object", localName = "input") - val members = includedMembers ?: inner.shape.members() - val allowUnusedVariables = writable { - if (members.isEmpty()) { Attribute.AllowUnusedVariables.render(this) } - } - rustBlockTemplate( - """ - pub fn $fnName( - #{AllowUnusedVariables:W} object: &mut #{JsonObjectWriter}, - #{AllowUnusedVariables:W} input: &#{StructureSymbol}, - ) -> Result<(), #{Error}> - """, - "StructureSymbol" to symbolProvider.toSymbol(context.shape), - "AllowUnusedVariables" to allowUnusedVariables, - *codegenScope, - ) { - for (member in members) { - serializeMember(MemberContext.structMember(inner, member, symbolProvider, jsonName)) + val structureSerializer = + protocolFunctions.serializeFn(context.shape) { fnName -> + val inner = context.copy(objectName = "object", localName = "input") + val members = includedMembers ?: inner.shape.members() + val allowUnusedVariables = + writable { + if (members.isEmpty()) { + Attribute.AllowUnusedVariables.render(this) + } + } + rustBlockTemplate( + """ + pub fn $fnName( + #{AllowUnusedVariables:W} object: &mut #{JsonObjectWriter}, + #{AllowUnusedVariables:W} input: &#{StructureSymbol}, + ) -> Result<(), #{Error}> + """, + "StructureSymbol" to symbolProvider.toSymbol(context.shape), + "AllowUnusedVariables" to allowUnusedVariables, + *codegenScope, + ) { + for (member in members) { + serializeMember(MemberContext.structMember(inner, member, symbolProvider, jsonName)) + } + rust("Ok(())") } - rust("Ok(())") } - } rust("#T(&mut ${context.objectName}, ${context.localName})?;", structureSerializer) } @@ -386,7 +401,10 @@ class JsonSerializerGenerator( } } - private fun RustWriter.serializeMemberValue(context: MemberContext, target: Shape) { + private fun RustWriter.serializeMemberValue( + context: MemberContext, + target: Shape, + ) { val writer = context.writerExpression val value = context.valueExpression @@ -394,21 +412,23 @@ class JsonSerializerGenerator( is StringShape -> rust("$writer.string(${value.name}.as_str());") is BooleanShape -> rust("$writer.boolean(${value.asValue()});") is NumberShape -> { - val numberType = when (target) { - is IntegerShape, is ByteShape, is LongShape, is ShortShape -> "NegInt" - is DoubleShape, is FloatShape -> "Float" - else -> throw IllegalStateException("unreachable") - } + val numberType = + when (target) { + is IntegerShape, is ByteShape, is LongShape, is ShortShape -> "NegInt" + is DoubleShape, is FloatShape -> "Float" + else -> throw IllegalStateException("unreachable") + } rust( "$writer.number(##[allow(clippy::useless_conversion)]#T::$numberType((${value.asValue()}).into()));", RuntimeType.smithyTypes(runtimeConfig).resolve("Number"), ) } - is BlobShape -> rust( - "$writer.string_unchecked(&#T(${value.asRef()}));", - RuntimeType.base64Encode(runtimeConfig), - ) + is BlobShape -> + rust( + "$writer.string_unchecked(&#T(${value.asRef()}));", + RuntimeType.base64Encode(runtimeConfig), + ) is TimestampShape -> { val timestampFormat = @@ -420,28 +440,35 @@ class JsonSerializerGenerator( ) } - is CollectionShape -> jsonArrayWriter(context) { arrayName -> - serializeCollection(Context(arrayName, value, target)) - } + is CollectionShape -> + jsonArrayWriter(context) { arrayName -> + serializeCollection(Context(arrayName, value, target)) + } - is MapShape -> jsonObjectWriter(context) { objectName -> - serializeMap(Context(objectName, value, target)) - } + is MapShape -> + jsonObjectWriter(context) { objectName -> + serializeMap(Context(objectName, value, target)) + } - is StructureShape -> jsonObjectWriter(context) { objectName -> - serializeStructure(StructContext(objectName, value.asRef(), target)) - } + is StructureShape -> + jsonObjectWriter(context) { objectName -> + serializeStructure(StructContext(objectName, value.asRef(), target)) + } - is UnionShape -> jsonObjectWriter(context) { objectName -> - serializeUnion(Context(objectName, value, target)) - } + is UnionShape -> + jsonObjectWriter(context) { objectName -> + serializeUnion(Context(objectName, value, target)) + } is DocumentShape -> rust("$writer.document(${value.asRef()});") else -> TODO(target.toString()) } } - private fun RustWriter.jsonArrayWriter(context: MemberContext, inner: RustWriter.(String) -> Unit) { + private fun RustWriter.jsonArrayWriter( + context: MemberContext, + inner: RustWriter.(String) -> Unit, + ) { safeName("array").also { arrayName -> rust("let mut $arrayName = ${context.writerExpression}.start_array();") inner(arrayName) @@ -449,7 +476,10 @@ class JsonSerializerGenerator( } } - private fun RustWriter.jsonObjectWriter(context: MemberContext, inner: RustWriter.(String) -> Unit) { + private fun RustWriter.jsonObjectWriter( + context: MemberContext, + inner: RustWriter.(String) -> Unit, + ) { safeName("object").also { objectName -> rust("##[allow(unused_mut)]") rust("let mut $objectName = ${context.writerExpression}.start_object();") @@ -497,34 +527,36 @@ class JsonSerializerGenerator( private fun RustWriter.serializeUnion(context: Context) { val unionSymbol = symbolProvider.toSymbol(context.shape) - val unionSerializer = protocolFunctions.serializeFn(context.shape) { fnName -> - rustBlockTemplate( - "pub fn $fnName(${context.writerExpression}: &mut #{JsonObjectWriter}, input: &#{Input}) -> Result<(), #{Error}>", - "Input" to unionSymbol, - *codegenScope, - ) { - rustBlock("match input") { - for (member in context.shape.members()) { - val variantName = if (member.isTargetUnit()) { - "${symbolProvider.toMemberName(member)}" - } else { - "${symbolProvider.toMemberName(member)}(inner)" + val unionSerializer = + protocolFunctions.serializeFn(context.shape) { fnName -> + rustBlockTemplate( + "pub fn $fnName(${context.writerExpression}: &mut #{JsonObjectWriter}, input: &#{Input}) -> Result<(), #{Error}>", + "Input" to unionSymbol, + *codegenScope, + ) { + rustBlock("match input") { + for (member in context.shape.members()) { + val variantName = + if (member.isTargetUnit()) { + "${symbolProvider.toMemberName(member)}" + } else { + "${symbolProvider.toMemberName(member)}(inner)" + } + withBlock("#T::$variantName => {", "},", unionSymbol) { + serializeMember(MemberContext.unionMember(context, "inner", member, jsonName)) + } } - withBlock("#T::$variantName => {", "},", unionSymbol) { - serializeMember(MemberContext.unionMember(context, "inner", member, jsonName)) + if (codegenTarget.renderUnknownVariant()) { + rustTemplate( + "#{Union}::${UnionGenerator.UnknownVariantName} => return Err(#{Error}::unknown_variant(${unionSymbol.name.dq()}))", + "Union" to unionSymbol, + *codegenScope, + ) } } - if (codegenTarget.renderUnknownVariant()) { - rustTemplate( - "#{Union}::${UnionGenerator.UnknownVariantName} => return Err(#{Error}::unknown_variant(${unionSymbol.name.dq()}))", - "Union" to unionSymbol, - *codegenScope, - ) - } + rust("Ok(())") } - rust("Ok(())") } - } rust("#T(&mut ${context.writerExpression}, ${context.valueExpression.asRef()})?;", unionSerializer) } } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/QuerySerializerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/QuerySerializerGenerator.kt index 23c8bbb4fb8..8d6454160ba 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/QuerySerializerGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/QuerySerializerGenerator.kt @@ -98,16 +98,19 @@ abstract class QuerySerializerGenerator(private val codegenContext: CodegenConte private val smithyQuery = RuntimeType.smithyQuery(runtimeConfig) private val serdeUtil = SerializerUtil(model) private val protocolFunctions = ProtocolFunctions(codegenContext) - private val codegenScope = arrayOf( - "String" to RuntimeType.String, - "Error" to serializerError, - "SdkBody" to RuntimeType.sdkBody(runtimeConfig), - "QueryWriter" to smithyQuery.resolve("QueryWriter"), - "QueryValueWriter" to smithyQuery.resolve("QueryValueWriter"), - ) + private val codegenScope = + arrayOf( + "String" to RuntimeType.String, + "Error" to serializerError, + "SdkBody" to RuntimeType.sdkBody(runtimeConfig), + "QueryWriter" to smithyQuery.resolve("QueryWriter"), + "QueryValueWriter" to smithyQuery.resolve("QueryValueWriter"), + ) abstract val protocolName: String + abstract fun MemberShape.queryKeyName(prioritizedFallback: String? = null): String + abstract fun MemberShape.isFlattened(): Boolean override fun documentSerializer(): RuntimeType { @@ -171,17 +174,18 @@ abstract class QuerySerializerGenerator(private val codegenContext: CodegenConte return } val structureSymbol = symbolProvider.toSymbol(context.shape) - val structureSerializer = protocolFunctions.serializeFn(context.shape) { fnName -> - Attribute.AllowUnusedMut.render(this) - rustBlockTemplate( - "pub fn $fnName(mut writer: #{QueryValueWriter}, input: &#{Input}) -> Result<(), #{Error}>", - "Input" to structureSymbol, - *codegenScope, - ) { - serializeStructureInner(context) - rust("Ok(())") + val structureSerializer = + protocolFunctions.serializeFn(context.shape) { fnName -> + Attribute.AllowUnusedMut.render(this) + rustBlockTemplate( + "pub fn $fnName(mut writer: #{QueryValueWriter}, input: &#{Input}) -> Result<(), #{Error}>", + "Input" to structureSymbol, + *codegenScope, + ) { + serializeStructureInner(context) + rust("Ok(())") + } } - } rust("#T(${context.writerExpression}, ${context.valueExpression.asRef()})?;", structureSerializer) } @@ -216,7 +220,10 @@ abstract class QuerySerializerGenerator(private val codegenContext: CodegenConte } } - private fun RustWriter.serializeMemberValue(context: MemberContext, target: Shape) { + private fun RustWriter.serializeMemberValue( + context: MemberContext, + target: Shape, + ) { val writer = context.writerExpression val value = context.valueExpression when (target) { @@ -228,21 +235,23 @@ abstract class QuerySerializerGenerator(private val codegenContext: CodegenConte } is BooleanShape -> rust("$writer.boolean(${value.asValue()});") is NumberShape -> { - val numberType = when (symbolProvider.toSymbol(target).rustType()) { - is RustType.Float -> "Float" - // NegInt takes an i64 while PosInt takes u64. We need this to be signed here - is RustType.Integer -> "NegInt" - else -> throw IllegalStateException("unreachable") - } + val numberType = + when (symbolProvider.toSymbol(target).rustType()) { + is RustType.Float -> "Float" + // NegInt takes an i64 while PosInt takes u64. We need this to be signed here + is RustType.Integer -> "NegInt" + else -> throw IllegalStateException("unreachable") + } rust( "$writer.number(##[allow(clippy::useless_conversion)]#T::$numberType((${value.asValue()}).into()));", smithyTypes.resolve("Number"), ) } - is BlobShape -> rust( - "$writer.string(&#T(${value.asRef()}));", - RuntimeType.base64Encode(runtimeConfig), - ) + is BlobShape -> + rust( + "$writer.string(&#T(${value.asRef()}));", + RuntimeType.base64Encode(runtimeConfig), + ) is TimestampShape -> { val timestampFormat = determineTimestampFormat(context.shape) val timestampFormatType = RuntimeType.serializeTimestampFormat(runtimeConfig, timestampFormat) @@ -251,9 +260,10 @@ abstract class QuerySerializerGenerator(private val codegenContext: CodegenConte is CollectionShape -> serializeCollection(context, Context(writer, context.valueExpression, target)) is MapShape -> serializeMap(context, Context(writer, context.valueExpression, target)) is StructureShape -> serializeStructure(Context(writer, context.valueExpression, target)) - is UnionShape -> structWriter(context) { writerExpression -> - serializeUnion(Context(writerExpression, context.valueExpression, target)) - } + is UnionShape -> + structWriter(context) { writerExpression -> + serializeUnion(Context(writerExpression, context.valueExpression, target)) + } else -> TODO(target.toString()) } } @@ -262,7 +272,10 @@ abstract class QuerySerializerGenerator(private val codegenContext: CodegenConte shape.getMemberTrait(model, TimestampFormatTrait::class.java).orNull()?.format ?: TimestampFormatTrait.Format.DATE_TIME - private fun RustWriter.structWriter(context: MemberContext, inner: RustWriter.(String) -> Unit) { + private fun RustWriter.structWriter( + context: MemberContext, + inner: RustWriter.(String) -> Unit, + ) { val prefix = context.shape.queryKeyName() safeName("scope").also { scopeName -> Attribute.AllowUnusedMut.render(this) @@ -271,12 +284,16 @@ abstract class QuerySerializerGenerator(private val codegenContext: CodegenConte } } - private fun RustWriter.serializeCollection(memberContext: MemberContext, context: Context) { + private fun RustWriter.serializeCollection( + memberContext: MemberContext, + context: Context, + ) { val flat = memberContext.shape.isFlattened() - val memberOverride = when (val override = context.shape.member.getTrait()?.value) { - null -> "None" - else -> "Some(${override.dq()})" - } + val memberOverride = + when (val override = context.shape.member.getTrait()?.value) { + null -> "None" + else -> "Some(${override.dq()})" + } val itemName = safeName("item") safeName("list").also { listName -> rust("let mut $listName = ${context.writerExpression}.start_list($flat, $memberOverride);") @@ -294,7 +311,10 @@ abstract class QuerySerializerGenerator(private val codegenContext: CodegenConte } } - private fun RustWriter.serializeMap(memberContext: MemberContext, context: Context) { + private fun RustWriter.serializeMap( + memberContext: MemberContext, + context: Context, + ) { val flat = memberContext.shape.isFlattened() val entryKeyName = context.shape.key.queryKeyName("key").dq() val entryValueName = context.shape.value.queryKeyName("value").dq() @@ -304,10 +324,11 @@ abstract class QuerySerializerGenerator(private val codegenContext: CodegenConte rust("let mut $mapName = ${context.writerExpression}.start_map($flat, $entryKeyName, $entryValueName);") rustBlock("for ($keyName, $valueName) in ${context.valueExpression.asRef()}") { val keyTarget = model.expectShape(context.shape.key.target) - val keyExpression = when (keyTarget.hasTrait()) { - true -> "$keyName.as_str()" - else -> keyName - } + val keyExpression = + when (keyTarget.hasTrait()) { + true -> "$keyName.as_str()" + else -> keyName + } val entryName = safeName("entry") Attribute.AllowUnusedMut.render(this) rust("let mut $entryName = $mapName.entry($keyExpression);") @@ -319,41 +340,43 @@ abstract class QuerySerializerGenerator(private val codegenContext: CodegenConte private fun RustWriter.serializeUnion(context: Context) { val unionSymbol = symbolProvider.toSymbol(context.shape) - val unionSerializer = protocolFunctions.serializeFn(context.shape) { fnName -> - Attribute.AllowUnusedMut.render(this) - rustBlockTemplate( - "pub fn $fnName(mut writer: #{QueryValueWriter}, input: &#{Input}) -> Result<(), #{Error}>", - "Input" to unionSymbol, - *codegenScope, - ) { - rustBlock("match input") { - for (member in context.shape.members()) { - val variantName = if (member.isTargetUnit()) { - "${symbolProvider.toMemberName(member)}" - } else { - "${symbolProvider.toMemberName(member)}(inner)" + val unionSerializer = + protocolFunctions.serializeFn(context.shape) { fnName -> + Attribute.AllowUnusedMut.render(this) + rustBlockTemplate( + "pub fn $fnName(mut writer: #{QueryValueWriter}, input: &#{Input}) -> Result<(), #{Error}>", + "Input" to unionSymbol, + *codegenScope, + ) { + rustBlock("match input") { + for (member in context.shape.members()) { + val variantName = + if (member.isTargetUnit()) { + "${symbolProvider.toMemberName(member)}" + } else { + "${symbolProvider.toMemberName(member)}(inner)" + } + withBlock("#T::$variantName => {", "},", unionSymbol) { + serializeMember( + MemberContext.unionMember( + context.copy(writerExpression = "writer"), + "inner", + member, + ), + ) + } } - withBlock("#T::$variantName => {", "},", unionSymbol) { - serializeMember( - MemberContext.unionMember( - context.copy(writerExpression = "writer"), - "inner", - member, - ), + if (target.renderUnknownVariant()) { + rustTemplate( + "#{Union}::${UnionGenerator.UnknownVariantName} => return Err(#{Error}::unknown_variant(${unionSymbol.name.dq()}))", + "Union" to unionSymbol, + *codegenScope, ) } } - if (target.renderUnknownVariant()) { - rustTemplate( - "#{Union}::${UnionGenerator.UnknownVariantName} => return Err(#{Error}::unknown_variant(${unionSymbol.name.dq()}))", - "Union" to unionSymbol, - *codegenScope, - ) - } + rust("Ok(())") } - rust("Ok(())") } - } rust("#T(${context.writerExpression}, ${context.valueExpression.asRef()})?;", unionSerializer) } } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/SerializerUtil.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/SerializerUtil.kt index 27293e43039..3617fa8ae9e 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/SerializerUtil.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/SerializerUtil.kt @@ -13,7 +13,11 @@ import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock class SerializerUtil(private val model: Model) { - fun RustWriter.ignoreZeroValues(shape: MemberShape, value: ValueExpression, inner: Writable) { + fun RustWriter.ignoreZeroValues( + shape: MemberShape, + value: ValueExpression, + inner: Writable, + ) { // Required shapes should always be serialized // See https://github.com/smithy-lang/smithy-rs/issues/230 and https://github.com/aws/aws-sdk-go-v2/pull/1129 if ( diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/ValueExpression.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/ValueExpression.kt index 00bc8ba74c7..fe21e20b081 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/ValueExpression.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/ValueExpression.kt @@ -11,17 +11,20 @@ sealed class ValueExpression { abstract val name: String data class Reference(override val name: String) : ValueExpression() + data class Value(override val name: String) : ValueExpression() - fun asValue(): String = when (this) { - is Reference -> autoDeref(name) - is Value -> name - } + fun asValue(): String = + when (this) { + is Reference -> autoDeref(name) + is Value -> name + } - fun asRef(): String = when (this) { - is Reference -> name - is Value -> "&$name" - } + fun asRef(): String = + when (this) { + is Reference -> name + is Value -> "&$name" + } override fun toString(): String = this.name } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/XmlBindingTraitSerializerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/XmlBindingTraitSerializerGenerator.kt index fc4f0198387..f29e867e1e9 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/XmlBindingTraitSerializerGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/XmlBindingTraitSerializerGenerator.kt @@ -85,11 +85,15 @@ class XmlBindingTraitSerializerGenerator( companion object { // Kotlin doesn't have a "This" type @Suppress("UNCHECKED_CAST") - fun updateInput(input: T, newInput: String): T = when (input) { - is Element -> input.copy(input = newInput) as T - is Scope -> input.copy(input = newInput) as T - else -> TODO() - } + fun updateInput( + input: T, + newInput: String, + ): T = + when (input) { + is Element -> input.copy(input = newInput) as T + is Scope -> input.copy(input = newInput) as T + else -> TODO() + } } } @@ -105,8 +109,9 @@ class XmlBindingTraitSerializerGenerator( if (xmlMembers.isEmpty()) { return null } - val operationXmlName = xmlIndex.operationInputShapeName(operationShape) - ?: throw CodegenException("operation must have a name if it has members") + val operationXmlName = + xmlIndex.operationInputShapeName(operationShape) + ?: throw CodegenException("operation must have a name if it has members") return protocolFunctions.serializeFn(operationShape, fnNameSuffix = "op_input") { fnName -> rustBlockTemplate( "pub fn $fnName(input: &#{target}) -> Result<#{SdkBody}, #{Error}>", @@ -162,11 +167,12 @@ class XmlBindingTraitSerializerGenerator( *codegenScope, ) when (target) { - is StructureShape -> serializeStructure( - target, - XmlMemberIndex.fromMembers(target.members().toList()), - Ctx.Element("root", "input"), - ) + is StructureShape -> + serializeStructure( + target, + XmlMemberIndex.fromMembers(target.members().toList()), + Ctx.Element("root", "input"), + ) is UnionShape -> serializeUnion(target, Ctx.Element("root", "input")) else -> throw IllegalStateException("xml payloadSerializer only supports structs and unions") @@ -204,8 +210,9 @@ class XmlBindingTraitSerializerGenerator( if (xmlMembers.isEmpty()) { return null } - val operationXmlName = xmlIndex.operationOutputShapeName(operationShape) - ?: throw CodegenException("operation must have a name if it has members") + val operationXmlName = + xmlIndex.operationOutputShapeName(operationShape) + ?: throw CodegenException("operation must have a name if it has members") return protocolFunctions.serializeFn(operationShape, fnNameSuffix = "output") { fnName -> rustBlockTemplate( "pub fn $fnName(output: &#{target}) -> Result", @@ -235,9 +242,10 @@ class XmlBindingTraitSerializerGenerator( override fun serverErrorSerializer(shape: ShapeId): RuntimeType { val errorShape = model.expectShape(shape, StructureShape::class.java) - val xmlMembers = httpBindingResolver.errorResponseBindings(shape) - .filter { it.location == HttpLocation.DOCUMENT } - .map { it.member } + val xmlMembers = + httpBindingResolver.errorResponseBindings(shape) + .filter { it.location == HttpLocation.DOCUMENT } + .map { it.member } return protocolFunctions.serializeFn(errorShape, fnNameSuffix = "error") { fnName -> rustBlockTemplate( "pub fn $fnName(error: &#{target}) -> Result", @@ -273,7 +281,10 @@ class XmlBindingTraitSerializerGenerator( return ".write_ns(${uri.dq()}, $prefix)" } - private fun RustWriter.structureInner(members: XmlMemberIndex, ctx: Ctx.Element) { + private fun RustWriter.structureInner( + members: XmlMemberIndex, + ctx: Ctx.Element, + ) { if (members.attributeMembers.isNotEmpty()) { rust("let mut ${ctx.elementWriter} = ${ctx.elementWriter};") } @@ -293,16 +304,20 @@ class XmlBindingTraitSerializerGenerator( rust("scope.finish();") } - private fun RustWriter.serializeRawMember(member: MemberShape, input: String) { + private fun RustWriter.serializeRawMember( + member: MemberShape, + input: String, + ) { when (model.expectShape(member.target)) { is StringShape -> { // The `input` expression always evaluates to a reference type at this point, but if it does so because // it's preceded by the `&` operator, calling `as_str()` on it will upset Clippy. - val dereferenced = if (input.startsWith("&")) { - autoDeref(input) - } else { - input - } + val dereferenced = + if (input.startsWith("&")) { + autoDeref(input) + } else { + input + } rust("$dereferenced.as_str()") } @@ -330,7 +345,11 @@ class XmlBindingTraitSerializerGenerator( } @Suppress("NAME_SHADOWING") - private fun RustWriter.serializeMember(memberShape: MemberShape, ctx: Ctx.Scope, rootNameOverride: String? = null) { + private fun RustWriter.serializeMember( + memberShape: MemberShape, + ctx: Ctx.Scope, + rootNameOverride: String? = null, + ) { val target = model.expectShape(memberShape.target) val xmlName = rootNameOverride ?: xmlIndex.memberName(memberShape) val ns = memberShape.xmlNamespace(root = false).apply() @@ -343,19 +362,21 @@ class XmlBindingTraitSerializerGenerator( } } - is CollectionShape -> if (memberShape.hasTrait()) { - serializeFlatList(memberShape, target, ctx) - } else { - rust("let mut inner_writer = ${ctx.scopeWriter}.start_el(${xmlName.dq()})$ns.finish();") - serializeList(target, Ctx.Scope("inner_writer", ctx.input)) - } + is CollectionShape -> + if (memberShape.hasTrait()) { + serializeFlatList(memberShape, target, ctx) + } else { + rust("let mut inner_writer = ${ctx.scopeWriter}.start_el(${xmlName.dq()})$ns.finish();") + serializeList(target, Ctx.Scope("inner_writer", ctx.input)) + } - is MapShape -> if (memberShape.hasTrait()) { - serializeMap(target, xmlIndex.memberName(memberShape), ctx) - } else { - rust("let mut inner_writer = ${ctx.scopeWriter}.start_el(${xmlName.dq()})$ns.finish();") - serializeMap(target, "entry", Ctx.Scope("inner_writer", ctx.input)) - } + is MapShape -> + if (memberShape.hasTrait()) { + serializeMap(target, xmlIndex.memberName(memberShape), ctx) + } else { + rust("let mut inner_writer = ${ctx.scopeWriter}.start_el(${xmlName.dq()})$ns.finish();") + serializeMap(target, "entry", Ctx.Scope("inner_writer", ctx.input)) + } is StructureShape -> { // We call serializeStructure only when target.members() is nonempty. @@ -401,74 +422,91 @@ class XmlBindingTraitSerializerGenerator( fnNameSuffix: String? = null, ) { val structureSymbol = symbolProvider.toSymbol(structureShape) - val structureSerializer = protocolFunctions.serializeFn(structureShape, fnNameSuffix = fnNameSuffix) { fnName -> - rustBlockTemplate( - "pub fn $fnName(input: &#{Input}, writer: #{ElementWriter}) -> Result<(), #{Error}>", - "Input" to structureSymbol, - *codegenScope, - ) { - if (!members.isNotEmpty()) { - // removed unused warning if there are no fields we're going to read - rust("let _ = input;") + val structureSerializer = + protocolFunctions.serializeFn(structureShape, fnNameSuffix = fnNameSuffix) { fnName -> + rustBlockTemplate( + "pub fn $fnName(input: &#{Input}, writer: #{ElementWriter}) -> Result<(), #{Error}>", + "Input" to structureSymbol, + *codegenScope, + ) { + if (!members.isNotEmpty()) { + // removed unused warning if there are no fields we're going to read + rust("let _ = input;") + } + structureInner(members, Ctx.Element("writer", "&input")) + rust("Ok(())") } - structureInner(members, Ctx.Element("writer", "&input")) - rust("Ok(())") } - } rust("#T(${ctx.input}, ${ctx.elementWriter})?", structureSerializer) } - private fun RustWriter.serializeUnion(unionShape: UnionShape, ctx: Ctx.Element) { + private fun RustWriter.serializeUnion( + unionShape: UnionShape, + ctx: Ctx.Element, + ) { val unionSymbol = symbolProvider.toSymbol(unionShape) - val structureSerializer = protocolFunctions.serializeFn(unionShape) { fnName -> - rustBlockTemplate( - "pub fn $fnName(input: &#{Input}, writer: #{ElementWriter}) -> Result<(), #{Error}>", - "Input" to unionSymbol, - *codegenScope, - ) { - rust("let mut scope_writer = writer.finish();") - rustBlock("match input") { - val members = unionShape.members() - members.forEach { member -> - val variantName = if (member.isTargetUnit()) { - "${symbolProvider.toMemberName(member)}" - } else { - "${symbolProvider.toMemberName(member)}(inner)" - } - withBlock("#T::$variantName =>", ",", unionSymbol) { - serializeMember(member, Ctx.Scope("scope_writer", "inner")) + val structureSerializer = + protocolFunctions.serializeFn(unionShape) { fnName -> + rustBlockTemplate( + "pub fn $fnName(input: &#{Input}, writer: #{ElementWriter}) -> Result<(), #{Error}>", + "Input" to unionSymbol, + *codegenScope, + ) { + rust("let mut scope_writer = writer.finish();") + rustBlock("match input") { + val members = unionShape.members() + members.forEach { member -> + val variantName = + if (member.isTargetUnit()) { + "${symbolProvider.toMemberName(member)}" + } else { + "${symbolProvider.toMemberName(member)}(inner)" + } + withBlock("#T::$variantName =>", ",", unionSymbol) { + serializeMember(member, Ctx.Scope("scope_writer", "inner")) + } } - } - if (codegenTarget.renderUnknownVariant()) { - rustTemplate( - "#{Union}::${UnionGenerator.UnknownVariantName} => return Err(#{Error}::unknown_variant(${unionSymbol.name.dq()}))", - "Union" to unionSymbol, - *codegenScope, - ) + if (codegenTarget.renderUnknownVariant()) { + rustTemplate( + "#{Union}::${UnionGenerator.UnknownVariantName} => return Err(#{Error}::unknown_variant(${unionSymbol.name.dq()}))", + "Union" to unionSymbol, + *codegenScope, + ) + } } + rust("Ok(())") } - rust("Ok(())") } - } rust("#T(${ctx.input}, ${ctx.elementWriter})?", structureSerializer) } - private fun RustWriter.serializeList(listShape: CollectionShape, ctx: Ctx.Scope) { + private fun RustWriter.serializeList( + listShape: CollectionShape, + ctx: Ctx.Scope, + ) { val itemName = safeName("list_item") rustBlock("for $itemName in ${ctx.input}") { serializeMember(listShape.member, ctx.copy(input = itemName)) } } - private fun RustWriter.serializeFlatList(member: MemberShape, listShape: CollectionShape, ctx: Ctx.Scope) { + private fun RustWriter.serializeFlatList( + member: MemberShape, + listShape: CollectionShape, + ctx: Ctx.Scope, + ) { val itemName = safeName("list_item") rustBlock("for $itemName in ${ctx.input}") { serializeMember(listShape.member, ctx.copy(input = itemName), xmlIndex.memberName(member)) } } - private fun RustWriter.serializeMap(mapShape: MapShape, entryName: String, ctx: Ctx.Scope) { + private fun RustWriter.serializeMap( + mapShape: MapShape, + entryName: String, + ctx: Ctx.Scope, + ) { val key = safeName("key") val value = safeName("value") rustBlock("for ($key, $value) in ${ctx.input}") { @@ -501,28 +539,30 @@ class XmlBindingTraitSerializerGenerator( if (memberSymbol.isOptional()) { val tmp = safeName() val target = model.expectShape(member.target) - val pattern = if (target.isStructureShape && target.members().isEmpty()) { - // In this case, we mark a variable captured in the if-let - // expression as unused to prevent the warning coming - // from the following code generated by handleOptional: - // if let Some(var_2) = &input.input { - // scope.start_el("input").finish(); - // } - // where var_2 above is unused. - "Some(_$tmp)" - } else { - "Some($tmp)" - } + val pattern = + if (target.isStructureShape && target.members().isEmpty()) { + // In this case, we mark a variable captured in the if-let + // expression as unused to prevent the warning coming + // from the following code generated by handleOptional: + // if let Some(var_2) = &input.input { + // scope.start_el("input").finish(); + // } + // where var_2 above is unused. + "Some(_$tmp)" + } else { + "Some($tmp)" + } rustBlock("if let $pattern = ${ctx.input}") { inner(Ctx.updateInput(ctx, tmp)) } } else { with(util) { - val valueExpression = if (ctx.input.startsWith("&")) { - ValueExpression.Reference(ctx.input) - } else { - ValueExpression.Value(ctx.input) - } + val valueExpression = + if (ctx.input.startsWith("&")) { + ValueExpression.Reference(ctx.input) + } else { + ValueExpression.Value(ctx.input) + } ignoreZeroValues(member, valueExpression) { inner(ctx) } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/traits/RustBoxTrait.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/traits/RustBoxTrait.kt index 41144d5945c..86e341062cf 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/traits/RustBoxTrait.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/traits/RustBoxTrait.kt @@ -19,6 +19,7 @@ import software.amazon.smithy.model.traits.Trait */ class RustBoxTrait : Trait { val ID = ShapeId.from("software.amazon.smithy.rust.codegen.smithy.rust.synthetic#box") + override fun toNode(): Node = Node.objectNode() override fun toShapeId(): ShapeId = ID diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/traits/SyntheticOutputTrait.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/traits/SyntheticOutputTrait.kt index ba086f9cb6a..d34ca0e39da 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/traits/SyntheticOutputTrait.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/traits/SyntheticOutputTrait.kt @@ -17,7 +17,7 @@ import software.amazon.smithy.model.traits.AnnotationTrait */ class SyntheticOutputTrait constructor(val operation: ShapeId, val originalId: ShapeId?) : AnnotationTrait(ID, Node.objectNode()) { - companion object { - val ID: ShapeId = ShapeId.from("smithy.api.internal#syntheticOutput") + companion object { + val ID: ShapeId = ShapeId.from("smithy.api.internal#syntheticOutput") + } } -} diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/EventStreamNormalizer.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/EventStreamNormalizer.kt index e46323ce19c..236ce443383 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/EventStreamNormalizer.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/EventStreamNormalizer.kt @@ -28,17 +28,21 @@ import software.amazon.smithy.rust.codegen.core.util.outputShape * place that does codegen with the unions. */ object EventStreamNormalizer { - fun transform(model: Model): Model = ModelTransformer.create().mapShapes(model) { shape -> - if (shape is OperationShape && shape.isEventStream(model)) { - addStreamErrorsToOperationErrors(model, shape) - } else if (shape is UnionShape && shape.isEventStream()) { - syntheticEquivalentEventStreamUnion(model, shape) - } else { - shape + fun transform(model: Model): Model = + ModelTransformer.create().mapShapes(model) { shape -> + if (shape is OperationShape && shape.isEventStream(model)) { + addStreamErrorsToOperationErrors(model, shape) + } else if (shape is UnionShape && shape.isEventStream()) { + syntheticEquivalentEventStreamUnion(model, shape) + } else { + shape + } } - } - private fun addStreamErrorsToOperationErrors(model: Model, operation: OperationShape): OperationShape { + private fun addStreamErrorsToOperationErrors( + model: Model, + operation: OperationShape, + ): OperationShape { if (!operation.isEventStream(model)) { return operation } @@ -57,10 +61,14 @@ object EventStreamNormalizer { .build() } - private fun syntheticEquivalentEventStreamUnion(model: Model, union: UnionShape): UnionShape { - val (errorMembers, eventMembers) = union.members().partition { member -> - model.expectShape(member.target).hasTrait() - } + private fun syntheticEquivalentEventStreamUnion( + model: Model, + union: UnionShape, + ): UnionShape { + val (errorMembers, eventMembers) = + union.members().partition { member -> + model.expectShape(member.target).hasTrait() + } return union.toBuilder() .members(eventMembers) .addTrait(SyntheticEventStreamUnionTrait(errorMembers)) @@ -73,7 +81,10 @@ fun OperationShape.operationErrors(model: Model): List { return operationIndex.getErrors(this) } -fun eventStreamErrors(model: Model, shape: Shape): Map> { +fun eventStreamErrors( + model: Model, + shape: Shape, +): Map> { return DirectedWalker(model) .walkShapes(shape) .filter { it is UnionShape && it.isEventStream() } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/OperationNormalizer.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/OperationNormalizer.kt index 241e0d44d14..4092174b55e 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/OperationNormalizer.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/OperationNormalizer.kt @@ -54,10 +54,11 @@ object OperationNormalizer { fun transform(model: Model): Model { val transformer = ModelTransformer.create() val operations = model.shapes(OperationShape::class.java).toList() - val newShapes = operations.flatMap { operation -> - // Generate or modify the input and output of the given `Operation` to be a unique shape - listOf(syntheticInputShape(model, operation), syntheticOutputShape(model, operation)) - } + val newShapes = + operations.flatMap { operation -> + // Generate or modify the input and output of the given `Operation` to be a unique shape + listOf(syntheticInputShape(model, operation), syntheticOutputShape(model, operation)) + } val shapeConflict = newShapes.firstOrNull { shape -> model.getShape(shape.id).isPresent } check( shapeConflict == null, @@ -67,13 +68,14 @@ object OperationNormalizer { val modelWithOperationInputs = model.toBuilder().addShapes(newShapes).build() return transformer.mapShapes(modelWithOperationInputs) { // Update all operations to point to their new input/output shapes - val transformed: Optional = it.asOperationShape().map { operation -> - modelWithOperationInputs.expectShape(operation.syntheticInputId()) - operation.toBuilder() - .input(operation.syntheticInputId()) - .output(operation.syntheticOutputId()) - .build() - } + val transformed: Optional = + it.asOperationShape().map { operation -> + modelWithOperationInputs.expectShape(operation.syntheticInputId()) + operation.toBuilder() + .input(operation.syntheticInputId()) + .output(operation.syntheticOutputId()) + .build() + } transformed.orElse(it) } } @@ -84,11 +86,15 @@ object OperationNormalizer { * * If the operation does not have an output, an empty shape is generated */ - private fun syntheticOutputShape(model: Model, operation: OperationShape): StructureShape { + private fun syntheticOutputShape( + model: Model, + operation: OperationShape, + ): StructureShape { val outputId = operation.syntheticOutputId() - val outputShapeBuilder = operation.output.map { shapeId -> - model.expectShape(shapeId, StructureShape::class.java).toBuilder().rename(outputId) - }.orElse(empty(outputId)) + val outputShapeBuilder = + operation.output.map { shapeId -> + model.expectShape(shapeId, StructureShape::class.java).toBuilder().rename(outputId) + }.orElse(empty(outputId)) return outputShapeBuilder.addTrait( SyntheticOutputTrait( operation = operation.id, @@ -103,11 +109,15 @@ object OperationNormalizer { * * If the input operation does not have an input, an empty shape is generated */ - private fun syntheticInputShape(model: Model, operation: OperationShape): StructureShape { + private fun syntheticInputShape( + model: Model, + operation: OperationShape, + ): StructureShape { val inputId = operation.syntheticInputId() - val inputShapeBuilder = operation.input.map { shapeId -> - model.expectShape(shapeId, StructureShape::class.java).toBuilder().rename(inputId) - }.orElse(empty(inputId)) + val inputShapeBuilder = + operation.input.map { shapeId -> + model.expectShape(shapeId, StructureShape::class.java).toBuilder().rename(inputId) + }.orElse(empty(inputId)) // There are still shapes missing the input trait. If we don't apply this, we'll get bad results from the // nullability index if (!inputShapeBuilder.build().hasTrait()) { diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/RecursiveShapeBoxer.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/RecursiveShapeBoxer.kt index d53751829fb..f9c4d2d17c8 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/RecursiveShapeBoxer.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/RecursiveShapeBoxer.kt @@ -77,13 +77,14 @@ class RecursiveShapeBoxer( // (External to this function) Go back to 1. val index = TopologicalIndex.of(model) val recursiveShapes = index.recursiveShapes - val loops = recursiveShapes.map { shapeId -> - // Get all the shapes in the closure (represented as `Path`s). - index.getRecursiveClosure(shapeId) - }.flatMap { loops -> - // Flatten the connections into shapes. - loops.map { it.shapes } - } + val loops = + recursiveShapes.map { shapeId -> + // Get all the shapes in the closure (represented as `Path`s). + index.getRecursiveClosure(shapeId) + }.flatMap { loops -> + // Flatten the connections into shapes. + loops.map { it.shapes } + } val loopToFix = loops.firstOrNull { !containsIndirectionPredicate(it) } return loopToFix?.let { loop: List -> @@ -111,11 +112,12 @@ class RecursiveShapeBoxer( * indirection artificially ourselves using `Box`. * */ -private fun containsIndirection(loop: Collection): Boolean = loop.find { - when (it) { - is CollectionShape, is MapShape -> true - else -> it.hasTrait() - } -} != null +private fun containsIndirection(loop: Collection): Boolean = + loop.find { + when (it) { + is CollectionShape, is MapShape -> true + else -> it.hasTrait() + } + } != null private fun addRustBoxTrait(shape: MemberShape): MemberShape = shape.toBuilder().addTrait(RustBoxTrait()).build() diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/BasicTestModels.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/BasicTestModels.kt index eb4829702ee..c138a599bf5 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/BasicTestModels.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/BasicTestModels.kt @@ -6,7 +6,8 @@ package software.amazon.smithy.rust.codegen.core.testutil object BasicTestModels { - val AwsJson10TestModel = """ + val AwsJson10TestModel = + """ namespace com.example use aws.protocols#awsJson1_0 @awsJson1_0 @@ -19,5 +20,5 @@ object BasicTestModels { structure TestInput { foo: String, } - """.asSmithyModel() + """.asSmithyModel() } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/CodegenIntegrationTest.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/CodegenIntegrationTest.kt index 19bd7ddf580..9790b080b65 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/CodegenIntegrationTest.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/CodegenIntegrationTest.kt @@ -32,16 +32,21 @@ data class IntegrationTestParams( /** * Run cargo test on a true, end-to-end, codegen product of a given model. */ -fun codegenIntegrationTest(model: Model, params: IntegrationTestParams, invokePlugin: (PluginContext) -> Unit): Path { - val (ctx, testDir) = generatePluginContext( - model, - params.additionalSettings, - params.addModuleToEventStreamAllowList, - params.moduleVersion, - params.service, - params.runtimeConfig, - params.overrideTestDir, - ) +fun codegenIntegrationTest( + model: Model, + params: IntegrationTestParams, + invokePlugin: (PluginContext) -> Unit, +): Path { + val (ctx, testDir) = + generatePluginContext( + model, + params.additionalSettings, + params.addModuleToEventStreamAllowList, + params.moduleVersion, + params.service, + params.runtimeConfig, + params.overrideTestDir, + ) testDir.writeDotCargoConfigToml(listOf("--deny", "warnings")) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/DefaultBuilderInstantiator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/DefaultBuilderInstantiator.kt index 96af195c764..aa11d70edb7 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/DefaultBuilderInstantiator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/DefaultBuilderInstantiator.kt @@ -19,11 +19,19 @@ import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderInstant * and to serve as the base behavior for client and server instantiators. */ class DefaultBuilderInstantiator(private val checkFallibleBuilder: Boolean, private val symbolProvider: RustSymbolProvider) : BuilderInstantiator { - override fun setField(builder: String, value: Writable, field: MemberShape): Writable { + override fun setField( + builder: String, + value: Writable, + field: MemberShape, + ): Writable { return setFieldWithSetter(builder, value, field) } - override fun finalizeBuilder(builder: String, shape: StructureShape, mapErr: Writable?): Writable { + override fun finalizeBuilder( + builder: String, + shape: StructureShape, + mapErr: Writable?, + ): Writable { return writable { rust("builder.build()") if (checkFallibleBuilder && BuilderGenerator.hasFallibleBuilder(shape, symbolProvider)) { diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamMarshallTestCases.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamMarshallTestCases.kt index 2850f51c721..65ae5019cb5 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamMarshallTestCases.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamMarshallTestCases.kt @@ -23,8 +23,9 @@ object EventStreamMarshallTestCases { ) { val generator = "crate::event_stream_serde::TestStreamMarshaller" - val protocolTestHelpers = CargoDependency.smithyProtocolTestHelpers(TestRuntimeConfig) - .copy(scope = DependencyScope.Compile) + val protocolTestHelpers = + CargoDependency.smithyProtocolTestHelpers(TestRuntimeConfig) + .copy(scope = DependencyScope.Compile) fun builderInput( @Language("Rust", prefix = "macro_rules! foo { () => {{\n", suffix = "\n}}}") diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamTestModels.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamTestModels.kt index e944a552a08..c0f61e07dbe 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamTestModels.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamTestModels.kt @@ -16,7 +16,8 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestXml private fun fillInBaseModel( protocolName: String, extraServiceAnnotations: String = "", -): String = """ +): String = + """ namespace test use smithy.framework#ValidationException @@ -87,14 +88,18 @@ private fun fillInBaseModel( $extraServiceAnnotations @$protocolName service TestService { version: "123", operations: [TestStreamOp] } -""" + """ object EventStreamTestModels { private fun restJson1(): Model = fillInBaseModel("restJson1").asSmithyModel() + private fun restXml(): Model = fillInBaseModel("restXml").asSmithyModel() + private fun awsJson11(): Model = fillInBaseModel("awsJson1_1").asSmithyModel() + private fun awsQuery(): Model = fillInBaseModel("awsQuery", "@xmlNamespace(uri: \"https://example.com\")").asSmithyModel() + private fun ec2Query(): Model = fillInBaseModel("ec2Query", "@xmlNamespace(uri: \"https://example.com\")").asSmithyModel() @@ -114,79 +119,82 @@ object EventStreamTestModels { override fun toString(): String = protocolShapeId } - val TEST_CASES = listOf( - // - // restJson1 - // - TestCase( - protocolShapeId = "aws.protocols#restJson1", - model = restJson1(), - mediaType = "application/json", - requestContentType = "application/vnd.amazon.eventstream", - responseContentType = "application/json", - validTestStruct = """{"someString":"hello","someInt":5}""", - validMessageWithNoHeaderPayloadTraits = """{"someString":"hello","someInt":5}""", - validTestUnion = """{"Foo":"hello"}""", - validSomeError = """{"Message":"some error"}""", - validUnmodeledError = """{"Message":"unmodeled error"}""", - ) { RestJson(it) }, - - // - // awsJson1_1 - // - TestCase( - protocolShapeId = "aws.protocols#awsJson1_1", - model = awsJson11(), - mediaType = "application/x-amz-json-1.1", - requestContentType = "application/x-amz-json-1.1", - responseContentType = "application/x-amz-json-1.1", - validTestStruct = """{"someString":"hello","someInt":5}""", - validMessageWithNoHeaderPayloadTraits = """{"someString":"hello","someInt":5}""", - validTestUnion = """{"Foo":"hello"}""", - validSomeError = """{"Message":"some error"}""", - validUnmodeledError = """{"Message":"unmodeled error"}""", - ) { AwsJson(it, AwsJsonVersion.Json11) }, - - // - // restXml - // - TestCase( - protocolShapeId = "aws.protocols#restXml", - model = restXml(), - mediaType = "application/xml", - requestContentType = "application/vnd.amazon.eventstream", - responseContentType = "application/xml", - validTestStruct = """ - - hello - 5 - - """.trimIndent(), - validMessageWithNoHeaderPayloadTraits = """ - - hello - 5 - - """.trimIndent(), - validTestUnion = "hello", - validSomeError = """ - - - SomeError - SomeError - some error - - - """.trimIndent(), - validUnmodeledError = """ - - - UnmodeledError - UnmodeledError - unmodeled error - - - """.trimIndent(), - ) { RestXml(it) }, - ) + val TEST_CASES = + listOf( + // + // restJson1 + // + TestCase( + protocolShapeId = "aws.protocols#restJson1", + model = restJson1(), + mediaType = "application/json", + requestContentType = "application/vnd.amazon.eventstream", + responseContentType = "application/json", + validTestStruct = """{"someString":"hello","someInt":5}""", + validMessageWithNoHeaderPayloadTraits = """{"someString":"hello","someInt":5}""", + validTestUnion = """{"Foo":"hello"}""", + validSomeError = """{"Message":"some error"}""", + validUnmodeledError = """{"Message":"unmodeled error"}""", + ) { RestJson(it) }, + // + // awsJson1_1 + // + TestCase( + protocolShapeId = "aws.protocols#awsJson1_1", + model = awsJson11(), + mediaType = "application/x-amz-json-1.1", + requestContentType = "application/x-amz-json-1.1", + responseContentType = "application/x-amz-json-1.1", + validTestStruct = """{"someString":"hello","someInt":5}""", + validMessageWithNoHeaderPayloadTraits = """{"someString":"hello","someInt":5}""", + validTestUnion = """{"Foo":"hello"}""", + validSomeError = """{"Message":"some error"}""", + validUnmodeledError = """{"Message":"unmodeled error"}""", + ) { AwsJson(it, AwsJsonVersion.Json11) }, + // + // restXml + // + TestCase( + protocolShapeId = "aws.protocols#restXml", + model = restXml(), + mediaType = "application/xml", + requestContentType = "application/vnd.amazon.eventstream", + responseContentType = "application/xml", + validTestStruct = + """ + + hello + 5 + + """.trimIndent(), + validMessageWithNoHeaderPayloadTraits = + """ + + hello + 5 + + """.trimIndent(), + validTestUnion = "hello", + validSomeError = + """ + + + SomeError + SomeError + some error + + + """.trimIndent(), + validUnmodeledError = + """ + + + UnmodeledError + UnmodeledError + unmodeled error + + + """.trimIndent(), + ) { RestXml(it) }, + ) } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamUnmarshallTestCases.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamUnmarshallTestCases.kt index 3adc813546a..4a94d0af3ae 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamUnmarshallTestCases.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamUnmarshallTestCases.kt @@ -77,13 +77,13 @@ object EventStreamUnmarshallTestCases { expect_event(result.unwrap()) ); """, - "DataInput" to conditionalBuilderInput( - """ - Blob::new(&b"hello, world!"[..]) - """, - conditional = optionalBuilderInputs, - ), - + "DataInput" to + conditionalBuilderInput( + """ + Blob::new(&b"hello, world!"[..]) + """, + conditional = optionalBuilderInputs, + ), ) } @@ -118,18 +118,18 @@ object EventStreamUnmarshallTestCases { expect_event(result.unwrap()) ); """, - "StructInput" to conditionalBuilderInput( - """ - TestStruct::builder() - .some_string(#{StringInput}) - .some_int(#{IntInput}) - .build() - """, - conditional = optionalBuilderInputs, - "StringInput" to conditionalBuilderInput("\"hello\"", conditional = optionalBuilderInputs), - "IntInput" to conditionalBuilderInput("5", conditional = optionalBuilderInputs), - ), - + "StructInput" to + conditionalBuilderInput( + """ + TestStruct::builder() + .some_string(#{StringInput}) + .some_int(#{IntInput}) + .build() + """, + conditional = optionalBuilderInputs, + "StringInput" to conditionalBuilderInput("\"hello\"", conditional = optionalBuilderInputs), + "IntInput" to conditionalBuilderInput("5", conditional = optionalBuilderInputs), + ), ) } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/NamingObstacleCourseTestModels.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/NamingObstacleCourseTestModels.kt index 72979545b94..5508611972e 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/NamingObstacleCourseTestModels.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/NamingObstacleCourseTestModels.kt @@ -16,167 +16,172 @@ object NamingObstacleCourseTestModels { * Test model that confounds the generation machinery by using operations named after every item * in the Rust prelude. */ - fun rustPreludeOperationsModel(): Model = StringBuilder().apply { - append( - """ - ${"$"}version: "2.0" - namespace crate - - use smithy.test#httpRequestTests - use smithy.test#httpResponseTests - use aws.protocols#awsJson1_1 - use aws.api#service - use smithy.framework#ValidationException - - structure InputAndOutput {} - - @awsJson1_1 - @service(sdkId: "Config") - service Config { - version: "2006-03-01", - rename: { "smithy.api#String": "PreludeString" }, - operations: [ - """, - ) - for (item in rustPrelude) { - append("$item,\n") - } - append( - """ - ] + fun rustPreludeOperationsModel(): Model = + StringBuilder().apply { + append( + """ + ${"$"}version: "2.0" + namespace crate + + use smithy.test#httpRequestTests + use smithy.test#httpResponseTests + use aws.protocols#awsJson1_1 + use aws.api#service + use smithy.framework#ValidationException + + structure InputAndOutput {} + + @awsJson1_1 + @service(sdkId: "Config") + service Config { + version: "2006-03-01", + rename: { "smithy.api#String": "PreludeString" }, + operations: [ + """, + ) + for (item in rustPrelude) { + append("$item,\n") } - """, - ) - for (item in rustPrelude) { - append("operation $item { input: InputAndOutput, output: InputAndOutput, errors: [ValidationException] }\n") - } - }.toString().asSmithyModel() - - fun rustPreludeStructsModel(): Model = StringBuilder().apply { - append( - """ - ${"$"}version: "2.0" - namespace crate - - use smithy.test#httpRequestTests - use smithy.test#httpResponseTests - use aws.protocols#awsJson1_1 - use aws.api#service - use smithy.framework#ValidationException - - structure InputAndOutput {} - - @awsJson1_1 - @service(sdkId: "Config") - service Config { - version: "2006-03-01", - rename: { "smithy.api#String": "PreludeString" }, - operations: [ - """, - ) - for (item in rustPrelude) { - append("Use$item,\n") - } - append( - """ - ] + append( + """ + ] + } + """, + ) + for (item in rustPrelude) { + append("operation $item { input: InputAndOutput, output: InputAndOutput, errors: [ValidationException] }\n") } - """, - ) - for (item in rustPrelude) { - append("structure $item { $item: smithy.api#String }\n") - append("operation Use$item { input: $item, output: $item, errors: [ValidationException] }\n") - } - println(toString()) - }.toString().asSmithyModel() - - fun rustPreludeEnumsModel(): Model = StringBuilder().apply { - append( - """ - ${"$"}version: "2.0" - namespace crate - - use smithy.test#httpRequestTests - use smithy.test#httpResponseTests - use aws.protocols#awsJson1_1 - use aws.api#service - use smithy.framework#ValidationException - - structure InputAndOutput {} - - @awsJson1_1 - @service(sdkId: "Config") - service Config { - version: "2006-03-01", - rename: { "smithy.api#String": "PreludeString" }, - operations: [ - """, - ) - for (item in rustPrelude) { - append("Use$item,\n") - } - append( - """ - ] + }.toString().asSmithyModel() + + fun rustPreludeStructsModel(): Model = + StringBuilder().apply { + append( + """ + ${"$"}version: "2.0" + namespace crate + + use smithy.test#httpRequestTests + use smithy.test#httpResponseTests + use aws.protocols#awsJson1_1 + use aws.api#service + use smithy.framework#ValidationException + + structure InputAndOutput {} + + @awsJson1_1 + @service(sdkId: "Config") + service Config { + version: "2006-03-01", + rename: { "smithy.api#String": "PreludeString" }, + operations: [ + """, + ) + for (item in rustPrelude) { + append("Use$item,\n") } - """, - ) - for (item in rustPrelude) { - append("enum $item { $item }\n") - append("structure Struct$item { $item: $item }\n") - append("operation Use$item { input: Struct$item, output: Struct$item, errors: [ValidationException] }\n") - } - }.toString().asSmithyModel() - - fun rustPreludeEnumVariantsModel(): Model = StringBuilder().apply { - append( - """ - ${"$"}version: "2.0" - namespace crate - - use smithy.test#httpRequestTests - use smithy.test#httpResponseTests - use aws.protocols#awsJson1_1 - use aws.api#service - use smithy.framework#ValidationException - - @awsJson1_1 - @service(sdkId: "Config") - service Config { - version: "2006-03-01", - rename: { "smithy.api#String": "PreludeString" }, - operations: [EnumOp] + append( + """ + ] + } + """, + ) + for (item in rustPrelude) { + append("structure $item { $item: smithy.api#String }\n") + append("operation Use$item { input: $item, output: $item, errors: [ValidationException] }\n") } - - operation EnumOp { - input: InputAndOutput, - output: InputAndOutput, - errors: [ValidationException], + println(toString()) + }.toString().asSmithyModel() + + fun rustPreludeEnumsModel(): Model = + StringBuilder().apply { + append( + """ + ${"$"}version: "2.0" + namespace crate + + use smithy.test#httpRequestTests + use smithy.test#httpResponseTests + use aws.protocols#awsJson1_1 + use aws.api#service + use smithy.framework#ValidationException + + structure InputAndOutput {} + + @awsJson1_1 + @service(sdkId: "Config") + service Config { + version: "2006-03-01", + rename: { "smithy.api#String": "PreludeString" }, + operations: [ + """, + ) + for (item in rustPrelude) { + append("Use$item,\n") } - - structure InputAndOutput { - the_enum: TheEnum, + append( + """ + ] + } + """, + ) + for (item in rustPrelude) { + append("enum $item { $item }\n") + append("structure Struct$item { $item: $item }\n") + append("operation Use$item { input: Struct$item, output: Struct$item, errors: [ValidationException] }\n") } - - enum TheEnum { - """, - ) - for (item in rustPrelude) { - append("$item,\n") - } - append( - """ + }.toString().asSmithyModel() + + fun rustPreludeEnumVariantsModel(): Model = + StringBuilder().apply { + append( + """ + ${"$"}version: "2.0" + namespace crate + + use smithy.test#httpRequestTests + use smithy.test#httpResponseTests + use aws.protocols#awsJson1_1 + use aws.api#service + use smithy.framework#ValidationException + + @awsJson1_1 + @service(sdkId: "Config") + service Config { + version: "2006-03-01", + rename: { "smithy.api#String": "PreludeString" }, + operations: [EnumOp] + } + + operation EnumOp { + input: InputAndOutput, + output: InputAndOutput, + errors: [ValidationException], + } + + structure InputAndOutput { + the_enum: TheEnum, + } + + enum TheEnum { + """, + ) + for (item in rustPrelude) { + append("$item,\n") } - """, - ) - }.toString().asSmithyModel() + append( + """ + } + """, + ) + }.toString().asSmithyModel() /** * This targets two bug classes: * - operation inputs used as nested outputs * - operation outputs used as nested outputs */ - fun reusedInputOutputShapesModel(protocol: Trait) = """ + fun reusedInputOutputShapesModel(protocol: Trait) = + """ namespace test use ${protocol.toShapeId()} use aws.api#service @@ -226,5 +231,5 @@ object NamingObstacleCourseTestModels { list GetThingInputList { member: GetThingInput } - """.asSmithyModel() + """.asSmithyModel() } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/Rust.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/Rust.kt index e428ea2e552..d52d65fb54a 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/Rust.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/Rust.kt @@ -46,11 +46,13 @@ import java.nio.file.Path import kotlin.io.path.absolutePathString import kotlin.io.path.writeText -val TestModuleDocProvider = object : ModuleDocProvider { - override fun docsWriter(module: RustModule.LeafModule): Writable = writable { - docs("Some test documentation\n\nSome more details...") +val TestModuleDocProvider = + object : ModuleDocProvider { + override fun docsWriter(module: RustModule.LeafModule): Writable = + writable { + docs("Some test documentation\n\nSome more details...") + } } -} /** * Waiting for Kotlin to stabilize their temp directory functionality @@ -70,11 +72,12 @@ private fun tempDir(directory: File? = null): File { */ object TestWorkspace { private val baseDir by lazy { - val appDataDir = System.getProperty("APPDATA") - ?: System.getenv("XDG_DATA_HOME") - ?: System.getProperty("user.home") - ?.let { Path.of(it, ".local", "share").absolutePathString() } - ?.also { File(it).mkdirs() } + val appDataDir = + System.getProperty("APPDATA") + ?: System.getenv("XDG_DATA_HOME") + ?: System.getProperty("user.home") + ?.let { Path.of(it, ".local", "share").absolutePathString() } + ?.also { File(it).mkdirs() } if (appDataDir != null) { File(Path.of(appDataDir, "smithy-test-workspace").absolutePathString()) } else { @@ -89,13 +92,15 @@ object TestWorkspace { private fun generate() { val cargoToml = baseDir.resolve("Cargo.toml") - val workspaceToml = TomlWriter().write( - mapOf( - "workspace" to mapOf( - "members" to subprojects, + val workspaceToml = + TomlWriter().write( + mapOf( + "workspace" to + mapOf( + "members" to subprojects, + ), ), - ), - ) + ) cargoToml.writeText(workspaceToml) } @@ -173,28 +178,30 @@ fun generatePluginContext( val moduleName = "test_${testDir.nameWithoutExtension}" val testPath = testDir.toPath() val manifest = FileManifest.create(testPath) - var settingsBuilder = Node.objectNodeBuilder() - .withMember("module", Node.from(moduleName)) - .withMember("moduleVersion", Node.from(moduleVersion)) - .withMember("moduleDescription", Node.from("test")) - .withMember("moduleAuthors", Node.fromStrings("testgenerator@smithy.com")) - .letIf(service != null) { it.withMember("service", service) } - .withMember( - "runtimeConfig", - Node.objectNodeBuilder().withMember( - "relativePath", - Node.from(((runtimeConfig ?: TestRuntimeConfig).runtimeCrateLocation).path), - ).build(), - ) + var settingsBuilder = + Node.objectNodeBuilder() + .withMember("module", Node.from(moduleName)) + .withMember("moduleVersion", Node.from(moduleVersion)) + .withMember("moduleDescription", Node.from("test")) + .withMember("moduleAuthors", Node.fromStrings("testgenerator@smithy.com")) + .letIf(service != null) { it.withMember("service", service) } + .withMember( + "runtimeConfig", + Node.objectNodeBuilder().withMember( + "relativePath", + Node.from(((runtimeConfig ?: TestRuntimeConfig).runtimeCrateLocation).path), + ).build(), + ) if (addModuleToEventStreamAllowList) { - settingsBuilder = settingsBuilder.withMember( - "codegen", - Node.objectNodeBuilder().withMember( - "eventStreamAllowList", - Node.fromStrings(moduleName), - ).build(), - ) + settingsBuilder = + settingsBuilder.withMember( + "codegen", + Node.objectNodeBuilder().withMember( + "eventStreamAllowList", + Node.fromStrings(moduleName), + ).build(), + ) } val settings = settingsBuilder.merge(additionalSettings).build() @@ -232,10 +239,14 @@ fun RustWriter.unitTest( return testDependenciesOnly { rustBlock("fn $name()", *args, block = block) } } -fun RustWriter.cargoDependencies() = dependencies.map { RustDependency.fromSymbolDependency(it) } - .filterIsInstance().distinct() +fun RustWriter.cargoDependencies() = + dependencies.map { RustDependency.fromSymbolDependency(it) } + .filterIsInstance().distinct() -fun RustWriter.assertNoNewDependencies(block: Writable, dependencyFilter: (CargoDependency) -> String?): RustWriter { +fun RustWriter.assertNoNewDependencies( + block: Writable, + dependencyFilter: (CargoDependency) -> String?, +): RustWriter { val startingDependencies = cargoDependencies().toSet() block(this) val endingDependencies = cargoDependencies().toSet() @@ -260,19 +271,25 @@ fun RustWriter.assertNoNewDependencies(block: Writable, dependencyFilter: (Cargo return this } -fun RustWriter.testDependenciesOnly(block: Writable) = assertNoNewDependencies(block) { dep -> - if (dep.scope != DependencyScope.Dev) { - "Cannot add $dep — this writer should only add test dependencies." - } else { - null +fun RustWriter.testDependenciesOnly(block: Writable) = + assertNoNewDependencies(block) { dep -> + if (dep.scope != DependencyScope.Dev) { + "Cannot add $dep — this writer should only add test dependencies." + } else { + null + } } -} -fun testDependenciesOnly(block: Writable): Writable = { - testDependenciesOnly(block) -} +fun testDependenciesOnly(block: Writable): Writable = + { + testDependenciesOnly(block) + } -fun RustWriter.tokioTest(name: String, vararg args: Any, block: Writable) { +fun RustWriter.tokioTest( + name: String, + vararg args: Any, + block: Writable, +) { unitTest(name, attribute = Attribute.TokioTest, async = true, block = block, args = args) } @@ -301,13 +318,14 @@ class TestWriterDelegator( * * This should only be used in test code—the generated module name will be something like `tests_123` */ -fun RustCrate.testModule(block: Writable) = lib { - withInlineModule( - RustModule.inlineTests(safeName("tests")), - TestModuleDocProvider, - block, - ) -} +fun RustCrate.testModule(block: Writable) = + lib { + withInlineModule( + RustModule.inlineTests(safeName("tests")), + TestModuleDocProvider, + block, + ) + } fun FileManifest.printGeneratedFiles() { this.files.forEach { path -> @@ -324,12 +342,13 @@ fun TestWriterDelegator.compileAndTest( runClippy: Boolean = false, expectFailure: Boolean = false, ): String { - val stubModel = """ + val stubModel = + """ namespace fake service Fake { version: "123" } - """.asSmithyModel() + """.asSmithyModel() this.finalize( rustSettings(), stubModel, @@ -393,27 +412,31 @@ fun RustWriter.compileAndTest( clippy: Boolean = false, expectFailure: Boolean = false, ): String { - val deps = this.dependencies - .map { RustDependency.fromSymbolDependency(it) } - .filterIsInstance() - .distinct() - .mergeDependencyFeatures() - .mergeIdenticalTestDependencies() - val module = if (this.namespace.contains("::")) { - this.namespace.split("::")[1] - } else { - "lib" - } - val tempDir = this.toString() - .intoCrate(deps, module = module, main = main, strict = clippy) + val deps = + this.dependencies + .map { RustDependency.fromSymbolDependency(it) } + .filterIsInstance() + .distinct() + .mergeDependencyFeatures() + .mergeIdenticalTestDependencies() + val module = + if (this.namespace.contains("::")) { + this.namespace.split("::")[1] + } else { + "lib" + } + val tempDir = + this.toString() + .intoCrate(deps, module = module, main = main, strict = clippy) val mainRs = tempDir.resolve("src/main.rs") val testModule = tempDir.resolve("src/$module.rs") try { - val testOutput = if ((mainRs.readText() + testModule.readText()).contains("#[test]")) { - "cargo test".runCommand(tempDir.toPath()) - } else { - "cargo check".runCommand(tempDir.toPath()) - } + val testOutput = + if ((mainRs.readText() + testModule.readText()).contains("#[test]")) { + "cargo test".runCommand(tempDir.toPath()) + } else { + "cargo check".runCommand(tempDir.toPath()) + } if (expectFailure) { println("Test sources for debugging: file://${testModule.absolutePath}") } @@ -434,18 +457,19 @@ private fun String.intoCrate( ): File { this.shouldParseAsRust() val tempDir = TestWorkspace.subproject() - val cargoToml = RustWriter.toml("Cargo.toml").apply { - CargoTomlGenerator( - moduleName = tempDir.nameWithoutExtension, - moduleVersion = "0.0.1", - moduleAuthors = listOf("Testy McTesterson"), - moduleDescription = null, - moduleLicense = null, - moduleRepository = null, - writer = this, - dependencies = deps, - ).render() - }.toString() + val cargoToml = + RustWriter.toml("Cargo.toml").apply { + CargoTomlGenerator( + moduleName = tempDir.nameWithoutExtension, + moduleVersion = "0.0.1", + moduleAuthors = listOf("Testy McTesterson"), + moduleDescription = null, + moduleLicense = null, + moduleRepository = null, + writer = this, + dependencies = deps, + ).render() + }.toString() tempDir.resolve("Cargo.toml").writeText(cargoToml) tempDir.resolve("src").mkdirs() val mainRs = tempDir.resolve("src/main.rs") @@ -507,7 +531,10 @@ fun String.compileAndRun(vararg strings: String) { binary.absolutePath.runCommand() } -fun RustCrate.integrationTest(name: String, writable: Writable) = this.withFile("tests/$name.rs", writable) +fun RustCrate.integrationTest( + name: String, + writable: Writable, +) = this.withFile("tests/$name.rs", writable) fun TestWriterDelegator.unitTest(test: Writable): TestWriterDelegator { lib { diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/TestHelpers.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/TestHelpers.kt index cbd29d33d46..f5797ba9b6e 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/TestHelpers.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/TestHelpers.kt @@ -50,7 +50,7 @@ import software.amazon.smithy.rust.codegen.core.util.toSnakeCase import java.io.File val TestRuntimeConfig = - RuntimeConfig(runtimeCrateLocation = RuntimeCrateLocation.Path(File("../rust-runtime/").absolutePath)) + RuntimeConfig(runtimeCrateLocation = RuntimeCrateLocation.path(File("../rust-runtime/").absolutePath)) /** * IMPORTANT: You shouldn't need to refer to these directly in code or tests. They are private for a reason. @@ -68,15 +68,19 @@ private object CodegenCoreTestModules { val OperationsTestModule = RustModule.public("test_operation") object TestModuleProvider : ModuleProvider { - override fun moduleForShape(context: ModuleProviderContext, shape: Shape): RustModule.LeafModule = + override fun moduleForShape( + context: ModuleProviderContext, + shape: Shape, + ): RustModule.LeafModule = when (shape) { is OperationShape -> OperationsTestModule - is StructureShape -> when { - shape.hasTrait() -> ErrorsTestModule - shape.hasTrait() -> InputsTestModule - shape.hasTrait() -> OutputsTestModule - else -> ModelsTestModule - } + is StructureShape -> + when { + shape.hasTrait() -> ErrorsTestModule + shape.hasTrait() -> InputsTestModule + shape.hasTrait() -> OutputsTestModule + else -> ModelsTestModule + } else -> ModelsTestModule } @@ -107,12 +111,13 @@ private object CodegenCoreTestModules { } } -fun testRustSymbolProviderConfig(nullabilityCheckMode: NullableIndex.CheckMode) = RustSymbolProviderConfig( - runtimeConfig = TestRuntimeConfig, - renameExceptions = true, - nullabilityCheckMode = nullabilityCheckMode, - moduleProvider = CodegenCoreTestModules.TestModuleProvider, -) +fun testRustSymbolProviderConfig(nullabilityCheckMode: NullableIndex.CheckMode) = + RustSymbolProviderConfig( + runtimeConfig = TestRuntimeConfig, + renameExceptions = true, + nullabilityCheckMode = nullabilityCheckMode, + moduleProvider = CodegenCoreTestModules.TestModuleProvider, + ) fun testRustSettings( service: ShapeId = ShapeId.from("notrelevant#notrelevant"), @@ -139,6 +144,7 @@ fun testRustSettings( ) private const val SmithyVersion = "1.0" + fun String.asSmithyModel( sourceLocation: String? = null, smithyVersion: String = SmithyVersion, @@ -158,18 +164,19 @@ internal fun testSymbolProvider( model: Model, rustReservedWordConfig: RustReservedWordConfig? = null, nullabilityCheckMode: NullableIndex.CheckMode = NullableIndex.CheckMode.CLIENT, -): RustSymbolProvider = SymbolVisitor( - testRustSettings(), - model, - ServiceShape.builder().version("test").id("test#Service").build(), - testRustSymbolProviderConfig(nullabilityCheckMode), -).let { BaseSymbolMetadataProvider(it, additionalAttributes = listOf(Attribute.NonExhaustive)) } - .let { - RustReservedWordSymbolProvider( - it, - rustReservedWordConfig ?: RustReservedWordConfig(emptyMap(), emptyMap(), emptyMap()), - ) - } +): RustSymbolProvider = + SymbolVisitor( + testRustSettings(), + model, + ServiceShape.builder().version("test").id("test#Service").build(), + testRustSymbolProviderConfig(nullabilityCheckMode), + ).let { BaseSymbolMetadataProvider(it, additionalAttributes = listOf(Attribute.NonExhaustive)) } + .let { + RustReservedWordSymbolProvider( + it, + rustReservedWordConfig ?: RustReservedWordConfig(emptyMap(), emptyMap(), emptyMap()), + ) + } // Intentionally only visible to codegen-core since the other modules have their own contexts internal fun testCodegenContext( @@ -178,21 +185,22 @@ internal fun testCodegenContext( settings: CoreRustSettings = testRustSettings(), codegenTarget: CodegenTarget = CodegenTarget.CLIENT, nullabilityCheckMode: NullableIndex.CheckMode = NullableIndex.CheckMode.CLIENT, -): CodegenContext = object : CodegenContext( - model, - testSymbolProvider(model, nullabilityCheckMode = nullabilityCheckMode), - TestModuleDocProvider, - serviceShape - ?: model.serviceShapes.firstOrNull() - ?: ServiceShape.builder().version("test").id("test#Service").build(), - ShapeId.from("test#Protocol"), - settings, - codegenTarget, -) { - override fun builderInstantiator(): BuilderInstantiator { - return DefaultBuilderInstantiator(codegenTarget == CodegenTarget.CLIENT, symbolProvider) +): CodegenContext = + object : CodegenContext( + model, + testSymbolProvider(model, nullabilityCheckMode = nullabilityCheckMode), + TestModuleDocProvider, + serviceShape + ?: model.serviceShapes.firstOrNull() + ?: ServiceShape.builder().version("test").id("test#Service").build(), + ShapeId.from("test#Protocol"), + settings, + codegenTarget, + ) { + override fun builderInstantiator(): BuilderInstantiator { + return DefaultBuilderInstantiator(codegenTarget == CodegenTarget.CLIENT, symbolProvider) + } } -} /** * In tests, we frequently need to generate a struct, a builder, and an impl block to access said builder. @@ -214,7 +222,10 @@ fun StructureShape.renderWithModelBuilder( } } -fun RustCrate.unitTest(name: String? = null, test: Writable) { +fun RustCrate.unitTest( + name: String? = null, + test: Writable, +) { lib { val testName = name ?: safeName("test") unitTest(testName, block = test) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/util/Exec.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/util/Exec.kt index 64223fe5500..296d7bc39f2 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/util/Exec.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/util/Exec.kt @@ -12,17 +12,22 @@ import java.util.logging.Logger data class CommandError(val output: String) : Exception("Command Error\n$output") -fun String.runCommand(workdir: Path? = null, environment: Map = mapOf(), timeout: Long = 3600): String { +fun String.runCommand( + workdir: Path? = null, + environment: Map = mapOf(), + timeout: Long = 3600, +): String { val logger = Logger.getLogger("RunCommand") logger.fine("Invoking comment $this in `$workdir` with env $environment") val start = System.currentTimeMillis() val parts = this.split("\\s".toRegex()) - val builder = ProcessBuilder(*parts.toTypedArray()) - .redirectOutput(ProcessBuilder.Redirect.PIPE) - .redirectError(ProcessBuilder.Redirect.PIPE) - .letIf(workdir != null) { - it.directory(workdir?.toFile()) - } + val builder = + ProcessBuilder(*parts.toTypedArray()) + .redirectOutput(ProcessBuilder.Redirect.PIPE) + .redirectError(ProcessBuilder.Redirect.PIPE) + .letIf(workdir != null) { + it.directory(workdir?.toFile()) + } val env = builder.environment() environment.forEach { (k, v) -> env[k] = v } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/util/LetIf.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/util/LetIf.kt index 89868f7a005..4c6a70535a2 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/util/LetIf.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/util/LetIf.kt @@ -7,7 +7,10 @@ package software.amazon.smithy.rust.codegen.core.util /** * Utility function similar to `let` that conditionally applies [f] only if [cond] is true. */ -fun T.letIf(cond: Boolean, f: (T) -> T): T { +fun T.letIf( + cond: Boolean, + f: (T) -> T, +): T { return if (cond) { f(this) } else { @@ -15,19 +18,23 @@ fun T.letIf(cond: Boolean, f: (T) -> T): T { } } -fun List.extendIf(condition: Boolean, f: () -> T) = if (condition) { +fun List.extendIf( + condition: Boolean, + f: () -> T, +) = if (condition) { this + listOf(f()) } else { this } -fun Boolean.thenSingletonListOf(f: () -> T): List = if (this) { - listOf(f()) -} else { - listOf() -} +fun Boolean.thenSingletonListOf(f: () -> T): List = + if (this) { + listOf(f()) + } else { + listOf() + } /** * Returns this list if it is non-empty otherwise, it returns null */ -fun List.orNullIfEmpty(): List? = this.ifEmpty { null } +fun List.orNullIfEmpty(): List? = this.ifEmpty { null } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/util/Map.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/util/Map.kt index 00c4450a55a..0a214c6f825 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/util/Map.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/util/Map.kt @@ -8,11 +8,13 @@ package software.amazon.smithy.rust.codegen.core.util /** * Deep merges two maps, with the properties of `other` taking priority over the properties of `this`. */ -fun Map.deepMergeWith(other: Map): Map = - deepMergeMaps(this, other) +fun Map.deepMergeWith(other: Map): Map = deepMergeMaps(this, other) @Suppress("UNCHECKED_CAST") -private fun deepMergeMaps(left: Map, right: Map): Map { +private fun deepMergeMaps( + left: Map, + right: Map, +): Map { val result = mutableMapOf() for (leftEntry in left.entries) { val rightValue = right[leftEntry.key] diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/util/Panic.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/util/Panic.kt index 7d24a179883..ea7fe0724a2 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/util/Panic.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/util/Panic.kt @@ -6,7 +6,9 @@ package software.amazon.smithy.rust.codegen.core.util /** Something has gone horribly wrong due to a coding error */ +@Suppress("ktlint:standard:function-naming") fun PANIC(reason: String = ""): Nothing = throw RuntimeException(reason) /** This code should never be executed (but Kotlin cannot prove that) */ +@Suppress("ktlint:standard:function-naming") fun UNREACHABLE(reason: String): Nothing = throw IllegalStateException("This should be unreachable: $reason") diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/util/Smithy.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/util/Smithy.kt index bde5c4b3389..b167f05d2d3 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/util/Smithy.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/util/Smithy.kt @@ -44,12 +44,15 @@ fun StructureShape.expectMember(member: String): MemberShape = fun UnionShape.expectMember(member: String): MemberShape = this.getMember(member).orElseThrow { CodegenException("$member did not exist on $this") } -fun StructureShape.errorMessageMember(): MemberShape? = this.getMember("message").or { - this.getMember("Message") -}.orNull() +fun StructureShape.errorMessageMember(): MemberShape? = + this.getMember("message").or { + this.getMember("Message") + }.orNull() fun StructureShape.hasStreamingMember(model: Model) = this.findStreamingMember(model) != null + fun UnionShape.hasStreamingMember(model: Model) = this.findMemberWithTrait(model) != null + fun MemberShape.isStreaming(model: Model) = this.getMemberTrait(model, StreamingTrait::class.java).isPresent fun UnionShape.isEventStream(): Boolean { @@ -90,9 +93,10 @@ fun OperationShape.isEventStream(model: Model): Boolean { return isInputEventStream(model) || isOutputEventStream(model) } -fun ServiceShape.hasEventStreamOperations(model: Model): Boolean = operations.any { id -> - model.expectShape(id, OperationShape::class.java).isEventStream(model) -} +fun ServiceShape.hasEventStreamOperations(model: Model): Boolean = + operations.any { id -> + model.expectShape(id, OperationShape::class.java).isEventStream(model) + } fun Shape.shouldRedact(model: Model): Boolean = when (this) { @@ -102,7 +106,10 @@ fun Shape.shouldRedact(model: Model): Boolean = const val REDACTION = "\"*** Sensitive Data Redacted ***\"" -fun Shape.redactIfNecessary(model: Model, safeToPrint: String): String = +fun Shape.redactIfNecessary( + model: Model, + safeToPrint: String, +): String = if (this.shouldRedact(model)) { REDACTION } else { @@ -149,5 +156,4 @@ fun String.shapeId() = ShapeId.from(this) fun ServiceShape.serviceNameOrDefault(default: String) = getTrait()?.value ?: default /** Returns the SDK ID of the given service shape */ -fun ServiceShape.sdkId(): String = - getTrait()?.sdkId?.lowercase()?.replace(" ", "") ?: id.getName(this) +fun ServiceShape.sdkId(): String = getTrait()?.sdkId?.lowercase()?.replace(" ", "") ?: id.getName(this) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/util/Strings.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/util/Strings.kt index d6668500280..a6e15baccb2 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/util/Strings.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/util/Strings.kt @@ -32,21 +32,25 @@ private fun String.splitOnWordBoundaries(): List { if (currentWord.isNotEmpty()) { out += currentWord.lowercase() } - currentWord = if (next.isLetterOrDigit()) { - next.toString() - } else { - "" - } + currentWord = + if (next.isLetterOrDigit()) { + next.toString() + } else { + "" + } } val allLowerCase = this.lowercase() == this this.forEachIndexed { index, nextCharacter -> val computeWordInProgress = { - val result = completeWordInProgress && currentWord.isNotEmpty() && completeWords.any { - it.startsWith(currentWord, ignoreCase = true) && (currentWord + this.substring(index)).startsWith( - it, - ignoreCase = true, - ) && !it.equals(currentWord, ignoreCase = true) - } + val result = + completeWordInProgress && currentWord.isNotEmpty() && + completeWords.any { + it.startsWith(currentWord, ignoreCase = true) && + (currentWord + this.substring(index)).startsWith( + it, + ignoreCase = true, + ) && !it.equals(currentWord, ignoreCase = true) + } completeWordInProgress = result result @@ -63,9 +67,10 @@ private fun String.splitOnWordBoundaries(): List { !computeWordInProgress() && loweredFollowedByUpper(currentWord, nextCharacter) -> emit(nextCharacter) // s3[k]ey - !computeWordInProgress() && allLowerCase && digitFollowedByLower(currentWord, nextCharacter) -> emit( - nextCharacter, - ) + !computeWordInProgress() && allLowerCase && digitFollowedByLower(currentWord, nextCharacter) -> + emit( + nextCharacter, + ) // DB[P]roxy, or `IAM[U]ser` but not AC[L]s endOfAcronym(currentWord, nextCharacter, this.getOrNull(index + 1), this.getOrNull(index + 2)) -> emit(nextCharacter) @@ -83,7 +88,12 @@ private fun String.splitOnWordBoundaries(): List { /** * Handle cases like `DB[P]roxy`, `ARN[S]upport`, `AC[L]s` */ -private fun endOfAcronym(current: String, nextChar: Char, peek: Char?, doublePeek: Char?): Boolean { +private fun endOfAcronym( + current: String, + nextChar: Char, + peek: Char?, + doublePeek: Char?, +): Boolean { if (!current.last().isUpperCase()) { // Not an acronym in progress return false @@ -109,14 +119,20 @@ private fun endOfAcronym(current: String, nextChar: Char, peek: Char?, doublePee return true } -private fun loweredFollowedByUpper(current: String, nextChar: Char): Boolean { +private fun loweredFollowedByUpper( + current: String, + nextChar: Char, +): Boolean { if (!nextChar.isUpperCase()) { return false } return current.last().isLowerCase() || current.last().isDigit() } -private fun digitFollowedByLower(current: String, nextChar: Char): Boolean { +private fun digitFollowedByLower( + current: String, + nextChar: Char, +): Boolean { return (current.last().isDigit() && nextChar.isLowerCase()) } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/util/Synthetics.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/util/Synthetics.kt index f0746701e3f..4ecc62361d4 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/util/Synthetics.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/util/Synthetics.kt @@ -20,18 +20,20 @@ fun Model.Builder.cloneOperation( idTransform: (ShapeId) -> ShapeId, ): Model.Builder { val operationShape = model.expectShape(oldOperation.toShapeId(), OperationShape::class.java) - val inputShape = model.expectShape( - checkNotNull(operationShape.input.orNull()) { - "cloneOperation expects OperationNormalizer to be run first to add input shapes to all operations" - }, - StructureShape::class.java, - ) - val outputShape = model.expectShape( - checkNotNull(operationShape.output.orNull()) { - "cloneOperation expects OperationNormalizer to be run first to add output shapes to all operations" - }, - StructureShape::class.java, - ) + val inputShape = + model.expectShape( + checkNotNull(operationShape.input.orNull()) { + "cloneOperation expects OperationNormalizer to be run first to add input shapes to all operations" + }, + StructureShape::class.java, + ) + val outputShape = + model.expectShape( + checkNotNull(operationShape.output.orNull()) { + "cloneOperation expects OperationNormalizer to be run first to add output shapes to all operations" + }, + StructureShape::class.java, + ) val inputId = idTransform(inputShape.id) addShape(inputShape.toBuilder().rename(inputId).build()) @@ -54,8 +56,9 @@ fun Model.Builder.cloneOperation( * Renames a StructureShape builder and automatically fixes all the members. */ fun StructureShape.Builder.rename(newId: ShapeId): StructureShape.Builder { - val renamedMembers = this.build().members().map { - it.toBuilder().id(newId.withMember(it.memberName)).build() - } + val renamedMembers = + this.build().members().map { + it.toBuilder().id(newId.withMember(it.memberName)).build() + } return this.id(newId).members(renamedMembers) } diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/VersionTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/VersionTest.kt index 2147dc857ce..31e91398bc2 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/VersionTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/VersionTest.kt @@ -26,26 +26,25 @@ class VersionTest { @ParameterizedTest() @MethodSource("invalidVersionProvider") - fun `fails to parse version`( - content: String, - ) { + fun `fails to parse version`(content: String) { shouldThrowAny { Version.parse(content) } } companion object { @JvmStatic - fun versionProvider() = listOf( - Arguments.of( - """{ "stableVersion": "1.0.1", "unstableVersion": "0.60.1","githash": "0198d26096eb1af510ce24766c921ffc5e4c191e", "runtimeCrates": {} }""", - "1.0.1-0198d26096eb1af510ce24766c921ffc5e4c191e", - "1.0.1", - ), - Arguments.of( - """{ "unstableVersion": "0.60.1", "stableVersion": "release-2022-08-04", "githash": "db48039065bec890ef387385773b37154b555b14", "runtimeCrates": {} }""", - "release-2022-08-04-db48039065bec890ef387385773b37154b555b14", - "release-2022-08-04", - ), - ) + fun versionProvider() = + listOf( + Arguments.of( + """{ "stableVersion": "1.0.1", "unstableVersion": "0.60.1","githash": "0198d26096eb1af510ce24766c921ffc5e4c191e", "runtimeCrates": {} }""", + "1.0.1-0198d26096eb1af510ce24766c921ffc5e4c191e", + "1.0.1", + ), + Arguments.of( + """{ "unstableVersion": "0.60.1", "stableVersion": "release-2022-08-04", "githash": "db48039065bec890ef387385773b37154b555b14", "runtimeCrates": {} }""", + "release-2022-08-04-db48039065bec890ef387385773b37154b555b14", + "release-2022-08-04", + ), + ) @JvmStatic fun invalidVersionProvider() = listOf("0.0.0", "") diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/InlineDependencyTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/InlineDependencyTest.kt index 4e341c769f8..af54e494ee8 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/InlineDependencyTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/InlineDependencyTest.kt @@ -17,9 +17,10 @@ import software.amazon.smithy.rust.codegen.core.testutil.unitTest import kotlin.io.path.pathString internal class InlineDependencyTest { - private fun makeDep(name: String) = InlineDependency(name, RustModule.private("module")) { - rustBlock("fn foo()") {} - } + private fun makeDep(name: String) = + InlineDependency(name, RustModule.private("module")) { + rustBlock("fn foo()") {} + } @Test fun `equal dependencies should be equal`() { @@ -60,13 +61,14 @@ internal class InlineDependencyTest { val a = RustModule.public("a") val b = RustModule.public("b", parent = a) val c = RustModule.public("c", parent = b) - val type = RuntimeType.forInlineFun("forty2", c) { - rust( - """ - pub fn forty2() -> usize { 42 } - """, - ) - } + val type = + RuntimeType.forInlineFun("forty2", c) { + rust( + """ + pub fn forty2() -> usize { 42 } + """, + ) + } val crate = TestWorkspace.testProject() crate.lib { unitTest("use_nested_module") { diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustGenericsTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustGenericsTest.kt index a8eb41353f5..6d0e70d1268 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustGenericsTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustGenericsTest.kt @@ -57,44 +57,49 @@ class RustGenericsTest { @Test fun `bounds is correct for several args`() { - val gg = RustGenerics( - GenericTypeArg("A", testRT("Apple")), - GenericTypeArg("PL", testRT("Plum")), - GenericTypeArg("PE", testRT("Pear")), - ) + val gg = + RustGenerics( + GenericTypeArg("A", testRT("Apple")), + GenericTypeArg("PL", testRT("Plum")), + GenericTypeArg("PE", testRT("Pear")), + ) val writer = RustWriter.forModule("model") writer.rustTemplate("#{bounds:W}", "bounds" to gg.bounds()) - writer.toString() shouldContain """ + writer.toString() shouldContain + """ A: test::Apple, PL: test::Plum, PE: test::Pear, - """.trimIndent() + """.trimIndent() } @Test fun `bounds skips arg with no bounds`() { - val gg = RustGenerics( - GenericTypeArg("A", testRT("Apple")), - GenericTypeArg("PL"), - GenericTypeArg("PE", testRT("Pear")), - ) + val gg = + RustGenerics( + GenericTypeArg("A", testRT("Apple")), + GenericTypeArg("PL"), + GenericTypeArg("PE", testRT("Pear")), + ) val writer = RustWriter.forModule("model") writer.rustTemplate("#{bounds:W}", "bounds" to gg.bounds()) - writer.toString() shouldContain """ + writer.toString() shouldContain + """ A: test::Apple, PE: test::Pear, - """.trimIndent() + """.trimIndent() } @Test fun `bounds generates nothing if all args are skipped`() { - val gg = RustGenerics( - GenericTypeArg("A"), - GenericTypeArg("PL"), - GenericTypeArg("PE"), - ) + val gg = + RustGenerics( + GenericTypeArg("A"), + GenericTypeArg("PL"), + GenericTypeArg("PE"), + ) val writer = RustWriter.forModule("model") writer.rustTemplate("A#{bounds:W}B", "bounds" to gg.bounds()) @@ -103,19 +108,22 @@ class RustGenericsTest { @Test fun `Adding GenericGenerators works`() { - val ggA = RustGenerics( - GenericTypeArg("A", testRT("Apple")), - ) - val ggB = RustGenerics( - GenericTypeArg("B", testRT("Banana")), - ) + val ggA = + RustGenerics( + GenericTypeArg("A", testRT("Apple")), + ) + val ggB = + RustGenerics( + GenericTypeArg("B", testRT("Banana")), + ) RustWriter.forModule("model").let { it.rustTemplate("#{bounds:W}", "bounds" to (ggA + ggB).bounds()) - it.toString() shouldContain """ + it.toString() shouldContain + """ A: test::Apple, B: test::Banana, - """.trimIndent() + """.trimIndent() } RustWriter.forModule("model").let { diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustReservedWordsTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustReservedWordsTest.kt index ec97799af42..99b0217eb9a 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustReservedWordsTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustReservedWordsTest.kt @@ -23,20 +23,28 @@ import software.amazon.smithy.rust.codegen.core.util.lookup internal class RustReservedWordSymbolProviderTest { private class TestSymbolProvider(model: Model, nullabilityCheckMode: NullableIndex.CheckMode) : WrappingSymbolProvider(SymbolVisitor(testRustSettings(), model, null, testRustSymbolProviderConfig(nullabilityCheckMode))) + private val emptyConfig = RustReservedWordConfig(emptyMap(), emptyMap(), emptyMap()) @Test fun `structs are escaped`() { - val model = """ + val model = + """ namespace test structure Self {} - """.asSmithyModel() - val provider = RustReservedWordSymbolProvider(TestSymbolProvider(model, NullableIndex.CheckMode.CLIENT), emptyConfig) + """.asSmithyModel() + val provider = + RustReservedWordSymbolProvider(TestSymbolProvider(model, NullableIndex.CheckMode.CLIENT), emptyConfig) val symbol = provider.toSymbol(model.lookup("test#Self")) symbol.name shouldBe "SelfValue" } - private fun mappingTest(config: RustReservedWordConfig, model: Model, id: String, test: (String) -> Unit) { + private fun mappingTest( + config: RustReservedWordConfig, + model: Model, + id: String, + test: (String) -> Unit, + ) { val provider = RustReservedWordSymbolProvider(TestSymbolProvider(model, NullableIndex.CheckMode.CLIENT), config) val symbol = provider.toMemberName(model.lookup("test#Container\$$id")) test(symbol) @@ -44,39 +52,44 @@ internal class RustReservedWordSymbolProviderTest { @Test fun `structs member names are mapped via config`() { - val config = emptyConfig.copy( - structureMemberMap = mapOf( - "name_to_map" to "mapped_name", - "NameToMap" to "MappedName", - ), - ) - var model = """ + val config = + emptyConfig.copy( + structureMemberMap = + mapOf( + "name_to_map" to "mapped_name", + "NameToMap" to "MappedName", + ), + ) + var model = + """ namespace test structure Container { name_to_map: String } - """.asSmithyModel() + """.asSmithyModel() mappingTest(config, model, "name_to_map") { memberName -> memberName shouldBe "mapped_name" } - model = """ + model = + """ namespace test enum Container { NameToMap = "NameToMap" } - """.asSmithyModel(smithyVersion = "2.0") + """.asSmithyModel(smithyVersion = "2.0") mappingTest(config, model, "NameToMap") { memberName -> // Container was not a struct, so the field keeps its old name memberName shouldBe "NameToMap" } - model = """ + model = + """ namespace test union Container { NameToMap: String } - """.asSmithyModel() + """.asSmithyModel() mappingTest(config, model, "NameToMap") { memberName -> // Container was not a struct, so the field keeps its old name memberName shouldBe "NameToMap" @@ -85,40 +98,45 @@ internal class RustReservedWordSymbolProviderTest { @Test fun `union member names are mapped via config`() { - val config = emptyConfig.copy( - unionMemberMap = mapOf( - "name_to_map" to "mapped_name", - "NameToMap" to "MappedName", - ), - ) - - var model = """ + val config = + emptyConfig.copy( + unionMemberMap = + mapOf( + "name_to_map" to "mapped_name", + "NameToMap" to "MappedName", + ), + ) + + var model = + """ namespace test union Container { NameToMap: String } - """.asSmithyModel() + """.asSmithyModel() mappingTest(config, model, "NameToMap") { memberName -> memberName shouldBe "MappedName" } - model = """ + model = + """ namespace test structure Container { name_to_map: String } - """.asSmithyModel() + """.asSmithyModel() mappingTest(config, model, "name_to_map") { memberName -> // Container was not a union, so the field keeps its old name memberName shouldBe "name_to_map" } - model = """ + model = + """ namespace test enum Container { NameToMap = "NameToMap" } - """.asSmithyModel(smithyVersion = "2.0") + """.asSmithyModel(smithyVersion = "2.0") mappingTest(config, model, "NameToMap") { memberName -> // Container was not a union, so the field keeps its old name memberName shouldBe "NameToMap" @@ -127,13 +145,15 @@ internal class RustReservedWordSymbolProviderTest { @Test fun `member names are escaped`() { - val model = """ + val model = + """ namespace namespace structure container { async: String } - """.asSmithyModel() - val provider = RustReservedWordSymbolProvider(TestSymbolProvider(model, NullableIndex.CheckMode.CLIENT), emptyConfig) + """.asSmithyModel() + val provider = + RustReservedWordSymbolProvider(TestSymbolProvider(model, NullableIndex.CheckMode.CLIENT), emptyConfig) provider.toMemberName( MemberShape.builder().id("namespace#container\$async").target("namespace#Integer").build(), ) shouldBe "r##async" @@ -145,27 +165,35 @@ internal class RustReservedWordSymbolProviderTest { @Test fun `enum variant names are updated to avoid conflicts`() { - val model = """ + val model = + """ namespace foo @enum([{ name: "dontcare", value: "dontcare" }]) string Container - """.asSmithyModel() - val provider = RustReservedWordSymbolProvider( - TestSymbolProvider(model, NullableIndex.CheckMode.CLIENT), - reservedWordConfig = emptyConfig.copy( - enumMemberMap = mapOf( - "Unknown" to "UnknownValue", - "UnknownValue" to "UnknownValue_", - ), - ), - ) - - fun expectEnumRename(original: String, expected: MaybeRenamed) { - val symbol = provider.toSymbol( - MemberShape.builder() - .id(ShapeId.fromParts("foo", "Container").withMember(original)) - .target("smithy.api#String") - .build(), + """.asSmithyModel() + val provider = + RustReservedWordSymbolProvider( + TestSymbolProvider(model, NullableIndex.CheckMode.CLIENT), + reservedWordConfig = + emptyConfig.copy( + enumMemberMap = + mapOf( + "Unknown" to "UnknownValue", + "UnknownValue" to "UnknownValue_", + ), + ), ) + + fun expectEnumRename( + original: String, + expected: MaybeRenamed, + ) { + val symbol = + provider.toSymbol( + MemberShape.builder() + .id(ShapeId.fromParts("foo", "Container").withMember(original)) + .target("smithy.api#String") + .build(), + ) symbol.name shouldBe expected.name symbol.renamedFrom() shouldBe expected.renamedFrom } diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustTypeTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustTypeTest.kt index fcd9583d497..c0a81cd90cd 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustTypeTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustTypeTest.kt @@ -19,7 +19,10 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.util.dq internal class RustTypesTest { - private fun forInputExpectOutput(t: Writable, expectedOutput: String) { + private fun forInputExpectOutput( + t: Writable, + expectedOutput: String, + ) { val writer = RustWriter.forModule("rust_types") writer.rustInlineTemplate("'") t.invoke(writer) @@ -152,17 +155,18 @@ internal class RustTypesTest { @Test fun `attribute macros from strings render properly`() { - val attributeMacro = Attribute( - Attribute.cfg( - Attribute.all( - Attribute.pair("feature" to "unstable".dq()), - Attribute.any( - Attribute.pair("feature" to "serialize".dq()), - Attribute.pair("feature" to "deserialize".dq()), + val attributeMacro = + Attribute( + Attribute.cfg( + Attribute.all( + Attribute.pair("feature" to "unstable".dq()), + Attribute.any( + Attribute.pair("feature" to "serialize".dq()), + Attribute.pair("feature" to "deserialize".dq()), + ), ), ), - ), - ) + ) forInputExpectOutput( writable { attributeMacro.render(this) @@ -173,16 +177,17 @@ internal class RustTypesTest { @Test fun `attribute macros render writers properly`() { - val attributeMacro = Attribute( - cfg( - all( - // Normally we'd use the `pair` fn to define these but this is a test - writable { rustInline("""feature = "unstable"""") }, - writable { rustInline("""feature = "serialize"""") }, - writable { rustInline("""feature = "deserialize"""") }, + val attributeMacro = + Attribute( + cfg( + all( + // Normally we'd use the `pair` fn to define these but this is a test + writable { rustInline("""feature = "unstable"""") }, + writable { rustInline("""feature = "serialize"""") }, + writable { rustInline("""feature = "deserialize"""") }, + ), ), - ), - ) + ) forInputExpectOutput( writable { attributeMacro.render(this) @@ -200,13 +205,14 @@ internal class RustTypesTest { @Test fun `derive attribute macros render properly`() { - val attributeMacro = Attribute( - derive( - RuntimeType.Clone, - RuntimeType.Debug, - RuntimeType.StdError, - ), - ) + val attributeMacro = + Attribute( + derive( + RuntimeType.Clone, + RuntimeType.Debug, + RuntimeType.StdError, + ), + ) forInputExpectOutput( writable { attributeMacro.render(this) diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustWriterTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustWriterTest.kt index 2bd5269cc25..18f8cbfadc3 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustWriterTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustWriterTest.kt @@ -46,14 +46,16 @@ class RustWriterTest { @Test fun `manually created struct`() { val stringShape = StringShape.builder().id("test#Hello").build() - val set = SetShape.builder() - .id("foo.bar#Records") - .member(stringShape.id) - .build() - val model = Model.assembler() - .addShapes(set, stringShape) - .assemble() - .unwrap() + val set = + SetShape.builder() + .id("foo.bar#Records") + .member(stringShape.id) + .build() + val model = + Model.assembler() + .addShapes(set, stringShape) + .assemble() + .unwrap() val provider = testSymbolProvider(model) val setSymbol = provider.toSymbol(set) @@ -99,10 +101,11 @@ class RustWriterTest { @Test fun `generate doc links`() { - val model = """ + val model = + """ namespace test structure Foo {} - """.asSmithyModel() + """.asSmithyModel() val shape = model.lookup("test#Foo") val symbol = testSymbolProvider(model).toSymbol(shape) val writer = RustWriter.root() @@ -136,11 +139,12 @@ class RustWriterTest { @Test fun `attributes with derive helpers must come after derives`() { val attr = Attribute("foo", isDeriveHelper = true) - val metadata = RustMetadata( - derives = setOf(RuntimeType.Debug), - additionalAttributes = listOf(Attribute.AllowDeprecated, attr), - visibility = Visibility.PUBLIC, - ) + val metadata = + RustMetadata( + derives = setOf(RuntimeType.Debug), + additionalAttributes = listOf(Attribute.AllowDeprecated, attr), + visibility = Visibility.PUBLIC, + ) val sut = RustWriter.root() metadata.render(sut) sut.toString().shouldContain("#[allow(deprecated)]\n#[derive(::std::fmt::Debug)]\n#[foo]") @@ -189,13 +193,14 @@ class RustWriterTest { @Test fun `missing template parameters are enclosed in backticks in the exception message`() { val sut = RustWriter.root() - val exception = assertThrows { - sut.rustTemplate( - "#{Foo} #{Bar}", - "Foo Bar" to CargoDependency.Http.toType().resolve("foo"), - "Baz" to CargoDependency.Http.toType().resolve("foo"), - ) - } + val exception = + assertThrows { + sut.rustTemplate( + "#{Foo} #{Bar}", + "Foo Bar" to CargoDependency.Http.toType().resolve("foo"), + "Baz" to CargoDependency.Http.toType().resolve("foo"), + ) + } exception.message shouldBe """ Rust block template expected `Foo` but was not present in template. diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/WritableTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/WritableTest.kt index a9b45582ef5..b847768c55a 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/WritableTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/WritableTest.kt @@ -11,7 +11,10 @@ import org.junit.jupiter.api.Test import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType internal class RustTypeParametersTest { - private fun forInputExpectOutput(input: Any, expectedOutput: String) { + private fun forInputExpectOutput( + input: Any, + expectedOutput: String, + ) { val writer = RustWriter.forModule("model") writer.rustInlineTemplate("'") writer.rustInlineTemplate("#{typeParameters:W}", "typeParameters" to rustTypeParameters(input)) @@ -50,13 +53,14 @@ internal class RustTypeParametersTest { @Test fun `rustTypeParameters accepts heterogeneous inputs`() { val writer = RustWriter.forModule("model") - val tps = rustTypeParameters( - RuntimeType("crate::operation::Operation").toSymbol(), - RustType.Unit, - RuntimeType.String, - "T", - RustGenerics(GenericTypeArg("A"), GenericTypeArg("B")), - ) + val tps = + rustTypeParameters( + RuntimeType("crate::operation::Operation").toSymbol(), + RustType.Unit, + RuntimeType.String, + "T", + RustGenerics(GenericTypeArg("A"), GenericTypeArg("B")), + ) writer.rustInlineTemplate("'") writer.rustInlineTemplate("#{tps:W}", "tps" to tps) writer.rustInlineTemplate("'") diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/CodegenDelegatorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/CodegenDelegatorTest.kt index 6ed2fc5ed32..e1d6074bdcb 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/CodegenDelegatorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/CodegenDelegatorTest.kt @@ -22,53 +22,56 @@ class CodegenDelegatorTest { CargoDependency("A", CratesIo("1"), Compile, optional = false, features = setOf("f1")), CargoDependency("A", CratesIo("1"), Compile, optional = false, features = setOf("f2")), CargoDependency("A", CratesIo("1"), Compile, optional = false, features = setOf("f1", "f2")), - CargoDependency("B", CratesIo("2"), Compile, optional = false, features = setOf()), CargoDependency("B", CratesIo("2"), Compile, optional = true, features = setOf()), - CargoDependency("C", CratesIo("3"), Compile, optional = true, features = setOf()), CargoDependency("C", CratesIo("3"), Compile, optional = true, features = setOf()), ).shuffled().mergeDependencyFeatures() - merged shouldBe setOf( - CargoDependency("A", CratesIo("1"), Compile, optional = false, features = setOf("f1", "f2")), - CargoDependency("B", CratesIo("2"), Compile, optional = false, features = setOf()), - CargoDependency("C", CratesIo("3"), Compile, optional = true, features = setOf()), - ) + merged shouldBe + setOf( + CargoDependency("A", CratesIo("1"), Compile, optional = false, features = setOf("f1", "f2")), + CargoDependency("B", CratesIo("2"), Compile, optional = false, features = setOf()), + CargoDependency("C", CratesIo("3"), Compile, optional = true, features = setOf()), + ) } @RepeatedTest(10) // Test it several times since the shuffle adds in some randomness fun testMergeDependencyFeaturesDontMergeDevOnlyFeatures() { - val merged = listOf( - CargoDependency("A", CratesIo("1"), Compile, features = setOf("a")), - CargoDependency("A", CratesIo("1"), Compile, features = setOf("b")), - CargoDependency("A", CratesIo("1"), Dev, features = setOf("c")), - CargoDependency("A", CratesIo("1"), Dev, features = setOf("test-util")), - ).shuffled().mergeDependencyFeatures() - .sortedBy { it.scope } + val merged = + listOf( + CargoDependency("A", CratesIo("1"), Compile, features = setOf("a")), + CargoDependency("A", CratesIo("1"), Compile, features = setOf("b")), + CargoDependency("A", CratesIo("1"), Dev, features = setOf("c")), + CargoDependency("A", CratesIo("1"), Dev, features = setOf("test-util")), + ).shuffled().mergeDependencyFeatures() + .sortedBy { it.scope } - merged shouldBe setOf( - CargoDependency("A", CratesIo("1"), Compile, features = setOf("a", "b")), - CargoDependency("A", CratesIo("1"), Dev, features = setOf("c", "test-util")), - ) + merged shouldBe + setOf( + CargoDependency("A", CratesIo("1"), Compile, features = setOf("a", "b")), + CargoDependency("A", CratesIo("1"), Dev, features = setOf("c", "test-util")), + ) } @Test fun testMergeIdenticalFeatures() { - val merged = listOf( - CargoDependency("A", CratesIo("1"), Compile), - CargoDependency("A", CratesIo("1"), Dev), - CargoDependency("B", CratesIo("1"), Compile), - CargoDependency("B", CratesIo("1"), Dev, features = setOf("a", "b")), - CargoDependency("C", CratesIo("1"), Compile), - CargoDependency("C", CratesIo("1"), Dev, features = setOf("test-util")), - ).mergeIdenticalTestDependencies() - merged shouldBe setOf( - CargoDependency("A", CratesIo("1"), Compile), - CargoDependency("B", CratesIo("1"), Compile), - CargoDependency("B", CratesIo("1"), Dev, features = setOf("a", "b")), - CargoDependency("C", CratesIo("1"), Compile), - CargoDependency("C", CratesIo("1"), Dev, features = setOf("test-util")), - ) + val merged = + listOf( + CargoDependency("A", CratesIo("1"), Compile), + CargoDependency("A", CratesIo("1"), Dev), + CargoDependency("B", CratesIo("1"), Compile), + CargoDependency("B", CratesIo("1"), Dev, features = setOf("a", "b")), + CargoDependency("C", CratesIo("1"), Compile), + CargoDependency("C", CratesIo("1"), Dev, features = setOf("test-util")), + ).mergeIdenticalTestDependencies() + merged shouldBe + setOf( + CargoDependency("A", CratesIo("1"), Compile), + CargoDependency("B", CratesIo("1"), Compile), + CargoDependency("B", CratesIo("1"), Dev, features = setOf("a", "b")), + CargoDependency("C", CratesIo("1"), Compile), + CargoDependency("C", CratesIo("1"), Dev, features = setOf("test-util")), + ) } } diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeTypeTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeTypeTest.kt index b3dec08b7c4..3bd22665b2a 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeTypeTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeTypeTest.kt @@ -42,51 +42,52 @@ class RuntimeTypesTest { crateLoc.crateLocation("aws-smithy-runtime-api") shouldBe Local("/foo", null) crateLoc.crateLocation("aws-smithy-http") shouldBe Local("/foo", null) - val crateLocVersioned = RuntimeCrateLocation(null, CrateVersionMap(mapOf("aws-smithy-runtime-api" to "999.999"))) + val crateLocVersioned = + RuntimeCrateLocation(null, CrateVersionMap(mapOf("aws-smithy-runtime-api" to "999.999"))) crateLocVersioned.crateLocation("aws-smithy-runtime") shouldBe CratesIo(Version.stableCrateVersion()) crateLocVersioned.crateLocation("aws-smithy-runtime-api") shouldBe CratesIo("999.999") crateLocVersioned.crateLocation("aws-smithy-http") shouldBe CratesIo(Version.unstableCrateVersion()) } companion object { - @JvmStatic - fun runtimeConfigProvider() = listOf( - Arguments.of( - "{}", - RuntimeCrateLocation(null, CrateVersionMap(mapOf())), - ), - Arguments.of( - """ - { - "relativePath": "/path" - } - """, - RuntimeCrateLocation("/path", CrateVersionMap(mapOf())), - ), - Arguments.of( - """ - { - "versions": { - "a": "1.0", - "b": "2.0" + fun runtimeConfigProvider() = + listOf( + Arguments.of( + "{}", + RuntimeCrateLocation(null, CrateVersionMap(mapOf())), + ), + Arguments.of( + """ + { + "relativePath": "/path" + } + """, + RuntimeCrateLocation("/path", CrateVersionMap(mapOf())), + ), + Arguments.of( + """ + { + "versions": { + "a": "1.0", + "b": "2.0" + } } - } - """, - RuntimeCrateLocation(null, CrateVersionMap(mapOf("a" to "1.0", "b" to "2.0"))), - ), - Arguments.of( - """ - { - "relativePath": "/path", - "versions": { - "a": "1.0", - "b": "2.0" + """, + RuntimeCrateLocation(null, CrateVersionMap(mapOf("a" to "1.0", "b" to "2.0"))), + ), + Arguments.of( + """ + { + "relativePath": "/path", + "versions": { + "a": "1.0", + "b": "2.0" + } } - } - """, - RuntimeCrateLocation("/path", CrateVersionMap(mapOf("a" to "1.0", "b" to "2.0"))), - ), - ) + """, + RuntimeCrateLocation("/path", CrateVersionMap(mapOf("a" to "1.0", "b" to "2.0"))), + ), + ) } } diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolVisitorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolVisitorTest.kt index 06485823c69..4821ecaf0c6 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolVisitorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolVisitorTest.kt @@ -40,14 +40,16 @@ class SymbolVisitorTest { fun `creates structures`() { val memberBuilder = MemberShape.builder().id("foo.bar#MyStruct\$someField").target("smithy.api#String") val member = memberBuilder.build() - val struct = StructureShape.builder() - .id("foo.bar#MyStruct") - .addMember(member) - .build() - val model = Model.assembler() - .addShapes(struct, member) - .assemble() - .unwrap() + val struct = + StructureShape.builder() + .id("foo.bar#MyStruct") + .addMember(member) + .build() + val model = + Model.assembler() + .addShapes(struct, member) + .assemble() + .unwrap() val provider: SymbolProvider = testSymbolProvider(model) val sym = provider.toSymbol(struct) sym.rustType().render(false) shouldBe "MyStruct" @@ -59,15 +61,17 @@ class SymbolVisitorTest { fun `renames errors`() { val memberBuilder = MemberShape.builder().id("foo.bar#TerribleException\$someField").target("smithy.api#String") val member = memberBuilder.build() - val struct = StructureShape.builder() - .id("foo.bar#TerribleException") - .addMember(member) - .addTrait(ErrorTrait("server")) - .build() - val model = Model.assembler() - .addShapes(struct, member) - .assemble() - .unwrap() + val struct = + StructureShape.builder() + .id("foo.bar#TerribleException") + .addMember(member) + .addTrait(ErrorTrait("server")) + .build() + val model = + Model.assembler() + .addShapes(struct, member) + .assemble() + .unwrap() val provider: SymbolProvider = testSymbolProvider(model) val sym = provider.toSymbol(struct) sym.rustType().render(false) shouldBe "TerribleError" @@ -76,7 +80,8 @@ class SymbolVisitorTest { @Test fun `creates enums`() { - val model = """ + val model = + """ namespace test @enum([ @@ -90,7 +95,7 @@ class SymbolVisitorTest { } ]) string StandardUnit - """.asSmithyModel() + """.asSmithyModel() val shape = model.expectShape(ShapeId.from("test#StandardUnit")) val provider: SymbolProvider = testSymbolProvider(model) val sym = provider.toSymbol(shape) @@ -118,13 +123,18 @@ class SymbolVisitorTest { "Boolean, true, bool", "PrimitiveBoolean, false, bool", ) - fun `creates primitives`(primitiveType: String, optional: Boolean, rustName: String) { - val model = """ + fun `creates primitives`( + primitiveType: String, + optional: Boolean, + rustName: String, + ) { + val model = + """ namespace foo.bar structure MyStruct { quux: $primitiveType } - """.asSmithyModel() + """.asSmithyModel() val member = model.expectShape(ShapeId.from("foo.bar#MyStruct\$quux")) val provider: SymbolProvider = testSymbolProvider(model) val memberSymbol = provider.toSymbol(member) @@ -142,14 +152,16 @@ class SymbolVisitorTest { @Test fun `creates sets of strings`() { val stringShape = StringShape.builder().id("test#Hello").build() - val set = SetShape.builder() - .id("foo.bar#Records") - .member(stringShape.id) - .build() - val model = Model.assembler() - .addShapes(set, stringShape) - .assemble() - .unwrap() + val set = + SetShape.builder() + .id("foo.bar#Records") + .member(stringShape.id) + .build() + val model = + Model.assembler() + .addShapes(set, stringShape) + .assemble() + .unwrap() val provider: SymbolProvider = testSymbolProvider(model) val setSymbol = provider.toSymbol(set) @@ -161,14 +173,16 @@ class SymbolVisitorTest { fun `create vec instead for non-strings`() { val struct = StructureShape.builder().id("foo.bar#Record").build() val setMember = MemberShape.builder().id("foo.bar#Records\$member").target(struct).build() - val set = SetShape.builder() - .id("foo.bar#Records") - .member(setMember) - .build() - val model = Model.assembler() - .addShapes(set, setMember, struct) - .assemble() - .unwrap() + val set = + SetShape.builder() + .id("foo.bar#Records") + .member(setMember) + .build() + val model = + Model.assembler() + .addShapes(set, setMember, struct) + .assemble() + .unwrap() val provider: SymbolProvider = testSymbolProvider(model) val setSymbol = provider.toSymbol(set) @@ -180,16 +194,18 @@ class SymbolVisitorTest { fun `create sparse collections`() { val struct = StructureShape.builder().id("foo.bar#Record").build() val setMember = MemberShape.builder().id("foo.bar#Records\$member").target(struct).build() - val set = ListShape.builder() - .id("foo.bar#Records") - .member(setMember) - .addTrait(SparseTrait()) - .build() - val model = Model.assembler() - .putProperty(ModelAssembler.ALLOW_UNKNOWN_TRAITS, true) - .addShapes(set, setMember, struct) - .assemble() - .unwrap() + val set = + ListShape.builder() + .id("foo.bar#Records") + .member(setMember) + .addTrait(SparseTrait()) + .build() + val model = + Model.assembler() + .putProperty(ModelAssembler.ALLOW_UNKNOWN_TRAITS, true) + .addShapes(set, setMember, struct) + .assemble() + .unwrap() val provider: SymbolProvider = testSymbolProvider(model) val setSymbol = provider.toSymbol(set) @@ -201,14 +217,16 @@ class SymbolVisitorTest { fun `create timestamps`() { val memberBuilder = MemberShape.builder().id("foo.bar#MyStruct\$someField").target("smithy.api#Timestamp") val member = memberBuilder.build() - val struct = StructureShape.builder() - .id("foo.bar#MyStruct") - .addMember(member) - .build() - val model = Model.assembler() - .addShapes(struct, member) - .assemble() - .unwrap() + val struct = + StructureShape.builder() + .id("foo.bar#MyStruct") + .addMember(member) + .build() + val model = + Model.assembler() + .addShapes(struct, member) + .assemble() + .unwrap() val provider: SymbolProvider = testSymbolProvider(model) val sym = provider.toSymbol(member) sym.rustType().render(false) shouldBe "Option" @@ -218,7 +236,8 @@ class SymbolVisitorTest { @Test fun `creates operations`() { - val model = """ + val model = + """ namespace smithy.example @idempotent @@ -252,7 +271,7 @@ class SymbolVisitorTest { // Sent in the body additional: String, } - """.asSmithyModel() + """.asSmithyModel() val symbol = testSymbolProvider(model).toSymbol(model.expectShape(ShapeId.from("smithy.example#PutObject"))) symbol.definitionFile shouldBe "src/test_operation.rs" symbol.name shouldBe "PutObject" diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customizations/SmithyTypesPubUseExtraTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customizations/SmithyTypesPubUseExtraTest.kt index f9eda03a43f..71bd51cf21b 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customizations/SmithyTypesPubUseExtraTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customizations/SmithyTypesPubUseExtraTest.kt @@ -50,19 +50,22 @@ class SmithyTypesPubUseExtraTest { private val codegenContext: CodegenContext = testCodegenContext(model) init { - val (context, _) = generatePluginContext( - model, - runtimeConfig = codegenContext.runtimeConfig, - ) - rustCrate = RustCrate( - context.fileManifest, - codegenContext.symbolProvider, - codegenContext.settings.codegenConfig, - codegenContext.expectModuleDocProvider(), - ) + val (context, _) = + generatePluginContext( + model, + runtimeConfig = codegenContext.runtimeConfig, + ) + rustCrate = + RustCrate( + context.fileManifest, + codegenContext.symbolProvider, + codegenContext.settings.codegenConfig, + codegenContext.expectModuleDocProvider(), + ) } private fun reexportsWithEmptyModel() = reexportsWithMember() + private fun reexportsWithMember( inputMember: String = "", outputMember: String = "", @@ -77,19 +80,29 @@ class SmithyTypesPubUseExtraTest { writer.toString() } - private fun assertDoesntHaveReexports(reexports: String, expectedTypes: List) = - expectedTypes.forEach { assertDoesntHaveReexports(reexports, it) } + private fun assertDoesntHaveReexports( + reexports: String, + expectedTypes: List, + ) = expectedTypes.forEach { assertDoesntHaveReexports(reexports, it) } - private fun assertDoesntHaveReexports(reexports: String, type: String) { + private fun assertDoesntHaveReexports( + reexports: String, + type: String, + ) { if (reexports.contains(type)) { throw AssertionError("Expected $type to NOT be re-exported, but it was.") } } - private fun assertHasReexports(reexports: String, expectedTypes: List) = - expectedTypes.forEach { assertHasReexport(reexports, it) } + private fun assertHasReexports( + reexports: String, + expectedTypes: List, + ) = expectedTypes.forEach { assertHasReexport(reexports, it) } - private fun assertHasReexport(reexports: String, type: String) { + private fun assertHasReexport( + reexports: String, + type: String, + ) { if (!reexports.contains(type)) { throw AssertionError("Expected $type to be re-exported. Re-exported types:\n$reexports") } diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/BuilderGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/BuilderGeneratorTest.kt index b38aa586e78..2ea5c56f59e 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/BuilderGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/BuilderGeneratorTest.kt @@ -67,11 +67,12 @@ internal class BuilderGeneratorTest { @Test fun `generate fallible builders`() { val baseProvider = testSymbolProvider(StructureGeneratorTest.model) - val provider = object : WrappingSymbolProvider(baseProvider) { - override fun toSymbol(shape: Shape): Symbol { - return baseProvider.toSymbol(shape).toBuilder().setDefault(Default.NoDefault).build() + val provider = + object : WrappingSymbolProvider(baseProvider) { + override fun toSymbol(shape: Shape): Symbol { + return baseProvider.toSymbol(shape).toBuilder().setDefault(Default.NoDefault).build() + } } - } val project = TestWorkspace.testProject(provider) project.moduleFor(StructureGeneratorTest.struct) { @@ -97,7 +98,12 @@ internal class BuilderGeneratorTest { project.compileAndTest() } - private fun generator(model: Model, provider: RustSymbolProvider, writer: RustWriter, shape: StructureShape) = StructureGenerator(model, provider, writer, shape, emptyList(), StructSettings(flattenVecAccessors = true)) + private fun generator( + model: Model, + provider: RustSymbolProvider, + writer: RustWriter, + shape: StructureShape, + ) = StructureGenerator(model, provider, writer, shape, emptyList(), StructSettings(flattenVecAccessors = true)) @Test fun `builder for a struct with sensitive fields should implement the debug trait as such`() { @@ -153,7 +159,8 @@ internal class BuilderGeneratorTest { @Test fun `it supports nonzero defaults`() { - val model = """ + val model = + """ namespace com.test structure MyStruct { @default(0) @@ -197,13 +204,14 @@ internal class BuilderGeneratorTest { @default(1) integer OneDefault - """.asSmithyModel(smithyVersion = "2.0") + """.asSmithyModel(smithyVersion = "2.0") - val provider = testSymbolProvider( - model, - rustReservedWordConfig = StructureGeneratorTest.rustReservedWordConfig, - nullabilityCheckMode = NullableIndex.CheckMode.CLIENT_CAREFUL, - ) + val provider = + testSymbolProvider( + model, + rustReservedWordConfig = StructureGeneratorTest.rustReservedWordConfig, + nullabilityCheckMode = NullableIndex.CheckMode.CLIENT_CAREFUL, + ) val project = TestWorkspace.testProject(provider) val shape: StructureShape = model.lookup("com.test#MyStruct") project.useShapeWriter(shape) { diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGeneratorTest.kt index b233a763bd7..b0fa42ef0fc 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGeneratorTest.kt @@ -31,15 +31,17 @@ import software.amazon.smithy.rust.codegen.core.util.lookup import software.amazon.smithy.rust.codegen.core.util.orNull class EnumGeneratorTest { - private val rustReservedWordConfig = RustReservedWordConfig( - enumMemberMap = mapOf("Unknown" to "UnknownValue"), - structureMemberMap = emptyMap(), - unionMemberMap = emptyMap(), - ) + private val rustReservedWordConfig = + RustReservedWordConfig( + enumMemberMap = mapOf("Unknown" to "UnknownValue"), + structureMemberMap = emptyMap(), + unionMemberMap = emptyMap(), + ) @Nested inner class EnumMemberModelTests { - private val testModel = """ + private val testModel = + """ namespace test @enum([ { value: "some-value-1", @@ -53,16 +55,17 @@ class EnumGeneratorTest { documentation: "It has some docs that #need to be escaped" } ]) string EnumWithUnknown - """.asSmithyModel() + """.asSmithyModel() private val symbolProvider = testSymbolProvider(testModel, rustReservedWordConfig = rustReservedWordConfig) private val enumTrait = testModel.lookup("test#EnumWithUnknown").expectTrait() - private fun model(name: String): EnumMemberModel = EnumMemberModel( - testModel.lookup("test#EnumWithUnknown"), - enumTrait.values.first { it.name.orNull() == name }, - symbolProvider, - ) + private fun model(name: String): EnumMemberModel = + EnumMemberModel( + testModel.lookup("test#EnumWithUnknown"), + enumTrait.values.first { it.name.orNull() == name }, + symbolProvider, + ) @Test fun `it converts enum names to PascalCase and renames any named Unknown to UnknownValue`() { @@ -114,7 +117,8 @@ class EnumGeneratorTest { @Test fun `it generates named enums`() { - val model = """ + val model = + """ namespace test @enum([ { @@ -133,7 +137,7 @@ class EnumGeneratorTest { ]) @deprecated(since: "1.2.3") string InstanceType - """.asSmithyModel() + """.asSmithyModel() val shape = model.lookup("test#InstanceType") val provider = testSymbolProvider(model) @@ -162,7 +166,8 @@ class EnumGeneratorTest { @Test fun `named enums implement eq and hash`() { - val model = """ + val model = + """ namespace test @enum([ { @@ -174,7 +179,7 @@ class EnumGeneratorTest { name: "Bar" }]) string FooEnum - """.asSmithyModel() + """.asSmithyModel() val shape = model.lookup("test#FooEnum") val provider = testSymbolProvider(model) @@ -196,7 +201,8 @@ class EnumGeneratorTest { @Test fun `unnamed enums implement eq and hash`() { - val model = """ + val model = + """ namespace test @enum([ { @@ -207,7 +213,7 @@ class EnumGeneratorTest { }]) @deprecated string FooEnum - """.asSmithyModel() + """.asSmithyModel() val shape = model.lookup("test#FooEnum") val provider = testSymbolProvider(model) @@ -230,7 +236,8 @@ class EnumGeneratorTest { @Test fun `it generates unnamed enums`() { - val model = """ + val model = + """ namespace test @enum([ { @@ -250,7 +257,7 @@ class EnumGeneratorTest { }, ]) string FooEnum - """.asSmithyModel() + """.asSmithyModel() val shape = model.lookup("test#FooEnum") val provider = testSymbolProvider(model) @@ -271,7 +278,8 @@ class EnumGeneratorTest { @Test fun `it should generate documentation for enums`() { - val model = """ + val model = + """ namespace test /// Some top-level documentation. @@ -280,7 +288,7 @@ class EnumGeneratorTest { { name: "Unknown", value: "Unknown" }, ]) string SomeEnum - """.asSmithyModel() + """.asSmithyModel() val shape = model.lookup("test#SomeEnum") val provider = testSymbolProvider(model, rustReservedWordConfig = rustReservedWordConfig) @@ -300,7 +308,8 @@ class EnumGeneratorTest { @Test fun `it should generate documentation for unnamed enums`() { - val model = """ + val model = + """ namespace test /// Some top-level documentation. @@ -309,7 +318,7 @@ class EnumGeneratorTest { { value: "Two" }, ]) string SomeEnum - """.asSmithyModel() + """.asSmithyModel() val shape = model.lookup("test#SomeEnum") val provider = testSymbolProvider(model) @@ -327,14 +336,15 @@ class EnumGeneratorTest { @Test fun `it handles variants that clash with Rust reserved words`() { - val model = """ + val model = + """ namespace test @enum([ { name: "Known", value: "Known" }, { name: "Self", value: "other" }, ]) string SomeEnum - """.asSmithyModel() + """.asSmithyModel() val shape = model.lookup("test#SomeEnum") val provider = testSymbolProvider(model) @@ -351,14 +361,15 @@ class EnumGeneratorTest { @Test fun `impl debug for non-sensitive enum should implement the derived debug trait`() { - val model = """ + val model = + """ namespace test @enum([ { name: "Foo", value: "Foo" }, { name: "Bar", value: "Bar" }, ]) string SomeEnum - """.asSmithyModel() + """.asSmithyModel() val shape = model.lookup("test#SomeEnum") val provider = testSymbolProvider(model) @@ -378,7 +389,8 @@ class EnumGeneratorTest { @Test fun `impl debug for sensitive enum should redact text`() { - val model = """ + val model = + """ namespace test @sensitive @enum([ @@ -386,7 +398,7 @@ class EnumGeneratorTest { { name: "Bar", value: "Bar" }, ]) string SomeEnum - """.asSmithyModel() + """.asSmithyModel() val shape = model.lookup("test#SomeEnum") val provider = testSymbolProvider(model) @@ -406,14 +418,15 @@ class EnumGeneratorTest { @Test fun `impl debug for non-sensitive unnamed enum should implement the derived debug trait`() { - val model = """ + val model = + """ namespace test @enum([ { value: "Foo" }, { value: "Bar" }, ]) string SomeEnum - """.asSmithyModel() + """.asSmithyModel() val shape = model.lookup("test#SomeEnum") val provider = testSymbolProvider(model) @@ -437,7 +450,8 @@ class EnumGeneratorTest { @Test fun `impl debug for sensitive unnamed enum should redact text`() { - val model = """ + val model = + """ namespace test @sensitive @enum([ @@ -445,7 +459,7 @@ class EnumGeneratorTest { { value: "Bar" }, ]) string SomeEnum - """.asSmithyModel() + """.asSmithyModel() val shape = model.lookup("test#SomeEnum") val provider = testSymbolProvider(model) @@ -470,41 +484,48 @@ class EnumGeneratorTest { @Test fun `it supports other enum types`() { class CustomizingEnumType : EnumType() { - override fun implFromForStr(context: EnumGeneratorContext): Writable = writable { - // intentional no-op - } + override fun implFromForStr(context: EnumGeneratorContext): Writable = + writable { + // intentional no-op + } - override fun implFromStr(context: EnumGeneratorContext): Writable = writable { - // intentional no-op - } + override fun implFromStr(context: EnumGeneratorContext): Writable = + writable { + // intentional no-op + } - override fun additionalEnumMembers(context: EnumGeneratorContext): Writable = writable { - rust("// additional enum members") - } + override fun additionalEnumMembers(context: EnumGeneratorContext): Writable = + writable { + rust("// additional enum members") + } - override fun additionalAsStrMatchArms(context: EnumGeneratorContext): Writable = writable { - rust("// additional as_str match arm") - } + override fun additionalAsStrMatchArms(context: EnumGeneratorContext): Writable = + writable { + rust("// additional as_str match arm") + } - override fun additionalDocs(context: EnumGeneratorContext): Writable = writable { - rust("// additional docs") - } + override fun additionalDocs(context: EnumGeneratorContext): Writable = + writable { + rust("// additional docs") + } } - val model = """ + val model = + """ namespace test @enum([ { name: "Known", value: "Known" }, { name: "Self", value: "other" }, ]) string SomeEnum - """.asSmithyModel() + """.asSmithyModel() val shape = model.lookup("test#SomeEnum") val provider = testSymbolProvider(model) - val output = RustWriter.root().apply { - renderEnum(model, provider, shape, CustomizingEnumType()) - }.toString() + val output = + RustWriter.root().apply { + renderEnum(model, provider, shape, CustomizingEnumType()) + }.toString() // Since we didn't use the Infallible EnumType, there should be no Unknown variant output shouldNotContain "Unknown" diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/InstantiatorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/InstantiatorTest.kt index aa89acbb2bf..a1eef56057b 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/InstantiatorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/InstantiatorTest.kt @@ -32,7 +32,8 @@ import software.amazon.smithy.rust.codegen.core.util.dq import software.amazon.smithy.rust.codegen.core.util.lookup class InstantiatorTest { - private val model = """ + private val model = + """ namespace com.test @documentation("this documents the shape") @@ -84,7 +85,9 @@ class InstantiatorTest { @required num: Integer } - """.asSmithyModel().let { RecursiveShapeBoxer().transform(it) } + """.asSmithyModel().let { + RecursiveShapeBoxer().transform(it) + } private val codegenContext = testCodegenContext(model) private val symbolProvider = codegenContext.symbolProvider @@ -150,16 +153,17 @@ class InstantiatorTest { val structure = model.lookup("com.test#WithBox") val sut = Instantiator(symbolProvider, model, runtimeConfig, BuilderKindBehavior(codegenContext)) - val data = Node.parse( - """ - { - "member": { - "member": { } - }, - "value": 10 - } - """, - ) + val data = + Node.parse( + """ + { + "member": { + "member": { } + }, + "value": 10 + } + """, + ) val project = TestWorkspace.testProject(model) structure.renderWithModelBuilder(model, symbolProvider, project) @@ -204,12 +208,13 @@ class InstantiatorTest { @Test fun `generate sparse lists`() { val data = Node.parse(""" [ "bar", "foo", null ] """) - val sut = Instantiator( - symbolProvider, - model, - runtimeConfig, - BuilderKindBehavior(codegenContext), - ) + val sut = + Instantiator( + symbolProvider, + model, + runtimeConfig, + BuilderKindBehavior(codegenContext), + ) val project = TestWorkspace.testProject(model) project.lib { @@ -225,21 +230,23 @@ class InstantiatorTest { @Test fun `generate maps of maps`() { - val data = Node.parse( - """ - { - "k1": { "map": {} }, - "k2": { "map": { "k3": {} } }, - "k3": { } - } - """, - ) - val sut = Instantiator( - symbolProvider, - model, - runtimeConfig, - BuilderKindBehavior(codegenContext), - ) + val data = + Node.parse( + """ + { + "k1": { "map": {} }, + "k2": { "map": { "k3": {} } }, + "k3": { } + } + """, + ) + val sut = + Instantiator( + symbolProvider, + model, + runtimeConfig, + BuilderKindBehavior(codegenContext), + ) val inner = model.lookup("com.test#Inner") val project = TestWorkspace.testProject(model) @@ -266,12 +273,13 @@ class InstantiatorTest { fun `blob inputs are binary data`() { // "Parameter values that contain binary data MUST be defined using values // that can be represented in plain text (for example, use "foo" and not "Zm9vCg==")." - val sut = Instantiator( - symbolProvider, - model, - runtimeConfig, - BuilderKindBehavior(codegenContext), - ) + val sut = + Instantiator( + symbolProvider, + model, + runtimeConfig, + BuilderKindBehavior(codegenContext), + ) val project = TestWorkspace.testProject(model) project.testModule { @@ -293,12 +301,13 @@ class InstantiatorTest { fun `integer and fractional timestamps`() { // "Parameter values that contain binary data MUST be defined using values // that can be represented in plain text (for example, use "foo" and not "Zm9vCg==")." - val sut = Instantiator( - symbolProvider, - model, - runtimeConfig, - BuilderKindBehavior(codegenContext), - ) + val sut = + Instantiator( + symbolProvider, + model, + runtimeConfig, + BuilderKindBehavior(codegenContext), + ) val project = TestWorkspace.testProject(model) project.testModule { unitTest("timestamps") { diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/StructureGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/StructureGeneratorTest.kt index f31fd538e83..c197a7cb4c0 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/StructureGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/StructureGeneratorTest.kt @@ -87,14 +87,20 @@ class StructureGeneratorTest { val secretStructure = model.lookup("com.test#SecretStructure") val structWithInnerSecretStructure = model.lookup("com.test#StructWithInnerSecretStructure") - val rustReservedWordConfig: RustReservedWordConfig = RustReservedWordConfig( - structureMemberMap = StructureGenerator.structureMemberNameMap, - enumMemberMap = emptyMap(), - unionMemberMap = emptyMap(), - ) + val rustReservedWordConfig: RustReservedWordConfig = + RustReservedWordConfig( + structureMemberMap = StructureGenerator.structureMemberNameMap, + enumMemberMap = emptyMap(), + unionMemberMap = emptyMap(), + ) } - private fun structureGenerator(model: Model, provider: RustSymbolProvider, writer: RustWriter, shape: StructureShape) = StructureGenerator(model, provider, writer, shape, emptyList(), StructSettings(flattenVecAccessors = true)) + private fun structureGenerator( + model: Model, + provider: RustSymbolProvider, + writer: RustWriter, + shape: StructureShape, + ) = StructureGenerator(model, provider, writer, shape, emptyList(), StructSettings(flattenVecAccessors = true)) @Test fun `generate basic structures`() { @@ -213,7 +219,8 @@ class StructureGeneratorTest { @Test fun `attach docs to everything`() { - val model = """ + val model = + """ namespace com.test @documentation("inner doc") structure Inner { } @@ -259,7 +266,8 @@ class StructureGeneratorTest { @Test fun `deprecated trait with message and since`() { - val model = """ + val model = + """ namespace test @deprecated @@ -273,7 +281,7 @@ class StructureGeneratorTest { @deprecated(message: "Fly, you fools!", since: "1.2.3") structure Qux {} - """.asSmithyModel() + """.asSmithyModel() val provider = testSymbolProvider(model, rustReservedWordConfig = rustReservedWordConfig) val project = TestWorkspace.testProject(provider) project.lib { rust("##![allow(deprecated)]") } @@ -290,7 +298,8 @@ class StructureGeneratorTest { @Test fun `nested deprecated trait`() { - val model = """ + val model = + """ namespace test structure Nested { @@ -306,7 +315,7 @@ class StructureGeneratorTest { @deprecated structure Bar {} - """.asSmithyModel() + """.asSmithyModel() val provider = testSymbolProvider(model, rustReservedWordConfig = rustReservedWordConfig) val project = TestWorkspace.testProject(provider) project.lib { rust("##![allow(deprecated)]") } @@ -404,7 +413,8 @@ class StructureGeneratorTest { @Test fun `fields are NOT doc-hidden`() { - val model = """ + val model = + """ namespace com.test structure MyStruct { foo: String, @@ -413,7 +423,7 @@ class StructureGeneratorTest { ts: Timestamp, byteValue: Byte, } - """.asSmithyModel() + """.asSmithyModel() val struct = model.lookup("com.test#MyStruct") val provider = testSymbolProvider(model, rustReservedWordConfig = rustReservedWordConfig) @@ -425,11 +435,12 @@ class StructureGeneratorTest { @Test fun `streaming fields are NOT doc-hidden`() { - val model = """ + val model = + """ namespace com.test @streaming blob SomeStreamingThing structure MyStruct { foo: SomeStreamingThing } - """.asSmithyModel() + """.asSmithyModel() val struct = model.lookup("com.test#MyStruct") val provider = testSymbolProvider(model, rustReservedWordConfig = rustReservedWordConfig) diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/TestEnumType.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/TestEnumType.kt index e8ea12c0dbc..5699c9e5cad 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/TestEnumType.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/TestEnumType.kt @@ -13,37 +13,40 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.util.dq object TestEnumType : EnumType() { - override fun implFromForStr(context: EnumGeneratorContext): Writable = writable { - rustTemplate( - """ - impl #{From}<&str> for ${context.enumName} { - fn from(s: &str) -> Self { - match s { - #{matchArms} + override fun implFromForStr(context: EnumGeneratorContext): Writable = + writable { + rustTemplate( + """ + impl #{From}<&str> for ${context.enumName} { + fn from(s: &str) -> Self { + match s { + #{matchArms} + } } } - } - """, - "From" to RuntimeType.From, - "matchArms" to writable { - context.sortedMembers.forEach { member -> - rust("${member.value.dq()} => ${context.enumName}::${member.derivedName()},") - } - rust("_ => panic!()") - }, - ) - } + """, + "From" to RuntimeType.From, + "matchArms" to + writable { + context.sortedMembers.forEach { member -> + rust("${member.value.dq()} => ${context.enumName}::${member.derivedName()},") + } + rust("_ => panic!()") + }, + ) + } - override fun implFromStr(context: EnumGeneratorContext): Writable = writable { - rust( - """ - impl std::str::FromStr for ${context.enumName} { - type Err = std::convert::Infallible; - fn from_str(s: &str) -> std::result::Result { - Ok(${context.enumName}::from(s)) + override fun implFromStr(context: EnumGeneratorContext): Writable = + writable { + rust( + """ + impl std::str::FromStr for ${context.enumName} { + type Err = std::convert::Infallible; + fn from_str(s: &str) -> std::result::Result { + Ok(${context.enumName}::from(s)) + } } - } - """, - ) - } + """, + ) + } } diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/UnionGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/UnionGeneratorTest.kt index 8b6778890f0..6513ce0952b 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/UnionGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/UnionGeneratorTest.kt @@ -20,15 +20,16 @@ import software.amazon.smithy.rust.codegen.core.util.lookup class UnionGeneratorTest { @Test fun `generate basic unions`() { - val writer = generateUnion( - """ - union MyUnion { - stringConfig: String, - @documentation("This *is* documentation about the member") - intConfig: PrimitiveInteger - } - """, - ) + val writer = + generateUnion( + """ + union MyUnion { + stringConfig: String, + @documentation("This *is* documentation about the member") + intConfig: PrimitiveInteger + } + """, + ) writer.compileAndTest( """ @@ -43,14 +44,15 @@ class UnionGeneratorTest { @Test fun `generate conversion helper methods`() { - val writer = generateUnion( - """ - union MyUnion { - stringValue: String, - intValue: PrimitiveInteger - } - """, - ) + val writer = + generateUnion( + """ + union MyUnion { + stringValue: String, + intValue: PrimitiveInteger + } + """, + ) writer.compileAndTest( """ @@ -99,7 +101,8 @@ class UnionGeneratorTest { @Test fun `generate deprecated unions`() { - val model = """namespace test + val model = + """namespace test union Nested { foo: Foo, @deprecated @@ -112,7 +115,7 @@ class UnionGeneratorTest { @deprecated union Bar { x: Integer } - """.asSmithyModel() + """.asSmithyModel() val provider = testSymbolProvider(model) val project = TestWorkspace.testProject(provider) project.lib { rust("##![allow(deprecated)]") } @@ -127,14 +130,15 @@ class UnionGeneratorTest { @Test fun `impl debug for non-sensitive union should implement the derived debug trait`() { - val writer = generateUnion( - """ - union MyUnion { - foo: PrimitiveInteger - bar: String, - } - """, - ) + val writer = + generateUnion( + """ + union MyUnion { + foo: PrimitiveInteger + bar: String, + } + """, + ) writer.compileAndTest( """ @@ -146,15 +150,16 @@ class UnionGeneratorTest { @Test fun `impl debug for sensitive union should redact text`() { - val writer = generateUnion( - """ - @sensitive - union MyUnion { - foo: PrimitiveInteger, - bar: String, - } - """, - ) + val writer = + generateUnion( + """ + @sensitive + union MyUnion { + foo: PrimitiveInteger, + bar: String, + } + """, + ) writer.compileAndTest( """ @@ -166,17 +171,18 @@ class UnionGeneratorTest { @Test fun `impl debug for union should redact text for sensitive member target`() { - val writer = generateUnion( - """ - @sensitive - string Bar - - union MyUnion { - foo: PrimitiveInteger, - bar: Bar, - } - """, - ) + val writer = + generateUnion( + """ + @sensitive + string Bar + + union MyUnion { + foo: PrimitiveInteger, + bar: Bar, + } + """, + ) writer.compileAndTest( """ @@ -188,17 +194,18 @@ class UnionGeneratorTest { @Test fun `impl debug for union with unit target should redact text for sensitive member target`() { - val writer = generateUnion( - """ - @sensitive - string Bar - - union MyUnion { - foo: Unit, - bar: Bar, - } - """, - ) + val writer = + generateUnion( + """ + @sensitive + string Bar + + union MyUnion { + foo: Unit, + bar: Bar, + } + """, + ) writer.compileAndTest( """ @@ -219,7 +226,11 @@ class UnionGeneratorTest { ) } - private fun generateUnion(modelSmithy: String, unionName: String = "MyUnion", unknownVariant: Boolean = true): RustWriter { + private fun generateUnion( + modelSmithy: String, + unionName: String = "MyUnion", + unknownVariant: Boolean = true, + ): RustWriter { val model = "namespace test\n$modelSmithy".asSmithyModel() val provider: SymbolProvider = testSymbolProvider(model) val writer = RustWriter.forModule("model") diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/ProtocolFunctionsTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/ProtocolFunctionsTest.kt index cbc08421f5f..06a5e9f42aa 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/ProtocolFunctionsTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/ProtocolFunctionsTest.kt @@ -12,7 +12,8 @@ import software.amazon.smithy.rust.codegen.core.testutil.testSymbolProvider import software.amazon.smithy.rust.codegen.core.util.lookup class ProtocolFunctionsTest { - private val testModel = """ + private val testModel = + """ namespace test structure SomeStruct1 { @@ -79,13 +80,16 @@ class ProtocolFunctionsTest { operation Op2 { input: Op1Input, } - """.asSmithyModel() + """.asSmithyModel() @Test fun `generates function names for shapes`() { val symbolProvider = testSymbolProvider(testModel) - fun test(shapeId: String, expected: String) { + fun test( + shapeId: String, + expected: String, + ) { symbolProvider.shapeFunctionName(null, testModel.lookup(shapeId)) shouldBe expected } diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/AwsQueryParserGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/AwsQueryParserGeneratorTest.kt index 38beea1e1be..bc239f10215 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/AwsQueryParserGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/AwsQueryParserGeneratorTest.kt @@ -23,7 +23,8 @@ import software.amazon.smithy.rust.codegen.core.util.lookup import software.amazon.smithy.rust.codegen.core.util.outputShape class AwsQueryParserGeneratorTest { - private val baseModel = """ + private val baseModel = + """ namespace test use aws.protocols#awsQuery @@ -37,17 +38,18 @@ class AwsQueryParserGeneratorTest { operation SomeOperation { output: SomeOutput } - """.asSmithyModel() + """.asSmithyModel() @Test fun `it modifies operation parsing to include Response and Result tags`() { val model = RecursiveShapeBoxer().transform(OperationNormalizer.transform(baseModel)) val codegenContext = testCodegenContext(model) val symbolProvider = codegenContext.symbolProvider - val parserGenerator = AwsQueryParserGenerator( - codegenContext, - RuntimeType.wrappedXmlErrors(TestRuntimeConfig), - ) + val parserGenerator = + AwsQueryParserGenerator( + codegenContext, + RuntimeType.wrappedXmlErrors(TestRuntimeConfig), + ) val operationParser = parserGenerator.operationParser(model.lookup("test#SomeOperation"))!! val project = TestWorkspace.testProject(testSymbolProvider(model)) diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/Ec2QueryParserGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/Ec2QueryParserGeneratorTest.kt index 9a51b072538..aaac54b7399 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/Ec2QueryParserGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/Ec2QueryParserGeneratorTest.kt @@ -23,7 +23,8 @@ import software.amazon.smithy.rust.codegen.core.util.lookup import software.amazon.smithy.rust.codegen.core.util.outputShape class Ec2QueryParserGeneratorTest { - private val baseModel = """ + private val baseModel = + """ namespace test use aws.protocols#awsQuery @@ -37,17 +38,18 @@ class Ec2QueryParserGeneratorTest { operation SomeOperation { output: SomeOutput } - """.asSmithyModel() + """.asSmithyModel() @Test fun `it modifies operation parsing to include Response and Result tags`() { val model = RecursiveShapeBoxer().transform(OperationNormalizer.transform(baseModel)) val codegenContext = testCodegenContext(model) val symbolProvider = codegenContext.symbolProvider - val parserGenerator = Ec2QueryParserGenerator( - codegenContext, - RuntimeType.wrappedXmlErrors(TestRuntimeConfig), - ) + val parserGenerator = + Ec2QueryParserGenerator( + codegenContext, + RuntimeType.wrappedXmlErrors(TestRuntimeConfig), + ) val operationParser = parserGenerator.operationParser(model.lookup("test#SomeOperation"))!! val project = TestWorkspace.testProject(testSymbolProvider(model)) diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGeneratorTest.kt index e5ca4d19aee..ed416545b1f 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGeneratorTest.kt @@ -28,7 +28,8 @@ import software.amazon.smithy.rust.codegen.core.util.lookup import software.amazon.smithy.rust.codegen.core.util.outputShape class JsonParserGeneratorTest { - private val baseModel = """ + private val baseModel = + """ namespace test use aws.protocols#restJson1 @@ -108,7 +109,7 @@ class JsonParserGeneratorTest { output: OpOutput, errors: [Error] } - """.asSmithyModel() + """.asSmithyModel() @Test fun `generates valid deserializers`() { @@ -116,11 +117,12 @@ class JsonParserGeneratorTest { val codegenContext = testCodegenContext(model) val symbolProvider = codegenContext.symbolProvider - val parserGenerator = JsonParserGenerator( - codegenContext, - HttpTraitHttpBindingResolver(model, ProtocolContentTypes.consistent("application/json")), - ::restJsonFieldName, - ) + val parserGenerator = + JsonParserGenerator( + codegenContext, + HttpTraitHttpBindingResolver(model, ProtocolContentTypes.consistent("application/json")), + ::restJsonFieldName, + ) val operationGenerator = parserGenerator.operationParser(model.lookup("test#Op")) val payloadGenerator = parserGenerator.payloadParser(model.lookup("test#OpOutput\$top")) val errorParser = parserGenerator.errorParser(model.lookup("test#Error")) diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/XmlBindingTraitParserGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/XmlBindingTraitParserGeneratorTest.kt index 0d78af182ea..0ec685f15a6 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/XmlBindingTraitParserGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/XmlBindingTraitParserGeneratorTest.kt @@ -29,7 +29,8 @@ import software.amazon.smithy.rust.codegen.core.util.lookup import software.amazon.smithy.rust.codegen.core.util.outputShape internal class XmlBindingTraitParserGeneratorTest { - private val baseModel = """ + private val baseModel = + """ namespace test use aws.protocols#restXml union Choice { @@ -89,17 +90,18 @@ internal class XmlBindingTraitParserGeneratorTest { input: Top, output: Top } - """.asSmithyModel() + """.asSmithyModel() @Test fun `generates valid parsers`() { val model = RecursiveShapeBoxer().transform(OperationNormalizer.transform(baseModel)) val codegenContext = testCodegenContext(model) val symbolProvider = codegenContext.symbolProvider - val parserGenerator = XmlBindingTraitParserGenerator( - codegenContext, - RuntimeType.wrappedXmlErrors(TestRuntimeConfig), - ) { _, inner -> inner("decoder") } + val parserGenerator = + XmlBindingTraitParserGenerator( + codegenContext, + RuntimeType.wrappedXmlErrors(TestRuntimeConfig), + ) { _, inner -> inner("decoder") } val operationParser = parserGenerator.operationParser(model.lookup("test#Op"))!! val choiceShape = model.lookup("test#Choice") diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/AwsQuerySerializerGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/AwsQuerySerializerGeneratorTest.kt index 2203b05dc1f..a0fd7fffe8c 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/AwsQuerySerializerGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/AwsQuerySerializerGeneratorTest.kt @@ -28,7 +28,8 @@ import software.amazon.smithy.rust.codegen.core.util.inputShape import software.amazon.smithy.rust.codegen.core.util.lookup class AwsQuerySerializerGeneratorTest { - private val baseModel = """ + private val baseModel = + """ namespace test use aws.protocols#restJson1 @@ -85,15 +86,16 @@ class AwsQuerySerializerGeneratorTest { operation Op { input: OpInput, } - """.asSmithyModel() + """.asSmithyModel() @ParameterizedTest @CsvSource("true", "false") fun `generates valid serializers`(generateUnknownVariant: Boolean) { - val codegenTarget = when (generateUnknownVariant) { - true -> CodegenTarget.CLIENT - false -> CodegenTarget.SERVER - } + val codegenTarget = + when (generateUnknownVariant) { + true -> CodegenTarget.CLIENT + false -> CodegenTarget.SERVER + } val model = RecursiveShapeBoxer().transform(OperationNormalizer.transform(baseModel)) val codegenContext = testCodegenContext(model, codegenTarget = codegenTarget) val symbolProvider = codegenContext.symbolProvider @@ -155,7 +157,8 @@ class AwsQuerySerializerGeneratorTest { project.compileAndTest() } - private val baseModelWithRequiredTypes = """ + private val baseModelWithRequiredTypes = + """ namespace test use aws.protocols#restJson1 @@ -219,7 +222,7 @@ class AwsQuerySerializerGeneratorTest { operation Op { input: OpInput, } - """.asSmithyModel() + """.asSmithyModel() @ParameterizedTest @CsvSource( @@ -229,13 +232,18 @@ class AwsQuerySerializerGeneratorTest { "true, CLIENT_ZERO_VALUE_V1_NO_INPUT", "false, SERVER", ) - fun `generates valid serializers for required types`(generateUnknownVariant: Boolean, nullabilityCheckMode: NullableIndex.CheckMode) { - val codegenTarget = when (generateUnknownVariant) { - true -> CodegenTarget.CLIENT - false -> CodegenTarget.SERVER - } + fun `generates valid serializers for required types`( + generateUnknownVariant: Boolean, + nullabilityCheckMode: NullableIndex.CheckMode, + ) { + val codegenTarget = + when (generateUnknownVariant) { + true -> CodegenTarget.CLIENT + false -> CodegenTarget.SERVER + } val model = RecursiveShapeBoxer().transform(OperationNormalizer.transform(baseModelWithRequiredTypes)) - val codegenContext = testCodegenContext(model, codegenTarget = codegenTarget, nullabilityCheckMode = nullabilityCheckMode) + val codegenContext = + testCodegenContext(model, codegenTarget = codegenTarget, nullabilityCheckMode = nullabilityCheckMode) val symbolProvider = codegenContext.symbolProvider val parserGenerator = AwsQuerySerializerGenerator(codegenContext) val operationGenerator = parserGenerator.operationInputSerializer(model.lookup("test#Op")) @@ -245,7 +253,12 @@ class AwsQuerySerializerGeneratorTest { // Depending on the nullability check mode, the builder can be fallible or not. When it's fallible, we need to // add unwrap calls. val builderIsFallible = hasFallibleBuilder(model.lookup("test#Top"), symbolProvider) - val maybeUnwrap = if (builderIsFallible) { ".unwrap()" } else { "" } + val maybeUnwrap = + if (builderIsFallible) { + ".unwrap()" + } else { + "" + } project.lib { unitTest( "query_serializer", diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/Ec2QuerySerializerGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/Ec2QuerySerializerGeneratorTest.kt index 4b5f490e13a..d0a33e1f5d9 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/Ec2QuerySerializerGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/Ec2QuerySerializerGeneratorTest.kt @@ -29,7 +29,8 @@ import software.amazon.smithy.rust.codegen.core.util.inputShape import software.amazon.smithy.rust.codegen.core.util.lookup class Ec2QuerySerializerGeneratorTest { - private val baseModel = """ + private val baseModel = + """ namespace test union Choice { @@ -85,7 +86,7 @@ class Ec2QuerySerializerGeneratorTest { operation Op { input: OpInput, } - """.asSmithyModel() + """.asSmithyModel() @ParameterizedTest @CsvSource( @@ -152,7 +153,8 @@ class Ec2QuerySerializerGeneratorTest { project.compileAndTest() } - private val baseModelWithRequiredTypes = """ + private val baseModelWithRequiredTypes = + """ namespace test union Choice { @@ -215,7 +217,7 @@ class Ec2QuerySerializerGeneratorTest { operation Op { input: OpInput, } - """.asSmithyModel() + """.asSmithyModel() @ParameterizedTest @CsvSource( @@ -239,7 +241,12 @@ class Ec2QuerySerializerGeneratorTest { // add unwrap calls. val builderIsFallible = BuilderGenerator.hasFallibleBuilder(model.lookup("test#Top"), symbolProvider) - val maybeUnwrap = if (builderIsFallible) { ".unwrap()" } else { "" } + val maybeUnwrap = + if (builderIsFallible) { + ".unwrap()" + } else { + "" + } project.lib { unitTest( "ec2query_serializer", diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGeneratorTest.kt index 61140ecd420..922f2b7ae8d 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGeneratorTest.kt @@ -31,7 +31,8 @@ import software.amazon.smithy.rust.codegen.core.util.inputShape import software.amazon.smithy.rust.codegen.core.util.lookup class JsonSerializerGeneratorTest { - private val baseModel = """ + private val baseModel = + """ namespace test use aws.protocols#restJson1 @@ -100,7 +101,7 @@ class JsonSerializerGeneratorTest { operation Op { input: OpInput, } - """.asSmithyModel() + """.asSmithyModel() @ParameterizedTest @CsvSource( @@ -114,11 +115,12 @@ class JsonSerializerGeneratorTest { val model = RecursiveShapeBoxer().transform(OperationNormalizer.transform(baseModel)) val codegenContext = testCodegenContext(model, nullabilityCheckMode = nullabilityCheckMode) val symbolProvider = codegenContext.symbolProvider - val parserSerializer = JsonSerializerGenerator( - codegenContext, - HttpTraitHttpBindingResolver(model, ProtocolContentTypes.consistent("application/json")), - ::restJsonFieldName, - ) + val parserSerializer = + JsonSerializerGenerator( + codegenContext, + HttpTraitHttpBindingResolver(model, ProtocolContentTypes.consistent("application/json")), + ::restJsonFieldName, + ) val operationGenerator = parserSerializer.operationInputSerializer(model.lookup("test#Op")) val documentGenerator = parserSerializer.documentSerializer() @@ -167,7 +169,8 @@ class JsonSerializerGeneratorTest { project.compileAndTest() } - private val baseModelWithRequiredTypes = """ + private val baseModelWithRequiredTypes = + """ namespace test use aws.protocols#restJson1 @@ -245,7 +248,7 @@ class JsonSerializerGeneratorTest { operation Op { input: OpInput, } - """.asSmithyModel() + """.asSmithyModel() @ParameterizedTest @CsvSource( @@ -259,11 +262,12 @@ class JsonSerializerGeneratorTest { val model = RecursiveShapeBoxer().transform(OperationNormalizer.transform(baseModelWithRequiredTypes)) val codegenContext = testCodegenContext(model, nullabilityCheckMode = nullabilityCheckMode) val symbolProvider = codegenContext.symbolProvider - val parserSerializer = JsonSerializerGenerator( - codegenContext, - HttpTraitHttpBindingResolver(model, ProtocolContentTypes.consistent("application/json")), - ::restJsonFieldName, - ) + val parserSerializer = + JsonSerializerGenerator( + codegenContext, + HttpTraitHttpBindingResolver(model, ProtocolContentTypes.consistent("application/json")), + ::restJsonFieldName, + ) val operationGenerator = parserSerializer.operationInputSerializer(model.lookup("test#Op")) val documentGenerator = parserSerializer.documentSerializer() @@ -273,7 +277,12 @@ class JsonSerializerGeneratorTest { // add unwrap calls. val builderIsFallible = BuilderGenerator.hasFallibleBuilder(model.lookup("test#Top"), symbolProvider) - val maybeUnwrap = if (builderIsFallible) { ".unwrap()" } else { "" } + val maybeUnwrap = + if (builderIsFallible) { + ".unwrap()" + } else { + "" + } project.lib { unitTest( "json_serializers", diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/XmlBindingTraitSerializerGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/XmlBindingTraitSerializerGeneratorTest.kt index 2aa24581b55..58f482e2769 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/XmlBindingTraitSerializerGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/XmlBindingTraitSerializerGeneratorTest.kt @@ -32,7 +32,8 @@ import software.amazon.smithy.rust.codegen.core.util.inputShape import software.amazon.smithy.rust.codegen.core.util.lookup internal class XmlBindingTraitSerializerGeneratorTest { - private val baseModel = """ + private val baseModel = + """ namespace test use aws.protocols#restXml union Choice { @@ -106,7 +107,7 @@ internal class XmlBindingTraitSerializerGeneratorTest { operation Op { input: OpInput, } - """.asSmithyModel() + """.asSmithyModel() @ParameterizedTest @CsvSource( @@ -120,10 +121,11 @@ internal class XmlBindingTraitSerializerGeneratorTest { val model = RecursiveShapeBoxer().transform(OperationNormalizer.transform(baseModel)) val codegenContext = testCodegenContext(model, nullabilityCheckMode = nullabilityCheckMode) val symbolProvider = codegenContext.symbolProvider - val parserGenerator = XmlBindingTraitSerializerGenerator( - codegenContext, - HttpTraitHttpBindingResolver(model, ProtocolContentTypes.consistent("application/xml")), - ) + val parserGenerator = + XmlBindingTraitSerializerGenerator( + codegenContext, + HttpTraitHttpBindingResolver(model, ProtocolContentTypes.consistent("application/xml")), + ) val operationSerializer = parserGenerator.payloadSerializer(model.lookup("test#OpInput\$payload")) val project = TestWorkspace.testProject(testSymbolProvider(model)) @@ -171,7 +173,8 @@ internal class XmlBindingTraitSerializerGeneratorTest { project.compileAndTest() } - private val baseModelWithRequiredTypes = """ + private val baseModelWithRequiredTypes = + """ namespace test use aws.protocols#restXml union Choice { @@ -237,7 +240,7 @@ internal class XmlBindingTraitSerializerGeneratorTest { operation Op { input: OpInput, } - """.asSmithyModel() + """.asSmithyModel() @ParameterizedTest @CsvSource( @@ -251,10 +254,11 @@ internal class XmlBindingTraitSerializerGeneratorTest { val model = RecursiveShapeBoxer().transform(OperationNormalizer.transform(baseModelWithRequiredTypes)) val codegenContext = testCodegenContext(model, nullabilityCheckMode = nullabilityCheckMode) val symbolProvider = codegenContext.symbolProvider - val parserGenerator = XmlBindingTraitSerializerGenerator( - codegenContext, - HttpTraitHttpBindingResolver(model, ProtocolContentTypes.consistent("application/xml")), - ) + val parserGenerator = + XmlBindingTraitSerializerGenerator( + codegenContext, + HttpTraitHttpBindingResolver(model, ProtocolContentTypes.consistent("application/xml")), + ) val operationSerializer = parserGenerator.payloadSerializer(model.lookup("test#OpInput\$payload")) val project = TestWorkspace.testProject(symbolProvider) @@ -263,11 +267,22 @@ internal class XmlBindingTraitSerializerGeneratorTest { // add unwrap calls. val builderIsFallible = BuilderGenerator.hasFallibleBuilder(model.lookup("test#Top"), symbolProvider) - val maybeUnwrap = if (builderIsFallible) { ".unwrap()" } else { "" } - val payloadIsOptional = model.lookup("test#OpInput\$payload").let { - symbolProvider.toSymbol(it).isOptional() - } - val maybeUnwrapPayload = if (payloadIsOptional) { ".unwrap()" } else { "" } + val maybeUnwrap = + if (builderIsFallible) { + ".unwrap()" + } else { + "" + } + val payloadIsOptional = + model.lookup("test#OpInput\$payload").let { + symbolProvider.toSymbol(it).isOptional() + } + val maybeUnwrapPayload = + if (payloadIsOptional) { + ".unwrap()" + } else { + "" + } project.lib { unitTest( "serialize_xml", diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/EventStreamNormalizerTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/EventStreamNormalizerTest.kt index 0398180b8d6..6c964fc3a01 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/EventStreamNormalizerTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/EventStreamNormalizerTest.kt @@ -17,15 +17,16 @@ import software.amazon.smithy.rust.codegen.core.util.hasTrait class EventStreamNormalizerTest { @Test fun `it should leave normal unions alone`() { - val transformed = EventStreamNormalizer.transform( - """ - namespace test - union SomeNormalUnion { - Foo: String, - Bar: Long, - } - """.asSmithyModel(), - ) + val transformed = + EventStreamNormalizer.transform( + """ + namespace test + union SomeNormalUnion { + Foo: String, + Bar: Long, + } + """.asSmithyModel(), + ) val shape = transformed.expectShape(ShapeId.from("test#SomeNormalUnion"), UnionShape::class.java) shape.hasTrait() shouldBe false @@ -34,24 +35,25 @@ class EventStreamNormalizerTest { @Test fun `it should transform event stream unions`() { - val transformed = EventStreamNormalizer.transform( - """ - namespace test - - structure SomeMember { - } - - @error("client") - structure SomeError { - } - - @streaming - union SomeEventStream { - SomeMember: SomeMember, - SomeError: SomeError, - } - """.asSmithyModel(), - ) + val transformed = + EventStreamNormalizer.transform( + """ + namespace test + + structure SomeMember { + } + + @error("client") + structure SomeError { + } + + @streaming + union SomeEventStream { + SomeMember: SomeMember, + SomeError: SomeError, + } + """.asSmithyModel(), + ) val shape = transformed.expectShape(ShapeId.from("test#SomeEventStream"), UnionShape::class.java) shape.hasTrait() shouldBe true diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/OperationNormalizerTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/OperationNormalizerTest.kt index 8e1d922820d..39553dd8802 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/OperationNormalizerTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/OperationNormalizerTest.kt @@ -20,10 +20,11 @@ import software.amazon.smithy.rust.codegen.core.util.orNull internal class OperationNormalizerTest { @Test fun `add inputs and outputs to empty operations`() { - val model = """ + val model = + """ namespace smithy.test operation Empty {} - """.asSmithyModel() + """.asSmithyModel() val operationId = ShapeId.from("smithy.test#Empty") model.expectShape(operationId, OperationShape::class.java).input.isPresent shouldBe false val modified = OperationNormalizer.transform(model) @@ -43,7 +44,8 @@ internal class OperationNormalizerTest { @Test fun `create cloned inputs for operations`() { - val model = """ + val model = + """ namespace smithy.test structure RenameMe { v: String @@ -51,7 +53,7 @@ internal class OperationNormalizerTest { operation MyOp { input: RenameMe } - """.asSmithyModel() + """.asSmithyModel() val operationId = ShapeId.from("smithy.test#MyOp") model.expectShape(operationId, OperationShape::class.java).input.isPresent shouldBe true val modified = OperationNormalizer.transform(model) @@ -66,7 +68,8 @@ internal class OperationNormalizerTest { @Test fun `create cloned outputs for operations`() { - val model = """ + val model = + """ namespace smithy.test structure RenameMe { v: String @@ -74,7 +77,7 @@ internal class OperationNormalizerTest { operation MyOp { output: RenameMe } - """.asSmithyModel() + """.asSmithyModel() val operationId = ShapeId.from("smithy.test#MyOp") model.expectShape(operationId, OperationShape::class.java).output.isPresent shouldBe true val modified = OperationNormalizer.transform(model) @@ -89,7 +92,8 @@ internal class OperationNormalizerTest { @Test fun `synthetics should not collide with other operations`() { - val model = """ + val model = + """ namespace test structure DeleteApplicationRequest {} @@ -107,7 +111,7 @@ internal class OperationNormalizerTest { input: DeleteApplicationOutputRequest, output: DeleteApplicationOutputResponse, } - """.asSmithyModel() + """.asSmithyModel() (model.expectShape(ShapeId.from("test#DeleteApplicationOutput")) is OperationShape) shouldBe true diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/RecursiveShapeBoxerTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/RecursiveShapeBoxerTest.kt index 293e2217131..f6f8dc2a4b7 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/RecursiveShapeBoxerTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/RecursiveShapeBoxerTest.kt @@ -18,7 +18,8 @@ import kotlin.streams.toList internal class RecursiveShapeBoxerTest { @Test fun `leave non-recursive models unchanged`() { - val model = """ + val model = + """ namespace com.example list BarList { member: Bar @@ -30,19 +31,20 @@ internal class RecursiveShapeBoxerTest { structure Bar { hello: Hello } - """.asSmithyModel() + """.asSmithyModel() RecursiveShapeBoxer().transform(model) shouldBe model } @Test fun `add the box trait to simple recursive shapes`() { - val model = """ + val model = + """ namespace com.example structure Recursive { RecursiveStruct: Recursive, anotherField: Boolean } - """.asSmithyModel() + """.asSmithyModel() val transformed = RecursiveShapeBoxer().transform(model) val member: MemberShape = transformed.lookup("com.example#Recursive\$RecursiveStruct") member.expectTrait() @@ -50,7 +52,8 @@ internal class RecursiveShapeBoxerTest { @Test fun `add the box trait to complex structures`() { - val model = """ + val model = + """ namespace com.example structure Expr { left: Atom, @@ -69,14 +72,15 @@ internal class RecursiveShapeBoxerTest { otherMember: Atom, third: SecondTree } - """.asSmithyModel() + """.asSmithyModel() val transformed = RecursiveShapeBoxer().transform(model) val boxed = transformed.shapes().filter { it.hasTrait() }.toList() - boxed.map { it.id.toString().removePrefix("com.example#") }.toSet() shouldBe setOf( - "Atom\$add", - "Atom\$sub", - "SecondTree\$third", - "Atom\$more", - ) + boxed.map { it.id.toString().removePrefix("com.example#") }.toSet() shouldBe + setOf( + "Atom\$add", + "Atom\$sub", + "SecondTree\$third", + "Atom\$more", + ) } } diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/RecursiveShapesIntegrationTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/RecursiveShapesIntegrationTest.kt index 110836fdadd..1f8917c35f6 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/RecursiveShapesIntegrationTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/RecursiveShapesIntegrationTest.kt @@ -24,7 +24,8 @@ import software.amazon.smithy.rust.codegen.core.util.lookup class RecursiveShapesIntegrationTest { @Test fun `recursive shapes are properly boxed`() { - val model = """ + val model = + """ namespace com.example structure Expr { left: Atom, @@ -43,7 +44,7 @@ class RecursiveShapesIntegrationTest { otherMember: Atom, third: SecondTree } - """.asSmithyModel() + """.asSmithyModel() val check = { input: Model -> val symbolProvider = testSymbolProvider(model) @@ -62,9 +63,10 @@ class RecursiveShapesIntegrationTest { project } val unmodifiedProject = check(model) - val output = assertThrows { - unmodifiedProject.compileAndTest(expectFailure = true) - } + val output = + assertThrows { + unmodifiedProject.compileAndTest(expectFailure = true) + } // THIS IS A LOAD-BEARING shouldContain! If the compiler error changes then this will break! output.message shouldContain "have infinite size" diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/util/MapTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/util/MapTest.kt index a38499de36b..1ed9bd77064 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/util/MapTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/util/MapTest.kt @@ -13,58 +13,68 @@ class MapTest { fun `it should deep merge maps`() { mapOf().deepMergeWith(mapOf()) shouldBe emptyMap() - mapOf("foo" to 1, "bar" to "baz").deepMergeWith(mapOf()) shouldBe mapOf( - "foo" to 1, - "bar" to "baz", - ) + mapOf("foo" to 1, "bar" to "baz").deepMergeWith(mapOf()) shouldBe + mapOf( + "foo" to 1, + "bar" to "baz", + ) - mapOf().deepMergeWith(mapOf("foo" to 1, "bar" to "baz")) shouldBe mapOf( - "foo" to 1, - "bar" to "baz", - ) + mapOf().deepMergeWith(mapOf("foo" to 1, "bar" to "baz")) shouldBe + mapOf( + "foo" to 1, + "bar" to "baz", + ) mapOf( - "package" to mapOf( - "name" to "foo", - "version" to "1.0.0", - ), + "package" to + mapOf( + "name" to "foo", + "version" to "1.0.0", + ), ).deepMergeWith( mapOf( - "package" to mapOf( - "readme" to "README.md", - ), + "package" to + mapOf( + "readme" to "README.md", + ), ), - ) shouldBe mapOf( - "package" to mapOf( - "name" to "foo", - "version" to "1.0.0", - "readme" to "README.md", - ), - ) + ) shouldBe + mapOf( + "package" to + mapOf( + "name" to "foo", + "version" to "1.0.0", + "readme" to "README.md", + ), + ) mapOf( - "package" to mapOf( - "name" to "foo", - "version" to "1.0.0", - "overwrite-me" to "wrong", - ), + "package" to + mapOf( + "name" to "foo", + "version" to "1.0.0", + "overwrite-me" to "wrong", + ), "make-me-not-a-map" to mapOf("foo" to "bar"), ).deepMergeWith( mapOf( - "package" to mapOf( - "readme" to "README.md", - "overwrite-me" to "correct", - ), + "package" to + mapOf( + "readme" to "README.md", + "overwrite-me" to "correct", + ), "make-me-not-a-map" to 5, ), - ) shouldBe mapOf( - "package" to mapOf( - "name" to "foo", - "version" to "1.0.0", - "readme" to "README.md", - "overwrite-me" to "correct", - ), - "make-me-not-a-map" to 5, - ) + ) shouldBe + mapOf( + "package" to + mapOf( + "name" to "foo", + "version" to "1.0.0", + "readme" to "README.md", + "overwrite-me" to "correct", + ), + "make-me-not-a-map" to 5, + ) } } diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/util/StringsTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/util/StringsTest.kt index af4e65d450e..ab46172348d 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/util/StringsTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/util/StringsTest.kt @@ -17,14 +17,14 @@ import java.io.File import java.util.stream.Stream internal class StringsTest { - @Test fun doubleQuote() { "abc".doubleQuote() shouldBe "\"abc\"" """{"some": "json"}""".doubleQuote() shouldBe """"{\"some\": \"json\"}"""" - """{"nested": "{\"nested\": 5}"}"}""".doubleQuote() shouldBe """ - "{\"nested\": \"{\\\"nested\\\": 5}\"}\"}" - """.trimIndent().trim() + """{"nested": "{\"nested\": 5}"}"}""".doubleQuote() shouldBe + """ + "{\"nested\": \"{\\\"nested\\\": 5}\"}\"}" + """.trimIndent().trim() } @Test @@ -59,7 +59,10 @@ internal class StringsTest { @ParameterizedTest @ArgumentsSource(TestCasesProvider::class) - fun testSnakeCase(input: String, output: String) { + fun testSnakeCase( + input: String, + output: String, + ) { input.toSnakeCase() shouldBe output } } diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/util/SyntheticsTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/util/SyntheticsTest.kt index 270183d05e8..a21479c697d 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/util/SyntheticsTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/util/SyntheticsTest.kt @@ -15,7 +15,8 @@ import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel class SyntheticsTest { @Test fun `it should clone operations`() { - val model = """ + val model = + """ namespace test service TestService { @@ -33,11 +34,12 @@ class SyntheticsTest { } operation SomeOperation { input: TestInput, output: TestOutput } - """.asSmithyModel() + """.asSmithyModel() - val transformed = model.toBuilder().cloneOperation(model, ShapeId.from("test#SomeOperation")) { shapeId -> - ShapeId.fromParts(shapeId.namespace + ".cloned", shapeId.name + "Foo") - }.build() + val transformed = + model.toBuilder().cloneOperation(model, ShapeId.from("test#SomeOperation")) { shapeId -> + ShapeId.fromParts(shapeId.namespace + ".cloned", shapeId.name + "Foo") + }.build() val newOp = transformed.expectShape(ShapeId.from("test.cloned#SomeOperationFoo"), OperationShape::class.java) newOp.input.orNull() shouldBe ShapeId.from("test.cloned#TestInputFoo") @@ -56,14 +58,15 @@ class SyntheticsTest { @Test fun `it should rename structs`() { - val model = """ + val model = + """ namespace test structure SomeInput { one: String, two: String, } - """.asSmithyModel() + """.asSmithyModel() val original = model.expectShape(ShapeId.from("test#SomeInput"), StructureShape::class.java) val new = original.toBuilder().rename(ShapeId.from("new#SomeOtherInput")).build() diff --git a/codegen-server/build.gradle.kts b/codegen-server/build.gradle.kts index be9709471ed..1bd38b864f1 100644 --- a/codegen-server/build.gradle.kts +++ b/codegen-server/build.gradle.kts @@ -32,7 +32,7 @@ dependencies { testImplementation("software.amazon.smithy:smithy-validation-model:$smithyVersion") } -tasks.compileKotlin { kotlinOptions.jvmTarget = "1.8" } +tasks.compileKotlin { kotlinOptions.jvmTarget = "11" } // Reusable license copySpec val licenseSpec = copySpec { @@ -63,7 +63,7 @@ if (isTestingEnabled.toBoolean()) { testImplementation("io.kotest:kotest-assertions-core-jvm:$kotestVersion") } - tasks.compileTestKotlin { kotlinOptions.jvmTarget = "1.8" } + tasks.compileTestKotlin { kotlinOptions.jvmTarget = "11" } tasks.test { useJUnitPlatform() diff --git a/codegen-server/python/build.gradle.kts b/codegen-server/python/build.gradle.kts index 337d099bccb..c79ae694be5 100644 --- a/codegen-server/python/build.gradle.kts +++ b/codegen-server/python/build.gradle.kts @@ -32,7 +32,7 @@ dependencies { testImplementation("software.amazon.smithy:smithy-validation-model:$smithyVersion") } -tasks.compileKotlin { kotlinOptions.jvmTarget = "1.8" } +tasks.compileKotlin { kotlinOptions.jvmTarget = "11" } // Reusable license copySpec val licenseSpec = copySpec { @@ -63,7 +63,7 @@ if (isTestingEnabled.toBoolean()) { testImplementation("io.kotest:kotest-assertions-core-jvm:$kotestVersion") } - tasks.compileTestKotlin { kotlinOptions.jvmTarget = "1.8" } + tasks.compileTestKotlin { kotlinOptions.jvmTarget = "11" } tasks.test { useJUnitPlatform() diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonEventStreamSymbolProvider.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonEventStreamSymbolProvider.kt index 38d3830bace..fe642c7cc9c 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonEventStreamSymbolProvider.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonEventStreamSymbolProvider.kt @@ -58,25 +58,28 @@ class PythonEventStreamSymbolProvider( // We can only wrap the type if it's either an input or an output that used in an operation model.expectShape(shape.container).let { maybeInputOutput -> - val operationId = maybeInputOutput.getTrait()?.operation - ?: maybeInputOutput.getTrait()?.operation + val operationId = + maybeInputOutput.getTrait()?.operation + ?: maybeInputOutput.getTrait()?.operation operationId?.let { model.expectShape(it, OperationShape::class.java) } } ?: return initial val unionShape = model.expectShape(shape.target).asUnionShape().get() - val error = if (unionShape.eventStreamErrors().isEmpty()) { - RuntimeType.smithyHttp(runtimeConfig).resolve("event_stream::MessageStreamError").toSymbol() - } else { - symbolForEventStreamError(unionShape) - } + val error = + if (unionShape.eventStreamErrors().isEmpty()) { + RuntimeType.smithyHttp(runtimeConfig).resolve("event_stream::MessageStreamError").toSymbol() + } else { + symbolForEventStreamError(unionShape) + } val inner = initial.rustType().stripOuter() val innerSymbol = Symbol.builder().name(inner.name).rustType(inner).build() val containerName = shape.container.name val memberName = shape.memberName.toPascalCase() - val outer = when (shape.isOutputEventStream(model)) { - true -> "${containerName}${memberName}EventStreamSender" - else -> "${containerName}${memberName}Receiver" - } + val outer = + when (shape.isOutputEventStream(model)) { + true -> "${containerName}${memberName}EventStreamSender" + else -> "${containerName}${memberName}Receiver" + } val rustType = RustType.Opaque(outer, PythonServerRustModule.PythonEventStream.fullyQualifiedPath()) return Symbol.builder() .name(rustType.name) diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonServerCargoDependency.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonServerCargoDependency.kt index 2782e44b94c..2e8070d15e3 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonServerCargoDependency.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonServerCargoDependency.kt @@ -16,16 +16,20 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig */ object PythonServerCargoDependency { val PyO3: CargoDependency = CargoDependency("pyo3", CratesIo("0.18")) - val PyO3Asyncio: CargoDependency = CargoDependency("pyo3-asyncio", CratesIo("0.18"), features = setOf("attributes", "tokio-runtime")) + val PyO3Asyncio: CargoDependency = + CargoDependency("pyo3-asyncio", CratesIo("0.18"), features = setOf("attributes", "tokio-runtime")) val Tokio: CargoDependency = CargoDependency("tokio", CratesIo("1.20.1"), features = setOf("full")) val TokioStream: CargoDependency = CargoDependency("tokio-stream", CratesIo("0.1.12")) val Tracing: CargoDependency = CargoDependency("tracing", CratesIo("0.1")) val Tower: CargoDependency = CargoDependency("tower", CratesIo("0.4")) val TowerHttp: CargoDependency = CargoDependency("tower-http", CratesIo("0.3"), features = setOf("trace")) - val Hyper: CargoDependency = CargoDependency("hyper", CratesIo("0.14.12"), features = setOf("server", "http1", "http2", "tcp", "stream")) + val Hyper: CargoDependency = + CargoDependency("hyper", CratesIo("0.14.12"), features = setOf("server", "http1", "http2", "tcp", "stream")) val NumCpus: CargoDependency = CargoDependency("num_cpus", CratesIo("1.13")) val ParkingLot: CargoDependency = CargoDependency("parking_lot", CratesIo("0.12")) fun smithyHttpServer(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-http-server") - fun smithyHttpServerPython(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-http-server-python") + + fun smithyHttpServerPython(runtimeConfig: RuntimeConfig) = + runtimeConfig.smithyRuntimeCrate("smithy-http-server-python") } diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonServerCodegenVisitor.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonServerCodegenVisitor.kt index 3bb954bfc22..19ea83426ac 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonServerCodegenVisitor.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonServerCodegenVisitor.kt @@ -60,7 +60,6 @@ class PythonServerCodegenVisitor( context: PluginContext, private val codegenDecorator: ServerCodegenDecorator, ) : ServerCodegenVisitor(context, codegenDecorator) { - init { val rustSymbolProviderConfig = RustSymbolProviderConfig( @@ -95,23 +94,26 @@ class PythonServerCodegenVisitor( publicConstrainedTypes: Boolean, includeConstraintShapeProvider: Boolean, codegenDecorator: ServerCodegenDecorator, - ) = RustServerCodegenPythonPlugin.baseSymbolProvider(settings, model, serviceShape, rustSymbolProviderConfig, publicConstrainedTypes, includeConstraintShapeProvider, codegenDecorator) + ) = + RustServerCodegenPythonPlugin.baseSymbolProvider(settings, model, serviceShape, rustSymbolProviderConfig, publicConstrainedTypes, includeConstraintShapeProvider, codegenDecorator) - val serverSymbolProviders = ServerSymbolProviders.from( - settings, - model, - service, - rustSymbolProviderConfig, - settings.codegenConfig.publicConstrainedTypes, - codegenDecorator, - ::baseSymbolProviderFactory, - ) + val serverSymbolProviders = + ServerSymbolProviders.from( + settings, + model, + service, + rustSymbolProviderConfig, + settings.codegenConfig.publicConstrainedTypes, + codegenDecorator, + ::baseSymbolProviderFactory, + ) // Override `codegenContext` which carries the various symbol providers. - val moduleDocProvider = codegenDecorator.moduleDocumentationCustomization( - codegenContext, - PythonServerModuleDocProvider(ServerModuleDocProvider(codegenContext)), - ) + val moduleDocProvider = + codegenDecorator.moduleDocumentationCustomization( + codegenContext, + PythonServerModuleDocProvider(ServerModuleDocProvider(codegenContext)), + ) codegenContext = ServerCodegenContext( model, @@ -127,12 +129,13 @@ class PythonServerCodegenVisitor( ) // Override `rustCrate` which carries the symbolProvider. - rustCrate = RustCrate( - context.fileManifest, - codegenContext.symbolProvider, - settings.codegenConfig, - codegenContext.expectModuleDocProvider(), - ) + rustCrate = + RustCrate( + context.fileManifest, + codegenContext.symbolProvider, + settings.codegenConfig, + codegenContext.expectModuleDocProvider(), + ) // Override `protocolGenerator` which carries the symbolProvider. protocolGenerator = protocolGeneratorFactory.buildProtocolGenerator(codegenContext) } @@ -175,8 +178,10 @@ class PythonServerCodegenVisitor( * Although raw strings require no code generation, enums are actually [EnumTrait] applied to string shapes. */ override fun stringShape(shape: StringShape) { - fun pythonServerEnumGeneratorFactory(codegenContext: ServerCodegenContext, shape: StringShape) = - PythonServerEnumGenerator(codegenContext, shape, validationExceptionConversionGenerator) + fun pythonServerEnumGeneratorFactory( + codegenContext: ServerCodegenContext, + shape: StringShape, + ) = PythonServerEnumGenerator(codegenContext, shape, validationExceptionConversionGenerator) stringShape(shape, ::pythonServerEnumGeneratorFactory) } @@ -193,7 +198,8 @@ class PythonServerCodegenVisitor( PythonServerUnionGenerator(model, codegenContext, this, shape, renderUnknownVariant = false).render() } - if (shape.isReachableFromOperationInput() && shape.canReachConstrainedShape( + if (shape.isReachableFromOperationInput() && + shape.canReachConstrainedShape( model, codegenContext.symbolProvider, ) diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonServerSymbolProvider.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonServerSymbolProvider.kt index f01e9a41c4a..70e055e23f8 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonServerSymbolProvider.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonServerSymbolProvider.kt @@ -35,8 +35,13 @@ import software.amazon.smithy.rust.codegen.server.smithy.ConstrainedShapeSymbolP import software.amazon.smithy.rust.codegen.server.smithy.ServerRustSettings import java.util.logging.Logger -/* Returns the Python implementation of the ByteStream shape or the original symbol that is provided in input. */ -private fun toPythonByteStreamSymbolOrOriginal(model: Model, config: RustSymbolProviderConfig, initial: Symbol, shape: Shape): Symbol { +// Returns the Python implementation of the ByteStream shape or the original symbol that is provided in input. +private fun toPythonByteStreamSymbolOrOriginal( + model: Model, + config: RustSymbolProviderConfig, + initial: Symbol, + shape: Shape, +): Symbol { if (shape !is MemberShape) { return initial } @@ -75,7 +80,6 @@ class PythonServerSymbolVisitor( serviceShape: ServiceShape?, config: RustSymbolProviderConfig, ) : SymbolVisitor(settings, model, serviceShape, config) { - private val runtimeConfig = config.runtimeConfig private val logger = Logger.getLogger(javaClass.name) @@ -111,7 +115,6 @@ class PythonConstrainedShapeSymbolProvider( serviceShape: ServiceShape, publicConstrainedTypes: Boolean, ) : ConstrainedShapeSymbolProvider(base, serviceShape, publicConstrainedTypes) { - override fun toSymbol(shape: Shape): Symbol { val initial = super.toSymbol(shape) return toPythonByteStreamSymbolOrOriginal(model, config, initial, shape) @@ -127,7 +130,6 @@ class PythonConstrainedShapeSymbolProvider( * Note that since streaming members can only be used on the root shape, this can only impact input and output shapes. */ class PythonStreamingShapeMetadataProvider(private val base: RustSymbolProvider) : SymbolMetadataProvider(base) { - override fun structureMeta(structureShape: StructureShape): RustMetadata { val baseMetadata = base.toSymbol(structureShape).expectRustMetadata() return if (structureShape.hasStreamingMember(model)) { @@ -147,10 +149,16 @@ class PythonStreamingShapeMetadataProvider(private val base: RustSymbolProvider) } override fun memberMeta(memberShape: MemberShape) = base.toSymbol(memberShape).expectRustMetadata() + override fun enumMeta(stringShape: StringShape) = base.toSymbol(stringShape).expectRustMetadata() + override fun listMeta(listShape: ListShape) = base.toSymbol(listShape).expectRustMetadata() + override fun mapMeta(mapShape: MapShape) = base.toSymbol(mapShape).expectRustMetadata() + override fun stringMeta(stringShape: StringShape) = base.toSymbol(stringShape).expectRustMetadata() + override fun numberMeta(numberShape: NumberShape) = base.toSymbol(numberShape).expectRustMetadata() + override fun blobMeta(blobShape: BlobShape) = base.toSymbol(blobShape).expectRustMetadata() } diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonType.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonType.kt index 8cdc481b395..e22786dd5f0 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonType.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonType.kt @@ -101,23 +101,23 @@ sealed class PythonType { } data class Opaque(override val name: String, val pythonRootModuleName: String, val rustNamespace: String? = null) : PythonType() { - - override val namespace: String? = rustNamespace?.split("::")?.joinToString(".") { - when (it) { - "crate" -> pythonRootModuleName - // In Python, we expose submodules from `aws_smithy_http_server_python` - // like `types`, `middleware`, `tls` etc. from Python root module - "aws_smithy_http_server_python" -> pythonRootModuleName - else -> it - } - } - // Most opaque types have a leading `::`, so strip that for Python as needed - .let { - when (it?.startsWith(".")) { - true -> it.substring(1) + override val namespace: String? = + rustNamespace?.split("::")?.joinToString(".") { + when (it) { + "crate" -> pythonRootModuleName + // In Python, we expose submodules from `aws_smithy_http_server_python` + // like `types`, `middleware`, `tls` etc. from Python root module + "aws_smithy_http_server_python" -> pythonRootModuleName else -> it } } + // Most opaque types have a leading `::`, so strip that for Python as needed + .let { + when (it?.startsWith(".")) { + true -> it.substring(1) + else -> it + } + } } } @@ -139,12 +139,13 @@ fun RustType.pythonType(pythonRootModuleName: String): PythonType = is RustType.Option -> PythonType.Optional(this.member.pythonType(pythonRootModuleName)) is RustType.Box -> this.member.pythonType(pythonRootModuleName) is RustType.Dyn -> this.member.pythonType(pythonRootModuleName) - is RustType.Application -> PythonType.Application( - this.type.pythonType(pythonRootModuleName), - this.args.map { - it.pythonType(pythonRootModuleName) - }, - ) + is RustType.Application -> + PythonType.Application( + this.type.pythonType(pythonRootModuleName), + this.args.map { + it.pythonType(pythonRootModuleName) + }, + ) is RustType.Opaque -> PythonType.Opaque(this.name, pythonRootModuleName, rustNamespace = this.namespace) is RustType.MaybeConstrained -> this.member.pythonType(pythonRootModuleName) } @@ -154,39 +155,41 @@ fun RustType.pythonType(pythonRootModuleName: String): PythonType = * It generates something like `typing.Dict[String, String]`. */ fun PythonType.render(fullyQualified: Boolean = true): String { - val namespace = if (fullyQualified) { - this.namespace?.let { "$it." } ?: "" - } else { - "" - } - val base = when (this) { - is PythonType.None -> this.name - is PythonType.Bool -> this.name - is PythonType.Float -> this.name - is PythonType.Int -> this.name - is PythonType.Str -> this.name - is PythonType.Any -> this.name - is PythonType.Opaque -> this.name - is PythonType.List -> "${this.name}[${this.member.render(fullyQualified)}]" - is PythonType.Dict -> "${this.name}[${this.key.render(fullyQualified)}, ${this.member.render(fullyQualified)}]" - is PythonType.Set -> "${this.name}[${this.member.render(fullyQualified)}]" - is PythonType.Awaitable -> "${this.name}[${this.member.render(fullyQualified)}]" - is PythonType.Optional -> "${this.name}[${this.member.render(fullyQualified)}]" - is PythonType.AsyncIterator -> "${this.name}[${this.member.render(fullyQualified)}]" - is PythonType.Application -> { - val args = this.args.joinToString(", ") { it.render(fullyQualified) } - "${this.name}[$args]" - } - is PythonType.Callable -> { - val args = this.args.joinToString(", ") { it.render(fullyQualified) } - val rtype = this.rtype.render(fullyQualified) - "${this.name}[[$args], $rtype]" + val namespace = + if (fullyQualified) { + this.namespace?.let { "$it." } ?: "" + } else { + "" } - is PythonType.Union -> { - val args = this.args.joinToString(", ") { it.render(fullyQualified) } - "${this.name}[$args]" + val base = + when (this) { + is PythonType.None -> this.name + is PythonType.Bool -> this.name + is PythonType.Float -> this.name + is PythonType.Int -> this.name + is PythonType.Str -> this.name + is PythonType.Any -> this.name + is PythonType.Opaque -> this.name + is PythonType.List -> "${this.name}[${this.member.render(fullyQualified)}]" + is PythonType.Dict -> "${this.name}[${this.key.render(fullyQualified)}, ${this.member.render(fullyQualified)}]" + is PythonType.Set -> "${this.name}[${this.member.render(fullyQualified)}]" + is PythonType.Awaitable -> "${this.name}[${this.member.render(fullyQualified)}]" + is PythonType.Optional -> "${this.name}[${this.member.render(fullyQualified)}]" + is PythonType.AsyncIterator -> "${this.name}[${this.member.render(fullyQualified)}]" + is PythonType.Application -> { + val args = this.args.joinToString(", ") { it.render(fullyQualified) } + "${this.name}[$args]" + } + is PythonType.Callable -> { + val args = this.args.joinToString(", ") { it.render(fullyQualified) } + val rtype = this.rtype.render(fullyQualified) + "${this.name}[[$args], $rtype]" + } + is PythonType.Union -> { + val args = this.args.joinToString(", ") { it.render(fullyQualified) } + "${this.name}[$args]" + } } - } return "$namespace$base" } diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/RustServerCodegenPythonPlugin.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/RustServerCodegenPythonPlugin.kt index a2c9ab45de6..735d6901571 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/RustServerCodegenPythonPlugin.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/RustServerCodegenPythonPlugin.kt @@ -75,8 +75,7 @@ class RustServerCodegenPythonPlugin : SmithyBuildPlugin { constrainedTypes: Boolean = true, includeConstrainedShapeProvider: Boolean = true, codegenDecorator: ServerCodegenDecorator, - ) = - // Rename a set of symbols that do not implement `PyClass` and have been wrapped in + ) = // Rename a set of symbols that do not implement `PyClass` and have been wrapped in // `aws_smithy_http_server_python::types`. PythonServerSymbolVisitor(settings, model, serviceShape = serviceShape, config = rustSymbolProviderConfig) // Generate public constrained types for directly constrained shapes. diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/customizations/PythonServerCodegenDecorator.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/customizations/PythonServerCodegenDecorator.kt index daa4a7b4dd4..ab103ff6394 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/customizations/PythonServerCodegenDecorator.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/customizations/PythonServerCodegenDecorator.kt @@ -35,15 +35,14 @@ class CdylibManifestDecorator : ServerCodegenDecorator { override val name: String = "CdylibDecorator" override val order: Byte = 0 - override fun crateManifestCustomizations( - codegenContext: ServerCodegenContext, - ): ManifestCustomizations = + override fun crateManifestCustomizations(codegenContext: ServerCodegenContext): ManifestCustomizations = mapOf( - "lib" to mapOf( - // Library target names cannot contain hyphen names. - "name" to codegenContext.settings.moduleName.toSnakeCase(), - "crate-type" to listOf("cdylib"), - ), + "lib" to + mapOf( + // Library target names cannot contain hyphen names. + "name" to codegenContext.settings.moduleName.toSnakeCase(), + "crate-type" to listOf("cdylib"), + ), ) } @@ -53,14 +52,15 @@ class CdylibManifestDecorator : ServerCodegenDecorator { class PubUsePythonTypes(private val codegenContext: ServerCodegenContext) : LibRsCustomization() { override fun section(section: LibRsSection): Writable { return when (section) { - is LibRsSection.Body -> writable { - docs("Re-exported Python types from supporting crates.") - rustBlock("pub mod python_types") { - rust("pub use #T;", PythonServerRuntimeType.blob(codegenContext.runtimeConfig).toSymbol()) - rust("pub use #T;", PythonServerRuntimeType.dateTime(codegenContext.runtimeConfig).toSymbol()) - rust("pub use #T;", PythonServerRuntimeType.document(codegenContext.runtimeConfig).toSymbol()) + is LibRsSection.Body -> + writable { + docs("Re-exported Python types from supporting crates.") + rustBlock("pub mod python_types") { + rust("pub use #T;", PythonServerRuntimeType.blob(codegenContext.runtimeConfig).toSymbol()) + rust("pub use #T;", PythonServerRuntimeType.dateTime(codegenContext.runtimeConfig).toSymbol()) + rust("pub use #T;", PythonServerRuntimeType.document(codegenContext.runtimeConfig).toSymbol()) + } } - } else -> emptySection } } @@ -73,7 +73,10 @@ class PythonExportModuleDecorator : ServerCodegenDecorator { override val name: String = "PythonExportModuleDecorator" override val order: Byte = 0 - override fun extras(codegenContext: ServerCodegenContext, rustCrate: RustCrate) { + override fun extras( + codegenContext: ServerCodegenContext, + rustCrate: RustCrate, + ) { val service = codegenContext.settings.getService(codegenContext.model) val serviceShapes = DirectedWalker(codegenContext.model).walkShapes(service) PythonServerModuleGenerator(codegenContext, rustCrate, serviceShapes).render() @@ -104,19 +107,26 @@ class PyProjectTomlDecorator : ServerCodegenDecorator { override val name: String = "PyProjectTomlDecorator" override val order: Byte = 0 - override fun extras(codegenContext: ServerCodegenContext, rustCrate: RustCrate) { + override fun extras( + codegenContext: ServerCodegenContext, + rustCrate: RustCrate, + ) { rustCrate.withFile("pyproject.toml") { - val config = mapOf( - "build-system" to listOfNotNull( - "requires" to listOfNotNull("maturin>=0.14,<0.15"), - "build-backend" to "maturin", - ).toMap(), - "tool" to listOfNotNull( - "maturin" to listOfNotNull( - "python-source" to "python", - ).toMap(), - ).toMap(), - ) + val config = + mapOf( + "build-system" to + listOfNotNull( + "requires" to listOfNotNull("maturin>=0.14,<0.15"), + "build-backend" to "maturin", + ).toMap(), + "tool" to + listOfNotNull( + "maturin" to + listOfNotNull( + "python-source" to "python", + ).toMap(), + ).toMap(), + ) writeWithNoFormatting(TomlWriter().write(config)) } } @@ -134,7 +144,10 @@ class PyO3ExtensionModuleDecorator : ServerCodegenDecorator { override val name: String = "PyO3ExtensionModuleDecorator" override val order: Byte = 0 - override fun extras(codegenContext: ServerCodegenContext, rustCrate: RustCrate) { + override fun extras( + codegenContext: ServerCodegenContext, + rustCrate: RustCrate, + ) { // Add `pyo3/extension-module` to default features. rustCrate.mergeFeature(Feature("extension-module", true, listOf("pyo3/extension-module"))) } @@ -157,7 +170,10 @@ class InitPyDecorator : ServerCodegenDecorator { override val name: String = "InitPyDecorator" override val order: Byte = 0 - override fun extras(codegenContext: ServerCodegenContext, rustCrate: RustCrate) { + override fun extras( + codegenContext: ServerCodegenContext, + rustCrate: RustCrate, + ) { val libName = codegenContext.settings.moduleName.toSnakeCase() rustCrate.withFile("python/$libName/__init__.py") { @@ -185,7 +201,10 @@ class PyTypedMarkerDecorator : ServerCodegenDecorator { override val name: String = "PyTypedMarkerDecorator" override val order: Byte = 0 - override fun extras(codegenContext: ServerCodegenContext, rustCrate: RustCrate) { + override fun extras( + codegenContext: ServerCodegenContext, + rustCrate: RustCrate, + ) { val libName = codegenContext.settings.moduleName.toSnakeCase() rustCrate.withFile("python/$libName/py.typed") { @@ -204,7 +223,10 @@ class AddStubgenScriptDecorator : ServerCodegenDecorator { override val name: String = "AddStubgenScriptDecorator" override val order: Byte = 0 - override fun extras(codegenContext: ServerCodegenContext, rustCrate: RustCrate) { + override fun extras( + codegenContext: ServerCodegenContext, + rustCrate: RustCrate, + ) { val stubgenPythonContent = this::class.java.getResource("/stubgen.py").readText() rustCrate.withFile("stubgen.py") { writeWithNoFormatting("$stubgenPythonContent") @@ -216,26 +238,27 @@ class AddStubgenScriptDecorator : ServerCodegenDecorator { } } -val DECORATORS = arrayOf( - /** - * Add the [InternalServerError] error to all operations. - * This is done because the Python interpreter can raise exceptions during execution. - */ - AddInternalServerErrorToAllOperationsDecorator(), - // Add the [lib] section to Cargo.toml to configure the generation of the shared library. - CdylibManifestDecorator(), - // Add `pub use` of `aws_smithy_http_server_python::types`. - PubUsePythonTypesDecorator(), - // Render the Python shared library export. - PythonExportModuleDecorator(), - // Generate `pyproject.toml` for the crate. - PyProjectTomlDecorator(), - // Add PyO3 extension module feature. - PyO3ExtensionModuleDecorator(), - // Generate `__init__.py` for the Python source. - InitPyDecorator(), - // Generate `py.typed` for the Python source. - PyTypedMarkerDecorator(), - // Generate scripts for stub generation. - AddStubgenScriptDecorator(), -) +val DECORATORS = + arrayOf( + /* + * Add the [InternalServerError] error to all operations. + * This is done because the Python interpreter can raise exceptions during execution. + */ + AddInternalServerErrorToAllOperationsDecorator(), + // Add the [lib] section to Cargo.toml to configure the generation of the shared library. + CdylibManifestDecorator(), + // Add `pub use` of `aws_smithy_http_server_python::types`. + PubUsePythonTypesDecorator(), + // Render the Python shared library export. + PythonExportModuleDecorator(), + // Generate `pyproject.toml` for the crate. + PyProjectTomlDecorator(), + // Add PyO3 extension module feature. + PyO3ExtensionModuleDecorator(), + // Generate `__init__.py` for the Python source. + InitPyDecorator(), + // Generate `py.typed` for the Python source. + PyTypedMarkerDecorator(), + // Generate scripts for stub generation. + AddStubgenScriptDecorator(), + ) diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/ConstrainedPythonBlobGenerator.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/ConstrainedPythonBlobGenerator.kt index a9c2021520e..d25c1103592 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/ConstrainedPythonBlobGenerator.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/ConstrainedPythonBlobGenerator.kt @@ -45,9 +45,10 @@ class ConstrainedPythonBlobGenerator( } } val constraintViolation = constraintViolationSymbolProvider.toSymbol(shape) - private val blobConstraintsInfo: List = listOf(LengthTrait::class.java) - .mapNotNull { shape.getTrait(it).orNull() } - .map { BlobLength(it) } + private val blobConstraintsInfo: List = + listOf(LengthTrait::class.java) + .mapNotNull { shape.getTrait(it).orNull() } + .map { BlobLength(it) } private val constraintsInfo: List = blobConstraintsInfo.map { it.toTraitInfo() } fun render() { @@ -57,7 +58,10 @@ class ConstrainedPythonBlobGenerator( renderTryFrom(symbol, blobType) } - fun renderFrom(symbol: Symbol, blobType: RustType) { + fun renderFrom( + symbol: Symbol, + blobType: RustType, + ) { val name = symbol.name val inner = blobType.render() writer.rustTemplate( @@ -79,7 +83,10 @@ class ConstrainedPythonBlobGenerator( ) } - fun renderTryFrom(symbol: Symbol, blobType: RustType) { + fun renderTryFrom( + symbol: Symbol, + blobType: RustType, + ) { val name = symbol.name val inner = blobType.render() writer.rustTemplate( diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt index 9b05d22434a..84ef0bdbc55 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt @@ -70,11 +70,12 @@ class PythonApplicationGenerator( private val protocol: ServerProtocol, ) { private val index = TopDownIndex.of(codegenContext.model) - private val operations = index.getContainedOperations(codegenContext.serviceShape).toSortedSet( - compareBy { - it.id - }, - ).toList() + private val operations = + index.getContainedOperations(codegenContext.serviceShape).toSortedSet( + compareBy { + it.id + }, + ).toList() private val symbolProvider = codegenContext.symbolProvider private val libName = codegenContext.settings.moduleName.toSnakeCase() private val runtimeConfig = codegenContext.runtimeConfig @@ -348,18 +349,19 @@ class PythonApplicationGenerator( val output = PythonType.Opaque("${operationName}Output", libName, rustNamespace = "crate::output") val context = PythonType.Opaque("Ctx", libName) val returnType = PythonType.Union(listOf(output, PythonType.Awaitable(output))) - val handler = PythonType.Union( - listOf( - PythonType.Callable( - listOf(input, context), - returnType, + val handler = + PythonType.Union( + listOf( + PythonType.Callable( + listOf(input, context), + returnType, + ), + PythonType.Callable( + listOf(input), + returnType, + ), ), - PythonType.Callable( - listOf(input), - returnType, - ), - ), - ) + ) rustTemplate( """ @@ -435,21 +437,23 @@ class PythonApplicationGenerator( ) } - private fun RustWriter.operationImplementationStubs(operations: List) = rust( - operations.joinToString("\n///\n") { - val operationDocumentation = it.getTrait()?.value - val ret = if (!operationDocumentation.isNullOrBlank()) { - operationDocumentation.replace("#", "##").prependIndent("/// ## ") + "\n" - } else { - "" - } - ret + - """ - /// ${it.signature()}: - /// raise NotImplementedError - """.trimIndent() - }, - ) + private fun RustWriter.operationImplementationStubs(operations: List) = + rust( + operations.joinToString("\n///\n") { + val operationDocumentation = it.getTrait()?.value + val ret = + if (!operationDocumentation.isNullOrBlank()) { + operationDocumentation.replace("#", "##").prependIndent("/// ## ") + "\n" + } else { + "" + } + ret + + """ + /// ${it.signature()}: + /// raise NotImplementedError + """.trimIndent() + }, + ) /** * Returns the function signature for an operation handler implementation. Used in the documentation. diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerEnumGenerator.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerEnumGenerator.kt index ac12cc0df37..a7bcdd56d5e 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerEnumGenerator.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerEnumGenerator.kt @@ -35,27 +35,28 @@ class PythonConstrainedEnum( override fun additionalEnumAttributes(context: EnumGeneratorContext): List = listOf(Attribute(pyO3.resolve("pyclass"))) - override fun additionalEnumImpls(context: EnumGeneratorContext): Writable = writable { - Attribute(pyO3.resolve("pymethods")).render(this) - rustTemplate( - """ - impl ${context.enumName} { - #{name_method:W} - ##[getter] - pub fn value(&self) -> &str { - self.as_str() - } - fn __repr__(&self) -> String { - self.as_str().to_owned() - } - fn __str__(&self) -> String { - self.as_str().to_owned() + override fun additionalEnumImpls(context: EnumGeneratorContext): Writable = + writable { + Attribute(pyO3.resolve("pymethods")).render(this) + rustTemplate( + """ + impl ${context.enumName} { + #{name_method:W} + ##[getter] + pub fn value(&self) -> &str { + self.as_str() + } + fn __repr__(&self) -> String { + self.as_str().to_owned() + } + fn __str__(&self) -> String { + self.as_str().to_owned() + } } - } - """, - "name_method" to pyEnumName(context), - ) - } + """, + "name_method" to pyEnumName(context), + ) + } private fun pyEnumName(context: EnumGeneratorContext): Writable = writable { @@ -80,8 +81,8 @@ class PythonServerEnumGenerator( shape: StringShape, validationExceptionConversionGenerator: ValidationExceptionConversionGenerator, ) : EnumGenerator( - codegenContext.model, - codegenContext.symbolProvider, - shape, - PythonConstrainedEnum(codegenContext, shape, validationExceptionConversionGenerator), -) + codegenContext.model, + codegenContext.symbolProvider, + shape, + PythonConstrainedEnum(codegenContext, shape, validationExceptionConversionGenerator), + ) diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerEventStreamErrorGenerator.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerEventStreamErrorGenerator.kt index b915e953cd9..118aa8a76c8 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerEventStreamErrorGenerator.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerEventStreamErrorGenerator.kt @@ -34,14 +34,15 @@ class PythonServerEventStreamErrorGenerator( private val symbolProvider: RustSymbolProvider, val shape: UnionShape, ) : ServerOperationErrorGenerator( - model, - symbolProvider, - shape, -) { + model, + symbolProvider, + shape, + ) { private val errorSymbol = symbolProvider.symbolForEventStreamError(shape) - private val errors = shape.eventStreamErrors().map { - model.expectShape(it.asMemberShape().get().target, StructureShape::class.java) - } + private val errors = + shape.eventStreamErrors().map { + model.expectShape(it.asMemberShape().get().target, StructureShape::class.java) + } private val pyO3 = PythonServerCargoDependency.PyO3.toType() diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerModuleGenerator.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerModuleGenerator.kt index cbec0b8551c..6edb2072fd2 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerModuleGenerator.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerModuleGenerator.kt @@ -28,10 +28,11 @@ class PythonServerModuleGenerator( private val rustCrate: RustCrate, private val serviceShapes: Set, ) { - private val codegenScope = arrayOf( - "SmithyPython" to PythonServerCargoDependency.smithyHttpServerPython(codegenContext.runtimeConfig).toType(), - "pyo3" to PythonServerCargoDependency.PyO3.toType(), - ) + private val codegenScope = + arrayOf( + "SmithyPython" to PythonServerCargoDependency.smithyHttpServerPython(codegenContext.runtimeConfig).toType(), + "pyo3" to PythonServerCargoDependency.PyO3.toType(), + ) private val symbolProvider = codegenContext.symbolProvider private val libName = codegenContext.settings.moduleName.toSnakeCase() @@ -84,18 +85,20 @@ class PythonServerModuleGenerator( visitedModelType = true } when (shape) { - is UnionShape -> rustTemplate( - """ - $moduleType.add_class::()?; - """, - *codegenScope, - ) - else -> rustTemplate( - """ - $moduleType.add_class::()?; - """, - *codegenScope, - ) + is UnionShape -> + rustTemplate( + """ + $moduleType.add_class::()?; + """, + *codegenScope, + ) + else -> + rustTemplate( + """ + $moduleType.add_class::()?; + """, + *codegenScope, + ) } } } diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerOperationHandlerGenerator.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerOperationHandlerGenerator.kt index 53b82c2d293..c77798f88bf 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerOperationHandlerGenerator.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerOperationHandlerGenerator.kt @@ -82,7 +82,10 @@ class PythonServerOperationHandlerGenerator( ) } - private fun renderPyFunction(name: String, output: String): Writable = + private fun renderPyFunction( + name: String, + output: String, + ): Writable = writable { rustTemplate( """ @@ -101,7 +104,10 @@ class PythonServerOperationHandlerGenerator( ) } - private fun renderPyCoroutine(name: String, output: String): Writable = + private fun renderPyCoroutine( + name: String, + output: String, + ): Writable = writable { rustTemplate( """ diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerStructureGenerator.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerStructureGenerator.kt index 9357383e6b2..3a78b0ddd05 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerStructureGenerator.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerStructureGenerator.kt @@ -41,7 +41,6 @@ class PythonServerStructureGenerator( private val writer: RustWriter, private val shape: StructureShape, ) : StructureGenerator(model, codegenContext.symbolProvider, writer, shape, emptyList(), codegenContext.structSettings()) { - private val symbolProvider = codegenContext.symbolProvider private val libName = codegenContext.settings.moduleName.toSnakeCase() private val pyO3 = PythonServerCargoDependency.PyO3.toType() @@ -151,13 +150,19 @@ class PythonServerStructureGenerator( rust("/// :rtype ${PythonType.None.renderAsDocstring()}:") } - private fun renderMemberSignature(shape: MemberShape, symbol: Symbol): Writable = + private fun renderMemberSignature( + shape: MemberShape, + symbol: Symbol, + ): Writable = writable { val pythonType = memberPythonType(shape, symbol) rust("/// :type ${pythonType.renderAsDocstring()}:") } - private fun memberPythonType(shape: MemberShape, symbol: Symbol): PythonType = + private fun memberPythonType( + shape: MemberShape, + symbol: Symbol, + ): PythonType = if (shape.isEventStream(model)) { val eventStreamSymbol = PythonEventStreamSymbolProvider.parseSymbol(symbol) val innerT = eventStreamSymbol.innerT.pythonType(libName) diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerUnionGenerator.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerUnionGenerator.kt index 01a2a833afe..ef5a043a505 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerUnionGenerator.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerUnionGenerator.kt @@ -179,11 +179,12 @@ class PythonServerUnionGenerator( ) writer.rust("/// :rtype ${pythonType.renderAsDocstring()}:") writer.rustBlockTemplate("pub fn as_$funcNamePart(&self) -> #{pyo3}::PyResult<${rustType.render()}>", "pyo3" to pyo3) { - val variantType = if (rustType.isCopy()) { - "*variant" - } else { - "variant.clone()" - } + val variantType = + if (rustType.isCopy()) { + "*variant" + } else { + "variant.clone()" + } val errorVariant = memberSymbol.rustType().pythonType(libName).renderAsDocstring() rustTemplate( """ diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/protocols/PythonServerProtocolLoader.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/protocols/PythonServerProtocolLoader.kt index dbfa9e029dc..4e41cc8ed37 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/protocols/PythonServerProtocolLoader.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/protocols/PythonServerProtocolLoader.kt @@ -37,18 +37,22 @@ import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerRestJso */ class PythonServerAfterDeserializedMemberJsonParserCustomization(private val runtimeConfig: RuntimeConfig) : JsonParserCustomization() { - override fun section(section: JsonParserSection): Writable = when (section) { - is JsonParserSection.AfterTimestampDeserializedMember -> writable { - rust(".map(#T::from)", PythonServerRuntimeType.dateTime(runtimeConfig).toSymbol()) + override fun section(section: JsonParserSection): Writable = + when (section) { + is JsonParserSection.AfterTimestampDeserializedMember -> + writable { + rust(".map(#T::from)", PythonServerRuntimeType.dateTime(runtimeConfig).toSymbol()) + } + is JsonParserSection.AfterBlobDeserializedMember -> + writable { + rust(".map(#T::from)", PythonServerRuntimeType.blob(runtimeConfig).toSymbol()) + } + is JsonParserSection.AfterDocumentDeserializedMember -> + writable { + rust(".map(#T::from)", PythonServerRuntimeType.document(runtimeConfig).toSymbol()) + } + else -> emptySection } - is JsonParserSection.AfterBlobDeserializedMember -> writable { - rust(".map(#T::from)", PythonServerRuntimeType.blob(runtimeConfig).toSymbol()) - } - is JsonParserSection.AfterDocumentDeserializedMember -> writable { - rust(".map(#T::from)", PythonServerRuntimeType.document(runtimeConfig).toSymbol()) - } - else -> emptySection - } } /** @@ -57,13 +61,15 @@ class PythonServerAfterDeserializedMemberJsonParserCustomization(private val run */ class PythonServerAfterDeserializedMemberServerHttpBoundCustomization() : ServerHttpBoundProtocolCustomization() { - override fun section(section: ServerHttpBoundProtocolSection): Writable = when (section) { - is ServerHttpBoundProtocolSection.AfterTimestampDeserializedMember -> writable { - rust(".into()") - } + override fun section(section: ServerHttpBoundProtocolSection): Writable = + when (section) { + is ServerHttpBoundProtocolSection.AfterTimestampDeserializedMember -> + writable { + rust(".into()") + } - else -> emptySection - } + else -> emptySection + } } /** @@ -71,12 +77,14 @@ class PythonServerAfterDeserializedMemberServerHttpBoundCustomization() : */ class PythonServerAfterDeserializedMemberHttpBindingCustomization(private val runtimeConfig: RuntimeConfig) : HttpBindingCustomization() { - override fun section(section: HttpBindingSection): Writable = when (section) { - is HttpBindingSection.AfterDeserializingIntoADateTimeOfHttpHeaders -> writable { - rust(".into_iter().map(#T::from).collect()", PythonServerRuntimeType.dateTime(runtimeConfig).toSymbol()) + override fun section(section: HttpBindingSection): Writable = + when (section) { + is HttpBindingSection.AfterDeserializingIntoADateTimeOfHttpHeaders -> + writable { + rust(".into_iter().map(#T::from).collect()", PythonServerRuntimeType.dateTime(runtimeConfig).toSymbol()) + } + else -> emptySection } - else -> emptySection - } } /** @@ -87,60 +95,73 @@ class PythonServerAfterDeserializedMemberHttpBindingCustomization(private val ru * `aws_smithy_http_server_python::types::ByteStream` which already implements the `Stream` trait. */ class PythonServerStreamPayloadSerializerCustomization() : ServerHttpBoundProtocolCustomization() { - override fun section(section: ServerHttpBoundProtocolSection): Writable = when (section) { - is ServerHttpBoundProtocolSection.WrapStreamPayload -> writable { - section.params.payloadGenerator.generatePayload(this, section.params.shapeName, section.params.shape) - } + override fun section(section: ServerHttpBoundProtocolSection): Writable = + when (section) { + is ServerHttpBoundProtocolSection.WrapStreamPayload -> + writable { + section.params.payloadGenerator.generatePayload(this, section.params.shapeName, section.params.shape) + } - else -> emptySection - } + else -> emptySection + } } class PythonServerProtocolLoader( private val supportedProtocols: ProtocolMap, ) : ProtocolLoader(supportedProtocols) { - companion object { fun defaultProtocols(runtimeConfig: RuntimeConfig) = mapOf( - RestJson1Trait.ID to ServerRestJsonFactory( - additionalParserCustomizations = listOf( - PythonServerAfterDeserializedMemberJsonParserCustomization(runtimeConfig), - ), - additionalServerHttpBoundProtocolCustomizations = listOf( - PythonServerAfterDeserializedMemberServerHttpBoundCustomization(), - PythonServerStreamPayloadSerializerCustomization(), - ), - additionalHttpBindingCustomizations = listOf( - PythonServerAfterDeserializedMemberHttpBindingCustomization(runtimeConfig), - ), - ), - AwsJson1_0Trait.ID to ServerAwsJsonFactory( - AwsJsonVersion.Json10, - additionalParserCustomizations = listOf( - PythonServerAfterDeserializedMemberJsonParserCustomization(runtimeConfig), - ), - additionalServerHttpBoundProtocolCustomizations = listOf( - PythonServerAfterDeserializedMemberServerHttpBoundCustomization(), - PythonServerStreamPayloadSerializerCustomization(), - ), - additionalHttpBindingCustomizations = listOf( - PythonServerAfterDeserializedMemberHttpBindingCustomization(runtimeConfig), - ), - ), - AwsJson1_1Trait.ID to ServerAwsJsonFactory( - AwsJsonVersion.Json11, - additionalParserCustomizations = listOf( - PythonServerAfterDeserializedMemberJsonParserCustomization(runtimeConfig), + RestJson1Trait.ID to + ServerRestJsonFactory( + additionalParserCustomizations = + listOf( + PythonServerAfterDeserializedMemberJsonParserCustomization(runtimeConfig), + ), + additionalServerHttpBoundProtocolCustomizations = + listOf( + PythonServerAfterDeserializedMemberServerHttpBoundCustomization(), + PythonServerStreamPayloadSerializerCustomization(), + ), + additionalHttpBindingCustomizations = + listOf( + PythonServerAfterDeserializedMemberHttpBindingCustomization(runtimeConfig), + ), ), - additionalServerHttpBoundProtocolCustomizations = listOf( - PythonServerAfterDeserializedMemberServerHttpBoundCustomization(), - PythonServerStreamPayloadSerializerCustomization(), + AwsJson1_0Trait.ID to + ServerAwsJsonFactory( + AwsJsonVersion.Json10, + additionalParserCustomizations = + listOf( + PythonServerAfterDeserializedMemberJsonParserCustomization(runtimeConfig), + ), + additionalServerHttpBoundProtocolCustomizations = + listOf( + PythonServerAfterDeserializedMemberServerHttpBoundCustomization(), + PythonServerStreamPayloadSerializerCustomization(), + ), + additionalHttpBindingCustomizations = + listOf( + PythonServerAfterDeserializedMemberHttpBindingCustomization(runtimeConfig), + ), ), - additionalHttpBindingCustomizations = listOf( - PythonServerAfterDeserializedMemberHttpBindingCustomization(runtimeConfig), + AwsJson1_1Trait.ID to + ServerAwsJsonFactory( + AwsJsonVersion.Json11, + additionalParserCustomizations = + listOf( + PythonServerAfterDeserializedMemberJsonParserCustomization(runtimeConfig), + ), + additionalServerHttpBoundProtocolCustomizations = + listOf( + PythonServerAfterDeserializedMemberServerHttpBoundCustomization(), + PythonServerStreamPayloadSerializerCustomization(), + ), + additionalHttpBindingCustomizations = + listOf( + PythonServerAfterDeserializedMemberHttpBindingCustomization(runtimeConfig), + ), ), - ), ) } } diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/testutil/PythonServerTestHelpers.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/testutil/PythonServerTestHelpers.kt index c7e6023f3cc..3f4199c4fa4 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/testutil/PythonServerTestHelpers.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/testutil/PythonServerTestHelpers.kt @@ -21,10 +21,9 @@ import java.io.File import java.nio.file.Path val TestRuntimeConfig = - RuntimeConfig(runtimeCrateLocation = RuntimeCrateLocation.Path(File("../../rust-runtime").absolutePath)) + RuntimeConfig(runtimeCrateLocation = RuntimeCrateLocation.path(File("../../rust-runtime").absolutePath)) -fun generatePythonServerPluginContext(model: Model) = - generatePluginContext(model, runtimeConfig = TestRuntimeConfig) +fun generatePythonServerPluginContext(model: Model) = generatePluginContext(model, runtimeConfig = TestRuntimeConfig) fun executePythonServerCodegenVisitor(pluginCtx: PluginContext) { val codegenDecorator = diff --git a/codegen-server/python/src/test/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerSymbolProviderTest.kt b/codegen-server/python/src/test/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerSymbolProviderTest.kt index e3ab955a9f1..3f0922f49e4 100644 --- a/codegen-server/python/src/test/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerSymbolProviderTest.kt +++ b/codegen-server/python/src/test/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerSymbolProviderTest.kt @@ -21,7 +21,8 @@ internal class PythonServerSymbolProviderTest { @Test fun `python symbol provider rewrites timestamp shape symbol`() { - val model = """ + val model = + """ namespace test structure TimestampStruct { @@ -45,7 +46,7 @@ internal class PythonServerSymbolProviderTest { key: String, value: Timestamp } - """.asSmithyModel() + """.asSmithyModel() val provider = PythonServerSymbolVisitor(serverTestRustSettings(), model, null, ServerTestRustSymbolProviderConfig) @@ -72,7 +73,8 @@ internal class PythonServerSymbolProviderTest { @Test fun `python symbol provider rewrites blob shape symbol`() { - val model = """ + val model = + """ namespace test structure BlobStruct { @@ -96,7 +98,7 @@ internal class PythonServerSymbolProviderTest { key: String, value: Blob } - """.asSmithyModel() + """.asSmithyModel() val provider = PythonServerSymbolVisitor(serverTestRustSettings(), model, null, ServerTestRustSymbolProviderConfig) diff --git a/codegen-server/python/src/test/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerTypesTest.kt b/codegen-server/python/src/test/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerTypesTest.kt index 9c38d343958..44b83c31326 100644 --- a/codegen-server/python/src/test/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerTypesTest.kt +++ b/codegen-server/python/src/test/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerTypesTest.kt @@ -19,7 +19,8 @@ import kotlin.io.path.appendText internal class PythonServerTypesTest { @Test fun `document type`() { - val model = """ + val model = + """ namespace test use aws.protocols#restJson1 @@ -44,50 +45,51 @@ internal class PythonServerTypesTest { structure EchoOutput { value: Document, } - """.asSmithyModel() + """.asSmithyModel() val (pluginCtx, testDir) = generatePythonServerPluginContext(model) executePythonServerCodegenVisitor(pluginCtx) - val testCases = listOf( - Pair( - """ { "value": 42 } """, - """ - assert input.value == 42 - output = EchoOutput(value=input.value) - """, - ), - Pair( - """ { "value": "foobar" } """, - """ - assert input.value == "foobar" - output = EchoOutput(value=input.value) - """, - ), - Pair( - """ - { - "value": [ - true, - false, - 42, - 42.0, - -42, - { - "nested": "value" - }, - { - "nested": [1, 2, 3] - } - ] - } - """, - """ - assert input.value == [True, False, 42, 42.0, -42, {"nested": "value"}, {"nested": [1, 2, 3]}] - output = EchoOutput(value=input.value) - """, - ), - ) + val testCases = + listOf( + Pair( + """ { "value": 42 } """, + """ + assert input.value == 42 + output = EchoOutput(value=input.value) + """, + ), + Pair( + """ { "value": "foobar" } """, + """ + assert input.value == "foobar" + output = EchoOutput(value=input.value) + """, + ), + Pair( + """ + { + "value": [ + true, + false, + 42, + 42.0, + -42, + { + "nested": "value" + }, + { + "nested": [1, 2, 3] + } + ] + } + """, + """ + assert input.value == [True, False, 42, 42.0, -42, {"nested": "value"}, {"nested": [1, 2, 3]}] + output = EchoOutput(value=input.value) + """, + ), + ) val writer = RustWriter.forModule("service") writer.tokioTest("document_type") { @@ -147,7 +149,8 @@ internal class PythonServerTypesTest { @Test fun `timestamp type`() { - val model = """ + val model = + """ namespace test use aws.protocols#restJson1 @@ -178,7 +181,7 @@ internal class PythonServerTypesTest { value: Timestamp, opt_value: Timestamp, } - """.asSmithyModel() + """.asSmithyModel() val (pluginCtx, testDir) = generatePythonServerPluginContext(model) executePythonServerCodegenVisitor(pluginCtx) diff --git a/codegen-server/python/src/test/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonTypeInformationGenerationTest.kt b/codegen-server/python/src/test/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonTypeInformationGenerationTest.kt index c552d32eddb..9a7bbc1ed79 100644 --- a/codegen-server/python/src/test/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonTypeInformationGenerationTest.kt +++ b/codegen-server/python/src/test/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonTypeInformationGenerationTest.kt @@ -16,7 +16,8 @@ import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestCode internal class PythonTypeInformationGenerationTest { @Test fun `generates python type information`() { - val model = """ + val model = + """ namespace test structure Foo { @@ -24,7 +25,7 @@ internal class PythonTypeInformationGenerationTest { bar: String, baz: Integer } - """.asSmithyModel() + """.asSmithyModel() val foo = model.lookup("test#Foo") val codegenContext = serverTestCodegenContext(model) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstrainedShapeSymbolMetadataProvider.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstrainedShapeSymbolMetadataProvider.kt index df933eb657d..f100fb55704 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstrainedShapeSymbolMetadataProvider.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstrainedShapeSymbolMetadataProvider.kt @@ -31,10 +31,12 @@ class ConstrainedShapeSymbolMetadataProvider( private val base: RustSymbolProvider, private val constrainedTypes: Boolean, ) : SymbolMetadataProvider(base) { - override fun memberMeta(memberShape: MemberShape) = base.toSymbol(memberShape).expectRustMetadata() + override fun structureMeta(structureShape: StructureShape) = base.toSymbol(structureShape).expectRustMetadata() + override fun unionMeta(unionShape: UnionShape) = base.toSymbol(unionShape).expectRustMetadata() + override fun enumMeta(stringShape: StringShape) = base.toSymbol(stringShape).expectRustMetadata() private fun addDerivesAndAdjustVisibilityIfConstrained(shape: Shape): RustMetadata { @@ -58,8 +60,14 @@ class ConstrainedShapeSymbolMetadataProvider( } override fun listMeta(listShape: ListShape): RustMetadata = addDerivesAndAdjustVisibilityIfConstrained(listShape) + override fun mapMeta(mapShape: MapShape): RustMetadata = addDerivesAndAdjustVisibilityIfConstrained(mapShape) - override fun stringMeta(stringShape: StringShape): RustMetadata = addDerivesAndAdjustVisibilityIfConstrained(stringShape) - override fun numberMeta(numberShape: NumberShape): RustMetadata = addDerivesAndAdjustVisibilityIfConstrained(numberShape) + + override fun stringMeta(stringShape: StringShape): RustMetadata = + addDerivesAndAdjustVisibilityIfConstrained(stringShape) + + override fun numberMeta(numberShape: NumberShape): RustMetadata = + addDerivesAndAdjustVisibilityIfConstrained(numberShape) + override fun blobMeta(blobShape: BlobShape) = addDerivesAndAdjustVisibilityIfConstrained(blobShape) } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstrainedShapeSymbolProvider.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstrainedShapeSymbolProvider.kt index 0dfda68d1eb..4d4d9db21a5 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstrainedShapeSymbolProvider.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstrainedShapeSymbolProvider.kt @@ -144,8 +144,9 @@ open class ConstrainedShapeSymbolProvider( supportedCollectionConstraintTraits.mapNotNull { shape.getTrait(it).orNull() }.toSet() val allConstraintTraits = allConstraintTraits.mapNotNull { shape.getTrait(it).orNull() }.toSet() - return supportedConstraintTraits.isNotEmpty() && allConstraintTraits.subtract(supportedConstraintTraits) - .isEmpty() + return supportedConstraintTraits.isNotEmpty() && + allConstraintTraits.subtract(supportedConstraintTraits) + .isEmpty() } /** @@ -160,8 +161,9 @@ open class ConstrainedShapeSymbolProvider( defaultModule: RustModule.LeafModule, pubCrateServerBuilder: Boolean, ): Pair { - val syntheticMemberTrait = shape.getTrait() - ?: return Pair(shape.contextName(serviceShape), defaultModule) + val syntheticMemberTrait = + shape.getTrait() + ?: return Pair(shape.contextName(serviceShape), defaultModule) return if (syntheticMemberTrait.container is StructureShape) { val builderModule = syntheticMemberTrait.container.serverBuilderModule(base, pubCrateServerBuilder) @@ -171,19 +173,22 @@ open class ConstrainedShapeSymbolProvider( // For non-structure shapes, the new shape defined for a constrained member shape // needs to be placed in an inline module named `pub {container_name_in_snake_case}`. val moduleName = RustReservedWords.escapeIfNeeded(syntheticMemberTrait.container.id.name.toSnakeCase()) - val innerModuleName = moduleName + if (pubCrateServerBuilder) { - "_internal" - } else { - "" - } + val innerModuleName = + moduleName + + if (pubCrateServerBuilder) { + "_internal" + } else { + "" + } - val innerModule = RustModule.new( - innerModuleName, - visibility = Visibility.publicIf(!pubCrateServerBuilder, Visibility.PUBCRATE), - parent = defaultModule, - inline = true, - documentationOverride = "", - ) + val innerModule = + RustModule.new( + innerModuleName, + visibility = Visibility.publicIf(!pubCrateServerBuilder, Visibility.PUBCRATE), + parent = defaultModule, + inline = true, + documentationOverride = "", + ) val renameTo = syntheticMemberTrait.member.memberName ?: syntheticMemberTrait.member.id.name Pair(renameTo.toPascalCase(), innerModule) } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintViolationSymbolProvider.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintViolationSymbolProvider.kt index c42ed10198f..ca8ae559f80 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintViolationSymbolProvider.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintViolationSymbolProvider.kt @@ -72,31 +72,34 @@ class ConstraintViolationSymbolProvider( private val serviceShape: ServiceShape, ) : WrappingSymbolProvider(base) { private val constraintViolationName = "ConstraintViolation" - private val visibility = when (publicConstrainedTypes) { - true -> Visibility.PUBLIC - false -> Visibility.PUBCRATE - } + private val visibility = + when (publicConstrainedTypes) { + true -> Visibility.PUBLIC + false -> Visibility.PUBCRATE + } private fun Shape.shapeModule(): RustModule.LeafModule { - val documentation = if (publicConstrainedTypes && this.isDirectlyConstrained(base)) { - val symbol = base.toSymbol(this) - "See [`${this.contextName(serviceShape)}`]($symbol)." - } else { - "" - } + val documentation = + if (publicConstrainedTypes && this.isDirectlyConstrained(base)) { + val symbol = base.toSymbol(this) + "See [`${this.contextName(serviceShape)}`]($symbol)." + } else { + "" + } val syntheticTrait = getTrait() - val (module, name) = if (syntheticTrait != null) { - // For constrained member shapes, the ConstraintViolation code needs to go in an inline rust module - // that is a descendant of the module that contains the extracted shape itself. - val overriddenMemberModule = this.getParentAndInlineModuleForConstrainedMember(base, publicConstrainedTypes)!! - val name = syntheticTrait.member.memberName - Pair(overriddenMemberModule.second, RustReservedWords.escapeIfNeeded(name).toSnakeCase()) - } else { - // Need to use the context name so we get the correct name for maps. - Pair(ServerRustModule.Model, RustReservedWords.escapeIfNeeded(this.contextName(serviceShape)).toSnakeCase()) - } + val (module, name) = + if (syntheticTrait != null) { + // For constrained member shapes, the ConstraintViolation code needs to go in an inline rust module + // that is a descendant of the module that contains the extracted shape itself. + val overriddenMemberModule = this.getParentAndInlineModuleForConstrainedMember(base, publicConstrainedTypes)!! + val name = syntheticTrait.member.memberName + Pair(overriddenMemberModule.second, RustReservedWords.escapeIfNeeded(name).toSnakeCase()) + } else { + // Need to use the context name so we get the correct name for maps. + Pair(ServerRustModule.Model, RustReservedWords.escapeIfNeeded(this.contextName(serviceShape)).toSnakeCase()) + } return RustModule.new( name = name, diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/Constraints.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/Constraints.kt index 5d0313e132c..f6dc9eac664 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/Constraints.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/Constraints.kt @@ -41,7 +41,7 @@ import software.amazon.smithy.rust.codegen.core.util.toSnakeCase import software.amazon.smithy.rust.codegen.server.smithy.generators.serverBuilderModule import software.amazon.smithy.rust.codegen.server.smithy.traits.SyntheticStructureFromConstrainedMemberTrait -/** +/* * This file contains utilities to work with constrained shapes. */ @@ -49,27 +49,28 @@ import software.amazon.smithy.rust.codegen.server.smithy.traits.SyntheticStructu * Whether the shape has any trait that could cause a request to be rejected with a constraint violation, _whether * we support it or not_. */ -fun Shape.hasConstraintTrait() = - allConstraintTraits.any(this::hasTrait) - -val allConstraintTraits = setOf( - LengthTrait::class.java, - PatternTrait::class.java, - RangeTrait::class.java, - UniqueItemsTrait::class.java, - EnumTrait::class.java, - RequiredTrait::class.java, -) +fun Shape.hasConstraintTrait() = allConstraintTraits.any(this::hasTrait) + +val allConstraintTraits = + setOf( + LengthTrait::class.java, + PatternTrait::class.java, + RangeTrait::class.java, + UniqueItemsTrait::class.java, + EnumTrait::class.java, + RequiredTrait::class.java, + ) val supportedStringConstraintTraits = setOf(LengthTrait::class.java, PatternTrait::class.java) /** * Supported constraint traits for the `list` and `set` shapes. */ -val supportedCollectionConstraintTraits = setOf( - LengthTrait::class.java, - UniqueItemsTrait::class.java, -) +val supportedCollectionConstraintTraits = + setOf( + LengthTrait::class.java, + UniqueItemsTrait::class.java, + ) /** * We say a shape is _directly_ constrained if: @@ -86,30 +87,39 @@ val supportedCollectionConstraintTraits = setOf( * * [the spec]: https://awslabs.github.io/smithy/2.0/spec/constraint-traits.html */ -fun Shape.isDirectlyConstrained(symbolProvider: SymbolProvider): Boolean = when (this) { - is StructureShape -> { - // TODO(https://github.com/smithy-lang/smithy-rs/issues/1302, https://github.com/awslabs/smithy/issues/1179): - // The only reason why the functions in this file have - // to take in a `SymbolProvider` is because non-`required` blob streaming members are interpreted as - // `required`, so we can't use `member.isOptional` here. - this.members().any { !symbolProvider.toSymbol(it).isOptional() && !it.hasNonNullDefault() } - } +fun Shape.isDirectlyConstrained(symbolProvider: SymbolProvider): Boolean = + when (this) { + is StructureShape -> { + // TODO(https://github.com/smithy-lang/smithy-rs/issues/1302, https://github.com/awslabs/smithy/issues/1179): + // The only reason why the functions in this file have + // to take in a `SymbolProvider` is because non-`required` blob streaming members are interpreted as + // `required`, so we can't use `member.isOptional` here. + this.members().any { !symbolProvider.toSymbol(it).isOptional() && !it.hasNonNullDefault() } + } - is MapShape -> this.hasTrait() - is StringShape -> this.hasTrait() || supportedStringConstraintTraits.any { this.hasTrait(it) } - is CollectionShape -> supportedCollectionConstraintTraits.any { this.hasTrait(it) } - is IntegerShape, is ShortShape, is LongShape, is ByteShape -> this.hasTrait() - is BlobShape -> this.hasTrait() - else -> false -} + is MapShape -> this.hasTrait() + is StringShape -> this.hasTrait() || supportedStringConstraintTraits.any { this.hasTrait(it) } + is CollectionShape -> supportedCollectionConstraintTraits.any { this.hasTrait(it) } + is IntegerShape, is ShortShape, is LongShape, is ByteShape -> this.hasTrait() + is BlobShape -> this.hasTrait() + else -> false + } -fun MemberShape.hasConstraintTraitOrTargetHasConstraintTrait(model: Model, symbolProvider: SymbolProvider): Boolean = +fun MemberShape.hasConstraintTraitOrTargetHasConstraintTrait( + model: Model, + symbolProvider: SymbolProvider, +): Boolean = this.isDirectlyConstrained(symbolProvider) || model.expectShape(this.target).isDirectlyConstrained(symbolProvider) -fun Shape.isTransitivelyButNotDirectlyConstrained(model: Model, symbolProvider: SymbolProvider): Boolean = - !this.isDirectlyConstrained(symbolProvider) && this.canReachConstrainedShape(model, symbolProvider) +fun Shape.isTransitivelyButNotDirectlyConstrained( + model: Model, + symbolProvider: SymbolProvider, +): Boolean = !this.isDirectlyConstrained(symbolProvider) && this.canReachConstrainedShape(model, symbolProvider) -fun Shape.canReachConstrainedShape(model: Model, symbolProvider: SymbolProvider): Boolean = +fun Shape.canReachConstrainedShape( + model: Model, + symbolProvider: SymbolProvider, +): Boolean = if (this is MemberShape) { // TODO(https://github.com/smithy-lang/smithy-rs/issues/1401) Constraint traits on member shapes are not implemented // yet. Also, note that a walker over a member shape can, perhaps counterintuitively, reach the _containing_ shape, @@ -119,18 +129,24 @@ fun Shape.canReachConstrainedShape(model: Model, symbolProvider: SymbolProvider) DirectedWalker(model).walkShapes(this).toSet().any { it.isDirectlyConstrained(symbolProvider) } } -fun MemberShape.targetCanReachConstrainedShape(model: Model, symbolProvider: SymbolProvider): Boolean = - model.expectShape(this.target).canReachConstrainedShape(model, symbolProvider) - -fun Shape.hasPublicConstrainedWrapperTupleType(model: Model, publicConstrainedTypes: Boolean): Boolean = when (this) { - is CollectionShape -> publicConstrainedTypes && supportedCollectionConstraintTraits.any(this::hasTrait) - is MapShape -> publicConstrainedTypes && this.hasTrait() - is StringShape -> !this.hasTrait() && (publicConstrainedTypes && supportedStringConstraintTraits.any(this::hasTrait)) - is IntegerShape, is ShortShape, is LongShape, is ByteShape -> publicConstrainedTypes && this.hasTrait() - is MemberShape -> model.expectShape(this.target).hasPublicConstrainedWrapperTupleType(model, publicConstrainedTypes) - is BlobShape -> publicConstrainedTypes && this.hasTrait() - else -> false -} +fun MemberShape.targetCanReachConstrainedShape( + model: Model, + symbolProvider: SymbolProvider, +): Boolean = model.expectShape(this.target).canReachConstrainedShape(model, symbolProvider) + +fun Shape.hasPublicConstrainedWrapperTupleType( + model: Model, + publicConstrainedTypes: Boolean, +): Boolean = + when (this) { + is CollectionShape -> publicConstrainedTypes && supportedCollectionConstraintTraits.any(this::hasTrait) + is MapShape -> publicConstrainedTypes && this.hasTrait() + is StringShape -> !this.hasTrait() && (publicConstrainedTypes && supportedStringConstraintTraits.any(this::hasTrait)) + is IntegerShape, is ShortShape, is LongShape, is ByteShape -> publicConstrainedTypes && this.hasTrait() + is MemberShape -> model.expectShape(this.target).hasPublicConstrainedWrapperTupleType(model, publicConstrainedTypes) + is BlobShape -> publicConstrainedTypes && this.hasTrait() + else -> false + } fun Shape.wouldHaveConstrainedWrapperTupleTypeWerePublicConstrainedTypesEnabled(model: Model): Boolean = hasPublicConstrainedWrapperTupleType(model, true) @@ -141,8 +157,11 @@ fun Shape.wouldHaveConstrainedWrapperTupleTypeWerePublicConstrainedTypesEnabled( * This function is used in core code generators, so it takes in a [CodegenContext] that is downcast * to [ServerCodegenContext] when generating servers. */ -fun workingWithPublicConstrainedWrapperTupleType(shape: Shape, model: Model, publicConstrainedTypes: Boolean): Boolean = - shape.hasPublicConstrainedWrapperTupleType(model, publicConstrainedTypes) +fun workingWithPublicConstrainedWrapperTupleType( + shape: Shape, + model: Model, + publicConstrainedTypes: Boolean, +): Boolean = shape.hasPublicConstrainedWrapperTupleType(model, publicConstrainedTypes) /** * Returns whether a shape's type _name_ contains a non-public type when `publicConstrainedTypes` is `false`. @@ -159,16 +178,19 @@ fun Shape.typeNameContainsNonPublicType( model: Model, symbolProvider: SymbolProvider, publicConstrainedTypes: Boolean, -): Boolean = !publicConstrainedTypes && when (this) { - is SimpleShape -> wouldHaveConstrainedWrapperTupleTypeWerePublicConstrainedTypesEnabled(model) - is MemberShape -> model.expectShape(this.target) - .typeNameContainsNonPublicType(model, symbolProvider, publicConstrainedTypes) - - is CollectionShape -> this.canReachConstrainedShape(model, symbolProvider) - is MapShape -> this.canReachConstrainedShape(model, symbolProvider) - is StructureShape, is UnionShape -> false - else -> UNREACHABLE("the above arms should be exhaustive, but we received shape: $this") -} +): Boolean = + !publicConstrainedTypes && + when (this) { + is SimpleShape -> wouldHaveConstrainedWrapperTupleTypeWerePublicConstrainedTypesEnabled(model) + is MemberShape -> + model.expectShape(this.target) + .typeNameContainsNonPublicType(model, symbolProvider, publicConstrainedTypes) + + is CollectionShape -> this.canReachConstrainedShape(model, symbolProvider) + is MapShape -> this.canReachConstrainedShape(model, symbolProvider) + is StructureShape, is UnionShape -> false + else -> UNREACHABLE("the above arms should be exhaustive, but we received shape: $this") + } /** * For synthetic shapes that are added to the model because of member constrained shapes, it returns @@ -183,7 +205,10 @@ fun Shape.overriddenConstrainedMemberInfo(): Pair? { /** * Returns the parent and the inline module that this particular shape should go in. */ -fun Shape.getParentAndInlineModuleForConstrainedMember(symbolProvider: RustSymbolProvider, publicConstrainedTypes: Boolean): Pair? { +fun Shape.getParentAndInlineModuleForConstrainedMember( + symbolProvider: RustSymbolProvider, + publicConstrainedTypes: Boolean, +): Pair? { val overriddenTrait = getTrait() ?: return null return if (overriddenTrait.container is StructureShape) { val structureModule = symbolProvider.toSymbol(overriddenTrait.container).module() @@ -202,12 +227,13 @@ fun Shape.getParentAndInlineModuleForConstrainedMember(symbolProvider: RustSymbo Pair(shapeModule.parent as RustModule.LeafModule, shapeModule) } else { val name = RustReservedWords.escapeIfNeeded(overriddenTrait.container.id.name).toSnakeCase() + "_internal" - val innerModule = RustModule.new( - name = name, - visibility = Visibility.PUBCRATE, - parent = ServerRustModule.Model, - inline = true, - ) + val innerModule = + RustModule.new( + name = name, + visibility = Visibility.PUBCRATE, + parent = ServerRustModule.Model, + inline = true, + ) Pair(ServerRustModule.Model, innerModule) } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/DeriveEqAndHashSymbolMetadataProvider.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/DeriveEqAndHashSymbolMetadataProvider.kt index d3f70592714..d2ff87195dc 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/DeriveEqAndHashSymbolMetadataProvider.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/DeriveEqAndHashSymbolMetadataProvider.kt @@ -83,12 +83,18 @@ class DeriveEqAndHashSymbolMetadataProvider( override fun memberMeta(memberShape: MemberShape) = base.toSymbol(memberShape).expectRustMetadata() override fun structureMeta(structureShape: StructureShape) = addDeriveEqAndHashIfPossible(structureShape) + override fun unionMeta(unionShape: UnionShape) = addDeriveEqAndHashIfPossible(unionShape) + override fun enumMeta(stringShape: StringShape) = addDeriveEqAndHashIfPossible(stringShape) override fun listMeta(listShape: ListShape): RustMetadata = addDeriveEqAndHashIfPossible(listShape) + override fun mapMeta(mapShape: MapShape): RustMetadata = addDeriveEqAndHashIfPossible(mapShape) + override fun stringMeta(stringShape: StringShape): RustMetadata = addDeriveEqAndHashIfPossible(stringShape) + override fun numberMeta(numberShape: NumberShape): RustMetadata = addDeriveEqAndHashIfPossible(numberShape) + override fun blobMeta(blobShape: BlobShape): RustMetadata = addDeriveEqAndHashIfPossible(blobShape) } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/LengthTraitValidationErrorMessage.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/LengthTraitValidationErrorMessage.kt index cb08de3c734..f31feae6ad3 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/LengthTraitValidationErrorMessage.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/LengthTraitValidationErrorMessage.kt @@ -9,15 +9,16 @@ import software.amazon.smithy.model.traits.LengthTrait fun LengthTrait.validationErrorMessage(): String { val beginning = "Value with length {} at '{}' failed to satisfy constraint: Member must have length " - val ending = if (this.min.isPresent && this.max.isPresent) { - "between ${this.min.get()} and ${this.max.get()}, inclusive" - } else if (this.min.isPresent) { - ( - "greater than or equal to ${this.min.get()}" + val ending = + if (this.min.isPresent && this.max.isPresent) { + "between ${this.min.get()} and ${this.max.get()}, inclusive" + } else if (this.min.isPresent) { + ( + "greater than or equal to ${this.min.get()}" ) - } else { - check(this.max.isPresent) - "less than or equal to ${this.max.get()}" - } + } else { + check(this.max.isPresent) + "less than or equal to ${this.max.get()}" + } return "$beginning$ending" } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/PatternTraitEscapedSpecialCharsValidator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/PatternTraitEscapedSpecialCharsValidator.kt index a26dff9e36c..4a51a2040a2 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/PatternTraitEscapedSpecialCharsValidator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/PatternTraitEscapedSpecialCharsValidator.kt @@ -14,18 +14,20 @@ import software.amazon.smithy.rust.codegen.core.util.dq import software.amazon.smithy.rust.codegen.core.util.expectTrait class PatternTraitEscapedSpecialCharsValidator : AbstractValidator() { - private val specialCharsWithEscapes = mapOf( - '\b' to "\\b", - '\u000C' to "\\f", - '\n' to "\\n", - '\r' to "\\r", - '\t' to "\\t", - ) + private val specialCharsWithEscapes = + mapOf( + '\b' to "\\b", + '\u000C' to "\\f", + '\n' to "\\n", + '\r' to "\\r", + '\t' to "\\t", + ) private val specialChars = specialCharsWithEscapes.keys override fun validate(model: Model): List { - val shapes = model.getStringShapesWithTrait(PatternTrait::class.java) + - model.getMemberShapesWithTrait(PatternTrait::class.java) + val shapes = + model.getStringShapesWithTrait(PatternTrait::class.java) + + model.getMemberShapesWithTrait(PatternTrait::class.java) return shapes .filter { shape -> checkMisuse(shape) } .map { shape -> makeError(shape) } @@ -34,10 +36,11 @@ class PatternTraitEscapedSpecialCharsValidator : AbstractValidator() { private fun makeError(shape: Shape): ValidationEvent { val pattern = shape.expectTrait() - val replacement = pattern.pattern.toString() - .map { specialCharsWithEscapes.getOrElse(it) { it.toString() } } - .joinToString("") - .dq() + val replacement = + pattern.pattern.toString() + .map { specialCharsWithEscapes.getOrElse(it) { it.toString() } } + .joinToString("") + .dq() val message = """ Non-escaped special characters used inside `@pattern`. diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/PubCrateConstrainedShapeSymbolProvider.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/PubCrateConstrainedShapeSymbolProvider.kt index c64182f152d..24c32c66c20 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/PubCrateConstrainedShapeSymbolProvider.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/PubCrateConstrainedShapeSymbolProvider.kt @@ -68,12 +68,13 @@ class PubCrateConstrainedShapeSymbolProvider( check(shape is CollectionShape || shape is MapShape) val name = constrainedTypeNameForCollectionOrMapShape(shape, serviceShape) - val module = RustModule.new( - RustReservedWords.escapeIfNeeded(name.toSnakeCase()), - visibility = Visibility.PUBCRATE, - parent = ServerRustModule.ConstrainedModule, - inline = true, - ) + val module = + RustModule.new( + RustReservedWords.escapeIfNeeded(name.toSnakeCase()), + visibility = Visibility.PUBCRATE, + parent = ServerRustModule.ConstrainedModule, + inline = true, + ) val rustType = RustType.Opaque(name, module.fullyQualifiedPath()) return Symbol.builder() .rustType(rustType) @@ -127,7 +128,10 @@ class PubCrateConstrainedShapeSymbolProvider( } } -fun constrainedTypeNameForCollectionOrMapShape(shape: Shape, serviceShape: ServiceShape): String { +fun constrainedTypeNameForCollectionOrMapShape( + shape: Shape, + serviceShape: ServiceShape, +): String { check(shape is CollectionShape || shape is MapShape) return "${shape.id.getName(serviceShape).toPascalCase()}Constrained" } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RangeTraitValidationErrorMessage.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RangeTraitValidationErrorMessage.kt index 9cde66e6e04..560777debbc 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RangeTraitValidationErrorMessage.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RangeTraitValidationErrorMessage.kt @@ -9,15 +9,16 @@ import software.amazon.smithy.model.traits.RangeTrait fun RangeTrait.validationErrorMessage(): String { val beginning = "Value at '{}' failed to satisfy constraint: Member must be " - val ending = if (this.min.isPresent && this.max.isPresent) { - "between ${this.min.get()} and ${this.max.get()}, inclusive" - } else if (this.min.isPresent) { - ( - "greater than or equal to ${this.min.get()}" + val ending = + if (this.min.isPresent && this.max.isPresent) { + "between ${this.min.get()} and ${this.max.get()}, inclusive" + } else if (this.min.isPresent) { + ( + "greater than or equal to ${this.min.get()}" ) - } else { - check(this.max.isPresent) - "less than or equal to ${this.max.get()}" - } + } else { + check(this.max.isPresent) + "less than or equal to ${this.max.get()}" + } return "$beginning$ending" } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RustCrateInlineModuleComposingWriter.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RustCrateInlineModuleComposingWriter.kt index 96b3b5739c3..2a6867c6f6a 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RustCrateInlineModuleComposingWriter.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RustCrateInlineModuleComposingWriter.kt @@ -147,7 +147,8 @@ private val crateToInlineModule: ConcurrentHashMap = class InnerModule(private val moduleDocProvider: ModuleDocProvider, debugMode: Boolean) { // Holds the root modules to start rendering the descendents from. private val topLevelModuleWriters: ConcurrentHashMap = ConcurrentHashMap() - private val inlineModuleWriters: ConcurrentHashMap> = ConcurrentHashMap() + private val inlineModuleWriters: ConcurrentHashMap> = + ConcurrentHashMap() private val docWriters: ConcurrentHashMap> = ConcurrentHashMap() private val writerCreator = RustWriter.factory(debugMode) @@ -155,13 +156,19 @@ class InnerModule(private val moduleDocProvider: ModuleDocProvider, debugMode: B // indicating that it contains generated code and should not be manually edited. This comment // appears on each descendent inline module. To remove those comments, each time an inline // module is rendered, first `emptyLineCount` characters are removed from it. - private val emptyLineCount: Int = writerCreator - .apply("lines-it-always-writes.rs", "crate") - .toString() - .split("\n")[0] - .length - - fun withInlineModule(outerWriter: RustWriter, innerModule: RustModule.LeafModule, docWriter: DocWriter? = null, writable: Writable) { + private val emptyLineCount: Int = + writerCreator + .apply("lines-it-always-writes.rs", "crate") + .toString() + .split("\n")[0] + .length + + fun withInlineModule( + outerWriter: RustWriter, + innerModule: RustModule.LeafModule, + docWriter: DocWriter? = null, + writable: Writable, + ) { if (docWriter != null) { val moduleDocWriterList = docWriters.getOrPut(innerModule) { mutableListOf() } moduleDocWriterList.add(docWriter) @@ -172,7 +179,12 @@ class InnerModule(private val moduleDocProvider: ModuleDocProvider, debugMode: B /** * Given a `RustCrate` and a `RustModule.LeafModule()`, it creates a writer to that module and calls the writable. */ - fun withInlineModuleHierarchyUsingCrate(rustCrate: RustCrate, inlineModule: RustModule.LeafModule, docWriter: DocWriter? = null, writable: Writable) { + fun withInlineModuleHierarchyUsingCrate( + rustCrate: RustCrate, + inlineModule: RustModule.LeafModule, + docWriter: DocWriter? = null, + writable: Writable, + ) { val hierarchy = getHierarchy(inlineModule).toMutableList() check(!hierarchy.first().isInline()) { "When adding a `RustModule.LeafModule` to the crate, the topmost module in the hierarchy cannot be an inline module." @@ -211,7 +223,12 @@ class InnerModule(private val moduleDocProvider: ModuleDocProvider, debugMode: B * Given a `Writer` to a module and an inline `RustModule.LeafModule()`, it creates a writer to that module and calls the writable. * It registers the complete hierarchy including the `outerWriter` if that is not already registrered. */ - fun withInlineModuleHierarchy(outerWriter: RustWriter, inlineModule: RustModule.LeafModule, docWriter: DocWriter? = null, writable: Writable) { + fun withInlineModuleHierarchy( + outerWriter: RustWriter, + inlineModule: RustModule.LeafModule, + docWriter: DocWriter? = null, + writable: Writable, + ) { val hierarchy = getHierarchy(inlineModule).toMutableList() if (!hierarchy.first().isInline()) { hierarchy.removeFirst() @@ -263,12 +280,18 @@ class InnerModule(private val moduleDocProvider: ModuleDocProvider, debugMode: B fun render() { var writerToAddDependencies: RustWriter? = null - fun writeInlineCode(rustWriter: RustWriter, code: String) { + fun writeInlineCode( + rustWriter: RustWriter, + code: String, + ) { val inlineCode = code.drop(emptyLineCount) rustWriter.writeWithNoFormatting(inlineCode) } - fun renderDescendents(topLevelWriter: RustWriter, inMemoryWriter: RustWriter) { + fun renderDescendents( + topLevelWriter: RustWriter, + inMemoryWriter: RustWriter, + ) { // Traverse all descendent inline modules and render them. inlineModuleWriters[inMemoryWriter]!!.forEach { writeDocs(it.inlineModule) @@ -300,7 +323,10 @@ class InnerModule(private val moduleDocProvider: ModuleDocProvider, debugMode: B * Given the inline-module returns an existing `RustWriter`, or if that inline module * has never been registered before then a new `RustWriter` is created and returned. */ - private fun getWriter(outerWriter: RustWriter, inlineModule: RustModule.LeafModule): RustWriter { + private fun getWriter( + outerWriter: RustWriter, + inlineModule: RustModule.LeafModule, + ): RustWriter { val nestedModuleWriter = inlineModuleWriters[outerWriter] if (nestedModuleWriter != null) { return findOrAddToList(nestedModuleWriter, inlineModule) @@ -326,9 +352,10 @@ class InnerModule(private val moduleDocProvider: ModuleDocProvider, debugMode: B inlineModuleList: MutableList, lookForModule: RustModule.LeafModule, ): RustWriter { - val inlineModuleAndWriter = inlineModuleList.firstOrNull() { - it.inlineModule.name == lookForModule.name - } + val inlineModuleAndWriter = + inlineModuleList.firstOrNull { + it.inlineModule.name == lookForModule.name + } return if (inlineModuleAndWriter == null) { val inlineWriter = createNewInlineModule() inlineModuleList.add(InlineModuleWithWriter(lookForModule, inlineWriter)) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RustServerCodegenPlugin.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RustServerCodegenPlugin.kt index 2672bbe9f3d..e68fa26a33c 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RustServerCodegenPlugin.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RustServerCodegenPlugin.kt @@ -70,30 +70,29 @@ class RustServerCodegenPlugin : ServerDecoratableBuildPlugin() { constrainedTypes: Boolean = true, includeConstrainedShapeProvider: Boolean = true, codegenDecorator: ServerCodegenDecorator, - ) = - SymbolVisitor(settings, model, serviceShape = serviceShape, config = rustSymbolProviderConfig) - // Generate public constrained types for directly constrained shapes. - .let { - if (includeConstrainedShapeProvider) ConstrainedShapeSymbolProvider(it, serviceShape, constrainedTypes) else it - } - // Generate different types for EventStream shapes (e.g. transcribe streaming) - .let { EventStreamSymbolProvider(rustSymbolProviderConfig.runtimeConfig, it, CodegenTarget.SERVER) } - // Generate [ByteStream] instead of `Blob` for streaming binary shapes (e.g. S3 GetObject) - .let { StreamingShapeSymbolProvider(it) } - // Add Rust attributes (like `#[derive(PartialEq)]`) to generated shapes - .let { BaseSymbolMetadataProvider(it, additionalAttributes = listOf()) } - // Constrained shapes generate newtypes that need the same derives we place on types generated from aggregate shapes. - .let { ConstrainedShapeSymbolMetadataProvider(it, constrainedTypes) } - // Streaming shapes need different derives (e.g. they cannot derive `PartialEq`) - .let { StreamingShapeMetadataProvider(it) } - // Derive `Eq` and `Hash` if possible. - .let { DeriveEqAndHashSymbolMetadataProvider(it) } - // Rename shapes that clash with Rust reserved words & and other SDK specific features e.g. `send()` cannot - // be the name of an operation input - .let { RustReservedWordSymbolProvider(it, ServerReservedWords) } - // Allows decorators to inject a custom symbol provider - .let { codegenDecorator.symbolProvider(it) } - // Inject custom symbols. - .let { CustomShapeSymbolProvider(it) } + ) = SymbolVisitor(settings, model, serviceShape = serviceShape, config = rustSymbolProviderConfig) + // Generate public constrained types for directly constrained shapes. + .let { + if (includeConstrainedShapeProvider) ConstrainedShapeSymbolProvider(it, serviceShape, constrainedTypes) else it + } + // Generate different types for EventStream shapes (e.g. transcribe streaming) + .let { EventStreamSymbolProvider(rustSymbolProviderConfig.runtimeConfig, it, CodegenTarget.SERVER) } + // Generate [ByteStream] instead of `Blob` for streaming binary shapes (e.g. S3 GetObject) + .let { StreamingShapeSymbolProvider(it) } + // Add Rust attributes (like `#[derive(PartialEq)]`) to generated shapes + .let { BaseSymbolMetadataProvider(it, additionalAttributes = listOf()) } + // Constrained shapes generate newtypes that need the same derives we place on types generated from aggregate shapes. + .let { ConstrainedShapeSymbolMetadataProvider(it, constrainedTypes) } + // Streaming shapes need different derives (e.g. they cannot derive `PartialEq`) + .let { StreamingShapeMetadataProvider(it) } + // Derive `Eq` and `Hash` if possible. + .let { DeriveEqAndHashSymbolMetadataProvider(it) } + // Rename shapes that clash with Rust reserved words & and other SDK specific features e.g. `send()` cannot + // be the name of an operation input + .let { RustReservedWordSymbolProvider(it, ServerReservedWords) } + // Allows decorators to inject a custom symbol provider + .let { codegenDecorator.symbolProvider(it) } + // Inject custom symbols. + .let { CustomShapeSymbolProvider(it) } } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCargoDependency.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCargoDependency.kt index 38de49c39bb..01df0c0a937 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCargoDependency.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCargoDependency.kt @@ -29,5 +29,6 @@ object ServerCargoDependency { val HyperDev: CargoDependency = CargoDependency("hyper", CratesIo("0.14.12"), DependencyScope.Dev) fun smithyHttpServer(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-http-server") + fun smithyTypes(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-types") } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenContext.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenContext.kt index d952a7771b0..41c3c94708e 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenContext.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenContext.kt @@ -37,8 +37,8 @@ data class ServerCodegenContext( val constraintViolationSymbolProvider: ConstraintViolationSymbolProvider, val pubCrateConstrainedShapeSymbolProvider: PubCrateConstrainedShapeSymbolProvider, ) : CodegenContext( - model, symbolProvider, moduleDocProvider, serviceShape, protocol, settings, CodegenTarget.SERVER, -) { + model, symbolProvider, moduleDocProvider, serviceShape, protocol, settings, CodegenTarget.SERVER, + ) { override fun builderInstantiator(): BuilderInstantiator { return ServerBuilderInstantiator(symbolProvider, returnSymbolToParseFn(this)) } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt index 9e9ce5c598d..87b55066165 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt @@ -101,7 +101,6 @@ open class ServerCodegenVisitor( context: PluginContext, private val codegenDecorator: ServerCodegenDecorator, ) : ShapeVisitor.Default() { - protected val logger = Logger.getLogger(javaClass.name) protected var settings = ServerRustSettings.from(context.model, context.settings) @@ -114,12 +113,13 @@ open class ServerCodegenVisitor( protected var validationExceptionConversionGenerator: ValidationExceptionConversionGenerator init { - val rustSymbolProviderConfig = RustSymbolProviderConfig( - runtimeConfig = settings.runtimeConfig, - renameExceptions = false, - nullabilityCheckMode = NullableIndex.CheckMode.SERVER, - moduleProvider = ServerModuleProvider, - ) + val rustSymbolProviderConfig = + RustSymbolProviderConfig( + runtimeConfig = settings.runtimeConfig, + renameExceptions = false, + nullabilityCheckMode = NullableIndex.CheckMode.SERVER, + moduleProvider = ServerModuleProvider, + ) val baseModel = baselineTransform(context.model) val service = settings.getService(baseModel) @@ -135,45 +135,50 @@ open class ServerCodegenVisitor( model = codegenDecorator.transformModel(service, baseModel, settings) - val serverSymbolProviders = ServerSymbolProviders.from( - settings, - model, - service, - rustSymbolProviderConfig, - settings.codegenConfig.publicConstrainedTypes, - codegenDecorator, - RustServerCodegenPlugin::baseSymbolProvider, - ) + val serverSymbolProviders = + ServerSymbolProviders.from( + settings, + model, + service, + rustSymbolProviderConfig, + settings.codegenConfig.publicConstrainedTypes, + codegenDecorator, + RustServerCodegenPlugin::baseSymbolProvider, + ) - codegenContext = ServerCodegenContext( - model, - serverSymbolProviders.symbolProvider, - null, - service, - protocolShape, - settings, - serverSymbolProviders.unconstrainedShapeSymbolProvider, - serverSymbolProviders.constrainedShapeSymbolProvider, - serverSymbolProviders.constraintViolationSymbolProvider, - serverSymbolProviders.pubCrateConstrainedShapeSymbolProvider, - ) + codegenContext = + ServerCodegenContext( + model, + serverSymbolProviders.symbolProvider, + null, + service, + protocolShape, + settings, + serverSymbolProviders.unconstrainedShapeSymbolProvider, + serverSymbolProviders.constrainedShapeSymbolProvider, + serverSymbolProviders.constraintViolationSymbolProvider, + serverSymbolProviders.pubCrateConstrainedShapeSymbolProvider, + ) // We can use a not-null assertion because [CombinedServerCodegenDecorator] returns a not null value. validationExceptionConversionGenerator = codegenDecorator.validationExceptionConversion(codegenContext)!! - codegenContext = codegenContext.copy( - moduleDocProvider = codegenDecorator.moduleDocumentationCustomization( - codegenContext, - ServerModuleDocProvider(codegenContext), - ), - ) + codegenContext = + codegenContext.copy( + moduleDocProvider = + codegenDecorator.moduleDocumentationCustomization( + codegenContext, + ServerModuleDocProvider(codegenContext), + ), + ) - rustCrate = RustCrate( - context.fileManifest, - codegenContext.symbolProvider, - settings.codegenConfig, - codegenContext.expectModuleDocProvider(), - ) + rustCrate = + RustCrate( + context.fileManifest, + codegenContext.symbolProvider, + settings.codegenConfig, + codegenContext.expectModuleDocProvider(), + ) protocolGenerator = this.protocolGeneratorFactory.buildProtocolGenerator(codegenContext) } @@ -209,8 +214,7 @@ open class ServerCodegenVisitor( /** * Exposure purely for unit test purposes. */ - internal fun baselineTransformInternalTest(model: Model) = - baselineTransform(model) + internal fun baselineTransformInternalTest(model: Model) = baselineTransform(model) /** * Execute code generation @@ -324,12 +328,13 @@ open class ServerCodegenVisitor( writer: RustWriter, ) { if (codegenContext.settings.codegenConfig.publicConstrainedTypes || shape.isReachableFromOperationInput()) { - val serverBuilderGenerator = ServerBuilderGenerator( - codegenContext, - shape, - validationExceptionConversionGenerator, - protocolGenerator.protocol, - ) + val serverBuilderGenerator = + ServerBuilderGenerator( + codegenContext, + shape, + validationExceptionConversionGenerator, + protocolGenerator.protocol, + ) serverBuilderGenerator.render(rustCrate, writer) if (codegenContext.settings.codegenConfig.publicConstrainedTypes) { @@ -366,14 +371,16 @@ open class ServerCodegenVisitor( } override fun listShape(shape: ListShape) = collectionShape(shape) + override fun setShape(shape: SetShape) = collectionShape(shape) private fun collectionShape(shape: CollectionShape) { val renderUnconstrainedList = - shape.isReachableFromOperationInput() && shape.canReachConstrainedShape( - model, - codegenContext.symbolProvider, - ) + shape.isReachableFromOperationInput() && + shape.canReachConstrainedShape( + model, + codegenContext.symbolProvider, + ) val isDirectlyConstrained = shape.isDirectlyConstrained(codegenContext.symbolProvider) if (renderUnconstrainedList) { @@ -433,10 +440,11 @@ open class ServerCodegenVisitor( override fun mapShape(shape: MapShape) { val renderUnconstrainedMap = - shape.isReachableFromOperationInput() && shape.canReachConstrainedShape( - model, - codegenContext.symbolProvider, - ) + shape.isReachableFromOperationInput() && + shape.canReachConstrainedShape( + model, + codegenContext.symbolProvider, + ) val isDirectlyConstrained = shape.isDirectlyConstrained(codegenContext.symbolProvider) if (renderUnconstrainedMap) { @@ -498,15 +506,21 @@ open class ServerCodegenVisitor( * Although raw strings require no code generation, enums are actually [EnumTrait] applied to string shapes. */ override fun stringShape(shape: StringShape) { - fun serverEnumGeneratorFactory(codegenContext: ServerCodegenContext, shape: StringShape) = - ServerEnumGenerator(codegenContext, shape, validationExceptionConversionGenerator) + fun serverEnumGeneratorFactory( + codegenContext: ServerCodegenContext, + shape: StringShape, + ) = ServerEnumGenerator(codegenContext, shape, validationExceptionConversionGenerator) stringShape(shape, ::serverEnumGeneratorFactory) } override fun integerShape(shape: IntegerShape) = integralShape(shape) + override fun shortShape(shape: ShortShape) = integralShape(shape) + override fun longShape(shape: LongShape) = integralShape(shape) + override fun byteShape(shape: ByteShape) = integralShape(shape) + private fun integralShape(shape: NumberShape) { if (shape.isDirectlyConstrained(codegenContext.symbolProvider)) { logger.info("[rust-server-codegen] Generating a constrained integral $shape") @@ -569,7 +583,8 @@ open class ServerCodegenVisitor( UnionGenerator(model, codegenContext.symbolProvider, this, shape, renderUnknownVariant = false).render() } - if (shape.isReachableFromOperationInput() && shape.canReachConstrainedShape( + if (shape.isReachableFromOperationInput() && + shape.canReachConstrainedShape( model, codegenContext.symbolProvider, ) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerReservedWords.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerReservedWords.kt index 64d9ea04f41..23c2e5c1684 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerReservedWords.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerReservedWords.kt @@ -8,8 +8,9 @@ package software.amazon.smithy.rust.codegen.server.smithy import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWordConfig import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureGenerator -val ServerReservedWords = RustReservedWordConfig( - structureMemberMap = StructureGenerator.structureMemberNameMap, - unionMemberMap = emptyMap(), - enumMemberMap = emptyMap(), -) +val ServerReservedWords = + RustReservedWordConfig( + structureMemberMap = StructureGenerator.structureMemberNameMap, + unionMemberMap = emptyMap(), + enumMemberMap = emptyMap(), + ) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerRuntimeType.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerRuntimeType.kt index 46b18face58..7e539915281 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerRuntimeType.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerRuntimeType.kt @@ -17,8 +17,11 @@ object ServerRuntimeType { fun router(runtimeConfig: RuntimeConfig) = ServerCargoDependency.smithyHttpServer(runtimeConfig).toType().resolve("routing::Router") - fun protocol(name: String, path: String, runtimeConfig: RuntimeConfig) = - ServerCargoDependency.smithyHttpServer(runtimeConfig).toType().resolve("protocol::$path::$name") + fun protocol( + name: String, + path: String, + runtimeConfig: RuntimeConfig, + ) = ServerCargoDependency.smithyHttpServer(runtimeConfig).toType().resolve("protocol::$path::$name") fun protocol(runtimeConfig: RuntimeConfig) = protocol("Protocol", "", runtimeConfig) } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerRustModule.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerRustModule.kt index 2c2ef75f744..b7e349ddc02 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerRustModule.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerRustModule.kt @@ -30,6 +30,7 @@ import software.amazon.smithy.rust.codegen.core.util.hasTrait import software.amazon.smithy.rust.codegen.core.util.toSnakeCase import software.amazon.smithy.rust.codegen.server.smithy.generators.DocHandlerGenerator import software.amazon.smithy.rust.codegen.server.smithy.generators.handlerImports + object ServerRustModule { val root = RustModule.LibRs @@ -67,43 +68,49 @@ class ServerModuleDocProvider(private val codegenContext: ServerCodegenContext) } } - private fun operationShapeModuleDoc(): Writable = writable { - val index = TopDownIndex.of(codegenContext.model) - val operations = index.getContainedOperations(codegenContext.serviceShape).toSortedSet(compareBy { it.id }) + private fun operationShapeModuleDoc(): Writable = + writable { + val index = TopDownIndex.of(codegenContext.model) + val operations = index.getContainedOperations(codegenContext.serviceShape).toSortedSet(compareBy { it.id }) - val firstOperation = operations.first() ?: return@writable - val crateName = codegenContext.settings.moduleName.toSnakeCase() + val firstOperation = operations.first() ?: return@writable + val crateName = codegenContext.settings.moduleName.toSnakeCase() - rustTemplate( - """ - /// A collection of types representing each operation defined in the service closure. - /// - /// The [plugin system](#{SmithyHttpServer}::plugin) makes use of these - /// [zero-sized types](https://doc.rust-lang.org/nomicon/exotic-sizes.html##zero-sized-types-zsts) (ZSTs) to - /// parameterize [`Plugin`](#{SmithyHttpServer}::plugin::Plugin) implementations. Their traits, such as - /// [`OperationShape`](#{SmithyHttpServer}::operation::OperationShape), can be used to provide - /// operation specific information to the [`Layer`](#{Tower}::Layer) being applied. - """.trimIndent(), - "SmithyHttpServer" to - ServerCargoDependency.smithyHttpServer(codegenContext.runtimeConfig).toType(), - "Tower" to ServerCargoDependency.Tower.toType(), - "Handler" to DocHandlerGenerator(codegenContext, firstOperation, "handler", commentToken = "///").docSignature(), - "HandlerImports" to handlerImports(crateName, operations, commentToken = "///"), - ) - } + rustTemplate( + """ + /// A collection of types representing each operation defined in the service closure. + /// + /// The [plugin system](#{SmithyHttpServer}::plugin) makes use of these + /// [zero-sized types](https://doc.rust-lang.org/nomicon/exotic-sizes.html##zero-sized-types-zsts) (ZSTs) to + /// parameterize [`Plugin`](#{SmithyHttpServer}::plugin::Plugin) implementations. Their traits, such as + /// [`OperationShape`](#{SmithyHttpServer}::operation::OperationShape), can be used to provide + /// operation specific information to the [`Layer`](#{Tower}::Layer) being applied. + """.trimIndent(), + "SmithyHttpServer" to + ServerCargoDependency.smithyHttpServer(codegenContext.runtimeConfig).toType(), + "Tower" to ServerCargoDependency.Tower.toType(), + "Handler" to DocHandlerGenerator(codegenContext, firstOperation, "handler", commentToken = "///").docSignature(), + "HandlerImports" to handlerImports(crateName, operations, commentToken = "///"), + ) + } } object ServerModuleProvider : ModuleProvider { - override fun moduleForShape(context: ModuleProviderContext, shape: Shape): RustModule.LeafModule = when (shape) { - is OperationShape -> ServerRustModule.Operation - is StructureShape -> when { - shape.hasTrait() -> ServerRustModule.Error - shape.hasTrait() -> ServerRustModule.Input - shape.hasTrait() -> ServerRustModule.Output + override fun moduleForShape( + context: ModuleProviderContext, + shape: Shape, + ): RustModule.LeafModule = + when (shape) { + is OperationShape -> ServerRustModule.Operation + is StructureShape -> + when { + shape.hasTrait() -> ServerRustModule.Error + shape.hasTrait() -> ServerRustModule.Input + shape.hasTrait() -> ServerRustModule.Output + else -> ServerRustModule.Model + } else -> ServerRustModule.Model } - else -> ServerRustModule.Model - } override fun moduleForOperationError( context: ModuleProviderContext, @@ -115,18 +122,24 @@ object ServerModuleProvider : ModuleProvider { eventStream: UnionShape, ): RustModule.LeafModule = ServerRustModule.Error - override fun moduleForBuilder(context: ModuleProviderContext, shape: Shape, symbol: Symbol): RustModule.LeafModule { + override fun moduleForBuilder( + context: ModuleProviderContext, + shape: Shape, + symbol: Symbol, + ): RustModule.LeafModule { val pubCrate = !(context.settings as ServerRustSettings).codegenConfig.publicConstrainedTypes - val builderNamespace = RustReservedWords.escapeIfNeeded(symbol.name.toSnakeCase()) + - if (pubCrate) { - "_internal" - } else { - "" + val builderNamespace = + RustReservedWords.escapeIfNeeded(symbol.name.toSnakeCase()) + + if (pubCrate) { + "_internal" + } else { + "" + } + val visibility = + when (pubCrate) { + true -> Visibility.PUBCRATE + false -> Visibility.PUBLIC } - val visibility = when (pubCrate) { - true -> Visibility.PUBCRATE - false -> Visibility.PUBLIC - } return RustModule.new( builderNamespace, visibility, diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerRustSettings.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerRustSettings.kt index 7b0ba87e611..6f6c25a450a 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerRustSettings.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerRustSettings.kt @@ -14,7 +14,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.CoreRustSettings import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig import java.util.Optional -/** +/* * [ServerRustSettings] and [ServerCodegenConfig] classes. * * These classes are entirely analogous to [ClientRustSettings] and [ClientCodegenConfig]. Refer to the documentation @@ -40,20 +40,23 @@ data class ServerRustSettings( override val examplesUri: String?, override val customizationConfig: ObjectNode?, ) : CoreRustSettings( - service, - moduleName, - moduleVersion, - moduleAuthors, - moduleDescription, - moduleRepository, - runtimeConfig, - codegenConfig, - license, - examplesUri, - customizationConfig, -) { + service, + moduleName, + moduleVersion, + moduleAuthors, + moduleDescription, + moduleRepository, + runtimeConfig, + codegenConfig, + license, + examplesUri, + customizationConfig, + ) { companion object { - fun from(model: Model, config: ObjectNode): ServerRustSettings { + fun from( + model: Model, + config: ObjectNode, + ): ServerRustSettings { val coreRustSettings = CoreRustSettings.from(model, config) val codegenSettingsNode = config.getObjectMember(CODEGEN_SETTINGS) val coreCodegenConfig = CoreCodegenConfig.fromNode(codegenSettingsNode) @@ -91,27 +94,29 @@ data class ServerCodegenConfig( */ val experimentalCustomValidationExceptionWithReasonPleaseDoNotUse: String? = defaultExperimentalCustomValidationExceptionWithReasonPleaseDoNotUse, ) : CoreCodegenConfig( - formatTimeoutSeconds, debugMode, -) { + formatTimeoutSeconds, debugMode, + ) { companion object { private const val defaultPublicConstrainedTypes = true private const val defaultIgnoreUnsupportedConstraints = false private val defaultExperimentalCustomValidationExceptionWithReasonPleaseDoNotUse = null - fun fromCodegenConfigAndNode(coreCodegenConfig: CoreCodegenConfig, node: Optional) = - if (node.isPresent) { - ServerCodegenConfig( - formatTimeoutSeconds = coreCodegenConfig.formatTimeoutSeconds, - debugMode = coreCodegenConfig.debugMode, - publicConstrainedTypes = node.get().getBooleanMemberOrDefault("publicConstrainedTypes", defaultPublicConstrainedTypes), - ignoreUnsupportedConstraints = node.get().getBooleanMemberOrDefault("ignoreUnsupportedConstraints", defaultIgnoreUnsupportedConstraints), - experimentalCustomValidationExceptionWithReasonPleaseDoNotUse = node.get().getStringMemberOrDefault("experimentalCustomValidationExceptionWithReasonPleaseDoNotUse", defaultExperimentalCustomValidationExceptionWithReasonPleaseDoNotUse), - ) - } else { - ServerCodegenConfig( - formatTimeoutSeconds = coreCodegenConfig.formatTimeoutSeconds, - debugMode = coreCodegenConfig.debugMode, - ) - } + fun fromCodegenConfigAndNode( + coreCodegenConfig: CoreCodegenConfig, + node: Optional, + ) = if (node.isPresent) { + ServerCodegenConfig( + formatTimeoutSeconds = coreCodegenConfig.formatTimeoutSeconds, + debugMode = coreCodegenConfig.debugMode, + publicConstrainedTypes = node.get().getBooleanMemberOrDefault("publicConstrainedTypes", defaultPublicConstrainedTypes), + ignoreUnsupportedConstraints = node.get().getBooleanMemberOrDefault("ignoreUnsupportedConstraints", defaultIgnoreUnsupportedConstraints), + experimentalCustomValidationExceptionWithReasonPleaseDoNotUse = node.get().getStringMemberOrDefault("experimentalCustomValidationExceptionWithReasonPleaseDoNotUse", defaultExperimentalCustomValidationExceptionWithReasonPleaseDoNotUse), + ) + } else { + ServerCodegenConfig( + formatTimeoutSeconds = coreCodegenConfig.formatTimeoutSeconds, + debugMode = coreCodegenConfig.debugMode, + ) + } } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerSymbolProviders.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerSymbolProviders.kt index a2693050a3c..c8035121d65 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerSymbolProviders.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerSymbolProviders.kt @@ -11,6 +11,9 @@ import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProviderConfig import software.amazon.smithy.rust.codegen.server.smithy.customize.ServerCodegenDecorator +typealias BaseSymbolProviderFactory = + (settings: ServerRustSettings, model: Model, service: ServiceShape, rustSymbolProviderConfig: RustSymbolProviderConfig, publicConstrainedTypes: Boolean, includeConstraintShapeProvider: Boolean, codegenDecorator: ServerCodegenDecorator) -> RustSymbolProvider + /** * Just a handy class to centralize initialization all the symbol providers required by the server code generators, to * make the init blocks of the codegen visitors ([ServerCodegenVisitor] and [PythonServerCodegenVisitor]), and the @@ -31,41 +34,46 @@ class ServerSymbolProviders private constructor( rustSymbolProviderConfig: RustSymbolProviderConfig, publicConstrainedTypes: Boolean, codegenDecorator: ServerCodegenDecorator, - baseSymbolProviderFactory: (settings: ServerRustSettings, model: Model, service: ServiceShape, rustSymbolProviderConfig: RustSymbolProviderConfig, publicConstrainedTypes: Boolean, includeConstraintShapeProvider: Boolean, codegenDecorator: ServerCodegenDecorator) -> RustSymbolProvider, + baseSymbolProviderFactory: BaseSymbolProviderFactory, ): ServerSymbolProviders { - val baseSymbolProvider = baseSymbolProviderFactory(settings, model, service, rustSymbolProviderConfig, publicConstrainedTypes, publicConstrainedTypes, codegenDecorator) + val baseSymbolProvider = + baseSymbolProviderFactory(settings, model, service, rustSymbolProviderConfig, publicConstrainedTypes, publicConstrainedTypes, codegenDecorator) return ServerSymbolProviders( symbolProvider = baseSymbolProvider, - constrainedShapeSymbolProvider = baseSymbolProviderFactory( - settings, - model, - service, - rustSymbolProviderConfig, - publicConstrainedTypes, - true, - codegenDecorator, - ), - unconstrainedShapeSymbolProvider = UnconstrainedShapeSymbolProvider( + constrainedShapeSymbolProvider = baseSymbolProviderFactory( settings, model, service, rustSymbolProviderConfig, - false, - false, + publicConstrainedTypes, + true, codegenDecorator, ), - publicConstrainedTypes, service, - ), - pubCrateConstrainedShapeSymbolProvider = PubCrateConstrainedShapeSymbolProvider( - baseSymbolProvider, - service, - ), - constraintViolationSymbolProvider = ConstraintViolationSymbolProvider( - baseSymbolProvider, - publicConstrainedTypes, - service, - ), + unconstrainedShapeSymbolProvider = + UnconstrainedShapeSymbolProvider( + baseSymbolProviderFactory( + settings, + model, + service, + rustSymbolProviderConfig, + false, + false, + codegenDecorator, + ), + publicConstrainedTypes, service, + ), + pubCrateConstrainedShapeSymbolProvider = + PubCrateConstrainedShapeSymbolProvider( + baseSymbolProvider, + service, + ), + constraintViolationSymbolProvider = + ConstraintViolationSymbolProvider( + baseSymbolProvider, + publicConstrainedTypes, + service, + ), ) } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/UnconstrainedShapeSymbolProvider.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/UnconstrainedShapeSymbolProvider.kt index 488c9347a20..a615d403665 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/UnconstrainedShapeSymbolProvider.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/UnconstrainedShapeSymbolProvider.kt @@ -100,12 +100,13 @@ class UnconstrainedShapeSymbolProvider( val name = unconstrainedTypeNameForCollectionOrMapOrUnionShape(shape) val parent = shape.getParentAndInlineModuleForConstrainedMember(this, publicConstrainedTypes)?.second ?: ServerRustModule.UnconstrainedModule - val module = RustModule.new( - RustReservedWords.escapeIfNeeded(name.toSnakeCase()), - visibility = Visibility.PUBCRATE, - parent = parent, - inline = true, - ) + val module = + RustModule.new( + RustReservedWords.escapeIfNeeded(name.toSnakeCase()), + visibility = Visibility.PUBCRATE, + parent = parent, + inline = true, + ) val rustType = RustType.Opaque(name, module.fullyQualifiedPath()) return Symbol.builder() .rustType(rustType) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ValidateUnsupportedConstraints.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ValidateUnsupportedConstraints.kt index abee410b5b0..0d3f50ddca1 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ValidateUnsupportedConstraints.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ValidateUnsupportedConstraints.kt @@ -37,7 +37,12 @@ private sealed class UnsupportedConstraintMessageKind { private val constraintTraitsUberIssue = "https://github.com/smithy-lang/smithy-rs/issues/1401" fun intoLogMessage(ignoreUnsupportedConstraints: Boolean): LogMessage { - fun buildMessage(intro: String, willSupport: Boolean, trackingIssue: String? = null, canBeIgnored: Boolean = true): String { + fun buildMessage( + intro: String, + willSupport: Boolean, + trackingIssue: String? = null, + canBeIgnored: Boolean = true, + ): String { var msg = """ $intro This is not supported in the smithy-rs server SDK.""" @@ -62,79 +67,85 @@ private sealed class UnsupportedConstraintMessageKind { constraintTrait: Trait, trackingIssue: String, willSupport: Boolean = true, - ) = - buildMessage( - "The ${shape.type} shape `${shape.id}` has the constraint trait `${constraintTrait.toShapeId()}` attached.", - willSupport, - trackingIssue, - ) + ) = buildMessage( + "The ${shape.type} shape `${shape.id}` has the constraint trait `${constraintTrait.toShapeId()}` attached.", + willSupport, + trackingIssue, + ) val level = if (ignoreUnsupportedConstraints) Level.WARNING else Level.SEVERE return when (this) { - is UnsupportedConstraintOnMemberShape -> LogMessage( - level, - buildMessageShapeHasUnsupportedConstraintTrait(shape, constraintTrait, constraintTraitsUberIssue), - ) + is UnsupportedConstraintOnMemberShape -> + LogMessage( + level, + buildMessageShapeHasUnsupportedConstraintTrait(shape, constraintTrait, constraintTraitsUberIssue), + ) - is UnsupportedConstraintOnShapeReachableViaAnEventStream -> LogMessage( - Level.SEVERE, - buildMessage( - """ - The ${shape.type} shape `${shape.id}` has the constraint trait `${constraintTrait.toShapeId()}` attached. - This shape is also part of an event stream; it is unclear what the semantics for constrained shapes in event streams are. - Please remove the trait from the shape to synthesize your model. - """.trimIndent().replace("\n", " "), - willSupport = false, - "https://github.com/awslabs/smithy/issues/1388", - canBeIgnored = false, - ), - ) + is UnsupportedConstraintOnShapeReachableViaAnEventStream -> + LogMessage( + Level.SEVERE, + buildMessage( + """ + The ${shape.type} shape `${shape.id}` has the constraint trait `${constraintTrait.toShapeId()}` attached. + This shape is also part of an event stream; it is unclear what the semantics for constrained shapes in event streams are. + Please remove the trait from the shape to synthesize your model. + """.trimIndent().replace("\n", " "), + willSupport = false, + "https://github.com/awslabs/smithy/issues/1388", + canBeIgnored = false, + ), + ) - is UnsupportedLengthTraitOnStreamingBlobShape -> LogMessage( - level, - buildMessage( - """ - The ${shape.type} shape `${shape.id}` has both the `${lengthTrait.toShapeId()}` and `${streamingTrait.toShapeId()}` constraint traits attached. - It is unclear what the semantics for streaming blob shapes are. - """.trimIndent().replace("\n", " "), - willSupport = false, - "https://github.com/awslabs/smithy/issues/1389", - ), - ) + is UnsupportedLengthTraitOnStreamingBlobShape -> + LogMessage( + level, + buildMessage( + """ + The ${shape.type} shape `${shape.id}` has both the `${lengthTrait.toShapeId()}` and `${streamingTrait.toShapeId()}` constraint traits attached. + It is unclear what the semantics for streaming blob shapes are. + """.trimIndent().replace("\n", " "), + willSupport = false, + "https://github.com/awslabs/smithy/issues/1389", + ), + ) - is UnsupportedRangeTraitOnShape -> LogMessage( - level, - buildMessageShapeHasUnsupportedConstraintTrait( - shape, - rangeTrait, - willSupport = false, - trackingIssue = "https://github.com/smithy-lang/smithy-rs/issues/2007", - ), - ) + is UnsupportedRangeTraitOnShape -> + LogMessage( + level, + buildMessageShapeHasUnsupportedConstraintTrait( + shape, + rangeTrait, + willSupport = false, + trackingIssue = "https://github.com/smithy-lang/smithy-rs/issues/2007", + ), + ) - is UnsupportedUniqueItemsTraitOnShape -> LogMessage( - level, - buildMessageShapeHasUnsupportedConstraintTrait(shape, uniqueItemsTrait, constraintTraitsUberIssue), - ) + is UnsupportedUniqueItemsTraitOnShape -> + LogMessage( + level, + buildMessageShapeHasUnsupportedConstraintTrait(shape, uniqueItemsTrait, constraintTraitsUberIssue), + ) - is UnsupportedMapShapeReachableFromUniqueItemsList -> LogMessage( - Level.SEVERE, - buildMessage( - """ - The map shape `${mapShape.id}` is reachable from the list shape `${listShape.id}`, which has the - `@uniqueItems` trait attached. - """.trimIndent().replace("\n", " "), - willSupport = false, - trackingIssue = "https://github.com/awslabs/smithy/issues/1567", - canBeIgnored = false, - ), - ) + is UnsupportedMapShapeReachableFromUniqueItemsList -> + LogMessage( + Level.SEVERE, + buildMessage( + """ + The map shape `${mapShape.id}` is reachable from the list shape `${listShape.id}`, which has the + `@uniqueItems` trait attached. + """.trimIndent().replace("\n", " "), + willSupport = false, + trackingIssue = "https://github.com/awslabs/smithy/issues/1567", + canBeIgnored = false, + ), + ) } } } private data class OperationWithConstrainedInputWithoutValidationException(val shape: OperationShape) + private data class UnsupportedConstraintOnMemberShape(val shape: MemberShape, val constraintTrait: Trait) : UnsupportedConstraintMessageKind() @@ -160,6 +171,7 @@ private data class UnsupportedMapShapeReachableFromUniqueItemsList( ) : UnsupportedConstraintMessageKind() data class LogMessage(val level: Level, val message: String) + data class ValidationResult(val shouldAbort: Boolean, val messages: List) : Throwable(message = messages.joinToString("\n") { it.message }) @@ -176,17 +188,18 @@ fun validateOperationsWithConstrainedInputHaveValidationExceptionAttached( // TODO(https://github.com/smithy-lang/smithy-rs/issues/1401): This check will go away once we add support for // `disableDefaultValidation` set to `true`, allowing service owners to map from constraint violations to operation errors. val walker = DirectedWalker(model) - val operationsWithConstrainedInputWithoutValidationExceptionSet = walker.walkShapes(service) - .filterIsInstance() - .asSequence() - .filter { operationShape -> - // Walk the shapes reachable via this operation input. - walker.walkShapes(operationShape.inputShape(model)) - .any { it is SetShape || it is EnumShape || it.hasConstraintTrait() } - } - .filter { !it.errors.contains(validationExceptionShapeId) } - .map { OperationWithConstrainedInputWithoutValidationException(it) } - .toSet() + val operationsWithConstrainedInputWithoutValidationExceptionSet = + walker.walkShapes(service) + .filterIsInstance() + .asSequence() + .filter { operationShape -> + // Walk the shapes reachable via this operation input. + walker.walkShapes(operationShape.inputShape(model)) + .any { it is SetShape || it is EnumShape || it.hasConstraintTrait() } + } + .filter { !it.errors.contains(validationExceptionShapeId) } + .map { OperationWithConstrainedInputWithoutValidationException(it) } + .toSet() val messages = operationsWithConstrainedInputWithoutValidationExceptionSet.map { @@ -224,63 +237,70 @@ fun validateUnsupportedConstraints( // 1. Constraint traits on streaming blob shapes are used. Their semantics are unclear. // TODO(https://github.com/awslabs/smithy/issues/1389) - val unsupportedLengthTraitOnStreamingBlobShapeSet = walker - .walkShapes(service) - .asSequence() - .filterIsInstance() - .filter { it.hasTrait() && it.hasTrait() } - .map { UnsupportedLengthTraitOnStreamingBlobShape(it, it.expectTrait(), it.expectTrait()) } - .toSet() + val unsupportedLengthTraitOnStreamingBlobShapeSet = + walker + .walkShapes(service) + .asSequence() + .filterIsInstance() + .filter { it.hasTrait() && it.hasTrait() } + .map { UnsupportedLengthTraitOnStreamingBlobShape(it, it.expectTrait(), it.expectTrait()) } + .toSet() // 2. Constraint traits in event streams are used. Their semantics are unclear. // TODO(https://github.com/awslabs/smithy/issues/1388) - val eventStreamShapes = walker - .walkShapes(service) - .asSequence() - .filter { it.hasTrait() } - val unsupportedConstraintOnNonErrorShapeReachableViaAnEventStreamSet = eventStreamShapes - .flatMap { walker.walkShapes(it) } - .filterMapShapesToTraits(allConstraintTraits) - .map { (shape, trait) -> UnsupportedConstraintOnShapeReachableViaAnEventStream(shape, trait) } - .toSet() - val eventStreamErrors = eventStreamShapes.map { - it.expectTrait() - }.map { it.errorMembers } - val unsupportedConstraintErrorShapeReachableViaAnEventStreamSet = eventStreamErrors - .flatMap { it } - .flatMap { walker.walkShapes(it) } - .filterMapShapesToTraits(allConstraintTraits) - .map { (shape, trait) -> UnsupportedConstraintOnShapeReachableViaAnEventStream(shape, trait) } - .toSet() + val eventStreamShapes = + walker + .walkShapes(service) + .asSequence() + .filter { it.hasTrait() } + val unsupportedConstraintOnNonErrorShapeReachableViaAnEventStreamSet = + eventStreamShapes + .flatMap { walker.walkShapes(it) } + .filterMapShapesToTraits(allConstraintTraits) + .map { (shape, trait) -> UnsupportedConstraintOnShapeReachableViaAnEventStream(shape, trait) } + .toSet() + val eventStreamErrors = + eventStreamShapes.map { + it.expectTrait() + }.map { it.errorMembers } + val unsupportedConstraintErrorShapeReachableViaAnEventStreamSet = + eventStreamErrors + .flatMap { it } + .flatMap { walker.walkShapes(it) } + .filterMapShapesToTraits(allConstraintTraits) + .map { (shape, trait) -> UnsupportedConstraintOnShapeReachableViaAnEventStream(shape, trait) } + .toSet() val unsupportedConstraintShapeReachableViaAnEventStreamSet = unsupportedConstraintOnNonErrorShapeReachableViaAnEventStreamSet + unsupportedConstraintErrorShapeReachableViaAnEventStreamSet // 3. Range trait used on unsupported shapes. // TODO(https://github.com/smithy-lang/smithy-rs/issues/2007) - val unsupportedRangeTraitOnShapeSet = walker - .walkShapes(service) - .asSequence() - .filterNot { it is IntegerShape || it is ShortShape || it is LongShape || it is ByteShape } - .filterMapShapesToTraits(setOf(RangeTrait::class.java)) - .map { (shape, rangeTrait) -> UnsupportedRangeTraitOnShape(shape, rangeTrait as RangeTrait) } - .toSet() + val unsupportedRangeTraitOnShapeSet = + walker + .walkShapes(service) + .asSequence() + .filterNot { it is IntegerShape || it is ShortShape || it is LongShape || it is ByteShape } + .filterMapShapesToTraits(setOf(RangeTrait::class.java)) + .map { (shape, rangeTrait) -> UnsupportedRangeTraitOnShape(shape, rangeTrait as RangeTrait) } + .toSet() // 4. `@uniqueItems` cannot reach a map shape. // See https://github.com/awslabs/smithy/issues/1567. - val mapShapeReachableFromUniqueItemsListShapeSet = walker - .walkShapes(service) - .asSequence() - .filterMapShapesToTraits(setOf(UniqueItemsTrait::class.java)) - .flatMap { (listShape, uniqueItemsTrait) -> - walker.walkShapes(listShape).filterIsInstance().map { mapShape -> - UnsupportedMapShapeReachableFromUniqueItemsList( - listShape as ListShape, - uniqueItemsTrait as UniqueItemsTrait, - mapShape, - ) + val mapShapeReachableFromUniqueItemsListShapeSet = + walker + .walkShapes(service) + .asSequence() + .filterMapShapesToTraits(setOf(UniqueItemsTrait::class.java)) + .flatMap { (listShape, uniqueItemsTrait) -> + walker.walkShapes(listShape).filterIsInstance().map { mapShape -> + UnsupportedMapShapeReachableFromUniqueItemsList( + listShape as ListShape, + uniqueItemsTrait as UniqueItemsTrait, + mapShape, + ) + } } - } - .toSet() + .toSet() val messages = ( @@ -294,16 +314,17 @@ fun validateUnsupportedConstraints( mapShapeReachableFromUniqueItemsListShapeSet.map { it.intoLogMessage(codegenConfig.ignoreUnsupportedConstraints) } - ).toMutableList() + ).toMutableList() if (messages.isEmpty() && codegenConfig.ignoreUnsupportedConstraints) { - messages += LogMessage( - Level.SEVERE, - """ - The `ignoreUnsupportedConstraints` flag in the `codegen` configuration is set to `true`, but it has no - effect. All the constraint traits used in the model are well-supported, please remove this flag. - """.trimIndent().replace("\n", " "), - ) + messages += + LogMessage( + Level.SEVERE, + """ + The `ignoreUnsupportedConstraints` flag in the `codegen` configuration is set to `true`, but it has no + effect. All the constraint traits used in the model are well-supported, please remove this flag. + """.trimIndent().replace("\n", " "), + ) } return ValidationResult(shouldAbort = messages.any { it.level == Level.SEVERE }, messages) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/AdditionalErrorsDecorator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/AdditionalErrorsDecorator.kt index 8842686d116..3dbf7c9804d 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/AdditionalErrorsDecorator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/AdditionalErrorsDecorator.kt @@ -35,8 +35,11 @@ class AddInternalServerErrorToInfallibleOperationsDecorator : ServerCodegenDecor override val name: String = "AddInternalServerErrorToInfallibleOperations" override val order: Byte = 0 - override fun transformModel(service: ServiceShape, model: Model, settings: ServerRustSettings): Model = - addErrorShapeToModelOperations(service, model) { shape -> shape.allErrors(model).isEmpty() } + override fun transformModel( + service: ServiceShape, + model: Model, + settings: ServerRustSettings, + ): Model = addErrorShapeToModelOperations(service, model) { shape -> shape.allErrors(model).isEmpty() } } /** @@ -61,11 +64,18 @@ class AddInternalServerErrorToAllOperationsDecorator : ServerCodegenDecorator { override val name: String = "AddInternalServerErrorToAllOperations" override val order: Byte = 0 - override fun transformModel(service: ServiceShape, model: Model, settings: ServerRustSettings): Model = - addErrorShapeToModelOperations(service, model) { true } + override fun transformModel( + service: ServiceShape, + model: Model, + settings: ServerRustSettings, + ): Model = addErrorShapeToModelOperations(service, model) { true } } -fun addErrorShapeToModelOperations(service: ServiceShape, model: Model, opSelector: (OperationShape) -> Boolean): Model { +fun addErrorShapeToModelOperations( + service: ServiceShape, + model: Model, + opSelector: (OperationShape) -> Boolean, +): Model { val errorShape = internalServerError(service.id.namespace) val modelShapes = model.toBuilder().addShapes(listOf(errorShape)).build() return ModelTransformer.create().mapShapes(modelShapes) { shape -> diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/BeforeIteratingOverMapOrCollectionJsonCustomization.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/BeforeIteratingOverMapOrCollectionJsonCustomization.kt index 8dabd4ac02a..b891de94856 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/BeforeIteratingOverMapOrCollectionJsonCustomization.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/BeforeIteratingOverMapOrCollectionJsonCustomization.kt @@ -21,19 +21,21 @@ import software.amazon.smithy.rust.codegen.server.smithy.workingWithPublicConstr * That value will be a `std::collections::HashMap` for map shapes, and a `std::vec::Vec` for collection shapes. */ class BeforeIteratingOverMapOrCollectionJsonCustomization(private val codegenContext: ServerCodegenContext) : JsonSerializerCustomization() { - override fun section(section: JsonSerializerSection): Writable = when (section) { - is JsonSerializerSection.BeforeIteratingOverMapOrCollection -> writable { - check(section.shape is CollectionShape || section.shape is MapShape) - if (workingWithPublicConstrainedWrapperTupleType( - section.shape, - codegenContext.model, - codegenContext.settings.codegenConfig.publicConstrainedTypes, - ) - ) { - section.context.valueExpression = - ValueExpression.Reference("&${section.context.valueExpression.name}.0") - } + override fun section(section: JsonSerializerSection): Writable = + when (section) { + is JsonSerializerSection.BeforeIteratingOverMapOrCollection -> + writable { + check(section.shape is CollectionShape || section.shape is MapShape) + if (workingWithPublicConstrainedWrapperTupleType( + section.shape, + codegenContext.model, + codegenContext.settings.codegenConfig.publicConstrainedTypes, + ) + ) { + section.context.valueExpression = + ValueExpression.Reference("&${section.context.valueExpression.name}.0") + } + } + else -> emptySection } - else -> emptySection - } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/BeforeSerializingMemberJsonCustomization.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/BeforeSerializingMemberJsonCustomization.kt index 71122511ad8..bd548cb743b 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/BeforeSerializingMemberJsonCustomization.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/BeforeSerializingMemberJsonCustomization.kt @@ -24,21 +24,23 @@ import software.amazon.smithy.rust.codegen.server.smithy.workingWithPublicConstr */ class BeforeSerializingMemberJsonCustomization(private val codegenContext: ServerCodegenContext) : JsonSerializerCustomization() { - override fun section(section: JsonSerializerSection): Writable = when (section) { - is JsonSerializerSection.BeforeSerializingNonNullMember -> writable { - if (workingWithPublicConstrainedWrapperTupleType( - section.shape, - codegenContext.model, - codegenContext.settings.codegenConfig.publicConstrainedTypes, - ) - ) { - if (section.shape is IntegerShape || section.shape is ShortShape || section.shape is LongShape || section.shape is ByteShape || section.shape is BlobShape) { - section.context.valueExpression = - ValueExpression.Reference("&${section.context.valueExpression.name}.0") + override fun section(section: JsonSerializerSection): Writable = + when (section) { + is JsonSerializerSection.BeforeSerializingNonNullMember -> + writable { + if (workingWithPublicConstrainedWrapperTupleType( + section.shape, + codegenContext.model, + codegenContext.settings.codegenConfig.publicConstrainedTypes, + ) + ) { + if (section.shape is IntegerShape || section.shape is ShortShape || section.shape is LongShape || section.shape is ByteShape || section.shape is BlobShape) { + section.context.valueExpression = + ValueExpression.Reference("&${section.context.valueExpression.name}.0") + } + } } - } - } - else -> emptySection - } + else -> emptySection + } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/CustomValidationExceptionWithReasonDecorator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/CustomValidationExceptionWithReasonDecorator.kt index b2b2aa0b992..ad080febbba 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/CustomValidationExceptionWithReasonDecorator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/CustomValidationExceptionWithReasonDecorator.kt @@ -53,8 +53,9 @@ class CustomValidationExceptionWithReasonDecorator : ServerCodegenDecorator { override val order: Byte get() = -69 - override fun validationExceptionConversion(codegenContext: ServerCodegenContext): - ValidationExceptionConversionGenerator? = + override fun validationExceptionConversion( + codegenContext: ServerCodegenContext, + ): ValidationExceptionConversionGenerator? = if (codegenContext.settings.codegenConfig.experimentalCustomValidationExceptionWithReasonPleaseDoNotUse != null) { ValidationExceptionWithReasonConversionGenerator(codegenContext) } else { @@ -67,136 +68,141 @@ class ValidationExceptionWithReasonConversionGenerator(private val codegenContex override val shapeId: ShapeId = ShapeId.from(codegenContext.settings.codegenConfig.experimentalCustomValidationExceptionWithReasonPleaseDoNotUse) - override fun renderImplFromConstraintViolationForRequestRejection(protocol: ServerProtocol): Writable = writable { - rustTemplate( - """ - impl #{From} for #{RequestRejection} { - fn from(constraint_violation: ConstraintViolation) -> Self { - let first_validation_exception_field = constraint_violation.as_validation_exception_field("".to_owned()); - let validation_exception = crate::error::ValidationException { - message: format!("1 validation error detected. {}", &first_validation_exception_field.message), - reason: crate::model::ValidationExceptionReason::FieldValidationFailed, - fields: Some(vec![first_validation_exception_field]), - }; - Self::ConstraintViolation( - crate::protocol_serde::shape_validation_exception::ser_validation_exception_error(&validation_exception) - .expect("validation exceptions should never fail to serialize; please file a bug report under https://github.com/smithy-lang/smithy-rs/issues") - ) + override fun renderImplFromConstraintViolationForRequestRejection(protocol: ServerProtocol): Writable = + writable { + rustTemplate( + """ + impl #{From} for #{RequestRejection} { + fn from(constraint_violation: ConstraintViolation) -> Self { + let first_validation_exception_field = constraint_violation.as_validation_exception_field("".to_owned()); + let validation_exception = crate::error::ValidationException { + message: format!("1 validation error detected. {}", &first_validation_exception_field.message), + reason: crate::model::ValidationExceptionReason::FieldValidationFailed, + fields: Some(vec![first_validation_exception_field]), + }; + Self::ConstraintViolation( + crate::protocol_serde::shape_validation_exception::ser_validation_exception_error(&validation_exception) + .expect("validation exceptions should never fail to serialize; please file a bug report under https://github.com/smithy-lang/smithy-rs/issues") + ) + } } - } - """, - "RequestRejection" to protocol.requestRejection(codegenContext.runtimeConfig), - "From" to RuntimeType.From, - ) - } + """, + "RequestRejection" to protocol.requestRejection(codegenContext.runtimeConfig), + "From" to RuntimeType.From, + ) + } - override fun stringShapeConstraintViolationImplBlock(stringConstraintsInfo: Collection): Writable = writable { - val validationExceptionFields = - stringConstraintsInfo.map { - writable { - when (it) { - is Pattern -> { - rustTemplate( - """ - Self::Pattern(_) => crate::model::ValidationExceptionField { - message: #{MessageWritable:W}, - name: path, - reason: crate::model::ValidationExceptionFieldReason::PatternNotValid, - }, - """, - "MessageWritable" to it.errorMessage(), - ) - } - is Length -> { - rust( - """ - Self::Length(length) => crate::model::ValidationExceptionField { - message: format!("${it.lengthTrait.validationErrorMessage()}", length, &path), - name: path, - reason: crate::model::ValidationExceptionFieldReason::LengthNotValid, - }, - """, - ) + override fun stringShapeConstraintViolationImplBlock(stringConstraintsInfo: Collection): Writable = + writable { + val validationExceptionFields = + stringConstraintsInfo.map { + writable { + when (it) { + is Pattern -> { + rustTemplate( + """ + Self::Pattern(_) => crate::model::ValidationExceptionField { + message: #{MessageWritable:W}, + name: path, + reason: crate::model::ValidationExceptionFieldReason::PatternNotValid, + }, + """, + "MessageWritable" to it.errorMessage(), + ) + } + is Length -> { + rust( + """ + Self::Length(length) => crate::model::ValidationExceptionField { + message: format!("${it.lengthTrait.validationErrorMessage()}", length, &path), + name: path, + reason: crate::model::ValidationExceptionFieldReason::LengthNotValid, + }, + """, + ) + } } } - } - }.join("\n") + }.join("\n") - rustTemplate( - """ - pub(crate) fn as_validation_exception_field(self, path: #{String}) -> crate::model::ValidationExceptionField { - match self { - #{ValidationExceptionFields:W} - } - } - """, - "String" to RuntimeType.String, - "ValidationExceptionFields" to validationExceptionFields, - ) - } - - override fun enumShapeConstraintViolationImplBlock(enumTrait: EnumTrait) = writable { - val enumValueSet = enumTrait.enumDefinitionValues.joinToString(", ") - val message = "Value at '{}' failed to satisfy constraint: Member must satisfy enum value set: [$enumValueSet]" - rustTemplate( - """ - pub(crate) fn as_validation_exception_field(self, path: #{String}) -> crate::model::ValidationExceptionField { - crate::model::ValidationExceptionField { - message: format!(r##"$message"##, &path), - name: path, - reason: crate::model::ValidationExceptionFieldReason::ValueNotValid, + rustTemplate( + """ + pub(crate) fn as_validation_exception_field(self, path: #{String}) -> crate::model::ValidationExceptionField { + match self { + #{ValidationExceptionFields:W} + } } - } - """, - "String" to RuntimeType.String, - ) - } + """, + "String" to RuntimeType.String, + "ValidationExceptionFields" to validationExceptionFields, + ) + } - override fun numberShapeConstraintViolationImplBlock(rangeInfo: Range) = writable { - rustTemplate( - """ - pub(crate) fn as_validation_exception_field(self, path: #{String}) -> crate::model::ValidationExceptionField { - match self { - Self::Range(_) => crate::model::ValidationExceptionField { - message: format!("${rangeInfo.rangeTrait.validationErrorMessage()}", &path), + override fun enumShapeConstraintViolationImplBlock(enumTrait: EnumTrait) = + writable { + val enumValueSet = enumTrait.enumDefinitionValues.joinToString(", ") + val message = "Value at '{}' failed to satisfy constraint: Member must satisfy enum value set: [$enumValueSet]" + rustTemplate( + """ + pub(crate) fn as_validation_exception_field(self, path: #{String}) -> crate::model::ValidationExceptionField { + crate::model::ValidationExceptionField { + message: format!(r##"$message"##, &path), name: path, reason: crate::model::ValidationExceptionFieldReason::ValueNotValid, } } - } - """, - "String" to RuntimeType.String, - ) - } + """, + "String" to RuntimeType.String, + ) + } - override fun blobShapeConstraintViolationImplBlock(blobConstraintsInfo: Collection) = writable { - val validationExceptionFields = - blobConstraintsInfo.map { - writable { - rust( - """ - Self::Length(length) => crate::model::ValidationExceptionField { - message: format!("${it.lengthTrait.validationErrorMessage()}", length, &path), + override fun numberShapeConstraintViolationImplBlock(rangeInfo: Range) = + writable { + rustTemplate( + """ + pub(crate) fn as_validation_exception_field(self, path: #{String}) -> crate::model::ValidationExceptionField { + match self { + Self::Range(_) => crate::model::ValidationExceptionField { + message: format!("${rangeInfo.rangeTrait.validationErrorMessage()}", &path), name: path, - reason: crate::model::ValidationExceptionFieldReason::LengthNotValid, - }, - """, - ) + reason: crate::model::ValidationExceptionFieldReason::ValueNotValid, + } + } } - }.join("\n") + """, + "String" to RuntimeType.String, + ) + } - rustTemplate( - """ - pub(crate) fn as_validation_exception_field(self, path: #{String}) -> crate::model::ValidationExceptionField { - match self { - #{ValidationExceptionFields:W} + override fun blobShapeConstraintViolationImplBlock(blobConstraintsInfo: Collection) = + writable { + val validationExceptionFields = + blobConstraintsInfo.map { + writable { + rust( + """ + Self::Length(length) => crate::model::ValidationExceptionField { + message: format!("${it.lengthTrait.validationErrorMessage()}", length, &path), + name: path, + reason: crate::model::ValidationExceptionFieldReason::LengthNotValid, + }, + """, + ) + } + }.join("\n") + + rustTemplate( + """ + pub(crate) fn as_validation_exception_field(self, path: #{String}) -> crate::model::ValidationExceptionField { + match self { + #{ValidationExceptionFields:W} + } } - } - """, - "String" to RuntimeType.String, - "ValidationExceptionFields" to validationExceptionFields, - ) - } + """, + "String" to RuntimeType.String, + "ValidationExceptionFields" to validationExceptionFields, + ) + } override fun mapShapeConstraintViolationImplBlock( shape: MapShape, @@ -231,60 +237,61 @@ class ValidationExceptionWithReasonConversionGenerator(private val codegenContex } } - override fun builderConstraintViolationImplBlock(constraintViolations: Collection) = writable { - rustBlock("match self") { - constraintViolations.forEach { - if (it.hasInner()) { - rust("""ConstraintViolation::${it.name()}(inner) => inner.as_validation_exception_field(path + "/${it.forMember.memberName}"),""") - } else { - rust( - """ - ConstraintViolation::${it.name()} => crate::model::ValidationExceptionField { - message: format!("Value at '{}/${it.forMember.memberName}' failed to satisfy constraint: Member must not be null", path), - name: path + "/${it.forMember.memberName}", - reason: crate::model::ValidationExceptionFieldReason::Other, - }, - """, - ) + override fun builderConstraintViolationImplBlock(constraintViolations: Collection) = + writable { + rustBlock("match self") { + constraintViolations.forEach { + if (it.hasInner()) { + rust("""ConstraintViolation::${it.name()}(inner) => inner.as_validation_exception_field(path + "/${it.forMember.memberName}"),""") + } else { + rust( + """ + ConstraintViolation::${it.name()} => crate::model::ValidationExceptionField { + message: format!("Value at '{}/${it.forMember.memberName}' failed to satisfy constraint: Member must not be null", path), + name: path + "/${it.forMember.memberName}", + reason: crate::model::ValidationExceptionFieldReason::Other, + }, + """, + ) + } } } } - } override fun collectionShapeConstraintViolationImplBlock( - collectionConstraintsInfo: - Collection, + collectionConstraintsInfo: Collection, isMemberConstrained: Boolean, ) = writable { - val validationExceptionFields = collectionConstraintsInfo.map { - writable { - when (it) { - is CollectionTraitInfo.Length -> { - rust( - """ - Self::Length(length) => crate::model::ValidationExceptionField { - message: format!("${it.lengthTrait.validationErrorMessage()}", length, &path), - name: path, - reason: crate::model::ValidationExceptionFieldReason::LengthNotValid, - }, - """, - ) - } - is CollectionTraitInfo.UniqueItems -> { - rust( - """ - Self::UniqueItems { duplicate_indices, .. } => - crate::model::ValidationExceptionField { - message: format!("${it.uniqueItemsTrait.validationErrorMessage()}", &duplicate_indices, &path), + val validationExceptionFields = + collectionConstraintsInfo.map { + writable { + when (it) { + is CollectionTraitInfo.Length -> { + rust( + """ + Self::Length(length) => crate::model::ValidationExceptionField { + message: format!("${it.lengthTrait.validationErrorMessage()}", length, &path), name: path, - reason: crate::model::ValidationExceptionFieldReason::ValueNotValid, + reason: crate::model::ValidationExceptionFieldReason::LengthNotValid, }, - """, - ) + """, + ) + } + is CollectionTraitInfo.UniqueItems -> { + rust( + """ + Self::UniqueItems { duplicate_indices, .. } => + crate::model::ValidationExceptionField { + message: format!("${it.uniqueItemsTrait.validationErrorMessage()}", &duplicate_indices, &path), + name: path, + reason: crate::model::ValidationExceptionFieldReason::ValueNotValid, + }, + """, + ) + } } } - } - }.toMutableList() + }.toMutableList() if (isMemberConstrained) { validationExceptionFields += { diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/ServerRequiredCustomizations.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/ServerRequiredCustomizations.kt index 7823c53326d..f4205e794d3 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/ServerRequiredCustomizations.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/ServerRequiredCustomizations.kt @@ -31,10 +31,12 @@ class ServerRequiredCustomizations : ServerCodegenDecorator { override fun libRsCustomizations( codegenContext: ServerCodegenContext, baseCustomizations: List, - ): List = - baseCustomizations + AllowLintsCustomization() + ): List = baseCustomizations + AllowLintsCustomization() - override fun extras(codegenContext: ServerCodegenContext, rustCrate: RustCrate) { + override fun extras( + codegenContext: ServerCodegenContext, + rustCrate: RustCrate, + ) { val rc = codegenContext.runtimeConfig // Add rt-tokio feature for `ByteStream::from_path` diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/SmithyValidationExceptionDecorator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/SmithyValidationExceptionDecorator.kt index f9124f7f778..9e9438b6416 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/SmithyValidationExceptionDecorator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/SmithyValidationExceptionDecorator.kt @@ -53,72 +53,76 @@ class SmithyValidationExceptionDecorator : ServerCodegenDecorator { override val order: Byte get() = 69 - override fun validationExceptionConversion(codegenContext: ServerCodegenContext): ValidationExceptionConversionGenerator = - SmithyValidationExceptionConversionGenerator(codegenContext) + override fun validationExceptionConversion( + codegenContext: ServerCodegenContext, + ): ValidationExceptionConversionGenerator = SmithyValidationExceptionConversionGenerator(codegenContext) } class SmithyValidationExceptionConversionGenerator(private val codegenContext: ServerCodegenContext) : ValidationExceptionConversionGenerator { - // Define a companion object so that we can refer to this shape id globally. companion object { val SHAPE_ID: ShapeId = ShapeId.from("smithy.framework#ValidationException") } + override val shapeId: ShapeId = SHAPE_ID - override fun renderImplFromConstraintViolationForRequestRejection(protocol: ServerProtocol): Writable = writable { - rustTemplate( - """ - impl #{From} for #{RequestRejection} { - fn from(constraint_violation: ConstraintViolation) -> Self { - let first_validation_exception_field = constraint_violation.as_validation_exception_field("".to_owned()); - let validation_exception = crate::error::ValidationException { - message: format!("1 validation error detected. {}", &first_validation_exception_field.message), - field_list: Some(vec![first_validation_exception_field]), - }; - Self::ConstraintViolation( - crate::protocol_serde::shape_validation_exception::ser_validation_exception_error(&validation_exception) - .expect("validation exceptions should never fail to serialize; please file a bug report under https://github.com/smithy-lang/smithy-rs/issues") - ) + override fun renderImplFromConstraintViolationForRequestRejection(protocol: ServerProtocol): Writable = + writable { + rustTemplate( + """ + impl #{From} for #{RequestRejection} { + fn from(constraint_violation: ConstraintViolation) -> Self { + let first_validation_exception_field = constraint_violation.as_validation_exception_field("".to_owned()); + let validation_exception = crate::error::ValidationException { + message: format!("1 validation error detected. {}", &first_validation_exception_field.message), + field_list: Some(vec![first_validation_exception_field]), + }; + Self::ConstraintViolation( + crate::protocol_serde::shape_validation_exception::ser_validation_exception_error(&validation_exception) + .expect("validation exceptions should never fail to serialize; please file a bug report under https://github.com/smithy-lang/smithy-rs/issues") + ) + } } - } - """, - "RequestRejection" to protocol.requestRejection(codegenContext.runtimeConfig), - "From" to RuntimeType.From, - ) - } + """, + "RequestRejection" to protocol.requestRejection(codegenContext.runtimeConfig), + "From" to RuntimeType.From, + ) + } - override fun stringShapeConstraintViolationImplBlock(stringConstraintsInfo: Collection): Writable = writable { - val constraintsInfo: List = stringConstraintsInfo.map(StringTraitInfo::toTraitInfo) + override fun stringShapeConstraintViolationImplBlock(stringConstraintsInfo: Collection): Writable = + writable { + val constraintsInfo: List = stringConstraintsInfo.map(StringTraitInfo::toTraitInfo) - rustTemplate( - """ - pub(crate) fn as_validation_exception_field(self, path: #{String}) -> crate::model::ValidationExceptionField { - match self { - #{ValidationExceptionFields:W} + rustTemplate( + """ + pub(crate) fn as_validation_exception_field(self, path: #{String}) -> crate::model::ValidationExceptionField { + match self { + #{ValidationExceptionFields:W} + } } - } - """, - "String" to RuntimeType.String, - "ValidationExceptionFields" to constraintsInfo.map { it.asValidationExceptionField }.join("\n"), - ) - } + """, + "String" to RuntimeType.String, + "ValidationExceptionFields" to constraintsInfo.map { it.asValidationExceptionField }.join("\n"), + ) + } - override fun blobShapeConstraintViolationImplBlock(blobConstraintsInfo: Collection): Writable = writable { - val constraintsInfo: List = blobConstraintsInfo.map(BlobLength::toTraitInfo) + override fun blobShapeConstraintViolationImplBlock(blobConstraintsInfo: Collection): Writable = + writable { + val constraintsInfo: List = blobConstraintsInfo.map(BlobLength::toTraitInfo) - rustTemplate( - """ - pub(crate) fn as_validation_exception_field(self, path: #{String}) -> crate::model::ValidationExceptionField { - match self { - #{ValidationExceptionFields:W} + rustTemplate( + """ + pub(crate) fn as_validation_exception_field(self, path: #{String}) -> crate::model::ValidationExceptionField { + match self { + #{ValidationExceptionFields:W} + } } - } - """, - "String" to RuntimeType.String, - "ValidationExceptionFields" to constraintsInfo.map { it.asValidationExceptionField }.join("\n"), - ) - } + """, + "String" to RuntimeType.String, + "ValidationExceptionFields" to constraintsInfo.map { it.asValidationExceptionField }.join("\n"), + ) + } override fun mapShapeConstraintViolationImplBlock( shape: MapShape, @@ -155,63 +159,66 @@ class SmithyValidationExceptionConversionGenerator(private val codegenContext: S } } - override fun enumShapeConstraintViolationImplBlock(enumTrait: EnumTrait) = writable { - val enumValueSet = enumTrait.enumDefinitionValues.joinToString(", ") - val message = "Value at '{}' failed to satisfy constraint: Member must satisfy enum value set: [$enumValueSet]" - rustTemplate( - """ - pub(crate) fn as_validation_exception_field(self, path: #{String}) -> crate::model::ValidationExceptionField { - crate::model::ValidationExceptionField { - message: format!(r##"$message"##, &path), - path, + override fun enumShapeConstraintViolationImplBlock(enumTrait: EnumTrait) = + writable { + val enumValueSet = enumTrait.enumDefinitionValues.joinToString(", ") + val message = "Value at '{}' failed to satisfy constraint: Member must satisfy enum value set: [$enumValueSet]" + rustTemplate( + """ + pub(crate) fn as_validation_exception_field(self, path: #{String}) -> crate::model::ValidationExceptionField { + crate::model::ValidationExceptionField { + message: format!(r##"$message"##, &path), + path, + } } - } - """, - "String" to RuntimeType.String, - ) - } + """, + "String" to RuntimeType.String, + ) + } - override fun numberShapeConstraintViolationImplBlock(rangeInfo: Range) = writable { - rustTemplate( - """ - pub(crate) fn as_validation_exception_field(self, path: #{String}) -> crate::model::ValidationExceptionField { - match self { - #{ValidationExceptionFields:W} + override fun numberShapeConstraintViolationImplBlock(rangeInfo: Range) = + writable { + rustTemplate( + """ + pub(crate) fn as_validation_exception_field(self, path: #{String}) -> crate::model::ValidationExceptionField { + match self { + #{ValidationExceptionFields:W} + } } - } - """, - "String" to RuntimeType.String, - "ValidationExceptionFields" to rangeInfo.toTraitInfo().asValidationExceptionField, - ) - } + """, + "String" to RuntimeType.String, + "ValidationExceptionFields" to rangeInfo.toTraitInfo().asValidationExceptionField, + ) + } - override fun builderConstraintViolationImplBlock(constraintViolations: Collection) = writable { - rustBlock("match self") { - constraintViolations.forEach { - if (it.hasInner()) { - rust("""ConstraintViolation::${it.name()}(inner) => inner.as_validation_exception_field(path + "/${it.forMember.memberName}"),""") - } else { - rust( - """ - ConstraintViolation::${it.name()} => crate::model::ValidationExceptionField { - message: format!("Value at '{}/${it.forMember.memberName}' failed to satisfy constraint: Member must not be null", path), - path: path + "/${it.forMember.memberName}", - }, - """, - ) + override fun builderConstraintViolationImplBlock(constraintViolations: Collection) = + writable { + rustBlock("match self") { + constraintViolations.forEach { + if (it.hasInner()) { + rust("""ConstraintViolation::${it.name()}(inner) => inner.as_validation_exception_field(path + "/${it.forMember.memberName}"),""") + } else { + rust( + """ + ConstraintViolation::${it.name()} => crate::model::ValidationExceptionField { + message: format!("Value at '{}/${it.forMember.memberName}' failed to satisfy constraint: Member must not be null", path), + path: path + "/${it.forMember.memberName}", + }, + """, + ) + } } } } - } override fun collectionShapeConstraintViolationImplBlock( - collectionConstraintsInfo: - Collection, + collectionConstraintsInfo: Collection, isMemberConstrained: Boolean, ) = writable { - val validationExceptionFields = collectionConstraintsInfo.map { - it.toTraitInfo().asValidationExceptionField - }.toMutableList() + val validationExceptionFields = + collectionConstraintsInfo.map { + it.toTraitInfo().asValidationExceptionField + }.toMutableList() if (isMemberConstrained) { validationExceptionFields += { rust( diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customize/ServerCodegenDecorator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customize/ServerCodegenDecorator.kt index 22df729e5ae..5bd79ed7a06 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customize/ServerCodegenDecorator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customize/ServerCodegenDecorator.kt @@ -27,8 +27,13 @@ typealias ServerProtocolMap = ProtocolMap { - fun protocols(serviceId: ShapeId, currentProtocols: ServerProtocolMap): ServerProtocolMap = currentProtocols - fun validationExceptionConversion(codegenContext: ServerCodegenContext): ValidationExceptionConversionGenerator? = null + fun protocols( + serviceId: ShapeId, + currentProtocols: ServerProtocolMap, + ): ServerProtocolMap = currentProtocols + + fun validationExceptionConversion(codegenContext: ServerCodegenContext): ValidationExceptionConversionGenerator? = + null /** * Injection point to allow a decorator to postprocess the error message that arises when an operation is @@ -42,7 +47,8 @@ interface ServerCodegenDecorator : CoreCodegenDecorator = emptyList() + fun postprocessOperationGenerateAdditionalStructures(operationShape: OperationShape): List = + emptyList() /** * For each service, this hook allows decorators to return a collection of structure shapes that will additionally be generated. @@ -67,7 +73,6 @@ interface ServerCodegenDecorator : CoreCodegenDecorator) : CombinedCoreCodegenDecorator(decorators), ServerCodegenDecorator { - private val orderedDecorators = decorators.sortedBy { it.order } override val name: String @@ -75,22 +80,31 @@ class CombinedServerCodegenDecorator(decorators: List) : override val order: Byte get() = 0 - override fun protocols(serviceId: ShapeId, currentProtocols: ServerProtocolMap): ServerProtocolMap = + override fun protocols( + serviceId: ShapeId, + currentProtocols: ServerProtocolMap, + ): ServerProtocolMap = combineCustomizations(currentProtocols) { decorator, protocolMap -> decorator.protocols(serviceId, protocolMap) } - override fun validationExceptionConversion(codegenContext: ServerCodegenContext): ValidationExceptionConversionGenerator = + override fun validationExceptionConversion( + codegenContext: ServerCodegenContext, + ): ValidationExceptionConversionGenerator = // We use `firstNotNullOf` instead of `firstNotNullOfOrNull` because the [SmithyValidationExceptionDecorator] // is registered. orderedDecorators.firstNotNullOf { it.validationExceptionConversion(codegenContext) } - override fun postprocessValidationExceptionNotAttachedErrorMessage(validationResult: ValidationResult): ValidationResult = + override fun postprocessValidationExceptionNotAttachedErrorMessage( + validationResult: ValidationResult, + ): ValidationResult = orderedDecorators.foldRight(validationResult) { decorator, accumulated -> decorator.postprocessValidationExceptionNotAttachedErrorMessage(accumulated) } - override fun postprocessOperationGenerateAdditionalStructures(operationShape: OperationShape): List = + override fun postprocessOperationGenerateAdditionalStructures( + operationShape: OperationShape, + ): List = orderedDecorators.flatMap { it.postprocessOperationGenerateAdditionalStructures(operationShape) } override fun postprocessServiceGenerateAdditionalStructures(serviceShape: ServiceShape): List = diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedBlobGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedBlobGenerator.kt index 5a1c3bc4e18..9d90562908a 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedBlobGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedBlobGenerator.kt @@ -47,9 +47,10 @@ class ConstrainedBlobGenerator( PubCrateConstraintViolationSymbolProvider(this) } } - private val blobConstraintsInfo: List = listOf(LengthTrait::class.java) - .mapNotNull { shape.getTrait(it).orNull() } - .map { BlobLength(it) } + private val blobConstraintsInfo: List = + listOf(LengthTrait::class.java) + .mapNotNull { shape.getTrait(it).orNull() } + .map { BlobLength(it) } private val constraintsInfo: List = blobConstraintsInfo.map { it.toTraitInfo() } fun render() { @@ -116,7 +117,11 @@ class ConstrainedBlobGenerator( } } - private fun renderConstraintViolationEnum(writer: RustWriter, shape: BlobShape, constraintViolation: Symbol) { + private fun renderConstraintViolationEnum( + writer: RustWriter, + shape: BlobShape, + constraintViolation: Symbol, + ) { writer.rustTemplate( """ ##[derive(Debug, PartialEq)] @@ -141,41 +146,46 @@ class ConstrainedBlobGenerator( } data class BlobLength(val lengthTrait: LengthTrait) { - fun toTraitInfo(): TraitInfo = TraitInfo( - { rust("Self::check_length(&value)?;") }, - { - docs("Error when a blob doesn't satisfy its `@length` requirements.") - rust("Length(usize)") - }, - { - rust( - """ - Self::Length(length) => crate::model::ValidationExceptionField { - message: format!("${lengthTrait.validationErrorMessage()}", length, &path), - path, + fun toTraitInfo(): TraitInfo = + TraitInfo( + { rust("Self::check_length(&value)?;") }, + { + docs("Error when a blob doesn't satisfy its `@length` requirements.") + rust("Length(usize)") + }, + { + rust( + """ + Self::Length(length) => crate::model::ValidationExceptionField { + message: format!("${lengthTrait.validationErrorMessage()}", length, &path), + path, },""", - ) - }, - this::renderValidationFunction, - ) + ) + }, + this::renderValidationFunction, + ) /** * Renders a `check_length` function to validate the blob matches the * required length indicated by the `@length` trait. */ - private fun renderValidationFunction(constraintViolation: Symbol, unconstrainedTypeName: String): Writable = { - rust( - """ - fn check_length(blob: &$unconstrainedTypeName) -> Result<(), $constraintViolation> { - let length = blob.as_ref().len(); + private fun renderValidationFunction( + constraintViolation: Symbol, + unconstrainedTypeName: String, + ): Writable = + { + rust( + """ + fn check_length(blob: &$unconstrainedTypeName) -> Result<(), $constraintViolation> { + let length = blob.as_ref().len(); - if ${lengthTrait.rustCondition("length")} { - Ok(()) - } else { - Err($constraintViolation::Length(length)) + if ${lengthTrait.rustCondition("length")} { + Ok(()) + } else { + Err($constraintViolation::Length(length)) + } } - } - """, - ) - } + """, + ) + } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedCollectionGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedCollectionGenerator.kt index f705e4af3a0..d2029bfd2e5 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedCollectionGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedCollectionGenerator.kt @@ -76,12 +76,13 @@ class ConstrainedCollectionGenerator( val constraintViolation = constraintViolationSymbolProvider.toSymbol(shape) val constrainedSymbol = symbolProvider.toSymbol(shape) - val codegenScope = arrayOf( - "ValueMemberSymbol" to constrainedShapeSymbolProvider.toSymbol(shape.member), - "From" to RuntimeType.From, - "TryFrom" to RuntimeType.TryFrom, - "ConstraintViolation" to constraintViolation, - ) + val codegenScope = + arrayOf( + "ValueMemberSymbol" to constrainedShapeSymbolProvider.toSymbol(shape.member), + "From" to RuntimeType.From, + "TryFrom" to RuntimeType.TryFrom, + "ConstraintViolation" to constraintViolation, + ) writer.documentShape(shape, model) writer.docs(rustDocsConstrainedTypeEpilogue(name)) @@ -116,9 +117,10 @@ class ConstrainedCollectionGenerator( #{ValidationFunctions:W} """, *codegenScope, - "ValidationFunctions" to constraintsInfo.map { - it.validationFunctionDefinition(constraintViolation, inner) - }.join("\n"), + "ValidationFunctions" to + constraintsInfo.map { + it.validationFunctionDefinition(constraintViolation, inner) + }.join("\n"), ) } @@ -355,7 +357,11 @@ sealed class CollectionTraitInfo { } companion object { - private fun fromTrait(trait: Trait, shape: CollectionShape, symbolProvider: SymbolProvider): CollectionTraitInfo { + private fun fromTrait( + trait: Trait, + shape: CollectionShape, + symbolProvider: SymbolProvider, + ): CollectionTraitInfo { check(shape.hasTrait(trait.toShapeId())) return when (trait) { is LengthTrait -> { @@ -370,7 +376,10 @@ sealed class CollectionTraitInfo { } } - fun fromShape(shape: CollectionShape, symbolProvider: SymbolProvider): List = + fun fromShape( + shape: CollectionShape, + symbolProvider: SymbolProvider, + ): List = supportedCollectionConstraintTraits .mapNotNull { shape.getTrait(it).orNull() } .map { trait -> fromTrait(trait, shape, symbolProvider) } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedMapGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedMapGenerator.kt index 28b0d9f8d75..2128cdaec84 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedMapGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedMapGenerator.kt @@ -64,14 +64,15 @@ class ConstrainedMapGenerator( val constraintViolation = constraintViolationSymbolProvider.toSymbol(shape) val constrainedSymbol = symbolProvider.toSymbol(shape) - val codegenScope = arrayOf( - "HashMap" to RuntimeType.HashMap, - "KeySymbol" to constrainedShapeSymbolProvider.toSymbol(model.expectShape(shape.key.target)), - "ValueMemberSymbol" to constrainedShapeSymbolProvider.toSymbol(shape.value), - "From" to RuntimeType.From, - "TryFrom" to RuntimeType.TryFrom, - "ConstraintViolation" to constraintViolation, - ) + val codegenScope = + arrayOf( + "HashMap" to RuntimeType.HashMap, + "KeySymbol" to constrainedShapeSymbolProvider.toSymbol(model.expectShape(shape.key.target)), + "ValueMemberSymbol" to constrainedShapeSymbolProvider.toSymbol(shape.value), + "From" to RuntimeType.From, + "TryFrom" to RuntimeType.TryFrom, + "ConstraintViolation" to constraintViolation, + ) writer.documentShape(shape, model) writer.docs(rustDocsConstrainedTypeEpilogue(name)) @@ -134,11 +135,12 @@ class ConstrainedMapGenerator( ) { val keyShape = model.expectShape(shape.key.target, StringShape::class.java) val keyNeedsConversion = keyShape.typeNameContainsNonPublicType(model, symbolProvider, publicConstrainedTypes) - val key = if (keyNeedsConversion) { - "k.into()" - } else { - "k" - } + val key = + if (keyNeedsConversion) { + "k.into()" + } else { + "k" + } writer.rustTemplate( """ diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedMapGeneratorCommon.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedMapGeneratorCommon.kt index fb5ce1daee0..0c88a447c47 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedMapGeneratorCommon.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedMapGeneratorCommon.kt @@ -16,7 +16,13 @@ import software.amazon.smithy.rust.codegen.server.smithy.isDirectlyConstrained * Common helper functions used in [UnconstrainedMapGenerator] and [MapConstraintViolationGenerator]. */ -fun isKeyConstrained(shape: StringShape, symbolProvider: SymbolProvider) = shape.isDirectlyConstrained(symbolProvider) +fun isKeyConstrained( + shape: StringShape, + symbolProvider: SymbolProvider, +) = shape.isDirectlyConstrained(symbolProvider) -fun isValueConstrained(shape: Shape, model: Model, symbolProvider: SymbolProvider): Boolean = - shape.canReachConstrainedShape(model, symbolProvider) +fun isValueConstrained( + shape: Shape, + model: Model, + symbolProvider: SymbolProvider, +): Boolean = shape.canReachConstrainedShape(model, symbolProvider) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedNumberGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedNumberGenerator.kt index 9680865a915..248f9a57747 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedNumberGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedNumberGenerator.kt @@ -50,13 +50,14 @@ class ConstrainedNumberGenerator( val constrainedShapeSymbolProvider = codegenContext.constrainedShapeSymbolProvider val publicConstrainedTypes = codegenContext.settings.codegenConfig.publicConstrainedTypes - private val unconstrainedType = when (shape) { - is ByteShape -> RustType.Integer(8) - is ShortShape -> RustType.Integer(16) - is IntegerShape -> RustType.Integer(32) - is LongShape -> RustType.Integer(64) - else -> UNREACHABLE("Trying to generate a constrained number for an unsupported Smithy number shape") - } + private val unconstrainedType = + when (shape) { + is ByteShape -> RustType.Integer(8) + is ShortShape -> RustType.Integer(16) + is IntegerShape -> RustType.Integer(32) + is LongShape -> RustType.Integer(64) + else -> UNREACHABLE("Trying to generate a constrained number for an unsupported Smithy number shape") + } private val constraintViolationSymbolProvider = with(codegenContext.constraintViolationSymbolProvider) { @@ -158,46 +159,52 @@ class ConstrainedNumberGenerator( } data class Range(val rangeTrait: RangeTrait) { - fun toTraitInfo(): TraitInfo = TraitInfo( - { rust("Self::check_range(value)?;") }, - { docs("Error when a number doesn't satisfy its `@range` requirements.") }, - { - rust( - """ - Self::Range(_) => crate::model::ValidationExceptionField { - message: format!("${rangeTrait.validationErrorMessage()}", &path), - path, - }, - """, - ) - }, - this::renderValidationFunction, - ) + fun toTraitInfo(): TraitInfo = + TraitInfo( + { rust("Self::check_range(value)?;") }, + { docs("Error when a number doesn't satisfy its `@range` requirements.") }, + { + rust( + """ + Self::Range(_) => crate::model::ValidationExceptionField { + message: format!("${rangeTrait.validationErrorMessage()}", &path), + path, + }, + """, + ) + }, + this::renderValidationFunction, + ) /** * Renders a `check_range` function to validate that the value matches the * required range indicated by the `@range` trait. */ - private fun renderValidationFunction(constraintViolation: Symbol, unconstrainedTypeName: String): Writable = { - val valueVariableName = "value" - val condition = if (rangeTrait.min.isPresent && rangeTrait.max.isPresent) { - "(${rangeTrait.min.get()}..=${rangeTrait.max.get()}).contains(&$valueVariableName)" - } else if (rangeTrait.min.isPresent) { - "${rangeTrait.min.get()} <= $valueVariableName" - } else { - "$valueVariableName <= ${rangeTrait.max.get()}" - } - - rust( - """ - fn check_range($valueVariableName: $unconstrainedTypeName) -> Result<(), $constraintViolation> { - if $condition { - Ok(()) + private fun renderValidationFunction( + constraintViolation: Symbol, + unconstrainedTypeName: String, + ): Writable = + { + val valueVariableName = "value" + val condition = + if (rangeTrait.min.isPresent && rangeTrait.max.isPresent) { + "(${rangeTrait.min.get()}..=${rangeTrait.max.get()}).contains(&$valueVariableName)" + } else if (rangeTrait.min.isPresent) { + "${rangeTrait.min.get()} <= $valueVariableName" } else { - Err($constraintViolation::Range($valueVariableName)) + "$valueVariableName <= ${rangeTrait.max.get()}" } - } - """, - ) - } + + rust( + """ + fn check_range($valueVariableName: $unconstrainedTypeName) -> Result<(), $constraintViolation> { + if $condition { + Ok(()) + } else { + Err($constraintViolation::Range($valueVariableName)) + } + } + """, + ) + } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedShapeGeneratorCommon.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedShapeGeneratorCommon.kt index 1e63fda7125..e001eda9818 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedShapeGeneratorCommon.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedShapeGeneratorCommon.kt @@ -9,18 +9,20 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators * Functions shared amongst the constrained shape generators, to keep them DRY and consistent. */ -fun rustDocsConstrainedTypeEpilogue(typeName: String) = """ +fun rustDocsConstrainedTypeEpilogue(typeName: String) = + """ This is a constrained type because its corresponding modeled Smithy shape has one or more [constraint traits]. Use [`$typeName::try_from`] to construct values of this type. [constraint traits]: https://awslabs.github.io/smithy/1.0/spec/core/constraint-traits.html -""" + """ -fun rustDocsTryFromMethod(typeName: String, inner: String) = +fun rustDocsTryFromMethod( + typeName: String, + inner: String, +) = "Constructs a `$typeName` from an [`$inner`], failing when the provided value does not satisfy the modeled constraints." -fun rustDocsInnerMethod(inner: String) = - "Returns an immutable reference to the underlying [`$inner`]." +fun rustDocsInnerMethod(inner: String) = "Returns an immutable reference to the underlying [`$inner`]." -fun rustDocsIntoInnerMethod(inner: String) = - "Consumes the value, returning the underlying [`$inner`]." +fun rustDocsIntoInnerMethod(inner: String) = "Consumes the value, returning the underlying [`$inner`]." diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedStringGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedStringGenerator.kt index 7ab164b39bf..994dff469fe 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedStringGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedStringGenerator.kt @@ -147,7 +147,11 @@ class ConstrainedStringGenerator( renderTests(shape) } - private fun renderConstraintViolationEnum(writer: RustWriter, shape: StringShape, constraintViolation: Symbol) { + private fun renderConstraintViolationEnum( + writer: RustWriter, + shape: StringShape, + constraintViolation: Symbol, + ) { writer.rustTemplate( """ ##[derive(Debug, PartialEq)] @@ -186,49 +190,55 @@ class ConstrainedStringGenerator( } } } + data class Length(val lengthTrait: LengthTrait) : StringTraitInfo() { - override fun toTraitInfo(): TraitInfo = TraitInfo( - tryFromCheck = { rust("Self::check_length(&value)?;") }, - constraintViolationVariant = { - docs("Error when a string doesn't satisfy its `@length` requirements.") - rust("Length(usize)") - }, - asValidationExceptionField = { - rust( - """ - Self::Length(length) => crate::model::ValidationExceptionField { - message: format!("${lengthTrait.validationErrorMessage()}", length, &path), - path, - }, - """, - ) - }, - validationFunctionDefinition = this::renderValidationFunction, - ) + override fun toTraitInfo(): TraitInfo = + TraitInfo( + tryFromCheck = { rust("Self::check_length(&value)?;") }, + constraintViolationVariant = { + docs("Error when a string doesn't satisfy its `@length` requirements.") + rust("Length(usize)") + }, + asValidationExceptionField = { + rust( + """ + Self::Length(length) => crate::model::ValidationExceptionField { + message: format!("${lengthTrait.validationErrorMessage()}", length, &path), + path, + }, + """, + ) + }, + validationFunctionDefinition = this::renderValidationFunction, + ) /** * Renders a `check_length` function to validate the string matches the * required length indicated by the `@length` trait. */ @Suppress("UNUSED_PARAMETER") - private fun renderValidationFunction(constraintViolation: Symbol, unconstrainedTypeName: String): Writable = { - // Note that we're using the linear time check `chars().count()` instead of `len()` on the input value, since the - // Smithy specification says the `length` trait counts the number of Unicode code points when applied to string shapes. - // https://awslabs.github.io/smithy/1.0/spec/core/constraint-traits.html#length-trait - rust( - """ - fn check_length(string: &str) -> Result<(), $constraintViolation> { - let length = string.chars().count(); + private fun renderValidationFunction( + constraintViolation: Symbol, + unconstrainedTypeName: String, + ): Writable = + { + // Note that we're using the linear time check `chars().count()` instead of `len()` on the input value, since the + // Smithy specification says the `length` trait counts the number of Unicode code points when applied to string shapes. + // https://awslabs.github.io/smithy/1.0/spec/core/constraint-traits.html#length-trait + rust( + """ + fn check_length(string: &str) -> Result<(), $constraintViolation> { + let length = string.chars().count(); - if ${lengthTrait.rustCondition("length")} { - Ok(()) - } else { - Err($constraintViolation::Length(length)) + if ${lengthTrait.rustCondition("length")} { + Ok(()) + } else { + Err($constraintViolation::Length(length)) + } } - } - """, - ) - } + """, + ) + } } data class Pattern(val symbol: Symbol, val patternTrait: PatternTrait, val isSensitive: Boolean) : StringTraitInfo() { @@ -253,16 +263,17 @@ data class Pattern(val symbol: Symbol, val patternTrait: PatternTrait, val isSen ) }, this::renderValidationFunction, - testCases = listOf { - unitTest("regex_compiles") { - rustTemplate( - """ - #{T}::compile_regex(); - """, - "T" to symbol, - ) - } - }, + testCases = + listOf { + unitTest("regex_compiles") { + rustTemplate( + """ + #{T}::compile_regex(); + """, + "T" to symbol, + ) + } + }, ) } @@ -282,7 +293,10 @@ data class Pattern(val symbol: Symbol, val patternTrait: PatternTrait, val isSen * Renders a `check_pattern` function to validate the string matches the * supplied regex in the `@pattern` trait. */ - private fun renderValidationFunction(constraintViolation: Symbol, unconstrainedTypeName: String): Writable { + private fun renderValidationFunction( + constraintViolation: Symbol, + unconstrainedTypeName: String, + ): Writable { val pattern = patternTrait.pattern val errorMessageForUnsupportedRegex = """The regular expression $pattern is not supported by the `regex` crate; feel free to file an issue under https://github.com/smithy-lang/smithy-rs/issues for support""" @@ -317,18 +331,21 @@ data class Pattern(val symbol: Symbol, val patternTrait: PatternTrait, val isSen sealed class StringTraitInfo { companion object { - fun fromTrait(symbol: Symbol, trait: Trait, isSensitive: Boolean) = - when (trait) { - is PatternTrait -> { - Pattern(symbol, trait, isSensitive) - } - - is LengthTrait -> { - Length(trait) - } + fun fromTrait( + symbol: Symbol, + trait: Trait, + isSensitive: Boolean, + ) = when (trait) { + is PatternTrait -> { + Pattern(symbol, trait, isSensitive) + } - else -> PANIC("StringTraitInfo.fromTrait called with unsupported trait $trait") + is LengthTrait -> { + Length(trait) } + + else -> PANIC("StringTraitInfo.fromTrait called with unsupported trait $trait") + } } abstract fun toTraitInfo(): TraitInfo diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/DocHandlerGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/DocHandlerGenerator.kt index d67b79b4553..1872bc29ae5 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/DocHandlerGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/DocHandlerGenerator.kt @@ -36,11 +36,12 @@ class DocHandlerGenerator( * Returns the function signature for an operation handler implementation. Used in the documentation. */ fun docSignature(): Writable { - val outputT = if (operation.errors.isEmpty()) { - "${OutputModule.name}::${outputSymbol.name}" - } else { - "Result<${OutputModule.name}::${outputSymbol.name}, ${ErrorModule.name}::${errorSymbol.name}>" - } + val outputT = + if (operation.errors.isEmpty()) { + "${OutputModule.name}::${outputSymbol.name}" + } else { + "Result<${OutputModule.name}::${outputSymbol.name}, ${ErrorModule.name}::${errorSymbol.name}>" + } return writable { rust( @@ -58,11 +59,12 @@ class DocHandlerGenerator( * difference that we don't ellide the error for use in `tower::service_fn`. */ fun docFixedSignature(): Writable { - val errorT = if (operation.errors.isEmpty()) { - "std::convert::Infallible" - } else { - "${ErrorModule.name}::${errorSymbol.name}" - } + val errorT = + if (operation.errors.isEmpty()) { + "std::convert::Infallible" + } else { + "${ErrorModule.name}::${errorSymbol.name}" + } return writable { rust( diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/LenghTraitCommon.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/LenghTraitCommon.kt index bc074126ee1..ab8691b0cd2 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/LenghTraitCommon.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/LenghTraitCommon.kt @@ -8,13 +8,14 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators import software.amazon.smithy.model.traits.LengthTrait fun LengthTrait.rustCondition(lengthVariable: String): String { - val condition = if (min.isPresent && max.isPresent) { - "(${min.get()}..=${max.get()}).contains(&$lengthVariable)" - } else if (min.isPresent) { - "${min.get()} <= $lengthVariable" - } else { - "$lengthVariable <= ${max.get()}" - } + val condition = + if (min.isPresent && max.isPresent) { + "(${min.get()}..=${max.get()}).contains(&$lengthVariable)" + } else if (min.isPresent) { + "${min.get()} <= $lengthVariable" + } else { + "$lengthVariable <= ${max.get()}" + } return condition } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/MapConstraintViolationGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/MapConstraintViolationGenerator.kt index 38917b14776..eb4fcc66879 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/MapConstraintViolationGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/MapConstraintViolationGenerator.kt @@ -63,11 +63,12 @@ class MapConstraintViolationGenerator( } val constraintViolationCodegenScope = constraintViolationCodegenScopeMutableList.toTypedArray() - val constraintViolationVisibility = if (publicConstrainedTypes) { - Visibility.PUBLIC - } else { - Visibility.PUBCRATE - } + val constraintViolationVisibility = + if (publicConstrainedTypes) { + Visibility.PUBLIC + } else { + Visibility.PUBCRATE + } inlineModuleCreator(constraintViolationSymbol) { // TODO(https://github.com/smithy-lang/smithy-rs/issues/1401) We should really have two `ConstraintViolation` @@ -94,13 +95,14 @@ class MapConstraintViolationGenerator( #{MapShapeConstraintViolationImplBlock} } """, - "MapShapeConstraintViolationImplBlock" to validationExceptionConversionGenerator.mapShapeConstraintViolationImplBlock( - shape, - keyShape, - valueShape, - symbolProvider, - model, - ), + "MapShapeConstraintViolationImplBlock" to + validationExceptionConversionGenerator.mapShapeConstraintViolationImplBlock( + shape, + keyShape, + valueShape, + symbolProvider, + model, + ), ) } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/PubCrateConstrainedCollectionGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/PubCrateConstrainedCollectionGenerator.kt index 016da351773..294719fd7d3 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/PubCrateConstrainedCollectionGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/PubCrateConstrainedCollectionGenerator.kt @@ -59,19 +59,21 @@ class PubCrateConstrainedCollectionGenerator( val unconstrainedSymbol = unconstrainedShapeSymbolProvider.toSymbol(shape) val name = constrainedSymbol.name val innerShape = model.expectShape(shape.member.target) - val innerMemberSymbol = if (innerShape.isTransitivelyButNotDirectlyConstrained(model, symbolProvider)) { - pubCrateConstrainedShapeSymbolProvider.toSymbol(shape.member) - } else { - constrainedShapeSymbolProvider.toSymbol(shape.member) - } + val innerMemberSymbol = + if (innerShape.isTransitivelyButNotDirectlyConstrained(model, symbolProvider)) { + pubCrateConstrainedShapeSymbolProvider.toSymbol(shape.member) + } else { + constrainedShapeSymbolProvider.toSymbol(shape.member) + } - val codegenScope = arrayOf( - "InnerMemberSymbol" to innerMemberSymbol, - "ConstrainedTrait" to RuntimeType.ConstrainedTrait, - "UnconstrainedSymbol" to unconstrainedSymbol, - "Symbol" to symbol, - "From" to RuntimeType.From, - ) + val codegenScope = + arrayOf( + "InnerMemberSymbol" to innerMemberSymbol, + "ConstrainedTrait" to RuntimeType.ConstrainedTrait, + "UnconstrainedSymbol" to unconstrainedSymbol, + "Symbol" to symbol, + "From" to RuntimeType.From, + ) inlineModuleCreator(constrainedSymbol) { rustTemplate( diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/PubCrateConstrainedMapGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/PubCrateConstrainedMapGenerator.kt index 838d4da0856..7a62171fc66 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/PubCrateConstrainedMapGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/PubCrateConstrainedMapGenerator.kt @@ -59,20 +59,22 @@ class PubCrateConstrainedMapGenerator( val keyShape = model.expectShape(shape.key.target, StringShape::class.java) val valueShape = model.expectShape(shape.value.target) val keySymbol = constrainedShapeSymbolProvider.toSymbol(keyShape) - val valueMemberSymbol = if (valueShape.isTransitivelyButNotDirectlyConstrained(model, symbolProvider)) { - pubCrateConstrainedShapeSymbolProvider.toSymbol(shape.value) - } else { - constrainedShapeSymbolProvider.toSymbol(shape.value) - } + val valueMemberSymbol = + if (valueShape.isTransitivelyButNotDirectlyConstrained(model, symbolProvider)) { + pubCrateConstrainedShapeSymbolProvider.toSymbol(shape.value) + } else { + constrainedShapeSymbolProvider.toSymbol(shape.value) + } - val codegenScope = arrayOf( - "KeySymbol" to keySymbol, - "ValueMemberSymbol" to valueMemberSymbol, - "ConstrainedTrait" to RuntimeType.ConstrainedTrait, - "UnconstrainedSymbol" to unconstrainedSymbol, - "Symbol" to symbol, - "From" to RuntimeType.From, - ) + val codegenScope = + arrayOf( + "KeySymbol" to keySymbol, + "ValueMemberSymbol" to valueMemberSymbol, + "ConstrainedTrait" to RuntimeType.ConstrainedTrait, + "UnconstrainedSymbol" to unconstrainedSymbol, + "Symbol" to symbol, + "From" to RuntimeType.From, + ) inlineModuleCreator(constrainedSymbol) { rustTemplate( diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ScopeMacroGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ScopeMacroGenerator.kt index 7c3acff6ab2..1329c2b7cdb 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ScopeMacroGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ScopeMacroGenerator.kt @@ -27,153 +27,158 @@ class ScopeMacroGenerator( private val index = TopDownIndex.of(codegenContext.model) private val operations = index.getContainedOperations(codegenContext.serviceShape).toSortedSet(compareBy { it.id }) - private fun macro(): Writable = writable { - val firstOperationName = codegenContext.symbolProvider.toSymbol(operations.first()).name.toPascalCase() - val operationNames = operations.joinToString(" ") { - codegenContext.symbolProvider.toSymbol(it).name.toPascalCase() - } + private fun macro(): Writable = + writable { + val firstOperationName = codegenContext.symbolProvider.toSymbol(operations.first()).name.toPascalCase() + val operationNames = + operations.joinToString(" ") { + codegenContext.symbolProvider.toSymbol(it).name.toPascalCase() + } - // When writing `macro_rules!` we add whitespace between `$` and the arguments to avoid Kotlin templating. + // When writing `macro_rules!` we add whitespace between `$` and the arguments to avoid Kotlin templating. - // To acheive the desired API we need to calculate the set theoretic complement `B \ A`. - // The macro below, for rules prefixed with `@`, encodes a state machine which performs this. - // The initial state is `(A) () (B)`, where `A` and `B` are lists of elements of `A` and `B`. - // The rules, in order: - // - Terminate on pattern `() (t0, t1, ...) (b0, b1, ...)`, the complement has been calculated as - // `{ t0, t1, ..., b0, b1, ...}`. - // - Send pattern `(x, a0, a1, ...) (t0, t1, ...) (x, b0, b1, ...)` to - // `(a0, a1, ...) (t0, t1, ...) (b0, b1, ...)`, eliminating a matching `x` from `A` and `B`. - // - Send pattern `(a0, a1, ...) (t0, t1, ...) ()` to `(a0, a1, ...) () (t0, t1, ...)`, restarting the search. - // - Send pattern `(a0, a1, ...) (t0, t1, ...) (b0, b1, ...)` to `(a0, a1, ...) (b0, t0, t1, ...) (b1, ...)`, - // iterating through the `B`. - val operationBranches = operations - .map { codegenContext.symbolProvider.toSymbol(it).name.toPascalCase() }.joinToString("") { - """ - // $it match found, pop from both `member` and `not_member` - (@ $ name: ident, $ contains: ident ($it $($ member: ident)*) ($($ temp: ident)*) ($it $($ not_member: ident)*)) => { - scope! { @ $ name, $ contains ($($ member)*) ($($ temp)*) ($($ not_member)*) } - }; - // $it match not found, pop from `not_member` into `temp` stack - (@ $ name: ident, $ contains: ident ($it $($ member: ident)*) ($($ temp: ident)*) ($ other: ident $($ not_member: ident)*)) => { - scope! { @ $ name, $ contains ($it $($ member)*) ($ other $($ temp)*) ($($ not_member)*) } - }; - """ - } - val crateName = codegenContext.moduleUseName() + // To acheive the desired API we need to calculate the set theoretic complement `B \ A`. + // The macro below, for rules prefixed with `@`, encodes a state machine which performs this. + // The initial state is `(A) () (B)`, where `A` and `B` are lists of elements of `A` and `B`. + // The rules, in order: + // - Terminate on pattern `() (t0, t1, ...) (b0, b1, ...)`, the complement has been calculated as + // `{ t0, t1, ..., b0, b1, ...}`. + // - Send pattern `(x, a0, a1, ...) (t0, t1, ...) (x, b0, b1, ...)` to + // `(a0, a1, ...) (t0, t1, ...) (b0, b1, ...)`, eliminating a matching `x` from `A` and `B`. + // - Send pattern `(a0, a1, ...) (t0, t1, ...) ()` to `(a0, a1, ...) () (t0, t1, ...)`, restarting the search. + // - Send pattern `(a0, a1, ...) (t0, t1, ...) (b0, b1, ...)` to `(a0, a1, ...) (b0, t0, t1, ...) (b1, ...)`, + // iterating through the `B`. + val operationBranches = + operations + .map { codegenContext.symbolProvider.toSymbol(it).name.toPascalCase() }.joinToString("") { + """ + // $it match found, pop from both `member` and `not_member` + (@ $ name: ident, $ contains: ident ($it $($ member: ident)*) ($($ temp: ident)*) ($it $($ not_member: ident)*)) => { + scope! { @ $ name, $ contains ($($ member)*) ($($ temp)*) ($($ not_member)*) } + }; + // $it match not found, pop from `not_member` into `temp` stack + (@ $ name: ident, $ contains: ident ($it $($ member: ident)*) ($($ temp: ident)*) ($ other: ident $($ not_member: ident)*)) => { + scope! { @ $ name, $ contains ($it $($ member)*) ($ other $($ temp)*) ($($ not_member)*) } + }; + """ + } + val crateName = codegenContext.moduleUseName() - // If we have a second operation we can perform further checks - val otherOperationName: String? = operations.toList().getOrNull(1)?.let { - codegenContext.symbolProvider.toSymbol(it).name - } - val furtherTests = if (otherOperationName != null) { - writable { - rustTemplate( - """ - /// ## let a = Plugin::<(), $otherOperationName, u64>::apply(&scoped_a, 6); - /// ## let b = Plugin::<(), $otherOperationName, u64>::apply(&scoped_b, 6); - /// ## assert_eq!(a, 6_u64); - /// ## assert_eq!(b, 3_u32); - """, - ) - } - } else { - writable {} - } + // If we have a second operation we can perform further checks + val otherOperationName: String? = + operations.toList().getOrNull(1)?.let { + codegenContext.symbolProvider.toSymbol(it).name + } + val furtherTests = + if (otherOperationName != null) { + writable { + rustTemplate( + """ + /// ## let a = Plugin::<(), $otherOperationName, u64>::apply(&scoped_a, 6); + /// ## let b = Plugin::<(), $otherOperationName, u64>::apply(&scoped_b, 6); + /// ## assert_eq!(a, 6_u64); + /// ## assert_eq!(b, 3_u32); + """, + ) + } + } else { + writable {} + } - rustTemplate( - """ - /// A macro to help with scoping [plugins](#{SmithyHttpServer}::plugin) to a subset of all operations. - /// - /// In contrast to [`aws_smithy_http_server::scope`](#{SmithyHttpServer}::scope), this macro has knowledge - /// of the service and any operations _not_ specified will be placed in the opposing group. - /// - /// ## Example - /// - /// ```rust - /// scope! { - /// /// Includes [`$firstOperationName`], excluding all other operations. - /// struct ScopeA { - /// includes: [$firstOperationName] - /// } - /// } - /// - /// scope! { - /// /// Excludes [`$firstOperationName`], excluding all other operations. - /// struct ScopeB { - /// excludes: [$firstOperationName] - /// } - /// } - /// - /// ## use #{SmithyHttpServer}::plugin::{Plugin, Scoped}; - /// ## use $crateName::scope; - /// ## struct MockPlugin; - /// ## impl Plugin for MockPlugin { type Output = u32; fn apply(&self, input: T) -> u32 { 3 } } - /// ## let scoped_a = Scoped::new::(MockPlugin); - /// ## let scoped_b = Scoped::new::(MockPlugin); - /// ## let a = Plugin::<(), $crateName::operation_shape::$firstOperationName, u64>::apply(&scoped_a, 6); - /// ## let b = Plugin::<(), $crateName::operation_shape::$firstOperationName, u64>::apply(&scoped_b, 6); - /// ## assert_eq!(a, 3_u32); - /// ## assert_eq!(b, 6_u64); - /// ``` - ##[macro_export] - macro_rules! scope { - // Completed, render impls - (@ $ name: ident, $ contains: ident () ($($ temp: ident)*) ($($ not_member: ident)*)) => { - $( - impl #{SmithyHttpServer}::plugin::scoped::Membership<$ temp> for $ name { - type Contains = #{SmithyHttpServer}::plugin::scoped::$ contains; + rustTemplate( + """ + /// A macro to help with scoping [plugins](#{SmithyHttpServer}::plugin) to a subset of all operations. + /// + /// In contrast to [`aws_smithy_http_server::scope`](#{SmithyHttpServer}::scope), this macro has knowledge + /// of the service and any operations _not_ specified will be placed in the opposing group. + /// + /// ## Example + /// + /// ```rust + /// scope! { + /// /// Includes [`$firstOperationName`], excluding all other operations. + /// struct ScopeA { + /// includes: [$firstOperationName] + /// } + /// } + /// + /// scope! { + /// /// Excludes [`$firstOperationName`], excluding all other operations. + /// struct ScopeB { + /// excludes: [$firstOperationName] + /// } + /// } + /// + /// ## use #{SmithyHttpServer}::plugin::{Plugin, Scoped}; + /// ## use $crateName::scope; + /// ## struct MockPlugin; + /// ## impl Plugin for MockPlugin { type Output = u32; fn apply(&self, input: T) -> u32 { 3 } } + /// ## let scoped_a = Scoped::new::(MockPlugin); + /// ## let scoped_b = Scoped::new::(MockPlugin); + /// ## let a = Plugin::<(), $crateName::operation_shape::$firstOperationName, u64>::apply(&scoped_a, 6); + /// ## let b = Plugin::<(), $crateName::operation_shape::$firstOperationName, u64>::apply(&scoped_b, 6); + /// ## assert_eq!(a, 3_u32); + /// ## assert_eq!(b, 6_u64); + /// ``` + ##[macro_export] + macro_rules! scope { + // Completed, render impls + (@ $ name: ident, $ contains: ident () ($($ temp: ident)*) ($($ not_member: ident)*)) => { + $( + impl #{SmithyHttpServer}::plugin::scoped::Membership<$ temp> for $ name { + type Contains = #{SmithyHttpServer}::plugin::scoped::$ contains; + } + )* + $( + impl #{SmithyHttpServer}::plugin::scoped::Membership<$ not_member> for $ name { + type Contains = #{SmithyHttpServer}::plugin::scoped::$ contains; + } + )* + }; + // All `not_member`s exhausted, move `temp` into `not_member` + (@ $ name: ident, $ contains: ident ($($ member: ident)*) ($($ temp: ident)*) ()) => { + scope! { @ $ name, $ contains ($($ member)*) () ($($ temp)*) } + }; + $operationBranches + ( + $(##[$ attrs:meta])* + $ vis:vis struct $ name:ident { + includes: [$($ include:ident),*] } - )* - $( - impl #{SmithyHttpServer}::plugin::scoped::Membership<$ not_member> for $ name { - type Contains = #{SmithyHttpServer}::plugin::scoped::$ contains; + ) => { + use $ crate::operation_shape::*; + #{SmithyHttpServer}::scope! { + $(##[$ attrs])* + $ vis struct $ name { + includes: [$($ include),*], + excludes: [] + } } - )* - }; - // All `not_member`s exhausted, move `temp` into `not_member` - (@ $ name: ident, $ contains: ident ($($ member: ident)*) ($($ temp: ident)*) ()) => { - scope! { @ $ name, $ contains ($($ member)*) () ($($ temp)*) } - }; - $operationBranches - ( - $(##[$ attrs:meta])* - $ vis:vis struct $ name:ident { - includes: [$($ include:ident),*] - } - ) => { - use $ crate::operation_shape::*; - #{SmithyHttpServer}::scope! { - $(##[$ attrs])* - $ vis struct $ name { - includes: [$($ include),*], - excludes: [] + scope! { @ $ name, False ($($ include)*) () ($operationNames) } + }; + ( + $(##[$ attrs:meta])* + $ vis:vis struct $ name:ident { + excludes: [$($ exclude:ident),*] } - } - scope! { @ $ name, False ($($ include)*) () ($operationNames) } - }; - ( - $(##[$ attrs:meta])* - $ vis:vis struct $ name:ident { - excludes: [$($ exclude:ident),*] - } - ) => { - use $ crate::operation_shape::*; + ) => { + use $ crate::operation_shape::*; - #{SmithyHttpServer}::scope! { - $(##[$ attrs])* - $ vis struct $ name { - includes: [], - excludes: [$($ exclude),*] + #{SmithyHttpServer}::scope! { + $(##[$ attrs])* + $ vis struct $ name { + includes: [], + excludes: [$($ exclude),*] + } } - } - scope! { @ $ name, True ($($ exclude)*) () ($operationNames) } - }; - } - """, - *codegenScope, - "FurtherTests" to furtherTests, - ) - } + scope! { @ $ name, True ($($ exclude)*) () ($operationNames) } + }; + } + """, + *codegenScope, + "FurtherTests" to furtherTests, + ) + } fun render(writer: RustWriter) { macro()(writer) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderConstraintViolations.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderConstraintViolations.kt index d91176c0f8f..3e9595ca638 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderConstraintViolations.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderConstraintViolations.kt @@ -50,12 +50,13 @@ class ServerBuilderConstraintViolations( } } private val members: List = shape.allMembers.values.toList() - val all = members.flatMap { member -> - listOfNotNull( - forMember(member), - builderConstraintViolationForMember(member), - ) - } + val all = + members.flatMap { member -> + listOfNotNull( + forMember(member), + builderConstraintViolationForMember(member), + ) + } fun render( writer: RustWriter, @@ -117,11 +118,12 @@ class ServerBuilderConstraintViolations( rustBlock("fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result") { rustBlock("match self") { all.forEach { - val arm = if (it.hasInner()) { - "ConstraintViolation::${it.name()}(_)" - } else { - "ConstraintViolation::${it.name()}" - } + val arm = + if (it.hasInner()) { + "ConstraintViolation::${it.name()}(_)" + } else { + "ConstraintViolation::${it.name()}" + } rust("""$arm => write!(f, "${it.message(symbolProvider, model)}"),""") } } @@ -192,10 +194,11 @@ enum class ConstraintViolationKind { } data class ConstraintViolation(val forMember: MemberShape, val kind: ConstraintViolationKind) { - fun name() = when (kind) { - ConstraintViolationKind.MISSING_MEMBER -> "Missing${forMember.memberName.toPascalCase()}" - ConstraintViolationKind.CONSTRAINED_SHAPE_FAILURE -> forMember.memberName.toPascalCase() - } + fun name() = + when (kind) { + ConstraintViolationKind.MISSING_MEMBER -> "Missing${forMember.memberName.toPascalCase()}" + ConstraintViolationKind.CONSTRAINED_SHAPE_FAILURE -> forMember.memberName.toPascalCase() + } /** * Whether the constraint violation is a Rust tuple struct with one element. @@ -205,7 +208,10 @@ data class ConstraintViolation(val forMember: MemberShape, val kind: ConstraintV /** * A message for a `ConstraintViolation` variant. This is used in both Rust documentation and the `Display` trait implementation. */ - fun message(symbolProvider: SymbolProvider, model: Model): String { + fun message( + symbolProvider: SymbolProvider, + model: Model, + ): String { val memberName = symbolProvider.toMemberName(forMember) val structureSymbol = symbolProvider.toSymbol(model.expectShape(forMember.container)) return when (kind) { diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGenerator.kt index b8c74b9c5c1..3b9ca3df3a4 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGenerator.kt @@ -107,24 +107,28 @@ class ServerBuilderGenerator( takeInUnconstrainedTypes: Boolean, ): Boolean { val members = structureShape.members() + fun isOptional(member: MemberShape) = symbolProvider.toSymbol(member).isOptional() + fun hasDefault(member: MemberShape) = member.hasNonNullDefault() + fun isNotConstrained(member: MemberShape) = !member.canReachConstrainedShape(model, symbolProvider) - val notFallible = members.all { - if (structureShape.isReachableFromOperationInput()) { - // When deserializing an input structure, constraints might not be satisfied by the data in the - // incoming request. - // For this builder not to be fallible, no members must be constrained (constraints in input must - // always be checked) and all members must _either_ be optional (no need to set it; not required) - // or have a default value. - isNotConstrained(it) && (isOptional(it) || hasDefault(it)) - } else { - // This structure will be constructed manually by the user. - // Constraints will have to be dealt with before members are set in the builder. - isOptional(it) || hasDefault(it) + val notFallible = + members.all { + if (structureShape.isReachableFromOperationInput()) { + // When deserializing an input structure, constraints might not be satisfied by the data in the + // incoming request. + // For this builder not to be fallible, no members must be constrained (constraints in input must + // always be checked) and all members must _either_ be optional (no need to set it; not required) + // or have a default value. + isNotConstrained(it) && (isOptional(it) || hasDefault(it)) + } else { + // This structure will be constructed manually by the user. + // Constraints will have to be dealt with before members are set in the builder. + isOptional(it) || hasDefault(it) + } } - } return if (takeInUnconstrainedTypes) { !notFallible && structureShape.canReachConstrainedShape(model, symbolProvider) @@ -150,15 +154,19 @@ class ServerBuilderGenerator( ServerBuilderConstraintViolations(codegenContext, shape, takeInUnconstrainedTypes, customValidationExceptionWithReasonConversionGenerator) private val lifetime = shape.lifetimeDeclaration(symbolProvider) - private val codegenScope = arrayOf( - "RequestRejection" to protocol.requestRejection(codegenContext.runtimeConfig), - "Structure" to structureSymbol, - "From" to RuntimeType.From, - "TryFrom" to RuntimeType.TryFrom, - "MaybeConstrained" to RuntimeType.MaybeConstrained, - ) + private val codegenScope = + arrayOf( + "RequestRejection" to protocol.requestRejection(codegenContext.runtimeConfig), + "Structure" to structureSymbol, + "From" to RuntimeType.From, + "TryFrom" to RuntimeType.TryFrom, + "MaybeConstrained" to RuntimeType.MaybeConstrained, + ) - fun render(rustCrate: RustCrate, writer: RustWriter) { + fun render( + rustCrate: RustCrate, + writer: RustWriter, + ) { val docWriter: () -> Unit = { writer.docs("See #D.", structureSymbol) } rustCrate.withInMemoryInlineModule(writer, builderSymbol.module(), docWriter) { renderBuilder(this) @@ -194,9 +202,10 @@ class ServerBuilderGenerator( // since we are a builder and everything is optional. val baseDerives = structureSymbol.expectRustMetadata().derives // Filter out any derive that isn't Debug or Clone. Then add a Default derive - val builderDerives = baseDerives.filter { - it == RuntimeType.Debug || it == RuntimeType.Clone - } + RuntimeType.Default + val builderDerives = + baseDerives.filter { + it == RuntimeType.Debug || it == RuntimeType.Clone + } + RuntimeType.Default Attribute(derive(builderDerives)).render(writer) writer.rustBlock("${visibility.toRustQualifier()} struct Builder$lifetime") { members.forEach { renderBuilderMember(this, it) } @@ -287,7 +296,10 @@ class ServerBuilderGenerator( } } - private fun renderBuilderMember(writer: RustWriter, member: MemberShape) { + private fun renderBuilderMember( + writer: RustWriter, + member: MemberShape, + ) { val memberSymbol = builderMemberSymbol(member) val memberName = constrainedShapeSymbolProvider.toMemberName(member) // Builder members are crate-public to enable using them directly in serializers/deserializers. @@ -383,14 +395,15 @@ class ServerBuilderGenerator( member: MemberShape, ) { val builderMemberSymbol = builderMemberSymbol(member) - val inputType = builderMemberSymbol.rustType().stripOuter().implInto() - .letIf( - // TODO(https://github.com/smithy-lang/smithy-rs/issues/1302, https://github.com/awslabs/smithy/issues/1179): - // The only reason why this condition can't simply be `member.isOptional` - // is because non-`required` blob streaming members are interpreted as - // `required`, so we can't use `member.isOptional` here. - symbolProvider.toSymbol(member).isOptional(), - ) { "Option<$it>" } + val inputType = + builderMemberSymbol.rustType().stripOuter().implInto() + .letIf( + // TODO(https://github.com/smithy-lang/smithy-rs/issues/1302, https://github.com/awslabs/smithy/issues/1179): + // The only reason why this condition can't simply be `member.isOptional` + // is because non-`required` blob streaming members are interpreted as + // `required`, so we can't use `member.isOptional` here. + symbolProvider.toSymbol(member).isOptional(), + ) { "Option<$it>" } val memberName = symbolProvider.toMemberName(member) writer.documentShape(member, model) @@ -463,13 +476,14 @@ class ServerBuilderGenerator( */ private fun builderMemberSymbol(member: MemberShape): Symbol = if (takeInUnconstrainedTypes && member.targetCanReachConstrainedShape(model, symbolProvider)) { - val strippedOption = if (member.hasConstraintTraitOrTargetHasConstraintTrait(model, symbolProvider)) { - constrainedShapeSymbolProvider.toSymbol(member) - } else { - pubCrateConstrainedShapeSymbolProvider.toSymbol(member) - } - // Strip the `Option` in case the member is not `required`. - .mapRustType { it.stripOuter() } + val strippedOption = + if (member.hasConstraintTraitOrTargetHasConstraintTrait(model, symbolProvider)) { + constrainedShapeSymbolProvider.toSymbol(member) + } else { + pubCrateConstrainedShapeSymbolProvider.toSymbol(member) + } + // Strip the `Option` in case the member is not `required`. + .mapRustType { it.stripOuter() } val hadBox = strippedOption.isRustBoxed() strippedOption @@ -543,13 +557,18 @@ class ServerBuilderGenerator( } } - private fun enforceConstraints(writer: RustWriter, member: MemberShape, constraintViolation: ConstraintViolation) { + private fun enforceConstraints( + writer: RustWriter, + member: MemberShape, + constraintViolation: ConstraintViolation, + ) { // This member is constrained. Enforce the constraint traits on the value set in the builder. // The code is slightly different in case the member is recursive, since it will be wrapped in // `std::boxed::Box`. - val hasBox = builderMemberSymbol(member) - .mapRustType { it.stripOuter() } - .isRustBoxed() + val hasBox = + builderMemberSymbol(member) + .mapRustType { it.stripOuter() } + .isRustBoxed() val errHasBox = member.hasTrait() if (hasBox) { diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGeneratorCommon.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGeneratorCommon.kt index 4912d9eb314..b40743524bd 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGeneratorCommon.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGeneratorCommon.kt @@ -47,7 +47,7 @@ import software.amazon.smithy.rust.codegen.core.util.isStreaming import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency import software.amazon.smithy.rust.codegen.server.smithy.hasPublicConstrainedWrapperTupleType -/** +/* * Some common freestanding functions shared across: * - [ServerBuilderGenerator]; and * - [ServerBuilderGeneratorWithoutPublicConstrainedTypes], @@ -57,7 +57,11 @@ import software.amazon.smithy.rust.codegen.server.smithy.hasPublicConstrainedWra /** * Returns a writable to render the return type of the server builders' `build()` method. */ -fun buildFnReturnType(isBuilderFallible: Boolean, structureSymbol: Symbol, lifetime: String) = writable { +fun buildFnReturnType( + isBuilderFallible: Boolean, + structureSymbol: Symbol, + lifetime: String, +) = writable { if (isBuilderFallible) { rust("Result<#T $lifetime, ConstraintViolation>", structureSymbol) } else { @@ -128,29 +132,33 @@ fun defaultValue( CodegenException("Default value $node for member shape ${member.id} is unsupported or cannot exist; please file a bug report under https://github.com/smithy-lang/smithy-rs/issues") when (val target = model.expectShape(member.target)) { is EnumShape, is IntEnumShape -> { - val value = when (target) { - is IntEnumShape -> node.expectNumberNode().value - is EnumShape -> node.expectStringNode().value - else -> throw CodegenException("Default value for shape ${target.id} must be of EnumShape or IntEnumShape") - } - val enumValues = when (target) { - is IntEnumShape -> target.enumValues - is EnumShape -> target.enumValues - else -> UNREACHABLE( - "Target shape ${target.id} must be an `EnumShape` or an `IntEnumShape` at this point, otherwise it would have failed above", - ) - } - val variant = enumValues - .entries - .filter { entry -> entry.value == value } - .map { entry -> - EnumMemberModel.toEnumVariantName( - symbolProvider, - target, - EnumDefinition.builder().name(entry.key).value(entry.value.toString()).build(), - )!! + val value = + when (target) { + is IntEnumShape -> node.expectNumberNode().value + is EnumShape -> node.expectStringNode().value + else -> throw CodegenException("Default value for shape ${target.id} must be of EnumShape or IntEnumShape") } - .first() + val enumValues = + when (target) { + is IntEnumShape -> target.enumValues + is EnumShape -> target.enumValues + else -> + UNREACHABLE( + "Target shape ${target.id} must be an `EnumShape` or an `IntEnumShape` at this point, otherwise it would have failed above", + ) + } + val variant = + enumValues + .entries + .filter { entry -> entry.value == value } + .map { entry -> + EnumMemberModel.toEnumVariantName( + symbolProvider, + target, + EnumDefinition.builder().name(entry.key).value(entry.value.toString()).build(), + )!! + } + .first() rust("#T::${variant.name}", symbolProvider.toSymbol(target)) } @@ -162,20 +170,21 @@ fun defaultValue( is DoubleShape -> rust(node.expectNumberNode().value.toDouble().toString() + "f64") is BooleanShape -> rust(node.expectBooleanNode().value.toString()) is StringShape -> rust("String::from(${node.expectStringNode().value.dq()})") - is TimestampShape -> when (node) { - is NumberNode -> rust(node.expectNumberNode().value.toString()) - is StringNode -> { - val value = node.expectStringNode().value - rustTemplate( - """ - #{SmithyTypes}::DateTime::from_str("$value", #{SmithyTypes}::date_time::Format::DateTime) - .expect("default value `$value` cannot be parsed into a valid date time; please file a bug report under https://github.com/smithy-lang/smithy-rs/issues") - """, - "SmithyTypes" to types, - ) + is TimestampShape -> + when (node) { + is NumberNode -> rust(node.expectNumberNode().value.toString()) + is StringNode -> { + val value = node.expectStringNode().value + rustTemplate( + """ + #{SmithyTypes}::DateTime::from_str("$value", #{SmithyTypes}::date_time::Format::DateTime) + .expect("default value `$value` cannot be parsed into a valid date time; please file a bug report under https://github.com/smithy-lang/smithy-rs/issues") + """, + "SmithyTypes" to types, + ) + } + else -> throw unsupportedDefaultValueException } - else -> throw unsupportedDefaultValueException - } is ListShape -> { check(node is ArrayNode && node.isEmpty) rust("Vec::new()") @@ -186,19 +195,21 @@ fun defaultValue( } is DocumentShape -> { when (node) { - is NullNode -> rustTemplate( - "#{SmithyTypes}::Document::Null", - "SmithyTypes" to types, - ) + is NullNode -> + rustTemplate( + "#{SmithyTypes}::Document::Null", + "SmithyTypes" to types, + ) is BooleanNode -> rustTemplate("""#{SmithyTypes}::Document::Bool(${node.value})""", "SmithyTypes" to types) is StringNode -> rustTemplate("#{SmithyTypes}::Document::String(String::from(${node.value.dq()}))", "SmithyTypes" to types) is NumberNode -> { val value = node.value.toString() - val variant = when (node.value) { - is Float, is Double -> "Float" - else -> if (node.value.toLong() >= 0) "PosInt" else "NegInt" - } + val variant = + when (node.value) { + is Float, is Double -> "Float" + else -> if (node.value.toLong() >= 0) "PosInt" else "NegInt" + } rustTemplate( "#{SmithyTypes}::Document::Number(#{SmithyTypes}::Number::$variant($value))", "SmithyTypes" to types, diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGeneratorWithoutPublicConstrainedTypes.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGeneratorWithoutPublicConstrainedTypes.kt index 7589a761c08..0e27fb0b493 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGeneratorWithoutPublicConstrainedTypes.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGeneratorWithoutPublicConstrainedTypes.kt @@ -64,12 +64,15 @@ class ServerBuilderGeneratorWithoutPublicConstrainedTypes( symbolProvider: SymbolProvider, ): Boolean { val members = structureShape.members() + fun isOptional(member: MemberShape) = symbolProvider.toSymbol(member).isOptional() + fun hasDefault(member: MemberShape) = member.hasNonNullDefault() - val notFallible = members.all { - isOptional(it) || hasDefault(it) - } + val notFallible = + members.all { + isOptional(it) || hasDefault(it) + } return !notFallible } @@ -87,15 +90,19 @@ class ServerBuilderGeneratorWithoutPublicConstrainedTypes( ServerBuilderConstraintViolations(codegenContext, shape, builderTakesInUnconstrainedTypes = false, validationExceptionConversionGenerator) private val lifetime = shape.lifetimeDeclaration(symbolProvider) - private val codegenScope = arrayOf( - "RequestRejection" to protocol.requestRejection(codegenContext.runtimeConfig), - "Structure" to structureSymbol, - "From" to RuntimeType.From, - "TryFrom" to RuntimeType.TryFrom, - "MaybeConstrained" to RuntimeType.MaybeConstrained, - ) + private val codegenScope = + arrayOf( + "RequestRejection" to protocol.requestRejection(codegenContext.runtimeConfig), + "Structure" to structureSymbol, + "From" to RuntimeType.From, + "TryFrom" to RuntimeType.TryFrom, + "MaybeConstrained" to RuntimeType.MaybeConstrained, + ) - fun render(rustCrate: RustCrate, writer: RustWriter) { + fun render( + rustCrate: RustCrate, + writer: RustWriter, + ) { check(!codegenContext.settings.codegenConfig.publicConstrainedTypes) { "ServerBuilderGeneratorWithoutPublicConstrainedTypes should only be used when `publicConstrainedTypes` is false" } @@ -206,7 +213,10 @@ class ServerBuilderGeneratorWithoutPublicConstrainedTypes( } } - private fun renderBuilderMember(writer: RustWriter, member: MemberShape) { + private fun renderBuilderMember( + writer: RustWriter, + member: MemberShape, + ) { val memberSymbol = builderMemberSymbol(member) val memberName = symbolProvider.toMemberName(member) // Builder members are crate-public to enable using them directly in serializers/deserializers. @@ -220,7 +230,10 @@ class ServerBuilderGeneratorWithoutPublicConstrainedTypes( * * This method is meant for use by the user; it is not used by the generated crate's (de)serializers. */ - private fun renderBuilderMemberFn(writer: RustWriter, member: MemberShape) { + private fun renderBuilderMemberFn( + writer: RustWriter, + member: MemberShape, + ) { val memberSymbol = symbolProvider.toSymbol(member) val memberName = symbolProvider.toMemberName(member) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderSymbol.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderSymbol.kt index f0ca507db16..9e79f918c64 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderSymbol.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderSymbol.kt @@ -24,18 +24,23 @@ fun StructureShape.serverBuilderSymbol(codegenContext: ServerCodegenContext): Sy ) // TODO(https://github.com/smithy-lang/smithy-rs/issues/2396): Replace this with `RustSymbolProvider.moduleForBuilder` -fun StructureShape.serverBuilderModule(symbolProvider: SymbolProvider, pubCrate: Boolean): RustModule.LeafModule { +fun StructureShape.serverBuilderModule( + symbolProvider: SymbolProvider, + pubCrate: Boolean, +): RustModule.LeafModule { val structureSymbol = symbolProvider.toSymbol(this) - val builderNamespace = RustReservedWords.escapeIfNeeded(structureSymbol.name.toSnakeCase()) + - if (pubCrate) { - "_internal" - } else { - "" + val builderNamespace = + RustReservedWords.escapeIfNeeded(structureSymbol.name.toSnakeCase()) + + if (pubCrate) { + "_internal" + } else { + "" + } + val visibility = + when (pubCrate) { + true -> Visibility.PUBCRATE + false -> Visibility.PUBLIC } - val visibility = when (pubCrate) { - true -> Visibility.PUBCRATE - false -> Visibility.PUBLIC - } return RustModule.new( builderNamespace, visibility, @@ -46,7 +51,10 @@ fun StructureShape.serverBuilderModule(symbolProvider: SymbolProvider, pubCrate: } // TODO(https://github.com/smithy-lang/smithy-rs/issues/2396): Replace this with `RustSymbolProvider.symbolForBuilder` -fun StructureShape.serverBuilderSymbol(symbolProvider: SymbolProvider, pubCrate: Boolean): Symbol { +fun StructureShape.serverBuilderSymbol( + symbolProvider: SymbolProvider, + pubCrate: Boolean, +): Symbol { val builderModule = serverBuilderModule(symbolProvider, pubCrate) val rustType = RustType.Opaque("Builder", builderModule.fullyQualifiedPath()) return Symbol.builder() diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerEnumGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerEnumGenerator.kt index 09a0d2d5cd7..354cbc0f662 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerEnumGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerEnumGenerator.kt @@ -38,71 +38,75 @@ open class ConstrainedEnum( } private val constraintViolationSymbol = constraintViolationSymbolProvider.toSymbol(shape) private val constraintViolationName = constraintViolationSymbol.name - private val codegenScope = arrayOf( - "String" to RuntimeType.String, - ) - - override fun implFromForStr(context: EnumGeneratorContext): Writable = writable { - withInlineModule(constraintViolationSymbol.module(), codegenContext.moduleDocProvider) { - rustTemplate( - """ - ##[derive(Debug, PartialEq)] - pub struct $constraintViolationName(pub(crate) #{String}); - """, - *codegenScope, - ) + private val codegenScope = + arrayOf( + "String" to RuntimeType.String, + ) - if (shape.isReachableFromOperationInput()) { + override fun implFromForStr(context: EnumGeneratorContext): Writable = + writable { + withInlineModule(constraintViolationSymbol.module(), codegenContext.moduleDocProvider) { rustTemplate( """ - impl $constraintViolationName { - #{EnumShapeConstraintViolationImplBlock:W} - } + ##[derive(Debug, PartialEq)] + pub struct $constraintViolationName(pub(crate) #{String}); """, - "EnumShapeConstraintViolationImplBlock" to validationExceptionConversionGenerator.enumShapeConstraintViolationImplBlock( - context.enumTrait, - ), + *codegenScope, ) + + if (shape.isReachableFromOperationInput()) { + rustTemplate( + """ + impl $constraintViolationName { + #{EnumShapeConstraintViolationImplBlock:W} + } + """, + "EnumShapeConstraintViolationImplBlock" to + validationExceptionConversionGenerator.enumShapeConstraintViolationImplBlock( + context.enumTrait, + ), + ) + } } - } - rustBlock("impl #T<&str> for ${context.enumName}", RuntimeType.TryFrom) { - rust("type Error = #T;", constraintViolationSymbol) - rustBlockTemplate("fn try_from(s: &str) -> #{Result}>::Error>", *preludeScope) { - rustBlock("match s") { - context.sortedMembers.forEach { member -> - rust("${member.value.dq()} => Ok(${context.enumName}::${member.derivedName()}),") + rustBlock("impl #T<&str> for ${context.enumName}", RuntimeType.TryFrom) { + rust("type Error = #T;", constraintViolationSymbol) + rustBlockTemplate("fn try_from(s: &str) -> #{Result}>::Error>", *preludeScope) { + rustBlock("match s") { + context.sortedMembers.forEach { member -> + rust("${member.value.dq()} => Ok(${context.enumName}::${member.derivedName()}),") + } + rust("_ => Err(#T(s.to_owned()))", constraintViolationSymbol) } - rust("_ => Err(#T(s.to_owned()))", constraintViolationSymbol) } } - } - rustTemplate( - """ - impl #{TryFrom}<#{String}> for ${context.enumName} { - type Error = #{ConstraintViolation}; - fn try_from(s: #{String}) -> #{Result}>::Error> { - s.as_str().try_into() + rustTemplate( + """ + impl #{TryFrom}<#{String}> for ${context.enumName} { + type Error = #{ConstraintViolation}; + fn try_from(s: #{String}) -> #{Result}>::Error> { + s.as_str().try_into() + } } - } - """, - *preludeScope, - "ConstraintViolation" to constraintViolationSymbol, - ) - } + """, + *preludeScope, + "ConstraintViolation" to constraintViolationSymbol, + ) + } - override fun implFromStr(context: EnumGeneratorContext): Writable = writable { - rustTemplate( - """ - impl std::str::FromStr for ${context.enumName} { - type Err = #{ConstraintViolation}; - fn from_str(s: &str) -> std::result::Result::Err> { - Self::try_from(s) + override fun implFromStr(context: EnumGeneratorContext): Writable = + writable { + rustTemplate( + """ + impl std::str::FromStr for ${context.enumName} { + type Err = #{ConstraintViolation}; + fn from_str(s: &str) -> std::result::Result::Err> { + Self::try_from(s) + } } - } - """, - "ConstraintViolation" to constraintViolationSymbol, - ) - } + """, + "ConstraintViolation" to constraintViolationSymbol, + ) + } } class ServerEnumGenerator( @@ -110,8 +114,8 @@ class ServerEnumGenerator( shape: StringShape, validationExceptionConversionGenerator: ValidationExceptionConversionGenerator, ) : EnumGenerator( - codegenContext.model, - codegenContext.symbolProvider, - shape, - enumType = ConstrainedEnum(codegenContext, shape, validationExceptionConversionGenerator), -) + codegenContext.model, + codegenContext.symbolProvider, + shape, + enumType = ConstrainedEnum(codegenContext, shape, validationExceptionConversionGenerator), + ) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerHttpSensitivityGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerHttpSensitivityGenerator.kt index 0347013de5c..468c59547b0 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerHttpSensitivityGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerHttpSensitivityGenerator.kt @@ -33,25 +33,28 @@ import java.util.Optional /** Models the ways status codes can be bound and sensitive. */ class StatusCodeSensitivity(private val sensitive: Boolean, runtimeConfig: RuntimeConfig) { - private val codegenScope = arrayOf( - "SmithyHttpServer" to ServerCargoDependency.smithyHttpServer(runtimeConfig).toType(), - ) + private val codegenScope = + arrayOf( + "SmithyHttpServer" to ServerCargoDependency.smithyHttpServer(runtimeConfig).toType(), + ) /** Returns the type of the `MakeFmt`. */ - fun type(): Writable = writable { - if (sensitive) { - rustTemplate("#{SmithyHttpServer}::instrumentation::MakeSensitive", *codegenScope) - } else { - rustTemplate("#{SmithyHttpServer}::instrumentation::MakeIdentity", *codegenScope) + fun type(): Writable = + writable { + if (sensitive) { + rustTemplate("#{SmithyHttpServer}::instrumentation::MakeSensitive", *codegenScope) + } else { + rustTemplate("#{SmithyHttpServer}::instrumentation::MakeIdentity", *codegenScope) + } } - } /** Returns the setter. */ - fun setter(): Writable = writable { - if (sensitive) { - rust(".status_code()") + fun setter(): Writable = + writable { + if (sensitive) { + rust(".status_code()") + } } - } } /** Represents the information needed to specify the position of a greedy label. */ @@ -68,54 +71,59 @@ class LabelSensitivity(internal val labelIndexes: List, internal val greedy arrayOf("SmithyHttpServer" to ServerCargoDependency.smithyHttpServer(runtimeConfig).toType()) /** Returns the closure used during construction. */ - fun closure(): Writable = writable { - if (labelIndexes.isNotEmpty()) { - rustTemplate( - """ - { - |index: usize| matches!(index, ${labelIndexes.joinToString("|")}) - } as fn(usize) -> bool - """, - *codegenScope, - ) - } else { - rust("{ |_index: usize| false } as fn(usize) -> bool") + fun closure(): Writable = + writable { + if (labelIndexes.isNotEmpty()) { + rustTemplate( + """ + { + |index: usize| matches!(index, ${labelIndexes.joinToString("|")}) + } as fn(usize) -> bool + """, + *codegenScope, + ) + } else { + rust("{ |_index: usize| false } as fn(usize) -> bool") + } } - } + private fun hasRedactions(): Boolean = labelIndexes.isNotEmpty() || greedyLabel != null /** Returns the type of the `MakeFmt`. */ - fun type(): Writable = if (hasRedactions()) { - writable { - rustTemplate("#{SmithyHttpServer}::instrumentation::sensitivity::uri::MakeLabel bool>", *codegenScope) - } - } else { - writable { - rustTemplate("#{SmithyHttpServer}::instrumentation::MakeIdentity", *codegenScope) + fun type(): Writable = + if (hasRedactions()) { + writable { + rustTemplate("#{SmithyHttpServer}::instrumentation::sensitivity::uri::MakeLabel bool>", *codegenScope) + } + } else { + writable { + rustTemplate("#{SmithyHttpServer}::instrumentation::MakeIdentity", *codegenScope) + } } - } /** Returns the value of the `GreedyLabel`. */ - private fun greedyLabelStruct(): Writable = writable { - if (greedyLabel != null) { - rustTemplate( - """ + private fun greedyLabelStruct(): Writable = + writable { + if (greedyLabel != null) { + rustTemplate( + """ Some(#{SmithyHttpServer}::instrumentation::sensitivity::uri::GreedyLabel::new(${greedyLabel.segmentIndex}, ${greedyLabel.endOffset}))""", - *codegenScope, - ) - } else { - rust("None") + *codegenScope, + ) + } else { + rust("None") + } } - } /** Returns the setter enclosing the closure or suffix position. */ - fun setter(): Writable = if (hasRedactions()) { - writable { - rustTemplate(".label(#{Closure:W}, #{GreedyLabel:W})", "Closure" to closure(), "GreedyLabel" to greedyLabelStruct()) + fun setter(): Writable = + if (hasRedactions()) { + writable { + rustTemplate(".label(#{Closure:W}, #{GreedyLabel:W})", "Closure" to closure(), "GreedyLabel" to greedyLabelStruct()) + } + } else { + writable { } } - } else { - writable { } - } } /** Models the ways headers can be bound and sensitive */ @@ -124,10 +132,11 @@ sealed class HeaderSensitivity( val headerKeys: List, runtimeConfig: RuntimeConfig, ) { - private val codegenScope = arrayOf( - "SmithyHttpServer" to ServerCargoDependency.smithyHttpServer(runtimeConfig).toType(), - "Http" to CargoDependency.Http.toType(), - ) + private val codegenScope = + arrayOf( + "SmithyHttpServer" to ServerCargoDependency.smithyHttpServer(runtimeConfig).toType(), + "Http" to CargoDependency.Http.toType(), + ) /** The case where `prefixHeaders` value is not sensitive. */ class NotSensitiveMapValue( @@ -148,55 +157,62 @@ sealed class HeaderSensitivity( ) : HeaderSensitivity(headerKeys, runtimeConfig) /** Is there anything to redact? */ - internal fun hasRedactions(): Boolean = headerKeys.isNotEmpty() || when (this) { - is NotSensitiveMapValue -> prefixHeader != null - is SensitiveMapValue -> true - } + internal fun hasRedactions(): Boolean = + headerKeys.isNotEmpty() || + when (this) { + is NotSensitiveMapValue -> prefixHeader != null + is SensitiveMapValue -> true + } /** Returns the type of the `MakeDebug`. */ - fun type(): Writable = writable { - if (hasRedactions()) { - rustTemplate("#{SmithyHttpServer}::instrumentation::sensitivity::headers::MakeHeaders #{SmithyHttpServer}::instrumentation::sensitivity::headers::HeaderMarker>", *codegenScope) - } else { - rustTemplate("#{SmithyHttpServer}::instrumentation::MakeIdentity", *codegenScope) + fun type(): Writable = + writable { + if (hasRedactions()) { + rustTemplate("#{SmithyHttpServer}::instrumentation::sensitivity::headers::MakeHeaders #{SmithyHttpServer}::instrumentation::sensitivity::headers::HeaderMarker>", *codegenScope) + } else { + rustTemplate("#{SmithyHttpServer}::instrumentation::MakeIdentity", *codegenScope) + } } - } /** Returns the closure used during construction. */ internal fun closure(): Writable { - val nameMatch = if (headerKeys.isEmpty()) { - writable { - rust("false") - } - } else { - writable { - val matches = headerKeys.joinToString("|") { it.dq() } - rust("matches!(name.as_str(), $matches)") + val nameMatch = + if (headerKeys.isEmpty()) { + writable { + rust("false") + } + } else { + writable { + val matches = headerKeys.joinToString("|") { it.dq() } + rust("matches!(name.as_str(), $matches)") + } } - } - val suffixAndValue = when (this) { - is NotSensitiveMapValue -> writable { - prefixHeader?.let { - rust( - """ - let starts_with = name.as_str().starts_with("$it"); - let key_suffix = if starts_with { Some(${it.length}) } else { None }; - """, - ) - } ?: rust("let key_suffix = None;") - rust("let value = name_match;") - } - is SensitiveMapValue -> writable { - rust("let starts_with = name.as_str().starts_with(${prefixHeader.dq()});") - if (keySensitive) { - rust("let key_suffix = if starts_with { Some(${prefixHeader.length}) } else { None };") - } else { - rust("let key_suffix = None;") - } - rust("let value = name_match || starts_with;") + val suffixAndValue = + when (this) { + is NotSensitiveMapValue -> + writable { + prefixHeader?.let { + rust( + """ + let starts_with = name.as_str().starts_with("$it"); + let key_suffix = if starts_with { Some(${it.length}) } else { None }; + """, + ) + } ?: rust("let key_suffix = None;") + rust("let value = name_match;") + } + is SensitiveMapValue -> + writable { + rust("let starts_with = name.as_str().starts_with(${prefixHeader.dq()});") + if (keySensitive) { + rust("let key_suffix = if starts_with { Some(${prefixHeader.length}) } else { None };") + } else { + rust("let key_suffix = None;") + } + rust("let value = name_match || starts_with;") + } } - } return writable { rustTemplate( @@ -217,11 +233,12 @@ sealed class HeaderSensitivity( } /** Returns the setter enclosing the closure. */ - fun setter(): Writable = writable { - if (hasRedactions()) { - rustTemplate(".header(#{Closure:W})", "Closure" to closure()) + fun setter(): Writable = + writable { + if (hasRedactions()) { + rustTemplate(".header(#{Closure:W})", "Closure" to closure()) + } } - } } /** Models the ways query strings can be bound and sensitive. */ @@ -244,37 +261,42 @@ sealed class QuerySensitivity( class SensitiveMapValue(allKeysSensitive: Boolean, runtimeConfig: RuntimeConfig) : QuerySensitivity(allKeysSensitive, runtimeConfig) /** Is there anything to redact? */ - internal fun hasRedactions(): Boolean = when (this) { - is NotSensitiveMapValue -> allKeysSensitive || queryKeys.isNotEmpty() - is SensitiveMapValue -> true - } + internal fun hasRedactions(): Boolean = + when (this) { + is NotSensitiveMapValue -> allKeysSensitive || queryKeys.isNotEmpty() + is SensitiveMapValue -> true + } /** Returns the type of the `MakeFmt`. */ - fun type(): Writable = writable { - if (hasRedactions()) { - rustTemplate("#{SmithyHttpServer}::instrumentation::sensitivity::uri::MakeQuery #{SmithyHttpServer}::instrumentation::sensitivity::uri::QueryMarker>", *codegenScope) - } else { - rustTemplate("#{SmithyHttpServer}::instrumentation::MakeIdentity", *codegenScope) + fun type(): Writable = + writable { + if (hasRedactions()) { + rustTemplate("#{SmithyHttpServer}::instrumentation::sensitivity::uri::MakeQuery #{SmithyHttpServer}::instrumentation::sensitivity::uri::QueryMarker>", *codegenScope) + } else { + rustTemplate("#{SmithyHttpServer}::instrumentation::MakeIdentity", *codegenScope) + } } - } /** Returns the closure used during construction. */ internal fun closure(): Writable { - val value = when (this) { - is SensitiveMapValue -> writable { - rust("true") - } - is NotSensitiveMapValue -> if (queryKeys.isEmpty()) { - writable { - rust("false;") - } - } else { - writable { - val matches = queryKeys.joinToString("|") { it.dq() } - rust("matches!(name, $matches);") - } + val value = + when (this) { + is SensitiveMapValue -> + writable { + rust("true") + } + is NotSensitiveMapValue -> + if (queryKeys.isEmpty()) { + writable { + rust("false;") + } + } else { + writable { + val matches = queryKeys.joinToString("|") { it.dq() } + rust("matches!(name, $matches);") + } + } } - } return writable { rustTemplate( @@ -294,11 +316,12 @@ sealed class QuerySensitivity( } /** Returns the setter enclosing the closure. */ - fun setters(): Writable = writable { - if (hasRedactions()) { - rustTemplate(".query(#{Closure:W})", "Closure" to closure()) + fun setters(): Writable = + writable { + if (hasRedactions()) { + rustTemplate(".query(#{Closure:W})", "Closure" to closure()) + } } - } } /** Represents a `RequestFmt` or `ResponseFmt` type and value. */ @@ -322,10 +345,11 @@ class ServerHttpSensitivityGenerator( private val operation: OperationShape, private val runtimeConfig: RuntimeConfig, ) { - private val codegenScope = arrayOf( - "SmithyHttpServer" to ServerCargoDependency.smithyHttpServer(runtimeConfig).toType(), - "Http" to CargoDependency.Http.toType(), - ) + private val codegenScope = + arrayOf( + "SmithyHttpServer" to ServerCargoDependency.smithyHttpServer(runtimeConfig).toType(), + "Http" to CargoDependency.Http.toType(), + ) /** Constructs `StatusCodeSensitivity` of a `Shape` */ private fun findStatusCodeSensitivity(rootShape: Shape): StatusCodeSensitivity { @@ -333,10 +357,11 @@ class ServerHttpSensitivityGenerator( val rootSensitive = rootShape.hasTrait() // Find all sensitive `httpResponseCode` bindings in the `rootShape`. - val isSensitive = rootShape - .members() - .filter { it.hasTrait() } - .any { rootSensitive || it.getMemberTrait(model, SensitiveTrait::class.java).isPresent } + val isSensitive = + rootShape + .members() + .filter { it.hasTrait() } + .any { rootSensitive || it.getMemberTrait(model, SensitiveTrait::class.java).isPresent } return StatusCodeSensitivity(isSensitive, runtimeConfig) } @@ -351,17 +376,20 @@ class ServerHttpSensitivityGenerator( // Is `rootShape` sensitive and does `httpPrefixHeaders` exist? val rootSensitive = rootShape.hasTrait() - if (rootSensitive) if (prefixHeader != null) { - return HeaderSensitivity.SensitiveMapValue( - headerKeys.map { it.second }, true, - prefixHeader.second, runtimeConfig, - ) + if (rootSensitive) { + if (prefixHeader != null) { + return HeaderSensitivity.SensitiveMapValue( + headerKeys.map { it.second }, true, + prefixHeader.second, runtimeConfig, + ) + } } // Which headers are sensitive? - val sensitiveHeaders = headerKeys - .filter { (member, _) -> rootSensitive || member.getMemberTrait(model, SensitiveTrait::class.java).orNull() != null } - .map { (_, name) -> name } + val sensitiveHeaders = + headerKeys + .filter { (member, _) -> rootSensitive || member.getMemberTrait(model, SensitiveTrait::class.java).orNull() != null } + .map { (_, name) -> name } return if (prefixHeader != null) { // Get the `httpPrefixHeader` map. @@ -390,17 +418,20 @@ class ServerHttpSensitivityGenerator( // Is `rootShape` sensitive and does `httpQueryParams` exist? val rootSensitive = rootShape.hasTrait() - if (rootSensitive) if (queryParams != null) { - return QuerySensitivity.SensitiveMapValue(true, runtimeConfig) + if (rootSensitive) { + if (queryParams != null) { + return QuerySensitivity.SensitiveMapValue(true, runtimeConfig) + } } // Find all `httpQuery` bindings in the `rootShape`. val queryKeys = rootShape.members().mapNotNull { member -> member.getTrait()?.let { trait -> Pair(member, trait.value) } }.distinct() // Which queries are sensitive? - val sensitiveQueries = queryKeys - .filter { (member, _) -> rootSensitive || member.getMemberTrait(model, SensitiveTrait::class.java).orNull() != null } - .map { (_, name) -> name } + val sensitiveQueries = + queryKeys + .filter { (member, _) -> rootSensitive || member.getMemberTrait(model, SensitiveTrait::class.java).orNull() != null } + .map { (_, name) -> name } return if (queryParams != null) { // Get the `httpQueryParams` map. @@ -421,41 +452,48 @@ class ServerHttpSensitivityGenerator( } /** Constructs `LabelSensitivity` of a `Shape` */ - internal fun findLabelSensitivity(uriPattern: UriPattern, rootShape: Shape): LabelSensitivity { + internal fun findLabelSensitivity( + uriPattern: UriPattern, + rootShape: Shape, + ): LabelSensitivity { // Is root shape sensitive? val rootSensitive = rootShape.hasTrait() // Find `httpLabel` trait which are also sensitive. - val httpLabels = rootShape - .members() - .filter { it.hasTrait() } - .filter { rootSensitive || it.getMemberTrait(model, SensitiveTrait::class.java).orNull() != null } + val httpLabels = + rootShape + .members() + .filter { it.hasTrait() } + .filter { rootSensitive || it.getMemberTrait(model, SensitiveTrait::class.java).orNull() != null } + + val labelIndexes = + httpLabels + .mapNotNull { member -> + uriPattern + .segments + .withIndex() + .find { (_, segment) -> + segment.isLabel && !segment.isGreedyLabel && segment.content == member.memberName + } + } + .map { (index, _) -> index } - val labelIndexes = httpLabels - .mapNotNull { member -> + val greedyLabel = + httpLabels.mapNotNull { member -> uriPattern .segments .withIndex() - .find { (_, segment) -> - segment.isLabel && !segment.isGreedyLabel && segment.content == member.memberName - } - } - .map { (index, _) -> index } - - val greedyLabel = httpLabels.mapNotNull { member -> - uriPattern - .segments - .withIndex() - .find { (_, segment) -> segment.isGreedyLabel && segment.content == member.memberName } - } - .singleOrNull() - ?.let { (index, _) -> - val remainder = uriPattern - .segments - .drop(index + 1) - .sumOf { it.content.length + 1 } - GreedyLabel(index, remainder) + .find { (_, segment) -> segment.isGreedyLabel && segment.content == member.memberName } } + .singleOrNull() + ?.let { (index, _) -> + val remainder = + uriPattern + .segments + .drop(index + 1) + .sumOf { it.content.length + 1 } + GreedyLabel(index, remainder) + } return LabelSensitivity(labelIndexes, greedyLabel, runtimeConfig) } @@ -473,12 +511,14 @@ class ServerHttpSensitivityGenerator( } private fun defaultRequestFmt(): MakeFmt { - val type = writable { - rustTemplate("#{SmithyHttpServer}::instrumentation::sensitivity::DefaultRequestFmt", *codegenScope) - } - val value = writable { - rustTemplate("#{SmithyHttpServer}::instrumentation::sensitivity::RequestFmt::new()", *codegenScope) - } + val type = + writable { + rustTemplate("#{SmithyHttpServer}::instrumentation::sensitivity::DefaultRequestFmt", *codegenScope) + } + val value = + writable { + rustTemplate("#{SmithyHttpServer}::instrumentation::sensitivity::RequestFmt::new()", *codegenScope) + } return MakeFmt(type, value) } @@ -496,39 +536,43 @@ class ServerHttpSensitivityGenerator( // httpQuery/httpQueryParams bindings val querySensitivity = findQuerySensitivity(inputShape) - val type = writable { - rustTemplate( - """ - #{SmithyHttpServer}::instrumentation::sensitivity::RequestFmt< - #{HeaderType:W}, - #{SmithyHttpServer}::instrumentation::sensitivity::uri::MakeUri< - #{LabelType:W}, - #{QueryType:W} + val type = + writable { + rustTemplate( + """ + #{SmithyHttpServer}::instrumentation::sensitivity::RequestFmt< + #{HeaderType:W}, + #{SmithyHttpServer}::instrumentation::sensitivity::uri::MakeUri< + #{LabelType:W}, + #{QueryType:W} + > > - > - """, - "HeaderType" to headerSensitivity.type(), - "LabelType" to labelSensitivity.type(), - "QueryType" to querySensitivity.type(), - *codegenScope, - ) - } + """, + "HeaderType" to headerSensitivity.type(), + "LabelType" to labelSensitivity.type(), + "QueryType" to querySensitivity.type(), + *codegenScope, + ) + } - val value = writable { - rustTemplate("#{SmithyHttpServer}::instrumentation::sensitivity::RequestFmt::new()", *codegenScope) - } + headerSensitivity.setter() + labelSensitivity.setter() + querySensitivity.setters() + val value = + writable { + rustTemplate("#{SmithyHttpServer}::instrumentation::sensitivity::RequestFmt::new()", *codegenScope) + } + headerSensitivity.setter() + labelSensitivity.setter() + querySensitivity.setters() return MakeFmt(type, value) } private fun defaultResponseFmt(): MakeFmt { - val type = writable { - rustTemplate("#{SmithyHttpServer}::instrumentation::sensitivity::DefaultResponseFmt", *codegenScope) - } + val type = + writable { + rustTemplate("#{SmithyHttpServer}::instrumentation::sensitivity::DefaultResponseFmt", *codegenScope) + } - val value = writable { - rustTemplate("#{SmithyHttpServer}::instrumentation::sensitivity::ResponseFmt::new()", *codegenScope) - } + val value = + writable { + rustTemplate("#{SmithyHttpServer}::instrumentation::sensitivity::ResponseFmt::new()", *codegenScope) + } return MakeFmt(type, value) } @@ -543,18 +587,20 @@ class ServerHttpSensitivityGenerator( // Status code bindings val statusCodeSensitivity = findStatusCodeSensitivity(outputShape) - val type = writable { - rustTemplate( - "#{SmithyHttpServer}::instrumentation::sensitivity::ResponseFmt<#{HeaderType:W}, #{StatusType:W}>", - "HeaderType" to headerSensitivity.type(), - "StatusType" to statusCodeSensitivity.type(), - *codegenScope, - ) - } + val type = + writable { + rustTemplate( + "#{SmithyHttpServer}::instrumentation::sensitivity::ResponseFmt<#{HeaderType:W}, #{StatusType:W}>", + "HeaderType" to headerSensitivity.type(), + "StatusType" to statusCodeSensitivity.type(), + *codegenScope, + ) + } - val value = writable { - rustTemplate("#{SmithyHttpServer}::instrumentation::sensitivity::ResponseFmt::new()", *codegenScope) - } + headerSensitivity.setter() + statusCodeSensitivity.setter() + val value = + writable { + rustTemplate("#{SmithyHttpServer}::instrumentation::sensitivity::ResponseFmt::new()", *codegenScope) + } + headerSensitivity.setter() + statusCodeSensitivity.setter() return MakeFmt(type, value) } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiator.kt index fade448863b..f52ccd484ed 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiator.kt @@ -26,14 +26,15 @@ import software.amazon.smithy.rust.codegen.server.smithy.traits.isReachableFromO class ServerAfterInstantiatingValueConstrainItIfNecessary(val codegenContext: CodegenContext) : InstantiatorCustomization() { - - override fun section(section: InstantiatorSection): Writable = when (section) { - is InstantiatorSection.AfterInstantiatingValue -> writable { - if (section.shape.isDirectlyConstrained(codegenContext.symbolProvider)) { - rust(""".try_into().expect("this is only used in tests")""") - } + override fun section(section: InstantiatorSection): Writable = + when (section) { + is InstantiatorSection.AfterInstantiatingValue -> + writable { + if (section.shape.isDirectlyConstrained(codegenContext.symbolProvider)) { + rust(""".try_into().expect("this is only used in tests")""") + } + } } - } } class ServerBuilderKindBehavior(val codegenContext: CodegenContext) : Instantiator.BuilderKindBehavior { @@ -41,11 +42,12 @@ class ServerBuilderKindBehavior(val codegenContext: CodegenContext) : Instantiat // Only operation input builders take in unconstrained types. val takesInUnconstrainedTypes = shape.isReachableFromOperationInput() - val publicConstrainedTypes = if (codegenContext is ServerCodegenContext) { - codegenContext.settings.codegenConfig.publicConstrainedTypes - } else { - true - } + val publicConstrainedTypes = + if (codegenContext is ServerCodegenContext) { + codegenContext.settings.codegenConfig.publicConstrainedTypes + } else { + true + } return if (publicConstrainedTypes) { ServerBuilderGenerator.hasFallibleBuilder( @@ -84,7 +86,11 @@ class ServerBuilderInstantiator( private val symbolProvider: RustSymbolProvider, private val symbolParseFn: (Shape) -> ReturnSymbolToParse, ) : BuilderInstantiator { - override fun setField(builder: String, value: Writable, field: MemberShape): Writable { + override fun setField( + builder: String, + value: Writable, + field: MemberShape, + ): Writable { // Server builders have the ability to have non-optional fields. When one of these fields is used, // we need to use `if let(...)` to only set the field when it is present. return if (!symbolProvider.toSymbol(field).isOptional()) { @@ -104,12 +110,17 @@ class ServerBuilderInstantiator( } } - override fun finalizeBuilder(builder: String, shape: StructureShape, mapErr: Writable?): Writable = writable { - val returnSymbolToParse = symbolParseFn(shape) - if (returnSymbolToParse.isUnconstrained) { - rust(builder) - } else { - rust("$builder.build()") + override fun finalizeBuilder( + builder: String, + shape: StructureShape, + mapErr: Writable?, + ): Writable = + writable { + val returnSymbolToParse = symbolParseFn(shape) + if (returnSymbolToParse.isUnconstrained) { + rust(builder) + } else { + rust("$builder.build()") + } } - } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationErrorGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationErrorGenerator.kt index 0330b05905a..62ac9c38b73 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationErrorGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationErrorGenerator.kt @@ -39,23 +39,26 @@ open class ServerOperationErrorGenerator( private fun operationErrors(): List = (operationOrEventStream as OperationShape).operationErrors(model).map { it.asStructureShape().get() } + private fun eventStreamErrors(): List = (operationOrEventStream as UnionShape).eventStreamErrors() .map { model.expectShape(it.asMemberShape().get().target, StructureShape::class.java) } open fun render(writer: RustWriter) { - val (errorSymbol, errors) = when (operationOrEventStream) { - is OperationShape -> symbolProvider.symbolForOperationError(operationOrEventStream) to operationErrors() - is UnionShape -> symbolProvider.symbolForEventStreamError(operationOrEventStream) to eventStreamErrors() - else -> UNREACHABLE("OperationErrorGenerator only supports operation or event stream shapes") - } + val (errorSymbol, errors) = + when (operationOrEventStream) { + is OperationShape -> symbolProvider.symbolForOperationError(operationOrEventStream) to operationErrors() + is UnionShape -> symbolProvider.symbolForEventStreamError(operationOrEventStream) to eventStreamErrors() + else -> UNREACHABLE("OperationErrorGenerator only supports operation or event stream shapes") + } if (errors.isEmpty()) { return } - val meta = RustMetadata( - derives = setOf(RuntimeType.Debug), - visibility = Visibility.PUBLIC, - ) + val meta = + RustMetadata( + derives = setOf(RuntimeType.Debug), + visibility = Visibility.PUBLIC, + ) writer.rust("/// Error type for the `${symbol.name}` operation.") writer.rust("/// Each variant represents an error that can occur for the `${symbol.name}` operation.") diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationGenerator.kt index 3a5af6cb509..f7fe272f699 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationGenerator.kt @@ -34,14 +34,15 @@ class ServerOperationGenerator( private val operationId = operation.id /** Returns `std::convert::Infallible` if the model provides no errors. */ - private fun operationError(): Writable = writable { - if (operation.errors.isEmpty()) { - rust("std::convert::Infallible") - } else { - // Name comes from [ServerOperationErrorGenerator]. - rust("crate::error::${symbolProvider.toSymbol(operation).name}Error") + private fun operationError(): Writable = + writable { + if (operation.errors.isEmpty()) { + rust("std::convert::Infallible") + } else { + // Name comes from [ServerOperationErrorGenerator]. + rust("crate::error::${symbolProvider.toSymbol(operation).name}Error") + } } - } fun render(writer: RustWriter) { writer.documentShape(operation, model) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerRootGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerRootGenerator.kt index 63f55954da2..177c6fa8455 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerRootGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerRootGenerator.kt @@ -34,11 +34,12 @@ open class ServerRootGenerator( private val isConfigBuilderFallible: Boolean, ) { private val index = TopDownIndex.of(codegenContext.model) - private val operations = index.getContainedOperations(codegenContext.serviceShape).toSortedSet( - compareBy { - it.id - }, - ).toList() + private val operations = + index.getContainedOperations(codegenContext.serviceShape).toSortedSet( + compareBy { + it.id + }, + ).toList() private val serviceName = codegenContext.serviceShape.id.name.toPascalCase() fun documentation(writer: RustWriter) { @@ -52,11 +53,12 @@ open class ServerRootGenerator( val crateName = codegenContext.moduleUseName() val builderName = "${serviceName}Builder" val hasErrors = operations.any { it.errors.isNotEmpty() } - val handlers: Writable = operations - .map { operation -> - DocHandlerGenerator(codegenContext, operation, builderFieldNames[operation]!!, "//!").docSignature() - } - .join("//!\n") + val handlers: Writable = + operations + .map { operation -> + DocHandlerGenerator(codegenContext, operation, builderFieldNames[operation]!!, "//!").docSignature() + } + .join("//!\n") val unwrapConfigBuilder = if (isConfigBuilderFallible) ".expect(\"config failed to build\")" else "" @@ -246,11 +248,12 @@ open class ServerRootGenerator( documentation(rustWriter) // Only export config builder error if fallible. - val configErrorReExport = if (isConfigBuilderFallible) { - "${serviceName}ConfigError," - } else { - "" - } + val configErrorReExport = + if (isConfigBuilderFallible) { + "${serviceName}ConfigError," + } else { + "" + } rustWriter.rust( """ pub use crate::service::{ diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerRuntimeTypesReExportsGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerRuntimeTypesReExportsGenerator.kt index a6f62246d87..e432d7e8468 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerRuntimeTypesReExportsGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerRuntimeTypesReExportsGenerator.kt @@ -14,9 +14,10 @@ class ServerRuntimeTypesReExportsGenerator( codegenContext: CodegenContext, ) { private val runtimeConfig = codegenContext.runtimeConfig - private val codegenScope = arrayOf( - "SmithyHttpServer" to ServerCargoDependency.smithyHttpServer(runtimeConfig).toType(), - ) + private val codegenScope = + arrayOf( + "SmithyHttpServer" to ServerCargoDependency.smithyHttpServer(runtimeConfig).toType(), + ) fun render(writer: RustWriter) { writer.rustTemplate( diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGenerator.kt index 8eaec006a71..1898c6c3db0 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGenerator.kt @@ -80,269 +80,276 @@ class ServerServiceGenerator( private val requestSpecMap: Map> = operations.associateWith { operationShape -> val operationName = symbolProvider.toSymbol(operationShape).name - val spec = protocol.serverRouterRequestSpec( - operationShape, - operationName, - serviceName, - smithyHttpServer.resolve("routing::request_spec"), - ) + val spec = + protocol.serverRouterRequestSpec( + operationShape, + operationName, + serviceName, + smithyHttpServer.resolve("routing::request_spec"), + ) val functionName = RustReservedWords.escapeIfNeeded(operationName.toSnakeCase()) - val functionBody = writable { + val functionBody = + writable { + rustTemplate( + """ + fn $functionName() -> #{SpecType} { + #{Spec:W} + } + """, + "Spec" to spec, + "SpecType" to protocol.serverRouterRequestSpecType(smithyHttpServer.resolve("routing::request_spec")), + ) + } + Pair(functionName, functionBody) + } + + /** A `Writable` block containing all the `Handler` and `Operation` setters for the builder. */ + private fun builderSetters(): Writable = + writable { + for ((operationShape, structName) in operationStructNames) { + val fieldName = builderFieldNames[operationShape] + val docHandler = DocHandlerGenerator(codegenContext, operationShape, "handler", "///") + val handler = docHandler.docSignature() + val handlerFixed = docHandler.docFixedSignature() + val unwrapConfigBuilder = + if (isConfigBuilderFallible) { + ".expect(\"config failed to build\")" + } else { + "" + } rustTemplate( """ - fn $functionName() -> #{SpecType} { - #{Spec:W} + /// Sets the [`$structName`](crate::operation_shape::$structName) operation. + /// + /// This should be an async function satisfying the [`Handler`](#{SmithyHttpServer}::operation::Handler) trait. + /// See the [operation module documentation](#{SmithyHttpServer}::operation) for more information. + /// + /// ## Example + /// + /// ```no_run + /// use $crateName::{$serviceName, ${serviceName}Config}; + /// + #{HandlerImports:W} + /// + #{Handler:W} + /// + /// let config = ${serviceName}Config::builder().build()$unwrapConfigBuilder; + /// let app = $serviceName::builder(config) + /// .$fieldName(handler) + /// /* Set other handlers */ + /// .build() + /// .unwrap(); + /// ## let app: $serviceName<#{SmithyHttpServer}::routing::RoutingService<#{Router}<#{SmithyHttpServer}::routing::Route>, #{Protocol}>> = app; + /// ``` + /// + pub fn $fieldName(self, handler: HandlerType) -> Self + where + HandlerType: #{SmithyHttpServer}::operation::Handler, + + ModelPl: #{SmithyHttpServer}::plugin::Plugin< + $serviceName, + crate::operation_shape::$structName, + #{SmithyHttpServer}::operation::IntoService + >, + #{SmithyHttpServer}::operation::UpgradePlugin::: #{SmithyHttpServer}::plugin::Plugin< + $serviceName, + crate::operation_shape::$structName, + ModelPl::Output + >, + HttpPl: #{SmithyHttpServer}::plugin::Plugin< + $serviceName, + crate::operation_shape::$structName, + < + #{SmithyHttpServer}::operation::UpgradePlugin:: + as #{SmithyHttpServer}::plugin::Plugin< + $serviceName, + crate::operation_shape::$structName, + ModelPl::Output + > + >::Output + >, + + HttpPl::Output: #{Tower}::Service<#{Http}::Request, Response = #{Http}::Response<#{SmithyHttpServer}::body::BoxBody>, Error = ::std::convert::Infallible> + Clone + Send + 'static, + >>::Future: Send + 'static, + + { + use #{SmithyHttpServer}::operation::OperationShapeExt; + use #{SmithyHttpServer}::plugin::Plugin; + let svc = crate::operation_shape::$structName::from_handler(handler); + let svc = self.model_plugin.apply(svc); + let svc = #{SmithyHttpServer}::operation::UpgradePlugin::::new().apply(svc); + let svc = self.http_plugin.apply(svc); + self.${fieldName}_custom(svc) + } + + /// Sets the [`$structName`](crate::operation_shape::$structName) operation. + /// + /// This should be an async function satisfying the [`Handler`](#{SmithyHttpServer}::operation::Handler) trait. + /// See the [operation module documentation](#{SmithyHttpServer}::operation) for more information. + /// + /// ## Example + /// + /// ```no_run + /// use $crateName::{$serviceName, ${serviceName}Config}; + /// + #{HandlerImports:W} + /// + #{HandlerFixed:W} + /// + /// let config = ${serviceName}Config::builder().build()$unwrapConfigBuilder; + /// let svc = #{Tower}::util::service_fn(handler); + /// let app = $serviceName::builder(config) + /// .${fieldName}_service(svc) + /// /* Set other handlers */ + /// .build() + /// .unwrap(); + /// ## let app: $serviceName<#{SmithyHttpServer}::routing::RoutingService<#{Router}<#{SmithyHttpServer}::routing::Route>, #{Protocol}>> = app; + /// ``` + /// + pub fn ${fieldName}_service(self, service: S) -> Self + where + S: #{SmithyHttpServer}::operation::OperationService, + + ModelPl: #{SmithyHttpServer}::plugin::Plugin< + $serviceName, + crate::operation_shape::$structName, + #{SmithyHttpServer}::operation::Normalize + >, + #{SmithyHttpServer}::operation::UpgradePlugin::: #{SmithyHttpServer}::plugin::Plugin< + $serviceName, + crate::operation_shape::$structName, + ModelPl::Output + >, + HttpPl: #{SmithyHttpServer}::plugin::Plugin< + $serviceName, + crate::operation_shape::$structName, + < + #{SmithyHttpServer}::operation::UpgradePlugin:: + as #{SmithyHttpServer}::plugin::Plugin< + $serviceName, + crate::operation_shape::$structName, + ModelPl::Output + > + >::Output + >, + + HttpPl::Output: #{Tower}::Service<#{Http}::Request, Response = #{Http}::Response<#{SmithyHttpServer}::body::BoxBody>, Error = ::std::convert::Infallible> + Clone + Send + 'static, + >>::Future: Send + 'static, + + { + use #{SmithyHttpServer}::operation::OperationShapeExt; + use #{SmithyHttpServer}::plugin::Plugin; + let svc = crate::operation_shape::$structName::from_service(service); + let svc = self.model_plugin.apply(svc); + let svc = #{SmithyHttpServer}::operation::UpgradePlugin::::new().apply(svc); + let svc = self.http_plugin.apply(svc); + self.${fieldName}_custom(svc) + } + + /// Sets the [`$structName`](crate::operation_shape::$structName) to a custom [`Service`](tower::Service). + /// not constrained by the Smithy contract. + fn ${fieldName}_custom(mut self, svc: S) -> Self + where + S: #{Tower}::Service<#{Http}::Request, Response = #{Http}::Response<#{SmithyHttpServer}::body::BoxBody>, Error = ::std::convert::Infallible> + Clone + Send + 'static, + S::Future: Send + 'static, + { + self.$fieldName = Some(#{SmithyHttpServer}::routing::Route::new(svc)); + self } """, - "Spec" to spec, - "SpecType" to protocol.serverRouterRequestSpecType(smithyHttpServer.resolve("routing::request_spec")), + "Router" to protocol.routerType(), + "Protocol" to protocol.markerStruct(), + "Handler" to handler, + "HandlerFixed" to handlerFixed, + "HandlerImports" to handlerImports(crateName, operations), + *codegenScope, ) - } - Pair(functionName, functionBody) - } - /** A `Writable` block containing all the `Handler` and `Operation` setters for the builder. */ - private fun builderSetters(): Writable = writable { - for ((operationShape, structName) in operationStructNames) { - val fieldName = builderFieldNames[operationShape] - val docHandler = DocHandlerGenerator(codegenContext, operationShape, "handler", "///") - val handler = docHandler.docSignature() - val handlerFixed = docHandler.docFixedSignature() - val unwrapConfigBuilder = if (isConfigBuilderFallible) { - ".expect(\"config failed to build\")" - } else { - "" + // Adds newline between setters. + rust("") } - rustTemplate( - """ - /// Sets the [`$structName`](crate::operation_shape::$structName) operation. - /// - /// This should be an async function satisfying the [`Handler`](#{SmithyHttpServer}::operation::Handler) trait. - /// See the [operation module documentation](#{SmithyHttpServer}::operation) for more information. - /// - /// ## Example - /// - /// ```no_run - /// use $crateName::{$serviceName, ${serviceName}Config}; - /// - #{HandlerImports:W} - /// - #{Handler:W} - /// - /// let config = ${serviceName}Config::builder().build()$unwrapConfigBuilder; - /// let app = $serviceName::builder(config) - /// .$fieldName(handler) - /// /* Set other handlers */ - /// .build() - /// .unwrap(); - /// ## let app: $serviceName<#{SmithyHttpServer}::routing::RoutingService<#{Router}<#{SmithyHttpServer}::routing::Route>, #{Protocol}>> = app; - /// ``` - /// - pub fn $fieldName(self, handler: HandlerType) -> Self - where - HandlerType: #{SmithyHttpServer}::operation::Handler, - - ModelPl: #{SmithyHttpServer}::plugin::Plugin< - $serviceName, - crate::operation_shape::$structName, - #{SmithyHttpServer}::operation::IntoService - >, - #{SmithyHttpServer}::operation::UpgradePlugin::: #{SmithyHttpServer}::plugin::Plugin< - $serviceName, - crate::operation_shape::$structName, - ModelPl::Output - >, - HttpPl: #{SmithyHttpServer}::plugin::Plugin< - $serviceName, - crate::operation_shape::$structName, - < - #{SmithyHttpServer}::operation::UpgradePlugin:: - as #{SmithyHttpServer}::plugin::Plugin< - $serviceName, - crate::operation_shape::$structName, - ModelPl::Output - > - >::Output - >, + } - HttpPl::Output: #{Tower}::Service<#{Http}::Request, Response = #{Http}::Response<#{SmithyHttpServer}::body::BoxBody>, Error = ::std::convert::Infallible> + Clone + Send + 'static, - >>::Future: Send + 'static, + private fun buildMethod(): Writable = + writable { + val missingOperationsVariableName = "missing_operation_names" + val expectMessageVariableName = "unexpected_error_msg" - { - use #{SmithyHttpServer}::operation::OperationShapeExt; - use #{SmithyHttpServer}::plugin::Plugin; - let svc = crate::operation_shape::$structName::from_handler(handler); - let svc = self.model_plugin.apply(svc); - let svc = #{SmithyHttpServer}::operation::UpgradePlugin::::new().apply(svc); - let svc = self.http_plugin.apply(svc); - self.${fieldName}_custom(svc) + val nullabilityChecks = + writable { + for (operationShape in operations) { + val fieldName = builderFieldNames[operationShape]!! + val operationZstTypeName = operationStructNames[operationShape]!! + rust( + """ + if self.$fieldName.is_none() { + $missingOperationsVariableName.insert(crate::operation_shape::$operationZstTypeName::ID, ".$fieldName()"); + } + """, + ) + } + } + val routesArrayElements = + writable { + for (operationShape in operations) { + val fieldName = builderFieldNames[operationShape]!! + val (specBuilderFunctionName, _) = requestSpecMap.getValue(operationShape) + rust( + """ + ($requestSpecsModuleName::$specBuilderFunctionName(), self.$fieldName.expect($expectMessageVariableName)), + """, + ) + } } - /// Sets the [`$structName`](crate::operation_shape::$structName) operation. - /// - /// This should be an async function satisfying the [`Handler`](#{SmithyHttpServer}::operation::Handler) trait. - /// See the [operation module documentation](#{SmithyHttpServer}::operation) for more information. - /// - /// ## Example - /// - /// ```no_run - /// use $crateName::{$serviceName, ${serviceName}Config}; - /// - #{HandlerImports:W} - /// - #{HandlerFixed:W} + rustTemplate( + """ + /// Constructs a [`$serviceName`] from the arguments provided to the builder. /// - /// let config = ${serviceName}Config::builder().build()$unwrapConfigBuilder; - /// let svc = #{Tower}::util::service_fn(handler); - /// let app = $serviceName::builder(config) - /// .${fieldName}_service(svc) - /// /* Set other handlers */ - /// .build() - /// .unwrap(); - /// ## let app: $serviceName<#{SmithyHttpServer}::routing::RoutingService<#{Router}<#{SmithyHttpServer}::routing::Route>, #{Protocol}>> = app; - /// ``` + /// Forgetting to register a handler for one or more operations will result in an error. /// - pub fn ${fieldName}_service(self, service: S) -> Self - where - S: #{SmithyHttpServer}::operation::OperationService, - - ModelPl: #{SmithyHttpServer}::plugin::Plugin< - $serviceName, - crate::operation_shape::$structName, - #{SmithyHttpServer}::operation::Normalize - >, - #{SmithyHttpServer}::operation::UpgradePlugin::: #{SmithyHttpServer}::plugin::Plugin< - $serviceName, - crate::operation_shape::$structName, - ModelPl::Output - >, - HttpPl: #{SmithyHttpServer}::plugin::Plugin< - $serviceName, - crate::operation_shape::$structName, - < - #{SmithyHttpServer}::operation::UpgradePlugin:: - as #{SmithyHttpServer}::plugin::Plugin< - $serviceName, - crate::operation_shape::$structName, - ModelPl::Output - > - >::Output + /// Check out [`$builderName::build_unchecked`] if you'd prefer the service to return status code 500 when an + /// unspecified route is requested. + pub fn build(self) -> Result< + $serviceName< + #{SmithyHttpServer}::routing::RoutingService< + #{Router}, + #{Protocol}, + >, >, - - HttpPl::Output: #{Tower}::Service<#{Http}::Request, Response = #{Http}::Response<#{SmithyHttpServer}::body::BoxBody>, Error = ::std::convert::Infallible> + Clone + Send + 'static, - >>::Future: Send + 'static, - - { - use #{SmithyHttpServer}::operation::OperationShapeExt; - use #{SmithyHttpServer}::plugin::Plugin; - let svc = crate::operation_shape::$structName::from_service(service); - let svc = self.model_plugin.apply(svc); - let svc = #{SmithyHttpServer}::operation::UpgradePlugin::::new().apply(svc); - let svc = self.http_plugin.apply(svc); - self.${fieldName}_custom(svc) - } - - /// Sets the [`$structName`](crate::operation_shape::$structName) to a custom [`Service`](tower::Service). - /// not constrained by the Smithy contract. - fn ${fieldName}_custom(mut self, svc: S) -> Self + MissingOperationsError, + > where - S: #{Tower}::Service<#{Http}::Request, Response = #{Http}::Response<#{SmithyHttpServer}::body::BoxBody>, Error = ::std::convert::Infallible> + Clone + Send + 'static, - S::Future: Send + 'static, + L: #{Tower}::Layer<#{SmithyHttpServer}::routing::Route>, { - self.$fieldName = Some(#{SmithyHttpServer}::routing::Route::new(svc)); - self + let router = { + use #{SmithyHttpServer}::operation::OperationShape; + let mut $missingOperationsVariableName = std::collections::HashMap::new(); + #{NullabilityChecks:W} + if !$missingOperationsVariableName.is_empty() { + return Err(MissingOperationsError { + operation_names2setter_methods: $missingOperationsVariableName, + }); + } + let $expectMessageVariableName = "this should never panic since we are supposed to check beforehand that a handler has been registered for this operation; please file a bug report under https://github.com/smithy-lang/smithy-rs/issues"; + + #{PatternInitializations:W} + + #{Router}::from_iter([#{RoutesArrayElements:W}]) + }; + let svc = #{SmithyHttpServer}::routing::RoutingService::new(router); + let svc = svc.map(|s| s.layer(self.layer)); + Ok($serviceName { svc }) } """, - "Router" to protocol.routerType(), - "Protocol" to protocol.markerStruct(), - "Handler" to handler, - "HandlerFixed" to handlerFixed, - "HandlerImports" to handlerImports(crateName, operations), *codegenScope, + "Protocol" to protocol.markerStruct(), + "Router" to protocol.routerType(), + "NullabilityChecks" to nullabilityChecks, + "RoutesArrayElements" to routesArrayElements, + "PatternInitializations" to patternInitializations(), ) - - // Adds newline between setters. - rust("") - } - } - - private fun buildMethod(): Writable = writable { - val missingOperationsVariableName = "missing_operation_names" - val expectMessageVariableName = "unexpected_error_msg" - - val nullabilityChecks = writable { - for (operationShape in operations) { - val fieldName = builderFieldNames[operationShape]!! - val operationZstTypeName = operationStructNames[operationShape]!! - rust( - """ - if self.$fieldName.is_none() { - $missingOperationsVariableName.insert(crate::operation_shape::$operationZstTypeName::ID, ".$fieldName()"); - } - """, - ) - } } - val routesArrayElements = writable { - for (operationShape in operations) { - val fieldName = builderFieldNames[operationShape]!! - val (specBuilderFunctionName, _) = requestSpecMap.getValue(operationShape) - rust( - """ - ($requestSpecsModuleName::$specBuilderFunctionName(), self.$fieldName.expect($expectMessageVariableName)), - """, - ) - } - } - - rustTemplate( - """ - /// Constructs a [`$serviceName`] from the arguments provided to the builder. - /// - /// Forgetting to register a handler for one or more operations will result in an error. - /// - /// Check out [`$builderName::build_unchecked`] if you'd prefer the service to return status code 500 when an - /// unspecified route is requested. - pub fn build(self) -> Result< - $serviceName< - #{SmithyHttpServer}::routing::RoutingService< - #{Router}, - #{Protocol}, - >, - >, - MissingOperationsError, - > - where - L: #{Tower}::Layer<#{SmithyHttpServer}::routing::Route>, - { - let router = { - use #{SmithyHttpServer}::operation::OperationShape; - let mut $missingOperationsVariableName = std::collections::HashMap::new(); - #{NullabilityChecks:W} - if !$missingOperationsVariableName.is_empty() { - return Err(MissingOperationsError { - operation_names2setter_methods: $missingOperationsVariableName, - }); - } - let $expectMessageVariableName = "this should never panic since we are supposed to check beforehand that a handler has been registered for this operation; please file a bug report under https://github.com/smithy-lang/smithy-rs/issues"; - - #{PatternInitializations:W} - - #{Router}::from_iter([#{RoutesArrayElements:W}]) - }; - let svc = #{SmithyHttpServer}::routing::RoutingService::new(router); - let svc = svc.map(|s| s.layer(self.layer)); - Ok($serviceName { svc }) - } - """, - *codegenScope, - "Protocol" to protocol.markerStruct(), - "Router" to protocol.routerType(), - "NullabilityChecks" to nullabilityChecks, - "RoutesArrayElements" to routesArrayElements, - "PatternInitializations" to patternInitializations(), - ) - } /** * Renders `PatternString::compile_regex()` function calls for every @@ -350,14 +357,15 @@ class ServerServiceGenerator( */ @Suppress("DEPRECATION") private fun patternInitializations(): Writable { - val patterns = Walker(model).walkShapes(service) - .filter { shape -> shape is StringShape && shape.hasTrait() && !shape.hasTrait() } - .map { shape -> codegenContext.constrainedShapeSymbolProvider.toSymbol(shape) } - .map { symbol -> - writable { - rustTemplate("#{Type}::compile_regex();", "Type" to symbol) + val patterns = + Walker(model).walkShapes(service) + .filter { shape -> shape is StringShape && shape.hasTrait() && !shape.hasTrait() } + .map { shape -> codegenContext.constrainedShapeSymbolProvider.toSymbol(shape) } + .map { symbol -> + writable { + rustTemplate("#{Type}::compile_regex();", "Type" to symbol) + } } - } patterns.letIf(patterns.isNotEmpty()) { val docs = listOf(writable { rust("// Eagerly initialize regexes for `@pattern` strings.") }) @@ -368,406 +376,417 @@ class ServerServiceGenerator( return patterns.join("") } - private fun buildUncheckedMethod(): Writable = writable { - val pairs = writable { - for (operationShape in operations) { - val fieldName = builderFieldNames[operationShape]!! - val (specBuilderFunctionName, _) = requestSpecMap.getValue(operationShape) - rustTemplate( - """ - ( - $requestSpecsModuleName::$specBuilderFunctionName(), - self.$fieldName.unwrap_or_else(|| { - let svc = #{SmithyHttpServer}::operation::MissingFailure::<#{Protocol}>::default(); - #{SmithyHttpServer}::routing::Route::new(svc) - }) - ), - """, - "SmithyHttpServer" to smithyHttpServer, - "Protocol" to protocol.markerStruct(), - ) - } + private fun buildUncheckedMethod(): Writable = + writable { + val pairs = + writable { + for (operationShape in operations) { + val fieldName = builderFieldNames[operationShape]!! + val (specBuilderFunctionName, _) = requestSpecMap.getValue(operationShape) + rustTemplate( + """ + ( + $requestSpecsModuleName::$specBuilderFunctionName(), + self.$fieldName.unwrap_or_else(|| { + let svc = #{SmithyHttpServer}::operation::MissingFailure::<#{Protocol}>::default(); + #{SmithyHttpServer}::routing::Route::new(svc) + }) + ), + """, + "SmithyHttpServer" to smithyHttpServer, + "Protocol" to protocol.markerStruct(), + ) + } + } + rustTemplate( + """ + /// Constructs a [`$serviceName`] from the arguments provided to the builder. + /// Operations without a handler default to returning 500 Internal Server Error to the caller. + /// + /// Check out [`$builderName::build`] if you'd prefer the builder to fail if one or more operations do + /// not have a registered handler. + pub fn build_unchecked(self) -> $serviceName + where + Body: Send + 'static, + L: #{Tower}::Layer< + #{SmithyHttpServer}::routing::RoutingService<#{Router}<#{SmithyHttpServer}::routing::Route>, #{Protocol}> + > + { + let router = #{Router}::from_iter([#{Pairs:W}]); + let svc = self + .layer + .layer(#{SmithyHttpServer}::routing::RoutingService::new(router)); + $serviceName { svc } + } + """, + *codegenScope, + "Protocol" to protocol.markerStruct(), + "Router" to protocol.routerType(), + "Pairs" to pairs, + ) } - rustTemplate( - """ - /// Constructs a [`$serviceName`] from the arguments provided to the builder. - /// Operations without a handler default to returning 500 Internal Server Error to the caller. - /// - /// Check out [`$builderName::build`] if you'd prefer the builder to fail if one or more operations do - /// not have a registered handler. - pub fn build_unchecked(self) -> $serviceName - where - Body: Send + 'static, - L: #{Tower}::Layer< - #{SmithyHttpServer}::routing::RoutingService<#{Router}<#{SmithyHttpServer}::routing::Route>, #{Protocol}> - > - { - let router = #{Router}::from_iter([#{Pairs:W}]); - let svc = self - .layer - .layer(#{SmithyHttpServer}::routing::RoutingService::new(router)); - $serviceName { svc } - } - """, - *codegenScope, - "Protocol" to protocol.markerStruct(), - "Router" to protocol.routerType(), - "Pairs" to pairs, - ) - } /** Returns a `Writable` containing the builder struct definition and its implementations. */ - private fun builder(): Writable = writable { - val builderGenerics = listOf("Body", "L", "HttpPl", "ModelPl").joinToString(", ") - rustTemplate( - """ - /// The service builder for [`$serviceName`]. - /// - /// Constructed via [`$serviceName::builder`]. - pub struct $builderName<$builderGenerics> { - ${builderFields.joinToString(", ")}, - layer: L, - http_plugin: HttpPl, - model_plugin: ModelPl - } - - impl<$builderGenerics> $builderName<$builderGenerics> { - #{Setters:W} - } + private fun builder(): Writable = + writable { + val builderGenerics = listOf("Body", "L", "HttpPl", "ModelPl").joinToString(", ") + rustTemplate( + """ + /// The service builder for [`$serviceName`]. + /// + /// Constructed via [`$serviceName::builder`]. + pub struct $builderName<$builderGenerics> { + ${builderFields.joinToString(", ")}, + layer: L, + http_plugin: HttpPl, + model_plugin: ModelPl + } - impl<$builderGenerics> $builderName<$builderGenerics> { - #{BuildMethod:W} + impl<$builderGenerics> $builderName<$builderGenerics> { + #{Setters:W} + } - #{BuildUncheckedMethod:W} - } - """, - "Setters" to builderSetters(), - "BuildMethod" to buildMethod(), - "BuildUncheckedMethod" to buildUncheckedMethod(), - *codegenScope, - ) - } + impl<$builderGenerics> $builderName<$builderGenerics> { + #{BuildMethod:W} - private fun requestSpecsModule(): Writable = writable { - val functions = writable { - for ((_, function) in requestSpecMap.values) { - rustTemplate( - """ - pub(super) #{Function:W} - """, - "Function" to function, - ) - } + #{BuildUncheckedMethod:W} + } + """, + "Setters" to builderSetters(), + "BuildMethod" to buildMethod(), + "BuildUncheckedMethod" to buildUncheckedMethod(), + *codegenScope, + ) } - rustTemplate( - """ - mod $requestSpecsModuleName { - #{SpecFunctions:W} - } - """, - "SpecFunctions" to functions, - ) - } - /** Returns a `Writable` comma delimited sequence of `builder_field: None`. */ - private fun notSetFields(): Writable = builderFieldNames.values.map { + private fun requestSpecsModule(): Writable = writable { + val functions = + writable { + for ((_, function) in requestSpecMap.values) { + rustTemplate( + """ + pub(super) #{Function:W} + """, + "Function" to function, + ) + } + } rustTemplate( - "$it: None", - *codegenScope, + """ + mod $requestSpecsModuleName { + #{SpecFunctions:W} + } + """, + "SpecFunctions" to functions, ) } - }.join(", ") + + /** Returns a `Writable` comma delimited sequence of `builder_field: None`. */ + private fun notSetFields(): Writable = + builderFieldNames.values.map { + writable { + rustTemplate( + "$it: None", + *codegenScope, + ) + } + }.join(", ") /** Returns a `Writable` containing the service struct definition and its implementations. */ - private fun serviceStruct(): Writable = writable { - documentShape(service, model) + private fun serviceStruct(): Writable = + writable { + documentShape(service, model) - rustTemplate( - """ - /// - /// See the [root](crate) documentation for more information. - ##[derive(Clone)] - pub struct $serviceName< - S = #{SmithyHttpServer}::routing::RoutingService< - #{Router}< - #{SmithyHttpServer}::routing::Route< - #{SmithyHttpServer}::body::BoxBody + rustTemplate( + """ + /// + /// See the [root](crate) documentation for more information. + ##[derive(Clone)] + pub struct $serviceName< + S = #{SmithyHttpServer}::routing::RoutingService< + #{Router}< + #{SmithyHttpServer}::routing::Route< + #{SmithyHttpServer}::body::BoxBody + >, >, - >, - #{Protocol}, - > - > { - // This is the router wrapped by layers. - svc: S, - } + #{Protocol}, + > + > { + // This is the router wrapped by layers. + svc: S, + } - impl $serviceName<()> { - /// Constructs a builder for [`$serviceName`]. - /// You must specify a configuration object holding any plugins and layers that should be applied - /// to the operations in this service. - pub fn builder< - Body, - L, - HttpPl: #{SmithyHttpServer}::plugin::HttpMarker, - ModelPl: #{SmithyHttpServer}::plugin::ModelMarker, - >( - config: ${serviceName}Config, - ) -> $builderName { - $builderName { - #{NotSetFields1:W}, - layer: config.layers, - http_plugin: config.http_plugins, - model_plugin: config.model_plugins, + impl $serviceName<()> { + /// Constructs a builder for [`$serviceName`]. + /// You must specify a configuration object holding any plugins and layers that should be applied + /// to the operations in this service. + pub fn builder< + Body, + L, + HttpPl: #{SmithyHttpServer}::plugin::HttpMarker, + ModelPl: #{SmithyHttpServer}::plugin::ModelMarker, + >( + config: ${serviceName}Config, + ) -> $builderName { + $builderName { + #{NotSetFields1:W}, + layer: config.layers, + http_plugin: config.http_plugins, + model_plugin: config.model_plugins, + } } - } - /// Constructs a builder for [`$serviceName`]. - /// You must specify what plugins should be applied to the operations in this service. - /// - /// Use [`$serviceName::builder_without_plugins`] if you don't need to apply plugins. - /// - /// Check out [`HttpPlugins`](#{SmithyHttpServer}::plugin::HttpPlugins) and - /// [`ModelPlugins`](#{SmithyHttpServer}::plugin::ModelPlugins) if you need to apply - /// multiple plugins. - ##[deprecated( - since = "0.57.0", - note = "please use the `builder` constructor and register plugins on the `${serviceName}Config` object instead; see https://github.com/smithy-lang/smithy-rs/discussions/3096" - )] - pub fn builder_with_plugins< - Body, - HttpPl: #{SmithyHttpServer}::plugin::HttpMarker, - ModelPl: #{SmithyHttpServer}::plugin::ModelMarker - >( - http_plugin: HttpPl, - model_plugin: ModelPl - ) -> $builderName { - $builderName { - #{NotSetFields2:W}, - layer: #{Tower}::layer::util::Identity::new(), - http_plugin, - model_plugin + /// Constructs a builder for [`$serviceName`]. + /// You must specify what plugins should be applied to the operations in this service. + /// + /// Use [`$serviceName::builder_without_plugins`] if you don't need to apply plugins. + /// + /// Check out [`HttpPlugins`](#{SmithyHttpServer}::plugin::HttpPlugins) and + /// [`ModelPlugins`](#{SmithyHttpServer}::plugin::ModelPlugins) if you need to apply + /// multiple plugins. + ##[deprecated( + since = "0.57.0", + note = "please use the `builder` constructor and register plugins on the `${serviceName}Config` object instead; see https://github.com/smithy-lang/smithy-rs/discussions/3096" + )] + pub fn builder_with_plugins< + Body, + HttpPl: #{SmithyHttpServer}::plugin::HttpMarker, + ModelPl: #{SmithyHttpServer}::plugin::ModelMarker + >( + http_plugin: HttpPl, + model_plugin: ModelPl + ) -> $builderName { + $builderName { + #{NotSetFields2:W}, + layer: #{Tower}::layer::util::Identity::new(), + http_plugin, + model_plugin + } } - } - /// Constructs a builder for [`$serviceName`]. - /// - /// Use [`$serviceName::builder_with_plugins`] if you need to specify plugins. - ##[deprecated( - since = "0.57.0", - note = "please use the `builder` constructor instead; see https://github.com/smithy-lang/smithy-rs/discussions/3096" - )] - pub fn builder_without_plugins() -> $builderName< - Body, - #{Tower}::layer::util::Identity, - #{SmithyHttpServer}::plugin::IdentityPlugin, - #{SmithyHttpServer}::plugin::IdentityPlugin - > { - Self::builder_with_plugins(#{SmithyHttpServer}::plugin::IdentityPlugin, #{SmithyHttpServer}::plugin::IdentityPlugin) + /// Constructs a builder for [`$serviceName`]. + /// + /// Use [`$serviceName::builder_with_plugins`] if you need to specify plugins. + ##[deprecated( + since = "0.57.0", + note = "please use the `builder` constructor instead; see https://github.com/smithy-lang/smithy-rs/discussions/3096" + )] + pub fn builder_without_plugins() -> $builderName< + Body, + #{Tower}::layer::util::Identity, + #{SmithyHttpServer}::plugin::IdentityPlugin, + #{SmithyHttpServer}::plugin::IdentityPlugin + > { + Self::builder_with_plugins(#{SmithyHttpServer}::plugin::IdentityPlugin, #{SmithyHttpServer}::plugin::IdentityPlugin) + } } - } - impl $serviceName { - /// Converts [`$serviceName`] into a [`MakeService`](tower::make::MakeService). - pub fn into_make_service(self) -> #{SmithyHttpServer}::routing::IntoMakeService { - #{SmithyHttpServer}::routing::IntoMakeService::new(self) - } + impl $serviceName { + /// Converts [`$serviceName`] into a [`MakeService`](tower::make::MakeService). + pub fn into_make_service(self) -> #{SmithyHttpServer}::routing::IntoMakeService { + #{SmithyHttpServer}::routing::IntoMakeService::new(self) + } - /// Converts [`$serviceName`] into a [`MakeService`](tower::make::MakeService) with [`ConnectInfo`](#{SmithyHttpServer}::request::connect_info::ConnectInfo). - pub fn into_make_service_with_connect_info(self) -> #{SmithyHttpServer}::routing::IntoMakeServiceWithConnectInfo { - #{SmithyHttpServer}::routing::IntoMakeServiceWithConnectInfo::new(self) + /// Converts [`$serviceName`] into a [`MakeService`](tower::make::MakeService) with [`ConnectInfo`](#{SmithyHttpServer}::request::connect_info::ConnectInfo). + pub fn into_make_service_with_connect_info(self) -> #{SmithyHttpServer}::routing::IntoMakeServiceWithConnectInfo { + #{SmithyHttpServer}::routing::IntoMakeServiceWithConnectInfo::new(self) + } } - } - impl - $serviceName< - #{SmithyHttpServer}::routing::RoutingService< - #{Router}, - #{Protocol}, - >, - > - { - /// Applies a [`Layer`](#{Tower}::Layer) uniformly to all routes. - ##[deprecated( - since = "0.57.0", - note = "please add layers to the `${serviceName}Config` object instead; see https://github.com/smithy-lang/smithy-rs/discussions/3096" - )] - pub fn layer( - self, - layer: &L, - ) -> $serviceName< - #{SmithyHttpServer}::routing::RoutingService< - #{Router}, - #{Protocol}, - >, - > - where - L: #{Tower}::Layer, + impl + $serviceName< + #{SmithyHttpServer}::routing::RoutingService< + #{Router}, + #{Protocol}, + >, + > { - $serviceName { - svc: self.svc.map(|s| s.layer(layer)), + /// Applies a [`Layer`](#{Tower}::Layer) uniformly to all routes. + ##[deprecated( + since = "0.57.0", + note = "please add layers to the `${serviceName}Config` object instead; see https://github.com/smithy-lang/smithy-rs/discussions/3096" + )] + pub fn layer( + self, + layer: &L, + ) -> $serviceName< + #{SmithyHttpServer}::routing::RoutingService< + #{Router}, + #{Protocol}, + >, + > + where + L: #{Tower}::Layer, + { + $serviceName { + svc: self.svc.map(|s| s.layer(layer)), + } } - } - /// Applies [`Route::new`](#{SmithyHttpServer}::routing::Route::new) to all routes. - /// - /// This has the effect of erasing all types accumulated via layers. - pub fn boxed( - self, - ) -> $serviceName< - #{SmithyHttpServer}::routing::RoutingService< - #{Router}< - #{SmithyHttpServer}::routing::Route, + /// Applies [`Route::new`](#{SmithyHttpServer}::routing::Route::new) to all routes. + /// + /// This has the effect of erasing all types accumulated via layers. + pub fn boxed( + self, + ) -> $serviceName< + #{SmithyHttpServer}::routing::RoutingService< + #{Router}< + #{SmithyHttpServer}::routing::Route, + >, + #{Protocol}, >, - #{Protocol}, - >, - > + > + where + S: #{Tower}::Service< + #{Http}::Request, + Response = #{Http}::Response<#{SmithyHttpServer}::body::BoxBody>, + Error = std::convert::Infallible, + >, + S: Clone + Send + 'static, + S::Future: Send + 'static, + { + self.layer(&::tower::layer::layer_fn( + #{SmithyHttpServer}::routing::Route::new, + )) + } + } + + impl #{Tower}::Service for $serviceName where - S: #{Tower}::Service< - #{Http}::Request, - Response = #{Http}::Response<#{SmithyHttpServer}::body::BoxBody>, - Error = std::convert::Infallible, - >, - S: Clone + Send + 'static, - S::Future: Send + 'static, + S: #{Tower}::Service, { - self.layer(&::tower::layer::layer_fn( - #{SmithyHttpServer}::routing::Route::new, - )) - } - } + type Response = S::Response; + type Error = S::Error; + type Future = S::Future; - impl #{Tower}::Service for $serviceName - where - S: #{Tower}::Service, - { - type Response = S::Response; - type Error = S::Error; - type Future = S::Future; + fn poll_ready(&mut self, cx: &mut std::task::Context) -> std::task::Poll> { + self.svc.poll_ready(cx) + } - fn poll_ready(&mut self, cx: &mut std::task::Context) -> std::task::Poll> { - self.svc.poll_ready(cx) + fn call(&mut self, request: R) -> Self::Future { + self.svc.call(request) + } } + """, + "NotSetFields1" to notSetFields(), + "NotSetFields2" to notSetFields(), + "Router" to protocol.routerType(), + "Protocol" to protocol.markerStruct(), + *codegenScope, + ) + } - fn call(&mut self, request: R) -> Self::Future { - self.svc.call(request) + private fun missingOperationsError(): Writable = + writable { + rustTemplate( + """ + /// The error encountered when calling the [`$builderName::build`] method if one or more operation handlers are not + /// specified. + ##[derive(Debug)] + pub struct MissingOperationsError { + operation_names2setter_methods: std::collections::HashMap<#{SmithyHttpServer}::shape_id::ShapeId, &'static str>, } - } - """, - "NotSetFields1" to notSetFields(), - "NotSetFields2" to notSetFields(), - "Router" to protocol.routerType(), - "Protocol" to protocol.markerStruct(), - *codegenScope, - ) - } - private fun missingOperationsError(): Writable = writable { - rustTemplate( - """ - /// The error encountered when calling the [`$builderName::build`] method if one or more operation handlers are not - /// specified. - ##[derive(Debug)] - pub struct MissingOperationsError { - operation_names2setter_methods: std::collections::HashMap<#{SmithyHttpServer}::shape_id::ShapeId, &'static str>, - } - - impl std::fmt::Display for MissingOperationsError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "You must specify a handler for all operations attached to `$serviceName`.\n\ - We are missing handlers for the following operations:\n", - )?; - for operation_name in self.operation_names2setter_methods.keys() { - writeln!(f, "- {}", operation_name.absolute())?; + impl std::fmt::Display for MissingOperationsError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "You must specify a handler for all operations attached to `$serviceName`.\n\ + We are missing handlers for the following operations:\n", + )?; + for operation_name in self.operation_names2setter_methods.keys() { + writeln!(f, "- {}", operation_name.absolute())?; + } + + writeln!(f, "\nUse the dedicated methods on `$builderName` to register the missing handlers:")?; + for setter_name in self.operation_names2setter_methods.values() { + writeln!(f, "- {}", setter_name)?; + } + Ok(()) } - - writeln!(f, "\nUse the dedicated methods on `$builderName` to register the missing handlers:")?; - for setter_name in self.operation_names2setter_methods.values() { - writeln!(f, "- {}", setter_name)?; - } - Ok(()) } - } - - impl std::error::Error for MissingOperationsError {} - """, - *codegenScope, - ) - } - - private fun serviceShapeImpl(): Writable = writable { - val namespace = serviceId.namespace - val name = serviceId.name - val absolute = serviceId.toString().replace("#", "##") - val version = codegenContext.serviceShape.version?.let { "Some(\"$it\")" } ?: "None" - rustTemplate( - """ - impl #{SmithyHttpServer}::service::ServiceShape for $serviceName { - const ID: #{SmithyHttpServer}::shape_id::ShapeId = #{SmithyHttpServer}::shape_id::ShapeId::new("$absolute", "$namespace", "$name"); - const VERSION: Option<&'static str> = $version; + impl std::error::Error for MissingOperationsError {} + """, + *codegenScope, + ) + } - type Protocol = #{Protocol}; + private fun serviceShapeImpl(): Writable = + writable { + val namespace = serviceId.namespace + val name = serviceId.name + val absolute = serviceId.toString().replace("#", "##") + val version = codegenContext.serviceShape.version?.let { "Some(\"$it\")" } ?: "None" + rustTemplate( + """ + impl #{SmithyHttpServer}::service::ServiceShape for $serviceName { + const ID: #{SmithyHttpServer}::shape_id::ShapeId = #{SmithyHttpServer}::shape_id::ShapeId::new("$absolute", "$namespace", "$name"); - type Operations = Operation; - } - """, - "Protocol" to protocol.markerStruct(), - *codegenScope, - ) - } + const VERSION: Option<&'static str> = $version; - private fun operationEnum(): Writable = writable { - val operations = operationStructNames.values.joinToString(",") - val matchArms: Writable = operationStructNames.map { - (shape, name) -> - writable { - val absolute = shape.id.toString().replace("#", "##") - rustTemplate( - """ - Operation::$name => #{SmithyHttpServer}::shape_id::ShapeId::new("$absolute", "${shape.id.namespace}", "${shape.id.name}") - """, - *codegenScope, - ) - } - }.join(",") - rustTemplate( - """ - /// An enumeration of all [operations](https://smithy.io/2.0/spec/service-types.html##operation) in $serviceName. - ##[derive(Debug, PartialEq, Eq, Clone, Copy)] - pub enum Operation { - $operations - } + type Protocol = #{Protocol}; - impl Operation { - /// Returns the [operations](https://smithy.io/2.0/spec/service-types.html##operation) [`ShapeId`](#{SmithyHttpServer}::shape_id::ShapeId). - pub fn shape_id(&self) -> #{SmithyHttpServer}::shape_id::ShapeId { - match self { - #{Arms} - } + type Operations = Operation; } - } - """, - *codegenScope, - "Arms" to matchArms, - ) + """, + "Protocol" to protocol.markerStruct(), + *codegenScope, + ) + } - for ((_, value) in operationStructNames) { + private fun operationEnum(): Writable = + writable { + val operations = operationStructNames.values.joinToString(",") + val matchArms: Writable = + operationStructNames.map { + (shape, name) -> + writable { + val absolute = shape.id.toString().replace("#", "##") + rustTemplate( + """ + Operation::$name => #{SmithyHttpServer}::shape_id::ShapeId::new("$absolute", "${shape.id.namespace}", "${shape.id.name}") + """, + *codegenScope, + ) + } + }.join(",") rustTemplate( """ - impl #{SmithyHttpServer}::service::ContainsOperation - for $serviceName - { - const VALUE: Operation = Operation::$value; + /// An enumeration of all [operations](https://smithy.io/2.0/spec/service-types.html##operation) in $serviceName. + ##[derive(Debug, PartialEq, Eq, Clone, Copy)] + pub enum Operation { + $operations + } + + impl Operation { + /// Returns the [operations](https://smithy.io/2.0/spec/service-types.html##operation) [`ShapeId`](#{SmithyHttpServer}::shape_id::ShapeId). + pub fn shape_id(&self) -> #{SmithyHttpServer}::shape_id::ShapeId { + match self { + #{Arms} + } + } } """, *codegenScope, + "Arms" to matchArms, ) + + for ((_, value) in operationStructNames) { + rustTemplate( + """ + impl #{SmithyHttpServer}::service::ContainsOperation + for $serviceName + { + const VALUE: Operation = Operation::$value; + } + """, + *codegenScope, + ) + } } - } fun render(writer: RustWriter) { writer.rustTemplate( @@ -802,7 +821,11 @@ class ServerServiceGenerator( * use my_service::{input, output, error}; * ``` */ -fun handlerImports(crateName: String, operations: Collection, commentToken: String = "///") = writable { +fun handlerImports( + crateName: String, + operations: Collection, + commentToken: String = "///", +) = writable { val hasErrors = operations.any { it.errors.isNotEmpty() } val errorImport = if (hasErrors) ", ${ErrorModule.name}" else "" if (operations.isNotEmpty()) { diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceConfigGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceConfigGenerator.kt index 525487968fa..b1ca880ebc1 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceConfigGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceConfigGenerator.kt @@ -120,26 +120,28 @@ class ServiceConfigGenerator( ) { private val crateName = codegenContext.moduleUseName() private val smithyHttpServer = ServerCargoDependency.smithyHttpServer(codegenContext.runtimeConfig).toType() - private val codegenScope = arrayOf( - *preludeScope, - "Debug" to RuntimeType.Debug, - "SmithyHttpServer" to smithyHttpServer, - "PluginStack" to smithyHttpServer.resolve("plugin::PluginStack"), - "ModelMarker" to smithyHttpServer.resolve("plugin::ModelMarker"), - "HttpMarker" to smithyHttpServer.resolve("plugin::HttpMarker"), - "Tower" to RuntimeType.Tower, - "Stack" to RuntimeType.Tower.resolve("layer::util::Stack"), - ) + private val codegenScope = + arrayOf( + *preludeScope, + "Debug" to RuntimeType.Debug, + "SmithyHttpServer" to smithyHttpServer, + "PluginStack" to smithyHttpServer.resolve("plugin::PluginStack"), + "ModelMarker" to smithyHttpServer.resolve("plugin::ModelMarker"), + "HttpMarker" to smithyHttpServer.resolve("plugin::HttpMarker"), + "Tower" to RuntimeType.Tower, + "Stack" to RuntimeType.Tower.resolve("layer::util::Stack"), + ) private val serviceName = codegenContext.serviceShape.id.name.toPascalCase() fun render(writer: RustWriter) { - val unwrapConfigBuilder = if (isBuilderFallible) { - """ - /// .expect("config failed to build"); - """ - } else { - ";" - } + val unwrapConfigBuilder = + if (isBuilderFallible) { + """ + /// .expect("config failed to build"); + """ + } else { + ";" + } writer.rustTemplate( """ @@ -197,12 +199,12 @@ class ServiceConfigGenerator( pub(crate) model_plugins: M, #{BuilderRequiredMethodFlagDefinitions:W} } - + #{BuilderRequiredMethodError:W} impl ${serviceName}ConfigBuilder { #{InjectedMethods:W} - + /// Add a [`#{Tower}::Layer`] to the service. pub fn layer(self, layer: NewLayer) -> ${serviceName}ConfigBuilder<#{Stack}, H, M> { ${serviceName}ConfigBuilder { @@ -246,7 +248,7 @@ class ServiceConfigGenerator( #{BuilderRequiredMethodFlagsMove3:W} } } - + #{BuilderBuildMethod:W} } """, @@ -264,182 +266,198 @@ class ServiceConfigGenerator( private val isBuilderFallible = configMethods.isBuilderFallible() - private fun builderBuildRequiredMethodChecks() = configMethods.filter { it.isRequired }.map { - writable { - rustTemplate( - """ - if !self.${it.requiredBuilderFlagName()} { - return #{Err}(${serviceName}ConfigError::${it.requiredErrorVariant()}); - } - """, - *codegenScope, - ) - } - }.join("\n") + private fun builderBuildRequiredMethodChecks() = + configMethods.filter { it.isRequired }.map { + writable { + rustTemplate( + """ + if !self.${it.requiredBuilderFlagName()} { + return #{Err}(${serviceName}ConfigError::${it.requiredErrorVariant()}); + } + """, + *codegenScope, + ) + } + }.join("\n") - private fun builderRequiredMethodFlagsDefinitions() = configMethods.filter { it.isRequired }.map { - writable { rust("pub(crate) ${it.requiredBuilderFlagName()}: bool,") } - }.join("\n") + private fun builderRequiredMethodFlagsDefinitions() = + configMethods.filter { it.isRequired }.map { + writable { rust("pub(crate) ${it.requiredBuilderFlagName()}: bool,") } + }.join("\n") - private fun builderRequiredMethodFlagsInit() = configMethods.filter { it.isRequired }.map { - writable { rust("${it.requiredBuilderFlagName()}: false,") } - }.join("\n") + private fun builderRequiredMethodFlagsInit() = + configMethods.filter { it.isRequired }.map { + writable { rust("${it.requiredBuilderFlagName()}: false,") } + }.join("\n") - private fun builderRequiredMethodFlagsMove() = configMethods.filter { it.isRequired }.map { - writable { rust("${it.requiredBuilderFlagName()}: self.${it.requiredBuilderFlagName()},") } - }.join("\n") + private fun builderRequiredMethodFlagsMove() = + configMethods.filter { it.isRequired }.map { + writable { rust("${it.requiredBuilderFlagName()}: self.${it.requiredBuilderFlagName()},") } + }.join("\n") - private fun builderRequiredMethodError() = writable { - if (isBuilderFallible) { - val variants = configMethods.filter { it.isRequired }.map { - writable { - rust( - """ - ##[error("service is not fully configured; invoke `${it.name}` on the config builder")] - ${it.requiredErrorVariant()}, + private fun builderRequiredMethodError() = + writable { + if (isBuilderFallible) { + val variants = + configMethods.filter { it.isRequired }.map { + writable { + rust( + """ + ##[error("service is not fully configured; invoke `${it.name}` on the config builder")] + ${it.requiredErrorVariant()}, + """, + ) + } + } + rustTemplate( + """ + ##[derive(Debug, #{ThisError}::Error)] + pub enum ${serviceName}ConfigError { + #{Variants:W} + } """, - ) - } + "ThisError" to ServerCargoDependency.ThisError.toType(), + "Variants" to variants.join("\n"), + ) } - rustTemplate( - """ - ##[derive(Debug, #{ThisError}::Error)] - pub enum ${serviceName}ConfigError { - #{Variants:W} - } - """, - "ThisError" to ServerCargoDependency.ThisError.toType(), - "Variants" to variants.join("\n"), - ) } - } - private fun injectedMethods() = configMethods.map { - writable { - val paramBindings = it.params.map { binding -> - writable { rustTemplate("${binding.name}: #{BindingTy},", "BindingTy" to binding.ty) } - }.join("\n") - - // This produces a nested type like: "S>", where - // - "S" denotes a "stack type" with two generic type parameters: the first is the "inner" part of the stack - // and the second is the "outer" part of the stack. The outer part gets executed first. For an example, - // see `aws_smithy_http_server::plugin::PluginStack`. - // - "A", "B" are the types of the "things" that are added. - // - "T" is the generic type variable name used in the enclosing impl block. - fun List.stackReturnType(genericTypeVarName: String, stackType: RuntimeType): Writable = - this.fold(writable { rust(genericTypeVarName) }) { acc, next -> + private fun injectedMethods() = + configMethods.map { + writable { + val paramBindings = + it.params.map { binding -> + writable { rustTemplate("${binding.name}: #{BindingTy},", "BindingTy" to binding.ty) } + }.join("\n") + + // This produces a nested type like: "S>", where + // - "S" denotes a "stack type" with two generic type parameters: the first is the "inner" part of the stack + // and the second is the "outer" part of the stack. The outer part gets executed first. For an example, + // see `aws_smithy_http_server::plugin::PluginStack`. + // - "A", "B" are the types of the "things" that are added. + // - "T" is the generic type variable name used in the enclosing impl block. + fun List.stackReturnType( + genericTypeVarName: String, + stackType: RuntimeType, + ): Writable = + this.fold(writable { rust(genericTypeVarName) }) { acc, next -> + writable { + rustTemplate( + "#{StackType}<#{Ty}, #{Acc:W}>", + "StackType" to stackType, + "Ty" to next.ty, + "Acc" to acc, + ) + } + } + + val layersReturnTy = + it.initializer.layerBindings.stackReturnType("L", RuntimeType.Tower.resolve("layer::util::Stack")) + val httpPluginsReturnTy = + it.initializer.httpPluginBindings.stackReturnType("H", smithyHttpServer.resolve("plugin::PluginStack")) + val modelPluginsReturnTy = + it.initializer.modelPluginBindings.stackReturnType("M", smithyHttpServer.resolve("plugin::PluginStack")) + + val configBuilderReturnTy = writable { rustTemplate( - "#{StackType}<#{Ty}, #{Acc:W}>", - "StackType" to stackType, - "Ty" to next.ty, - "Acc" to acc, + """ + ${serviceName}ConfigBuilder< + #{LayersReturnTy:W}, + #{HttpPluginsReturnTy:W}, + #{ModelPluginsReturnTy:W}, + > + """, + "LayersReturnTy" to layersReturnTy, + "HttpPluginsReturnTy" to httpPluginsReturnTy, + "ModelPluginsReturnTy" to modelPluginsReturnTy, ) } - } - val layersReturnTy = - it.initializer.layerBindings.stackReturnType("L", RuntimeType.Tower.resolve("layer::util::Stack")) - val httpPluginsReturnTy = - it.initializer.httpPluginBindings.stackReturnType("H", smithyHttpServer.resolve("plugin::PluginStack")) - val modelPluginsReturnTy = - it.initializer.modelPluginBindings.stackReturnType("M", smithyHttpServer.resolve("plugin::PluginStack")) + val returnTy = + if (it.errorType != null) { + writable { + rustTemplate( + "#{Result}<#{T:W}, #{E}>", + "T" to configBuilderReturnTy, + "E" to it.errorType, + *codegenScope, + ) + } + } else { + configBuilderReturnTy + } - val configBuilderReturnTy = writable { - rustTemplate( + docs(it.docs) + rustBlockTemplate( """ - ${serviceName}ConfigBuilder< - #{LayersReturnTy:W}, - #{HttpPluginsReturnTy:W}, - #{ModelPluginsReturnTy:W}, - > + pub fn ${it.name}( + ##[allow(unused_mut)] + mut self, + #{ParamBindings:W} + ) -> #{ReturnTy:W} """, - "LayersReturnTy" to layersReturnTy, - "HttpPluginsReturnTy" to httpPluginsReturnTy, - "ModelPluginsReturnTy" to modelPluginsReturnTy, - ) - } + "ReturnTy" to returnTy, + "ParamBindings" to paramBindings, + ) { + rustTemplate("#{InitializerCode:W}", "InitializerCode" to it.initializer.code) - val returnTy = if (it.errorType != null) { - writable { - rustTemplate( - "#{Result}<#{T:W}, #{E}>", - "T" to configBuilderReturnTy, - "E" to it.errorType, - *codegenScope, - ) + check(it.initializer.layerBindings.size + it.initializer.httpPluginBindings.size + it.initializer.modelPluginBindings.size > 0) { + "This method's initializer does not register any layers, HTTP plugins, or model plugins. It must register at least something!" + } + + if (it.isRequired) { + rust("self.${it.requiredBuilderFlagName()} = true;") + } + conditionalBlock("Ok(", ")", conditional = it.errorType != null) { + val registrations = + ( + it.initializer.layerBindings.map { ".layer(${it.name})" } + + it.initializer.httpPluginBindings.map { ".http_plugin(${it.name})" } + + it.initializer.modelPluginBindings.map { ".model_plugin(${it.name})" } + ).joinToString("") + rust("self$registrations") + } } + } + }.join("\n\n") + + private fun builderBuildReturnType() = + writable { + val t = "super::${serviceName}Config" + + if (isBuilderFallible) { + rustTemplate("#{Result}<$t, ${serviceName}ConfigError>", *codegenScope) } else { - configBuilderReturnTy + rust(t) } + } - docs(it.docs) + private fun builderBuildMethod() = + writable { rustBlockTemplate( """ - pub fn ${it.name}( - ##[allow(unused_mut)] - mut self, - #{ParamBindings:W} - ) -> #{ReturnTy:W} + /// Build the configuration. + pub fn build(self) -> #{BuilderBuildReturnTy:W} """, - "ReturnTy" to returnTy, - "ParamBindings" to paramBindings, + "BuilderBuildReturnTy" to builderBuildReturnType(), ) { - rustTemplate("#{InitializerCode:W}", "InitializerCode" to it.initializer.code) - - check(it.initializer.layerBindings.size + it.initializer.httpPluginBindings.size + it.initializer.modelPluginBindings.size > 0) { - "This method's initializer does not register any layers, HTTP plugins, or model plugins. It must register at least something!" - } + rustTemplate( + "#{BuilderBuildRequiredMethodChecks:W}", + "BuilderBuildRequiredMethodChecks" to builderBuildRequiredMethodChecks(), + ) - if (it.isRequired) { - rust("self.${it.requiredBuilderFlagName()} = true;") - } - conditionalBlock("Ok(", ")", conditional = it.errorType != null) { - val registrations = ( - it.initializer.layerBindings.map { ".layer(${it.name})" } + - it.initializer.httpPluginBindings.map { ".http_plugin(${it.name})" } + - it.initializer.modelPluginBindings.map { ".model_plugin(${it.name})" } - ).joinToString("") - rust("self$registrations") + conditionalBlock("Ok(", ")", isBuilderFallible) { + rust( + """ + super::${serviceName}Config { + layers: self.layers, + http_plugins: self.http_plugins, + model_plugins: self.model_plugins, + } + """, + ) } } } - }.join("\n\n") - - private fun builderBuildReturnType() = writable { - val t = "super::${serviceName}Config" - - if (isBuilderFallible) { - rustTemplate("#{Result}<$t, ${serviceName}ConfigError>", *codegenScope) - } else { - rust(t) - } - } - - private fun builderBuildMethod() = writable { - rustBlockTemplate( - """ - /// Build the configuration. - pub fn build(self) -> #{BuilderBuildReturnTy:W} - """, - "BuilderBuildReturnTy" to builderBuildReturnType(), - ) { - rustTemplate( - "#{BuilderBuildRequiredMethodChecks:W}", - "BuilderBuildRequiredMethodChecks" to builderBuildRequiredMethodChecks(), - ) - - conditionalBlock("Ok(", ")", isBuilderFallible) { - rust( - """ - super::${serviceName}Config { - layers: self.layers, - http_plugins: self.http_plugins, - model_plugins: self.model_plugins, - } - """, - ) - } - } - } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/TraitInfo.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/TraitInfo.kt index 2179130dff2..611c8400e8f 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/TraitInfo.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/TraitInfo.kt @@ -44,13 +44,14 @@ fun RustWriter.renderTryFrom( #{ValidationFunctions:W} } """, - "ValidationFunctions" to constraintsInfo.map { - it.validationFunctionDefinition( - constraintViolationError, - unconstrainedTypeName, - ) - } - .join("\n"), + "ValidationFunctions" to + constraintsInfo.map { + it.validationFunctionDefinition( + constraintViolationError, + unconstrainedTypeName, + ) + } + .join("\n"), ) this.rustTemplate( diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedCollectionGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedCollectionGenerator.kt index 6916617b6ff..041b5a13811 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedCollectionGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedCollectionGenerator.kt @@ -60,11 +60,12 @@ class UnconstrainedCollectionGenerator( } private val constraintViolationSymbol = constraintViolationSymbolProvider.toSymbol(shape) private val constrainedShapeSymbolProvider = codegenContext.constrainedShapeSymbolProvider - private val constrainedSymbol = if (shape.isDirectlyConstrained(symbolProvider)) { - constrainedShapeSymbolProvider.toSymbol(shape) - } else { - pubCrateConstrainedShapeSymbolProvider.toSymbol(shape) - } + private val constrainedSymbol = + if (shape.isDirectlyConstrained(symbolProvider)) { + constrainedShapeSymbolProvider.toSymbol(shape) + } else { + pubCrateConstrainedShapeSymbolProvider.toSymbol(shape) + } private val innerShape = model.expectShape(shape.member.target) fun render() { @@ -103,22 +104,25 @@ class UnconstrainedCollectionGenerator( !innerShape.isDirectlyConstrained(symbolProvider) && innerShape !is StructureShape && innerShape !is UnionShape - val constrainedMemberSymbol = if (resolvesToNonPublicConstrainedValueType) { - pubCrateConstrainedShapeSymbolProvider.toSymbol(shape.member) - } else { - constrainedShapeSymbolProvider.toSymbol(shape.member) - } + val constrainedMemberSymbol = + if (resolvesToNonPublicConstrainedValueType) { + pubCrateConstrainedShapeSymbolProvider.toSymbol(shape.member) + } else { + constrainedShapeSymbolProvider.toSymbol(shape.member) + } val innerConstraintViolationSymbol = constraintViolationSymbolProvider.toSymbol(innerShape) - val boxErr = if (shape.member.hasTrait()) { - ".map_err(|(idx, inner_violation)| (idx, Box::new(inner_violation)))" - } else { - "" - } - val constrainValueWritable = writable { - conditionalBlock("inner.map(|inner| ", ").transpose()", constrainedMemberSymbol.isOptional()) { - rust("inner.try_into().map_err(|inner_violation| (idx, inner_violation))") + val boxErr = + if (shape.member.hasTrait()) { + ".map_err(|(idx, inner_violation)| (idx, Box::new(inner_violation)))" + } else { + "" + } + val constrainValueWritable = + writable { + conditionalBlock("inner.map(|inner| ", ").transpose()", constrainedMemberSymbol.isOptional()) { + rust("inner.try_into().map_err(|inner_violation| (idx, inner_violation))") + } } - } rustTemplate( """ diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedMapGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedMapGenerator.kt index 0862a26987c..c0114c8a8d0 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedMapGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedMapGenerator.kt @@ -60,11 +60,12 @@ class UnconstrainedMapGenerator( } private val constraintViolationSymbol = constraintViolationSymbolProvider.toSymbol(shape) private val constrainedShapeSymbolProvider = codegenContext.constrainedShapeSymbolProvider - private val constrainedSymbol = if (shape.isDirectlyConstrained(symbolProvider)) { - constrainedShapeSymbolProvider.toSymbol(shape) - } else { - pubCrateConstrainedShapeSymbolProvider.toSymbol(shape) - } + private val constrainedSymbol = + if (shape.isDirectlyConstrained(symbolProvider)) { + constrainedShapeSymbolProvider.toSymbol(shape) + } else { + pubCrateConstrainedShapeSymbolProvider.toSymbol(shape) + } private val keyShape = model.expectShape(shape.key.target, StringShape::class.java) private val valueShape = model.expectShape(shape.value.target) @@ -107,74 +108,80 @@ class UnconstrainedMapGenerator( !valueShape.isDirectlyConstrained(symbolProvider) && valueShape !is StructureShape && valueShape !is UnionShape - val constrainedMemberValueSymbol = if (resolvesToNonPublicConstrainedValueType) { - pubCrateConstrainedShapeSymbolProvider.toSymbol(shape.value) - } else { - constrainedShapeSymbolProvider.toSymbol(shape.value) - } - val constrainedValueSymbol = if (resolvesToNonPublicConstrainedValueType) { - pubCrateConstrainedShapeSymbolProvider.toSymbol(valueShape) - } else { - constrainedShapeSymbolProvider.toSymbol(valueShape) - } + val constrainedMemberValueSymbol = + if (resolvesToNonPublicConstrainedValueType) { + pubCrateConstrainedShapeSymbolProvider.toSymbol(shape.value) + } else { + constrainedShapeSymbolProvider.toSymbol(shape.value) + } + val constrainedValueSymbol = + if (resolvesToNonPublicConstrainedValueType) { + pubCrateConstrainedShapeSymbolProvider.toSymbol(valueShape) + } else { + constrainedShapeSymbolProvider.toSymbol(valueShape) + } val constrainedKeySymbol = constrainedShapeSymbolProvider.toSymbol(keyShape) val epilogueWritable = writable { rust("Ok((k, v))") } - val constrainKeyWritable = writable { - rustTemplate( - "let k: #{ConstrainedKeySymbol} = k.try_into().map_err(Self::Error::Key)?;", - "ConstrainedKeySymbol" to constrainedKeySymbol, - ) - } - val constrainValueWritable = writable { - val boxErr = if (shape.value.hasTrait()) { - ".map_err(Box::new)" - } else { - "" + val constrainKeyWritable = + writable { + rustTemplate( + "let k: #{ConstrainedKeySymbol} = k.try_into().map_err(Self::Error::Key)?;", + "ConstrainedKeySymbol" to constrainedKeySymbol, + ) } - if (constrainedMemberValueSymbol.isOptional()) { - // The map is `@sparse`. - rustBlock("match v") { - rust("None => Ok((k, None)),") - withBlock("Some(v) =>", ",") { - // DRYing this up with the else branch below would make this less understandable. - rustTemplate( - """ - match #{ConstrainedValueSymbol}::try_from(v)$boxErr { - Ok(v) => Ok((k, Some(v))), - Err(inner_constraint_violation) => Err(Self::Error::Value(k, inner_constraint_violation)), - } - """, - "ConstrainedValueSymbol" to constrainedValueSymbol, - ) + val constrainValueWritable = + writable { + val boxErr = + if (shape.value.hasTrait()) { + ".map_err(Box::new)" + } else { + "" } - } - } else { - rustTemplate( - """ - match #{ConstrainedValueSymbol}::try_from(v)$boxErr { - Ok(v) => #{Epilogue:W}, - Err(inner_constraint_violation) => Err(Self::Error::Value(k, inner_constraint_violation)), + if (constrainedMemberValueSymbol.isOptional()) { + // The map is `@sparse`. + rustBlock("match v") { + rust("None => Ok((k, None)),") + withBlock("Some(v) =>", ",") { + // DRYing this up with the else branch below would make this less understandable. + rustTemplate( + """ + match #{ConstrainedValueSymbol}::try_from(v)$boxErr { + Ok(v) => Ok((k, Some(v))), + Err(inner_constraint_violation) => Err(Self::Error::Value(k, inner_constraint_violation)), + } + """, + "ConstrainedValueSymbol" to constrainedValueSymbol, + ) + } } - """, - "ConstrainedValueSymbol" to constrainedValueSymbol, - "Epilogue" to epilogueWritable, - ) + } else { + rustTemplate( + """ + match #{ConstrainedValueSymbol}::try_from(v)$boxErr { + Ok(v) => #{Epilogue:W}, + Err(inner_constraint_violation) => Err(Self::Error::Value(k, inner_constraint_violation)), + } + """, + "ConstrainedValueSymbol" to constrainedValueSymbol, + "Epilogue" to epilogueWritable, + ) + } } - } - val constrainKVWritable = if ( - isKeyConstrained(keyShape, symbolProvider) && - isValueConstrained(valueShape, model, symbolProvider) - ) { - listOf(constrainKeyWritable, constrainValueWritable).join("\n") - } else if (isKeyConstrained(keyShape, symbolProvider)) { - listOf(constrainKeyWritable, epilogueWritable).join("\n") - } else if (isValueConstrained(valueShape, model, symbolProvider)) { - constrainValueWritable - } else { - epilogueWritable - } + val constrainKVWritable = + if ( + isKeyConstrained(keyShape, symbolProvider) && + isValueConstrained(valueShape, model, symbolProvider) + ) { + listOf(constrainKeyWritable, constrainValueWritable).join("\n") + } else if (isKeyConstrained(keyShape, symbolProvider)) { + listOf(constrainKeyWritable, epilogueWritable).join("\n") + } else if (isValueConstrained(valueShape, model, symbolProvider)) { + constrainValueWritable + } else { + epilogueWritable + } rustTemplate( """ diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedUnionGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedUnionGenerator.kt index d4f6cd48605..b4a0976e6be 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedUnionGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedUnionGenerator.kt @@ -128,11 +128,12 @@ class UnconstrainedUnionGenerator( "UnconstrainedSymbol" to symbol, ) - val constraintViolationVisibility = if (publicConstrainedTypes) { - Visibility.PUBLIC - } else { - Visibility.PUBCRATE - } + val constraintViolationVisibility = + if (publicConstrainedTypes) { + Visibility.PUBLIC + } else { + Visibility.PUBCRATE + } inlineModuleCreator( constraintViolationSymbol, @@ -172,7 +173,10 @@ class UnconstrainedUnionGenerator( .filter { it.targetCanReachConstrainedShape(model, symbolProvider) } .map { ConstraintViolation(it) } - private fun renderConstraintViolation(writer: RustWriter, constraintViolation: ConstraintViolation) { + private fun renderConstraintViolation( + writer: RustWriter, + constraintViolation: ConstraintViolation, + ) { val targetShape = model.expectShape(constraintViolation.forMember.target) val constraintViolationSymbol = @@ -188,68 +192,71 @@ class UnconstrainedUnionGenerator( ) } - private fun generateTryFromUnconstrainedUnionImpl() = writable { - withBlock("Ok(", ")") { - withBlock("match value {", "}") { - sortedMembers.forEach { member -> - val memberName = unconstrainedShapeSymbolProvider.toMemberName(member) - withBlockTemplate( - "#{UnconstrainedUnion}::$memberName(unconstrained) => Self::$memberName(", - "),", - "UnconstrainedUnion" to symbol, - ) { - if (!member.canReachConstrainedShape(model, symbolProvider)) { - rust("unconstrained") - } else { - val targetShape = model.expectShape(member.target) - val resolveToNonPublicConstrainedType = - targetShape !is StructureShape && targetShape !is UnionShape && !targetShape.hasTrait() && - (!publicConstrainedTypes || !targetShape.isDirectlyConstrained(symbolProvider)) - - val (unconstrainedVar, boxIt) = if (member.hasTrait()) { - "(*unconstrained)" to ".map(Box::new)" + private fun generateTryFromUnconstrainedUnionImpl() = + writable { + withBlock("Ok(", ")") { + withBlock("match value {", "}") { + sortedMembers.forEach { member -> + val memberName = unconstrainedShapeSymbolProvider.toMemberName(member) + withBlockTemplate( + "#{UnconstrainedUnion}::$memberName(unconstrained) => Self::$memberName(", + "),", + "UnconstrainedUnion" to symbol, + ) { + if (!member.canReachConstrainedShape(model, symbolProvider)) { + rust("unconstrained") } else { - "unconstrained" to "" - } - val boxErr = if (member.hasTrait()) { - ".map_err(Box::new)" - } else { - "" - } + val targetShape = model.expectShape(member.target) + val resolveToNonPublicConstrainedType = + targetShape !is StructureShape && targetShape !is UnionShape && !targetShape.hasTrait() && + (!publicConstrainedTypes || !targetShape.isDirectlyConstrained(symbolProvider)) - if (resolveToNonPublicConstrainedType) { - val constrainedSymbol = - if (!publicConstrainedTypes && targetShape.isDirectlyConstrained(symbolProvider)) { - codegenContext.constrainedShapeSymbolProvider.toSymbol(targetShape) + val (unconstrainedVar, boxIt) = + if (member.hasTrait()) { + "(*unconstrained)" to ".map(Box::new)" } else { - pubCrateConstrainedShapeSymbolProvider.toSymbol(targetShape) + "unconstrained" to "" } - rustTemplate( - """ - { - let constrained: #{ConstrainedSymbol} = $unconstrainedVar - .try_into()$boxIt$boxErr - .map_err(Self::Error::${ConstraintViolation(member).name()})?; - constrained.into() + val boxErr = + if (member.hasTrait()) { + ".map_err(Box::new)" + } else { + "" } - """, - "ConstrainedSymbol" to constrainedSymbol, - ) - } else { - rust( - """ - $unconstrainedVar - .try_into() - $boxIt - $boxErr - .map_err(Self::Error::${ConstraintViolation(member).name()})? - """, - ) + + if (resolveToNonPublicConstrainedType) { + val constrainedSymbol = + if (!publicConstrainedTypes && targetShape.isDirectlyConstrained(symbolProvider)) { + codegenContext.constrainedShapeSymbolProvider.toSymbol(targetShape) + } else { + pubCrateConstrainedShapeSymbolProvider.toSymbol(targetShape) + } + rustTemplate( + """ + { + let constrained: #{ConstrainedSymbol} = $unconstrainedVar + .try_into()$boxIt$boxErr + .map_err(Self::Error::${ConstraintViolation(member).name()})?; + constrained.into() + } + """, + "ConstrainedSymbol" to constrainedSymbol, + ) + } else { + rust( + """ + $unconstrainedVar + .try_into() + $boxIt + $boxErr + .map_err(Self::Error::${ConstraintViolation(member).name()})? + """, + ) + } } } } } } } - } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ValidationExceptionConversionGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ValidationExceptionConversionGenerator.kt index 3434afb0b12..b43af2a1fde 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ValidationExceptionConversionGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ValidationExceptionConversionGenerator.kt @@ -31,8 +31,11 @@ interface ValidationExceptionConversionGenerator { // Simple shapes. fun stringShapeConstraintViolationImplBlock(stringConstraintsInfo: Collection): Writable + fun enumShapeConstraintViolationImplBlock(enumTrait: EnumTrait): Writable + fun numberShapeConstraintViolationImplBlock(rangeInfo: Range): Writable + fun blobShapeConstraintViolationImplBlock(blobConstraintsInfo: Collection): Writable // Aggregate shapes. @@ -43,7 +46,9 @@ interface ValidationExceptionConversionGenerator { symbolProvider: RustSymbolProvider, model: Model, ): Writable + fun builderConstraintViolationImplBlock(constraintViolations: Collection): Writable + fun collectionShapeConstraintViolationImplBlock( collectionConstraintsInfo: Collection, isMemberConstrained: Boolean, diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/RestRequestSpecGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/RestRequestSpecGenerator.kt index b43eb582e60..a02a0634a23 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/RestRequestSpecGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/RestRequestSpecGenerator.kt @@ -36,34 +36,38 @@ class RestRequestSpecGenerator( }.toTypedArray() // TODO(https://github.com/smithy-lang/smithy-rs/issues/950): Support the `endpoint` trait. - val pathSegmentsVec = writable { - withBlock("vec![", "]") { - for (segment in httpTrait.uri.segments) { - val variant = when { - segment.isGreedyLabel -> "Greedy" - segment.isLabel -> "Label" - else -> """Literal(String::from("${segment.content}"))""" + val pathSegmentsVec = + writable { + withBlock("vec![", "]") { + for (segment in httpTrait.uri.segments) { + val variant = + when { + segment.isGreedyLabel -> "Greedy" + segment.isLabel -> "Label" + else -> """Literal(String::from("${segment.content}"))""" + } + rustTemplate( + "#{PathSegment}::$variant,", + *extraCodegenScope, + ) } - rustTemplate( - "#{PathSegment}::$variant,", - *extraCodegenScope, - ) } } - } - val querySegmentsVec = writable { - withBlock("vec![", "]") { - for (queryLiteral in httpTrait.uri.queryLiterals) { - val variant = if (queryLiteral.value == "") { - """Key(String::from("${queryLiteral.key}"))""" - } else { - """KeyValue(String::from("${queryLiteral.key}"), String::from("${queryLiteral.value}"))""" + val querySegmentsVec = + writable { + withBlock("vec![", "]") { + for (queryLiteral in httpTrait.uri.queryLiterals) { + val variant = + if (queryLiteral.value == "") { + """Key(String::from("${queryLiteral.key}"))""" + } else { + """KeyValue(String::from("${queryLiteral.key}"), String::from("${queryLiteral.value}"))""" + } + rustTemplate("#{QuerySegment}::$variant,", *extraCodegenScope) } - rustTemplate("#{QuerySegment}::$variant,", *extraCodegenScope) } } - } return writable { rustTemplate( diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/ServerRequestBindingGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/ServerRequestBindingGenerator.kt index e1e6c747f76..ce02e2c99bf 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/ServerRequestBindingGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/ServerRequestBindingGenerator.kt @@ -50,16 +50,16 @@ class ServerRequestBindingGenerator( binding: HttpBindingDescriptor, errorSymbol: Symbol, structuredHandler: RustWriter.(String) -> Unit, - ): RuntimeType = httpBindingGenerator.generateDeserializePayloadFn( - binding, - errorSymbol, - structuredHandler, - HttpMessageType.REQUEST, - ) + ): RuntimeType = + httpBindingGenerator.generateDeserializePayloadFn( + binding, + errorSymbol, + structuredHandler, + HttpMessageType.REQUEST, + ) - fun generateDeserializePrefixHeadersFn( - binding: HttpBindingDescriptor, - ): RuntimeType = httpBindingGenerator.generateDeserializePrefixHeaderFn(binding) + fun generateDeserializePrefixHeadersFn(binding: HttpBindingDescriptor): RuntimeType = + httpBindingGenerator.generateDeserializePrefixHeaderFn(binding) } /** @@ -68,20 +68,22 @@ class ServerRequestBindingGenerator( */ class ServerRequestAfterDeserializingIntoAHashMapOfHttpPrefixHeadersWrapInUnconstrainedMapHttpBindingCustomization(val codegenContext: ServerCodegenContext) : HttpBindingCustomization() { - override fun section(section: HttpBindingSection): Writable = when (section) { - is HttpBindingSection.BeforeRenderingHeaderValue, - is HttpBindingSection.BeforeIteratingOverMapShapeBoundWithHttpPrefixHeaders, - -> emptySection - is HttpBindingSection.AfterDeserializingIntoAHashMapOfHttpPrefixHeaders -> writable { - if (section.memberShape.targetCanReachConstrainedShape(codegenContext.model, codegenContext.unconstrainedShapeSymbolProvider)) { - rust( - "let out = out.map(#T);", - codegenContext.unconstrainedShapeSymbolProvider.toSymbol(section.memberShape).mapRustType { - it.stripOuter() - }, - ) - } + override fun section(section: HttpBindingSection): Writable = + when (section) { + is HttpBindingSection.BeforeRenderingHeaderValue, + is HttpBindingSection.BeforeIteratingOverMapShapeBoundWithHttpPrefixHeaders, + -> emptySection + is HttpBindingSection.AfterDeserializingIntoAHashMapOfHttpPrefixHeaders -> + writable { + if (section.memberShape.targetCanReachConstrainedShape(codegenContext.model, codegenContext.unconstrainedShapeSymbolProvider)) { + rust( + "let out = out.map(#T);", + codegenContext.unconstrainedShapeSymbolProvider.toSymbol(section.memberShape).mapRustType { + it.stripOuter() + }, + ) + } + } + else -> emptySection } - else -> emptySection - } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/ServerResponseBindingGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/ServerResponseBindingGenerator.kt index 01448d27a47..960cbd735d7 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/ServerResponseBindingGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/ServerResponseBindingGenerator.kt @@ -57,23 +57,25 @@ class ServerResponseBindingGenerator( */ class ServerResponseBeforeIteratingOverMapBoundWithHttpPrefixHeadersUnwrapConstrainedMapHttpBindingCustomization(val codegenContext: ServerCodegenContext) : HttpBindingCustomization() { - override fun section(section: HttpBindingSection): Writable = when (section) { - is HttpBindingSection.BeforeIteratingOverMapShapeBoundWithHttpPrefixHeaders -> writable { - if (workingWithPublicConstrainedWrapperTupleType( - section.shape, - codegenContext.model, - codegenContext.settings.codegenConfig.publicConstrainedTypes, - ) - ) { - rust("let ${section.variableName} = &${section.variableName}.0;") - } - } + override fun section(section: HttpBindingSection): Writable = + when (section) { + is HttpBindingSection.BeforeIteratingOverMapShapeBoundWithHttpPrefixHeaders -> + writable { + if (workingWithPublicConstrainedWrapperTupleType( + section.shape, + codegenContext.model, + codegenContext.settings.codegenConfig.publicConstrainedTypes, + ) + ) { + rust("let ${section.variableName} = &${section.variableName}.0;") + } + } - is HttpBindingSection.BeforeRenderingHeaderValue, - is HttpBindingSection.AfterDeserializingIntoAHashMapOfHttpPrefixHeaders, - is HttpBindingSection.AfterDeserializingIntoADateTimeOfHttpHeaders, - -> emptySection - } + is HttpBindingSection.BeforeRenderingHeaderValue, + is HttpBindingSection.AfterDeserializingIntoAHashMapOfHttpPrefixHeaders, + is HttpBindingSection.AfterDeserializingIntoADateTimeOfHttpHeaders, + -> emptySection + } } /** @@ -82,26 +84,29 @@ class ServerResponseBeforeIteratingOverMapBoundWithHttpPrefixHeadersUnwrapConstr */ class ServerResponseBeforeRenderingHeadersHttpBindingCustomization(val codegenContext: ServerCodegenContext) : HttpBindingCustomization() { - override fun section(section: HttpBindingSection): Writable = when (section) { - is HttpBindingSection.BeforeRenderingHeaderValue -> writable { - val isIntegral = section.context.shape is ByteShape || section.context.shape is ShortShape || section.context.shape is IntegerShape || section.context.shape is LongShape - val isCollection = section.context.shape is CollectionShape + override fun section(section: HttpBindingSection): Writable = + when (section) { + is HttpBindingSection.BeforeRenderingHeaderValue -> + writable { + val isIntegral = section.context.shape is ByteShape || section.context.shape is ShortShape || section.context.shape is IntegerShape || section.context.shape is LongShape + val isCollection = section.context.shape is CollectionShape - val workingWithPublicWrapper = workingWithPublicConstrainedWrapperTupleType( - section.context.shape, - codegenContext.model, - codegenContext.settings.codegenConfig.publicConstrainedTypes, - ) + val workingWithPublicWrapper = + workingWithPublicConstrainedWrapperTupleType( + section.context.shape, + codegenContext.model, + codegenContext.settings.codegenConfig.publicConstrainedTypes, + ) - if (workingWithPublicWrapper && (isIntegral || isCollection)) { - section.context.valueExpression = - ValueExpression.Reference("&${section.context.valueExpression.name.removePrefix("&")}.0") - } - } + if (workingWithPublicWrapper && (isIntegral || isCollection)) { + section.context.valueExpression = + ValueExpression.Reference("&${section.context.valueExpression.name.removePrefix("&")}.0") + } + } - is HttpBindingSection.BeforeIteratingOverMapShapeBoundWithHttpPrefixHeaders, - is HttpBindingSection.AfterDeserializingIntoAHashMapOfHttpPrefixHeaders, - is HttpBindingSection.AfterDeserializingIntoADateTimeOfHttpHeaders, - -> emptySection - } + is HttpBindingSection.BeforeIteratingOverMapShapeBoundWithHttpPrefixHeaders, + is HttpBindingSection.AfterDeserializingIntoAHashMapOfHttpPrefixHeaders, + is HttpBindingSection.AfterDeserializingIntoADateTimeOfHttpHeaders, + -> emptySection + } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt index f86c547da19..2fb76bf879b 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt @@ -126,10 +126,11 @@ class ServerAwsJsonProtocol( private val runtimeConfig = codegenContext.runtimeConfig override val protocolModulePath: String - get() = when (version) { - is AwsJsonVersion.Json10 -> "aws_json_10" - is AwsJsonVersion.Json11 -> "aws_json_11" - } + get() = + when (version) { + is AwsJsonVersion.Json10 -> "aws_json_10" + is AwsJsonVersion.Json11 -> "aws_json_11" + } override fun structuredDataParser(): StructuredDataParserGenerator = jsonParserGenerator( @@ -149,8 +150,9 @@ class ServerAwsJsonProtocol( } } - override fun routerType() = ServerCargoDependency.smithyHttpServer(runtimeConfig).toType() - .resolve("protocol::aws_json::router::AwsJsonRouter") + override fun routerType() = + ServerCargoDependency.smithyHttpServer(runtimeConfig).toType() + .resolve("protocol::aws_json::router::AwsJsonRouter") /** * Returns the operation name as required by the awsJson1.x protocols. @@ -164,14 +166,13 @@ class ServerAwsJsonProtocol( rust("""String::from("$serviceName.$operationName")""") } - override fun serverRouterRequestSpecType( - requestSpecModule: RuntimeType, - ): RuntimeType = RuntimeType.String + override fun serverRouterRequestSpecType(requestSpecModule: RuntimeType): RuntimeType = RuntimeType.String - override fun serverRouterRuntimeConstructor() = when (version) { - AwsJsonVersion.Json10 -> "new_aws_json_10_router" - AwsJsonVersion.Json11 -> "new_aws_json_11_router" - } + override fun serverRouterRuntimeConstructor() = + when (version) { + AwsJsonVersion.Json10 -> "new_aws_json_10_router" + AwsJsonVersion.Json11 -> "new_aws_json_11_router" + } override fun requestRejection(runtimeConfig: RuntimeConfig): RuntimeType = ServerCargoDependency.smithyHttpServer(runtimeConfig) @@ -259,17 +260,19 @@ class ServerRestXmlProtocol( */ class ServerRequestBeforeBoxingDeserializedMemberConvertToMaybeConstrainedJsonParserCustomization(val codegenContext: ServerCodegenContext) : JsonParserCustomization() { - override fun section(section: JsonParserSection): Writable = when (section) { - is JsonParserSection.BeforeBoxingDeserializedMember -> writable { - // We're only interested in _structure_ member shapes that can reach constrained shapes. - if ( - codegenContext.model.expectShape(section.shape.container) is StructureShape && - section.shape.targetCanReachConstrainedShape(codegenContext.model, codegenContext.symbolProvider) - ) { - rust(".map(|x| x.into())") - } + override fun section(section: JsonParserSection): Writable = + when (section) { + is JsonParserSection.BeforeBoxingDeserializedMember -> + writable { + // We're only interested in _structure_ member shapes that can reach constrained shapes. + if ( + codegenContext.model.expectShape(section.shape.container) is StructureShape && + section.shape.targetCanReachConstrainedShape(codegenContext.model, codegenContext.symbolProvider) + ) { + rust(".map(|x| x.into())") + } + } + + else -> emptySection } - - else -> emptySection - } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt index 968f87ea3a0..389133f8f78 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt @@ -76,37 +76,40 @@ class ServerProtocolTestGenerator( private val operations = TopDownIndex.of(codegenContext.model).getContainedOperations(codegenContext.serviceShape).sortedBy { it.id } - private val operationInputOutputTypes = operations.associateWith { - val inputSymbol = symbolProvider.toSymbol(it.inputShape(model)) - val outputSymbol = symbolProvider.toSymbol(it.outputShape(model)) - val operationSymbol = symbolProvider.toSymbol(it) - - val inputT = inputSymbol.fullName - val t = outputSymbol.fullName - val outputT = if (it.errors.isEmpty()) { - t - } else { - val errorType = RuntimeType("crate::error::${operationSymbol.name}Error") - val e = errorType.fullyQualifiedName() - "Result<$t, $e>" - } + private val operationInputOutputTypes = + operations.associateWith { + val inputSymbol = symbolProvider.toSymbol(it.inputShape(model)) + val outputSymbol = symbolProvider.toSymbol(it.outputShape(model)) + val operationSymbol = symbolProvider.toSymbol(it) + + val inputT = inputSymbol.fullName + val t = outputSymbol.fullName + val outputT = + if (it.errors.isEmpty()) { + t + } else { + val errorType = RuntimeType("crate::error::${operationSymbol.name}Error") + val e = errorType.fullyQualifiedName() + "Result<$t, $e>" + } - inputT to outputT - } + inputT to outputT + } private val instantiator = ServerInstantiator(codegenContext) - private val codegenScope = arrayOf( - "Bytes" to RuntimeType.Bytes, - "SmithyHttp" to RuntimeType.smithyHttp(codegenContext.runtimeConfig), - "Http" to RuntimeType.Http, - "Hyper" to RuntimeType.Hyper, - "Tokio" to ServerCargoDependency.TokioDev.toType(), - "Tower" to RuntimeType.Tower, - "SmithyHttpServer" to ServerCargoDependency.smithyHttpServer(codegenContext.runtimeConfig).toType(), - "AssertEq" to RuntimeType.PrettyAssertions.resolve("assert_eq!"), - "Router" to ServerRuntimeType.router(codegenContext.runtimeConfig), - ) + private val codegenScope = + arrayOf( + "Bytes" to RuntimeType.Bytes, + "SmithyHttp" to RuntimeType.smithyHttp(codegenContext.runtimeConfig), + "Http" to RuntimeType.Http, + "Hyper" to RuntimeType.Hyper, + "Tokio" to ServerCargoDependency.TokioDev.toType(), + "Tower" to RuntimeType.Tower, + "SmithyHttpServer" to ServerCargoDependency.smithyHttpServer(codegenContext.runtimeConfig).toType(), + "AssertEq" to RuntimeType.PrettyAssertions.resolve("assert_eq!"), + "Router" to ServerRuntimeType.router(codegenContext.runtimeConfig), + ) sealed class TestCase { abstract val id: String @@ -142,67 +145,84 @@ class ServerProtocolTestGenerator( } } - private fun renderOperationTestCases(operationShape: OperationShape, writer: RustWriter) { + private fun renderOperationTestCases( + operationShape: OperationShape, + writer: RustWriter, + ) { val outputShape = operationShape.outputShape(codegenContext.model) val operationSymbol = symbolProvider.toSymbol(operationShape) - val requestTests = operationShape.getTrait() - ?.getTestCasesFor(AppliesTo.SERVER).orEmpty().map { TestCase.RequestTest(it, operationShape) } - val responseTests = operationShape.getTrait() - ?.getTestCasesFor(AppliesTo.SERVER).orEmpty().map { TestCase.ResponseTest(it, outputShape) } - val errorTests = operationIndex.getErrors(operationShape).flatMap { error -> - val testCases = error.getTrait() - ?.getTestCasesFor(AppliesTo.SERVER).orEmpty() - testCases.map { TestCase.ResponseTest(it, error) } - } - val malformedRequestTests = operationShape.getTrait() - ?.testCases.orEmpty().map { TestCase.MalformedRequestTest(it) } - val allTests: List = (requestTests + responseTests + errorTests + malformedRequestTests) - .filterMatching() - .fixBroken() + val requestTests = + operationShape.getTrait() + ?.getTestCasesFor(AppliesTo.SERVER).orEmpty().map { TestCase.RequestTest(it, operationShape) } + val responseTests = + operationShape.getTrait() + ?.getTestCasesFor(AppliesTo.SERVER).orEmpty().map { TestCase.ResponseTest(it, outputShape) } + val errorTests = + operationIndex.getErrors(operationShape).flatMap { error -> + val testCases = + error.getTrait() + ?.getTestCasesFor(AppliesTo.SERVER).orEmpty() + testCases.map { TestCase.ResponseTest(it, error) } + } + val malformedRequestTests = + operationShape.getTrait() + ?.testCases.orEmpty().map { TestCase.MalformedRequestTest(it) } + val allTests: List = + (requestTests + responseTests + errorTests + malformedRequestTests) + .filterMatching() + .fixBroken() if (allTests.isNotEmpty()) { val operationName = operationSymbol.name - val module = RustModule.LeafModule( - "server_${operationName.toSnakeCase()}_test", - RustMetadata( - additionalAttributes = listOf( - Attribute.CfgTest, - Attribute(allow("unreachable_code", "unused_variables")), + val module = + RustModule.LeafModule( + "server_${operationName.toSnakeCase()}_test", + RustMetadata( + additionalAttributes = + listOf( + Attribute.CfgTest, + Attribute(allow("unreachable_code", "unused_variables")), + ), + visibility = Visibility.PRIVATE, ), - visibility = Visibility.PRIVATE, - ), - inline = true, - ) + inline = true, + ) writer.withInlineModule(module, null) { renderAllTestCases(operationShape, allTests) } } } - private fun RustWriter.renderAllTestCases(operationShape: OperationShape, allTests: List) { + private fun RustWriter.renderAllTestCases( + operationShape: OperationShape, + allTests: List, + ) { allTests.forEach { val operationSymbol = symbolProvider.toSymbol(operationShape) renderTestCaseBlock(it, this) { when (it) { - is TestCase.RequestTest -> this.renderHttpRequestTestCase( - it.testCase, - operationShape, - operationSymbol, - ) - - is TestCase.ResponseTest -> this.renderHttpResponseTestCase( - it.testCase, - it.targetShape, - operationShape, - operationSymbol, - ) - - is TestCase.MalformedRequestTest -> this.renderHttpMalformedRequestTestCase( - it.testCase, - operationShape, - operationSymbol, - ) + is TestCase.RequestTest -> + this.renderHttpRequestTestCase( + it.testCase, + operationShape, + operationSymbol, + ) + + is TestCase.ResponseTest -> + this.renderHttpResponseTestCase( + it.testCase, + it.targetShape, + operationShape, + operationSymbol, + ) + + is TestCase.MalformedRequestTest -> + this.renderHttpMalformedRequestTestCase( + it.testCase, + operationShape, + operationSymbol, + ) } } } @@ -228,20 +248,21 @@ class ServerProtocolTestGenerator( // This function applies a "fix function" to each broken test before we synthesize it. // Broken tests are those whose definitions in the `awslabs/smithy` repository are wrong, usually because they have // not been written with a server-side perspective in mind. - private fun List.fixBroken(): List = this.map { - when (it) { - is TestCase.MalformedRequestTest -> { - val howToFixIt = BrokenMalformedRequestTests[Pair(codegenContext.serviceShape.id.toString(), it.id)] - if (howToFixIt == null) { - it - } else { - val fixed = howToFixIt(it.testCase) - TestCase.MalformedRequestTest(fixed) + private fun List.fixBroken(): List = + this.map { + when (it) { + is TestCase.MalformedRequestTest -> { + val howToFixIt = BrokenMalformedRequestTests[Pair(codegenContext.serviceShape.id.toString(), it.id)] + if (howToFixIt == null) { + it + } else { + val fixed = howToFixIt(it.testCase) + TestCase.MalformedRequestTest(fixed) + } } + else -> it } - else -> it } - } private fun renderTestCaseBlock( testCase: TestCase, @@ -261,11 +282,12 @@ class ServerProtocolTestGenerator( if (expectFail(testCase)) { testModuleWriter.writeWithNoFormatting("#[should_panic]") } - val fnNameSuffix = when (testCase.testType) { - is TestType.Response -> "_response" - is TestType.Request -> "_request" - is TestType.MalformedRequest -> "_malformed_request" - } + val fnNameSuffix = + when (testCase.testType) { + is TestType.Response -> "_response" + is TestType.Request -> "_request" + is TestType.MalformedRequest -> "_malformed_request" + } testModuleWriter.rustBlock("async fn ${testCase.id.toSnakeCase()}$fnNameSuffix()") { block(this) } @@ -305,9 +327,10 @@ class ServerProtocolTestGenerator( } } - private fun expectFail(testCase: TestCase): Boolean = ExpectFail.find { - it.id == testCase.id && it.testType == testCase.testType && it.service == codegenContext.serviceShape.id.toString() - } != null + private fun expectFail(testCase: TestCase): Boolean = + ExpectFail.find { + it.id == testCase.id && it.testType == testCase.testType && it.service == codegenContext.serviceShape.id.toString() + } != null /** * Renders an HTTP response test case. @@ -325,7 +348,7 @@ class ServerProtocolTestGenerator( if (!protocolSupport.responseSerialization || ( !protocolSupport.errorSerialization && shape.hasTrait() - ) + ) ) { rust("/* test case disabled for this protocol (not yet supported) */") return @@ -430,29 +453,31 @@ class ServerProtocolTestGenerator( } /** Returns the body of the request test. */ - private fun checkRequestHandler(operationShape: OperationShape, httpRequestTestCase: HttpRequestTestCase) = - writable { - val inputShape = operationShape.inputShape(codegenContext.model) - val outputShape = operationShape.outputShape(codegenContext.model) - - // Construct expected request. - withBlock("let expected = ", ";") { - instantiator.render(this, inputShape, httpRequestTestCase.params, httpRequestTestCase.headers) - } + private fun checkRequestHandler( + operationShape: OperationShape, + httpRequestTestCase: HttpRequestTestCase, + ) = writable { + val inputShape = operationShape.inputShape(codegenContext.model) + val outputShape = operationShape.outputShape(codegenContext.model) - checkRequestParams(inputShape, this) + // Construct expected request. + withBlock("let expected = ", ";") { + instantiator.render(this, inputShape, httpRequestTestCase.params, httpRequestTestCase.headers) + } - // Construct a dummy response. - withBlock("let response = ", ";") { - instantiator.render(this, outputShape, Node.objectNode()) - } + checkRequestParams(inputShape, this) - if (operationShape.errors.isEmpty()) { - rust("response") - } else { - rust("Ok(response)") - } + // Construct a dummy response. + withBlock("let response = ", ";") { + instantiator.render(this, outputShape, Node.objectNode()) + } + + if (operationShape.errors.isEmpty()) { + rust("response") + } else { + rust("Ok(response)") } + } /** Checks the request. */ private fun makeRequest( @@ -495,7 +520,10 @@ class ServerProtocolTestGenerator( ) } - private fun checkRequestParams(inputShape: StructureShape, rustWriter: RustWriter) { + private fun checkRequestParams( + inputShape: StructureShape, + rustWriter: RustWriter, + ) { if (inputShape.hasStreamingMember(model)) { // A streaming shape does not implement `PartialEq`, so we have to iterate over the input shape's members // and handle the equality assertion separately. @@ -521,10 +549,11 @@ class ServerProtocolTestGenerator( } } } else { - val hasFloatingPointMembers = inputShape.members().any { - val target = model.expectShape(it.target) - (target is DoubleShape) || (target is FloatShape) - } + val hasFloatingPointMembers = + inputShape.members().any { + val target = model.expectShape(it.target) + (target is DoubleShape) || (target is FloatShape) + } // TODO(https://github.com/smithy-lang/smithy-rs/issues/1147) Handle the case of nested floating point members. if (hasFloatingPointMembers) { @@ -560,7 +589,10 @@ class ServerProtocolTestGenerator( } } - private fun checkResponse(rustWriter: RustWriter, testCase: HttpResponseTestCase) { + private fun checkResponse( + rustWriter: RustWriter, + testCase: HttpResponseTestCase, + ) { checkStatusCode(rustWriter, testCase.code) checkHeaders(rustWriter, "http_response.headers()", testCase.headers) checkForbidHeaders(rustWriter, "http_response.headers()", testCase.forbidHeaders) @@ -578,7 +610,10 @@ class ServerProtocolTestGenerator( } } - private fun checkResponse(rustWriter: RustWriter, testCase: HttpMalformedResponseDefinition) { + private fun checkResponse( + rustWriter: RustWriter, + testCase: HttpMalformedResponseDefinition, + ) { checkStatusCode(rustWriter, testCase.code) checkHeaders(rustWriter, "http_response.headers()", testCase.headers) @@ -610,7 +645,11 @@ class ServerProtocolTestGenerator( } } - private fun checkBody(rustWriter: RustWriter, body: String, mediaType: String?) { + private fun checkBody( + rustWriter: RustWriter, + body: String, + mediaType: String?, + ) { rustWriter.rustTemplate( """ let body = #{Hyper}::body::to_bytes(http_response.into_body()).await.expect("unable to extract body to bytes"); @@ -638,7 +677,10 @@ class ServerProtocolTestGenerator( } } - private fun checkStatusCode(rustWriter: RustWriter, statusCode: Int) { + private fun checkStatusCode( + rustWriter: RustWriter, + statusCode: Int, + ) { rustWriter.rustTemplate( """ #{AssertEq}( @@ -650,7 +692,11 @@ class ServerProtocolTestGenerator( ) } - private fun checkRequiredHeaders(rustWriter: RustWriter, actualExpression: String, requireHeaders: List) { + private fun checkRequiredHeaders( + rustWriter: RustWriter, + actualExpression: String, + requireHeaders: List, + ) { basicCheck( requireHeaders, rustWriter, @@ -660,7 +706,11 @@ class ServerProtocolTestGenerator( ) } - private fun checkForbidHeaders(rustWriter: RustWriter, actualExpression: String, forbidHeaders: List) { + private fun checkForbidHeaders( + rustWriter: RustWriter, + actualExpression: String, + forbidHeaders: List, + ) { basicCheck( forbidHeaders, rustWriter, @@ -670,7 +720,11 @@ class ServerProtocolTestGenerator( ) } - private fun checkHeaders(rustWriter: RustWriter, actualExpression: String, headers: Map) { + private fun checkHeaders( + rustWriter: RustWriter, + actualExpression: String, + headers: Map, + ) { if (headers.isEmpty()) { return } @@ -715,13 +769,19 @@ class ServerProtocolTestGenerator( * wraps `inner` in a call to `aws_smithy_protocol_test::assert_ok`, a convenience wrapper * for pretty printing protocol test helper results */ - private fun assertOk(rustWriter: RustWriter, inner: Writable) { + private fun assertOk( + rustWriter: RustWriter, + inner: Writable, + ) { rustWriter.rust("#T(", RuntimeType.protocolTest(codegenContext.runtimeConfig, "assert_ok")) inner(rustWriter) rustWriter.write(");") } - private fun strSlice(writer: RustWriter, args: List) { + private fun strSlice( + writer: RustWriter, + args: List, + ) { writer.withBlock("&[", "]") { rust(args.joinToString(",") { it.dq() }) } @@ -730,7 +790,9 @@ class ServerProtocolTestGenerator( companion object { sealed class TestType { object Request : TestType() + object Response : TestType() + object MalformedRequest : TestType() } @@ -743,112 +805,105 @@ class ServerProtocolTestGenerator( private const val AwsJson11 = "aws.protocoltests.json#JsonProtocol" private const val RestJson = "aws.protocoltests.restjson#RestJson" private const val RestJsonValidation = "aws.protocoltests.restjson.validation#RestJsonValidation" - private val ExpectFail: Set = setOf( - // Endpoint trait is not implemented yet, see https://github.com/smithy-lang/smithy-rs/issues/950. - FailingTest(RestJson, "RestJsonEndpointTrait", TestType.Request), - FailingTest(RestJson, "RestJsonEndpointTraitWithHostLabel", TestType.Request), - - FailingTest(RestJson, "RestJsonOmitsEmptyListQueryValues", TestType.Request), - // Tests involving `@range` on floats. - // Pending resolution from the Smithy team, see https://github.com/smithy-lang/smithy-rs/issues/2007. - FailingTest(RestJsonValidation, "RestJsonMalformedRangeFloat_case0", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedRangeFloat_case1", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedRangeMaxFloat", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedRangeMinFloat", TestType.MalformedRequest), - - // Tests involving floating point shapes and the `@range` trait; see https://github.com/smithy-lang/smithy-rs/issues/2007 - FailingTest(RestJsonValidation, "RestJsonMalformedRangeFloatOverride_case0", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedRangeFloatOverride_case1", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedRangeMaxFloatOverride", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedRangeMinFloatOverride", TestType.MalformedRequest), - - // Some tests for the S3 service (restXml). - FailingTest("com.amazonaws.s3#AmazonS3", "GetBucketLocationUnwrappedOutput", TestType.Response), - FailingTest("com.amazonaws.s3#AmazonS3", "S3DefaultAddressing", TestType.Request), - FailingTest("com.amazonaws.s3#AmazonS3", "S3VirtualHostAddressing", TestType.Request), - FailingTest("com.amazonaws.s3#AmazonS3", "S3PathAddressing", TestType.Request), - FailingTest("com.amazonaws.s3#AmazonS3", "S3VirtualHostDualstackAddressing", TestType.Request), - FailingTest("com.amazonaws.s3#AmazonS3", "S3VirtualHostAccelerateAddressing", TestType.Request), - FailingTest("com.amazonaws.s3#AmazonS3", "S3VirtualHostDualstackAccelerateAddressing", TestType.Request), - FailingTest("com.amazonaws.s3#AmazonS3", "S3OperationAddressingPreferred", TestType.Request), - FailingTest("com.amazonaws.s3#AmazonS3", "S3OperationNoErrorWrappingResponse", TestType.Response), - - // AwsJson1.0 failing tests. - FailingTest("aws.protocoltests.json10#JsonRpc10", "AwsJson10EndpointTraitWithHostLabel", TestType.Request), - FailingTest("aws.protocoltests.json10#JsonRpc10", "AwsJson10EndpointTrait", TestType.Request), - - // AwsJson1.1 failing tests. - FailingTest(AwsJson11, "AwsJson11EndpointTraitWithHostLabel", TestType.Request), - FailingTest(AwsJson11, "AwsJson11EndpointTrait", TestType.Request), - FailingTest(AwsJson11, "parses_the_request_id_from_the_response", TestType.Response), - - // TODO(https://github.com/awslabs/smithy/issues/1683): This has been marked as failing until resolution of said issue - FailingTest(RestJsonValidation, "RestJsonMalformedUniqueItemsBlobList", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedUniqueItemsBooleanList_case0", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedUniqueItemsBooleanList_case1", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedUniqueItemsStringList", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedUniqueItemsByteList", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedUniqueItemsShortList", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedUniqueItemsIntegerList", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedUniqueItemsLongList", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedUniqueItemsTimestampList", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedUniqueItemsDateTimeList", TestType.MalformedRequest), - FailingTest( - RestJsonValidation, - "RestJsonMalformedUniqueItemsHttpDateList_case0", - TestType.MalformedRequest, - ), - FailingTest( - RestJsonValidation, - "RestJsonMalformedUniqueItemsHttpDateList_case1", - TestType.MalformedRequest, - ), - FailingTest(RestJsonValidation, "RestJsonMalformedUniqueItemsEnumList", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedUniqueItemsIntEnumList", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedUniqueItemsListList", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedUniqueItemsStructureList", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedUniqueItemsUnionList_case0", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedUniqueItemsUnionList_case1", TestType.MalformedRequest), - - // TODO(https://github.com/smithy-lang/smithy-rs/issues/2472): We don't respect the `@internal` trait - FailingTest(RestJsonValidation, "RestJsonMalformedEnumList_case0", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedEnumList_case1", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedEnumMapKey_case0", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedEnumMapKey_case1", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedEnumMapValue_case0", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedEnumMapValue_case1", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedEnumString_case0", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedEnumString_case1", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedEnumUnion_case0", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedEnumUnion_case1", TestType.MalformedRequest), - - // TODO(https://github.com/awslabs/smithy/issues/1737): Specs on @internal, @tags, and enum values need to be clarified - FailingTest(RestJsonValidation, "RestJsonMalformedEnumTraitString_case0", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedEnumTraitString_case1", TestType.MalformedRequest), - ) + private val ExpectFail: Set = + setOf( + // Endpoint trait is not implemented yet, see https://github.com/smithy-lang/smithy-rs/issues/950. + FailingTest(RestJson, "RestJsonEndpointTrait", TestType.Request), + FailingTest(RestJson, "RestJsonEndpointTraitWithHostLabel", TestType.Request), + FailingTest(RestJson, "RestJsonOmitsEmptyListQueryValues", TestType.Request), + // Tests involving `@range` on floats. + // Pending resolution from the Smithy team, see https://github.com/smithy-lang/smithy-rs/issues/2007. + FailingTest(RestJsonValidation, "RestJsonMalformedRangeFloat_case0", TestType.MalformedRequest), + FailingTest(RestJsonValidation, "RestJsonMalformedRangeFloat_case1", TestType.MalformedRequest), + FailingTest(RestJsonValidation, "RestJsonMalformedRangeMaxFloat", TestType.MalformedRequest), + FailingTest(RestJsonValidation, "RestJsonMalformedRangeMinFloat", TestType.MalformedRequest), + // Tests involving floating point shapes and the `@range` trait; see https://github.com/smithy-lang/smithy-rs/issues/2007 + FailingTest(RestJsonValidation, "RestJsonMalformedRangeFloatOverride_case0", TestType.MalformedRequest), + FailingTest(RestJsonValidation, "RestJsonMalformedRangeFloatOverride_case1", TestType.MalformedRequest), + FailingTest(RestJsonValidation, "RestJsonMalformedRangeMaxFloatOverride", TestType.MalformedRequest), + FailingTest(RestJsonValidation, "RestJsonMalformedRangeMinFloatOverride", TestType.MalformedRequest), + // Some tests for the S3 service (restXml). + FailingTest("com.amazonaws.s3#AmazonS3", "GetBucketLocationUnwrappedOutput", TestType.Response), + FailingTest("com.amazonaws.s3#AmazonS3", "S3DefaultAddressing", TestType.Request), + FailingTest("com.amazonaws.s3#AmazonS3", "S3VirtualHostAddressing", TestType.Request), + FailingTest("com.amazonaws.s3#AmazonS3", "S3PathAddressing", TestType.Request), + FailingTest("com.amazonaws.s3#AmazonS3", "S3VirtualHostDualstackAddressing", TestType.Request), + FailingTest("com.amazonaws.s3#AmazonS3", "S3VirtualHostAccelerateAddressing", TestType.Request), + FailingTest("com.amazonaws.s3#AmazonS3", "S3VirtualHostDualstackAccelerateAddressing", TestType.Request), + FailingTest("com.amazonaws.s3#AmazonS3", "S3OperationAddressingPreferred", TestType.Request), + FailingTest("com.amazonaws.s3#AmazonS3", "S3OperationNoErrorWrappingResponse", TestType.Response), + // AwsJson1.0 failing tests. + FailingTest("aws.protocoltests.json10#JsonRpc10", "AwsJson10EndpointTraitWithHostLabel", TestType.Request), + FailingTest("aws.protocoltests.json10#JsonRpc10", "AwsJson10EndpointTrait", TestType.Request), + // AwsJson1.1 failing tests. + FailingTest(AwsJson11, "AwsJson11EndpointTraitWithHostLabel", TestType.Request), + FailingTest(AwsJson11, "AwsJson11EndpointTrait", TestType.Request), + FailingTest(AwsJson11, "parses_the_request_id_from_the_response", TestType.Response), + // TODO(https://github.com/awslabs/smithy/issues/1683): This has been marked as failing until resolution of said issue + FailingTest(RestJsonValidation, "RestJsonMalformedUniqueItemsBlobList", TestType.MalformedRequest), + FailingTest(RestJsonValidation, "RestJsonMalformedUniqueItemsBooleanList_case0", TestType.MalformedRequest), + FailingTest(RestJsonValidation, "RestJsonMalformedUniqueItemsBooleanList_case1", TestType.MalformedRequest), + FailingTest(RestJsonValidation, "RestJsonMalformedUniqueItemsStringList", TestType.MalformedRequest), + FailingTest(RestJsonValidation, "RestJsonMalformedUniqueItemsByteList", TestType.MalformedRequest), + FailingTest(RestJsonValidation, "RestJsonMalformedUniqueItemsShortList", TestType.MalformedRequest), + FailingTest(RestJsonValidation, "RestJsonMalformedUniqueItemsIntegerList", TestType.MalformedRequest), + FailingTest(RestJsonValidation, "RestJsonMalformedUniqueItemsLongList", TestType.MalformedRequest), + FailingTest(RestJsonValidation, "RestJsonMalformedUniqueItemsTimestampList", TestType.MalformedRequest), + FailingTest(RestJsonValidation, "RestJsonMalformedUniqueItemsDateTimeList", TestType.MalformedRequest), + FailingTest( + RestJsonValidation, + "RestJsonMalformedUniqueItemsHttpDateList_case0", + TestType.MalformedRequest, + ), + FailingTest( + RestJsonValidation, + "RestJsonMalformedUniqueItemsHttpDateList_case1", + TestType.MalformedRequest, + ), + FailingTest(RestJsonValidation, "RestJsonMalformedUniqueItemsEnumList", TestType.MalformedRequest), + FailingTest(RestJsonValidation, "RestJsonMalformedUniqueItemsIntEnumList", TestType.MalformedRequest), + FailingTest(RestJsonValidation, "RestJsonMalformedUniqueItemsListList", TestType.MalformedRequest), + FailingTest(RestJsonValidation, "RestJsonMalformedUniqueItemsStructureList", TestType.MalformedRequest), + FailingTest(RestJsonValidation, "RestJsonMalformedUniqueItemsUnionList_case0", TestType.MalformedRequest), + FailingTest(RestJsonValidation, "RestJsonMalformedUniqueItemsUnionList_case1", TestType.MalformedRequest), + // TODO(https://github.com/smithy-lang/smithy-rs/issues/2472): We don't respect the `@internal` trait + FailingTest(RestJsonValidation, "RestJsonMalformedEnumList_case0", TestType.MalformedRequest), + FailingTest(RestJsonValidation, "RestJsonMalformedEnumList_case1", TestType.MalformedRequest), + FailingTest(RestJsonValidation, "RestJsonMalformedEnumMapKey_case0", TestType.MalformedRequest), + FailingTest(RestJsonValidation, "RestJsonMalformedEnumMapKey_case1", TestType.MalformedRequest), + FailingTest(RestJsonValidation, "RestJsonMalformedEnumMapValue_case0", TestType.MalformedRequest), + FailingTest(RestJsonValidation, "RestJsonMalformedEnumMapValue_case1", TestType.MalformedRequest), + FailingTest(RestJsonValidation, "RestJsonMalformedEnumString_case0", TestType.MalformedRequest), + FailingTest(RestJsonValidation, "RestJsonMalformedEnumString_case1", TestType.MalformedRequest), + FailingTest(RestJsonValidation, "RestJsonMalformedEnumUnion_case0", TestType.MalformedRequest), + FailingTest(RestJsonValidation, "RestJsonMalformedEnumUnion_case1", TestType.MalformedRequest), + // TODO(https://github.com/awslabs/smithy/issues/1737): Specs on @internal, @tags, and enum values need to be clarified + FailingTest(RestJsonValidation, "RestJsonMalformedEnumTraitString_case0", TestType.MalformedRequest), + FailingTest(RestJsonValidation, "RestJsonMalformedEnumTraitString_case1", TestType.MalformedRequest), + ) private val RunOnly: Set? = null // These tests are not even attempted to be generated, either because they will not compile // or because they are flaky - private val DisableTests = setOf( - // TODO(https://github.com/smithy-lang/smithy-rs/issues/2891): Implement support for `@requestCompression` - "SDKAppendedGzipAfterProvidedEncoding_restJson1", - "SDKAppendedGzipAfterProvidedEncoding_restXml", - "SDKAppendsGzipAndIgnoresHttpProvidedEncoding_awsJson1_0", - "SDKAppendsGzipAndIgnoresHttpProvidedEncoding_awsJson1_1", - "SDKAppendsGzipAndIgnoresHttpProvidedEncoding_awsQuery", - "SDKAppendsGzipAndIgnoresHttpProvidedEncoding_ec2Query", - "SDKAppliedContentEncoding_awsJson1_0", - "SDKAppliedContentEncoding_awsJson1_1", - "SDKAppliedContentEncoding_awsQuery", - "SDKAppliedContentEncoding_ec2Query", - "SDKAppliedContentEncoding_restJson1", - "SDKAppliedContentEncoding_restXml", - - // RestXml S3 tests that fail to compile - "S3EscapeObjectKeyInUriLabel", - "S3EscapePathObjectKeyInUriLabel", - ) + private val DisableTests = + setOf( + // TODO(https://github.com/smithy-lang/smithy-rs/issues/2891): Implement support for `@requestCompression` + "SDKAppendedGzipAfterProvidedEncoding_restJson1", + "SDKAppendedGzipAfterProvidedEncoding_restXml", + "SDKAppendsGzipAndIgnoresHttpProvidedEncoding_awsJson1_0", + "SDKAppendsGzipAndIgnoresHttpProvidedEncoding_awsJson1_1", + "SDKAppendsGzipAndIgnoresHttpProvidedEncoding_awsQuery", + "SDKAppendsGzipAndIgnoresHttpProvidedEncoding_ec2Query", + "SDKAppliedContentEncoding_awsJson1_0", + "SDKAppliedContentEncoding_awsJson1_1", + "SDKAppliedContentEncoding_awsQuery", + "SDKAppliedContentEncoding_ec2Query", + "SDKAppliedContentEncoding_restJson1", + "SDKAppliedContentEncoding_restXml", + // RestXml S3 tests that fail to compile + "S3EscapeObjectKeyInUriLabel", + "S3EscapePathObjectKeyInUriLabel", + ) private fun fixRestJsonAllQueryStringTypes( testCase: HttpRequestTestCase, @@ -905,20 +960,23 @@ class ServerProtocolTestGenerator( ).build() // TODO(https://github.com/awslabs/smithy/issues/1506) - private fun fixRestJsonMalformedPatternReDOSString(testCase: HttpMalformedRequestTestCase): HttpMalformedRequestTestCase { + private fun fixRestJsonMalformedPatternReDOSString( + testCase: HttpMalformedRequestTestCase, + ): HttpMalformedRequestTestCase { val brokenResponse = testCase.response val brokenBody = brokenResponse.body.get() - val fixedBody = HttpMalformedResponseBodyDefinition.builder() - .mediaType(brokenBody.mediaType) - .contents( - """ - { - "message" : "1 validation error detected. Value at '/evilString' failed to satisfy constraint: Member must satisfy regular expression pattern: ^([0-9]+)+${'$'}", - "fieldList" : [{"message": "Value at '/evilString' failed to satisfy constraint: Member must satisfy regular expression pattern: ^([0-9]+)+${'$'}", "path": "/evilString"}] - } - """.trimIndent(), - ) - .build() + val fixedBody = + HttpMalformedResponseBodyDefinition.builder() + .mediaType(brokenBody.mediaType) + .contents( + """ + { + "message" : "1 validation error detected. Value at '/evilString' failed to satisfy constraint: Member must satisfy regular expression pattern: ^([0-9]+)+${'$'}", + "fieldList" : [{"message": "Value at '/evilString' failed to satisfy constraint: Member must satisfy regular expression pattern: ^([0-9]+)+${'$'}", "path": "/evilString"}] + } + """.trimIndent(), + ) + .build() return testCase.toBuilder() .response(brokenResponse.toBuilder().body(fixedBody).build()) @@ -930,7 +988,8 @@ class ServerProtocolTestGenerator( // advantage that once our upstream PRs get merged and we upgrade to the next Smithy release, our build will // fail and we will take notice to remove the fixes from `rest-json-extras.smithy`. This is exactly what the // client does. - private val BrokenMalformedRequestTests: Map, KFunction1> = + private val BrokenMalformedRequestTests: + Map, KFunction1> = // TODO(https://github.com/awslabs/smithy/issues/1506) mapOf( Pair( diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerAwsJson.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerAwsJson.kt index 549ca88b1ee..f967e9450d3 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerAwsJson.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerAwsJson.kt @@ -51,12 +51,12 @@ class ServerAwsJsonFactory( override fun support(): ProtocolSupport { return ProtocolSupport( - /* Client support */ + // Client support requestSerialization = false, requestBodySerialization = false, responseDeserialization = false, errorDeserialization = false, - /* Server support */ + // Server support requestDeserialization = true, requestBodyDeserialization = true, responseSerialization = true, @@ -78,23 +78,26 @@ class ServerAwsJsonFactory( * > field named __type */ class ServerAwsJsonError(private val awsJsonVersion: AwsJsonVersion) : JsonSerializerCustomization() { - override fun section(section: JsonSerializerSection): Writable = when (section) { - is JsonSerializerSection.ServerError -> writable { - if (section.structureShape.hasTrait()) { - val typeId = when (awsJsonVersion) { - // AwsJson 1.0 wants the whole shape ID (namespace#Shape). - // https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_0-protocol.html#operation-error-serialization - AwsJsonVersion.Json10 -> section.structureShape.id.toString() - // AwsJson 1.1 wants only the shape name (Shape). - // https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_1-protocol.html#operation-error-serialization - AwsJsonVersion.Json11 -> section.structureShape.id.name.toString() + override fun section(section: JsonSerializerSection): Writable = + when (section) { + is JsonSerializerSection.ServerError -> + writable { + if (section.structureShape.hasTrait()) { + val typeId = + when (awsJsonVersion) { + // AwsJson 1.0 wants the whole shape ID (namespace#Shape). + // https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_0-protocol.html#operation-error-serialization + AwsJsonVersion.Json10 -> section.structureShape.id.toString() + // AwsJson 1.1 wants only the shape name (Shape). + // https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_1-protocol.html#operation-error-serialization + AwsJsonVersion.Json11 -> section.structureShape.id.name.toString() + } + rust("""${section.jsonObject}.key("__type").string("${escape(typeId)}");""") + } } - rust("""${section.jsonObject}.key("__type").string("${escape(typeId)}");""") - } - } - else -> emptySection - } + else -> emptySection + } } /** @@ -112,10 +115,11 @@ class ServerAwsJsonSerializerGenerator( codegenContext, httpBindingResolver, ::awsJsonFieldName, - customizations = listOf( - ServerAwsJsonError(awsJsonVersion), - BeforeIteratingOverMapOrCollectionJsonCustomization(codegenContext), - BeforeSerializingMemberJsonCustomization(codegenContext), - ), + customizations = + listOf( + ServerAwsJsonError(awsJsonVersion), + BeforeIteratingOverMapOrCollectionJsonCustomization(codegenContext), + BeforeSerializingMemberJsonCustomization(codegenContext), + ), ), ) : StructuredDataSerializerGenerator by jsonSerializerGenerator diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt index 49ddcb03349..071b1095047 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt @@ -122,9 +122,9 @@ class ServerHttpBoundProtocolGenerator( customizations: List = listOf(), additionalHttpBindingCustomizations: List = listOf(), ) : ServerProtocolGenerator( - protocol, - ServerHttpBoundProtocolTraitImplGenerator(codegenContext, protocol, customizations, additionalHttpBindingCustomizations), -) { + protocol, + ServerHttpBoundProtocolTraitImplGenerator(codegenContext, protocol, customizations, additionalHttpBindingCustomizations), + ) { // Define suffixes for operation input / output / error wrappers companion object { const val OPERATION_INPUT_WRAPPER_SUFFIX = "OperationInputWrapper" @@ -136,26 +136,26 @@ class ServerHttpBoundProtocolPayloadGenerator( codegenContext: CodegenContext, protocol: Protocol, ) : ProtocolPayloadGenerator by HttpBoundProtocolPayloadGenerator( - codegenContext, protocol, HttpMessageType.RESPONSE, - renderEventStreamBody = { writer, params -> - writer.rustTemplate( - """ - { - let error_marshaller = #{errorMarshallerConstructorFn}(); - let marshaller = #{marshallerConstructorFn}(); - let signer = #{NoOpSigner}{}; - let adapter: #{aws_smithy_http}::event_stream::MessageStreamAdapter<_, _> = - ${params.outerName}.${params.memberName}.into_body_stream(marshaller, error_marshaller, signer); - adapter - } - """, - "aws_smithy_http" to RuntimeType.smithyHttp(codegenContext.runtimeConfig), - "NoOpSigner" to RuntimeType.smithyEventStream(codegenContext.runtimeConfig).resolve("frame::NoOpSigner"), - "marshallerConstructorFn" to params.marshallerConstructorFn, - "errorMarshallerConstructorFn" to params.errorMarshallerConstructorFn, - ) - }, -) + codegenContext, protocol, HttpMessageType.RESPONSE, + renderEventStreamBody = { writer, params -> + writer.rustTemplate( + """ + { + let error_marshaller = #{errorMarshallerConstructorFn}(); + let marshaller = #{marshallerConstructorFn}(); + let signer = #{NoOpSigner}{}; + let adapter: #{aws_smithy_http}::event_stream::MessageStreamAdapter<_, _> = + ${params.outerName}.${params.memberName}.into_body_stream(marshaller, error_marshaller, signer); + adapter + } + """, + "aws_smithy_http" to RuntimeType.smithyHttp(codegenContext.runtimeConfig), + "NoOpSigner" to RuntimeType.smithyEventStream(codegenContext.runtimeConfig).resolve("frame::NoOpSigner"), + "marshallerConstructorFn" to params.marshallerConstructorFn, + "errorMarshallerConstructorFn" to params.errorMarshallerConstructorFn, + ) + }, + ) /* * Generate all operation input parsers and output serializers for streaming and @@ -175,32 +175,36 @@ class ServerHttpBoundProtocolTraitImplGenerator( private val httpBindingResolver = protocol.httpBindingResolver private val protocolFunctions = ProtocolFunctions(codegenContext) - private val codegenScope = arrayOf( - "AsyncTrait" to ServerCargoDependency.AsyncTrait.toType(), - "Cow" to RuntimeType.Cow, - "DateTime" to RuntimeType.dateTime(runtimeConfig), - "FormUrlEncoded" to ServerCargoDependency.FormUrlEncoded.toType(), - "FuturesUtil" to ServerCargoDependency.FuturesUtil.toType(), - "HttpBody" to RuntimeType.HttpBody, - "header_util" to RuntimeType.smithyHttp(runtimeConfig).resolve("header"), - "Hyper" to RuntimeType.Hyper, - "LazyStatic" to RuntimeType.LazyStatic, - "Mime" to ServerCargoDependency.Mime.toType(), - "Nom" to ServerCargoDependency.Nom.toType(), - "OnceCell" to RuntimeType.OnceCell, - "PercentEncoding" to RuntimeType.PercentEncoding, - "Regex" to RuntimeType.Regex, - "SmithyHttpServer" to ServerCargoDependency.smithyHttpServer(runtimeConfig).toType(), - "SmithyTypes" to RuntimeType.smithyTypes(runtimeConfig), - "RuntimeError" to protocol.runtimeError(runtimeConfig), - "RequestRejection" to protocol.requestRejection(runtimeConfig), - "ResponseRejection" to protocol.responseRejection(runtimeConfig), - "PinProjectLite" to ServerCargoDependency.PinProjectLite.toType(), - "http" to RuntimeType.Http, - "Tracing" to RuntimeType.Tracing, - ) + private val codegenScope = + arrayOf( + "AsyncTrait" to ServerCargoDependency.AsyncTrait.toType(), + "Cow" to RuntimeType.Cow, + "DateTime" to RuntimeType.dateTime(runtimeConfig), + "FormUrlEncoded" to ServerCargoDependency.FormUrlEncoded.toType(), + "FuturesUtil" to ServerCargoDependency.FuturesUtil.toType(), + "HttpBody" to RuntimeType.HttpBody, + "header_util" to RuntimeType.smithyHttp(runtimeConfig).resolve("header"), + "Hyper" to RuntimeType.Hyper, + "LazyStatic" to RuntimeType.LazyStatic, + "Mime" to ServerCargoDependency.Mime.toType(), + "Nom" to ServerCargoDependency.Nom.toType(), + "OnceCell" to RuntimeType.OnceCell, + "PercentEncoding" to RuntimeType.PercentEncoding, + "Regex" to RuntimeType.Regex, + "SmithyHttpServer" to ServerCargoDependency.smithyHttpServer(runtimeConfig).toType(), + "SmithyTypes" to RuntimeType.smithyTypes(runtimeConfig), + "RuntimeError" to protocol.runtimeError(runtimeConfig), + "RequestRejection" to protocol.requestRejection(runtimeConfig), + "ResponseRejection" to protocol.responseRejection(runtimeConfig), + "PinProjectLite" to ServerCargoDependency.PinProjectLite.toType(), + "http" to RuntimeType.Http, + "Tracing" to RuntimeType.Tracing, + ) - fun generateTraitImpls(operationWriter: RustWriter, operationShape: OperationShape) { + fun generateTraitImpls( + operationWriter: RustWriter, + operationShape: OperationShape, + ) { val inputSymbol = symbolProvider.toSymbol(operationShape.inputShape(model)) val outputSymbol = symbolProvider.toSymbol(operationShape.outputShape(model)) @@ -223,62 +227,66 @@ class ServerHttpBoundProtocolTraitImplGenerator( ) { val operationName = symbolProvider.toSymbol(operationShape).name val staticContentType = "CONTENT_TYPE_${operationName.uppercase()}" - val verifyAcceptHeader = writable { - httpBindingResolver.responseContentType(operationShape)?.also { contentType -> - rustTemplate( - """ - if !#{SmithyHttpServer}::protocol::accept_header_classifier(request.headers(), &$staticContentType) { - return Err(#{RequestRejection}::NotAcceptable); - } - """, - *codegenScope, - ) - } - } - val verifyAcceptHeaderStaticContentTypeInit = writable { - httpBindingResolver.responseContentType(operationShape)?.also { contentType -> - val init = when (contentType) { - "application/json" -> "const $staticContentType: #{Mime}::Mime = #{Mime}::APPLICATION_JSON;" - "application/octet-stream" -> "const $staticContentType: #{Mime}::Mime = #{Mime}::APPLICATION_OCTET_STREAM;" - "application/x-www-form-urlencoded" -> "const $staticContentType: #{Mime}::Mime = #{Mime}::APPLICATION_WWW_FORM_URLENCODED;" - else -> - """ - static $staticContentType: #{OnceCell}::sync::Lazy<#{Mime}::Mime> = #{OnceCell}::sync::Lazy::new(|| { - ${contentType.dq()}.parse::<#{Mime}::Mime>().expect("BUG: MIME parsing failed, content_type is not valid") - }); + val verifyAcceptHeader = + writable { + httpBindingResolver.responseContentType(operationShape)?.also { contentType -> + rustTemplate( """ + if !#{SmithyHttpServer}::protocol::accept_header_classifier(request.headers(), &$staticContentType) { + return Err(#{RequestRejection}::NotAcceptable); + } + """, + *codegenScope, + ) + } + } + val verifyAcceptHeaderStaticContentTypeInit = + writable { + httpBindingResolver.responseContentType(operationShape)?.also { contentType -> + val init = + when (contentType) { + "application/json" -> "const $staticContentType: #{Mime}::Mime = #{Mime}::APPLICATION_JSON;" + "application/octet-stream" -> "const $staticContentType: #{Mime}::Mime = #{Mime}::APPLICATION_OCTET_STREAM;" + "application/x-www-form-urlencoded" -> "const $staticContentType: #{Mime}::Mime = #{Mime}::APPLICATION_WWW_FORM_URLENCODED;" + else -> + """ + static $staticContentType: #{OnceCell}::sync::Lazy<#{Mime}::Mime> = #{OnceCell}::sync::Lazy::new(|| { + ${contentType.dq()}.parse::<#{Mime}::Mime>().expect("BUG: MIME parsing failed, content_type is not valid") + }); + """ + } + rustTemplate(init, *codegenScope) } - rustTemplate(init, *codegenScope) } - } // This checks for the expected `Content-Type` header if the `@httpPayload` trait is present, as dictated by // the core Smithy library, which _does not_ require deserializing the payload. // If no members have `@httpPayload`, the expected `Content-Type` header as dictated _by the protocol_ is // checked later on for non-streaming operations, in `serverRenderShapeParser`: that check _does_ require at // least buffering the entire payload, since the check must only be performed if the payload is empty. - val verifyRequestContentTypeHeader = writable { - operationShape - .inputShape(model) - .members() - .find { it.hasTrait() } - ?.let { payload -> - val target = model.expectShape(payload.target) - if (!target.isBlobShape || target.hasTrait()) { - // `null` is only returned by Smithy when there are no members, but we know there's at least - // the one with `@httpPayload`, so `!!` is safe here. - val expectedRequestContentType = httpBindingResolver.requestContentType(operationShape)!! - rustTemplate( - """ - #{SmithyHttpServer}::protocol::content_type_header_classifier_http( - request.headers(), - Some("$expectedRequestContentType"), - )?; - """, - *codegenScope, - ) + val verifyRequestContentTypeHeader = + writable { + operationShape + .inputShape(model) + .members() + .find { it.hasTrait() } + ?.let { payload -> + val target = model.expectShape(payload.target) + if (!target.isBlobShape || target.hasTrait()) { + // `null` is only returned by Smithy when there are no members, but we know there's at least + // the one with `@httpPayload`, so `!!` is safe here. + val expectedRequestContentType = httpBindingResolver.requestContentType(operationShape)!! + rustTemplate( + """ + #{SmithyHttpServer}::protocol::content_type_header_classifier_http( + request.headers(), + Some("$expectedRequestContentType"), + )?; + """, + *codegenScope, + ) + } } - } - } + } // Implement `from_request` trait for input types. val inputFuture = "${inputSymbol.name}Future" @@ -542,10 +550,11 @@ class ServerHttpBoundProtocolTraitImplGenerator( rustTemplate("let mut builder = #{http}::Response::builder();", *codegenScope) serverRenderResponseHeaders(operationShape) // Fallback to the default code of `@http`, which should be 200. - val httpTraitDefaultStatusCode = HttpTrait - .builder().method("GET").uri(UriPattern.parse("/")) /* Required to build */ - .build() - .code + val httpTraitDefaultStatusCode = + HttpTrait + .builder().method("GET").uri(UriPattern.parse("/")) // Required to build + .build() + .code check(httpTraitDefaultStatusCode == 200) val httpTraitStatusCode = operationShape.getTrait()?.code ?: httpTraitDefaultStatusCode bindings.find { it.location == HttpLocation.RESPONSE_CODE } @@ -605,7 +614,10 @@ class ServerHttpBoundProtocolTraitImplGenerator( * 2. The protocol-specific `Content-Type` header for the operation. * 3. Additional protocol-specific headers for errors, if [errorShape] is non-null. */ - private fun RustWriter.serverRenderResponseHeaders(operationShape: OperationShape, errorShape: StructureShape? = null) { + private fun RustWriter.serverRenderResponseHeaders( + operationShape: OperationShape, + errorShape: StructureShape? = null, + ) { val bindingGenerator = ServerResponseBindingGenerator(protocol, codegenContext, operationShape) val addHeadersFn = bindingGenerator.generateAddHeadersFn(errorShape ?: operationShape) if (addHeadersFn != null) { @@ -669,21 +681,22 @@ class ServerHttpBoundProtocolTraitImplGenerator( ) } - private fun serverRenderHttpResponseCode(defaultCode: Int) = writable { - check(defaultCode in 100..999) { - """ - Smithy library lied to us. According to https://smithy.io/2.0/spec/http-bindings.html#http-trait, - "The provided value SHOULD be between 100 and 599, and it MUST be between 100 and 999". - """.replace("\n", "").trimIndent() + private fun serverRenderHttpResponseCode(defaultCode: Int) = + writable { + check(defaultCode in 100..999) { + """ + Smithy library lied to us. According to https://smithy.io/2.0/spec/http-bindings.html#http-trait, + "The provided value SHOULD be between 100 and 599, and it MUST be between 100 and 999". + """.replace("\n", "").trimIndent() + } + rustTemplate( + """ + let http_status: u16 = $defaultCode; + builder = builder.status(http_status); + """, + *codegenScope, + ) } - rustTemplate( - """ - let http_status: u16 = $defaultCode; - builder = builder.status(http_status); - """, - *codegenScope, - ) - } private fun serverRenderResponseCodeBinding( binding: HttpBindingDescriptor, @@ -715,7 +728,8 @@ class ServerHttpBoundProtocolTraitImplGenerator( inputShape: StructureShape, bindings: List, ) { - val httpBindingGenerator = ServerRequestBindingGenerator(protocol, codegenContext, operationShape, additionalHttpBindingCustomizations) + val httpBindingGenerator = + ServerRequestBindingGenerator(protocol, codegenContext, operationShape, additionalHttpBindingCustomizations) val structuredDataParser = protocol.structuredDataParser() Attribute.AllowUnusedMut.render(this) rust( @@ -756,7 +770,8 @@ class ServerHttpBoundProtocolTraitImplGenerator( } for (binding in bindings) { val member = binding.member - val parsedValue = serverRenderBindingParser(binding, operationShape, httpBindingGenerator, structuredDataParser) + val parsedValue = + serverRenderBindingParser(binding, operationShape, httpBindingGenerator, structuredDataParser) if (parsedValue != null) { rust("if let Some(value) = ") parsedValue(this) @@ -788,17 +803,18 @@ class ServerHttpBoundProtocolTraitImplGenerator( ) } } - val err = if (ServerBuilderGenerator.hasFallibleBuilder( - inputShape, - model, - symbolProvider, - takeInUnconstrainedTypes = true, - ) - ) { - "?" - } else { - "" - } + val err = + if (ServerBuilderGenerator.hasFallibleBuilder( + inputShape, + model, + symbolProvider, + takeInUnconstrainedTypes = true, + ) + ) { + "?" + } else { + "" + } rustTemplate("input.build()$err", *codegenScope) } @@ -816,11 +832,12 @@ class ServerHttpBoundProtocolTraitImplGenerator( rust("#T($body)", structuredDataParser.payloadParser(binding.member)) } val errorSymbol = getDeserializePayloadErrorSymbol(binding) - val deserializer = httpBindingGenerator.generateDeserializePayloadFn( - binding, - errorSymbol, - structuredHandler = structureShapeHandler, - ) + val deserializer = + httpBindingGenerator.generateDeserializePayloadFn( + binding, + errorSymbol, + structuredHandler = structureShapeHandler, + ) return writable { if (binding.member.isStreaming(model)) { rustTemplate( @@ -857,7 +874,10 @@ class ServerHttpBoundProtocolTraitImplGenerator( } } - private fun serverRenderUriPathParser(writer: RustWriter, operationShape: OperationShape) { + private fun serverRenderUriPathParser( + writer: RustWriter, + operationShape: OperationShape, + ) { val pathBindings = httpBindingResolver.requestBindings(operationShape).filter { it.location == HttpLocation.LABEL @@ -879,31 +899,37 @@ class ServerHttpBoundProtocolTraitImplGenerator( } else { "" } - val labeledNames = segments - .mapIndexed { index, segment -> - if (segment.isLabel) { "m$index" } else { "_" } - } - .joinToString(prefix = (if (segments.size > 1) "(" else ""), separator = ",", postfix = (if (segments.size > 1) ")" else "")) - val nomParser = segments - .map { segment -> - if (segment.isGreedyLabel) { - "#{Nom}::combinator::rest::<_, #{Nom}::error::Error<&str>>" - } else if (segment.isLabel) { - """#{Nom}::branch::alt::<_, _, #{Nom}::error::Error<&str>, _>((#{Nom}::bytes::complete::take_until("/"), #{Nom}::combinator::rest))""" - } else { - """#{Nom}::bytes::complete::tag::<_, _, #{Nom}::error::Error<&str>>("${segment.content}")""" + val labeledNames = + segments + .mapIndexed { index, segment -> + if (segment.isLabel) { + "m$index" + } else { + "_" + } } - } - .joinToString( - // TODO(https://github.com/smithy-lang/smithy-rs/issues/1289): Note we're limited to 21 labels because of `tuple`. - prefix = if (segments.size > 1) "#{Nom}::sequence::tuple::<_, _, #{Nom}::error::Error<&str>, _>((" else "", - postfix = if (segments.size > 1) "))" else "", - transform = { parser -> - """ - #{Nom}::sequence::preceded(#{Nom}::bytes::complete::tag("/"), $parser) - """.trimIndent() - }, - ) + .joinToString(prefix = (if (segments.size > 1) "(" else ""), separator = ",", postfix = (if (segments.size > 1) ")" else "")) + val nomParser = + segments + .map { segment -> + if (segment.isGreedyLabel) { + "#{Nom}::combinator::rest::<_, #{Nom}::error::Error<&str>>" + } else if (segment.isLabel) { + """#{Nom}::branch::alt::<_, _, #{Nom}::error::Error<&str>, _>((#{Nom}::bytes::complete::take_until("/"), #{Nom}::combinator::rest))""" + } else { + """#{Nom}::bytes::complete::tag::<_, _, #{Nom}::error::Error<&str>>("${segment.content}")""" + } + } + .joinToString( + // TODO(https://github.com/smithy-lang/smithy-rs/issues/1289): Note we're limited to 21 labels because of `tuple`. + prefix = if (segments.size > 1) "#{Nom}::sequence::tuple::<_, _, #{Nom}::error::Error<&str>, _>((" else "", + postfix = if (segments.size > 1) "))" else "", + transform = { parser -> + """ + #{Nom}::sequence::preceded(#{Nom}::bytes::complete::tag("/"), $parser) + """.trimIndent() + }, + ) with(writer) { rustTemplate("let input_string = uri.path();") if (greedyLabelIndex >= 0 && greedyLabelIndex + 1 < httpTrait.uri.segments.size) { @@ -948,7 +974,9 @@ class ServerHttpBoundProtocolTraitImplGenerator( // * a map of list of string; or // * a map of set of string. enum class QueryParamsTargetMapValueType { - STRING, LIST, SET + STRING, + LIST, + SET, } private fun queryParamsTargetMapValueType(targetMapValue: Shape): QueryParamsTargetMapValueType = @@ -967,7 +995,10 @@ class ServerHttpBoundProtocolTraitImplGenerator( ) } - private fun serverRenderQueryStringParser(writer: RustWriter, operationShape: OperationShape) { + private fun serverRenderQueryStringParser( + writer: RustWriter, + operationShape: OperationShape, + ) { val queryBindings = httpBindingResolver.requestBindings(operationShape).filter { it.location == HttpLocation.QUERY @@ -1156,8 +1187,13 @@ class ServerHttpBoundProtocolTraitImplGenerator( } } - private fun serverRenderHeaderParser(writer: RustWriter, binding: HttpBindingDescriptor, operationShape: OperationShape) { - val httpBindingGenerator = ServerRequestBindingGenerator(protocol, codegenContext, operationShape, additionalHttpBindingCustomizations) + private fun serverRenderHeaderParser( + writer: RustWriter, + binding: HttpBindingDescriptor, + operationShape: OperationShape, + ) { + val httpBindingGenerator = + ServerRequestBindingGenerator(protocol, codegenContext, operationShape, additionalHttpBindingCustomizations) val deserializer = httpBindingGenerator.generateDeserializeHeaderFn(binding) writer.rustTemplate( """ @@ -1168,7 +1204,11 @@ class ServerHttpBoundProtocolTraitImplGenerator( ) } - private fun serverRenderPrefixHeadersParser(writer: RustWriter, binding: HttpBindingDescriptor, operationShape: OperationShape) { + private fun serverRenderPrefixHeadersParser( + writer: RustWriter, + binding: HttpBindingDescriptor, + operationShape: OperationShape, + ) { check(binding.location == HttpLocation.PREFIX_HEADERS) val httpBindingGenerator = ServerRequestBindingGenerator(protocol, codegenContext, operationShape) @@ -1182,7 +1222,10 @@ class ServerHttpBoundProtocolTraitImplGenerator( ) } - private fun generateParseStrFn(binding: HttpBindingDescriptor, percentDecoding: Boolean): RuntimeType { + private fun generateParseStrFn( + binding: HttpBindingDescriptor, + percentDecoding: Boolean, + ): RuntimeType { val output = unconstrainedShapeSymbolProvider.toSymbol(binding.member) return protocolFunctions.deserializeFn(binding.member) { fnName -> rustBlockTemplate( diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerProtocolLoader.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerProtocolLoader.kt index a72ad201c0d..ae87ec5723b 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerProtocolLoader.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerProtocolLoader.kt @@ -21,56 +21,64 @@ import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocolGenerator class StreamPayloadSerializerCustomization() : ServerHttpBoundProtocolCustomization() { - override fun section(section: ServerHttpBoundProtocolSection): Writable = when (section) { - is ServerHttpBoundProtocolSection.WrapStreamPayload -> writable { - if (section.params.shape.isOutputEventStream(section.params.codegenContext.model)) { - // Event stream payload, of type `aws_smithy_http::event_stream::MessageStreamAdapter`, already - // implements the `Stream` trait, so no need to wrap it in the new-type. - section.params.payloadGenerator.generatePayload(this, section.params.shapeName, section.params.shape) - } else { - // Otherwise, the stream payload is `aws_smithy_types::byte_stream::ByteStream`. We wrap it in the - // new-type to enable the `Stream` trait. - withBlockTemplate( - "#{FuturesStreamCompatByteStream}::new(", - ")", - "FuturesStreamCompatByteStream" to RuntimeType.futuresStreamCompatByteStream(section.params.codegenContext.runtimeConfig), - ) { - section.params.payloadGenerator.generatePayload( - this, - section.params.shapeName, - section.params.shape, - ) + override fun section(section: ServerHttpBoundProtocolSection): Writable = + when (section) { + is ServerHttpBoundProtocolSection.WrapStreamPayload -> + writable { + if (section.params.shape.isOutputEventStream(section.params.codegenContext.model)) { + // Event stream payload, of type `aws_smithy_http::event_stream::MessageStreamAdapter`, already + // implements the `Stream` trait, so no need to wrap it in the new-type. + section.params.payloadGenerator.generatePayload(this, section.params.shapeName, section.params.shape) + } else { + // Otherwise, the stream payload is `aws_smithy_types::byte_stream::ByteStream`. We wrap it in the + // new-type to enable the `Stream` trait. + withBlockTemplate( + "#{FuturesStreamCompatByteStream}::new(", + ")", + "FuturesStreamCompatByteStream" to RuntimeType.futuresStreamCompatByteStream(section.params.codegenContext.runtimeConfig), + ) { + section.params.payloadGenerator.generatePayload( + this, + section.params.shapeName, + section.params.shape, + ) + } + } } - } - } - else -> emptySection - } + else -> emptySection + } } class ServerProtocolLoader(supportedProtocols: ProtocolMap) : ProtocolLoader(supportedProtocols) { - companion object { - val DefaultProtocols = mapOf( - RestJson1Trait.ID to ServerRestJsonFactory( - additionalServerHttpBoundProtocolCustomizations = listOf( - StreamPayloadSerializerCustomization(), - ), - ), - RestXmlTrait.ID to ServerRestXmlFactory( - additionalServerHttpBoundProtocolCustomizations = listOf( - StreamPayloadSerializerCustomization(), - ), - ), - AwsJson1_0Trait.ID to ServerAwsJsonFactory( - AwsJsonVersion.Json10, - additionalServerHttpBoundProtocolCustomizations = listOf(StreamPayloadSerializerCustomization()), - ), - AwsJson1_1Trait.ID to ServerAwsJsonFactory( - AwsJsonVersion.Json11, - additionalServerHttpBoundProtocolCustomizations = listOf(StreamPayloadSerializerCustomization()), - ), - ) + val DefaultProtocols = + mapOf( + RestJson1Trait.ID to + ServerRestJsonFactory( + additionalServerHttpBoundProtocolCustomizations = + listOf( + StreamPayloadSerializerCustomization(), + ), + ), + RestXmlTrait.ID to + ServerRestXmlFactory( + additionalServerHttpBoundProtocolCustomizations = + listOf( + StreamPayloadSerializerCustomization(), + ), + ), + AwsJson1_0Trait.ID to + ServerAwsJsonFactory( + AwsJsonVersion.Json10, + additionalServerHttpBoundProtocolCustomizations = listOf(StreamPayloadSerializerCustomization()), + ), + AwsJson1_1Trait.ID to + ServerAwsJsonFactory( + AwsJsonVersion.Json11, + additionalServerHttpBoundProtocolCustomizations = listOf(StreamPayloadSerializerCustomization()), + ), + ) } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerRestJson.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerRestJson.kt index ddf1ca08c33..5810081df52 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerRestJson.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerRestJson.kt @@ -28,7 +28,8 @@ class ServerRestJsonFactory( private val additionalServerHttpBoundProtocolCustomizations: List = listOf(), private val additionalHttpBindingCustomizations: List = listOf(), ) : ProtocolGeneratorFactory { - override fun protocol(codegenContext: ServerCodegenContext): Protocol = ServerRestJsonProtocol(codegenContext, additionalParserCustomizations) + override fun protocol(codegenContext: ServerCodegenContext): Protocol = + ServerRestJsonProtocol(codegenContext, additionalParserCustomizations) override fun buildProtocolGenerator(codegenContext: ServerCodegenContext): ServerHttpBoundProtocolGenerator = ServerHttpBoundProtocolGenerator( @@ -43,12 +44,12 @@ class ServerRestJsonFactory( override fun support(): ProtocolSupport { return ProtocolSupport( - /* Client support */ + // Client support requestSerialization = false, requestBodySerialization = false, responseDeserialization = false, errorDeserialization = false, - /* Server support */ + // Server support requestDeserialization = true, requestBodyDeserialization = true, responseSerialization = true, @@ -65,9 +66,10 @@ class ServerRestJsonSerializerGenerator( codegenContext, httpBindingResolver, ::restJsonFieldName, - customizations = listOf( - BeforeIteratingOverMapOrCollectionJsonCustomization(codegenContext), - BeforeSerializingMemberJsonCustomization(codegenContext), - ), + customizations = + listOf( + BeforeIteratingOverMapOrCollectionJsonCustomization(codegenContext), + BeforeSerializingMemberJsonCustomization(codegenContext), + ), ), ) : StructuredDataSerializerGenerator by jsonSerializerGenerator diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerRestXmlFactory.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerRestXmlFactory.kt index 9207c56046e..7ef7566a6f2 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerRestXmlFactory.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerRestXmlFactory.kt @@ -29,12 +29,12 @@ class ServerRestXmlFactory( override fun support(): ProtocolSupport { return ProtocolSupport( - /* Client support */ + // Client support requestSerialization = false, requestBodySerialization = false, responseDeserialization = false, errorDeserialization = false, - /* Server support */ + // Server support requestDeserialization = true, requestBodyDeserialization = true, responseSerialization = true, diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/testutil/ServerCodegenIntegrationTest.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/testutil/ServerCodegenIntegrationTest.kt index fc83f1392b0..8c0254904e8 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/testutil/ServerCodegenIntegrationTest.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/testutil/ServerCodegenIntegrationTest.kt @@ -27,16 +27,20 @@ fun serverIntegrationTest( test: (ServerCodegenContext, RustCrate) -> Unit = { _, _ -> }, ): Path { fun invokeRustCodegenPlugin(ctx: PluginContext) { - val codegenDecorator = object : ServerCodegenDecorator { - override val name: String = "Add tests" - override val order: Byte = 0 - - override fun classpathDiscoverable(): Boolean = false - - override fun extras(codegenContext: ServerCodegenContext, rustCrate: RustCrate) { - test(codegenContext, rustCrate) + val codegenDecorator = + object : ServerCodegenDecorator { + override val name: String = "Add tests" + override val order: Byte = 0 + + override fun classpathDiscoverable(): Boolean = false + + override fun extras( + codegenContext: ServerCodegenContext, + rustCrate: RustCrate, + ) { + test(codegenContext, rustCrate) + } } - } RustServerCodegenPlugin().executeWithDecorator(ctx, codegenDecorator, *additionalDecorators.toTypedArray()) } return codegenIntegrationTest(model, params, invokePlugin = ::invokeRustCodegenPlugin) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/testutil/ServerTestHelpers.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/testutil/ServerTestHelpers.kt index f161282242a..03643d80a80 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/testutil/ServerTestHelpers.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/testutil/ServerTestHelpers.kt @@ -35,38 +35,40 @@ import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.Ser import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerProtocolLoader // These are the settings we default to if the user does not override them in their `smithy-build.json`. -val ServerTestRustSymbolProviderConfig = RustSymbolProviderConfig( - runtimeConfig = TestRuntimeConfig, - renameExceptions = false, - nullabilityCheckMode = NullableIndex.CheckMode.SERVER, - moduleProvider = ServerModuleProvider, -) +val ServerTestRustSymbolProviderConfig = + RustSymbolProviderConfig( + runtimeConfig = TestRuntimeConfig, + renameExceptions = false, + nullabilityCheckMode = NullableIndex.CheckMode.SERVER, + moduleProvider = ServerModuleProvider, + ) private fun testServiceShapeFor(model: Model) = model.serviceShapes.firstOrNull() ?: ServiceShape.builder().version("test").id("test#Service").build() -fun serverTestSymbolProvider(model: Model, serviceShape: ServiceShape? = null) = - serverTestSymbolProviders(model, serviceShape).symbolProvider +fun serverTestSymbolProvider( + model: Model, + serviceShape: ServiceShape? = null, +) = serverTestSymbolProviders(model, serviceShape).symbolProvider fun serverTestSymbolProviders( model: Model, serviceShape: ServiceShape? = null, settings: ServerRustSettings? = null, decorators: List = emptyList(), -) = - ServerSymbolProviders.from( - serverTestRustSettings(), - model, - serviceShape ?: testServiceShapeFor(model), - ServerTestRustSymbolProviderConfig, - ( - settings ?: serverTestRustSettings( - (serviceShape ?: testServiceShapeFor(model)).id, - ) - ).codegenConfig.publicConstrainedTypes, - CombinedServerCodegenDecorator(decorators), - RustServerCodegenPlugin::baseSymbolProvider, - ) +) = ServerSymbolProviders.from( + serverTestRustSettings(), + model, + serviceShape ?: testServiceShapeFor(model), + ServerTestRustSymbolProviderConfig, + ( + settings ?: serverTestRustSettings( + (serviceShape ?: testServiceShapeFor(model)).id, + ) + ).codegenConfig.publicConstrainedTypes, + CombinedServerCodegenDecorator(decorators), + RustServerCodegenPlugin::baseSymbolProvider, +) fun serverTestRustSettings( service: ShapeId = ShapeId.from("notrelevant#notrelevant"), @@ -103,15 +105,16 @@ fun serverTestCodegenContext( ): ServerCodegenContext { val service = serviceShape ?: testServiceShapeFor(model) val protocol = protocolShapeId ?: ShapeId.from("test#Protocol") - val serverSymbolProviders = ServerSymbolProviders.from( - settings, - model, - service, - ServerTestRustSymbolProviderConfig, - settings.codegenConfig.publicConstrainedTypes, - CombinedServerCodegenDecorator(decorators), - RustServerCodegenPlugin::baseSymbolProvider, - ) + val serverSymbolProviders = + ServerSymbolProviders.from( + settings, + model, + service, + ServerTestRustSymbolProviderConfig, + settings.codegenConfig.publicConstrainedTypes, + CombinedServerCodegenDecorator(decorators), + RustServerCodegenPlugin::baseSymbolProvider, + ) return ServerCodegenContext( model, @@ -148,12 +151,13 @@ fun StructureShape.serverRenderWithModelBuilder( val serverCodegenContext = serverTestCodegenContext(model) // Note that this always uses `ServerBuilderGenerator` and _not_ `ServerBuilderGeneratorWithoutPublicConstrainedTypes`, // regardless of the `publicConstrainedTypes` setting. - val modelBuilder = ServerBuilderGenerator( - serverCodegenContext, - this, - SmithyValidationExceptionConversionGenerator(serverCodegenContext), - protocol ?: loadServerProtocol(model), - ) + val modelBuilder = + ServerBuilderGenerator( + serverCodegenContext, + this, + SmithyValidationExceptionConversionGenerator(serverCodegenContext), + protocol ?: loadServerProtocol(model), + ) modelBuilder.render(rustCrate, writer) writer.implBlock(symbolProvider.toSymbol(this)) { modelBuilder.renderConvenienceMethod(this) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/traits/ConstraintViolationRustBoxTrait.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/traits/ConstraintViolationRustBoxTrait.kt index 9aee2b884ef..73574328897 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/traits/ConstraintViolationRustBoxTrait.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/traits/ConstraintViolationRustBoxTrait.kt @@ -19,6 +19,7 @@ import software.amazon.smithy.model.traits.Trait */ class ConstraintViolationRustBoxTrait : Trait { val ID = ShapeId.from("software.amazon.smithy.rust.codegen.smithy.rust.synthetic#constraintViolationBox") + override fun toNode(): Node = Node.objectNode() override fun toShapeId(): ShapeId = ID diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/traits/ShapeReachableFromOperationInputTagTrait.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/traits/ShapeReachableFromOperationInputTagTrait.kt index d3684687fe3..b44dc44b840 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/traits/ShapeReachableFromOperationInputTagTrait.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/traits/ShapeReachableFromOperationInputTagTrait.kt @@ -35,19 +35,27 @@ class ShapeReachableFromOperationInputTagTrait : AnnotationTrait(ID, Node.object } } -private fun isShapeReachableFromOperationInput(shape: Shape) = when (shape) { - is StructureShape, is UnionShape, is MapShape, is ListShape, is StringShape, is IntegerShape, is ShortShape, is LongShape, is ByteShape, is BlobShape -> { - shape.hasTrait() - } +private fun isShapeReachableFromOperationInput(shape: Shape) = + when (shape) { + is StructureShape, is UnionShape, is MapShape, is ListShape, is StringShape, is IntegerShape, is ShortShape, is LongShape, is ByteShape, is BlobShape -> { + shape.hasTrait() + } - else -> PANIC("this method does not support shape type ${shape.type}") -} + else -> PANIC("this method does not support shape type ${shape.type}") + } fun StringShape.isReachableFromOperationInput() = isShapeReachableFromOperationInput(this) + fun StructureShape.isReachableFromOperationInput() = isShapeReachableFromOperationInput(this) + fun CollectionShape.isReachableFromOperationInput() = isShapeReachableFromOperationInput(this) + fun UnionShape.isReachableFromOperationInput() = isShapeReachableFromOperationInput(this) + fun MapShape.isReachableFromOperationInput() = isShapeReachableFromOperationInput(this) + fun IntegerShape.isReachableFromOperationInput() = isShapeReachableFromOperationInput(this) + fun NumberShape.isReachableFromOperationInput() = isShapeReachableFromOperationInput(this) + fun BlobShape.isReachableFromOperationInput() = isShapeReachableFromOperationInput(this) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/AttachValidationExceptionToConstrainedOperationInputsInAllowList.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/AttachValidationExceptionToConstrainedOperationInputsInAllowList.kt index 93bc55df128..18baaadf221 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/AttachValidationExceptionToConstrainedOperationInputsInAllowList.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/AttachValidationExceptionToConstrainedOperationInputsInAllowList.kt @@ -40,7 +40,6 @@ object AttachValidationExceptionToConstrainedOperationInputsInAllowList { ShapeId.from("aws.protocoltests.json#JsonProtocol"), ShapeId.from("com.amazonaws.s3#AmazonS3"), ShapeId.from("com.amazonaws.ebs#Ebs"), - // These are only loaded in the classpath and need this model transformer, but we don't generate server // SDKs for them. Here they are for reference. // ShapeId.from("aws.protocoltests.restxml#RestXml"), @@ -52,16 +51,17 @@ object AttachValidationExceptionToConstrainedOperationInputsInAllowList { fun transform(model: Model): Model { val walker = DirectedWalker(model) - val operationsWithConstrainedInputWithoutValidationException = model.serviceShapes - .filter { sherviceShapeIdAllowList.contains(it.toShapeId()) } - .flatMap { it.operations } - .map { model.expectShape(it, OperationShape::class.java) } - .filter { operationShape -> - // Walk the shapes reachable via this operation input. - walker.walkShapes(operationShape.inputShape(model)) - .any { it is SetShape || it is EnumShape || it.hasConstraintTrait() } - } - .filter { !it.errors.contains(SmithyValidationExceptionConversionGenerator.SHAPE_ID) } + val operationsWithConstrainedInputWithoutValidationException = + model.serviceShapes + .filter { sherviceShapeIdAllowList.contains(it.toShapeId()) } + .flatMap { it.operations } + .map { model.expectShape(it, OperationShape::class.java) } + .filter { operationShape -> + // Walk the shapes reachable via this operation input. + walker.walkShapes(operationShape.inputShape(model)) + .any { it is SetShape || it is EnumShape || it.hasConstraintTrait() } + } + .filter { !it.errors.contains(SmithyValidationExceptionConversionGenerator.SHAPE_ID) } return ModelTransformer.create().mapShapes(model) { shape -> if (shape is OperationShape && operationsWithConstrainedInputWithoutValidationException.contains(shape)) { diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/ConstrainedMemberTransform.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/ConstrainedMemberTransform.kt index 3ab99ce64db..c3540b273a0 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/ConstrainedMemberTransform.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/ConstrainedMemberTransform.kt @@ -26,7 +26,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.allConstraintTraits import software.amazon.smithy.rust.codegen.server.smithy.traits.SyntheticStructureFromConstrainedMemberTrait import software.amazon.smithy.utils.ToSmithyBuilder import java.lang.IllegalStateException -import java.util.* +import java.util.Locale /** * Transforms all member shapes that have constraints on them into equivalent non-constrained @@ -69,8 +69,7 @@ object ConstrainedMemberTransform { private val memberConstraintTraitsToOverride = allConstraintTraits - RequiredTrait::class.java - private fun Shape.hasMemberConstraintTrait() = - memberConstraintTraitsToOverride.any(this::hasTrait) + private fun Shape.hasMemberConstraintTrait() = memberConstraintTraitsToOverride.any(this::hasTrait) fun transform(model: Model): Model { val additionalNames = HashSet() @@ -81,20 +80,21 @@ object ConstrainedMemberTransform { // convert them into non-constrained members and then pass them to the transformer. // The transformer will add new shapes, and will replace existing member shapes' target // with the newly added shapes. - val transformations = model.operationShapes - .flatMap { listOfNotNull(it.input.orNull(), it.output.orNull()) + it.errors } - .mapNotNull { model.expectShape(it).asStructureShape().orElse(null) } - .filter { it.hasTrait(SyntheticInputTrait.ID) || it.hasTrait(SyntheticOutputTrait.ID) } - .flatMap { walker.walkShapes(it) } - .filter { it is StructureShape || it is ListShape || it is UnionShape || it is MapShape } - .flatMap { it.constrainedMembers() } - .mapNotNull { - val transformation = it.makeNonConstrained(model, additionalNames) - // Keep record of new names that have been generated to ensure none of them regenerated. - additionalNames.add(transformation.newShape.id) - - transformation - } + val transformations = + model.operationShapes + .flatMap { listOfNotNull(it.input.orNull(), it.output.orNull()) + it.errors } + .mapNotNull { model.expectShape(it).asStructureShape().orElse(null) } + .filter { it.hasTrait(SyntheticInputTrait.ID) || it.hasTrait(SyntheticOutputTrait.ID) } + .flatMap { walker.walkShapes(it) } + .filter { it is StructureShape || it is ListShape || it is UnionShape || it is MapShape } + .flatMap { it.constrainedMembers() } + .mapNotNull { + val transformation = it.makeNonConstrained(model, additionalNames) + // Keep record of new names that have been generated to ensure none of them regenerated. + additionalNames.add(transformation.newShape.id) + + transformation + } return applyTransformations(model, transformations) } @@ -108,15 +108,16 @@ object ConstrainedMemberTransform { ): Model { val modelBuilder = model.toBuilder() - val memberShapesToReplace = transformations.map { - // Add the new shape to the model. - modelBuilder.addShape(it.newShape) + val memberShapesToReplace = + transformations.map { + // Add the new shape to the model. + modelBuilder.addShape(it.newShape) - it.memberToChange.toBuilder() - .target(it.newShape.id) - .traits(it.traitsToKeep) - .build() - } + it.memberToChange.toBuilder() + .target(it.newShape.id) + .traits(it.traitsToKeep) + .build() + } // Change all original constrained member shapes with the new standalone shapes. return ModelTransformer.create() @@ -140,14 +141,14 @@ object ConstrainedMemberTransform { memberShape: ShapeId, ): ShapeId { val structName = memberShape.name - val memberName = memberShape.member.orElse(null) - .replaceFirstChar { if (it.isLowerCase()) it.titlecase(Locale.getDefault()) else it.toString() } + val memberName = + memberShape.member.orElse(null) + .replaceFirstChar { if (it.isLowerCase()) it.titlecase(Locale.getDefault()) else it.toString() } fun makeStructName(suffix: String = "") = ShapeId.from("${memberShape.namespace}#${structName}${memberName}$suffix") - fun structNameIsUnique(newName: ShapeId) = - model.getShape(newName).isEmpty && !additionalNames.contains(newName) + fun structNameIsUnique(newName: ShapeId) = model.getShape(newName).isEmpty && !additionalNames.contains(newName) fun generateUniqueName(): ShapeId { // Ensure the name does not already exist in the model, else make it unique @@ -173,10 +174,11 @@ object ConstrainedMemberTransform { model: Model, additionalNames: MutableSet, ): MemberShapeTransformation { - val (memberConstraintTraits, otherTraits) = this.allTraits.values - .partition { - memberConstraintTraitsToOverride.contains(it.javaClass) - } + val (memberConstraintTraits, otherTraits) = + this.allTraits.values + .partition { + memberConstraintTraitsToOverride.contains(it.javaClass) + } check(memberConstraintTraits.isNotEmpty()) { "There must at least be one member constraint on the shape" @@ -211,9 +213,10 @@ object ConstrainedMemberTransform { // Create a new unique standalone shape that will be added to the model later on val shapeId = overriddenShapeId(model, additionalNames, this.id) - val standaloneShape = builder.id(shapeId) - .traits(newTraits) - .build() + val standaloneShape = + builder.id(shapeId) + .traits(newTraits) + .build() // Since the new shape has not been added to the model as yet, the current // memberShape's target cannot be changed to the new shape. diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/RecursiveConstraintViolationBoxer.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/RecursiveConstraintViolationBoxer.kt index fc938aab7a1..56df1b4e606 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/RecursiveConstraintViolationBoxer.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/RecursiveConstraintViolationBoxer.kt @@ -65,10 +65,11 @@ object RecursiveConstraintViolationBoxer { * * [0] https://github.com/smithy-lang/smithy-rs/pull/2040 */ - fun transform(model: Model): Model = RecursiveShapeBoxer( - containsIndirectionPredicate = ::constraintViolationLoopContainsIndirection, - boxShapeFn = ::addConstraintViolationRustBoxTrait, - ).transform(model) + fun transform(model: Model): Model = + RecursiveShapeBoxer( + containsIndirectionPredicate = ::constraintViolationLoopContainsIndirection, + boxShapeFn = ::addConstraintViolationRustBoxTrait, + ).transform(model) private fun constraintViolationLoopContainsIndirection(loop: Collection): Boolean = loop.find { it.hasTrait() } != null diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/ShapesReachableFromOperationInputTagger.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/ShapesReachableFromOperationInputTagger.kt index 74cfda8f7ed..11615995798 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/ShapesReachableFromOperationInputTagger.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/ShapesReachableFromOperationInputTagger.kt @@ -45,31 +45,34 @@ import software.amazon.smithy.rust.codegen.server.smithy.traits.ShapeReachableFr */ object ShapesReachableFromOperationInputTagger { fun transform(model: Model): Model { - val inputShapes = model.operationShapes.map { - model.expectShape(it.inputShape, StructureShape::class.java) - } + val inputShapes = + model.operationShapes.map { + model.expectShape(it.inputShape, StructureShape::class.java) + } val walker = DirectedWalker(model) - val shapesReachableFromOperationInputs = inputShapes - .flatMap { walker.walkShapes(it) } - .toSet() + val shapesReachableFromOperationInputs = + inputShapes + .flatMap { walker.walkShapes(it) } + .toSet() return ModelTransformer.create().mapShapes(model) { shape -> when (shape) { is StructureShape, is UnionShape, is ListShape, is MapShape, is StringShape, is IntegerShape, is ShortShape, is LongShape, is ByteShape, is BlobShape -> { if (shapesReachableFromOperationInputs.contains(shape)) { - val builder = when (shape) { - is StructureShape -> shape.toBuilder() - is UnionShape -> shape.toBuilder() - is ListShape -> shape.toBuilder() - is MapShape -> shape.toBuilder() - is StringShape -> shape.toBuilder() - is IntegerShape -> shape.toBuilder() - is ShortShape -> shape.toBuilder() - is LongShape -> shape.toBuilder() - is ByteShape -> shape.toBuilder() - is BlobShape -> shape.toBuilder() - else -> UNREACHABLE("the `when` is exhaustive") - } + val builder = + when (shape) { + is StructureShape -> shape.toBuilder() + is UnionShape -> shape.toBuilder() + is ListShape -> shape.toBuilder() + is MapShape -> shape.toBuilder() + is StringShape -> shape.toBuilder() + is IntegerShape -> shape.toBuilder() + is ShortShape -> shape.toBuilder() + is LongShape -> shape.toBuilder() + is ByteShape -> shape.toBuilder() + is BlobShape -> shape.toBuilder() + else -> UNREACHABLE("the `when` is exhaustive") + } builder.addTrait(ShapeReachableFromOperationInputTagTrait()).build() } else { shape diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintsMemberShapeTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintsMemberShapeTest.kt index f5b619ea747..b97328eb8d7 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintsMemberShapeTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintsMemberShapeTest.kt @@ -37,7 +37,8 @@ import java.io.File import java.nio.file.Path class ConstraintsMemberShapeTest { - private val outputModelOnly = """ + private val outputModelOnly = + """ namespace constrainedMemberShape use aws.protocols#restJson1 @@ -158,7 +159,7 @@ class ConstraintsMemberShapeTest { string PatternString @range(min: 0, max:1000) integer RangedInteger - """.asSmithyModel() + """.asSmithyModel() private fun loadModel(model: Model): Model = ConstrainedMemberTransform.transform(OperationNormalizer.transform(model)) @@ -283,15 +284,20 @@ class ConstraintsMemberShapeTest { ) } - private fun runServerCodeGen(model: Model, dirToUse: File? = null, writable: Writable): Path { + private fun runServerCodeGen( + model: Model, + dirToUse: File? = null, + writable: Writable, + ): Path { val runtimeConfig = - RuntimeConfig(runtimeCrateLocation = RuntimeCrateLocation.Path(File("../rust-runtime").absolutePath)) + RuntimeConfig(runtimeCrateLocation = RuntimeCrateLocation.path(File("../rust-runtime").absolutePath)) - val (context, dir) = generatePluginContext( - model, - runtimeConfig = runtimeConfig, - overrideTestDir = dirToUse, - ) + val (context, dir) = + generatePluginContext( + model, + runtimeConfig = runtimeConfig, + overrideTestDir = dirToUse, + ) val codegenDecorator = CombinedServerCodegenDecorator.fromClasspath( context, @@ -305,12 +311,13 @@ class ConstraintsMemberShapeTest { val codegenContext = serverTestCodegenContext(model) val settings = ServerRustSettings.from(context.model, context.settings) - val rustCrate = RustCrate( - context.fileManifest, - codegenContext.symbolProvider, - settings.codegenConfig, - codegenContext.expectModuleDocProvider(), - ) + val rustCrate = + RustCrate( + context.fileManifest, + codegenContext.symbolProvider, + settings.codegenConfig, + codegenContext.expectModuleDocProvider(), + ) // We cannot write to the lib anymore as the RustWriter overwrites it, so writing code directly to check.rs // and then adding a `mod check;` to the lib.rs @@ -328,64 +335,65 @@ class ConstraintsMemberShapeTest { @Test fun `generate code and check member constrained shapes are in the right modules`() { - val dir = runServerCodeGen(outputModelOnly) { - fun RustWriter.testTypeExistsInBuilderModule(typeName: String) { - unitTest( - "builder_module_has_${typeName.toSnakeCase()}", - """ - #[allow(unused_imports)] use crate::output::operation_using_get_output::$typeName; - """, - ) - } - - // All directly constrained members of the output structure should be in the builder module - setOf( - "ConstrainedLong", - "ConstrainedByte", - "ConstrainedShort", - "ConstrainedInteger", - "ConstrainedString", - "RequiredConstrainedString", - "RequiredConstrainedLong", - "RequiredConstrainedByte", - "RequiredConstrainedInteger", - "RequiredConstrainedShort", - "ConstrainedPatternString", - ).forEach(::testTypeExistsInBuilderModule) - - fun Set.generateUseStatements(prefix: String) = - this.joinToString(separator = "\n") { - "#[allow(unused_imports)] use $prefix::$it;" + val dir = + runServerCodeGen(outputModelOnly) { + fun RustWriter.testTypeExistsInBuilderModule(typeName: String) { + unitTest( + "builder_module_has_${typeName.toSnakeCase()}", + """ + #[allow(unused_imports)] use crate::output::operation_using_get_output::$typeName; + """, + ) } - unitTest( - "map_overridden_enum", + // All directly constrained members of the output structure should be in the builder module setOf( - "Value", - "value::ConstraintViolation as ValueCV", - "Key", - "key::ConstraintViolation as KeyCV", - ).generateUseStatements("crate::model::pattern_map_override"), - ) + "ConstrainedLong", + "ConstrainedByte", + "ConstrainedShort", + "ConstrainedInteger", + "ConstrainedString", + "RequiredConstrainedString", + "RequiredConstrainedLong", + "RequiredConstrainedByte", + "RequiredConstrainedInteger", + "RequiredConstrainedShort", + "ConstrainedPatternString", + ).forEach(::testTypeExistsInBuilderModule) + + fun Set.generateUseStatements(prefix: String) = + this.joinToString(separator = "\n") { + "#[allow(unused_imports)] use $prefix::$it;" + } - unitTest( - "union_overridden_enum", - setOf( - "First", - "first::ConstraintViolation as FirstCV", - "Second", - "second::ConstraintViolation as SecondCV", - ).generateUseStatements("crate::model::pattern_union_override"), - ) + unitTest( + "map_overridden_enum", + setOf( + "Value", + "value::ConstraintViolation as ValueCV", + "Key", + "key::ConstraintViolation as KeyCV", + ).generateUseStatements("crate::model::pattern_map_override"), + ) - unitTest( - "list_overridden_enum", - setOf( - "Member", - "member::ConstraintViolation as MemberCV", - ).generateUseStatements("crate::model::pattern_string_list_override"), - ) - } + unitTest( + "union_overridden_enum", + setOf( + "First", + "first::ConstraintViolation as FirstCV", + "Second", + "second::ConstraintViolation as SecondCV", + ).generateUseStatements("crate::model::pattern_union_override"), + ) + + unitTest( + "list_overridden_enum", + setOf( + "Member", + "member::ConstraintViolation as MemberCV", + ).generateUseStatements("crate::model::pattern_string_list_override"), + ) + } val env = mapOf("RUSTFLAGS" to "-A dead_code") "cargo test".runCommand(dir, env) @@ -424,8 +432,9 @@ class ConstraintsMemberShapeTest { memberTargetShape.id.name shouldNotBe beforeTransformMemberShape.target.name // Target shape's name should match the expected name - val expectedName = memberShape.container.name.substringAfter('#') + - memberShape.memberName.substringBefore('#').toPascalCase() + val expectedName = + memberShape.container.name.substringAfter('#') + + memberShape.memberName.substringBefore('#').toPascalCase() memberTargetShape.id.name shouldBe expectedName @@ -438,19 +447,21 @@ class ConstraintsMemberShapeTest { val leftOutConstraintTrait = beforeTransformConstraintTraits - newShapeConstrainedTraits assert( - leftOutConstraintTrait.isEmpty() || leftOutConstraintTrait.all { - it.toShapeId() == RequiredTrait.ID - }, + leftOutConstraintTrait.isEmpty() || + leftOutConstraintTrait.all { + it.toShapeId() == RequiredTrait.ID + }, ) { lazyMessage } // In case the target shape has some more constraints, which the member shape did not override, // then those still need to apply on the new standalone shape that has been defined. - val leftOverTraits = originalTargetShape.allTraits.values - .filter { beforeOverridingTrait -> - beforeTransformConstraintTraits.none { - beforeOverridingTrait.toShapeId() == it.toShapeId() + val leftOverTraits = + originalTargetShape.allTraits.values + .filter { beforeOverridingTrait -> + beforeTransformConstraintTraits.none { + beforeOverridingTrait.toShapeId() == it.toShapeId() + } } - } val allNewShapeTraits = memberTargetShape.allTraits.values.toList() assert((leftOverTraits + newShapeConstrainedTraits).all { it in allNewShapeTraits }) { lazyMessage } } diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/CustomShapeSymbolProviderTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/CustomShapeSymbolProviderTest.kt index ea5d63ab2c0..f4a21fd0094 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/CustomShapeSymbolProviderTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/CustomShapeSymbolProviderTest.kt @@ -40,20 +40,23 @@ class CustomShapeSymbolProviderTest { """.asSmithyModel(smithyVersion = "2.0") private val serviceShape = baseModel.lookup("test#TestService") private val rustType = RustType.Opaque("fake-type") - private val symbol = Symbol.builder() - .name("fake-symbol") - .rustType(rustType) - .build() - private val model = ModelTransformer.create() - .mapShapes(baseModel) { - if (it is MemberShape) { - it.toBuilder().addTrait(SyntheticCustomShapeTrait(ShapeId.from("some#id"), symbol)).build() - } else { - it + private val symbol = + Symbol.builder() + .name("fake-symbol") + .rustType(rustType) + .build() + private val model = + ModelTransformer.create() + .mapShapes(baseModel) { + if (it is MemberShape) { + it.toBuilder().addTrait(SyntheticCustomShapeTrait(ShapeId.from("some#id"), symbol)).build() + } else { + it + } } - } - private val symbolProvider = serverTestSymbolProvider(baseModel, serviceShape) - .let { CustomShapeSymbolProvider(it) } + private val symbolProvider = + serverTestSymbolProvider(baseModel, serviceShape) + .let { CustomShapeSymbolProvider(it) } @Test fun `override with custom symbol`() { diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/DeriveEqAndHashSymbolMetadataProviderTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/DeriveEqAndHashSymbolMetadataProviderTest.kt index 5f2dea66e2d..360a8fe6061 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/DeriveEqAndHashSymbolMetadataProviderTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/DeriveEqAndHashSymbolMetadataProviderTest.kt @@ -170,44 +170,48 @@ internal class DeriveEqAndHashSymbolMetadataProviderTest { } """.asSmithyModel(smithyVersion = "2.0") private val serviceShape = model.lookup("test#TestService") - private val deriveEqAndHashSymbolMetadataProvider = serverTestSymbolProvider(model, serviceShape) - .let { BaseSymbolMetadataProvider(it, additionalAttributes = listOf()) } - .let { DeriveEqAndHashSymbolMetadataProvider(it) } + private val deriveEqAndHashSymbolMetadataProvider = + serverTestSymbolProvider(model, serviceShape) + .let { BaseSymbolMetadataProvider(it, additionalAttributes = listOf()) } + .let { DeriveEqAndHashSymbolMetadataProvider(it) } companion object { @JvmStatic fun getShapes(): Stream { - val shapesWithNeitherEqNorHash = listOf( - "test#StreamingOperationInputOutput", - "test#EventStreamOperationInputOutput", - "test#StreamingUnion", - "test#BlobStream", - "test#TestInputOutput", - "test#HasFloat", - "test#HasDouble", - "test#HasDocument", - "test#ContainsFloat", - "test#ContainsDouble", - "test#ContainsDocument", - ) - - val shapesWithEqAndHash = listOf( - "test#EqAndHashStruct", - "test#EqAndHashUnion", - "test#Enum", - "test#HasList", - ) - - val shapesWithOnlyEq = listOf( - "test#HasListWithMap", - "test#HasMap", - ) + val shapesWithNeitherEqNorHash = + listOf( + "test#StreamingOperationInputOutput", + "test#EventStreamOperationInputOutput", + "test#StreamingUnion", + "test#BlobStream", + "test#TestInputOutput", + "test#HasFloat", + "test#HasDouble", + "test#HasDocument", + "test#ContainsFloat", + "test#ContainsDouble", + "test#ContainsDocument", + ) + + val shapesWithEqAndHash = + listOf( + "test#EqAndHashStruct", + "test#EqAndHashUnion", + "test#Enum", + "test#HasList", + ) + + val shapesWithOnlyEq = + listOf( + "test#HasListWithMap", + "test#HasMap", + ) return ( shapesWithNeitherEqNorHash.map { Arguments.of(it, emptyList()) } + shapesWithEqAndHash.map { Arguments.of(it, listOf(RuntimeType.Eq, RuntimeType.Hash)) } + shapesWithOnlyEq.map { Arguments.of(it, listOf(RuntimeType.Eq)) } - ).stream() + ).stream() } } diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/PatternTraitEscapedSpecialCharsValidatorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/PatternTraitEscapedSpecialCharsValidatorTest.kt index 7a54849c208..5395d63f07b 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/PatternTraitEscapedSpecialCharsValidatorTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/PatternTraitEscapedSpecialCharsValidatorTest.kt @@ -17,84 +17,90 @@ import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel class PatternTraitEscapedSpecialCharsValidatorTest { @Test fun `should error out with a suggestion if non-escaped special chars used inside @pattern`() { - val exception = shouldThrow { - """ - namespace test - - @pattern("\t") - string MyString - """.asSmithyModel(smithyVersion = "2") - } + val exception = + shouldThrow { + """ + namespace test + + @pattern("\t") + string MyString + """.asSmithyModel(smithyVersion = "2") + } val events = exception.validationEvents.filter { it.severity == Severity.ERROR } events shouldHaveSize 1 events[0].shapeId.get() shouldBe ShapeId.from("test#MyString") - events[0].message shouldBe """ + events[0].message shouldBe + """ Non-escaped special characters used inside `@pattern`. You must escape them: `@pattern("\\t")`. See https://github.com/smithy-lang/smithy-rs/issues/2508 for more details. - """.trimIndent() + """.trimIndent() } @Test fun `should suggest escaping spacial characters properly`() { - val exception = shouldThrow { - """ - namespace test - - @pattern("[.\n\\r]+") - string MyString - """.asSmithyModel(smithyVersion = "2") - } + val exception = + shouldThrow { + """ + namespace test + + @pattern("[.\n\\r]+") + string MyString + """.asSmithyModel(smithyVersion = "2") + } val events = exception.validationEvents.filter { it.severity == Severity.ERROR } events shouldHaveSize 1 events[0].shapeId.get() shouldBe ShapeId.from("test#MyString") - events[0].message shouldBe """ + events[0].message shouldBe + """ Non-escaped special characters used inside `@pattern`. You must escape them: `@pattern("[.\\n\\r]+")`. See https://github.com/smithy-lang/smithy-rs/issues/2508 for more details. - """.trimIndent() + """.trimIndent() } @Test fun `should report all non-escaped special characters`() { - val exception = shouldThrow { - """ - namespace test + val exception = + shouldThrow { + """ + namespace test - @pattern("\b") - string MyString + @pattern("\b") + string MyString - @pattern("^\n$") - string MyString2 + @pattern("^\n$") + string MyString2 - @pattern("^[\n]+$") - string MyString3 + @pattern("^[\n]+$") + string MyString3 - @pattern("^[\r\t]$") - string MyString4 - """.asSmithyModel(smithyVersion = "2") - } + @pattern("^[\r\t]$") + string MyString4 + """.asSmithyModel(smithyVersion = "2") + } val events = exception.validationEvents.filter { it.severity == Severity.ERROR } events shouldHaveSize 4 } @Test fun `should report errors on string members`() { - val exception = shouldThrow { - """ - namespace test - - @pattern("\t") - string MyString - - structure MyStructure { - @pattern("\b") - field: String + val exception = + shouldThrow { + """ + namespace test + + @pattern("\t") + string MyString + + structure MyStructure { + @pattern("\b") + field: String + } + """.asSmithyModel(smithyVersion = "2") } - """.asSmithyModel(smithyVersion = "2") - } val events = exception.validationEvents.filter { it.severity == Severity.ERROR } events shouldHaveSize 2 diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/PubCrateConstrainedShapeSymbolProviderTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/PubCrateConstrainedShapeSymbolProviderTest.kt index 8d07a5959af..0c223f884bb 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/PubCrateConstrainedShapeSymbolProviderTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/PubCrateConstrainedShapeSymbolProviderTest.kt @@ -21,7 +21,8 @@ import software.amazon.smithy.rust.codegen.core.util.lookup import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestSymbolProviders class PubCrateConstrainedShapeSymbolProviderTest { - private val model = """ + private val model = + """ $baseModelString structure NonTransitivelyConstrainedStructureShape { @@ -46,7 +47,7 @@ class PubCrateConstrainedShapeSymbolProviderTest { union Union { structure: Structure } - """.asSmithyModel() + """.asSmithyModel() private val serverTestSymbolProviders = serverTestSymbolProviders(model) private val symbolProvider = serverTestSymbolProviders.symbolProvider @@ -74,10 +75,11 @@ class PubCrateConstrainedShapeSymbolProviderTest { val transitivelyConstrainedCollectionType = pubCrateConstrainedShapeSymbolProvider.toSymbol(transitivelyConstrainedCollectionShape).rustType() - transitivelyConstrainedCollectionType shouldBe RustType.Opaque( - "TransitivelyConstrainedCollectionConstrained", - "crate::constrained::transitively_constrained_collection_constrained", - ) + transitivelyConstrainedCollectionType shouldBe + RustType.Opaque( + "TransitivelyConstrainedCollectionConstrained", + "crate::constrained::transitively_constrained_collection_constrained", + ) } @Test @@ -86,10 +88,11 @@ class PubCrateConstrainedShapeSymbolProviderTest { val transitivelyConstrainedMapType = pubCrateConstrainedShapeSymbolProvider.toSymbol(transitivelyConstrainedMapShape).rustType() - transitivelyConstrainedMapType shouldBe RustType.Opaque( - "TransitivelyConstrainedMapConstrained", - "crate::constrained::transitively_constrained_map_constrained", - ) + transitivelyConstrainedMapType shouldBe + RustType.Opaque( + "TransitivelyConstrainedMapConstrained", + "crate::constrained::transitively_constrained_map_constrained", + ) } @Test @@ -97,12 +100,13 @@ class PubCrateConstrainedShapeSymbolProviderTest { val memberShape = model.lookup("test#StructureWithMemberTargetingAggregateShape\$member") val memberType = pubCrateConstrainedShapeSymbolProvider.toSymbol(memberShape).rustType() - memberType shouldBe RustType.Option( - RustType.Opaque( - "TransitivelyConstrainedCollectionConstrained", - "crate::constrained::transitively_constrained_collection_constrained", - ), - ) + memberType shouldBe + RustType.Option( + RustType.Opaque( + "TransitivelyConstrainedCollectionConstrained", + "crate::constrained::transitively_constrained_collection_constrained", + ), + ) } @Test diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RecursiveConstraintViolationsTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RecursiveConstraintViolationsTest.kt index 7467d0d76ff..432ad64fde4 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RecursiveConstraintViolationsTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RecursiveConstraintViolationsTest.kt @@ -16,7 +16,6 @@ import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverIntegrat import java.util.stream.Stream internal class RecursiveConstraintViolationsTest { - data class TestCase( /** The test name is only used in the generated report, to easily identify a failing test. **/ val testName: String, @@ -49,7 +48,10 @@ internal class RecursiveConstraintViolationsTest { } """ - private fun recursiveListModel(sparse: Boolean, listPrefix: String = ""): Pair = + private fun recursiveListModel( + sparse: Boolean, + listPrefix: String = "", + ): Pair = """ $baseModel @@ -57,18 +59,26 @@ internal class RecursiveConstraintViolationsTest { list: ${listPrefix}List } - ${ if (sparse) { "@sparse" } else { "" } } + ${ if (sparse) { + "@sparse" + } else { + "" + } } @length(min: 69) list ${listPrefix}List { member: Recursive } - """.asSmithyModel() to if ("${listPrefix}List" < "Recursive") { - "com.amazonaws.recursiveconstraintviolations#${listPrefix}List\$member" - } else { - "com.amazonaws.recursiveconstraintviolations#Recursive\$list" - } + """.asSmithyModel() to + if ("${listPrefix}List" < "Recursive") { + "com.amazonaws.recursiveconstraintviolations#${listPrefix}List\$member" + } else { + "com.amazonaws.recursiveconstraintviolations#Recursive\$list" + } - private fun recursiveMapModel(sparse: Boolean, mapPrefix: String = ""): Pair = + private fun recursiveMapModel( + sparse: Boolean, + mapPrefix: String = "", + ): Pair = """ $baseModel @@ -76,17 +86,22 @@ internal class RecursiveConstraintViolationsTest { map: ${mapPrefix}Map } - ${ if (sparse) { "@sparse" } else { "" } } + ${ if (sparse) { + "@sparse" + } else { + "" + } } @length(min: 69) map ${mapPrefix}Map { key: String, value: Recursive } - """.asSmithyModel() to if ("${mapPrefix}Map" < "Recursive") { - "com.amazonaws.recursiveconstraintviolations#${mapPrefix}Map\$value" - } else { - "com.amazonaws.recursiveconstraintviolations#Recursive\$map" - } + """.asSmithyModel() to + if ("${mapPrefix}Map" < "Recursive") { + "com.amazonaws.recursiveconstraintviolations#${mapPrefix}Map\$value" + } else { + "com.amazonaws.recursiveconstraintviolations#Recursive\$map" + } private fun recursiveUnionModel(unionPrefix: String = ""): Pair = """ @@ -137,34 +152,37 @@ internal class RecursiveConstraintViolationsTest { } override fun provideArguments(context: ExtensionContext?): Stream { - val listModels = listOf(false, true).flatMap { isSparse -> - listOf("", "ZZZ").map { listPrefix -> - val (model, shapeIdWithConstraintViolationRustBoxTrait) = recursiveListModel(isSparse, listPrefix) - var testName = "${ if (isSparse) "sparse" else "non-sparse" } recursive list" - if (listPrefix.isNotEmpty()) { - testName += " with shape name prefix $listPrefix" + val listModels = + listOf(false, true).flatMap { isSparse -> + listOf("", "ZZZ").map { listPrefix -> + val (model, shapeIdWithConstraintViolationRustBoxTrait) = recursiveListModel(isSparse, listPrefix) + var testName = "${ if (isSparse) "sparse" else "non-sparse" } recursive list" + if (listPrefix.isNotEmpty()) { + testName += " with shape name prefix $listPrefix" + } + TestCase(testName, model, shapeIdWithConstraintViolationRustBoxTrait) } - TestCase(testName, model, shapeIdWithConstraintViolationRustBoxTrait) } - } - val mapModels = listOf(false, true).flatMap { isSparse -> - listOf("", "ZZZ").map { mapPrefix -> - val (model, shapeIdWithConstraintViolationRustBoxTrait) = recursiveMapModel(isSparse, mapPrefix) - var testName = "${ if (isSparse) "sparse" else "non-sparse" } recursive map" - if (mapPrefix.isNotEmpty()) { - testName += " with shape name prefix $mapPrefix" + val mapModels = + listOf(false, true).flatMap { isSparse -> + listOf("", "ZZZ").map { mapPrefix -> + val (model, shapeIdWithConstraintViolationRustBoxTrait) = recursiveMapModel(isSparse, mapPrefix) + var testName = "${ if (isSparse) "sparse" else "non-sparse" } recursive map" + if (mapPrefix.isNotEmpty()) { + testName += " with shape name prefix $mapPrefix" + } + TestCase(testName, model, shapeIdWithConstraintViolationRustBoxTrait) } - TestCase(testName, model, shapeIdWithConstraintViolationRustBoxTrait) } - } - val unionModels = listOf("", "ZZZ").map { unionPrefix -> - val (model, shapeIdWithConstraintViolationRustBoxTrait) = recursiveUnionModel(unionPrefix) - var testName = "recursive union" - if (unionPrefix.isNotEmpty()) { - testName += " with shape name prefix $unionPrefix" + val unionModels = + listOf("", "ZZZ").map { unionPrefix -> + val (model, shapeIdWithConstraintViolationRustBoxTrait) = recursiveUnionModel(unionPrefix) + var testName = "recursive union" + if (unionPrefix.isNotEmpty()) { + testName += " with shape name prefix $unionPrefix" + } + TestCase(testName, model, shapeIdWithConstraintViolationRustBoxTrait) } - TestCase(testName, model, shapeIdWithConstraintViolationRustBoxTrait) - } return listOf(listModels, mapModels, unionModels) .flatten() .map { Arguments.of(it) }.stream() diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RustCrateInlineModuleComposingWriterTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RustCrateInlineModuleComposingWriterTest.kt index d4358aa4c03..5dcb4b3cf85 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RustCrateInlineModuleComposingWriterTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RustCrateInlineModuleComposingWriterTest.kt @@ -29,7 +29,8 @@ import java.io.File class RustCrateInlineModuleComposingWriterTest { private val rustCrate: RustCrate private val codegenContext: ServerCodegenContext - private val model: Model = """ + private val model: Model = + """ ${'$'}version: "2.0" namespace test @@ -56,27 +57,32 @@ class RustCrateInlineModuleComposingWriterTest { @pattern("^[a-m]+${'$'}") string PatternString - """.trimIndent().asSmithyModel() + """.trimIndent().asSmithyModel() init { codegenContext = serverTestCodegenContext(model) val runtimeConfig = - RuntimeConfig(runtimeCrateLocation = RuntimeCrateLocation.Path(File("../rust-runtime").absolutePath)) + RuntimeConfig(runtimeCrateLocation = RuntimeCrateLocation.path(File("../rust-runtime").absolutePath)) - val (context, _) = generatePluginContext( - model, - runtimeConfig = runtimeConfig, - ) + val (context, _) = + generatePluginContext( + model, + runtimeConfig = runtimeConfig, + ) val settings = ServerRustSettings.from(context.model, context.settings) - rustCrate = RustCrate( - context.fileManifest, - codegenContext.symbolProvider, - settings.codegenConfig, - codegenContext.expectModuleDocProvider(), - ) + rustCrate = + RustCrate( + context.fileManifest, + codegenContext.symbolProvider, + settings.codegenConfig, + codegenContext.expectModuleDocProvider(), + ) } - private fun createTestInlineModule(parentModule: RustModule, moduleName: String): RustModule.LeafModule = + private fun createTestInlineModule( + parentModule: RustModule, + moduleName: String, + ): RustModule.LeafModule = RustModule.new( moduleName, visibility = Visibility.PUBLIC, @@ -92,13 +98,19 @@ class RustCrateInlineModuleComposingWriterTest { inline = true, ) - private fun helloWorld(writer: RustWriter, moduleName: String) { + private fun helloWorld( + writer: RustWriter, + moduleName: String, + ) { writer.rustBlock("pub fn hello_world()") { writer.comment("Module $moduleName") } } - private fun byeWorld(writer: RustWriter, moduleName: String) { + private fun byeWorld( + writer: RustWriter, + moduleName: String, + ) { writer.rustBlock("pub fn bye_world()") { writer.comment("Module $moduleName") writer.rust("""println!("from inside $moduleName");""") @@ -152,12 +164,13 @@ class RustCrateInlineModuleComposingWriterTest { // crate::output::h::i::hello_world(); val testProject = TestWorkspace.testProject(serverTestSymbolProvider(model)) - val modules = hashMapOf( - "a" to createTestOrphanInlineModule("a"), - "d" to createTestOrphanInlineModule("d"), - "e" to createTestOrphanInlineModule("e"), - "i" to createTestOrphanInlineModule("i"), - ) + val modules = + hashMapOf( + "a" to createTestOrphanInlineModule("a"), + "d" to createTestOrphanInlineModule("d"), + "e" to createTestOrphanInlineModule("e"), + "i" to createTestOrphanInlineModule("i"), + ) modules["b"] = createTestInlineModule(ServerRustModule.Model, "b") modules["c"] = createTestInlineModule(modules["b"]!!, "c") diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitorTest.kt index 942e5f76170..54c3606fa11 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitorTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitorTest.kt @@ -18,7 +18,8 @@ import kotlin.io.path.writeText class ServerCodegenVisitorTest { @Test fun `baseline transform verify mixins removed`() { - val model = """ + val model = + """ namespace com.example use aws.protocols#awsJson1_0 @@ -43,7 +44,7 @@ class ServerCodegenVisitorTest { ] { greeting: String } - """.asSmithyModel(smithyVersion = "2.0") + """.asSmithyModel(smithyVersion = "2.0") val (ctx, testDir) = generatePluginContext(model) testDir.resolve("src/main.rs").writeText("fn main() {}") val codegenDecorator = diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ValidateUnsupportedConstraintsAreNotUsedTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ValidateUnsupportedConstraintsAreNotUsedTest.kt index a140322e361..72279a7a205 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ValidateUnsupportedConstraintsAreNotUsedTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ValidateUnsupportedConstraintsAreNotUsedTest.kt @@ -37,7 +37,10 @@ internal class ValidateUnsupportedConstraintsAreNotUsedTest { } """ - private fun validateModel(model: Model, serverCodegenConfig: ServerCodegenConfig = ServerCodegenConfig()): ValidationResult { + private fun validateModel( + model: Model, + serverCodegenConfig: ServerCodegenConfig = ServerCodegenConfig(), + ): ValidationResult { val service = model.serviceShapes.first() return validateUnsupportedConstraints(model, service, serverCodegenConfig) } @@ -54,16 +57,18 @@ internal class ValidateUnsupportedConstraintsAreNotUsedTest { } """.asSmithyModel() val service = model.lookup("test#TestService") - val validationResult = validateOperationsWithConstrainedInputHaveValidationExceptionAttached( - model, - service, - SmithyValidationExceptionConversionGenerator.SHAPE_ID, - ) + val validationResult = + validateOperationsWithConstrainedInputHaveValidationExceptionAttached( + model, + service, + SmithyValidationExceptionConversionGenerator.SHAPE_ID, + ) validationResult.messages shouldHaveSize 1 // Asserts the exact message, to ensure the formatting is appropriate. - validationResult.messages[0].message shouldBe """ + validationResult.messages[0].message shouldBe + """ Operation test#TestOperation takes in input that is constrained (https://awslabs.github.io/smithy/2.0/spec/constraint-traits.html), and as such can fail with a validation exception. You must model this behavior in the operation shape in your model file. ```smithy use smithy.framework#ValidationException @@ -73,7 +78,7 @@ internal class ValidateUnsupportedConstraintsAreNotUsedTest { errors: [..., ValidationException] // <-- Add this. } ``` - """.trimIndent() + """.trimIndent() } private val constraintTraitOnStreamingBlobShapeModel = @@ -95,10 +100,11 @@ internal class ValidateUnsupportedConstraintsAreNotUsedTest { val validationResult = validateModel(constraintTraitOnStreamingBlobShapeModel) validationResult.messages shouldHaveSize 1 - validationResult.messages[0].message shouldContain """ + validationResult.messages[0].message shouldContain + """ The blob shape `test#StreamingBlob` has both the `smithy.api#length` and `smithy.api#streaming` constraint traits attached. It is unclear what the semantics for streaming blob shapes are. - """.trimIndent().replace("\n", " ") + """.trimIndent().replace("\n", " ") } private val constrainedShapesInEventStreamModel = @@ -183,25 +189,27 @@ internal class ValidateUnsupportedConstraintsAreNotUsedTest { The map shape `test#Map` is reachable from the list shape `test#UniqueItemsList`, which has the `@uniqueItems` trait attached. """.trimIndent().replace("\n", " ") - ) + ) } @Test fun `it should abort when a map shape is reachable from a uniqueItems list shape, despite opting into ignoreUnsupportedConstraintTraits`() { - val validationResult = validateModel( - mapShapeReachableFromUniqueItemsListShapeModel, - ServerCodegenConfig().copy(ignoreUnsupportedConstraints = true), - ) + val validationResult = + validateModel( + mapShapeReachableFromUniqueItemsListShapeModel, + ServerCodegenConfig().copy(ignoreUnsupportedConstraints = true), + ) validationResult.shouldAbort shouldBe true } @Test fun `it should abort when constraint traits in event streams are used, despite opting into ignoreUnsupportedConstraintTraits`() { - val validationResult = validateModel( - EventStreamNormalizer.transform(constrainedShapesInEventStreamModel), - ServerCodegenConfig().copy(ignoreUnsupportedConstraints = true), - ) + val validationResult = + validateModel( + EventStreamNormalizer.transform(constrainedShapesInEventStreamModel), + ServerCodegenConfig().copy(ignoreUnsupportedConstraints = true), + ) validationResult.shouldAbort shouldBe true } @@ -216,10 +224,11 @@ internal class ValidateUnsupportedConstraintsAreNotUsedTest { @Test fun `it should not abort when ignoreUnsupportedConstraints is true and unsupported constraints are used`() { - val validationResult = validateModel( - constraintTraitOnStreamingBlobShapeModel, - ServerCodegenConfig().copy(ignoreUnsupportedConstraints = true), - ) + val validationResult = + validateModel( + constraintTraitOnStreamingBlobShapeModel, + ServerCodegenConfig().copy(ignoreUnsupportedConstraints = true), + ) validationResult.messages shouldHaveAtLeastSize 1 validationResult.shouldAbort shouldBe false @@ -235,10 +244,11 @@ internal class ValidateUnsupportedConstraintsAreNotUsedTest { @Test fun `it should set log level to warn when ignoreUnsupportedConstraints is true and unsupported constraints are used`() { - val validationResult = validateModel( - constraintTraitOnStreamingBlobShapeModel, - ServerCodegenConfig().copy(ignoreUnsupportedConstraints = true), - ) + val validationResult = + validateModel( + constraintTraitOnStreamingBlobShapeModel, + ServerCodegenConfig().copy(ignoreUnsupportedConstraints = true), + ) validationResult.messages shouldHaveAtLeastSize 1 validationResult.messages.shouldForAll { it.level shouldBe Level.WARNING } @@ -246,14 +256,16 @@ internal class ValidateUnsupportedConstraintsAreNotUsedTest { @Test fun `it should abort when ignoreUnsupportedConstraints is true and all used constraints are supported`() { - val allConstraintTraitsAreSupported = File("../codegen-core/common-test-models/constraints.smithy") - .readText() - .asSmithyModel() - - val validationResult = validateModel( - allConstraintTraitsAreSupported, - ServerCodegenConfig().copy(ignoreUnsupportedConstraints = true), - ) + val allConstraintTraitsAreSupported = + File("../codegen-core/common-test-models/constraints.smithy") + .readText() + .asSmithyModel() + + val validationResult = + validateModel( + allConstraintTraitsAreSupported, + ServerCodegenConfig().copy(ignoreUnsupportedConstraints = true), + ) validationResult.messages shouldHaveSize 1 validationResult.shouldAbort shouldBe true @@ -262,6 +274,6 @@ internal class ValidateUnsupportedConstraintsAreNotUsedTest { The `ignoreUnsupportedConstraints` flag in the `codegen` configuration is set to `true`, but it has no effect. All the constraint traits used in the model are well-supported, please remove this flag. """.trimIndent().replace("\n", " ") - ) + ) } } diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/AdditionalErrorsDecoratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/AdditionalErrorsDecoratorTest.kt index d7783c7ac10..8fc655da042 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/AdditionalErrorsDecoratorTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/AdditionalErrorsDecoratorTest.kt @@ -15,7 +15,8 @@ import software.amazon.smithy.rust.codegen.core.util.lookup import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestRustSettings class AdditionalErrorsDecoratorTest { - private val baseModel = """ + private val baseModel = + """ namespace test operation Infallible { @@ -33,7 +34,7 @@ class AdditionalErrorsDecoratorTest { @error("client") structure AnError { } - """.asSmithyModel() + """.asSmithyModel() private val model = OperationNormalizer.transform(baseModel) private val service = ServiceShape.builder().id("smithy.test#Test").build() private val settings = serverTestRustSettings() diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/CustomValidationExceptionWithReasonDecoratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/CustomValidationExceptionWithReasonDecoratorTest.kt index 31b46023819..9176a7bf086 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/CustomValidationExceptionWithReasonDecoratorTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/CustomValidationExceptionWithReasonDecoratorTest.kt @@ -75,16 +75,18 @@ fun swapOutSmithyValidationExceptionForCustomOne(model: Model): Model { """.asSmithyModel(smithyVersion = "2.0") // Remove Smithy's `ValidationException`. - var model = ModelTransformer.create().removeShapes( - model, - listOf(model.expectShape(SmithyValidationExceptionConversionGenerator.SHAPE_ID)), - ) + var model = + ModelTransformer.create().removeShapes( + model, + listOf(model.expectShape(SmithyValidationExceptionConversionGenerator.SHAPE_ID)), + ) // Add our custom one. model = ModelTransformer.create().replaceShapes(model, customValidationExceptionModel.shapes().toList()) // Make all operations use our custom one. - val newOperationShapes = model.operationShapes.map { operationShape -> - operationShape.toBuilder().addError(ShapeId.from("com.amazonaws.constraints#ValidationException")).build() - } + val newOperationShapes = + model.operationShapes.map { operationShape -> + operationShape.toBuilder().addError(ShapeId.from("com.amazonaws.constraints#ValidationException")).build() + } return ModelTransformer.create().replaceShapes(model, newOperationShapes) } @@ -97,12 +99,13 @@ internal class CustomValidationExceptionWithReasonDecoratorTest { serverIntegrationTest( model, IntegrationTestParams( - additionalSettings = Node.objectNodeBuilder().withMember( - "codegen", - Node.objectNodeBuilder() - .withMember("experimentalCustomValidationExceptionWithReasonPleaseDoNotUse", "com.amazonaws.constraints#ValidationException") - .build(), - ).build(), + additionalSettings = + Node.objectNodeBuilder().withMember( + "codegen", + Node.objectNodeBuilder() + .withMember("experimentalCustomValidationExceptionWithReasonPleaseDoNotUse", "com.amazonaws.constraints#ValidationException") + .build(), + ).build(), ), ) } diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/PostprocessValidationExceptionNotAttachedErrorMessageDecoratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/PostprocessValidationExceptionNotAttachedErrorMessageDecoratorTest.kt index 9130d30b330..11a2f882aba 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/PostprocessValidationExceptionNotAttachedErrorMessageDecoratorTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/PostprocessValidationExceptionNotAttachedErrorMessageDecoratorTest.kt @@ -39,33 +39,37 @@ internal class PostprocessValidationExceptionNotAttachedErrorMessageDecoratorTes } """.asSmithyModel() - val validationExceptionNotAttachedErrorMessageDummyPostprocessorDecorator = object : ServerCodegenDecorator { - override val name: String - get() = "ValidationExceptionNotAttachedErrorMessageDummyPostprocessorDecorator" - override val order: Byte - get() = 69 + val validationExceptionNotAttachedErrorMessageDummyPostprocessorDecorator = + object : ServerCodegenDecorator { + override val name: String + get() = "ValidationExceptionNotAttachedErrorMessageDummyPostprocessorDecorator" + override val order: Byte + get() = 69 - override fun postprocessValidationExceptionNotAttachedErrorMessage(validationResult: ValidationResult): ValidationResult { - check(validationResult.messages.size == 1) + override fun postprocessValidationExceptionNotAttachedErrorMessage( + validationResult: ValidationResult, + ): ValidationResult { + check(validationResult.messages.size == 1) - val level = validationResult.messages.first().level - val message = - """ - ${validationResult.messages.first().message} + val level = validationResult.messages.first().level + val message = + """ + ${validationResult.messages.first().message} - There are three things all wise men fear: the sea in storm, a night with no moon, and the anger of a gentle man. - """ + There are three things all wise men fear: the sea in storm, a night with no moon, and the anger of a gentle man. + """ - return validationResult.copy(messages = listOf(LogMessage(level, message))) + return validationResult.copy(messages = listOf(LogMessage(level, message))) + } } - } - val exception = assertThrows { - serverIntegrationTest( - model, - additionalDecorators = listOf(validationExceptionNotAttachedErrorMessageDummyPostprocessorDecorator), - ) - } + val exception = + assertThrows { + serverIntegrationTest( + model, + additionalDecorators = listOf(validationExceptionNotAttachedErrorMessageDummyPostprocessorDecorator), + ) + } val exceptionCause = (exception.cause!! as ValidationResult) exceptionCause.messages.size shouldBe 1 exceptionCause.messages.first().message shouldContain "There are three things all wise men fear: the sea in storm, a night with no moon, and the anger of a gentle man." diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedBlobGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedBlobGeneratorTest.kt index 3a35120f839..bc73ac7f327 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedBlobGeneratorTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedBlobGeneratorTest.kt @@ -32,27 +32,28 @@ class ConstrainedBlobGeneratorTest { data class TestCase(val model: Model, val validBlob: String, val invalidBlob: String) class ConstrainedBlobGeneratorTestProvider : ArgumentsProvider { - private val testCases = listOf( - // Min and max. - Triple("@length(min: 11, max: 12)", "validString", "invalidString"), - // Min equal to max. - Triple("@length(min: 11, max: 11)", "validString", "invalidString"), - // Only min. - Triple("@length(min: 11)", "validString", ""), - // Only max. - Triple("@length(max: 11)", "", "invalidString"), - ).map { - TestCase( - """ - namespace test - - ${it.first} - blob ConstrainedBlob - """.asSmithyModel(), - "aws_smithy_types::Blob::new(Vec::from(${it.second.dq()}))", - "aws_smithy_types::Blob::new(Vec::from(${it.third.dq()}))", - ) - } + private val testCases = + listOf( + // Min and max. + Triple("@length(min: 11, max: 12)", "validString", "invalidString"), + // Min equal to max. + Triple("@length(min: 11, max: 11)", "validString", "invalidString"), + // Only min. + Triple("@length(min: 11)", "validString", ""), + // Only max. + Triple("@length(max: 11)", "", "invalidString"), + ).map { + TestCase( + """ + namespace test + + ${it.first} + blob ConstrainedBlob + """.asSmithyModel(), + "aws_smithy_types::Blob::new(Vec::from(${it.second.dq()}))", + "aws_smithy_types::Blob::new(Vec::from(${it.third.dq()}))", + ) + } override fun provideArguments(context: ExtensionContext?): Stream = testCases.map { Arguments.of(it) }.stream() @@ -118,12 +119,13 @@ class ConstrainedBlobGeneratorTest { @Test fun `type should not be constructable without using a constructor`() { - val model = """ + val model = + """ namespace test @length(min: 1, max: 70) blob ConstrainedBlob - """.asSmithyModel() + """.asSmithyModel() val constrainedBlobShape = model.lookup("test#ConstrainedBlob") val codegenContext = serverTestCodegenContext(model) diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedCollectionGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedCollectionGeneratorTest.kt index 444d75d2ad5..23158df233b 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedCollectionGeneratorTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedCollectionGeneratorTest.kt @@ -70,90 +70,99 @@ class ConstrainedCollectionGeneratorTest { } """.asSmithyModel().let(ShapesReachableFromOperationInputTagger::transform) - private val lengthTraitTestCases = listOf( - // Min and max. - Triple("@length(min: 11, max: 12)", 11, 13), - // Min equal to max. - Triple("@length(min: 11, max: 11)", 11, 12), - // Only min. - Triple("@length(min: 11)", 15, 10), - // Only max. - Triple("@length(max: 11)", 11, 12), - ).map { - // Generate lists of strings of the specified length with consecutive items "0", "1", ... - val validList = List(it.second, Int::toString) - val invalidList = List(it.third, Int::toString) - - Triple(it.first, ArrayNode.fromStrings(validList), ArrayNode.fromStrings(invalidList)) - }.map { (trait, validList, invalidList) -> - TestCase( - model = generateModel(trait), - validLists = listOf(validList), - invalidLists = listOf(InvalidList(invalidList, expectedErrorFn = null)), - ) - } - - private fun constraintViolationForDuplicateIndices(duplicateIndices: List): - ((constraintViolation: Symbol, originalValueBindingName: String) -> Writable) { - fun ret(constraintViolation: Symbol, originalValueBindingName: String): Writable = writable { - // Public documentation for the unique items constraint violation states that callers should not - // rely on the order of the elements in `duplicate_indices`. However, the algorithm is deterministic, - // so we can internally assert the order. If the algorithm changes, the test cases will need to be - // adjusted. - rustTemplate( - """ - #{ConstraintViolation}::UniqueItems { - duplicate_indices: vec![${duplicateIndices.joinToString(", ")}], - original: $originalValueBindingName, - } - """, - "ConstraintViolation" to constraintViolation, + private val lengthTraitTestCases = + listOf( + // Min and max. + Triple("@length(min: 11, max: 12)", 11, 13), + // Min equal to max. + Triple("@length(min: 11, max: 11)", 11, 12), + // Only min. + Triple("@length(min: 11)", 15, 10), + // Only max. + Triple("@length(max: 11)", 11, 12), + ).map { + // Generate lists of strings of the specified length with consecutive items "0", "1", ... + val validList = List(it.second, Int::toString) + val invalidList = List(it.third, Int::toString) + + Triple(it.first, ArrayNode.fromStrings(validList), ArrayNode.fromStrings(invalidList)) + }.map { (trait, validList, invalidList) -> + TestCase( + model = generateModel(trait), + validLists = listOf(validList), + invalidLists = listOf(InvalidList(invalidList, expectedErrorFn = null)), ) } + private fun constraintViolationForDuplicateIndices( + duplicateIndices: List, + ): ((constraintViolation: Symbol, originalValueBindingName: String) -> Writable) { + fun ret( + constraintViolation: Symbol, + originalValueBindingName: String, + ): Writable = + writable { + // Public documentation for the unique items constraint violation states that callers should not + // rely on the order of the elements in `duplicate_indices`. However, the algorithm is deterministic, + // so we can internally assert the order. If the algorithm changes, the test cases will need to be + // adjusted. + rustTemplate( + """ + #{ConstraintViolation}::UniqueItems { + duplicate_indices: vec![${duplicateIndices.joinToString(", ")}], + original: $originalValueBindingName, + } + """, + "ConstraintViolation" to constraintViolation, + ) + } + return ::ret } - private val uniqueItemsTraitTestCases = listOf( - // We only need one test case, since `@uniqueItems` is not parameterizable. - TestCase( - model = generateModel("@uniqueItems"), - validLists = listOf( - ArrayNode.fromStrings(), - ArrayNode.fromStrings("0", "1"), - ArrayNode.fromStrings("a", "b", "a2"), - ArrayNode.fromStrings((0..69).map(Int::toString).toList()), - ), - invalidLists = listOf( - // Two elements, both duplicate. - InvalidList( - node = ArrayNode.fromStrings("0", "0"), - expectedErrorFn = constraintViolationForDuplicateIndices(listOf(0, 1)), - ), - // Two duplicate items, one at the beginning, one at the end. - InvalidList( - node = ArrayNode.fromStrings("0", "1", "2", "3", "4", "5", "0"), - expectedErrorFn = constraintViolationForDuplicateIndices(listOf(0, 6)), - ), - // Several duplicate items, all the same. - InvalidList( - node = ArrayNode.fromStrings("0", "1", "0", "0", "4", "0", "6", "7"), - expectedErrorFn = constraintViolationForDuplicateIndices(listOf(0, 2, 3, 5)), - ), - // Several equivalence classes. - InvalidList( - node = ArrayNode.fromStrings("0", "1", "0", "2", "1", "0", "2", "7", "2"), - // Note how the duplicate indices are not ordered. - expectedErrorFn = constraintViolationForDuplicateIndices(listOf(0, 1, 2, 3, 6, 5, 4, 8)), - ), - // The worst case: a fairly large number of elements, all duplicate. - InvalidList( - node = ArrayNode.fromStrings(generateSequence { "69" }.take(69).toList()), - expectedErrorFn = constraintViolationForDuplicateIndices((0..68).toList()), - ), + private val uniqueItemsTraitTestCases = + listOf( + // We only need one test case, since `@uniqueItems` is not parameterizable. + TestCase( + model = generateModel("@uniqueItems"), + validLists = + listOf( + ArrayNode.fromStrings(), + ArrayNode.fromStrings("0", "1"), + ArrayNode.fromStrings("a", "b", "a2"), + ArrayNode.fromStrings((0..69).map(Int::toString).toList()), + ), + invalidLists = + listOf( + // Two elements, both duplicate. + InvalidList( + node = ArrayNode.fromStrings("0", "0"), + expectedErrorFn = constraintViolationForDuplicateIndices(listOf(0, 1)), + ), + // Two duplicate items, one at the beginning, one at the end. + InvalidList( + node = ArrayNode.fromStrings("0", "1", "2", "3", "4", "5", "0"), + expectedErrorFn = constraintViolationForDuplicateIndices(listOf(0, 6)), + ), + // Several duplicate items, all the same. + InvalidList( + node = ArrayNode.fromStrings("0", "1", "0", "0", "4", "0", "6", "7"), + expectedErrorFn = constraintViolationForDuplicateIndices(listOf(0, 2, 3, 5)), + ), + // Several equivalence classes. + InvalidList( + node = ArrayNode.fromStrings("0", "1", "0", "2", "1", "0", "2", "7", "2"), + // Note how the duplicate indices are not ordered. + expectedErrorFn = constraintViolationForDuplicateIndices(listOf(0, 1, 2, 3, 6, 5, 4, 8)), + ), + // The worst case: a fairly large number of elements, all duplicate. + InvalidList( + node = ArrayNode.fromStrings(generateSequence { "69" }.take(69).toList()), + expectedErrorFn = constraintViolationForDuplicateIndices((0..68).toList()), + ), + ), ), - ), - ) + ) override fun provideArguments(context: ExtensionContext?): Stream = (lengthTraitTestCases + uniqueItemsTraitTestCases).map { Arguments.of(it) }.stream() @@ -170,11 +179,12 @@ class ConstrainedCollectionGeneratorTest { val project = TestWorkspace.testProject(codegenContext.symbolProvider) for (shape in listOf(constrainedListShape, constrainedSetShape)) { - val shapeName = when (shape) { - is SetShape -> "set" - is ListShape -> "list" - else -> UNREACHABLE("Shape is either list or set.") - } + val shapeName = + when (shape) { + is SetShape -> "set" + is ListShape -> "list" + else -> UNREACHABLE("Shape is either list or set.") + } project.withModule(ServerRustModule.Model) { render(codegenContext, this, shape) @@ -226,29 +236,31 @@ class ConstrainedCollectionGeneratorTest { } unitTest( name = "${shapeNameIdx}_try_from_fail", - block = writable { - rust( - """ - let $shapeNameIdx = $buildInvalidFnName(); - let constrained_res: Result <$typeName, _> = $shapeNameIdx.clone().try_into(); - """, - ) - - invalidList.expectedErrorFn?.also { expectedErrorFn -> - val expectedErrorWritable = expectedErrorFn( - codegenContext.constraintViolationSymbolProvider.toSymbol(shape), - shapeNameIdx, + block = + writable { + rust( + """ + let $shapeNameIdx = $buildInvalidFnName(); + let constrained_res: Result <$typeName, _> = $shapeNameIdx.clone().try_into(); + """, ) - rust("let err = constrained_res.unwrap_err();") - withBlock("let expected_err = ", ";") { - rustTemplate("#{ExpectedError:W}", "ExpectedError" to expectedErrorWritable) + invalidList.expectedErrorFn?.also { expectedErrorFn -> + val expectedErrorWritable = + expectedErrorFn( + codegenContext.constraintViolationSymbolProvider.toSymbol(shape), + shapeNameIdx, + ) + + rust("let err = constrained_res.unwrap_err();") + withBlock("let expected_err = ", ";") { + rustTemplate("#{ExpectedError:W}", "ExpectedError" to expectedErrorWritable) + } + rust("assert_eq!(err, expected_err);") + } ?: run { + rust("constrained_res.unwrap_err();") } - rust("assert_eq!(err, expected_err);") - } ?: run { - rust("constrained_res.unwrap_err();") - } - }, + }, ) } } diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedMapGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedMapGeneratorTest.kt index d475ccbd37e..ae050460a48 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedMapGeneratorTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedMapGeneratorTest.kt @@ -31,38 +31,38 @@ import software.amazon.smithy.rust.codegen.server.smithy.transformers.ShapesReac import java.util.stream.Stream class ConstrainedMapGeneratorTest { - data class TestCase(val model: Model, val validMap: ObjectNode, val invalidMap: ObjectNode) class ConstrainedMapGeneratorTestProvider : ArgumentsProvider { - private val testCases = listOf( - // Min and max. - Triple("@length(min: 11, max: 12)", 11, 13), - // Min equal to max. - Triple("@length(min: 11, max: 11)", 11, 12), - // Only min. - Triple("@length(min: 11)", 15, 10), - // Only max. - Triple("@length(max: 11)", 11, 12), - ).map { - val validStringMap = List(it.second) { index -> index.toString() to "value" }.toMap() - val inValidStringMap = List(it.third) { index -> index.toString() to "value" }.toMap() - Triple(it.first, ObjectNode.fromStringMap(validStringMap), ObjectNode.fromStringMap(inValidStringMap)) - }.map { (trait, validMap, invalidMap) -> - TestCase( - """ - namespace test - - $trait - map ConstrainedMap { - key: String, - value: String - } - """.asSmithyModel().let(ShapesReachableFromOperationInputTagger::transform), - validMap, - invalidMap, - ) - } + private val testCases = + listOf( + // Min and max. + Triple("@length(min: 11, max: 12)", 11, 13), + // Min equal to max. + Triple("@length(min: 11, max: 11)", 11, 12), + // Only min. + Triple("@length(min: 11)", 15, 10), + // Only max. + Triple("@length(max: 11)", 11, 12), + ).map { + val validStringMap = List(it.second) { index -> index.toString() to "value" }.toMap() + val inValidStringMap = List(it.third) { index -> index.toString() to "value" }.toMap() + Triple(it.first, ObjectNode.fromStringMap(validStringMap), ObjectNode.fromStringMap(inValidStringMap)) + }.map { (trait, validMap, invalidMap) -> + TestCase( + """ + namespace test + + $trait + map ConstrainedMap { + key: String, + value: String + } + """.asSmithyModel().let(ShapesReachableFromOperationInputTagger::transform), + validMap, + invalidMap, + ) + } override fun provideArguments(context: ExtensionContext?): Stream = testCases.map { Arguments.of(it) }.stream() @@ -129,7 +129,8 @@ class ConstrainedMapGeneratorTest { @Test fun `type should not be constructable without using a constructor`() { - val model = """ + val model = + """ namespace test @length(min: 1, max: 69) @@ -137,7 +138,7 @@ class ConstrainedMapGeneratorTest { key: String, value: String } - """.asSmithyModel().let(ShapesReachableFromOperationInputTagger::transform) + """.asSmithyModel().let(ShapesReachableFromOperationInputTagger::transform) val constrainedMapShape = model.lookup("test#ConstrainedMap") val writer = RustWriter.forModule(ServerRustModule.Model.name) diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedNumberGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedNumberGeneratorTest.kt index 3a34c7753cc..9e19265e592 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedNumberGeneratorTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedNumberGeneratorTest.kt @@ -26,8 +26,8 @@ import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestCode import java.util.stream.Stream class ConstrainedNumberGeneratorTest { - data class TestCaseInputs(val constraintAnnotation: String, val validValue: Int, val invalidValue: Int) + data class TestCase(val model: Model, val validValue: Int, val invalidValue: Int, val shapeName: String) class ConstrainedNumberGeneratorTestProvider : ArgumentsProvider { @@ -129,12 +129,13 @@ class ConstrainedNumberGeneratorTest { @ArgumentsSource(NoStructuralConstructorTestProvider::class) fun `type should not be constructable without using a constructor`(args: Triple) { val (smithyType, shapeName, rustType) = args - val model = """ + val model = + """ namespace test @range(min: -1, max: 5) $smithyType $shapeName - """.asSmithyModel() + """.asSmithyModel() val constrainedShape = model.lookup("test#$shapeName") val codegenContext = serverTestCodegenContext(model) diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedStringGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedStringGeneratorTest.kt index 5e2f828e1b4..bdc89c00782 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedStringGeneratorTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedStringGeneratorTest.kt @@ -32,42 +32,44 @@ class ConstrainedStringGeneratorTest { data class TestCase(val model: Model, val validString: String, val invalidString: String) class ConstrainedStringGeneratorTestProvider : ArgumentsProvider { - private val testCases = listOf( - // Min and max. - Triple("@length(min: 11, max: 12)", "validString", "invalidString"), - // Min equal to max. - Triple("@length(min: 11, max: 11)", "validString", "invalidString"), - // Only min. - Triple("@length(min: 11)", "validString", ""), - // Only max. - Triple("@length(max: 11)", "", "invalidString"), - // Count Unicode scalar values, not `.len()`. - Triple( - "@length(min: 3, max: 3)", - "👍👍👍", // These three emojis are three Unicode scalar values. - "👍👍👍👍", - ), - Triple("@pattern(\"^[a-z]+$\")", "valid", "123 invalid"), - Triple( - """ - @length(min: 3, max: 10) - @pattern("^a string$") - """, - "a string", "an invalid string", - ), - Triple("@pattern(\"123\")", "some pattern 123 in the middle", "no pattern at all"), - ).map { - TestCase( - """ - namespace test - - ${it.first} - string ConstrainedString - """.asSmithyModel(), - it.second, - it.third, - ) - } + private val testCases = + listOf( + // Min and max. + Triple("@length(min: 11, max: 12)", "validString", "invalidString"), + // Min equal to max. + Triple("@length(min: 11, max: 11)", "validString", "invalidString"), + // Only min. + Triple("@length(min: 11)", "validString", ""), + // Only max. + Triple("@length(max: 11)", "", "invalidString"), + // Count Unicode scalar values, not `.len()`. + Triple( + "@length(min: 3, max: 3)", + // These three emojis are three Unicode scalar values. + "👍👍👍", + "👍👍👍👍", + ), + Triple("@pattern(\"^[a-z]+$\")", "valid", "123 invalid"), + Triple( + """ + @length(min: 3, max: 10) + @pattern("^a string$") + """, + "a string", "an invalid string", + ), + Triple("@pattern(\"123\")", "some pattern 123 in the middle", "no pattern at all"), + ).map { + TestCase( + """ + namespace test + + ${it.first} + string ConstrainedString + """.asSmithyModel(), + it.second, + it.third, + ) + } override fun provideArguments(context: ExtensionContext?): Stream = testCases.map { Arguments.of(it) }.stream() @@ -132,12 +134,13 @@ class ConstrainedStringGeneratorTest { @Test fun `type should not be constructable without using a constructor`() { - val model = """ + val model = + """ namespace test @length(min: 1, max: 69) string ConstrainedString - """.asSmithyModel() + """.asSmithyModel() val constrainedStringShape = model.lookup("test#ConstrainedString") val codegenContext = serverTestCodegenContext(model) @@ -158,7 +161,8 @@ class ConstrainedStringGeneratorTest { @Test fun `Display implementation`() { - val model = """ + val model = + """ namespace test @length(min: 1, max: 69) @@ -167,7 +171,7 @@ class ConstrainedStringGeneratorTest { @sensitive @length(min: 1, max: 78) string SensitiveConstrainedString - """.asSmithyModel() + """.asSmithyModel() val constrainedStringShape = model.lookup("test#ConstrainedString") val sensitiveConstrainedStringShape = model.lookup("test#SensitiveConstrainedString") @@ -216,12 +220,13 @@ class ConstrainedStringGeneratorTest { @Test fun `A regex that is accepted by Smithy but not by the regex crate causes tests to fail`() { - val model = """ + val model = + """ namespace test @pattern("import (?!static).+") string PatternStringWithLookahead - """.asSmithyModel() + """.asSmithyModel() val constrainedStringShape = model.lookup("test#PatternStringWithLookahead") val codegenContext = serverTestCodegenContext(model) diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderConstraintViolationsTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderConstraintViolationsTest.kt index f36976e715e..329f75699ea 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderConstraintViolationsTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderConstraintViolationsTest.kt @@ -10,14 +10,14 @@ import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverIntegrationTest class ServerBuilderConstraintViolationsTest { - // This test exists not to regress on [this](https://github.com/smithy-lang/smithy-rs/issues/2343) issue. // We generated constraint violation variants, pointing to a structure (StructWithInnerDefault below), // but the structure was not constrained, because the structure's member have a default value // and default values are validated at generation time from the model. @Test fun `it should not generate constraint violations for members with a default value`() { - val model = """ + val model = + """ namespace test use aws.protocols#restJson1 @@ -43,7 +43,7 @@ class ServerBuilderConstraintViolationsTest { @default(false) inner: PrimitiveBoolean } - """.asSmithyModel(smithyVersion = "2") + """.asSmithyModel(smithyVersion = "2") serverIntegrationTest(model) } } diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderDefaultValuesTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderDefaultValuesTest.kt index 966426f948d..e0a20542851 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderDefaultValuesTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderDefaultValuesTest.kt @@ -43,52 +43,54 @@ import java.util.stream.Stream @TestInstance(TestInstance.Lifecycle.PER_CLASS) class ServerBuilderDefaultValuesTest { // When defaults are used, the model will be generated with these in the `@default` trait. - private val defaultValues = mapOf( - "Boolean" to "true", - "String" to "foo".dq(), - "Byte" to "5", - "Short" to "55", - "Integer" to "555", - "Long" to "5555", - "Float" to "0.5", - "Double" to "0.55", - "Timestamp" to "1985-04-12T23:20:50.52Z".dq(), - // "BigInteger" to "55555", "BigDecimal" to "0.555", // TODO(https://github.com/smithy-lang/smithy-rs/issues/312) - "StringList" to "[]", - "IntegerMap" to "{}", - "Language" to "en".dq(), - "DocumentBoolean" to "true", - "DocumentString" to "foo".dq(), - "DocumentNumberPosInt" to "100", - "DocumentNumberNegInt" to "-100", - "DocumentNumberFloat" to "0.1", - "DocumentList" to "[]", - "DocumentMap" to "{}", - ) + private val defaultValues = + mapOf( + "Boolean" to "true", + "String" to "foo".dq(), + "Byte" to "5", + "Short" to "55", + "Integer" to "555", + "Long" to "5555", + "Float" to "0.5", + "Double" to "0.55", + "Timestamp" to "1985-04-12T23:20:50.52Z".dq(), + // "BigInteger" to "55555", "BigDecimal" to "0.555", // TODO(https://github.com/smithy-lang/smithy-rs/issues/312) + "StringList" to "[]", + "IntegerMap" to "{}", + "Language" to "en".dq(), + "DocumentBoolean" to "true", + "DocumentString" to "foo".dq(), + "DocumentNumberPosInt" to "100", + "DocumentNumberNegInt" to "-100", + "DocumentNumberFloat" to "0.1", + "DocumentList" to "[]", + "DocumentMap" to "{}", + ) // When the test applies values to validate we honor custom values, use these (different) values. - private val customValues = mapOf( - "Boolean" to "false", - "String" to "bar".dq(), - "Byte" to "6", - "Short" to "66", - "Integer" to "666", - "Long" to "6666", - "Float" to "0.6", - "Double" to "0.66", - "Timestamp" to "2022-11-25T17:30:50.00Z".dq(), - // "BigInteger" to "55555", "BigDecimal" to "0.555", // TODO(https://github.com/smithy-lang/smithy-rs/issues/312) - "StringList" to "[]", - "IntegerMap" to "{}", - "Language" to "fr".dq(), - "DocumentBoolean" to "false", - "DocumentString" to "bar".dq(), - "DocumentNumberPosInt" to "1000", - "DocumentNumberNegInt" to "-1000", - "DocumentNumberFloat" to "0.01", - "DocumentList" to "[]", - "DocumentMap" to "{}", - ) + private val customValues = + mapOf( + "Boolean" to "false", + "String" to "bar".dq(), + "Byte" to "6", + "Short" to "66", + "Integer" to "666", + "Long" to "6666", + "Float" to "0.6", + "Double" to "0.66", + "Timestamp" to "2022-11-25T17:30:50.00Z".dq(), + // "BigInteger" to "55555", "BigDecimal" to "0.555", // TODO(https://github.com/smithy-lang/smithy-rs/issues/312) + "StringList" to "[]", + "IntegerMap" to "{}", + "Language" to "fr".dq(), + "DocumentBoolean" to "false", + "DocumentString" to "bar".dq(), + "DocumentNumberPosInt" to "1000", + "DocumentNumberNegInt" to "-1000", + "DocumentNumberFloat" to "0.01", + "DocumentList" to "[]", + "DocumentMap" to "{}", + ) @ParameterizedTest(name = "(#{index}) Server builders and default values. Params = requiredTrait: {0}, nullDefault: {1}, applyDefaultValues: {2}, builderGeneratorKind: {3}, assertValues: {4}") @MethodSource("testParameters") @@ -116,34 +118,37 @@ class ServerBuilderDefaultValuesTest { } val rustValues = setupRustValuesForTest(assertValues) - val setters = if (applyDefaultValues) { - structSetters(rustValues, nullDefault && !requiredTrait) - } else { - writable { } - } + val setters = + if (applyDefaultValues) { + structSetters(rustValues, nullDefault && !requiredTrait) + } else { + writable { } + } val unwrapBuilder = if (nullDefault && requiredTrait && applyDefaultValues) ".unwrap()" else "" unitTest( name = "generates_default_required_values", - block = writable { - rustTemplate( - """ - let my_struct = MyStruct::builder() - #{Setters:W} - .build() - $unwrapBuilder; - - #{Assertions:W} - """, - "Assertions" to assertions( - rustValues, - applyDefaultValues, - nullDefault, - requiredTrait, - applyDefaultValues, - ), - "Setters" to setters, - ) - }, + block = + writable { + rustTemplate( + """ + let my_struct = MyStruct::builder() + #{Setters:W} + .build() + $unwrapBuilder; + + #{Assertions:W} + """, + "Assertions" to + assertions( + rustValues, + applyDefaultValues, + nullDefault, + requiredTrait, + applyDefaultValues, + ), + "Setters" to setters, + ) + }, ) } @@ -153,39 +158,49 @@ class ServerBuilderDefaultValuesTest { } private fun setupRustValuesForTest(valuesMap: Map): Map { - return valuesMap + mapOf( - "Byte" to "${valuesMap["Byte"]}i8", - "Short" to "${valuesMap["Short"]}i16", - "Integer" to "${valuesMap["Integer"]}i32", - "Long" to "${valuesMap["Long"]}i64", - "Float" to "${valuesMap["Float"]}f32", - "Double" to "${valuesMap["Double"]}f64", - "Language" to "crate::model::Language::${valuesMap["Language"]!!.replace(""""""", "").toPascalCase()}", - "Timestamp" to """aws_smithy_types::DateTime::from_str(${valuesMap["Timestamp"]}, aws_smithy_types::date_time::Format::DateTime).unwrap()""", - // These must be empty - "StringList" to "Vec::::new()", - "IntegerMap" to "std::collections::HashMap::::new()", - "DocumentList" to "Vec::::new()", - "DocumentMap" to "std::collections::HashMap::::new()", - ) + valuesMap - .filter { it.value?.startsWith("Document") ?: false } - .map { it.key to "${it.value}.into()" } + return valuesMap + + mapOf( + "Byte" to "${valuesMap["Byte"]}i8", + "Short" to "${valuesMap["Short"]}i16", + "Integer" to "${valuesMap["Integer"]}i32", + "Long" to "${valuesMap["Long"]}i64", + "Float" to "${valuesMap["Float"]}f32", + "Double" to "${valuesMap["Double"]}f64", + "Language" to "crate::model::Language::${valuesMap["Language"]!!.replace(""""""", "").toPascalCase()}", + "Timestamp" to """aws_smithy_types::DateTime::from_str(${valuesMap["Timestamp"]}, aws_smithy_types::date_time::Format::DateTime).unwrap()""", + // These must be empty + "StringList" to "Vec::::new()", + "IntegerMap" to "std::collections::HashMap::::new()", + "DocumentList" to "Vec::::new()", + "DocumentMap" to "std::collections::HashMap::::new()", + ) + + valuesMap + .filter { it.value?.startsWith("Document") ?: false } + .map { it.key to "${it.value}.into()" } } - private fun writeServerBuilderGeneratorWithoutPublicConstrainedTypes(rustCrate: RustCrate, writer: RustWriter, model: Model, symbolProvider: RustSymbolProvider) { + private fun writeServerBuilderGeneratorWithoutPublicConstrainedTypes( + rustCrate: RustCrate, + writer: RustWriter, + model: Model, + symbolProvider: RustSymbolProvider, + ) { val struct = model.lookup("com.test#MyStruct") - val codegenContext = serverTestCodegenContext( - model, - settings = serverTestRustSettings( - codegenConfig = ServerCodegenConfig(publicConstrainedTypes = false), - ), - ) - val builderGenerator = ServerBuilderGeneratorWithoutPublicConstrainedTypes( - codegenContext, - struct, - SmithyValidationExceptionConversionGenerator(codegenContext), - ServerRestJsonProtocol(codegenContext), - ) + val codegenContext = + serverTestCodegenContext( + model, + settings = + serverTestRustSettings( + codegenConfig = ServerCodegenConfig(publicConstrainedTypes = false), + ), + ) + val builderGenerator = + ServerBuilderGeneratorWithoutPublicConstrainedTypes( + codegenContext, + struct, + SmithyValidationExceptionConversionGenerator(codegenContext), + ServerRestJsonProtocol(codegenContext), + ) writer.implBlock(symbolProvider.toSymbol(struct)) { builderGenerator.renderConvenienceMethod(writer) @@ -200,15 +215,21 @@ class ServerBuilderDefaultValuesTest { StructureGenerator(model, symbolProvider, writer, struct, emptyList(), codegenContext.structSettings()).render() } - private fun writeServerBuilderGenerator(rustCrate: RustCrate, writer: RustWriter, model: Model, symbolProvider: RustSymbolProvider) { + private fun writeServerBuilderGenerator( + rustCrate: RustCrate, + writer: RustWriter, + model: Model, + symbolProvider: RustSymbolProvider, + ) { val struct = model.lookup("com.test#MyStruct") val codegenContext = serverTestCodegenContext(model) - val builderGenerator = ServerBuilderGenerator( - codegenContext, - struct, - SmithyValidationExceptionConversionGenerator(codegenContext), - ServerRestJsonProtocol(codegenContext), - ) + val builderGenerator = + ServerBuilderGenerator( + codegenContext, + struct, + SmithyValidationExceptionConversionGenerator(codegenContext), + ServerRestJsonProtocol(codegenContext), + ) writer.implBlock(symbolProvider.toSymbol(struct)) { builderGenerator.renderConvenienceMethod(writer) @@ -223,7 +244,10 @@ class ServerBuilderDefaultValuesTest { StructureGenerator(model, symbolProvider, writer, struct, emptyList(), codegenContext.structSettings()).render() } - private fun structSetters(values: Map, optional: Boolean) = writable { + private fun structSetters( + values: Map, + optional: Boolean, + ) = writable { for ((key, value) in values) { withBlock(".${key.toSnakeCase()}(", ")") { conditionalBlock("Some(", ")", optional) { @@ -259,27 +283,30 @@ class ServerBuilderDefaultValuesTest { if (!hasSetValues) { rust("assert!($member.is_none());") } else { - val actual = writable { - rust(member) - if (!requiredTrait && !(hasDefaults && !hasNullValues)) { - rust(".unwrap()") + val actual = + writable { + rust(member) + if (!requiredTrait && !(hasDefaults && !hasNullValues)) { + rust(".unwrap()") + } } - } - val expected = writable { - val expected = if (key == "DocumentNull") { - "aws_smithy_types::Document::Null" - } else if (key == "DocumentString") { - "String::from($value).into()" - } else if (key.startsWith("DocumentNumber")) { - val type = key.replace("DocumentNumber", "") - "aws_smithy_types::Document::Number(aws_smithy_types::Number::$type($value))" - } else if (key.startsWith("Document")) { - "$value.into()" - } else { - "$value" + val expected = + writable { + val expected = + if (key == "DocumentNull") { + "aws_smithy_types::Document::Null" + } else if (key == "DocumentString") { + "String::from($value).into()" + } else if (key.startsWith("DocumentNumber")) { + val type = key.replace("DocumentNumber", "") + "aws_smithy_types::Document::Number(aws_smithy_types::Number::$type($value))" + } else if (key.startsWith("Document")) { + "$value.into()" + } else { + "$value" + } + rust(expected) } - rust(expected) - } rustTemplate("assert_eq!(#{Actual:W}, #{Expected:W});", "Actual" to actual, "Expected" to expected) } } @@ -293,19 +320,21 @@ class ServerBuilderDefaultValuesTest { ): Model { val requiredOrNot = if (requiredTrait) "@required" else "" - val members = values.entries.joinToString(", ") { - val value = if (applyDefaultValues) { - "= ${it.value}" - } else if (nullDefault) { - "= null" - } else { - "" + val members = + values.entries.joinToString(", ") { + val value = + if (applyDefaultValues) { + "= ${it.value}" + } else if (nullDefault) { + "= null" + } else { + "" + } + """ + $requiredOrNot + ${it.key.toPascalCase()}: ${it.key} $value + """ } - """ - $requiredOrNot - ${it.key.toPascalCase()}: ${it.key} $value - """ - } val model = """ namespace com.test @@ -352,26 +381,22 @@ class ServerBuilderDefaultValuesTest { } private fun testParameters(): Stream { - val builderGeneratorKindList = listOf( - BuilderGeneratorKind.SERVER_BUILDER_GENERATOR, - BuilderGeneratorKind.SERVER_BUILDER_GENERATOR_WITHOUT_PUBLIC_CONSTRAINED_TYPES, - ) + val builderGeneratorKindList = + listOf( + BuilderGeneratorKind.SERVER_BUILDER_GENERATOR, + BuilderGeneratorKind.SERVER_BUILDER_GENERATOR_WITHOUT_PUBLIC_CONSTRAINED_TYPES, + ) return Stream.of( TestConfig(defaultValues, requiredTrait = false, nullDefault = true, applyDefaultValues = true), TestConfig(defaultValues, requiredTrait = false, nullDefault = true, applyDefaultValues = false), - TestConfig(customValues, requiredTrait = false, nullDefault = true, applyDefaultValues = true), TestConfig(customValues, requiredTrait = false, nullDefault = true, applyDefaultValues = false), - TestConfig(defaultValues, requiredTrait = true, nullDefault = true, applyDefaultValues = true), TestConfig(customValues, requiredTrait = true, nullDefault = true, applyDefaultValues = true), - TestConfig(defaultValues, requiredTrait = false, nullDefault = false, applyDefaultValues = true), TestConfig(defaultValues, requiredTrait = false, nullDefault = false, applyDefaultValues = false), - TestConfig(customValues, requiredTrait = false, nullDefault = false, applyDefaultValues = true), TestConfig(customValues, requiredTrait = false, nullDefault = false, applyDefaultValues = false), - TestConfig(defaultValues, requiredTrait = true, nullDefault = false, applyDefaultValues = true), TestConfig(customValues, requiredTrait = true, nullDefault = false, applyDefaultValues = true), ).flatMap { (assertValues, requiredTrait, nullDefault, applyDefaultValues) -> diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGeneratorTest.kt index 27e45c373c1..ef20b24fb92 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGeneratorTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGeneratorTest.kt @@ -24,7 +24,8 @@ import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestCode class ServerBuilderGeneratorTest { @Test fun `it respects the sensitive trait in Debug impl`() { - val model = """ + val model = + """ namespace test @sensitive string SecretKey @@ -37,7 +38,7 @@ class ServerBuilderGeneratorTest { password: Password, secretKey: SecretKey } - """.asSmithyModel() + """.asSmithyModel() val codegenContext = serverTestCodegenContext(model) val project = TestWorkspace.testProject() @@ -46,12 +47,13 @@ class ServerBuilderGeneratorTest { val shape = model.lookup("test#Credentials") StructureGenerator(model, codegenContext.symbolProvider, writer, shape, emptyList(), codegenContext.structSettings()).render() - val builderGenerator = ServerBuilderGenerator( - codegenContext, - shape, - SmithyValidationExceptionConversionGenerator(codegenContext), - ServerRestJsonProtocol(codegenContext), - ) + val builderGenerator = + ServerBuilderGenerator( + codegenContext, + shape, + SmithyValidationExceptionConversionGenerator(codegenContext), + ServerRestJsonProtocol(codegenContext), + ) builderGenerator.render(project, writer) diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerEnumGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerEnumGeneratorTest.kt index bffb82bb408..32361060328 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerEnumGeneratorTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerEnumGeneratorTest.kt @@ -16,7 +16,8 @@ import software.amazon.smithy.rust.codegen.server.smithy.customizations.SmithyVa import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestCodegenContext class ServerEnumGeneratorTest { - private val model = """ + private val model = + """ namespace test @enum([ { @@ -33,7 +34,7 @@ class ServerEnumGeneratorTest { }, ]) string InstanceType - """.asSmithyModel() + """.asSmithyModel() private val codegenContext = serverTestCodegenContext(model) private val writer = RustWriter.forModule("model") diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerHttpSensitivityGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerHttpSensitivityGeneratorTest.kt index a90acec6e48..0cddbd246a3 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerHttpSensitivityGeneratorTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerHttpSensitivityGeneratorTest.kt @@ -22,14 +22,16 @@ import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestSymbolProvider class ServerHttpSensitivityGeneratorTest { - private val codegenScope = arrayOf( - "SmithyHttpServer" to ServerCargoDependency.smithyHttpServer(TestRuntimeConfig).toType(), - "Http" to CargoDependency.Http.toType(), - ) + private val codegenScope = + arrayOf( + "SmithyHttpServer" to ServerCargoDependency.smithyHttpServer(TestRuntimeConfig).toType(), + "Http" to CargoDependency.Http.toType(), + ) @Test fun `query closure`() { - val model = """ + val model = + """ namespace test operation Secret { @@ -48,7 +50,7 @@ class ServerHttpSensitivityGeneratorTest { @httpQuery("query_b") queryB: SensitiveString } - """.asSmithyModel() + """.asSmithyModel() val operation = model.operationShapes.toList()[0] val generator = ServerHttpSensitivityGenerator(model, operation, TestRuntimeConfig) @@ -77,7 +79,8 @@ class ServerHttpSensitivityGeneratorTest { @Test fun `query params closure`() { - val model = """ + val model = + """ namespace test operation Secret { @@ -95,7 +98,7 @@ class ServerHttpSensitivityGeneratorTest { @httpQueryParams() params: StringMap, } - """.asSmithyModel() + """.asSmithyModel() val operation = model.operationShapes.toList()[0] val generator = ServerHttpSensitivityGenerator(model, operation, TestRuntimeConfig) @@ -123,7 +126,8 @@ class ServerHttpSensitivityGeneratorTest { @Test fun `query params key closure`() { - val model = """ + val model = + """ namespace test operation Secret { @@ -144,7 +148,7 @@ class ServerHttpSensitivityGeneratorTest { value: String } - """.asSmithyModel() + """.asSmithyModel() val operation = model.operationShapes.toList()[0] val generator = ServerHttpSensitivityGenerator(model, operation, TestRuntimeConfig) @@ -171,7 +175,8 @@ class ServerHttpSensitivityGeneratorTest { @Test fun `query params value closure`() { - val model = """ + val model = + """ namespace test operation Secret { @@ -192,7 +197,7 @@ class ServerHttpSensitivityGeneratorTest { value: SensitiveValue } - """.asSmithyModel() + """.asSmithyModel() val operation = model.operationShapes.toList()[0] val generator = ServerHttpSensitivityGenerator(model, operation, TestRuntimeConfig) @@ -219,7 +224,8 @@ class ServerHttpSensitivityGeneratorTest { @Test fun `query params none`() { - val model = """ + val model = + """ namespace test operation Secret { @@ -237,7 +243,7 @@ class ServerHttpSensitivityGeneratorTest { value: String } - """.asSmithyModel() + """.asSmithyModel() val operation = model.operationShapes.toList()[0] val generator = ServerHttpSensitivityGenerator(model, operation, TestRuntimeConfig) @@ -250,7 +256,8 @@ class ServerHttpSensitivityGeneratorTest { @Test fun `header closure`() { - val model = """ + val model = + """ namespace test operation Secret { @@ -269,7 +276,7 @@ class ServerHttpSensitivityGeneratorTest { @httpHeader("header-b") headerB: SensitiveString } - """.asSmithyModel() + """.asSmithyModel() val operation = model.operationShapes.toList()[0] val generator = ServerHttpSensitivityGenerator(model, operation, TestRuntimeConfig) @@ -299,7 +306,8 @@ class ServerHttpSensitivityGeneratorTest { @Test fun `prefix header closure`() { - val model = """ + val model = + """ namespace test operation Secret { @@ -318,7 +326,7 @@ class ServerHttpSensitivityGeneratorTest { value: String } - """.asSmithyModel() + """.asSmithyModel() val operation = model.operationShapes.toList()[0] val generator = ServerHttpSensitivityGenerator(model, operation, TestRuntimeConfig) @@ -349,7 +357,8 @@ class ServerHttpSensitivityGeneratorTest { @Test fun `prefix header none`() { - val model = """ + val model = + """ namespace test operation Secret { @@ -367,7 +376,7 @@ class ServerHttpSensitivityGeneratorTest { value: String } - """.asSmithyModel() + """.asSmithyModel() val operation = model.operationShapes.toList()[0] val generator = ServerHttpSensitivityGenerator(model, operation, TestRuntimeConfig) @@ -379,7 +388,8 @@ class ServerHttpSensitivityGeneratorTest { @Test fun `prefix headers key closure`() { - val model = """ + val model = + """ namespace test operation Secret { @@ -399,7 +409,7 @@ class ServerHttpSensitivityGeneratorTest { value: String } - """.asSmithyModel() + """.asSmithyModel() val operation = model.operationShapes.toList()[0] val generator = ServerHttpSensitivityGenerator(model, operation, TestRuntimeConfig) @@ -432,7 +442,8 @@ class ServerHttpSensitivityGeneratorTest { @Test fun `prefix headers value closure`() { - val model = """ + val model = + """ namespace test operation Secret { @@ -452,7 +463,7 @@ class ServerHttpSensitivityGeneratorTest { key: String, value: SensitiveValue } - """.asSmithyModel() + """.asSmithyModel() val operation = model.operationShapes.toList()[0] val generator = ServerHttpSensitivityGenerator(model, operation, TestRuntimeConfig) @@ -486,7 +497,8 @@ class ServerHttpSensitivityGeneratorTest { @Test fun `uri closure`() { - val model = """ + val model = + """ namespace test @http(method: "GET", uri: "/secret/{labelA}/{labelB}") @@ -505,7 +517,7 @@ class ServerHttpSensitivityGeneratorTest { @httpLabel labelB: SensitiveString, } - """.asSmithyModel() + """.asSmithyModel() val operation = model.operationShapes.toList()[0] val generator = ServerHttpSensitivityGenerator(model, operation, TestRuntimeConfig) @@ -535,7 +547,8 @@ class ServerHttpSensitivityGeneratorTest { @Test fun `uri greedy`() { - val model = """ + val model = + """ namespace test @http(method: "GET", uri: "/secret/{labelA}/{labelB+}/labelC") @@ -554,7 +567,7 @@ class ServerHttpSensitivityGeneratorTest { @httpLabel labelB: SensitiveString, } - """.asSmithyModel() + """.asSmithyModel() val operation = model.operationShapes.toList()[0] val generator = ServerHttpSensitivityGenerator(model, operation, TestRuntimeConfig) diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiatorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiatorTest.kt index 5a6f9791518..1a1f9f01eab 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiatorTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiatorTest.kt @@ -39,7 +39,8 @@ import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestCode class ServerInstantiatorTest { // This model started off from the one in `InstantiatorTest.kt` from `codegen-core`. - private val model = """ + private val model = + """ namespace com.test use smithy.framework#ValidationException @@ -136,7 +137,9 @@ class ServerInstantiatorTest { }, ]) string NamedEnum - """.asSmithyModel().let { RecursiveShapeBoxer().transform(it) } + """.asSmithyModel().let { + RecursiveShapeBoxer().transform(it) + } private val codegenContext = serverTestCodegenContext(model) private val symbolProvider = codegenContext.symbolProvider @@ -269,17 +272,19 @@ class ServerInstantiatorTest { UnionGenerator(model, symbolProvider, this, nestedUnion).render() unitTest("writable_for_shapes") { - val sut = ServerInstantiator( - codegenContext, - customWritable = object : Instantiator.CustomWritable { - override fun generate(shape: Shape): Writable? = - if (model.lookup("com.test#NestedStruct\$num") == shape) { - writable("40 + 2") - } else { - null - } - }, - ) + val sut = + ServerInstantiator( + codegenContext, + customWritable = + object : Instantiator.CustomWritable { + override fun generate(shape: Shape): Writable? = + if (model.lookup("com.test#NestedStruct\$num") == shape) { + writable("40 + 2") + } else { + null + } + }, + ) val data = Node.parse("""{ "str": "hello", "num": 1 }""") withBlock("let result = ", ";") { sut.render(this, model.lookup("com.test#NestedStruct"), data as ObjectNode) @@ -294,39 +299,43 @@ class ServerInstantiatorTest { unitTest("writable_for_nested_inner_members") { val map = model.lookup("com.test#Inner\$map") - val sut = ServerInstantiator( - codegenContext, - customWritable = object : Instantiator.CustomWritable { - private var n: Int = 0 - override fun generate(shape: Shape): Writable? = - if (shape != map) { - null - } else if (n != 2) { - n += 1 - null - } else { - n += 1 - writable("None") - } - }, - ) - val data = Node.parse( - """ - { - "map": { - "k1": { - "map": { - "k2": { - "map": { - "never": {} + val sut = + ServerInstantiator( + codegenContext, + customWritable = + object : Instantiator.CustomWritable { + private var n: Int = 0 + + override fun generate(shape: Shape): Writable? = + if (shape != map) { + null + } else if (n != 2) { + n += 1 + null + } else { + n += 1 + writable("None") + } + }, + ) + val data = + Node.parse( + """ + { + "map": { + "k1": { + "map": { + "k2": { + "map": { + "never": {} + } } } } } } - } - """, - ) + """, + ) withBlock("let result = ", ";") { sut.render(this, inner, data as ObjectNode) diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationErrorGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationErrorGeneratorTest.kt index a73466277a8..8f0cc80dcb5 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationErrorGeneratorTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationErrorGeneratorTest.kt @@ -19,7 +19,8 @@ import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverRenderWi import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestSymbolProvider class ServerOperationErrorGeneratorTest { - private val baseModel = """ + private val baseModel = + """ namespace error use aws.protocols#restJson1 @@ -52,7 +53,7 @@ class ServerOperationErrorGeneratorTest { @error("server") @deprecated structure Deprecated { } - """.asSmithyModel() + """.asSmithyModel() private val model = OperationNormalizer.transform(baseModel) private val symbolProvider = serverTestSymbolProvider(model) diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceConfigGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceConfigGeneratorTest.kt index c2c568b291c..63f1f7385d2 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceConfigGeneratorTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceConfigGeneratorTest.kt @@ -24,62 +24,69 @@ internal class ServiceConfigGeneratorTest { fun `it should inject an aws_auth method that configures an HTTP plugin and a model plugin`() { val model = File("../codegen-core/common-test-models/simple.smithy").readText().asSmithyModel() - val decorator = object : ServerCodegenDecorator { - override val name: String - get() = "AWSAuth pre-applied middleware decorator" - override val order: Byte - get() = -69 - - override fun configMethods(codegenContext: ServerCodegenContext): List { - val smithyHttpServer = ServerCargoDependency.smithyHttpServer(codegenContext.runtimeConfig).toType() - val codegenScope = arrayOf( - "SmithyHttpServer" to smithyHttpServer, - ) - return listOf( - ConfigMethod( - name = "aws_auth", - docs = "Docs", - params = listOf( - Binding("auth_spec", RuntimeType.String), - Binding("authorizer", RuntimeType.U64), - ), - errorType = RuntimeType.std.resolve("io::Error"), - initializer = Initializer( - code = writable { - rustTemplate( - """ - if authorizer != 69 { - return Err(std::io::Error::new(std::io::ErrorKind::Other, "failure 1")); - } - - if auth_spec.len() != 69 { - return Err(std::io::Error::new(std::io::ErrorKind::Other, "failure 2")); - } - let authn_plugin = #{SmithyHttpServer}::plugin::IdentityPlugin; - let authz_plugin = #{SmithyHttpServer}::plugin::IdentityPlugin; - """, - *codegenScope, - ) - }, - layerBindings = emptyList(), - httpPluginBindings = listOf( - Binding( - "authn_plugin", - smithyHttpServer.resolve("plugin::IdentityPlugin"), + val decorator = + object : ServerCodegenDecorator { + override val name: String + get() = "AWSAuth pre-applied middleware decorator" + override val order: Byte + get() = -69 + + override fun configMethods(codegenContext: ServerCodegenContext): List { + val smithyHttpServer = ServerCargoDependency.smithyHttpServer(codegenContext.runtimeConfig).toType() + val codegenScope = + arrayOf( + "SmithyHttpServer" to smithyHttpServer, + ) + return listOf( + ConfigMethod( + name = "aws_auth", + docs = "Docs", + params = + listOf( + Binding("auth_spec", RuntimeType.String), + Binding("authorizer", RuntimeType.U64), ), - ), - modelPluginBindings = listOf( - Binding( - "authz_plugin", - smithyHttpServer.resolve("plugin::IdentityPlugin"), + errorType = RuntimeType.std.resolve("io::Error"), + initializer = + Initializer( + code = + writable { + rustTemplate( + """ + if authorizer != 69 { + return Err(std::io::Error::new(std::io::ErrorKind::Other, "failure 1")); + } + + if auth_spec.len() != 69 { + return Err(std::io::Error::new(std::io::ErrorKind::Other, "failure 2")); + } + let authn_plugin = #{SmithyHttpServer}::plugin::IdentityPlugin; + let authz_plugin = #{SmithyHttpServer}::plugin::IdentityPlugin; + """, + *codegenScope, + ) + }, + layerBindings = emptyList(), + httpPluginBindings = + listOf( + Binding( + "authn_plugin", + smithyHttpServer.resolve("plugin::IdentityPlugin"), + ), + ), + modelPluginBindings = + listOf( + Binding( + "authz_plugin", + smithyHttpServer.resolve("plugin::IdentityPlugin"), + ), + ), ), - ), + isRequired = true, ), - isRequired = true, - ), - ) + ) + } } - } serverIntegrationTest(model, additionalDecorators = listOf(decorator)) { _, rustCrate -> rustCrate.testModule { @@ -150,47 +157,52 @@ internal class ServiceConfigGeneratorTest { fun `it should inject an method that applies three non-required layers`() { val model = File("../codegen-core/common-test-models/simple.smithy").readText().asSmithyModel() - val decorator = object : ServerCodegenDecorator { - override val name: String - get() = "ApplyThreeNonRequiredLayers" - override val order: Byte - get() = 69 - - override fun configMethods(codegenContext: ServerCodegenContext): List { - val identityLayer = RuntimeType.Tower.resolve("layer::util::Identity") - val codegenScope = arrayOf( - "Identity" to identityLayer, - ) - return listOf( - ConfigMethod( - name = "three_non_required_layers", - docs = "Docs", - params = emptyList(), - errorType = null, - initializer = Initializer( - code = writable { - rustTemplate( - """ - let layer1 = #{Identity}::new(); - let layer2 = #{Identity}::new(); - let layer3 = #{Identity}::new(); - """, - *codegenScope, - ) - }, - layerBindings = listOf( - Binding("layer1", identityLayer), - Binding("layer2", identityLayer), - Binding("layer3", identityLayer), - ), - httpPluginBindings = emptyList(), - modelPluginBindings = emptyList(), + val decorator = + object : ServerCodegenDecorator { + override val name: String + get() = "ApplyThreeNonRequiredLayers" + override val order: Byte + get() = 69 + + override fun configMethods(codegenContext: ServerCodegenContext): List { + val identityLayer = RuntimeType.Tower.resolve("layer::util::Identity") + val codegenScope = + arrayOf( + "Identity" to identityLayer, + ) + return listOf( + ConfigMethod( + name = "three_non_required_layers", + docs = "Docs", + params = emptyList(), + errorType = null, + initializer = + Initializer( + code = + writable { + rustTemplate( + """ + let layer1 = #{Identity}::new(); + let layer2 = #{Identity}::new(); + let layer3 = #{Identity}::new(); + """, + *codegenScope, + ) + }, + layerBindings = + listOf( + Binding("layer1", identityLayer), + Binding("layer2", identityLayer), + Binding("layer3", identityLayer), + ), + httpPluginBindings = emptyList(), + modelPluginBindings = emptyList(), + ), + isRequired = false, ), - isRequired = false, - ), - ) + ) + } } - } serverIntegrationTest(model, additionalDecorators = listOf(decorator)) { _, rustCrate -> rustCrate.testModule { diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/RecursiveConstraintViolationBoxerTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/RecursiveConstraintViolationBoxerTest.kt index 622d3806039..3fe7cc75db7 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/RecursiveConstraintViolationBoxerTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/RecursiveConstraintViolationBoxerTest.kt @@ -20,9 +20,10 @@ internal class RecursiveConstraintViolationBoxerTest { fun `recursive constraint violation boxer test`(testCase: RecursiveConstraintViolationsTest.TestCase) { val transformed = RecursiveConstraintViolationBoxer.transform(testCase.model) - val shapesWithConstraintViolationRustBoxTrait = transformed.shapes().filter { - it.hasTrait() - }.toList() + val shapesWithConstraintViolationRustBoxTrait = + transformed.shapes().filter { + it.hasTrait() + }.toList() // Only the provided member shape should have the trait attached. shapesWithConstraintViolationRustBoxTrait shouldBe diff --git a/codegen-server/typescript/build.gradle.kts b/codegen-server/typescript/build.gradle.kts index 81f543c2220..39ca407a2d6 100644 --- a/codegen-server/typescript/build.gradle.kts +++ b/codegen-server/typescript/build.gradle.kts @@ -29,7 +29,7 @@ dependencies { implementation("software.amazon.smithy:smithy-protocol-test-traits:$smithyVersion") } -tasks.compileKotlin { kotlinOptions.jvmTarget = "1.8" } +tasks.compileKotlin { kotlinOptions.jvmTarget = "11" } // Reusable license copySpec val licenseSpec = copySpec { @@ -60,7 +60,7 @@ if (isTestingEnabled.toBoolean()) { testImplementation("io.kotest:kotest-assertions-core-jvm:$kotestVersion") } - tasks.compileTestKotlin { kotlinOptions.jvmTarget = "1.8" } + tasks.compileTestKotlin { kotlinOptions.jvmTarget = "11" } tasks.test { useJUnitPlatform() diff --git a/codegen-server/typescript/src/main/kotlin/software/amazon/smithy/rust/codegen/server/typescript/smithy/RustServerCodegenTsPlugin.kt b/codegen-server/typescript/src/main/kotlin/software/amazon/smithy/rust/codegen/server/typescript/smithy/RustServerCodegenTsPlugin.kt index 10aafec81bb..5cbf3732f9c 100644 --- a/codegen-server/typescript/src/main/kotlin/software/amazon/smithy/rust/codegen/server/typescript/smithy/RustServerCodegenTsPlugin.kt +++ b/codegen-server/typescript/src/main/kotlin/software/amazon/smithy/rust/codegen/server/typescript/smithy/RustServerCodegenTsPlugin.kt @@ -78,28 +78,27 @@ class RustServerCodegenTsPlugin : SmithyBuildPlugin { constrainedTypes: Boolean = true, includeConstrainedShapeProvider: Boolean = true, codegenDecorator: ServerCodegenDecorator, - ) = - TsServerSymbolVisitor(settings, model, serviceShape = serviceShape, config = rustSymbolProviderConfig) - // Generate public constrained types for directly constrained shapes. - // In the Typescript server project, this is only done to generate constrained types for simple shapes (e.g. - // a `string` shape with the `length` trait), but these always remain `pub(crate)`. - .let { - if (includeConstrainedShapeProvider) ConstrainedShapeSymbolProvider(it, serviceShape, constrainedTypes) else it - } - // Generate different types for EventStream shapes (e.g. transcribe streaming) - .let { EventStreamSymbolProvider(rustSymbolProviderConfig.runtimeConfig, it, CodegenTarget.SERVER) } - // Add Rust attributes (like `#[derive(PartialEq)]`) to generated shapes - .let { BaseSymbolMetadataProvider(it, additionalAttributes = listOf()) } - // Constrained shapes generate newtypes that need the same derives we place on types generated from aggregate shapes. - .let { ConstrainedShapeSymbolMetadataProvider(it, constrainedTypes) } - // Streaming shapes need different derives (e.g. they cannot derive Eq) - .let { TsStreamingShapeMetadataProvider(it) } - // Derive `Eq` and `Hash` if possible. - .let { DeriveEqAndHashSymbolMetadataProvider(it) } - // Rename shapes that clash with Rust reserved words & and other SDK specific features e.g. `send()` cannot - // be the name of an operation input - .let { RustReservedWordSymbolProvider(it, ServerReservedWords) } - // Allows decorators to inject a custom symbol provider - .let { codegenDecorator.symbolProvider(it) } + ) = TsServerSymbolVisitor(settings, model, serviceShape = serviceShape, config = rustSymbolProviderConfig) + // Generate public constrained types for directly constrained shapes. + // In the Typescript server project, this is only done to generate constrained types for simple shapes (e.g. + // a `string` shape with the `length` trait), but these always remain `pub(crate)`. + .let { + if (includeConstrainedShapeProvider) ConstrainedShapeSymbolProvider(it, serviceShape, constrainedTypes) else it + } + // Generate different types for EventStream shapes (e.g. transcribe streaming) + .let { EventStreamSymbolProvider(rustSymbolProviderConfig.runtimeConfig, it, CodegenTarget.SERVER) } + // Add Rust attributes (like `#[derive(PartialEq)]`) to generated shapes + .let { BaseSymbolMetadataProvider(it, additionalAttributes = listOf()) } + // Constrained shapes generate newtypes that need the same derives we place on types generated from aggregate shapes. + .let { ConstrainedShapeSymbolMetadataProvider(it, constrainedTypes) } + // Streaming shapes need different derives (e.g. they cannot derive Eq) + .let { TsStreamingShapeMetadataProvider(it) } + // Derive `Eq` and `Hash` if possible. + .let { DeriveEqAndHashSymbolMetadataProvider(it) } + // Rename shapes that clash with Rust reserved words & and other SDK specific features e.g. `send()` cannot + // be the name of an operation input + .let { RustReservedWordSymbolProvider(it, ServerReservedWords) } + // Allows decorators to inject a custom symbol provider + .let { codegenDecorator.symbolProvider(it) } } } diff --git a/codegen-server/typescript/src/main/kotlin/software/amazon/smithy/rust/codegen/server/typescript/smithy/TsServerCargoDependency.kt b/codegen-server/typescript/src/main/kotlin/software/amazon/smithy/rust/codegen/server/typescript/smithy/TsServerCargoDependency.kt index 5c3b4c22093..c23a59433b0 100644 --- a/codegen-server/typescript/src/main/kotlin/software/amazon/smithy/rust/codegen/server/typescript/smithy/TsServerCargoDependency.kt +++ b/codegen-server/typescript/src/main/kotlin/software/amazon/smithy/rust/codegen/server/typescript/smithy/TsServerCargoDependency.kt @@ -23,11 +23,14 @@ object TsServerCargoDependency { val Tracing: CargoDependency = CargoDependency("tracing", CratesIo("0.1")) val Tower: CargoDependency = CargoDependency("tower", CratesIo("0.4")) val TowerHttp: CargoDependency = CargoDependency("tower-http", CratesIo("0.3"), features = setOf("trace")) - val Hyper: CargoDependency = CargoDependency("hyper", CratesIo("0.14.12"), features = setOf("server", "http1", "http2", "tcp", "stream")) + val Hyper: CargoDependency = + CargoDependency("hyper", CratesIo("0.14.12"), features = setOf("server", "http1", "http2", "tcp", "stream")) val NumCpus: CargoDependency = CargoDependency("num_cpus", CratesIo("1.13")) val ParkingLot: CargoDependency = CargoDependency("parking_lot", CratesIo("0.12")) val Socket2: CargoDependency = CargoDependency("socket2", CratesIo("0.4")) fun smithyHttpServer(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-http-server") - fun smithyHttpServerTs(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-http-server-typescript") + + fun smithyHttpServerTs(runtimeConfig: RuntimeConfig) = + runtimeConfig.smithyRuntimeCrate("smithy-http-server-typescript") } diff --git a/codegen-server/typescript/src/main/kotlin/software/amazon/smithy/rust/codegen/server/typescript/smithy/TsServerCodegenVisitor.kt b/codegen-server/typescript/src/main/kotlin/software/amazon/smithy/rust/codegen/server/typescript/smithy/TsServerCodegenVisitor.kt index b1513b1be35..d47ed0b09a8 100644 --- a/codegen-server/typescript/src/main/kotlin/software/amazon/smithy/rust/codegen/server/typescript/smithy/TsServerCodegenVisitor.kt +++ b/codegen-server/typescript/src/main/kotlin/software/amazon/smithy/rust/codegen/server/typescript/smithy/TsServerCodegenVisitor.kt @@ -49,7 +49,6 @@ class TsServerCodegenVisitor( context: PluginContext, private val codegenDecorator: ServerCodegenDecorator, ) : ServerCodegenVisitor(context, codegenDecorator) { - init { val symbolVisitorConfig = RustSymbolProviderConfig( @@ -84,17 +83,19 @@ class TsServerCodegenVisitor( publicConstrainedTypes: Boolean, includeConstraintShapeProvider: Boolean, codegenDecorator: ServerCodegenDecorator, - ) = RustServerCodegenTsPlugin.baseSymbolProvider(settings, model, serviceShape, rustSymbolProviderConfig, publicConstrainedTypes, includeConstraintShapeProvider, codegenDecorator) - - val serverSymbolProviders = ServerSymbolProviders.from( - settings, - model, - service, - symbolVisitorConfig, - settings.codegenConfig.publicConstrainedTypes, - codegenDecorator, - ::baseSymbolProviderFactory, - ) + ) = + RustServerCodegenTsPlugin.baseSymbolProvider(settings, model, serviceShape, rustSymbolProviderConfig, publicConstrainedTypes, includeConstraintShapeProvider, codegenDecorator) + + val serverSymbolProviders = + ServerSymbolProviders.from( + settings, + model, + service, + symbolVisitorConfig, + settings.codegenConfig.publicConstrainedTypes, + codegenDecorator, + ::baseSymbolProviderFactory, + ) // Override `codegenContext` which carries the various symbol providers. codegenContext = @@ -111,18 +112,21 @@ class TsServerCodegenVisitor( serverSymbolProviders.pubCrateConstrainedShapeSymbolProvider, ) - codegenContext = codegenContext.copy( - moduleDocProvider = codegenDecorator.moduleDocumentationCustomization( - codegenContext, - TsServerModuleDocProvider(ServerModuleDocProvider(codegenContext)), - ), - ) + codegenContext = + codegenContext.copy( + moduleDocProvider = + codegenDecorator.moduleDocumentationCustomization( + codegenContext, + TsServerModuleDocProvider(ServerModuleDocProvider(codegenContext)), + ), + ) // Override `rustCrate` which carries the symbolProvider. - rustCrate = RustCrate( - context.fileManifest, codegenContext.symbolProvider, settings.codegenConfig, - codegenContext.expectModuleDocProvider(), - ) + rustCrate = + RustCrate( + context.fileManifest, codegenContext.symbolProvider, settings.codegenConfig, + codegenContext.expectModuleDocProvider(), + ) // Override `protocolGenerator` which carries the symbolProvider. protocolGenerator = protocolGeneratorFactory.buildProtocolGenerator(codegenContext) } @@ -164,8 +168,10 @@ class TsServerCodegenVisitor( * Although raw strings require no code generation, enums are actually [EnumTrait] applied to string shapes. */ override fun stringShape(shape: StringShape) { - fun tsServerEnumGeneratorFactory(codegenContext: ServerCodegenContext, shape: StringShape) = - TsServerEnumGenerator(codegenContext, shape, validationExceptionConversionGenerator) + fun tsServerEnumGeneratorFactory( + codegenContext: ServerCodegenContext, + shape: StringShape, + ) = TsServerEnumGenerator(codegenContext, shape, validationExceptionConversionGenerator) stringShape(shape, ::tsServerEnumGeneratorFactory) } diff --git a/codegen-server/typescript/src/main/kotlin/software/amazon/smithy/rust/codegen/server/typescript/smithy/TsServerSymbolProvider.kt b/codegen-server/typescript/src/main/kotlin/software/amazon/smithy/rust/codegen/server/typescript/smithy/TsServerSymbolProvider.kt index a7ca74d064c..cc042767db9 100644 --- a/codegen-server/typescript/src/main/kotlin/software/amazon/smithy/rust/codegen/server/typescript/smithy/TsServerSymbolProvider.kt +++ b/codegen-server/typescript/src/main/kotlin/software/amazon/smithy/rust/codegen/server/typescript/smithy/TsServerSymbolProvider.kt @@ -120,6 +120,7 @@ class TsStreamingShapeMetadataProvider(private val base: RustSymbolProvider) : S } override fun memberMeta(memberShape: MemberShape) = base.toSymbol(memberShape).expectRustMetadata() + override fun enumMeta(stringShape: StringShape): RustMetadata = RustMetadata( setOf(RuntimeType.Eq, RuntimeType.Ord, RuntimeType.PartialEq, RuntimeType.PartialOrd, RuntimeType.Debug), @@ -128,8 +129,12 @@ class TsStreamingShapeMetadataProvider(private val base: RustSymbolProvider) : S ) override fun listMeta(listShape: ListShape) = base.toSymbol(listShape).expectRustMetadata() + override fun mapMeta(mapShape: MapShape) = base.toSymbol(mapShape).expectRustMetadata() + override fun stringMeta(stringShape: StringShape) = base.toSymbol(stringShape).expectRustMetadata() + override fun numberMeta(numberShape: NumberShape) = base.toSymbol(numberShape).expectRustMetadata() + override fun blobMeta(blobShape: BlobShape) = base.toSymbol(blobShape).expectRustMetadata() } diff --git a/codegen-server/typescript/src/main/kotlin/software/amazon/smithy/rust/codegen/server/typescript/smithy/customizations/TsServerCodegenDecorator.kt b/codegen-server/typescript/src/main/kotlin/software/amazon/smithy/rust/codegen/server/typescript/smithy/customizations/TsServerCodegenDecorator.kt index 885bac4a9a0..00e42d7d511 100644 --- a/codegen-server/typescript/src/main/kotlin/software/amazon/smithy/rust/codegen/server/typescript/smithy/customizations/TsServerCodegenDecorator.kt +++ b/codegen-server/typescript/src/main/kotlin/software/amazon/smithy/rust/codegen/server/typescript/smithy/customizations/TsServerCodegenDecorator.kt @@ -25,13 +25,12 @@ class CdylibManifestDecorator : ServerCodegenDecorator { override val name: String = "CdylibDecorator" override val order: Byte = 0 - override fun crateManifestCustomizations( - codegenContext: ServerCodegenContext, - ): ManifestCustomizations = + override fun crateManifestCustomizations(codegenContext: ServerCodegenContext): ManifestCustomizations = mapOf( - "lib" to mapOf( - "crate-type" to listOf("cdylib"), - ), + "lib" to + mapOf( + "crate-type" to listOf("cdylib"), + ), ) } @@ -40,7 +39,10 @@ class NapiBuildRsDecorator : ServerCodegenDecorator { override val order: Byte = 0 private val napi_build = TsServerCargoDependency.NapiBuild.toType() - override fun extras(codegenContext: ServerCodegenContext, rustCrate: RustCrate) { + override fun extras( + codegenContext: ServerCodegenContext, + rustCrate: RustCrate, + ) { rustCrate.withFile("build.rs") { rustTemplate( """ @@ -58,7 +60,10 @@ class NapiPackageJsonDecorator : ServerCodegenDecorator { override val name: String = "NapiPackageJsonDecorator" override val order: Byte = 0 - override fun extras(codegenContext: ServerCodegenContext, rustCrate: RustCrate) { + override fun extras( + codegenContext: ServerCodegenContext, + rustCrate: RustCrate, + ) { val name = codegenContext.settings.moduleName.toSnakeCase() val version = codegenContext.settings.moduleVersion @@ -99,16 +104,17 @@ class NapiPackageJsonDecorator : ServerCodegenDecorator { } } -val DECORATORS = arrayOf( - /** - * Add the [InternalServerError] error to all operations. - * This is done because the Typescript interpreter can raise eceptions during execution. - */ - AddInternalServerErrorToAllOperationsDecorator(), - // Add the [lib] section to Cargo.toml to configure the generation of the shared library. - CdylibManifestDecorator(), - // Add the build.rs file needed to generate Typescript code. - NapiBuildRsDecorator(), - // Add the napi package.json. - NapiPackageJsonDecorator(), -) +val DECORATORS = + arrayOf( + /* + * Add the [InternalServerError] error to all operations. + * This is done because the Typescript interpreter can raise eceptions during execution. + */ + AddInternalServerErrorToAllOperationsDecorator(), + // Add the [lib] section to Cargo.toml to configure the generation of the shared library. + CdylibManifestDecorator(), + // Add the build.rs file needed to generate Typescript code. + NapiBuildRsDecorator(), + // Add the napi package.json. + NapiPackageJsonDecorator(), + ) diff --git a/codegen-server/typescript/src/main/kotlin/software/amazon/smithy/rust/codegen/server/typescript/smithy/generators/TsApplicationGenerator.kt b/codegen-server/typescript/src/main/kotlin/software/amazon/smithy/rust/codegen/server/typescript/smithy/generators/TsApplicationGenerator.kt index 92e346fa918..7809c9562c1 100644 --- a/codegen-server/typescript/src/main/kotlin/software/amazon/smithy/rust/codegen/server/typescript/smithy/generators/TsApplicationGenerator.kt +++ b/codegen-server/typescript/src/main/kotlin/software/amazon/smithy/rust/codegen/server/typescript/smithy/generators/TsApplicationGenerator.kt @@ -29,11 +29,12 @@ class TsApplicationGenerator( private val protocol: ServerProtocol, ) { private val index = TopDownIndex.of(codegenContext.model) - private val operations = index.getContainedOperations(codegenContext.serviceShape).toSortedSet( - compareBy { - it.id - }, - ).toList() + private val operations = + index.getContainedOperations(codegenContext.serviceShape).toSortedSet( + compareBy { + it.id + }, + ).toList() private val symbolProvider = codegenContext.symbolProvider private val libName = codegenContext.settings.moduleName.toSnakeCase() private val runtimeConfig = codegenContext.runtimeConfig @@ -263,7 +264,8 @@ class TsApplicationGenerator( writer.rustBlock("impl TsSocket") { writer.rustTemplate( - """pub fn to_raw_socket(&self) -> #{napi}::Result<#{socket2}::Socket> { + """ + pub fn to_raw_socket(&self) -> #{napi}::Result<#{socket2}::Socket> { self.0 .try_clone() .map_err(|e| #{napi}::Error::from_reason(e.to_string())) @@ -277,7 +279,8 @@ class TsApplicationGenerator( private fun renderServer(writer: RustWriter) { writer.rustBlockTemplate( - """pub fn start_hyper_worker( + """ + pub fn start_hyper_worker( socket: &TsSocket, app: #{tower}::util::BoxCloneService< #{http}::Request<#{SmithyServer}::body::Body>, diff --git a/codegen-server/typescript/src/main/kotlin/software/amazon/smithy/rust/codegen/server/typescript/smithy/generators/TsServerEnumGenerator.kt b/codegen-server/typescript/src/main/kotlin/software/amazon/smithy/rust/codegen/server/typescript/smithy/generators/TsServerEnumGenerator.kt index f12ebb5ef68..aef08a3484b 100644 --- a/codegen-server/typescript/src/main/kotlin/software/amazon/smithy/rust/codegen/server/typescript/smithy/generators/TsServerEnumGenerator.kt +++ b/codegen-server/typescript/src/main/kotlin/software/amazon/smithy/rust/codegen/server/typescript/smithy/generators/TsServerEnumGenerator.kt @@ -28,9 +28,10 @@ class TsConstrainedEnum( ) : ConstrainedEnum(codegenContext, shape, validationExceptionConversionGenerator) { private val napiDerive = TsServerCargoDependency.NapiDerive.toType() - override fun additionalEnumImpls(context: EnumGeneratorContext): Writable = writable { - this.rust("use napi::bindgen_prelude::ToNapiValue;") - } + override fun additionalEnumImpls(context: EnumGeneratorContext): Writable = + writable { + this.rust("use napi::bindgen_prelude::ToNapiValue;") + } override fun additionalEnumAttributes(context: EnumGeneratorContext): List = listOf(Attribute(napiDerive.resolve("napi"))) @@ -41,12 +42,12 @@ class TsServerEnumGenerator( shape: StringShape, validationExceptionConversionGenerator: ValidationExceptionConversionGenerator, ) : EnumGenerator( - codegenContext.model, - codegenContext.symbolProvider, - shape, - TsConstrainedEnum( - codegenContext, + codegenContext.model, + codegenContext.symbolProvider, shape, - validationExceptionConversionGenerator, - ), -) + TsConstrainedEnum( + codegenContext, + shape, + validationExceptionConversionGenerator, + ), + ) diff --git a/codegen-server/typescript/src/main/kotlin/software/amazon/smithy/rust/codegen/server/typescript/smithy/generators/TsServerStructureGenerator.kt b/codegen-server/typescript/src/main/kotlin/software/amazon/smithy/rust/codegen/server/typescript/smithy/generators/TsServerStructureGenerator.kt index cf08892bc9b..9bdf53db1a5 100644 --- a/codegen-server/typescript/src/main/kotlin/software/amazon/smithy/rust/codegen/server/typescript/smithy/generators/TsServerStructureGenerator.kt +++ b/codegen-server/typescript/src/main/kotlin/software/amazon/smithy/rust/codegen/server/typescript/smithy/generators/TsServerStructureGenerator.kt @@ -29,15 +29,15 @@ class TsServerStructureGenerator( private val writer: RustWriter, private val shape: StructureShape, ) : StructureGenerator(model, symbolProvider, writer, shape, listOf(), StructSettings(flattenVecAccessors = false)) { - private val napiDerive = TsServerCargoDependency.NapiDerive.toType() override fun renderStructure() { - val flavour = if (shape.hasTrait()) { - "constructor" - } else { - "object" - } + val flavour = + if (shape.hasTrait()) { + "constructor" + } else { + "object" + } Attribute( writable { rustInlineTemplate( diff --git a/gradle.properties b/gradle.properties index ef7900cdc14..ce30214ca8b 100644 --- a/gradle.properties +++ b/gradle.properties @@ -24,11 +24,11 @@ smithyGradlePluginVersion=0.7.0 smithyVersion=1.40.0 # kotlin -kotlinVersion=1.7.21 +kotlinVersion=1.9.20 # testing/utility -ktlintVersion=0.48.2 -kotestVersion=5.2.3 +ktlintVersion=1.0.1 +kotestVersion=5.8.0 # Avoid registering dependencies/plugins/tasks that are only used for testing purposes isTestingEnabled=true