Skip to content

Commit

Permalink
Don't encode Authorization header (#486)
Browse files Browse the repository at this point in the history
* Fix encoding Authorization header

* Fix wrong header in KordBuilder
  • Loading branch information
lukellmann authored Jan 17, 2022
1 parent 378d71d commit 3863444
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 21 deletions.
30 changes: 18 additions & 12 deletions core/src/main/kotlin/builder/kord/KordBuilder.kt
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

package dev.kord.core.builder.kord

import dev.kord.cache.api.DataCache
Expand Down Expand Up @@ -26,11 +25,11 @@ import dev.kord.rest.request.RequestHandler
import dev.kord.rest.request.isError
import dev.kord.rest.route.Route
import dev.kord.rest.service.RestClient
import io.ktor.client.HttpClient
import io.ktor.client.request.get
import io.ktor.client.statement.HttpResponse
import io.ktor.client.statement.readText
import io.ktor.http.HttpStatusCode
import io.ktor.client.*
import io.ktor.client.request.*
import io.ktor.client.statement.*
import io.ktor.http.*
import io.ktor.http.HttpHeaders.Authorization
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.flow.MutableSharedFlow
Expand All @@ -56,6 +55,7 @@ public operator fun DefaultGateway.Companion.invoke(
}

private val logger = KotlinLogging.logger { }
private val gatewayInfoJson = Json { ignoreUnknownKeys = true }

