Skip to content

Commit

Permalink
Merge pull request #465 from evanlinjin/feature/fix-routing-logic
Browse files Browse the repository at this point in the history
Fixed some routing logic.
  • Loading branch information
志宇 authored Jul 3, 2019
2 parents 91283de + ddd2593 commit ddcb391
Show file tree
Hide file tree
Showing 13 changed files with 258 additions and 150 deletions.
4 changes: 2 additions & 2 deletions cmd/apps/skychat/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,10 @@ func messageHandler(w http.ResponseWriter, req *http.Request) {

addr := &app.Addr{PubKey: pk, Port: 1}
connsMu.Lock()
conn := chatConns[pk]
conn, ok := chatConns[pk]
connsMu.Unlock()

if conn == nil {
if !ok {
var err error
err = r.Do(func() error {
conn, err = chatApp.Dial(addr)
Expand Down
3 changes: 1 addition & 2 deletions integration/test-messaging.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#!/usr/bin/env bash
source ./integration/generic/env-vars.sh
# curl --data {'"recipient":"'$PK_A'", "message":"Hello Joe!"}' -X POST $CHAT_C
curl --data {'"recipient":"'$PK_A'", "message":"Hello Joe!"}' -X POST $CHAT_C
curl --data {'"recipient":"'$PK_C'", "message":"Hello Mike!"}' -X POST $CHAT_A
16 changes: 8 additions & 8 deletions pkg/app/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ func (app *App) handleProto() {
}

func (app *App) serveConn(addr *LoopAddr, conn io.ReadWriteCloser) {
defer conn.Close()

for {
buf := make([]byte, 32*1024)
n, err := conn.Read(buf)
Expand All @@ -179,11 +181,10 @@ func (app *App) serveConn(addr *LoopAddr, conn io.ReadWriteCloser) {
}
}

if app.conns[*addr] != nil {
app.mu.Lock()
if _, ok := app.conns[*addr]; ok {
app.proto.Send(FrameClose, &addr, nil) // nolint: errcheck
}

app.mu.Lock()
delete(app.conns, *addr)
app.mu.Unlock()
}
Expand Down Expand Up @@ -251,13 +252,12 @@ func (app *App) confirmLoop(data []byte) error {

type appConn struct {
net.Conn
rw io.ReadWriteCloser
laddr *Addr
raddr *Addr
}

func newAppConn(conn net.Conn, laddr, raddr *Addr) *appConn {
return &appConn{conn, conn, laddr, raddr}
return &appConn{conn, laddr, raddr}
}

func (conn *appConn) LocalAddr() net.Addr {
Expand All @@ -269,13 +269,13 @@ func (conn *appConn) RemoteAddr() net.Addr {
}

func (conn *appConn) Write(p []byte) (n int, err error) {
return conn.rw.Write(p)
return conn.Conn.Write(p)
}

func (conn *appConn) Read(p []byte) (n int, err error) {
return conn.rw.Read(p)
return conn.Conn.Read(p)
}

func (conn *appConn) Close() error {
return conn.rw.Close()
return conn.Conn.Close()
}
2 changes: 1 addition & 1 deletion pkg/node/rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func newTransportSummary(tm *transport.Manager, tp *transport.ManagedTransport,
}

summary := &TransportSummary{
ID: tp.ID,
ID: tp.Entry.ID,
Local: tm.Local(),
Remote: remote,
Type: tp.Type(),
Expand Down
2 changes: 1 addition & 1 deletion pkg/node/rpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ func TestRPC(t *testing.T) {
t.Run("Transport", func(t *testing.T) {
var ids []uuid.UUID
node.tm.WalkTransports(func(tp *transport.ManagedTransport) bool {
ids = append(ids, tp.ID)
ids = append(ids, tp.Entry.ID)
return true
})

Expand Down
3 changes: 2 additions & 1 deletion pkg/router/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ func (r *Router) Serve(ctx context.Context) error {
}

go func(tp transport.Transport) {
defer tp.Close()
for {
if err := serve(tp); err != nil {
if err != io.EOF {
Expand Down Expand Up @@ -423,7 +424,7 @@ func (r *Router) setupProto(ctx context.Context) (*setup.Protocol, transport.Tra
// TODO(evanlinjin): need string constant for tp type.
tr, err := r.tm.CreateTransport(ctx, r.config.SetupNodes[0], dmsg.Type, false)
if err != nil {
return nil, nil, fmt.Errorf("transport: %s", err)
return nil, nil, fmt.Errorf("setup transport: %s", err)
}

sProto := setup.NewSetupProtocol(tr)
Expand Down
14 changes: 7 additions & 7 deletions pkg/router/router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func TestRouterForwarding(t *testing.T) {
tr3, err := m3.CreateTransport(context.TODO(), pk2, "mock2", true)
require.NoError(t, err)

rule := routing.ForwardRule(time.Now().Add(time.Hour), 4, tr3.ID)
rule := routing.ForwardRule(time.Now().Add(time.Hour), 4, tr3.Entry.ID)
routeID, err := rt.AddRule(rule)
require.NoError(t, err)

Expand Down Expand Up @@ -197,9 +197,9 @@ func TestRouterApp(t *testing.T) {

ni1, ni2 := noiseInstances(t, pk1, pk2, sk1, sk2)
raddr := &app.Addr{PubKey: pk2, Port: 5}
require.NoError(t, r.pm.SetLoop(6, raddr, &loop{tr.ID, 4, ni1}))
require.NoError(t, r.pm.SetLoop(6, raddr, &loop{tr.Entry.ID, 4, ni1}))

tr2 := m2.Transport(tr.ID)
tr2 := m2.Transport(tr.Entry.ID)
go proto.Send(app.FrameSend, &app.Packet{Addr: &app.LoopAddr{Port: 6, Remote: *raddr}, Payload: []byte("bar")}, nil) // nolint: errcheck

packet := make(routing.Packet, 29)
Expand Down Expand Up @@ -333,13 +333,13 @@ func TestRouterSetup(t *testing.T) {

var routeID routing.RouteID
t.Run("add route", func(t *testing.T) {
routeID, err = setup.AddRule(sProto, routing.ForwardRule(time.Now().Add(time.Hour), 2, tr.ID))
routeID, err = setup.AddRule(sProto, routing.ForwardRule(time.Now().Add(time.Hour), 2, tr.Entry.ID))
require.NoError(t, err)

rule, err := rt.Rule(routeID)
require.NoError(t, err)
assert.Equal(t, routing.RouteID(2), rule.RouteID())
assert.Equal(t, tr.ID, rule.TransportID())
assert.Equal(t, tr.Entry.ID, rule.TransportID())
})

t.Run("`confirm loop - responder", func(t *testing.T) {
Expand Down Expand Up @@ -371,7 +371,7 @@ func TestRouterSetup(t *testing.T) {
loop, err := r.pm.GetLoop(2, &app.Addr{PubKey: pk2, Port: 1})
require.NoError(t, err)
require.NotNil(t, loop)
assert.Equal(t, tr.ID, loop.trID)
assert.Equal(t, tr.Entry.ID, loop.trID)
assert.Equal(t, routing.RouteID(2), loop.routeID)

addrs := [2]*app.Addr{}
Expand Down Expand Up @@ -427,7 +427,7 @@ func TestRouterSetup(t *testing.T) {
l, err := r.pm.GetLoop(2, &app.Addr{PubKey: pk2, Port: 1})
require.NoError(t, err)
require.NotNil(t, l)
assert.Equal(t, tr.ID, l.trID)
assert.Equal(t, tr.Entry.ID, l.trID)
assert.Equal(t, routing.RouteID(2), l.routeID)

addrs := [2]*app.Addr{}
Expand Down
16 changes: 8 additions & 8 deletions pkg/setup/node_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,12 @@ func TestCreateLoop(t *testing.T) {

l := &routing.Loop{LocalPort: 1, RemotePort: 2, Expiry: time.Now().Add(time.Hour),
Forward: routing.Route{
&routing.Hop{From: pk1, To: pk2, Transport: tr1.ID},
&routing.Hop{From: pk2, To: pk3, Transport: tr3.ID},
&routing.Hop{From: pk1, To: pk2, Transport: tr1.Entry.ID},
&routing.Hop{From: pk2, To: pk3, Transport: tr3.Entry.ID},
},
Reverse: routing.Route{
&routing.Hop{From: pk3, To: pk2, Transport: tr3.ID},
&routing.Hop{From: pk2, To: pk1, Transport: tr1.ID},
&routing.Hop{From: pk3, To: pk2, Transport: tr3.Entry.ID},
&routing.Hop{From: pk2, To: pk1, Transport: tr1.Entry.ID},
},
}

Expand Down Expand Up @@ -132,25 +132,25 @@ func TestCreateLoop(t *testing.T) {
assert.Equal(t, uint16(1), rule.LocalPort())
rule = rules[2]
assert.Equal(t, routing.RuleForward, rule.Type())
assert.Equal(t, tr1.ID, rule.TransportID())
assert.Equal(t, tr1.Entry.ID, rule.TransportID())
assert.Equal(t, routing.RouteID(2), rule.RouteID())

rules = n2.getRules()
require.Len(t, rules, 2)
rule = rules[1]
assert.Equal(t, routing.RuleForward, rule.Type())
assert.Equal(t, tr1.ID, rule.TransportID())
assert.Equal(t, tr1.Entry.ID, rule.TransportID())
assert.Equal(t, routing.RouteID(1), rule.RouteID())
rule = rules[2]
assert.Equal(t, routing.RuleForward, rule.Type())
assert.Equal(t, tr3.ID, rule.TransportID())
assert.Equal(t, tr3.Entry.ID, rule.TransportID())
assert.Equal(t, routing.RouteID(2), rule.RouteID())

rules = n3.getRules()
require.Len(t, rules, 2)
rule = rules[1]
assert.Equal(t, routing.RuleForward, rule.Type())
assert.Equal(t, tr3.ID, rule.TransportID())
assert.Equal(t, tr3.Entry.ID, rule.TransportID())
assert.Equal(t, routing.RouteID(1), rule.RouteID())
rule = rules[2]
assert.Equal(t, routing.RuleApp, rule.Type())
Expand Down
64 changes: 59 additions & 5 deletions pkg/transport/log.go
Original file line number Diff line number Diff line change
@@ -1,21 +1,72 @@
package transport

import (
"bytes"
"encoding/gob"
"encoding/json"
"errors"
"fmt"
"math/big"
"os"
"path/filepath"
"strconv"
"sync"
"sync/atomic"

"github.com/google/uuid"
)

// LogEntry represents a logging entry for a given Transport.
// The entry is updated every time a packet is received or sent.
type LogEntry struct {
ReceivedBytes *big.Int `json:"received"` // Total received bytes.
SentBytes *big.Int `json:"sent"` // Total sent bytes.
RecvBytes uint64 `json:"recv"` // Total received bytes.
SentBytes uint64 `json:"sent"` // Total sent bytes.
}

// AddRecv records read.
func (le *LogEntry) AddRecv(n uint64) {
atomic.AddUint64(&le.RecvBytes, n)
}

// AddSent records write.
func (le *LogEntry) AddSent(n uint64) {
atomic.AddUint64(&le.SentBytes, n)
}

// MarshalJSON implements json.Marshaller
func (le *LogEntry) MarshalJSON() ([]byte, error) {
rb := strconv.FormatUint(atomic.LoadUint64(&le.RecvBytes), 10)
sb := strconv.FormatUint(atomic.LoadUint64(&le.SentBytes), 10)
return []byte(`{"recv":` + rb + `,"sent":` + sb + `}`), nil
}

// GobEncode implements gob.GobEncoder
func (le *LogEntry) GobEncode() ([]byte, error) {
var b bytes.Buffer
enc := gob.NewEncoder(&b)
if err := enc.Encode(le.RecvBytes); err != nil {
return nil, err
}
if err := enc.Encode(le.SentBytes); err != nil {
return nil, err
}
return b.Bytes(), nil
}

// GobDecode implements gob.GobDecoder
func (le *LogEntry) GobDecode(b []byte) error {
r := bytes.NewReader(b)
dec := gob.NewDecoder(r)
var rb uint64
if err := dec.Decode(&rb); err != nil {
return err
}
var sb uint64
if err := dec.Decode(&sb); err != nil {
return err
}
atomic.StoreUint64(&le.RecvBytes, rb)
atomic.StoreUint64(&le.SentBytes, sb)
return nil
}

// LogStore stores transport log entries.
Expand All @@ -32,14 +83,17 @@ type inMemoryTransportLogStore struct {
// InMemoryTransportLogStore implements in-memory TransportLogStore.
func InMemoryTransportLogStore() LogStore {
return &inMemoryTransportLogStore{
entries: map[uuid.UUID]*LogEntry{},
entries: make(map[uuid.UUID]*LogEntry),
}
}

func (tls *inMemoryTransportLogStore) Entry(id uuid.UUID) (*LogEntry, error) {
tls.mu.Lock()
entry := tls.entries[id]
entry, ok := tls.entries[id]
tls.mu.Unlock()
if !ok {
return entry, errors.New("transport log entry not found")
}

return entry, nil
}
Expand Down
37 changes: 32 additions & 5 deletions pkg/transport/log_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
package transport_test

import (
"encoding/json"
"fmt"
"io/ioutil"
"math/big"
"os"
"testing"

Expand All @@ -17,17 +18,22 @@ func testTransportLogStore(t *testing.T, logStore transport.LogStore) {
t.Helper()

id1 := uuid.New()
entry1 := &transport.LogEntry{big.NewInt(100), big.NewInt(200)}
entry1 := new(transport.LogEntry)
entry1.AddRecv(100)
entry1.AddSent(200)

id2 := uuid.New()
entry2 := &transport.LogEntry{big.NewInt(300), big.NewInt(400)}
entry2 := new(transport.LogEntry)
entry2.AddRecv(300)
entry2.AddSent(400)

require.NoError(t, logStore.Record(id1, entry1))
require.NoError(t, logStore.Record(id2, entry2))

entry, err := logStore.Entry(id2)
require.NoError(t, err)
assert.Equal(t, int64(300), entry.ReceivedBytes.Int64())
assert.Equal(t, int64(400), entry.SentBytes.Int64())
assert.Equal(t, uint64(300), entry.RecvBytes)
assert.Equal(t, uint64(400), entry.SentBytes)
}

func TestInMemoryTransportLogStore(t *testing.T) {
Expand All @@ -43,3 +49,24 @@ func TestFileTransportLogStore(t *testing.T) {
require.NoError(t, err)
testTransportLogStore(t, ls)
}

func TestLogEntry_MarshalJSON(t *testing.T) {
entry := new(transport.LogEntry)
entry.AddSent(10)
entry.AddRecv(100)
b, err := json.Marshal(entry)
require.NoError(t, err)
fmt.Println(string(b))
b, err = json.MarshalIndent(entry, "", "\t")
require.NoError(t, err)
fmt.Println(string(b))
}

func TestLogEntry_GobEncode(t *testing.T) {
var entry transport.LogEntry

enc, err := entry.GobEncode()
require.NoError(t, err)

require.NoError(t, entry.GobDecode(enc))
}
Loading

0 comments on commit ddcb391

Please sign in to comment.