Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: make handshake abortable #442

Merged
merged 1 commit into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 20 additions & 13 deletions src/noise.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { unmarshalPrivateKey } from '@libp2p/crypto/keys'
import { type MultiaddrConnection, type SecuredConnection, type PeerId, CodeError, type PrivateKey, serviceCapabilities, isPeerId } from '@libp2p/interface'
import { type MultiaddrConnection, type SecuredConnection, type PeerId, CodeError, type PrivateKey, serviceCapabilities, isPeerId, type AbortOptions } from '@libp2p/interface'
import { peerIdFromKeys } from '@libp2p/peer-id'
import { decode } from 'it-length-prefixed'
import { lpStream, type LengthPrefixedStream } from 'it-length-prefixed-stream'
Expand Down Expand Up @@ -72,10 +72,10 @@ export class Noise implements INoiseConnection {
* @param connection - streaming iterable duplex that will be encrypted
* @param remotePeer - PeerId of the remote peer. Used to validate the integrity of the remote peer.
*/
public async secureOutbound <Stream extends Duplex<AsyncGenerator<Uint8Array | Uint8ArrayList>> = MultiaddrConnection> (connection: Stream, remotePeer?: PeerId): Promise<SecuredConnection<Stream, NoiseExtensions>>
public async secureOutbound <Stream extends Duplex<AsyncGenerator<Uint8Array | Uint8ArrayList>> = MultiaddrConnection> (connection: Stream, options?: { remotePeer?: PeerId, signal?: AbortSignal }): Promise<SecuredConnection<Stream, NoiseExtensions>>
public async secureOutbound <Stream extends Duplex<AsyncGenerator<Uint8Array | Uint8ArrayList>> = MultiaddrConnection> (localPeer: PeerId, connection: Stream, remotePeer?: PeerId): Promise<SecuredConnection<Stream, NoiseExtensions>>
public async secureOutbound <Stream extends Duplex<AsyncGenerator<Uint8Array | Uint8ArrayList>> = MultiaddrConnection> (...args: any[]): Promise<SecuredConnection<Stream, NoiseExtensions>> {
const { localPeer, connection, remotePeer } = this.parseArgs<Stream>(args)
const { localPeer, connection, remotePeer, signal } = this.parseArgs<Stream>(args)

const wrappedConnection = lpStream(
connection,
Expand All @@ -96,7 +96,9 @@ export class Noise implements INoiseConnection {
const handshake = await this.performHandshakeInitiator(
wrappedConnection,
privateKey,
remoteIdentityKey
remoteIdentityKey, {
signal
}
)
const conn = await this.createSecureConnection(wrappedConnection, handshake)

Expand All @@ -117,10 +119,10 @@ export class Noise implements INoiseConnection {
* @param connection - streaming iterable duplex that will be encrypted.
* @param remotePeer - optional PeerId of the initiating peer, if known. This may only exist during transport upgrades.
*/
public async secureInbound <Stream extends Duplex<AsyncGenerator<Uint8Array | Uint8ArrayList>> = MultiaddrConnection> (connection: Stream, remotePeer?: PeerId): Promise<SecuredConnection<Stream, NoiseExtensions>>
public async secureInbound <Stream extends Duplex<AsyncGenerator<Uint8Array | Uint8ArrayList>> = MultiaddrConnection> (connection: Stream, options?: { remotePeer?: PeerId, signal?: AbortSignal }): Promise<SecuredConnection<Stream, NoiseExtensions>>
public async secureInbound <Stream extends Duplex<AsyncGenerator<Uint8Array | Uint8ArrayList>> = MultiaddrConnection> (localPeer: PeerId, connection: Stream, remotePeer?: PeerId): Promise<SecuredConnection<Stream, NoiseExtensions>>
public async secureInbound <Stream extends Duplex<AsyncGenerator<Uint8Array | Uint8ArrayList>> = MultiaddrConnection> (...args: any[]): Promise<SecuredConnection<Stream, NoiseExtensions>> {
const { localPeer, connection, remotePeer } = this.parseArgs<Stream>(args)
const { localPeer, connection, remotePeer, signal } = this.parseArgs<Stream>(args)

const wrappedConnection = lpStream(
connection,
Expand All @@ -141,7 +143,9 @@ export class Noise implements INoiseConnection {
const handshake = await this.performHandshakeResponder(
wrappedConnection,
privateKey,
remoteIdentityKey
remoteIdentityKey, {
signal
}
)
const conn = await this.createSecureConnection(wrappedConnection, handshake)

Expand All @@ -162,7 +166,8 @@ export class Noise implements INoiseConnection {
connection: LengthPrefixedStream,
// TODO: pass private key in noise constructor via Components
privateKey: PrivateKey,
remoteIdentityKey?: Uint8Array | Uint8ArrayList
remoteIdentityKey?: Uint8Array | Uint8ArrayList,
options?: AbortOptions
): Promise<HandshakeResult> {
let result: HandshakeResult
try {
Expand All @@ -175,7 +180,7 @@ export class Noise implements INoiseConnection {
prologue: this.prologue,
s: this.staticKey,
extensions: this.extensions
})
}, options)
this.metrics?.xxHandshakeSuccesses.increment()
} catch (e: unknown) {
this.metrics?.xxHandshakeErrors.increment()
Expand All @@ -192,7 +197,8 @@ export class Noise implements INoiseConnection {
connection: LengthPrefixedStream,
// TODO: pass private key in noise constructor via Components
privateKey: PrivateKey,
remoteIdentityKey?: Uint8Array | Uint8ArrayList
remoteIdentityKey?: Uint8Array | Uint8ArrayList,
options?: AbortOptions
): Promise<HandshakeResult> {
let result: HandshakeResult
try {
Expand All @@ -205,7 +211,7 @@ export class Noise implements INoiseConnection {
prologue: this.prologue,
s: this.staticKey,
extensions: this.extensions
})
}, options)
this.metrics?.xxHandshakeSuccesses.increment()
} catch (e: unknown) {
this.metrics?.xxHandshakeErrors.increment()
Expand Down Expand Up @@ -241,7 +247,7 @@ export class Noise implements INoiseConnection {
* TODO: remove this after `[email protected]` is released and only support the
* newer style
*/
private parseArgs <Stream extends Duplex<AsyncGenerator<Uint8Array | Uint8ArrayList>> = MultiaddrConnection> (args: any[]): { localPeer: PeerId, connection: Stream, remotePeer?: PeerId } {
private parseArgs <Stream extends Duplex<AsyncGenerator<Uint8Array | Uint8ArrayList>> = MultiaddrConnection> (args: any[]): { localPeer: PeerId, connection: Stream, remotePeer?: PeerId, signal?: AbortSignal } {
// if the first argument is a peer id, we're using the [email protected] style
if (isPeerId(args[0])) {
return {
Expand All @@ -256,7 +262,8 @@ export class Noise implements INoiseConnection {
return {
localPeer: this.components.peerId,
connection: args[0],
remotePeer: args[1]
remotePeer: args[1]?.remotePeer,
signal: args[1]?.signal
}
}
}
Expand Down
17 changes: 9 additions & 8 deletions src/performHandshake.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ import {
import { ZEROLEN, XXHandshakeState } from './protocol.js'
import { createHandshakePayload, decodeHandshakePayload } from './utils.js'
import type { HandshakeResult, HandshakeParams } from './types.js'
import type { AbortOptions } from '@libp2p/interface'

export async function performHandshakeInitiator (init: HandshakeParams): Promise<HandshakeResult> {
export async function performHandshakeInitiator (init: HandshakeParams, options?: AbortOptions): Promise<HandshakeResult> {
const { log, connection, crypto, privateKey, prologue, s, remoteIdentityKey, extensions } = init

const payload = await createHandshakePayload(privateKey, s.publicKey, extensions)
Expand All @@ -23,12 +24,12 @@ export async function performHandshakeInitiator (init: HandshakeParams): Promise

logLocalStaticKeys(xx.s, log)
log.trace('Stage 0 - Initiator starting to send first message.')
await connection.write(xx.writeMessageA(ZEROLEN))
await connection.write(xx.writeMessageA(ZEROLEN), options)
log.trace('Stage 0 - Initiator finished sending first message.')
logLocalEphemeralKeys(xx.e, log)

log.trace('Stage 1 - Initiator waiting to receive first message from responder...')
const plaintext = xx.readMessageB(await connection.read())
const plaintext = xx.readMessageB(await connection.read(options))
log.trace('Stage 1 - Initiator received the message.')
logRemoteEphemeralKey(xx.re, log)
logRemoteStaticKey(xx.rs, log)
Expand All @@ -38,7 +39,7 @@ export async function performHandshakeInitiator (init: HandshakeParams): Promise
log.trace('All good with the signature!')

log.trace('Stage 2 - Initiator sending third handshake message.')
await connection.write(xx.writeMessageC(payload))
await connection.write(xx.writeMessageC(payload), options)
log.trace('Stage 2 - Initiator sent message with signed payload.')

const [cs1, cs2] = xx.ss.split()
Expand All @@ -51,7 +52,7 @@ export async function performHandshakeInitiator (init: HandshakeParams): Promise
}
}

export async function performHandshakeResponder (init: HandshakeParams): Promise<HandshakeResult> {
export async function performHandshakeResponder (init: HandshakeParams, options?: AbortOptions): Promise<HandshakeResult> {
const { log, connection, crypto, privateKey, prologue, s, remoteIdentityKey, extensions } = init

const payload = await createHandshakePayload(privateKey, s.publicKey, extensions)
Expand All @@ -65,17 +66,17 @@ export async function performHandshakeResponder (init: HandshakeParams): Promise

logLocalStaticKeys(xx.s, log)
log.trace('Stage 0 - Responder waiting to receive first message.')
xx.readMessageA(await connection.read())
xx.readMessageA(await connection.read(options))
log.trace('Stage 0 - Responder received first message.')
logRemoteEphemeralKey(xx.re, log)

log.trace('Stage 1 - Responder sending out first message with signed payload and static key.')
await connection.write(xx.writeMessageB(payload))
await connection.write(xx.writeMessageB(payload), options)
log.trace('Stage 1 - Responder sent the second handshake message with signed payload.')
logLocalEphemeralKeys(xx.e, log)

log.trace('Stage 2 - Responder waiting for third handshake message...')
const plaintext = xx.readMessageC(await connection.read())
const plaintext = xx.readMessageC(await connection.read(options))
log.trace('Stage 2 - Responder received the message, finished handshake.')
const receivedPayload = await decodeHandshakePayload(plaintext, xx.rs, remoteIdentityKey)

Expand Down
27 changes: 27 additions & 0 deletions test/noise.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -179,4 +179,31 @@ describe('Noise', () => {
assert(false, err.message)
}
})

it('should abort noise handshake', async () => {
const abortController = new AbortController()
abortController.abort()

const noiseInit = new Noise({
peerId: localPeer,
logger: defaultLogger()
}, { staticNoiseKey: undefined, extensions: undefined })
const noiseResp = new Noise({
peerId: remotePeer,
logger: defaultLogger()
}, { staticNoiseKey: undefined, extensions: undefined })

const [inboundConnection, outboundConnection] = duplexPair<Uint8Array | Uint8ArrayList>()

await expect(Promise.all([
noiseInit.secureOutbound(outboundConnection, {
remotePeer,
signal: abortController.signal
}),
noiseResp.secureInbound(inboundConnection, {
remotePeer: localPeer
})
])).to.eventually.be.rejected
.with.property('name', 'AbortError')
})
})
Loading