Skip to content

Commit

Permalink
KTOR-7679 Allow disabling body decoding on server (#4444)
Browse files Browse the repository at this point in the history
* KTOR-7679 Allow disabling body decoding on server
  • Loading branch information
e5l authored and osipxd committed Jan 9, 2025
1 parent 649043c commit d742cc0
Show file tree
Hide file tree
Showing 7 changed files with 137 additions and 44 deletions.
2 changes: 2 additions & 0 deletions ktor-server/ktor-server-core/api/ktor-server-core.api
Original file line number Diff line number Diff line change
Expand Up @@ -895,7 +895,9 @@ public final class io/ktor/server/http/content/StaticContentResolutionKt {
public final class io/ktor/server/http/content/SuppressionAttributeKt {
public static final fun getSuppressionAttribute ()Lio/ktor/util/AttributeKey;
public static final fun isCompressionSuppressed (Lio/ktor/server/application/ApplicationCall;)Z
public static final fun isDecompressionSuppressed (Lio/ktor/server/application/ApplicationCall;)Z
public static final fun suppressCompression (Lio/ktor/server/application/ApplicationCall;)V
public static final fun suppressDecompression (Lio/ktor/server/application/ApplicationCall;)V
}

public final class io/ktor/server/logging/LoggingKt {
Expand Down
3 changes: 3 additions & 0 deletions ktor-server/ktor-server-core/api/ktor-server-core.klib.api
Original file line number Diff line number Diff line change
Expand Up @@ -1637,6 +1637,8 @@ final val io.ktor.server.http.content/SuppressionAttribute // io.ktor.server.htt
final fun <get-SuppressionAttribute>(): io.ktor.util/AttributeKey<kotlin/Boolean> // io.ktor.server.http.content/SuppressionAttribute.<get-SuppressionAttribute>|<get-SuppressionAttribute>(){}[0]
final val io.ktor.server.http.content/isCompressionSuppressed // io.ktor.server.http.content/isCompressionSuppressed|@io.ktor.server.application.ApplicationCall{}isCompressionSuppressed[0]
final fun (io.ktor.server.application/ApplicationCall).<get-isCompressionSuppressed>(): kotlin/Boolean // io.ktor.server.http.content/isCompressionSuppressed.<get-isCompressionSuppressed>|<get-isCompressionSuppressed>@io.ktor.server.application.ApplicationCall(){}[0]
final val io.ktor.server.http.content/isDecompressionSuppressed // io.ktor.server.http.content/isDecompressionSuppressed|@io.ktor.server.application.ApplicationCall{}isDecompressionSuppressed[0]
final fun (io.ktor.server.application/ApplicationCall).<get-isDecompressionSuppressed>(): kotlin/Boolean // io.ktor.server.http.content/isDecompressionSuppressed.<get-isDecompressionSuppressed>|<get-isDecompressionSuppressed>@io.ktor.server.application.ApplicationCall(){}[0]
final val io.ktor.server.logging/mdcProvider // io.ktor.server.logging/mdcProvider|@io.ktor.server.application.Application{}mdcProvider[0]
final fun (io.ktor.server.application/Application).<get-mdcProvider>(): io.ktor.server.logging/MDCProvider // io.ktor.server.logging/mdcProvider.<get-mdcProvider>|<get-mdcProvider>@io.ktor.server.application.Application(){}[0]
final val io.ktor.server.plugins/MutableOriginConnectionPointKey // io.ktor.server.plugins/MutableOriginConnectionPointKey|{}MutableOriginConnectionPointKey[0]
Expand Down Expand Up @@ -1672,6 +1674,7 @@ final fun (io.ktor.http/HeadersBuilder).io.ktor.server.response/contentRange(kot
final fun (io.ktor.http/URLBuilder.Companion).io.ktor.server.util/createFromCall(io.ktor.server.application/ApplicationCall): io.ktor.http/URLBuilder // io.ktor.server.util/createFromCall|[email protected](io.ktor.server.application.ApplicationCall){}[0]
final fun (io.ktor.server.application/Application).io.ktor.server.routing/routing(kotlin/Function1<io.ktor.server.routing/Routing, kotlin/Unit>): io.ktor.server.routing/RoutingRoot // io.ktor.server.routing/routing|[email protected](kotlin.Function1<io.ktor.server.routing.Routing,kotlin.Unit>){}[0]
final fun (io.ktor.server.application/ApplicationCall).io.ktor.server.http.content/suppressCompression() // io.ktor.server.http.content/suppressCompression|[email protected](){}[0]
final fun (io.ktor.server.application/ApplicationCall).io.ktor.server.http.content/suppressDecompression() // io.ktor.server.http.content/suppressDecompression|[email protected](){}[0]
final fun (io.ktor.server.application/ApplicationCall).io.ktor.server.http/push(kotlin/Function1<io.ktor.server.response/ResponsePushBuilder, kotlin/Unit>) // io.ktor.server.http/push|[email protected](kotlin.Function1<io.ktor.server.response.ResponsePushBuilder,kotlin.Unit>){}[0]
final fun (io.ktor.server.application/ApplicationCall).io.ktor.server.http/push(kotlin/String) // io.ktor.server.http/push|[email protected](kotlin.String){}[0]
final fun (io.ktor.server.application/ApplicationCall).io.ktor.server.http/push(kotlin/String, io.ktor.http/Parameters) // io.ktor.server.http/push|[email protected](kotlin.String;io.ktor.http.Parameters){}[0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ public class OnCallReceiveContext<PluginConfig : Any> internal constructor(
public suspend fun transformBody(transform: suspend TransformBodyContext.(body: ByteReadChannel) -> Any) {
val receiveBody = context.subject as? ByteReadChannel ?: return
val typeInfo = context.call.receiveType
if (typeInfo == typeInfo<ByteReadChannel>()) return

val transformContext = TransformBodyContext(typeInfo)
context.subject = transformContext.transform(receiveBody)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import io.ktor.util.*
level = DeprecationLevel.ERROR
)
public val SuppressionAttribute: AttributeKey<Boolean> = AttributeKey("preventCompression")
internal val DecompressionSuppressionAttribute: AttributeKey<Boolean> = AttributeKey("preventDecompression")

/**
* Suppress response body compression plugin for this [ApplicationCall].
Expand All @@ -24,8 +25,23 @@ public fun ApplicationCall.suppressCompression() {
attributes.put(SuppressionAttribute, true)
}

/**
* Suppresses the decompression for the current application call.
*/
public fun ApplicationCall.suppressDecompression() {
attributes.put(DecompressionSuppressionAttribute, true)
}

/**
* Checks if response body compression is suppressed for this [ApplicationCall].
*/
@Suppress("DEPRECATION_ERROR")
public val ApplicationCall.isCompressionSuppressed: Boolean get() = SuppressionAttribute in attributes

/**
* Indicates whether decompression is suppressed for the current application call.
* If decompression is suppressed, the plugin will not decompress the request body.
*
* To suppress decompression, use [suppressDecompression].
*/
public val ApplicationCall.isDecompressionSuppressed: Boolean get() = DecompressionSuppressionAttribute in attributes
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,22 @@
package io.ktor.server.plugins.compression

import io.ktor.http.*
import io.ktor.http.cio.*
import io.ktor.http.content.*
import io.ktor.server.application.*
import io.ktor.server.http.content.*
import io.ktor.server.request.*
import io.ktor.server.response.*
import io.ktor.util.*
import io.ktor.util.logging.*
import io.ktor.util.pipeline.*
import io.ktor.utils.io.*
import kotlinx.coroutines.*

/**
* List of [ContentEncoder] names that were used to decode request body.
*/
public val ApplicationRequest.appliedDecoders: List<String>
get() = call.attributes.getOrNull(DecompressionListAttribute) ?: emptyList()

internal val DecompressionListAttribute: AttributeKey<List<String>> = AttributeKey("DecompressionListAttribute")

internal val LOGGER = KtorSimpleLogger("io.ktor.server.plugins.compression.Compression")

Expand All @@ -24,27 +29,6 @@ internal val LOGGER = KtorSimpleLogger("io.ktor.server.plugins.compression.Compr
*/
internal const val DEFAULT_MINIMAL_COMPRESSION_SIZE: Long = 200L

private object ContentEncoding : Hook<suspend ContentEncoding.Context.(PipelineCall) -> Unit> {

class Context(private val pipelineContext: PipelineContext<Any, PipelineCall>) {
fun transformBody(block: (OutgoingContent) -> OutgoingContent?) {
val transformedContent = block(pipelineContext.subject as OutgoingContent)
if (transformedContent != null) {
pipelineContext.subject = transformedContent
}
}
}

override fun install(
pipeline: ApplicationCallPipeline,
handler: suspend Context.(PipelineCall) -> Unit
) {
pipeline.sendPipeline.intercept(ApplicationSendPipeline.ContentEncoding) {
handler(Context(this), call)
}
}
}

/**
* A plugin that provides the capability to compress a response and decompress request bodies.
* You can use different compression algorithms, including `gzip` and `deflate`,
Expand Down Expand Up @@ -82,15 +66,20 @@ public val Compression: RouteScopedPlugin<CompressionConfig> = createRouteScoped
if (!mode.response) return@on
encode(call, options)
}
onCall { call ->
if (!mode.request) return@onCall

onCallReceive { call ->
if (!mode.request) return@onCallReceive
decode(call, options)
}
}

@OptIn(InternalAPI::class)
private fun decode(call: PipelineCall, options: CompressionOptions) {
private suspend fun OnCallReceiveContext<CompressionConfig>.decode(call: PipelineCall, options: CompressionOptions) {
val encodingRaw = call.request.headers[HttpHeaders.ContentEncoding]
if (call.isDecompressionSuppressed) {
LOGGER.trace("Skip decompression for ${call.request.uri} because it is suppressed.")
return
}
if (encodingRaw == null) {
LOGGER.trace("Skip decompression for ${call.request.uri} because no content encoding provided.")
return
Expand All @@ -112,10 +101,12 @@ private fun decode(call: PipelineCall, options: CompressionOptions) {
} else {
call.request.setHeader(HttpHeaders.ContentEncoding, null)
}
val originalChannel = call.request.receiveChannel()
val decoded = encoders.fold(originalChannel) { content, encoder -> encoder.encoder.decode(content) }
call.request.setReceiveChannel(decoded)

call.attributes.put(DecompressionListAttribute, encoderNames)

transformBody { body ->
encoders.fold(body) { content, encoder -> encoder.encoder.decode(content) }
}
}

private fun ContentEncoding.Context.encode(call: PipelineCall, options: CompressionOptions) {
Expand Down Expand Up @@ -180,14 +171,6 @@ private fun ContentEncoding.Context.encode(call: PipelineCall, options: Compress
}
}

internal val DecompressionListAttribute: AttributeKey<List<String>> = AttributeKey("DecompressionListAttribute")

/**
* List of [ContentEncoder] names that were used to decode request body.
*/
public val ApplicationRequest.appliedDecoders: List<String>
get() = call.attributes.getOrNull(DecompressionListAttribute) ?: emptyList()

private fun PipelineResponse.isSSEResponse(): Boolean {
val contentType = headers[HttpHeaders.ContentType]?.let { ContentType.parse(it) }
return contentType?.withoutParameters() == ContentType.Text.EventStream
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* Copyright 2014-2024 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
*/

package io.ktor.server.plugins.compression

import io.ktor.http.content.*
import io.ktor.server.application.*
import io.ktor.server.response.*
import io.ktor.util.pipeline.*

internal object ContentEncoding : Hook<suspend ContentEncoding.Context.(PipelineCall) -> Unit> {

class Context(private val pipelineContext: PipelineContext<Any, PipelineCall>) {
fun transformBody(block: (OutgoingContent) -> OutgoingContent?) {
val transformedContent = block(pipelineContext.subject as OutgoingContent)
if (transformedContent != null) {
pipelineContext.subject = transformedContent
}
}
}

override fun install(
pipeline: ApplicationCallPipeline,
handler: suspend Context.(PipelineCall) -> Unit
) {
pipeline.sendPipeline.intercept(ApplicationSendPipeline.ContentEncoding) {
handler(Context(this), call)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -647,24 +647,31 @@ class CompressionTest {
install(Compression)
routing {
post("/identity") {
val message = call.receiveText()
assertNull(call.request.headers[HttpHeaders.ContentEncoding])
assertEquals(listOf("identity"), call.request.appliedDecoders)
call.respond(call.receiveText())

call.respond(message)
}
post("/gzip") {
val message = call.receiveText()

assertNull(call.request.headers[HttpHeaders.ContentEncoding])
assertEquals(listOf("gzip"), call.request.appliedDecoders)
call.respond(call.receiveText())

call.respond(message)
}
post("/deflate") {
val message = call.receiveText()
assertNull(call.request.headers[HttpHeaders.ContentEncoding])
assertEquals(listOf("deflate"), call.request.appliedDecoders)
call.respond(call.receiveText())
call.respond(message)
}
post("/multiple") {
val message = call.receiveText()
assertNull(call.request.headers[HttpHeaders.ContentEncoding])
assertEquals(listOf("identity", "deflate", "gzip"), call.request.appliedDecoders)
call.respond(call.receiveText())
call.respond(message)
}
post("/unknown") {
assertEquals("unknown", call.request.headers[HttpHeaders.ContentEncoding])
Expand Down Expand Up @@ -759,9 +766,38 @@ class CompressionTest {
}
routing {
post("/gzip") {
val body = call.receive<ByteArray>()

assertNull(call.request.headers[HttpHeaders.ContentEncoding])
assertContentEquals(compressed, body)

call.respond(textToCompressAsBytes)
}
}

val response = client.post("/gzip") {
setBody(compressed)
header(HttpHeaders.ContentEncoding, "gzip")
header(HttpHeaders.AcceptEncoding, "gzip")
}
assertContentEquals(textToCompressAsBytes, response.body<ByteArray>())
}

@Test
fun testDisableCallEncoding() = testApplication {
val compressed = GZip.encode(ByteReadChannel(textToCompressAsBytes)).readRemaining().readByteArray()
install(Compression)

routing {
post("/gzip") {
call.suppressCompression()

val body = call.receive<ByteArray>()
assertContentEquals(textToCompressAsBytes, body)

assertEquals("gzip", call.request.appliedDecoders.first())
assertNull(call.request.headers[HttpHeaders.ContentEncoding])
assertContentEquals(compressed, body)

call.respond(textToCompressAsBytes)
}
}
Expand All @@ -774,6 +810,29 @@ class CompressionTest {
assertContentEquals(textToCompressAsBytes, response.body<ByteArray>())
}

@Test
fun testDisableCallDecoding() = testApplication {
val compressed = GZip.encode(ByteReadChannel(textToCompressAsBytes)).readRemaining().readByteArray()

install(Compression)
routing {
post("/gzip") {
call.suppressDecompression()
assertEquals("gzip", call.request.headers[HttpHeaders.ContentEncoding])
val body = call.receive<ByteArray>()
assertContentEquals(compressed, body)
call.respond(textToCompress)
}
}

val response = client.post("/gzip") {
setBody(compressed)
header(HttpHeaders.ContentEncoding, "gzip")
header(HttpHeaders.AcceptEncoding, "gzip")
}
assertContentEquals(compressed, response.body<ByteArray>())
}

private suspend fun ApplicationTestBuilder.handleAndAssert(
url: String,
acceptHeader: String?,
Expand Down

0 comments on commit d742cc0

Please sign in to comment.