Skip to content

Commit

Permalink
device: use atomic access for unlocked keypair.next
Browse files Browse the repository at this point in the history
This code was attempting to use the "compare racily, then lock
and compare again" idiom to try and reduce lock contention.
However, that idiom is not safe to use unless the comparison
uses atomic operations, which this does not.

Reported-by: David Anderson <[email protected]>
Signed-off-by: Jason A. Donenfeld <[email protected]>
  • Loading branch information
zx2c4 committed May 2, 2020
1 parent fdba6c1 commit 8a3c04a
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 11 deletions.
10 changes: 10 additions & 0 deletions device/keypair.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ package device
import (
"crypto/cipher"
"sync"
"sync/atomic"
"time"
"unsafe"

"golang.zx2c4.com/wireguard/replay"
)
Expand Down Expand Up @@ -38,6 +40,14 @@ type Keypairs struct {
next *Keypair
}

func (kp *Keypairs) storeNext(next *Keypair) {
atomic.StorePointer((*unsafe.Pointer)((unsafe.Pointer)(&kp.next)), (unsafe.Pointer)(next))
}

func (kp *Keypairs) loadNext() *Keypair {
return (*Keypair)(atomic.LoadPointer((*unsafe.Pointer)((unsafe.Pointer)(&kp.next))))
}

func (kp *Keypairs) Current() *Keypair {
kp.RLock()
defer kp.RUnlock()
Expand Down
16 changes: 9 additions & 7 deletions device/noise-protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"golang.org/x/crypto/blake2s"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/poly1305"

"golang.zx2c4.com/wireguard/tai64n"
)

Expand Down Expand Up @@ -583,12 +584,12 @@ func (peer *Peer) BeginSymmetricSession() error {
defer keypairs.Unlock()

previous := keypairs.previous
next := keypairs.next
next := keypairs.loadNext()
current := keypairs.current

if isInitiator {
if next != nil {
keypairs.next = nil
keypairs.storeNext(nil)
keypairs.previous = next
device.DeleteKeypair(current)
} else {
Expand All @@ -597,7 +598,7 @@ func (peer *Peer) BeginSymmetricSession() error {
device.DeleteKeypair(previous)
keypairs.current = keypair
} else {
keypairs.next = keypair
keypairs.storeNext(keypair)
device.DeleteKeypair(next)
keypairs.previous = nil
device.DeleteKeypair(previous)
Expand All @@ -608,18 +609,19 @@ func (peer *Peer) BeginSymmetricSession() error {

func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool {
keypairs := &peer.keypairs
if keypairs.next != receivedKeypair {

if keypairs.loadNext() != receivedKeypair {
return false
}
keypairs.Lock()
defer keypairs.Unlock()
if keypairs.next != receivedKeypair {
if keypairs.loadNext() != receivedKeypair {
return false
}
old := keypairs.previous
keypairs.previous = keypairs.current
peer.device.DeleteKeypair(old)
keypairs.current = keypairs.next
keypairs.next = nil
keypairs.current = keypairs.loadNext()
keypairs.storeNext(nil)
return true
}
2 changes: 1 addition & 1 deletion device/noise_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ func TestNoiseHandshake(t *testing.T) {
t.Fatal("failed to derive keypair for peer 2", err)
}

key1 := peer1.keypairs.next
key1 := peer1.keypairs.loadNext()
key2 := peer2.keypairs.current

// encrypting / decryption test
Expand Down
6 changes: 3 additions & 3 deletions device/peer.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,10 +223,10 @@ func (peer *Peer) ZeroAndFlushAll() {
keypairs.Lock()
device.DeleteKeypair(keypairs.previous)
device.DeleteKeypair(keypairs.current)
device.DeleteKeypair(keypairs.next)
device.DeleteKeypair(keypairs.loadNext())
keypairs.previous = nil
keypairs.current = nil
keypairs.next = nil
keypairs.storeNext(nil)
keypairs.Unlock()

// clear handshake state
Expand Down Expand Up @@ -254,7 +254,7 @@ func (peer *Peer) ExpireCurrentKeypairs() {
keypairs.current.sendNonce = RejectAfterMessages
}
if keypairs.next != nil {
keypairs.next.sendNonce = RejectAfterMessages
keypairs.loadNext().sendNonce = RejectAfterMessages
}
keypairs.Unlock()
}
Expand Down

0 comments on commit 8a3c04a

Please sign in to comment.