diff --git a/ktor-server/ktor-server-core/api/ktor-server-core.api b/ktor-server/ktor-server-core/api/ktor-server-core.api index 8e3fa2229d..d719c85ddc 100644 --- a/ktor-server/ktor-server-core/api/ktor-server-core.api +++ b/ktor-server/ktor-server-core/api/ktor-server-core.api @@ -895,7 +895,9 @@ public final class io/ktor/server/http/content/StaticContentResolutionKt { public final class io/ktor/server/http/content/SuppressionAttributeKt { public static final fun getSuppressionAttribute ()Lio/ktor/util/AttributeKey; public static final fun isCompressionSuppressed (Lio/ktor/server/application/ApplicationCall;)Z + public static final fun isDecompressionSuppressed (Lio/ktor/server/application/ApplicationCall;)Z public static final fun suppressCompression (Lio/ktor/server/application/ApplicationCall;)V + public static final fun suppressDecompression (Lio/ktor/server/application/ApplicationCall;)V } public final class io/ktor/server/logging/LoggingKt { diff --git a/ktor-server/ktor-server-core/api/ktor-server-core.klib.api b/ktor-server/ktor-server-core/api/ktor-server-core.klib.api index 464607d402..9348ed2193 100644 --- a/ktor-server/ktor-server-core/api/ktor-server-core.klib.api +++ b/ktor-server/ktor-server-core/api/ktor-server-core.klib.api @@ -1637,6 +1637,8 @@ final val io.ktor.server.http.content/SuppressionAttribute // io.ktor.server.htt final fun (): io.ktor.util/AttributeKey // io.ktor.server.http.content/SuppressionAttribute.|(){}[0] final val io.ktor.server.http.content/isCompressionSuppressed // io.ktor.server.http.content/isCompressionSuppressed|@io.ktor.server.application.ApplicationCall{}isCompressionSuppressed[0] final fun (io.ktor.server.application/ApplicationCall).(): kotlin/Boolean // io.ktor.server.http.content/isCompressionSuppressed.|@io.ktor.server.application.ApplicationCall(){}[0] +final val io.ktor.server.http.content/isDecompressionSuppressed // io.ktor.server.http.content/isDecompressionSuppressed|@io.ktor.server.application.ApplicationCall{}isDecompressionSuppressed[0] + final fun (io.ktor.server.application/ApplicationCall).(): kotlin/Boolean // io.ktor.server.http.content/isDecompressionSuppressed.|@io.ktor.server.application.ApplicationCall(){}[0] final val io.ktor.server.logging/mdcProvider // io.ktor.server.logging/mdcProvider|@io.ktor.server.application.Application{}mdcProvider[0] final fun (io.ktor.server.application/Application).(): io.ktor.server.logging/MDCProvider // io.ktor.server.logging/mdcProvider.|@io.ktor.server.application.Application(){}[0] final val io.ktor.server.plugins/MutableOriginConnectionPointKey // io.ktor.server.plugins/MutableOriginConnectionPointKey|{}MutableOriginConnectionPointKey[0] @@ -1672,6 +1674,7 @@ final fun (io.ktor.http/HeadersBuilder).io.ktor.server.response/contentRange(kot final fun (io.ktor.http/URLBuilder.Companion).io.ktor.server.util/createFromCall(io.ktor.server.application/ApplicationCall): io.ktor.http/URLBuilder // io.ktor.server.util/createFromCall|createFromCall@io.ktor.http.URLBuilder.Companion(io.ktor.server.application.ApplicationCall){}[0] final fun (io.ktor.server.application/Application).io.ktor.server.routing/routing(kotlin/Function1): io.ktor.server.routing/RoutingRoot // io.ktor.server.routing/routing|routing@io.ktor.server.application.Application(kotlin.Function1){}[0] final fun (io.ktor.server.application/ApplicationCall).io.ktor.server.http.content/suppressCompression() // io.ktor.server.http.content/suppressCompression|suppressCompression@io.ktor.server.application.ApplicationCall(){}[0] +final fun (io.ktor.server.application/ApplicationCall).io.ktor.server.http.content/suppressDecompression() // io.ktor.server.http.content/suppressDecompression|suppressDecompression@io.ktor.server.application.ApplicationCall(){}[0] final fun (io.ktor.server.application/ApplicationCall).io.ktor.server.http/push(kotlin/Function1) // io.ktor.server.http/push|push@io.ktor.server.application.ApplicationCall(kotlin.Function1){}[0] final fun (io.ktor.server.application/ApplicationCall).io.ktor.server.http/push(kotlin/String) // io.ktor.server.http/push|push@io.ktor.server.application.ApplicationCall(kotlin.String){}[0] final fun (io.ktor.server.application/ApplicationCall).io.ktor.server.http/push(kotlin/String, io.ktor.http/Parameters) // io.ktor.server.http/push|push@io.ktor.server.application.ApplicationCall(kotlin.String;io.ktor.http.Parameters){}[0] diff --git a/ktor-server/ktor-server-core/common/src/io/ktor/server/application/KtorCallContexts.kt b/ktor-server/ktor-server-core/common/src/io/ktor/server/application/KtorCallContexts.kt index 641478b9c1..dc74db9cd9 100644 --- a/ktor-server/ktor-server-core/common/src/io/ktor/server/application/KtorCallContexts.kt +++ b/ktor-server/ktor-server-core/common/src/io/ktor/server/application/KtorCallContexts.kt @@ -59,7 +59,6 @@ public class OnCallReceiveContext internal constructor( public suspend fun transformBody(transform: suspend TransformBodyContext.(body: ByteReadChannel) -> Any) { val receiveBody = context.subject as? ByteReadChannel ?: return val typeInfo = context.call.receiveType - if (typeInfo == typeInfo()) return val transformContext = TransformBodyContext(typeInfo) context.subject = transformContext.transform(receiveBody) diff --git a/ktor-server/ktor-server-core/common/src/io/ktor/server/http/content/SuppressionAttribute.kt b/ktor-server/ktor-server-core/common/src/io/ktor/server/http/content/SuppressionAttribute.kt index dc93bd309a..53712f45b8 100644 --- a/ktor-server/ktor-server-core/common/src/io/ktor/server/http/content/SuppressionAttribute.kt +++ b/ktor-server/ktor-server-core/common/src/io/ktor/server/http/content/SuppressionAttribute.kt @@ -15,6 +15,7 @@ import io.ktor.util.* level = DeprecationLevel.ERROR ) public val SuppressionAttribute: AttributeKey = AttributeKey("preventCompression") +internal val DecompressionSuppressionAttribute: AttributeKey = AttributeKey("preventDecompression") /** * Suppress response body compression plugin for this [ApplicationCall]. @@ -24,8 +25,23 @@ public fun ApplicationCall.suppressCompression() { attributes.put(SuppressionAttribute, true) } +/** + * Suppresses the decompression for the current application call. + */ +public fun ApplicationCall.suppressDecompression() { + attributes.put(DecompressionSuppressionAttribute, true) +} + /** * Checks if response body compression is suppressed for this [ApplicationCall]. */ @Suppress("DEPRECATION_ERROR") public val ApplicationCall.isCompressionSuppressed: Boolean get() = SuppressionAttribute in attributes + +/** + * Indicates whether decompression is suppressed for the current application call. + * If decompression is suppressed, the plugin will not decompress the request body. + * + * To suppress decompression, use [suppressDecompression]. + */ +public val ApplicationCall.isDecompressionSuppressed: Boolean get() = DecompressionSuppressionAttribute in attributes diff --git a/ktor-server/ktor-server-plugins/ktor-server-compression/jvm/src/io/ktor/server/plugins/compression/Compression.kt b/ktor-server/ktor-server-plugins/ktor-server-compression/jvm/src/io/ktor/server/plugins/compression/Compression.kt index ed1c18cd52..a20db3782a 100644 --- a/ktor-server/ktor-server-plugins/ktor-server-compression/jvm/src/io/ktor/server/plugins/compression/Compression.kt +++ b/ktor-server/ktor-server-plugins/ktor-server-compression/jvm/src/io/ktor/server/plugins/compression/Compression.kt @@ -5,7 +5,6 @@ package io.ktor.server.plugins.compression import io.ktor.http.* -import io.ktor.http.cio.* import io.ktor.http.content.* import io.ktor.server.application.* import io.ktor.server.http.content.* @@ -13,9 +12,15 @@ import io.ktor.server.request.* import io.ktor.server.response.* import io.ktor.util.* import io.ktor.util.logging.* -import io.ktor.util.pipeline.* import io.ktor.utils.io.* -import kotlinx.coroutines.* + +/** + * List of [ContentEncoder] names that were used to decode request body. + */ +public val ApplicationRequest.appliedDecoders: List + get() = call.attributes.getOrNull(DecompressionListAttribute) ?: emptyList() + +internal val DecompressionListAttribute: AttributeKey> = AttributeKey("DecompressionListAttribute") internal val LOGGER = KtorSimpleLogger("io.ktor.server.plugins.compression.Compression") @@ -24,27 +29,6 @@ internal val LOGGER = KtorSimpleLogger("io.ktor.server.plugins.compression.Compr */ internal const val DEFAULT_MINIMAL_COMPRESSION_SIZE: Long = 200L -private object ContentEncoding : Hook Unit> { - - class Context(private val pipelineContext: PipelineContext) { - fun transformBody(block: (OutgoingContent) -> OutgoingContent?) { - val transformedContent = block(pipelineContext.subject as OutgoingContent) - if (transformedContent != null) { - pipelineContext.subject = transformedContent - } - } - } - - override fun install( - pipeline: ApplicationCallPipeline, - handler: suspend Context.(PipelineCall) -> Unit - ) { - pipeline.sendPipeline.intercept(ApplicationSendPipeline.ContentEncoding) { - handler(Context(this), call) - } - } -} - /** * A plugin that provides the capability to compress a response and decompress request bodies. * You can use different compression algorithms, including `gzip` and `deflate`, @@ -82,15 +66,20 @@ public val Compression: RouteScopedPlugin = createRouteScoped if (!mode.response) return@on encode(call, options) } - onCall { call -> - if (!mode.request) return@onCall + + onCallReceive { call -> + if (!mode.request) return@onCallReceive decode(call, options) } } @OptIn(InternalAPI::class) -private fun decode(call: PipelineCall, options: CompressionOptions) { +private suspend fun OnCallReceiveContext.decode(call: PipelineCall, options: CompressionOptions) { val encodingRaw = call.request.headers[HttpHeaders.ContentEncoding] + if (call.isDecompressionSuppressed) { + LOGGER.trace("Skip decompression for ${call.request.uri} because it is suppressed.") + return + } if (encodingRaw == null) { LOGGER.trace("Skip decompression for ${call.request.uri} because no content encoding provided.") return @@ -112,10 +101,12 @@ private fun decode(call: PipelineCall, options: CompressionOptions) { } else { call.request.setHeader(HttpHeaders.ContentEncoding, null) } - val originalChannel = call.request.receiveChannel() - val decoded = encoders.fold(originalChannel) { content, encoder -> encoder.encoder.decode(content) } - call.request.setReceiveChannel(decoded) + call.attributes.put(DecompressionListAttribute, encoderNames) + + transformBody { body -> + encoders.fold(body) { content, encoder -> encoder.encoder.decode(content) } + } } private fun ContentEncoding.Context.encode(call: PipelineCall, options: CompressionOptions) { @@ -180,14 +171,6 @@ private fun ContentEncoding.Context.encode(call: PipelineCall, options: Compress } } -internal val DecompressionListAttribute: AttributeKey> = AttributeKey("DecompressionListAttribute") - -/** - * List of [ContentEncoder] names that were used to decode request body. - */ -public val ApplicationRequest.appliedDecoders: List - get() = call.attributes.getOrNull(DecompressionListAttribute) ?: emptyList() - private fun PipelineResponse.isSSEResponse(): Boolean { val contentType = headers[HttpHeaders.ContentType]?.let { ContentType.parse(it) } return contentType?.withoutParameters() == ContentType.Text.EventStream diff --git a/ktor-server/ktor-server-plugins/ktor-server-compression/jvm/src/io/ktor/server/plugins/compression/ContentEncoding.kt b/ktor-server/ktor-server-plugins/ktor-server-compression/jvm/src/io/ktor/server/plugins/compression/ContentEncoding.kt new file mode 100644 index 0000000000..5a67799e18 --- /dev/null +++ b/ktor-server/ktor-server-plugins/ktor-server-compression/jvm/src/io/ktor/server/plugins/compression/ContentEncoding.kt @@ -0,0 +1,31 @@ +/* + * Copyright 2014-2024 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package io.ktor.server.plugins.compression + +import io.ktor.http.content.* +import io.ktor.server.application.* +import io.ktor.server.response.* +import io.ktor.util.pipeline.* + +internal object ContentEncoding : Hook Unit> { + + class Context(private val pipelineContext: PipelineContext) { + fun transformBody(block: (OutgoingContent) -> OutgoingContent?) { + val transformedContent = block(pipelineContext.subject as OutgoingContent) + if (transformedContent != null) { + pipelineContext.subject = transformedContent + } + } + } + + override fun install( + pipeline: ApplicationCallPipeline, + handler: suspend Context.(PipelineCall) -> Unit + ) { + pipeline.sendPipeline.intercept(ApplicationSendPipeline.ContentEncoding) { + handler(Context(this), call) + } + } +} diff --git a/ktor-server/ktor-server-tests/jvm/test/io/ktor/server/plugins/CompressionTest.kt b/ktor-server/ktor-server-tests/jvm/test/io/ktor/server/plugins/CompressionTest.kt index 4277b345e2..53e5dbe334 100644 --- a/ktor-server/ktor-server-tests/jvm/test/io/ktor/server/plugins/CompressionTest.kt +++ b/ktor-server/ktor-server-tests/jvm/test/io/ktor/server/plugins/CompressionTest.kt @@ -647,24 +647,31 @@ class CompressionTest { install(Compression) routing { post("/identity") { + val message = call.receiveText() assertNull(call.request.headers[HttpHeaders.ContentEncoding]) assertEquals(listOf("identity"), call.request.appliedDecoders) - call.respond(call.receiveText()) + + call.respond(message) } post("/gzip") { + val message = call.receiveText() + assertNull(call.request.headers[HttpHeaders.ContentEncoding]) assertEquals(listOf("gzip"), call.request.appliedDecoders) - call.respond(call.receiveText()) + + call.respond(message) } post("/deflate") { + val message = call.receiveText() assertNull(call.request.headers[HttpHeaders.ContentEncoding]) assertEquals(listOf("deflate"), call.request.appliedDecoders) - call.respond(call.receiveText()) + call.respond(message) } post("/multiple") { + val message = call.receiveText() assertNull(call.request.headers[HttpHeaders.ContentEncoding]) assertEquals(listOf("identity", "deflate", "gzip"), call.request.appliedDecoders) - call.respond(call.receiveText()) + call.respond(message) } post("/unknown") { assertEquals("unknown", call.request.headers[HttpHeaders.ContentEncoding]) @@ -759,9 +766,38 @@ class CompressionTest { } routing { post("/gzip") { + val body = call.receive() + assertNull(call.request.headers[HttpHeaders.ContentEncoding]) + assertContentEquals(compressed, body) + + call.respond(textToCompressAsBytes) + } + } + + val response = client.post("/gzip") { + setBody(compressed) + header(HttpHeaders.ContentEncoding, "gzip") + header(HttpHeaders.AcceptEncoding, "gzip") + } + assertContentEquals(textToCompressAsBytes, response.body()) + } + + @Test + fun testDisableCallEncoding() = testApplication { + val compressed = GZip.encode(ByteReadChannel(textToCompressAsBytes)).readRemaining().readByteArray() + install(Compression) + + routing { + post("/gzip") { + call.suppressCompression() + val body = call.receive() - assertContentEquals(textToCompressAsBytes, body) + + assertEquals("gzip", call.request.appliedDecoders.first()) + assertNull(call.request.headers[HttpHeaders.ContentEncoding]) + assertContentEquals(compressed, body) + call.respond(textToCompressAsBytes) } } @@ -774,6 +810,29 @@ class CompressionTest { assertContentEquals(textToCompressAsBytes, response.body()) } + @Test + fun testDisableCallDecoding() = testApplication { + val compressed = GZip.encode(ByteReadChannel(textToCompressAsBytes)).readRemaining().readByteArray() + + install(Compression) + routing { + post("/gzip") { + call.suppressDecompression() + assertEquals("gzip", call.request.headers[HttpHeaders.ContentEncoding]) + val body = call.receive() + assertContentEquals(compressed, body) + call.respond(textToCompress) + } + } + + val response = client.post("/gzip") { + setBody(compressed) + header(HttpHeaders.ContentEncoding, "gzip") + header(HttpHeaders.AcceptEncoding, "gzip") + } + assertContentEquals(compressed, response.body()) + } + private suspend fun ApplicationTestBuilder.handleAndAssert( url: String, acceptHeader: String?,