diff --git a/conn/bind_windows.go b/conn/bind_windows.go index 9268bc15f..c066efa4e 100644 --- a/conn/bind_windows.go +++ b/conn/bind_windows.go @@ -74,7 +74,7 @@ type afWinRingBind struct { type WinRingBind struct { v4, v6 afWinRingBind mu sync.RWMutex - isOpen uint32 + isOpen atomic.Uint32 // 0, 1, or 2 } func NewDefaultBind() Bind { return NewWinRingBind() } @@ -212,7 +212,7 @@ func (bind *afWinRingBind) CloseAndZero() { } func (bind *WinRingBind) closeAndZero() { - atomic.StoreUint32(&bind.isOpen, 0) + bind.isOpen.Store(0) bind.v4.CloseAndZero() bind.v6.CloseAndZero() } @@ -276,7 +276,7 @@ func (bind *WinRingBind) Open(port uint16) (recvFns []ReceiveFunc, selectedPort bind.closeAndZero() } }() - if atomic.LoadUint32(&bind.isOpen) != 0 { + if bind.isOpen.Load() != 0 { return nil, 0, ErrBindAlreadyOpen } var sa windows.Sockaddr @@ -299,17 +299,17 @@ func (bind *WinRingBind) Open(port uint16) (recvFns []ReceiveFunc, selectedPort return nil, 0, err } } - atomic.StoreUint32(&bind.isOpen, 1) + bind.isOpen.Store(1) return []ReceiveFunc{bind.receiveIPv4, bind.receiveIPv6}, selectedPort, err } func (bind *WinRingBind) Close() error { bind.mu.RLock() - if atomic.LoadUint32(&bind.isOpen) != 1 { + if bind.isOpen.Load() != 1 { bind.mu.RUnlock() return nil } - atomic.StoreUint32(&bind.isOpen, 2) + bind.isOpen.Store(2) windows.PostQueuedCompletionStatus(bind.v4.rx.iocp, 0, 0, nil) windows.PostQueuedCompletionStatus(bind.v4.tx.iocp, 0, 0, nil) windows.PostQueuedCompletionStatus(bind.v6.rx.iocp, 0, 0, nil) @@ -345,8 +345,8 @@ func (bind *afWinRingBind) InsertReceiveRequest() error { //go:linkname procyield runtime.procyield func procyield(cycles uint32) -func (bind *afWinRingBind) Receive(buf []byte, isOpen *uint32) (int, Endpoint, error) { - if atomic.LoadUint32(isOpen) != 1 { +func (bind *afWinRingBind) Receive(buf []byte, isOpen *atomic.Uint32) (int, Endpoint, error) { + if isOpen.Load() != 1 { return 0, nil, net.ErrClosed } bind.rx.mu.Lock() @@ -359,7 +359,7 @@ retry: count = 0 for tries := 0; count == 0 && tries < receiveSpins; tries++ { if tries > 0 { - if atomic.LoadUint32(isOpen) != 1 { + if isOpen.Load() != 1 { return 0, nil, net.ErrClosed } procyield(1) @@ -378,7 +378,7 @@ retry: if err != nil { return 0, nil, err } - if atomic.LoadUint32(isOpen) != 1 { + if isOpen.Load() != 1 { return 0, nil, net.ErrClosed } count = winrio.DequeueCompletion(bind.rx.cq, results[:]) @@ -395,7 +395,7 @@ retry: // huge packets. Just try again when this happens. The infinite loop this could cause is still limited to // attacker bandwidth, just like the rest of the receive path. if windows.Errno(results[0].Status) == windows.WSAEMSGSIZE { - if atomic.LoadUint32(isOpen) != 1 { + if isOpen.Load() != 1 { return 0, nil, net.ErrClosed } goto retry @@ -421,8 +421,8 @@ func (bind *WinRingBind) receiveIPv6(buf []byte) (int, Endpoint, error) { return bind.v6.Receive(buf, &bind.isOpen) } -func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *uint32) error { - if atomic.LoadUint32(isOpen) != 1 { +func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *atomic.Uint32) error { + if isOpen.Load() != 1 { return net.ErrClosed } if len(buf) > bytesPerPacket { @@ -444,7 +444,7 @@ func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *uint3 if err != nil { return err } - if atomic.LoadUint32(isOpen) != 1 { + if isOpen.Load() != 1 { return net.ErrClosed } count = winrio.DequeueCompletion(bind.tx.cq, results[:]) @@ -538,7 +538,7 @@ func (bind *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole func (bind *WinRingBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error { bind.mu.RLock() defer bind.mu.RUnlock() - if atomic.LoadUint32(&bind.isOpen) != 1 { + if bind.isOpen.Load() != 1 { return net.ErrClosed } err := bindSocketToInterface4(bind.v4.sock, interfaceIndex) @@ -552,7 +552,7 @@ func (bind *WinRingBind) BindSocketToInterface4(interfaceIndex uint32, blackhole func (bind *WinRingBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error { bind.mu.RLock() defer bind.mu.RUnlock() - if atomic.LoadUint32(&bind.isOpen) != 1 { + if bind.isOpen.Load() != 1 { return net.ErrClosed } err := bindSocketToInterface6(bind.v6.sock, interfaceIndex) diff --git a/device/alignment_test.go b/device/alignment_test.go index a918112bb..bd2b02ba9 100644 --- a/device/alignment_test.go +++ b/device/alignment_test.go @@ -18,30 +18,6 @@ func checkAlignment(t *testing.T, name string, offset uintptr) { } } -// TestPeerAlignment checks that atomically-accessed fields are -// aligned to 64-bit boundaries, as required by the atomic package. -// -// Unfortunately, violating this rule on 32-bit platforms results in a -// hard segfault at runtime. -func TestPeerAlignment(t *testing.T) { - var p Peer - - typ := reflect.TypeOf(&p).Elem() - t.Logf("Peer type size: %d, with fields:", typ.Size()) - for i := 0; i < typ.NumField(); i++ { - field := typ.Field(i) - t.Logf("\t%30s\toffset=%3v\t(type size=%3d, align=%d)", - field.Name, - field.Offset, - field.Type.Size(), - field.Type.Align(), - ) - } - - checkAlignment(t, "Peer.stats", unsafe.Offsetof(p.stats)) - checkAlignment(t, "Peer.isRunning", unsafe.Offsetof(p.isRunning)) -} - // TestDeviceAlignment checks that atomically-accessed fields are // aligned to 64-bit boundaries, as required by the atomic package. // diff --git a/device/device.go b/device/device.go index 3625608db..c2a2683be 100644 --- a/device/device.go +++ b/device/device.go @@ -30,7 +30,7 @@ type Device struct { // will become the actual state; Up can fail. // The device can also change state multiple times between time of check and time of use. // Unsynchronized uses of state must therefore be advisory/best-effort only. - state uint32 // actually a deviceState, but typed uint32 for convenience + state atomic.Uint32 // actually a deviceState, but typed uint32 for convenience // stopping blocks until all inputs to Device have been closed. stopping sync.WaitGroup // mu protects state changes. @@ -60,7 +60,7 @@ type Device struct { // Keep this 8-byte aligned rate struct { - underLoadUntil int64 + underLoadUntil atomic.Int64 limiter ratelimiter.Ratelimiter } @@ -82,7 +82,7 @@ type Device struct { tun struct { device tun.Device - mtu int32 + mtu atomic.Int32 } ipcMutex sync.RWMutex @@ -94,10 +94,9 @@ type Device struct { // There are three states: down, up, closed. // Transitions: // -// down -----+ -// ↑↓ ↓ -// up -> closed -// +// down -----+ +// ↑↓ ↓ +// up -> closed type deviceState uint32 //go:generate go run golang.org/x/tools/cmd/stringer -type deviceState -trimprefix=deviceState @@ -110,7 +109,7 @@ const ( // deviceState returns device.state.state as a deviceState // See those docs for how to interpret this value. func (device *Device) deviceState() deviceState { - return deviceState(atomic.LoadUint32(&device.state.state)) + return deviceState(device.state.state.Load()) } // isClosed reports whether the device is closed (or is closing). @@ -149,14 +148,14 @@ func (device *Device) changeState(want deviceState) (err error) { case old: return nil case deviceStateUp: - atomic.StoreUint32(&device.state.state, uint32(deviceStateUp)) + device.state.state.Store(uint32(deviceStateUp)) err = device.upLocked() if err == nil { break } fallthrough // up failed; bring the device all the way back down case deviceStateDown: - atomic.StoreUint32(&device.state.state, uint32(deviceStateDown)) + device.state.state.Store(uint32(deviceStateDown)) errDown := device.downLocked() if err == nil { err = errDown @@ -182,7 +181,7 @@ func (device *Device) upLocked() error { device.peers.RLock() for _, peer := range device.peers.keyMap { peer.Start() - if atomic.LoadUint32(&peer.persistentKeepaliveInterval) > 0 { + if peer.persistentKeepaliveInterval.Load() > 0 { peer.SendKeepalive() } } @@ -219,11 +218,11 @@ func (device *Device) IsUnderLoad() bool { now := time.Now() underLoad := len(device.queue.handshake.c) >= QueueHandshakeSize/8 if underLoad { - atomic.StoreInt64(&device.rate.underLoadUntil, now.Add(UnderLoadAfterTime).UnixNano()) + device.rate.underLoadUntil.Store(now.Add(UnderLoadAfterTime).UnixNano()) return true } // check if recently under load - return atomic.LoadInt64(&device.rate.underLoadUntil) > now.UnixNano() + return device.rate.underLoadUntil.Load() > now.UnixNano() } func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { @@ -283,7 +282,7 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device { device := new(Device) - device.state.state = uint32(deviceStateDown) + device.state.state.Store(uint32(deviceStateDown)) device.closed = make(chan struct{}) device.log = logger device.net.bind = bind @@ -293,7 +292,7 @@ func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device { device.log.Errorf("Trouble determining MTU, assuming default: %v", err) mtu = DefaultMTU } - device.tun.mtu = int32(mtu) + device.tun.mtu.Store(int32(mtu)) device.peers.keyMap = make(map[NoisePublicKey]*Peer) device.rate.limiter.Init() device.indexTable.Init() @@ -359,7 +358,7 @@ func (device *Device) Close() { if device.isClosed() { return } - atomic.StoreUint32(&device.state.state, uint32(deviceStateClosed)) + device.state.state.Store(uint32(deviceStateClosed)) device.log.Verbosef("Device closing") device.tun.device.Close() diff --git a/device/device_test.go b/device/device_test.go index ab7236efa..8cffe08d8 100644 --- a/device/device_test.go +++ b/device/device_test.go @@ -333,7 +333,7 @@ func BenchmarkThroughput(b *testing.B) { // Measure how long it takes to receive b.N packets, // starting when we receive the first packet. - var recv uint64 + var recv atomic.Uint64 var elapsed time.Duration var wg sync.WaitGroup wg.Add(1) @@ -342,7 +342,7 @@ func BenchmarkThroughput(b *testing.B) { var start time.Time for { <-pair[0].tun.Inbound - new := atomic.AddUint64(&recv, 1) + new := recv.Add(1) if new == 1 { start = time.Now() } @@ -358,7 +358,7 @@ func BenchmarkThroughput(b *testing.B) { ping := tuntest.Ping(pair[0].ip, pair[1].ip) pingc := pair[1].tun.Outbound var sent uint64 - for atomic.LoadUint64(&recv) != uint64(b.N) { + for recv.Load() != uint64(b.N) { sent++ pingc <- ping } diff --git a/device/keypair.go b/device/keypair.go index 788c947f4..206d7a906 100644 --- a/device/keypair.go +++ b/device/keypair.go @@ -10,7 +10,6 @@ import ( "sync" "sync/atomic" "time" - "unsafe" "golang.zx2c4.com/wireguard/replay" ) @@ -23,7 +22,7 @@ import ( */ type Keypair struct { - sendNonce uint64 // accessed atomically + sendNonce atomic.Uint64 send cipher.AEAD receive cipher.AEAD replayFilter replay.Filter @@ -37,15 +36,7 @@ type Keypairs struct { sync.RWMutex current *Keypair previous *Keypair - next *Keypair -} - -func (kp *Keypairs) storeNext(next *Keypair) { - atomic.StorePointer((*unsafe.Pointer)((unsafe.Pointer)(&kp.next)), (unsafe.Pointer)(next)) -} - -func (kp *Keypairs) loadNext() *Keypair { - return (*Keypair)(atomic.LoadPointer((*unsafe.Pointer)((unsafe.Pointer)(&kp.next)))) + next atomic.Pointer[Keypair] } func (kp *Keypairs) Current() *Keypair { diff --git a/device/misc.go b/device/misc.go deleted file mode 100644 index 4126704ca..000000000 --- a/device/misc.go +++ /dev/null @@ -1,41 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. - */ - -package device - -import ( - "sync/atomic" -) - -/* Atomic Boolean */ - -const ( - AtomicFalse = int32(iota) - AtomicTrue -) - -type AtomicBool struct { - int32 -} - -func (a *AtomicBool) Get() bool { - return atomic.LoadInt32(&a.int32) == AtomicTrue -} - -func (a *AtomicBool) Swap(val bool) bool { - flag := AtomicFalse - if val { - flag = AtomicTrue - } - return atomic.SwapInt32(&a.int32, flag) == AtomicTrue -} - -func (a *AtomicBool) Set(val bool) { - flag := AtomicFalse - if val { - flag = AtomicTrue - } - atomic.StoreInt32(&a.int32, flag) -} diff --git a/device/noise-protocol.go b/device/noise-protocol.go index ffa04528b..410926ea4 100644 --- a/device/noise-protocol.go +++ b/device/noise-protocol.go @@ -282,7 +282,7 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { // lookup peer peer := device.LookupPeer(peerPK) - if peer == nil || !peer.isRunning.Get() { + if peer == nil || !peer.isRunning.Load() { return nil } @@ -581,12 +581,12 @@ func (peer *Peer) BeginSymmetricSession() error { defer keypairs.Unlock() previous := keypairs.previous - next := keypairs.loadNext() + next := keypairs.next.Load() current := keypairs.current if isInitiator { if next != nil { - keypairs.storeNext(nil) + keypairs.next.Store(nil) keypairs.previous = next device.DeleteKeypair(current) } else { @@ -595,7 +595,7 @@ func (peer *Peer) BeginSymmetricSession() error { device.DeleteKeypair(previous) keypairs.current = keypair } else { - keypairs.storeNext(keypair) + keypairs.next.Store(keypair) device.DeleteKeypair(next) keypairs.previous = nil device.DeleteKeypair(previous) @@ -607,18 +607,18 @@ func (peer *Peer) BeginSymmetricSession() error { func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool { keypairs := &peer.keypairs - if keypairs.loadNext() != receivedKeypair { + if keypairs.next.Load() != receivedKeypair { return false } keypairs.Lock() defer keypairs.Unlock() - if keypairs.loadNext() != receivedKeypair { + if keypairs.next.Load() != receivedKeypair { return false } old := keypairs.previous keypairs.previous = keypairs.current peer.device.DeleteKeypair(old) - keypairs.current = keypairs.loadNext() - keypairs.storeNext(nil) + keypairs.current = keypairs.next.Load() + keypairs.next.Store(nil) return true } diff --git a/device/noise_test.go b/device/noise_test.go index e2f23c6c2..7c84efc0b 100644 --- a/device/noise_test.go +++ b/device/noise_test.go @@ -148,7 +148,7 @@ func TestNoiseHandshake(t *testing.T) { t.Fatal("failed to derive keypair for peer 2", err) } - key1 := peer1.keypairs.loadNext() + key1 := peer1.keypairs.next.Load() key2 := peer2.keypairs.current // encrypting / decryption test diff --git a/device/peer.go b/device/peer.go index 5bd52df79..79feae795 100644 --- a/device/peer.go +++ b/device/peer.go @@ -16,24 +16,16 @@ import ( ) type Peer struct { - isRunning AtomicBool - sync.RWMutex // Mostly protects endpoint, but is generally taken whenever we modify peer - keypairs Keypairs - handshake Handshake - device *Device - endpoint conn.Endpoint - stopping sync.WaitGroup // routines pending stop - - // These fields are accessed with atomic operations, which must be - // 64-bit aligned even on 32-bit platforms. Go guarantees that an - // allocated struct will be 64-bit aligned. So we place - // atomically-accessed fields up front, so that they can share in - // this alignment before smaller fields throw it off. - stats struct { - txBytes uint64 // bytes send to peer (endpoint) - rxBytes uint64 // bytes received from peer - lastHandshakeNano int64 // nano seconds since epoch - } + isRunning atomic.Bool + sync.RWMutex // Mostly protects endpoint, but is generally taken whenever we modify peer + keypairs Keypairs + handshake Handshake + device *Device + endpoint conn.Endpoint + stopping sync.WaitGroup // routines pending stop + txBytes atomic.Uint64 // bytes send to peer (endpoint) + rxBytes atomic.Uint64 // bytes received from peer + lastHandshakeNano atomic.Int64 // nano seconds since epoch disableRoaming bool @@ -43,9 +35,9 @@ type Peer struct { newHandshake *Timer zeroKeyMaterial *Timer persistentKeepalive *Timer - handshakeAttempts uint32 - needAnotherKeepalive AtomicBool - sentLastMinuteHandshake AtomicBool + handshakeAttempts atomic.Uint32 + needAnotherKeepalive atomic.Bool + sentLastMinuteHandshake atomic.Bool } state struct { @@ -60,7 +52,7 @@ type Peer struct { cookieGenerator CookieGenerator trieEntries list.List - persistentKeepaliveInterval uint32 // accessed atomically + persistentKeepaliveInterval atomic.Uint32 } func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { @@ -133,7 +125,7 @@ func (peer *Peer) SendBuffer(buffer []byte) error { err := peer.device.net.bind.Send(buffer, peer.endpoint) if err == nil { - atomic.AddUint64(&peer.stats.txBytes, uint64(len(buffer))) + peer.txBytes.Add(uint64(len(buffer))) } return err } @@ -174,7 +166,7 @@ func (peer *Peer) Start() { peer.state.Lock() defer peer.state.Unlock() - if peer.isRunning.Get() { + if peer.isRunning.Load() { return } @@ -198,7 +190,7 @@ func (peer *Peer) Start() { go peer.RoutineSequentialSender() go peer.RoutineSequentialReceiver() - peer.isRunning.Set(true) + peer.isRunning.Store(true) } func (peer *Peer) ZeroAndFlushAll() { @@ -210,10 +202,10 @@ func (peer *Peer) ZeroAndFlushAll() { keypairs.Lock() device.DeleteKeypair(keypairs.previous) device.DeleteKeypair(keypairs.current) - device.DeleteKeypair(keypairs.loadNext()) + device.DeleteKeypair(keypairs.next.Load()) keypairs.previous = nil keypairs.current = nil - keypairs.storeNext(nil) + keypairs.next.Store(nil) keypairs.Unlock() // clear handshake state @@ -238,11 +230,10 @@ func (peer *Peer) ExpireCurrentKeypairs() { keypairs := &peer.keypairs keypairs.Lock() if keypairs.current != nil { - atomic.StoreUint64(&keypairs.current.sendNonce, RejectAfterMessages) + keypairs.current.sendNonce.Store(RejectAfterMessages) } - if keypairs.next != nil { - next := keypairs.loadNext() - atomic.StoreUint64(&next.sendNonce, RejectAfterMessages) + if next := keypairs.next.Load(); next != nil { + next.sendNonce.Store(RejectAfterMessages) } keypairs.Unlock() } diff --git a/device/pools.go b/device/pools.go index f40477b66..9da0f7996 100644 --- a/device/pools.go +++ b/device/pools.go @@ -14,7 +14,7 @@ type WaitPool struct { pool sync.Pool cond sync.Cond lock sync.Mutex - count uint32 + count atomic.Uint32 max uint32 } @@ -27,10 +27,10 @@ func NewWaitPool(max uint32, new func() any) *WaitPool { func (p *WaitPool) Get() any { if p.max != 0 { p.lock.Lock() - for atomic.LoadUint32(&p.count) >= p.max { + for p.count.Load() >= p.max { p.cond.Wait() } - atomic.AddUint32(&p.count, 1) + p.count.Add(1) p.lock.Unlock() } return p.pool.Get() @@ -41,7 +41,7 @@ func (p *WaitPool) Put(x any) { if p.max == 0 { return } - atomic.AddUint32(&p.count, ^uint32(0)) + p.count.Add(^uint32(0)) p.cond.Signal() } diff --git a/device/pools_test.go b/device/pools_test.go index 17e2298f9..48a98b0f2 100644 --- a/device/pools_test.go +++ b/device/pools_test.go @@ -17,29 +17,31 @@ import ( func TestWaitPool(t *testing.T) { t.Skip("Currently disabled") var wg sync.WaitGroup - trials := int32(100000) + var trials atomic.Int32 + startTrials := int32(100000) if raceEnabled { // This test can be very slow with -race. - trials /= 10 + startTrials /= 10 } + trials.Store(startTrials) workers := runtime.NumCPU() + 2 if workers-4 <= 0 { t.Skip("Not enough cores") } p := NewWaitPool(uint32(workers-4), func() any { return make([]byte, 16) }) wg.Add(workers) - max := uint32(0) + var max atomic.Uint32 updateMax := func() { - count := atomic.LoadUint32(&p.count) + count := p.count.Load() if count > p.max { t.Errorf("count (%d) > max (%d)", count, p.max) } for { - old := atomic.LoadUint32(&max) + old := max.Load() if count <= old { break } - if atomic.CompareAndSwapUint32(&max, old, count) { + if max.CompareAndSwap(old, count) { break } } @@ -47,7 +49,7 @@ func TestWaitPool(t *testing.T) { for i := 0; i < workers; i++ { go func() { defer wg.Done() - for atomic.AddInt32(&trials, -1) > 0 { + for trials.Add(-1) > 0 { updateMax() x := p.Get() updateMax() @@ -59,14 +61,15 @@ func TestWaitPool(t *testing.T) { }() } wg.Wait() - if max != p.max { + if max.Load() != p.max { t.Errorf("Actual maximum count (%d) != ideal maximum count (%d)", max, p.max) } } func BenchmarkWaitPool(b *testing.B) { var wg sync.WaitGroup - trials := int32(b.N) + var trials atomic.Int32 + trials.Store(int32(b.N)) workers := runtime.NumCPU() + 2 if workers-4 <= 0 { b.Skip("Not enough cores") @@ -77,7 +80,7 @@ func BenchmarkWaitPool(b *testing.B) { for i := 0; i < workers; i++ { go func() { defer wg.Done() - for atomic.AddInt32(&trials, -1) > 0 { + for trials.Add(-1) > 0 { x := p.Get() time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond) p.Put(x) diff --git a/device/receive.go b/device/receive.go index cc3449801..4dbf1e827 100644 --- a/device/receive.go +++ b/device/receive.go @@ -11,7 +11,6 @@ import ( "errors" "net" "sync" - "sync/atomic" "time" "golang.org/x/crypto/chacha20poly1305" @@ -52,12 +51,12 @@ func (elem *QueueInboundElement) clearPointers() { * NOTE: Not thread safe, but called by sequential receiver! */ func (peer *Peer) keepKeyFreshReceiving() { - if peer.timers.sentLastMinuteHandshake.Get() { + if peer.timers.sentLastMinuteHandshake.Load() { return } keypair := peer.keypairs.Current() if keypair != nil && keypair.isInitiator && time.Since(keypair.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) { - peer.timers.sentLastMinuteHandshake.Set(true) + peer.timers.sentLastMinuteHandshake.Store(true) peer.SendHandshakeInitiation(false) } } @@ -163,7 +162,7 @@ func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) { elem.Lock() // add to decryption queues - if peer.isRunning.Get() { + if peer.isRunning.Load() { peer.queue.inbound.c <- elem device.queue.decryption.c <- elem buffer = device.GetMessageBuffer() @@ -268,7 +267,7 @@ func (device *Device) RoutineHandshake(id int) { // consume reply - if peer := entry.peer; peer.isRunning.Get() { + if peer := entry.peer; peer.isRunning.Load() { device.log.Verbosef("Receiving cookie response from %s", elem.endpoint.DstToString()) if !peer.cookieGenerator.ConsumeReply(&reply) { device.log.Verbosef("Could not decrypt invalid cookie response") @@ -341,7 +340,7 @@ func (device *Device) RoutineHandshake(id int) { peer.SetEndpointFromPacket(elem.endpoint) device.log.Verbosef("%v - Received handshake initiation", peer) - atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet))) + peer.rxBytes.Add(uint64(len(elem.packet))) peer.SendHandshakeResponse() @@ -369,7 +368,7 @@ func (device *Device) RoutineHandshake(id int) { peer.SetEndpointFromPacket(elem.endpoint) device.log.Verbosef("%v - Received handshake response", peer) - atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet))) + peer.rxBytes.Add(uint64(len(elem.packet))) // update timers @@ -426,7 +425,7 @@ func (peer *Peer) RoutineSequentialReceiver() { peer.keepKeyFreshReceiving() peer.timersAnyAuthenticatedPacketTraversal() peer.timersAnyAuthenticatedPacketReceived() - atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)+MinMessageSize)) + peer.rxBytes.Add(uint64(len(elem.packet) + MinMessageSize)) if len(elem.packet) == 0 { device.log.Verbosef("%v - Receiving keepalive packet", peer) diff --git a/device/send.go b/device/send.go index 0a7135f90..471c51c3b 100644 --- a/device/send.go +++ b/device/send.go @@ -12,7 +12,6 @@ import ( "net" "os" "sync" - "sync/atomic" "time" "golang.org/x/crypto/chacha20poly1305" @@ -76,7 +75,7 @@ func (elem *QueueOutboundElement) clearPointers() { /* Queues a keepalive if no packets are queued for peer */ func (peer *Peer) SendKeepalive() { - if len(peer.queue.staged) == 0 && peer.isRunning.Get() { + if len(peer.queue.staged) == 0 && peer.isRunning.Load() { elem := peer.device.NewOutboundElement() select { case peer.queue.staged <- elem: @@ -91,7 +90,7 @@ func (peer *Peer) SendKeepalive() { func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { if !isRetry { - atomic.StoreUint32(&peer.timers.handshakeAttempts, 0) + peer.timers.handshakeAttempts.Store(0) } peer.handshake.mutex.RLock() @@ -193,7 +192,7 @@ func (peer *Peer) keepKeyFreshSending() { if keypair == nil { return } - nonce := atomic.LoadUint64(&keypair.sendNonce) + nonce := keypair.sendNonce.Load() if nonce > RekeyAfterMessages || (keypair.isInitiator && time.Since(keypair.created) > RekeyAfterTime) { peer.SendHandshakeInitiation(false) } @@ -269,7 +268,7 @@ func (device *Device) RoutineReadFromTUN() { if peer == nil { continue } - if peer.isRunning.Get() { + if peer.isRunning.Load() { peer.StagePacket(elem) elem = nil peer.SendStagedPackets() @@ -300,7 +299,7 @@ top: } keypair := peer.keypairs.Current() - if keypair == nil || atomic.LoadUint64(&keypair.sendNonce) >= RejectAfterMessages || time.Since(keypair.created) >= RejectAfterTime { + if keypair == nil || keypair.sendNonce.Load() >= RejectAfterMessages || time.Since(keypair.created) >= RejectAfterTime { peer.SendHandshakeInitiation(false) return } @@ -309,9 +308,9 @@ top: select { case elem := <-peer.queue.staged: elem.peer = peer - elem.nonce = atomic.AddUint64(&keypair.sendNonce, 1) - 1 + elem.nonce = keypair.sendNonce.Add(1) - 1 if elem.nonce >= RejectAfterMessages { - atomic.StoreUint64(&keypair.sendNonce, RejectAfterMessages) + keypair.sendNonce.Store(RejectAfterMessages) peer.StagePacket(elem) // XXX: Out of order, but we can't front-load go chans goto top } @@ -320,7 +319,7 @@ top: elem.Lock() // add to parallel and sequential queue - if peer.isRunning.Get() { + if peer.isRunning.Load() { peer.queue.outbound.c <- elem peer.device.queue.encryption.c <- elem } else { @@ -385,7 +384,7 @@ func (device *Device) RoutineEncryption(id int) { binary.LittleEndian.PutUint64(fieldNonce, elem.nonce) // pad content to multiple of 16 - paddingSize := calculatePaddingSize(len(elem.packet), int(atomic.LoadInt32(&device.tun.mtu))) + paddingSize := calculatePaddingSize(len(elem.packet), int(device.tun.mtu.Load())) elem.packet = append(elem.packet, paddingZeros[:paddingSize]...) // encrypt content and release to consumer @@ -419,7 +418,7 @@ func (peer *Peer) RoutineSequentialSender() { return } elem.Lock() - if !peer.isRunning.Get() { + if !peer.isRunning.Load() { // peer has been stopped; return re-usable elems to the shared pool. // This is an optimization only. It is possible for the peer to be stopped // immediately after this check, in which case, elem will get processed. diff --git a/device/timers.go b/device/timers.go index 4d2d0f88a..c8ef8877a 100644 --- a/device/timers.go +++ b/device/timers.go @@ -9,7 +9,6 @@ package device import ( "sync" - "sync/atomic" "time" _ "unsafe" ) @@ -74,11 +73,11 @@ func (timer *Timer) IsPending() bool { } func (peer *Peer) timersActive() bool { - return peer.isRunning.Get() && peer.device != nil && peer.device.isUp() + return peer.isRunning.Load() && peer.device != nil && peer.device.isUp() } func expiredRetransmitHandshake(peer *Peer) { - if atomic.LoadUint32(&peer.timers.handshakeAttempts) > MaxTimerHandshakes { + if peer.timers.handshakeAttempts.Load() > MaxTimerHandshakes { peer.device.log.Verbosef("%s - Handshake did not complete after %d attempts, giving up", peer, MaxTimerHandshakes+2) if peer.timersActive() { @@ -97,8 +96,8 @@ func expiredRetransmitHandshake(peer *Peer) { peer.timers.zeroKeyMaterial.Mod(RejectAfterTime * 3) } } else { - atomic.AddUint32(&peer.timers.handshakeAttempts, 1) - peer.device.log.Verbosef("%s - Handshake did not complete after %d seconds, retrying (try %d)", peer, int(RekeyTimeout.Seconds()), atomic.LoadUint32(&peer.timers.handshakeAttempts)+1) + peer.timers.handshakeAttempts.Add(1) + peer.device.log.Verbosef("%s - Handshake did not complete after %d seconds, retrying (try %d)", peer, int(RekeyTimeout.Seconds()), peer.timers.handshakeAttempts.Load()+1) /* We clear the endpoint address src address, in case this is the cause of trouble. */ peer.Lock() @@ -113,8 +112,8 @@ func expiredRetransmitHandshake(peer *Peer) { func expiredSendKeepalive(peer *Peer) { peer.SendKeepalive() - if peer.timers.needAnotherKeepalive.Get() { - peer.timers.needAnotherKeepalive.Set(false) + if peer.timers.needAnotherKeepalive.Load() { + peer.timers.needAnotherKeepalive.Store(false) if peer.timersActive() { peer.timers.sendKeepalive.Mod(KeepaliveTimeout) } @@ -138,7 +137,7 @@ func expiredZeroKeyMaterial(peer *Peer) { } func expiredPersistentKeepalive(peer *Peer) { - if atomic.LoadUint32(&peer.persistentKeepaliveInterval) > 0 { + if peer.persistentKeepaliveInterval.Load() > 0 { peer.SendKeepalive() } } @@ -156,7 +155,7 @@ func (peer *Peer) timersDataReceived() { if !peer.timers.sendKeepalive.IsPending() { peer.timers.sendKeepalive.Mod(KeepaliveTimeout) } else { - peer.timers.needAnotherKeepalive.Set(true) + peer.timers.needAnotherKeepalive.Store(true) } } } @@ -187,9 +186,9 @@ func (peer *Peer) timersHandshakeComplete() { if peer.timersActive() { peer.timers.retransmitHandshake.Del() } - atomic.StoreUint32(&peer.timers.handshakeAttempts, 0) - peer.timers.sentLastMinuteHandshake.Set(false) - atomic.StoreInt64(&peer.stats.lastHandshakeNano, time.Now().UnixNano()) + peer.timers.handshakeAttempts.Store(0) + peer.timers.sentLastMinuteHandshake.Store(false) + peer.lastHandshakeNano.Store(time.Now().UnixNano()) } /* Should be called after an ephemeral key is created, which is before sending a handshake response or after receiving a handshake response. */ @@ -201,7 +200,7 @@ func (peer *Peer) timersSessionDerived() { /* Should be called before a packet with authentication -- keepalive, data, or handshake -- is sent, or after one is received. */ func (peer *Peer) timersAnyAuthenticatedPacketTraversal() { - keepalive := atomic.LoadUint32(&peer.persistentKeepaliveInterval) + keepalive := peer.persistentKeepaliveInterval.Load() if keepalive > 0 && peer.timersActive() { peer.timers.persistentKeepalive.Mod(time.Duration(keepalive) * time.Second) } @@ -216,9 +215,9 @@ func (peer *Peer) timersInit() { } func (peer *Peer) timersStart() { - atomic.StoreUint32(&peer.timers.handshakeAttempts, 0) - peer.timers.sentLastMinuteHandshake.Set(false) - peer.timers.needAnotherKeepalive.Set(false) + peer.timers.handshakeAttempts.Store(0) + peer.timers.sentLastMinuteHandshake.Store(false) + peer.timers.needAnotherKeepalive.Store(false) } func (peer *Peer) timersStop() { diff --git a/device/tun.go b/device/tun.go index 4af954821..d94bde1e1 100644 --- a/device/tun.go +++ b/device/tun.go @@ -7,7 +7,6 @@ package device import ( "fmt" - "sync/atomic" "golang.zx2c4.com/wireguard/tun" ) @@ -33,7 +32,7 @@ func (device *Device) RoutineTUNEventReader() { tooLarge = fmt.Sprintf(" (too large, capped at %v)", MaxContentSize) mtu = MaxContentSize } - old := atomic.SwapInt32(&device.tun.mtu, int32(mtu)) + old := device.tun.mtu.Swap(int32(mtu)) if int(old) != mtu { device.log.Verbosef("MTU updated: %v%s", mtu, tooLarge) } diff --git a/device/uapi.go b/device/uapi.go index 30dd97e8b..550a0323c 100644 --- a/device/uapi.go +++ b/device/uapi.go @@ -16,7 +16,6 @@ import ( "strconv" "strings" "sync" - "sync/atomic" "time" "golang.zx2c4.com/wireguard/ipc" @@ -112,15 +111,15 @@ func (device *Device) IpcGetOperation(w io.Writer) error { sendf("endpoint=%s", peer.endpoint.DstToString()) } - nano := atomic.LoadInt64(&peer.stats.lastHandshakeNano) + nano := peer.lastHandshakeNano.Load() secs := nano / time.Second.Nanoseconds() nano %= time.Second.Nanoseconds() sendf("last_handshake_time_sec=%d", secs) sendf("last_handshake_time_nsec=%d", nano) - sendf("tx_bytes=%d", atomic.LoadUint64(&peer.stats.txBytes)) - sendf("rx_bytes=%d", atomic.LoadUint64(&peer.stats.rxBytes)) - sendf("persistent_keepalive_interval=%d", atomic.LoadUint32(&peer.persistentKeepaliveInterval)) + sendf("tx_bytes=%d", peer.txBytes.Load()) + sendf("rx_bytes=%d", peer.rxBytes.Load()) + sendf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval.Load()) device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool { sendf("allowed_ip=%s", prefix.String()) @@ -358,7 +357,7 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error return ipcErrorf(ipc.IpcErrorInvalid, "failed to set persistent keepalive interval: %w", err) } - old := atomic.SwapUint32(&peer.persistentKeepaliveInterval, uint32(secs)) + old := peer.persistentKeepaliveInterval.Swap(uint32(secs)) // Send immediate keepalive if we're turning it on and before it wasn't on. peer.pkaOn = old == 0 && secs != 0 diff --git a/go.mod b/go.mod index d0d58b30a..c180d1b49 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module golang.zx2c4.com/wireguard -go 1.18 +go 1.19 require ( golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd diff --git a/ipc/namedpipe/file.go b/ipc/namedpipe/file.go index c5dd48a18..ec9b8d44e 100644 --- a/ipc/namedpipe/file.go +++ b/ipc/namedpipe/file.go @@ -54,7 +54,7 @@ type file struct { handle windows.Handle wg sync.WaitGroup wgLock sync.RWMutex - closing uint32 // used as atomic boolean + closing atomic.Bool socket bool readDeadline deadlineHandler writeDeadline deadlineHandler @@ -65,7 +65,7 @@ type deadlineHandler struct { channel timeoutChan channelLock sync.RWMutex timer *time.Timer - timedout uint32 // used as atomic boolean + timedout atomic.Bool } // makeFile makes a new file from an existing file handle @@ -89,7 +89,7 @@ func makeFile(h windows.Handle) (*file, error) { func (f *file) closeHandle() { f.wgLock.Lock() // Atomically set that we are closing, releasing the resources only once. - if atomic.SwapUint32(&f.closing, 1) == 0 { + if f.closing.Swap(true) == false { f.wgLock.Unlock() // cancel all IO and wait for it to complete windows.CancelIoEx(f.handle, nil) @@ -112,7 +112,7 @@ func (f *file) Close() error { // The caller must call f.wg.Done() when the IO is finished, prior to Close() returning. func (f *file) prepareIo() (*ioOperation, error) { f.wgLock.RLock() - if atomic.LoadUint32(&f.closing) == 1 { + if f.closing.Load() { f.wgLock.RUnlock() return nil, os.ErrClosed } @@ -144,7 +144,7 @@ func (f *file) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, err err return int(bytes), err } - if atomic.LoadUint32(&f.closing) == 1 { + if f.closing.Load() { windows.CancelIoEx(f.handle, &c.o) } @@ -160,7 +160,7 @@ func (f *file) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, err err case r = <-c.ch: err = r.err if err == windows.ERROR_OPERATION_ABORTED { - if atomic.LoadUint32(&f.closing) == 1 { + if f.closing.Load() { err = os.ErrClosed } } else if err != nil && f.socket { @@ -192,7 +192,7 @@ func (f *file) Read(b []byte) (int, error) { } defer f.wg.Done() - if atomic.LoadUint32(&f.readDeadline.timedout) == 1 { + if f.readDeadline.timedout.Load() { return 0, os.ErrDeadlineExceeded } @@ -219,7 +219,7 @@ func (f *file) Write(b []byte) (int, error) { } defer f.wg.Done() - if atomic.LoadUint32(&f.writeDeadline.timedout) == 1 { + if f.writeDeadline.timedout.Load() { return 0, os.ErrDeadlineExceeded } @@ -256,7 +256,7 @@ func (d *deadlineHandler) set(deadline time.Time) error { } d.timer = nil } - atomic.StoreUint32(&d.timedout, 0) + d.timedout.Store(false) select { case <-d.channel: @@ -271,7 +271,7 @@ func (d *deadlineHandler) set(deadline time.Time) error { } timeoutIO := func() { - atomic.StoreUint32(&d.timedout, 1) + d.timedout.Store(true) close(d.channel) } diff --git a/ipc/namedpipe/namedpipe.go b/ipc/namedpipe/namedpipe.go index 6db5ea31e..92cc1ee0f 100644 --- a/ipc/namedpipe/namedpipe.go +++ b/ipc/namedpipe/namedpipe.go @@ -29,7 +29,7 @@ type pipe struct { type messageBytePipe struct { pipe - writeClosed int32 + writeClosed atomic.Bool readEOF bool } @@ -51,17 +51,17 @@ func (f *pipe) SetDeadline(t time.Time) error { // CloseWrite closes the write side of a message pipe in byte mode. func (f *messageBytePipe) CloseWrite() error { - if !atomic.CompareAndSwapInt32(&f.writeClosed, 0, 1) { + if !f.writeClosed.CompareAndSwap(false, true) { return io.ErrClosedPipe } err := f.file.Flush() if err != nil { - atomic.StoreInt32(&f.writeClosed, 0) + f.writeClosed.Store(false) return err } _, err = f.file.Write(nil) if err != nil { - atomic.StoreInt32(&f.writeClosed, 0) + f.writeClosed.Store(false) return err } return nil @@ -70,7 +70,7 @@ func (f *messageBytePipe) CloseWrite() error { // Write writes bytes to a message pipe in byte mode. Zero-byte writes are ignored, since // they are used to implement CloseWrite. func (f *messageBytePipe) Write(b []byte) (int, error) { - if atomic.LoadInt32(&f.writeClosed) != 0 { + if f.writeClosed.Load() { return 0, io.ErrClosedPipe } if len(b) == 0 { diff --git a/tun/tun_windows.go b/tun/tun_windows.go index d0571508a..6782fd4cd 100644 --- a/tun/tun_windows.go +++ b/tun/tun_windows.go @@ -26,10 +26,10 @@ const ( ) type rateJuggler struct { - current uint64 - nextByteCount uint64 - nextStartTime int64 - changing int32 + current atomic.Uint64 + nextByteCount atomic.Uint64 + nextStartTime atomic.Int64 + changing atomic.Bool } type NativeTun struct { @@ -42,7 +42,7 @@ type NativeTun struct { events chan Event running sync.WaitGroup closeOnce sync.Once - close int32 + close atomic.Bool forcedMTU int } @@ -57,18 +57,14 @@ func procyield(cycles uint32) //go:linkname nanotime runtime.nanotime func nanotime() int64 -// // CreateTUN creates a Wintun interface with the given name. Should a Wintun // interface with the same name exist, it is reused. -// func CreateTUN(ifname string, mtu int) (Device, error) { return CreateTUNWithRequestedGUID(ifname, WintunStaticRequestedGUID, mtu) } -// // CreateTUNWithRequestedGUID creates a Wintun interface with the given name and // a requested GUID. Should a Wintun interface with the same name exist, it is reused. -// func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu int) (Device, error) { wt, err := wintun.CreateAdapter(ifname, WintunTunnelType, requestedGUID) if err != nil { @@ -113,7 +109,7 @@ func (tun *NativeTun) Events() chan Event { func (tun *NativeTun) Close() error { var err error tun.closeOnce.Do(func() { - atomic.StoreInt32(&tun.close, 1) + tun.close.Store(true) windows.SetEvent(tun.readWait) tun.running.Wait() tun.session.End() @@ -144,13 +140,13 @@ func (tun *NativeTun) Read(buff []byte, offset int) (int, error) { tun.running.Add(1) defer tun.running.Done() retry: - if atomic.LoadInt32(&tun.close) == 1 { + if tun.close.Load() { return 0, os.ErrClosed } start := nanotime() - shouldSpin := atomic.LoadUint64(&tun.rate.current) >= spinloopRateThreshold && uint64(start-atomic.LoadInt64(&tun.rate.nextStartTime)) <= rateMeasurementGranularity*2 + shouldSpin := tun.rate.current.Load() >= spinloopRateThreshold && uint64(start-tun.rate.nextStartTime.Load()) <= rateMeasurementGranularity*2 for { - if atomic.LoadInt32(&tun.close) == 1 { + if tun.close.Load() { return 0, os.ErrClosed } packet, err := tun.session.ReceivePacket() @@ -184,7 +180,7 @@ func (tun *NativeTun) Flush() error { func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { tun.running.Add(1) defer tun.running.Done() - if atomic.LoadInt32(&tun.close) == 1 { + if tun.close.Load() { return 0, os.ErrClosed } @@ -210,7 +206,7 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { func (tun *NativeTun) LUID() uint64 { tun.running.Add(1) defer tun.running.Done() - if atomic.LoadInt32(&tun.close) == 1 { + if tun.close.Load() { return 0 } return tun.wt.LUID() @@ -223,15 +219,15 @@ func (tun *NativeTun) RunningVersion() (version uint32, err error) { func (rate *rateJuggler) update(packetLen uint64) { now := nanotime() - total := atomic.AddUint64(&rate.nextByteCount, packetLen) - period := uint64(now - atomic.LoadInt64(&rate.nextStartTime)) + total := rate.nextByteCount.Add(packetLen) + period := uint64(now - rate.nextStartTime.Load()) if period >= rateMeasurementGranularity { - if !atomic.CompareAndSwapInt32(&rate.changing, 0, 1) { + if !rate.changing.CompareAndSwap(false, true) { return } - atomic.StoreInt64(&rate.nextStartTime, now) - atomic.StoreUint64(&rate.current, total*uint64(time.Second/time.Nanosecond)/period) - atomic.StoreUint64(&rate.nextByteCount, 0) - atomic.StoreInt32(&rate.changing, 0) + rate.nextStartTime.Store(now) + rate.current.Store(total * uint64(time.Second/time.Nanosecond) / period) + rate.nextByteCount.Store(0) + rate.changing.Store(false) } }