diff --git a/device/noise-protocol.go b/device/noise-protocol.go index e8f6145e5..f7d38ee30 100644 --- a/device/noise-protocol.go +++ b/device/noise-protocol.go @@ -6,6 +6,7 @@ package device import ( + "encoding/binary" "errors" "fmt" "sync" @@ -115,6 +116,53 @@ type MessageCookieReply struct { Cookie [blake2s.Size128 + poly1305.TagSize]byte } +var errMessageTooShort = errors.New("message too short") + +func (msg *MessageInitiation) unmarshal(b []byte) error { + if len(b) < MessageInitiationSize { + return errMessageTooShort + } + + msg.Type = binary.LittleEndian.Uint32(b) + msg.Sender = binary.LittleEndian.Uint32(b[4:]) + copy(msg.Ephemeral[:], b[8:]) + copy(msg.Static[:], b[8+len(msg.Ephemeral):]) + copy(msg.Timestamp[:], b[8+len(msg.Ephemeral)+len(msg.Static):]) + copy(msg.MAC1[:], b[8+len(msg.Ephemeral)+len(msg.Static)+len(msg.Timestamp):]) + copy(msg.MAC2[:], b[8+len(msg.Ephemeral)+len(msg.Static)+len(msg.Timestamp)+len(msg.MAC1):]) + + return nil +} + +func (msg *MessageResponse) unmarshal(b []byte) error { + if len(b) < MessageResponseSize { + return errMessageTooShort + } + + msg.Type = binary.LittleEndian.Uint32(b) + msg.Sender = binary.LittleEndian.Uint32(b[4:]) + msg.Receiver = binary.LittleEndian.Uint32(b[8:]) + copy(msg.Ephemeral[:], b[12:]) + copy(msg.Empty[:], b[12+len(msg.Ephemeral):]) + copy(msg.MAC1[:], b[12+len(msg.Ephemeral)+len(msg.Empty):]) + copy(msg.MAC2[:], b[12+len(msg.Ephemeral)+len(msg.Empty)+len(msg.MAC1):]) + + return nil +} + +func (msg *MessageCookieReply) unmarshal(b []byte) error { + if len(b) < MessageCookieReplySize { + return errMessageTooShort + } + + msg.Type = binary.LittleEndian.Uint32(b) + msg.Receiver = binary.LittleEndian.Uint32(b[4:]) + copy(msg.Nonce[:], b[8:]) + copy(msg.Cookie[:], b[8+len(msg.Nonce):]) + + return nil +} + type Handshake struct { state handshakeState mutex sync.RWMutex diff --git a/device/receive.go b/device/receive.go index 1ab3e2945..c4c6f567a 100644 --- a/device/receive.go +++ b/device/receive.go @@ -6,7 +6,6 @@ package device import ( - "bytes" "encoding/binary" "errors" "net" @@ -276,6 +275,11 @@ func (device *Device) RoutineHandshake(id int) { }() device.log.Verbosef("Routine: handshake worker %d - started", id) + var ( + msgCookieReply MessageCookieReply + msgInitiation MessageInitiation + msgResponse MessageResponse + ) for elem := range device.queue.handshake.c { // handle cookie fields and ratelimiting @@ -286,9 +290,7 @@ func (device *Device) RoutineHandshake(id int) { // unmarshal packet - var reply MessageCookieReply - reader := bytes.NewReader(elem.packet) - err := binary.Read(reader, binary.LittleEndian, &reply) + err := msgCookieReply.unmarshal(elem.packet) if err != nil { device.log.Verbosef("Failed to decode cookie reply") goto skip @@ -296,7 +298,7 @@ func (device *Device) RoutineHandshake(id int) { // lookup peer from index - entry := device.indexTable.Lookup(reply.Receiver) + entry := device.indexTable.Lookup(msgCookieReply.Receiver) if entry.peer == nil { goto skip @@ -306,7 +308,7 @@ func (device *Device) RoutineHandshake(id int) { if peer := entry.peer; peer.isRunning.Load() { device.log.Verbosef("Receiving cookie response from %s", elem.endpoint.DstToString()) - if !peer.cookieGenerator.ConsumeReply(&reply) { + if !peer.cookieGenerator.ConsumeReply(&msgCookieReply) { device.log.Verbosef("Could not decrypt invalid cookie response") } } @@ -352,9 +354,7 @@ func (device *Device) RoutineHandshake(id int) { // unmarshal - var msg MessageInitiation - reader := bytes.NewReader(elem.packet) - err := binary.Read(reader, binary.LittleEndian, &msg) + err := msgInitiation.unmarshal(elem.packet) if err != nil { device.log.Errorf("Failed to decode initiation message") goto skip @@ -362,7 +362,7 @@ func (device *Device) RoutineHandshake(id int) { // consume initiation - peer := device.ConsumeMessageInitiation(&msg) + peer := device.ConsumeMessageInitiation(&msgInitiation) if peer == nil { device.log.Verbosef("Received invalid initiation message from %s", elem.endpoint.DstToString()) goto skip @@ -385,9 +385,7 @@ func (device *Device) RoutineHandshake(id int) { // unmarshal - var msg MessageResponse - reader := bytes.NewReader(elem.packet) - err := binary.Read(reader, binary.LittleEndian, &msg) + err := msgResponse.unmarshal(elem.packet) if err != nil { device.log.Errorf("Failed to decode response message") goto skip @@ -395,7 +393,7 @@ func (device *Device) RoutineHandshake(id int) { // consume response - peer := device.ConsumeMessageResponse(&msg) + peer := device.ConsumeMessageResponse(&msgResponse) if peer == nil { device.log.Verbosef("Received invalid response message from %s", elem.endpoint.DstToString()) goto skip