public class KordBuilder(public val token: String) {
private var shardsBuilder: (recommended: Int) -> Shards = { Shards(it) }
Expand Down Expand Up @@ -83,7 +83,7 @@ public class KordBuilder(public val token: String) {
* The event flow used by [Kord.eventFlow] to publish [events][Kord.events].
*
*
* By default a [MutableSharedFlow] with an `extraBufferCapacity` of `Int.MAX_VALUE`.
* By default, a [MutableSharedFlow] with an `extraBufferCapacity` of `Int.MAX_VALUE` is used.
*/
public var eventFlow: MutableSharedFlow<Event> = MutableSharedFlow(
extraBufferCapacity = Int.MAX_VALUE
Expand Down Expand Up @@ -184,13 +184,17 @@ public class KordBuilder(public val token: String) {
* Requests the gateway info for the bot, or throws a [KordInitializationException] when something went wrong.
*/
private suspend fun HttpClient.getGatewayInfo(): BotGatewayResponse {
val response = get<HttpResponse>("${Route.baseUrl}${Route.GatewayBotGet.path}")
val response = get<HttpResponse>("${Route.baseUrl}${Route.GatewayBotGet.path}") {
header(Authorization, "Bot $token")
}
val responseBody = response.readText()
if (response.isError) {
val message = buildString {
append("Something went wrong while initializing Kord.")
append("Something went wrong while initializing Kord")
if (response.status == HttpStatusCode.Unauthorized) {
append(", make sure the bot token you entered is valid.")
} else {
append('.')
}

appendLine(responseBody)
Expand All @@ -199,7 +203,7 @@ public class KordBuilder(public val token: String) {
throw KordInitializationException(message)
}

return Json { ignoreUnknownKeys = true }.decodeFromString(BotGatewayResponse.serializer(), responseBody)
return gatewayInfoJson.decodeFromString(BotGatewayResponse.serializer(), responseBody)
}

/**
Expand All @@ -216,11 +220,13 @@ public class KordBuilder(public val token: String) {
logger.warn {
"""
kord's http client is currently using ${client.engine.config.threadsCount} threads,
which is less than the advised threadcount of ${shards.size + 1} (number of shards + 1)""".trimIndent()
which is less than the advised thread count of ${shards.size + 1} (number of shards + 1)
""".trimIndent()
}
}

val resources = ClientResources(token,applicationId ?: getBotIdFromToken(token), shardsInfo, client, defaultStrategy)
val resources =
ClientResources(token, applicationId ?: getBotIdFromToken(token), shardsInfo, client, defaultStrategy)
val rest = RestClient(handlerBuilder(resources))
val cache = KordCacheBuilder().apply { cacheBuilder(resources) }.build()
cache.registerKordData()
Expand Down
2 changes: 1 addition & 1 deletion rest/src/main/kotlin/request/HttpUtils.kt
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ private const val auditLogReason = "X-Audit-Log-Reason"
/**
* Sets the reason that will show up in the [Discord Audit Log]() to [reason] for this request.
*/
fun <T> RequestBuilder<T>.auditLogReason(reason: String?) = reason?.let { header(auditLogReason, reason) }
fun <T> RequestBuilder<T>.auditLogReason(reason: String?) = reason?.let { urlEncodedHeader(auditLogReason, reason) }

val HttpResponse.channelResetPoint: Instant
get() {
Expand Down
17 changes: 16 additions & 1 deletion rest/src/main/kotlin/request/RequestBuilder.kt
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,22 @@ class RequestBuilder<T>(private val route: Route<T>, keySize: Int = 2) {

fun parameter(key: String, value: Any) = parameters.append(key, value.toString())

fun header(key: String, value: String) = headers.append(key, value.encodeURLQueryComponent())
@Deprecated(
"'header' was renamed to 'urlEncodedHeader'",
ReplaceWith("urlEncodedHeader(key, value)"),
DeprecationLevel.ERROR,
)
fun header(key: String, value: String) = urlEncodedHeader(key, value)

/** Adds a header and encodes its [value] as an [URL query component][encodeURLQueryComponent]. */
fun urlEncodedHeader(key: String, value: String) {
headers.append(key, value.encodeURLQueryComponent())
}

/** Adds a header without encoding its [value]. */
fun unencodedHeader(key: String, value: String) {
headers.append(key, value)
}

fun file(name: String, input: java.io.InputStream) {
files.add(NamedFile(name, input))
Expand Down
9 changes: 5 additions & 4 deletions rest/src/main/kotlin/service/ChannelService.kt
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ class ChannelService(requestHandler: RequestHandler) : RestService(requestHandle
suspend fun bulkDelete(channelId: Snowflake, messages: BulkDeleteRequest, reason: String?) = call(Route.BulkMessageDeletePost) {
keys[Route.ChannelId] = channelId
body(BulkDeleteRequest.serializer(), messages)
auditLogReason(reason)
}

suspend fun deleteChannel(channelId: Snowflake, reason: String? = null) = call(Route.ChannelDelete) {
Expand Down Expand Up @@ -243,7 +244,7 @@ class ChannelService(requestHandler: RequestHandler) : RestService(requestHandle
call(Route.ChannelPut) {
keys[Route.ChannelId] = channelId
body(ChannelModifyPutRequest.serializer(), channel)
reason?.let { header("X-Audit-Log-Reason", reason) }
auditLogReason(reason)
}

suspend fun patchChannel(channelId: Snowflake, channel: ChannelModifyPatchRequest, reason: String? = null) =
Expand All @@ -256,7 +257,7 @@ class ChannelService(requestHandler: RequestHandler) : RestService(requestHandle
call(Route.ChannelPatch) {
keys[Route.ChannelId] = threadId
body(ChannelModifyPatchRequest.serializer(), thread)
reason?.let { header("X-Audit-Log-Reason", reason) }
auditLogReason(reason)
}


Expand All @@ -282,7 +283,7 @@ class ChannelService(requestHandler: RequestHandler) : RestService(requestHandle
keys[Route.ChannelId] = channelId
keys[Route.MessageId] = messageId
body(StartThreadRequest.serializer(), request)
reason?.let { header("X-Audit-Log-Reason", reason) }
auditLogReason(reason)
}
}

Expand All @@ -306,7 +307,7 @@ class ChannelService(requestHandler: RequestHandler) : RestService(requestHandle
return call(Route.StartThreadPost) {
keys[Route.ChannelId] = channelId
body(StartThreadRequest.serializer(), request)
reason?.let { header("X-Audit-Log-Reason", reason) }
auditLogReason(reason)
}
}

Expand Down
5 changes: 2 additions & 3 deletions rest/src/main/kotlin/service/RestService.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package dev.kord.rest.service
import dev.kord.rest.request.RequestBuilder
import dev.kord.rest.request.RequestHandler
import dev.kord.rest.route.Route
import io.ktor.http.HttpHeaders
import io.ktor.http.HttpHeaders.Authorization
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract

Expand All @@ -18,11 +18,10 @@ abstract class RestService(@PublishedApi internal val requestHandler: RequestHan
.apply(builder)
.apply {
if (route.requiresAuthorization) {
header(HttpHeaders.Authorization, "Bot ${requestHandler.token}")
unencodedHeader(Authorization, "Bot ${requestHandler.token}")
}
}
.build()
return requestHandler.handle(request)
}

}

0 comments on commit 3863444

Please sign in to comment.