diff --git a/src/handshake-ik.ts b/src/handshake-ik.ts index 7a9d64b..bf3851f 100644 --- a/src/handshake-ik.ts +++ b/src/handshake-ik.ts @@ -73,7 +73,7 @@ export class IKHandshake implements IHandshake { throw new Error('ik handshake stage 0 decryption validation fail') } logger('IK Stage 0 - Responder got message, going to verify payload.') - const decodedPayload = await decodePayload(plaintext) + const decodedPayload = decodePayload(plaintext) this.remotePeer = this.remotePeer || await getPeerIdFromPayload(decodedPayload) await verifySignedPayload(this.session.hs.rs, decodedPayload, this.remotePeer) this.setRemoteEarlyData(decodedPayload.data) @@ -99,7 +99,7 @@ export class IKHandshake implements IHandshake { if (!valid) { throw new Error('ik stage 1 decryption validation fail') } - const decodedPayload = await decodePayload(plaintext) + const decodedPayload = decodePayload(plaintext) this.remotePeer = this.remotePeer || await getPeerIdFromPayload(decodedPayload) await verifySignedPayload(receivedMessageBuffer.ns.slice(0, 32), decodedPayload, this.remotePeer) this.setRemoteEarlyData(decodedPayload.data) diff --git a/src/handshake-xx-fallback.ts b/src/handshake-xx-fallback.ts index bb9842a..9e00d58 100644 --- a/src/handshake-xx-fallback.ts +++ b/src/handshake-xx-fallback.ts @@ -69,7 +69,7 @@ export class XXFallbackHandshake extends XXHandshake { logger("Initiator going to check remote's signature...") try { - const decodedPayload = await decodePayload(plaintext) + const decodedPayload = decodePayload(plaintext) this.remotePeer = this.remotePeer || await getPeerIdFromPayload(decodedPayload) await verifySignedPayload(this.session.hs.rs, decodedPayload, this.remotePeer) this.setRemoteEarlyData(decodedPayload.data) diff --git a/src/handshake-xx.ts b/src/handshake-xx.ts index aa58942..cec4b3a 100644 --- a/src/handshake-xx.ts +++ b/src/handshake-xx.ts @@ -94,9 +94,9 @@ export class XXHandshake implements IHandshake { logger("Initiator going to check remote's signature...") try { - const decodedPayload = await decodePayload(plaintext) + const decodedPayload = decodePayload(plaintext) this.remotePeer = this.remotePeer || await getPeerIdFromPayload(decodedPayload) - this.remotePeer = await verifySignedPayload(this.session.hs.rs, decodedPayload, this.remotePeer) + await verifySignedPayload(this.session.hs.rs, decodedPayload, this.remotePeer) this.setRemoteEarlyData(decodedPayload.data) } catch (e) { const err = e as Error @@ -129,7 +129,7 @@ export class XXHandshake implements IHandshake { logger('Stage 2 - Responder received the message, finished handshake.') try { - const decodedPayload = await decodePayload(plaintext) + const decodedPayload = decodePayload(plaintext) this.remotePeer = this.remotePeer || await getPeerIdFromPayload(decodedPayload) await verifySignedPayload(this.session.hs.rs, decodedPayload, this.remotePeer) this.setRemoteEarlyData(decodedPayload.data) diff --git a/src/handshakes/ik.ts b/src/handshakes/ik.ts index 3eef984..2ec938f 100644 --- a/src/handshakes/ik.ts +++ b/src/handshakes/ik.ts @@ -108,7 +108,7 @@ export class IK extends AbstractHandshake { this.mixHash(hs.ss, hs.re) this.mixKey(hs.ss, this.dh(hs.s.privateKey, hs.re)) const { plaintext: ns, valid: valid1 } = this.decryptAndHash(hs.ss, message.ns) - if (valid1 && ns.length === 32 && isValidPublicKey(ns)) { + if (valid1 && isValidPublicKey(ns)) { hs.rs = ns } this.mixKey(hs.ss, this.dh(hs.s.privateKey, hs.rs)) diff --git a/src/handshakes/xx.ts b/src/handshakes/xx.ts index eebad4b..3499d4c 100644 --- a/src/handshakes/xx.ts +++ b/src/handshakes/xx.ts @@ -87,7 +87,7 @@ export class XX extends AbstractHandshake { } this.mixKey(hs.ss, this.dh(hs.e.privateKey, hs.re)) const { plaintext: ns, valid: valid1 } = this.decryptAndHash(hs.ss, message.ns) - if (valid1 && ns.length === 32 && isValidPublicKey(ns)) { + if (valid1 && isValidPublicKey(ns)) { hs.rs = ns } this.mixKey(hs.ss, this.dh(hs.e.privateKey, hs.rs)) @@ -97,7 +97,7 @@ export class XX extends AbstractHandshake { private readMessageC (hs: HandshakeState, message: MessageBuffer): {h: bytes, plaintext: bytes, valid: boolean, cs1: CipherState, cs2: CipherState} { const { plaintext: ns, valid: valid1 } = this.decryptAndHash(hs.ss, message.ns) - if (valid1 && ns.length === 32 && isValidPublicKey(ns)) { + if (valid1 && isValidPublicKey(ns)) { hs.rs = ns } if (!hs.e) { diff --git a/src/utils.ts b/src/utils.ts index 7dd7fd8..cc9ff33 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -62,11 +62,6 @@ export function getHandshakePayload (publicKey: bytes): bytes { return uint8ArrayConcat([prefix, publicKey], prefix.length + publicKey.length) } -async function isValidPeerId (peerId: PeerId, publicKeyProtobuf: bytes): Promise { - const generatedPeerId = await peerIdFromKeys(publicKeyProtobuf) - return generatedPeerId.equals(peerId) -} - /** * Verifies signed payload, throws on any irregularities. * @@ -80,15 +75,14 @@ export async function verifySignedPayload ( payload: pb.NoiseHandshakePayload, remotePeer: PeerId ): Promise { - const identityKey = payload.identityKey - if (!(await isValidPeerId(remotePeer, identityKey))) { + // Unmarshaling from PublicKey protobuf + const payloadPeerId = await peerIdFromKeys(payload.identityKey) + if (!payloadPeerId.equals(remotePeer)) { throw new Error("Peer ID doesn't match libp2p public key.") } const generatedPayload = getHandshakePayload(noiseStaticKey) - // Unmarshaling from PublicKey protobuf - const peerId = await peerIdFromKeys(identityKey) - if (peerId.publicKey == null) { + if (payloadPeerId.publicKey == null) { throw new Error('PublicKey was missing from PeerId') } @@ -96,7 +90,7 @@ export async function verifySignedPayload ( throw new Error('Signature was missing from message') } - const publicKey = unmarshalPublicKey(peerId.publicKey) + const publicKey = unmarshalPublicKey(payloadPeerId.publicKey) const valid = await publicKey.verify(generatedPayload, payload.identitySig) @@ -104,7 +98,7 @@ export async function verifySignedPayload ( throw new Error("Static key doesn't match to peer that signed payload!") } - return peerId + return payloadPeerId } export function isValidPublicKey (pk: bytes): boolean {