diff --git a/settings.gradle.kts b/settings.gradle.kts index c30874ebfbb5..afa0842bfcb3 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -78,6 +78,7 @@ fun VersionCatalogBuilder.common() { fun VersionCatalogBuilder.misc() { alias("codahale-xsalsa20poly1305").to("com.codahale", "xsalsa20poly1305").version("0.10.1") + alias("okio-io").to("com.squareup.okio", "okio").version("3.0.0-alpha.10") } fun VersionCatalogBuilder.tests() { diff --git a/voice/build.gradle.kts b/voice/build.gradle.kts index f3c9ecaf4d26..7bc3e4bd4d2c 100644 --- a/voice/build.gradle.kts +++ b/voice/build.gradle.kts @@ -7,6 +7,7 @@ dependencies { api(common) api(gateway) + implementation(libs.okio.io) implementation(libs.codahale.xsalsa20poly1305) api(libs.ktor.client.json) api(libs.ktor.client.serialization) diff --git a/voice/src/main/kotlin/rtp/AudioPacket.kt b/voice/src/main/kotlin/rtp/AudioPacket.kt new file mode 100644 index 000000000000..f5c11b8334f4 --- /dev/null +++ b/voice/src/main/kotlin/rtp/AudioPacket.kt @@ -0,0 +1,84 @@ +package dev.kord.voice.rtp + +import dev.kord.common.annotation.KordVoice +import dev.kord.voice.XSalsa20Poly1305Codec +import io.ktor.utils.io.core.* +import okio.Buffer + +/** + * A guestimated list of known Discord RTP payloads. + */ +sealed class PayloadType(val value: Byte) { + object Alive: PayloadType(0x37.toByte()) + object Audio: PayloadType(0x78.toByte()) + class Unknown(value: Byte): PayloadType(value) + + companion object { + fun from(value: Byte) = when(value) { + 0x37.toByte() -> Alive + 0x78.toByte() -> Audio + else -> Unknown(value) + } + } +} + +@OptIn(ExperimentalUnsignedTypes::class) +@KordVoice +sealed class AudioPacket private constructor(internal val packet: RTPPacket) { + val sequence: UShort = packet.sequence + val timestamp: UInt = packet.timestamp + val ssrc: UInt = packet.ssrc + val payload: ByteArray = packet.payload + val payloadType: PayloadType = PayloadType.from(packet.payloadType) + + class EncryptedPacket internal constructor( + packet: RTPPacket, + ) : AudioPacket(packet) { + fun decrypt(key: ByteArray): DecryptedPacket { + val nonce = ByteArray(NONCE_LENGTH) + packet.asByteArray().copyInto(nonce, 0, 0, RTP_HEADER_LENGTH) + val decrypted = XSalsa20Poly1305Codec.decrypt(payload, key, nonce) ?: error("fuck me") + return DecryptedPacket(packet.copy(payload = decrypted)) + } + } + + fun asByteReadPacket() = ByteReadPacket(packet.asByteArray()) + + class DecryptedPacket internal constructor( + packet: RTPPacket, + ) : AudioPacket(packet) { + fun encrypt(key: ByteArray): EncryptedPacket { + val nonce = ByteArray(NONCE_LENGTH) + packet.asByteArray().copyInto(nonce, 0, 0, RTP_HEADER_LENGTH) + val encrypted = XSalsa20Poly1305Codec.encrypt(payload, key, nonce) + return EncryptedPacket(packet.copy(payload = encrypted)) + } + + @PublishedApi + internal companion object { + fun create( + sequence: UShort, + timestamp: UInt, + ssrc: UInt, + decryptedData: ByteArray + ): DecryptedPacket = DecryptedPacket( + RTPPacket(ssrc, timestamp, sequence, PayloadType.Audio.value, decryptedData) + ) + } + } + + companion object { + private const val NONCE_LENGTH = 24 + private const val RTP_HEADER_LENGTH = 12 + + @OptIn(ExperimentalStdlibApi::class) + fun encryptedFrom(data: ByteReadPacket): EncryptedPacket? { + return try { + val packet = RTPPacket.fromBytes(Buffer().write(data.copy().readBytes())) + EncryptedPacket(packet) + } catch(e: Exception) { + null + } + } + } +} \ No newline at end of file diff --git a/voice/src/main/kotlin/rtp/RTPPacket.kt b/voice/src/main/kotlin/rtp/RTPPacket.kt new file mode 100644 index 000000000000..5d085fa41953 --- /dev/null +++ b/voice/src/main/kotlin/rtp/RTPPacket.kt @@ -0,0 +1,157 @@ +package dev.kord.voice.rtp + +import okio.Buffer +import kotlin.experimental.and + +/** + * Originally from [this github library](https://github.com/vidtec/rtp-packet/blob/0b54fdeab5666089215b0074c64a6735b8937f8d/src/main/java/org/vidtec/rfc3550/rtp/RTPPacket.java). + * + * Ported to Kotlin. + * + * Contains modifications to support RTP packets from Discord as they do not adhere to the standard RTP spec. + * + * Some changes include + * - Timestamps are stored as UInt (4 bytes) not long. + * - Sequences are stored as UShort (2 bytes) not long. + * - Payload Types are stored as UByte (1 byte) not short. + * - SSRCs are stored as UInt (4 bytes) not long. + * - And most notably, the extension header is not exactly after the RTP header. Discord instead encrypts it along with the payload... + */ +@Suppress("ArrayInDataClass") +@OptIn(ExperimentalUnsignedTypes::class) +internal data class RTPPacket( + val paddingBytes: UByte, + val payloadType: Byte, + val sequence: UShort, + val timestamp: UInt, + val ssrc: UInt, + val csrcIdentifiers: UIntArray, + val hasMarker: Boolean, + val hasExtension: Boolean, + val payload: ByteArray, +) { + companion object { + const val VERSION = 2 + + fun fromBytes(buffer: Buffer) = with(buffer) { + val initialLength = size + + require(initialLength > 13) { "packet too short" } + + /* + * first byte | bit table + * 0 = version + * 1 = padding bit + * 2 = extension bit + * 3-7 = csrc count + */ + val paddingBytes: UByte + val hasExtension: Boolean + val csrcCount: Byte + with(readByte()) { + require(((toInt() and 0xC0) == VERSION shl 6)) { "invalid version" } + + paddingBytes = if (((toInt() and 0x20) == 0x20)) buffer[initialLength - 1].toUByte() else 0u; + hasExtension = (toInt() and 0x10) == 0x10 + csrcCount = toByte() and 0x0F + } + + /* + * second byte | bit table + * 0 = marker + * 1-7 = payload type + */ + val marker: Boolean + val payloadType: Byte + with(readByte()) { + marker = (this and 0x80.toByte()) == 0x80.toByte() + payloadType = this and 0x7F + } + + val sequence = readShort().toUShort() + val timestamp = readInt().toUInt() + val ssrc = readInt().toUInt() + + // each csrc takes up 4 bytes, plus more data is required + require(size > csrcCount * 4 + 1) { "packet too short" } + val csrcIdentifiers = UIntArray(csrcCount.toInt()) { readInt().toUInt() } + + val payload = readByteArray().apply { copyOfRange(0, size - paddingBytes.toInt()) } + + RTPPacket( + paddingBytes, + payloadType, + sequence, + timestamp, + ssrc, + csrcIdentifiers, + marker, + hasExtension, + payload + ) + } + } + + init { + require(payloadType < 127 || payloadType > 0) { "invalid payload type" } + require(payload.isNotEmpty()) { "invalid payload" } + } + + private val header = with(Buffer()) { + val hasPadding = if (paddingBytes > 0u) 0x20 else 0x00 + val hasExtension = if (hasExtension) 0x10 else 0x0 + writeByte((VERSION shl 6) or (hasPadding) or (hasExtension) or 0) + writeByte(payloadType.toInt()) + writeShort(sequence.toInt()) + writeInt(timestamp.toInt()) + writeInt(ssrc.toInt()) + + readByteArray() + } + + fun asByteArray() = with(Buffer()) { + write(header) + write(payload) + + if (paddingBytes > 0u) { + write(ByteArray(paddingBytes.toInt() - 1) { 0x0 }) + writeByte(paddingBytes.toInt()) + } + + buffer.readByteArray() + } + + class Builder( + var ssrc: UInt, + var timestamp: UInt, + var sequence: UShort, + var payloadType: Byte, + var payload: ByteArray + ) { + var marker: Boolean = false + var paddingBytes: UByte = 0u + var hasExtension: Boolean = false + var csrcIdentifiers: UIntArray = uintArrayOf() + + fun build() = RTPPacket( + paddingBytes, + payloadType, + sequence, + timestamp, + ssrc, + csrcIdentifiers, + marker, + hasExtension, + payload + ) + } +} + +internal fun RTPPacket( + ssrc: UInt, + timestamp: UInt, + sequence: UShort, + payloadType: Byte, + payload: ByteArray, + builder: RTPPacket.Builder.() -> Unit = {} +) = RTPPacket.Builder(ssrc, timestamp, sequence, payloadType, payload).apply(builder).build() \ No newline at end of file diff --git a/voice/src/main/kotlin/streams/DefaultStreams.kt b/voice/src/main/kotlin/streams/DefaultStreams.kt index 48b5445109a4..60e314312316 100644 --- a/voice/src/main/kotlin/streams/DefaultStreams.kt +++ b/voice/src/main/kotlin/streams/DefaultStreams.kt @@ -5,7 +5,8 @@ import dev.kord.common.entity.Snowflake import dev.kord.voice.AudioFrame import dev.kord.voice.gateway.Speaking import dev.kord.voice.gateway.VoiceGateway -import dev.kord.voice.udp.AudioPacket +import dev.kord.voice.rtp.AudioPacket +import dev.kord.voice.rtp.PayloadType import dev.kord.voice.udp.VoiceUdpConnection import kotlinx.atomicfu.AtomicRef import kotlinx.atomicfu.atomic @@ -16,6 +17,7 @@ import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.SupervisorJob import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.flow.* +import okio.Buffer import kotlin.coroutines.CoroutineContext @KordVoice @@ -31,12 +33,31 @@ class DefaultStreams( override val incomingAudioPackets: SharedFlow = udp.incoming + .map { it.copy() } .mapNotNull(AudioPacket::encryptedFrom) - .map { it.decrypt(key!!) } + .filter { it.payloadType == PayloadType.Audio } + .map { it.decrypt(key!!).removeExtensionHeader() } .shareIn(this, SharingStarted.Lazily) + // perhaps we can expose the extension header later + @OptIn(ExperimentalUnsignedTypes::class) + fun AudioPacket.DecryptedPacket.removeExtensionHeader(): AudioPacket.DecryptedPacket { + fun processExtensionHeader(data: ByteArray): ByteArray = with(Buffer()) { + buffer.write(data) + readShort() // profile, ignore it + val countOf32BitWords = readShort().toLong() // amount of extension header "words" + readByteArray((countOf32BitWords * 32)/Byte.SIZE_BITS) // consume extension header + + readByteArray() // consume rest of payload and return it + } + + return if(packet.hasExtension) + AudioPacket.DecryptedPacket(packet.copy(hasExtension = false, payload = processExtensionHeader(packet.payload))) + else this + } + override val incomingAudioFrames: Flow> - get() = incomingAudioPackets.map { it.ssrc to AudioFrame.fromData(it.data)!! } + get() = incomingAudioPackets.map { it.ssrc to AudioFrame.fromData(it.payload)!! } private val _incomingUserAudioFrames: MutableSharedFlow> = MutableSharedFlow() diff --git a/voice/src/main/kotlin/streams/NOPStreams.kt b/voice/src/main/kotlin/streams/NOPStreams.kt index e7aea64a864a..f9e5639cbf7d 100644 --- a/voice/src/main/kotlin/streams/NOPStreams.kt +++ b/voice/src/main/kotlin/streams/NOPStreams.kt @@ -3,7 +3,7 @@ package dev.kord.voice.streams import dev.kord.common.annotation.KordVoice import dev.kord.common.entity.Snowflake import dev.kord.voice.AudioFrame -import dev.kord.voice.udp.AudioPacket +import dev.kord.voice.rtp.AudioPacket import kotlinx.coroutines.flow.Flow @KordVoice diff --git a/voice/src/main/kotlin/streams/Streams.kt b/voice/src/main/kotlin/streams/Streams.kt index b365a203c23c..e8f938d66599 100644 --- a/voice/src/main/kotlin/streams/Streams.kt +++ b/voice/src/main/kotlin/streams/Streams.kt @@ -3,7 +3,7 @@ package dev.kord.voice.streams import dev.kord.common.annotation.KordVoice import dev.kord.common.entity.Snowflake import dev.kord.voice.AudioFrame -import dev.kord.voice.udp.AudioPacket +import dev.kord.voice.rtp.AudioPacket import kotlinx.coroutines.flow.Flow diff --git a/voice/src/main/kotlin/udp/AudioPacket.kt b/voice/src/main/kotlin/udp/AudioPacket.kt deleted file mode 100644 index b742e06e4b93..000000000000 --- a/voice/src/main/kotlin/udp/AudioPacket.kt +++ /dev/null @@ -1,80 +0,0 @@ -package dev.kord.voice.udp - -import dev.kord.common.annotation.KordVoice -import dev.kord.voice.XSalsa20Poly1305Codec -import io.ktor.utils.io.core.* - -@OptIn(ExperimentalUnsignedTypes::class) -@KordVoice -sealed class AudioPacket( - val sequence: UShort, - val timestamp: UInt, - val ssrc: UInt, - val data: ByteArray -) { - @OptIn(ExperimentalIoApi::class) - val header = BytePacketBuilder().apply { - writeByte(RTP_TYPE) - writeByte(RTP_VERSION) - writeUShort(sequence) - writeUInt(timestamp) - writeUInt(ssrc) - }.build().readBytes() - - class EncryptedPacket( - sequence: UShort, - timestamp: UInt, - ssrc: UInt, - encryptedData: ByteArray - ) : AudioPacket(sequence, timestamp, ssrc, encryptedData) { - fun decrypt(key: ByteArray): DecryptedPacket { - val nonce = ByteArray(NONCE_LENGTH) - header.copyInto(nonce, 0, 0, RTP_HEADER_LENGTH) - val decrypted = XSalsa20Poly1305Codec.decrypt(data, key, nonce) ?: error("couldn't decrypt audio data") - return DecryptedPacket(sequence, timestamp, ssrc, decrypted) - } - } - - fun asByteReadPacket() = BytePacketBuilder().also { - it.writeFully(header) - it.writeFully(data) - }.build() - - class DecryptedPacket( - sequence: UShort, - timestamp: UInt, - ssrc: UInt, - decryptedData: ByteArray - ) : AudioPacket(sequence, timestamp, ssrc, decryptedData) { - fun encrypt(key: ByteArray): EncryptedPacket { - val nonce = ByteArray(NONCE_LENGTH) - header.copyInto(nonce, 0, 0, RTP_HEADER_LENGTH) - val encrypted = XSalsa20Poly1305Codec.encrypt(data, key, nonce) - return EncryptedPacket(sequence, timestamp, ssrc, encrypted) - } - } - - companion object { - private const val RTP_TYPE: Byte = 0x90.toByte() - private const val RTP_VERSION = 0x78.toByte() - private const val NONCE_LENGTH = 24 - private const val RTP_HEADER_LENGTH = 12 - - fun encryptedFrom(data: ByteReadPacket): EncryptedPacket? { - try { - with(data) { - require(readByte() == RTP_TYPE) - require(readByte() == RTP_VERSION) - val sequence = readUShort() - val timestamp = readUInt() - val ssrc = readUInt() - val encryptedData = readBytes() - - return EncryptedPacket(sequence, timestamp, ssrc, encryptedData) - } - } catch (e: Exception) { - return null - } - } - } -} \ No newline at end of file diff --git a/voice/src/main/kotlin/udp/DefaultAudioFrameSender.kt b/voice/src/main/kotlin/udp/DefaultAudioFrameSender.kt index 0357193b7125..ef99c2a03ba8 100644 --- a/voice/src/main/kotlin/udp/DefaultAudioFrameSender.kt +++ b/voice/src/main/kotlin/udp/DefaultAudioFrameSender.kt @@ -3,6 +3,7 @@ package dev.kord.voice.udp import dev.kord.common.annotation.KordVoice import dev.kord.voice.AudioFrame import dev.kord.voice.FrameInterceptor +import dev.kord.voice.rtp.AudioPacket import kotlinx.coroutines.CoroutineDispatcher import kotlinx.coroutines.CoroutineName import kotlinx.coroutines.SupervisorJob @@ -51,7 +52,7 @@ class DefaultAudioFrameSender( .transform { emit(interceptor.intercept(it)) } .collect { frame -> if (frame != null) { - val encryptedPacket = AudioPacket.DecryptedPacket( + val encryptedPacket = AudioPacket.DecryptedPacket.create( sequence = sequence, timestamp = sequence * 960u, ssrc = ssrc,