Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't encode Authorization header #486

Merged
merged 2 commits into from
Jan 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 18 additions & 12 deletions core/src/main/kotlin/builder/kord/KordBuilder.kt
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

package dev.kord.core.builder.kord

import dev.kord.cache.api.DataCache
Expand Down Expand Up @@ -26,11 +25,11 @@ import dev.kord.rest.request.RequestHandler
import dev.kord.rest.request.isError
import dev.kord.rest.route.Route
import dev.kord.rest.service.RestClient
import io.ktor.client.HttpClient
import io.ktor.client.request.get
import io.ktor.client.statement.HttpResponse
import io.ktor.client.statement.readText
import io.ktor.http.HttpStatusCode
import io.ktor.client.*
import io.ktor.client.request.*
import io.ktor.client.statement.*
import io.ktor.http.*
import io.ktor.http.HttpHeaders.Authorization
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.flow.MutableSharedFlow
Expand All @@ -56,6 +55,7 @@ public operator fun DefaultGateway.Companion.invoke(
}

private val logger = KotlinLogging.logger { }
private val gatewayInfoJson = Json { ignoreUnknownKeys = true }

public class KordBuilder(public val token: String) {
private var shardsBuilder: (recommended: Int) -> Shards = { Shards(it) }
Expand Down Expand Up @@ -83,7 +83,7 @@ public class KordBuilder(public val token: String) {
* The event flow used by [Kord.eventFlow] to publish [events][Kord.events].
*
*
* By default a [MutableSharedFlow] with an `extraBufferCapacity` of `Int.MAX_VALUE`.
* By default, a [MutableSharedFlow] with an `extraBufferCapacity` of `Int.MAX_VALUE` is used.
*/
public var eventFlow: MutableSharedFlow<Event> = MutableSharedFlow(
extraBufferCapacity = Int.MAX_VALUE
Expand Down Expand Up @@ -184,13 +184,17 @@ public class KordBuilder(public val token: String) {
* Requests the gateway info for the bot, or throws a [KordInitializationException] when something went wrong.
*/
private suspend fun HttpClient.getGatewayInfo(): BotGatewayResponse {
val response = get<HttpResponse>("${Route.baseUrl}${Route.GatewayBotGet.path}")
val response = get<HttpResponse>("${Route.baseUrl}${Route.GatewayBotGet.path}") {
header(Authorization, "Bot $token")
}
val responseBody = response.readText()
if (response.isError) {
val message = buildString {
append("Something went wrong while initializing Kord.")
append("Something went wrong while initializing Kord")
if (response.status == HttpStatusCode.Unauthorized) {
append(", make sure the bot token you entered is valid.")
} else {
append('.')
}

appendLine(responseBody)
Expand All @@ -199,7 +203,7 @@ public class KordBuilder(public val token: String) {
throw KordInitializationException(message)
}

return Json { ignoreUnknownKeys = true }.decodeFromString(BotGatewayResponse.serializer(), responseBody)
return gatewayInfoJson.decodeFromString(BotGatewayResponse.serializer(), responseBody)
}

/**
Expand All @@ -216,11 +220,13 @@ public class KordBuilder(public val token: String) {
logger.warn {
"""
kord's http client is currently using ${client.engine.config.threadsCount} threads,
which is less than the advised threadcount of ${shards.size + 1} (number of shards + 1)""".trimIndent()
which is less than the advised thread count of ${shards.size + 1} (number of shards + 1)
""".trimIndent()
}
}

val resources = ClientResources(token,applicationId ?: getBotIdFromToken(token), shardsInfo, client, defaultStrategy)
val resources =
ClientResources(token, applicationId ?: getBotIdFromToken(token), shardsInfo, client, defaultStrategy)
val rest = RestClient(handlerBuilder(resources))
val cache = KordCacheBuilder().apply { cacheBuilder(resources) }.build()
cache.registerKordData()
Expand Down
2 changes: 1 addition & 1 deletion rest/src/main/kotlin/request/HttpUtils.kt
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ private const val auditLogReason = "X-Audit-Log-Reason"
/**
* Sets the reason that will show up in the [Discord Audit Log]() to [reason] for this request.
*/
fun <T> RequestBuilder<T>.auditLogReason(reason: String?) = reason?.let { header(auditLogReason, reason) }
fun <T> RequestBuilder<T>.auditLogReason(reason: String?) = reason?.let { urlEncodedHeader(auditLogReason, reason) }

val HttpResponse.channelResetPoint: Instant
get() {
Expand Down
17 changes: 16 additions & 1 deletion rest/src/main/kotlin/request/RequestBuilder.kt
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,22 @@ class RequestBuilder<T>(private val route: Route<T>, keySize: Int = 2) {

fun parameter(key: String, value: Any) = parameters.append(key, value.toString())

fun header(key: String, value: String) = headers.append(key, value.encodeURLQueryComponent())
@Deprecated(
"'header' was renamed to 'urlEncodedHeader'",
ReplaceWith("urlEncodedHeader(key, value)"),
DeprecationLevel.ERROR,
)
fun header(key: String, value: String) = urlEncodedHeader(key, value)

/** Adds a header and encodes its [value] as an [URL query component][encodeURLQueryComponent]. */
fun urlEncodedHeader(key: String, value: String) {
headers.append(key, value.encodeURLQueryComponent())
}

/** Adds a header without encoding its [value]. */
fun unencodedHeader(key: String, value: String) {
headers.append(key, value)
}

fun file(name: String, input: java.io.InputStream) {
files.add(NamedFile(name, input))
Expand Down
9 changes: 5 additions & 4 deletions rest/src/main/kotlin/service/ChannelService.kt
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ class ChannelService(requestHandler: RequestHandler) : RestService(requestHandle
suspend fun bulkDelete(channelId: Snowflake, messages: BulkDeleteRequest, reason: String?) = call(Route.BulkMessageDeletePost) {
keys[Route.ChannelId] = channelId
body(BulkDeleteRequest.serializer(), messages)
auditLogReason(reason)
}

suspend fun deleteChannel(channelId: Snowflake, reason: String? = null) = call(Route.ChannelDelete) {
Expand Down Expand Up @@ -243,7 +244,7 @@ class ChannelService(requestHandler: RequestHandler) : RestService(requestHandle
call(Route.ChannelPut) {
keys[Route.ChannelId] = channelId
body(ChannelModifyPutRequest.serializer(), channel)
reason?.let { header("X-Audit-Log-Reason", reason) }
auditLogReason(reason)
}

suspend fun patchChannel(channelId: Snowflake, channel: ChannelModifyPatchRequest, reason: String? = null) =
Expand All @@ -256,7 +257,7 @@ class ChannelService(requestHandler: RequestHandler) : RestService(requestHandle
call(Route.ChannelPatch) {
keys[Route.ChannelId] = threadId
body(ChannelModifyPatchRequest.serializer(), thread)
reason?.let { header("X-Audit-Log-Reason", reason) }
auditLogReason(reason)
}


Expand All @@ -282,7 +283,7 @@ class ChannelService(requestHandler: RequestHandler) : RestService(requestHandle
keys[Route.ChannelId] = channelId
keys[Route.MessageId] = messageId
body(StartThreadRequest.serializer(), request)
reason?.let { header("X-Audit-Log-Reason", reason) }
auditLogReason(reason)
}
}

Expand All @@ -306,7 +307,7 @@ class ChannelService(requestHandler: RequestHandler) : RestService(requestHandle
return call(Route.StartThreadPost) {
keys[Route.ChannelId] = channelId
body(StartThreadRequest.serializer(), request)
reason?.let { header("X-Audit-Log-Reason", reason) }
auditLogReason(reason)
}
}

Expand Down
5 changes: 2 additions & 3 deletions rest/src/main/kotlin/service/RestService.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package dev.kord.rest.service
import dev.kord.rest.request.RequestBuilder
import dev.kord.rest.request.RequestHandler
import dev.kord.rest.route.Route
import io.ktor.http.HttpHeaders
import io.ktor.http.HttpHeaders.Authorization
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract

Expand All @@ -18,11 +18,10 @@ abstract class RestService(@PublishedApi internal val requestHandler: RequestHan
.apply(builder)
.apply {
if (route.requiresAuthorization) {
header(HttpHeaders.Authorization, "Bot ${requestHandler.token}")
unencodedHeader(Authorization, "Bot ${requestHandler.token}")
}
}
.build()
return requestHandler.handle(request)
}

}