Skip to content

Commit

Permalink
all: use Go 1.19 and its atomic types
Browse files Browse the repository at this point in the history
Signed-off-by: Brad Fitzpatrick <[email protected]>
Signed-off-by: Jason A. Donenfeld <[email protected]>
  • Loading branch information
bradfitz authored and zx2c4 committed Sep 4, 2022
1 parent d1d0842 commit 3a0dfef
Show file tree
Hide file tree
Showing 20 changed files with 156 additions and 246 deletions.
32 changes: 16 additions & 16 deletions conn/bind_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ type afWinRingBind struct {
type WinRingBind struct {
v4, v6 afWinRingBind
mu sync.RWMutex
isOpen uint32
isOpen atomic.Uint32 // 0, 1, or 2
}

func NewDefaultBind() Bind { return NewWinRingBind() }
Expand Down Expand Up @@ -212,7 +212,7 @@ func (bind *afWinRingBind) CloseAndZero() {
}

func (bind *WinRingBind) closeAndZero() {
atomic.StoreUint32(&bind.isOpen, 0)
bind.isOpen.Store(0)
bind.v4.CloseAndZero()
bind.v6.CloseAndZero()
}
Expand Down Expand Up @@ -276,7 +276,7 @@ func (bind *WinRingBind) Open(port uint16) (recvFns []ReceiveFunc, selectedPort
bind.closeAndZero()
}
}()
if atomic.LoadUint32(&bind.isOpen) != 0 {
if bind.isOpen.Load() != 0 {
return nil, 0, ErrBindAlreadyOpen
}
var sa windows.Sockaddr
Expand All @@ -299,17 +299,17 @@ func (bind *WinRingBind) Open(port uint16) (recvFns []ReceiveFunc, selectedPort
return nil, 0, err
}
}
atomic.StoreUint32(&bind.isOpen, 1)
bind.isOpen.Store(1)
return []ReceiveFunc{bind.receiveIPv4, bind.receiveIPv6}, selectedPort, err
}

func (bind *WinRingBind) Close() error {
bind.mu.RLock()
if atomic.LoadUint32(&bind.isOpen) != 1 {
if bind.isOpen.Load() != 1 {
bind.mu.RUnlock()
return nil
}
atomic.StoreUint32(&bind.isOpen, 2)
bind.isOpen.Store(2)
windows.PostQueuedCompletionStatus(bind.v4.rx.iocp, 0, 0, nil)
windows.PostQueuedCompletionStatus(bind.v4.tx.iocp, 0, 0, nil)
windows.PostQueuedCompletionStatus(bind.v6.rx.iocp, 0, 0, nil)
Expand Down Expand Up @@ -345,8 +345,8 @@ func (bind *afWinRingBind) InsertReceiveRequest() error {
//go:linkname procyield runtime.procyield
func procyield(cycles uint32)

func (bind *afWinRingBind) Receive(buf []byte, isOpen *uint32) (int, Endpoint, error) {
if atomic.LoadUint32(isOpen) != 1 {
func (bind *afWinRingBind) Receive(buf []byte, isOpen *atomic.Uint32) (int, Endpoint, error) {
if isOpen.Load() != 1 {
return 0, nil, net.ErrClosed
}
bind.rx.mu.Lock()
Expand All @@ -359,7 +359,7 @@ retry:
count = 0
for tries := 0; count == 0 && tries < receiveSpins; tries++ {
if tries > 0 {
if atomic.LoadUint32(isOpen) != 1 {
if isOpen.Load() != 1 {
return 0, nil, net.ErrClosed
}
procyield(1)
Expand All @@ -378,7 +378,7 @@ retry:
if err != nil {
return 0, nil, err
}
if atomic.LoadUint32(isOpen) != 1 {
if isOpen.Load() != 1 {
return 0, nil, net.ErrClosed
}
count = winrio.DequeueCompletion(bind.rx.cq, results[:])
Expand All @@ -395,7 +395,7 @@ retry:
// huge packets. Just try again when this happens. The infinite loop this could cause is still limited to
// attacker bandwidth, just like the rest of the receive path.
if windows.Errno(results[0].Status) == windows.WSAEMSGSIZE {
if atomic.LoadUint32(isOpen) != 1 {
if isOpen.Load() != 1 {
return 0, nil, net.ErrClosed
}
goto retry
Expand All @@ -421,8 +421,8 @@ func (bind *WinRingBind) receiveIPv6(buf []byte) (int, Endpoint, error) {
return bind.v6.Receive(buf, &bind.isOpen)
}

func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *uint32) error {
if atomic.LoadUint32(isOpen) != 1 {
func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *atomic.Uint32) error {
if isOpen.Load() != 1 {
return net.ErrClosed
}
if len(buf) > bytesPerPacket {
Expand All @@ -444,7 +444,7 @@ func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *uint3
if err != nil {
return err
}
if atomic.LoadUint32(isOpen) != 1 {
if isOpen.Load() != 1 {
return net.ErrClosed
}
count = winrio.DequeueCompletion(bind.tx.cq, results[:])
Expand Down Expand Up @@ -538,7 +538,7 @@ func (bind *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole
func (bind *WinRingBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
bind.mu.RLock()
defer bind.mu.RUnlock()
if atomic.LoadUint32(&bind.isOpen) != 1 {
if bind.isOpen.Load() != 1 {
return net.ErrClosed
}
err := bindSocketToInterface4(bind.v4.sock, interfaceIndex)
Expand All @@ -552,7 +552,7 @@ func (bind *WinRingBind) BindSocketToInterface4(interfaceIndex uint32, blackhole
func (bind *WinRingBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
bind.mu.RLock()
defer bind.mu.RUnlock()
if atomic.LoadUint32(&bind.isOpen) != 1 {
if bind.isOpen.Load() != 1 {
return net.ErrClosed
}
err := bindSocketToInterface6(bind.v6.sock, interfaceIndex)
Expand Down
24 changes: 0 additions & 24 deletions device/alignment_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,30 +18,6 @@ func checkAlignment(t *testing.T, name string, offset uintptr) {
}
}

// TestPeerAlignment checks that atomically-accessed fields are
// aligned to 64-bit boundaries, as required by the atomic package.
//
// Unfortunately, violating this rule on 32-bit platforms results in a
// hard segfault at runtime.
func TestPeerAlignment(t *testing.T) {
var p Peer

typ := reflect.TypeOf(&p).Elem()
t.Logf("Peer type size: %d, with fields:", typ.Size())
for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i)
t.Logf("\t%30s\toffset=%3v\t(type size=%3d, align=%d)",
field.Name,
field.Offset,
field.Type.Size(),
field.Type.Align(),
)
}

checkAlignment(t, "Peer.stats", unsafe.Offsetof(p.stats))
checkAlignment(t, "Peer.isRunning", unsafe.Offsetof(p.isRunning))
}

// TestDeviceAlignment checks that atomically-accessed fields are
// aligned to 64-bit boundaries, as required by the atomic package.
//
Expand Down
31 changes: 15 additions & 16 deletions device/device.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ type Device struct {
// will become the actual state; Up can fail.
// The device can also change state multiple times between time of check and time of use.
// Unsynchronized uses of state must therefore be advisory/best-effort only.
state uint32 // actually a deviceState, but typed uint32 for convenience
state atomic.Uint32 // actually a deviceState, but typed uint32 for convenience
// stopping blocks until all inputs to Device have been closed.
stopping sync.WaitGroup
// mu protects state changes.
Expand Down Expand Up @@ -60,7 +60,7 @@ type Device struct {

// Keep this 8-byte aligned
rate struct {
underLoadUntil int64
underLoadUntil atomic.Int64
limiter ratelimiter.Ratelimiter
}

Expand All @@ -82,7 +82,7 @@ type Device struct {

tun struct {
device tun.Device
mtu int32
mtu atomic.Int32
}

ipcMutex sync.RWMutex
Expand All @@ -94,10 +94,9 @@ type Device struct {
// There are three states: down, up, closed.
// Transitions:
//
// down -----+
// ↑↓ ↓
// up -> closed
//
// down -----+
// ↑↓ ↓
// up -> closed
type deviceState uint32

//go:generate go run golang.org/x/tools/cmd/stringer -type deviceState -trimprefix=deviceState
Expand All @@ -110,7 +109,7 @@ const (
// deviceState returns device.state.state as a deviceState
// See those docs for how to interpret this value.
func (device *Device) deviceState() deviceState {
return deviceState(atomic.LoadUint32(&device.state.state))
return deviceState(device.state.state.Load())
}

// isClosed reports whether the device is closed (or is closing).
Expand Down Expand Up @@ -149,14 +148,14 @@ func (device *Device) changeState(want deviceState) (err error) {
case old:
return nil
case deviceStateUp:
atomic.StoreUint32(&device.state.state, uint32(deviceStateUp))
device.state.state.Store(uint32(deviceStateUp))
err = device.upLocked()
if err == nil {
break
}
fallthrough // up failed; bring the device all the way back down
case deviceStateDown:
atomic.StoreUint32(&device.state.state, uint32(deviceStateDown))
device.state.state.Store(uint32(deviceStateDown))
errDown := device.downLocked()
if err == nil {
err = errDown
Expand All @@ -182,7 +181,7 @@ func (device *Device) upLocked() error {
device.peers.RLock()
for _, peer := range device.peers.keyMap {
peer.Start()
if atomic.LoadUint32(&peer.persistentKeepaliveInterval) > 0 {
if peer.persistentKeepaliveInterval.Load() > 0 {
peer.SendKeepalive()
}
}
Expand Down Expand Up @@ -219,11 +218,11 @@ func (device *Device) IsUnderLoad() bool {
now := time.Now()
underLoad := len(device.queue.handshake.c) >= QueueHandshakeSize/8
if underLoad {
atomic.StoreInt64(&device.rate.underLoadUntil, now.Add(UnderLoadAfterTime).UnixNano())
device.rate.underLoadUntil.Store(now.Add(UnderLoadAfterTime).UnixNano())
return true
}
// check if recently under load
return atomic.LoadInt64(&device.rate.underLoadUntil) > now.UnixNano()
return device.rate.underLoadUntil.Load() > now.UnixNano()
}

func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
Expand Down Expand Up @@ -283,7 +282,7 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {

func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device {
device := new(Device)
device.state.state = uint32(deviceStateDown)
device.state.state.Store(uint32(deviceStateDown))
device.closed = make(chan struct{})
device.log = logger
device.net.bind = bind
Expand All @@ -293,7 +292,7 @@ func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device {
device.log.Errorf("Trouble determining MTU, assuming default: %v", err)
mtu = DefaultMTU
}
device.tun.mtu = int32(mtu)
device.tun.mtu.Store(int32(mtu))
device.peers.keyMap = make(map[NoisePublicKey]*Peer)
device.rate.limiter.Init()
device.indexTable.Init()
Expand Down Expand Up @@ -359,7 +358,7 @@ func (device *Device) Close() {
if device.isClosed() {
return
}
atomic.StoreUint32(&device.state.state, uint32(deviceStateClosed))
device.state.state.Store(uint32(deviceStateClosed))
device.log.Verbosef("Device closing")

device.tun.device.Close()
Expand Down
6 changes: 3 additions & 3 deletions device/device_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ func BenchmarkThroughput(b *testing.B) {

// Measure how long it takes to receive b.N packets,
// starting when we receive the first packet.
var recv uint64
var recv atomic.Uint64
var elapsed time.Duration
var wg sync.WaitGroup
wg.Add(1)
Expand All @@ -342,7 +342,7 @@ func BenchmarkThroughput(b *testing.B) {
var start time.Time
for {
<-pair[0].tun.Inbound
new := atomic.AddUint64(&recv, 1)
new := recv.Add(1)
if new == 1 {
start = time.Now()
}
Expand All @@ -358,7 +358,7 @@ func BenchmarkThroughput(b *testing.B) {
ping := tuntest.Ping(pair[0].ip, pair[1].ip)
pingc := pair[1].tun.Outbound
var sent uint64
for atomic.LoadUint64(&recv) != uint64(b.N) {
for recv.Load() != uint64(b.N) {
sent++
pingc <- ping
}
Expand Down
13 changes: 2 additions & 11 deletions device/keypair.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"sync"
"sync/atomic"
"time"
"unsafe"

"golang.zx2c4.com/wireguard/replay"
)
Expand All @@ -23,7 +22,7 @@ import (
*/

type Keypair struct {
sendNonce uint64 // accessed atomically
sendNonce atomic.Uint64
send cipher.AEAD
receive cipher.AEAD
replayFilter replay.Filter
Expand All @@ -37,15 +36,7 @@ type Keypairs struct {
sync.RWMutex
current *Keypair
previous *Keypair
next *Keypair
}

func (kp *Keypairs) storeNext(next *Keypair) {
atomic.StorePointer((*unsafe.Pointer)((unsafe.Pointer)(&kp.next)), (unsafe.Pointer)(next))
}

func (kp *Keypairs) loadNext() *Keypair {
return (*Keypair)(atomic.LoadPointer((*unsafe.Pointer)((unsafe.Pointer)(&kp.next))))
next atomic.Pointer[Keypair]
}

func (kp *Keypairs) Current() *Keypair {
Expand Down
41 changes: 0 additions & 41 deletions device/misc.go

This file was deleted.

Loading

0 comments on commit 3a0dfef

Please sign in to comment.