Skip to content
This repository has been archived by the owner on May 26, 2022. It is now read-only.

Commit

Permalink
Updates on marten's comments
Browse files Browse the repository at this point in the history
  • Loading branch information
cannium committed Jan 2, 2019
1 parent ec0a1c1 commit e0d65c5
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 64 deletions.
7 changes: 5 additions & 2 deletions listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,11 @@ func (l *listener) Close() error {
delete(connManager.ipv4Conns, l.conn)
delete(connManager.ipv6Conns, l.conn)
}
if l.conn == connManager.defaultConn {
connManager.defaultConn = nil
if l.conn == connManager.defaultIpv4Conn {
connManager.defaultIpv4Conn = nil
}
if l.conn == connManager.defaultIpv6Conn {
connManager.defaultIpv6Conn = nil
}
connManager.mu.Unlock()
return l.quicListener.Close()
Expand Down
122 changes: 68 additions & 54 deletions transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,16 @@ var quicConfig = &quic.Config{
type addrType int

const (
other addrType = 0
loopback addrType = 1
global addrType = 2
unspecified addrType = 3
addrTypeOther addrType = iota
addrTypeLoopback
addrTypeGlobal
addrTypeUnspecified
)

type connManager struct {
mu sync.Mutex
defaultConn net.PacketConn
mu sync.Mutex
defaultIpv4Conn net.PacketConn
defaultIpv6Conn net.PacketConn
// map underhood PacketConn -> connection remote address type(usage of the conn)
ipv4Conns map[net.PacketConn]addrType
ipv6Conns map[net.PacketConn]addrType
Expand All @@ -57,67 +58,79 @@ func newConnManager() *connManager {

func typeOfIP(ipAddr net.IP) addrType {
if ipAddr.IsLoopback() {
return loopback
return addrTypeLoopback
}
if ipAddr.IsUnspecified() {
return unspecified
return addrTypeUnspecified
}
if ipAddr.IsGlobalUnicast() {
return global
return addrTypeGlobal
}
return other
return addrTypeOther
}

func (c *connManager) pickDialConn(conns map[net.PacketConn]addrType, remoteHostIP net.IP) net.PacketConn {
remoteAddrType := typeOfIP(remoteHostIP)
for conn, connAddrType := range conns {
if c.defaultConn == nil && connAddrType == unspecified {
c.defaultConn = conn
// GetConnForAddr try to reuse exist connections when possible
func (c *connManager) GetConnForAddr(network, remoteHost string) (net.PacketConn, error) {
remoteAddr, err := net.ResolveUDPAddr(network, remoteHost)
if err != nil {
return nil, err
}

var listenHost string
var conns map[net.PacketConn]addrType
var defaultConn net.PacketConn
var setDefaultConn func(conn net.PacketConn)
switch network {
case "udp4":
listenHost = "0.0.0.0:0"
conns = c.ipv4Conns
defaultConn = c.defaultIpv4Conn
setDefaultConn = func(conn net.PacketConn) {
c.defaultIpv4Conn = conn
}
if connAddrType == remoteAddrType {
return conn
case "udp6":
listenHost = ":0"
conns = c.ipv6Conns
defaultConn = c.defaultIpv6Conn
setDefaultConn = func(conn net.PacketConn) {
c.defaultIpv6Conn = conn
}
default:
return nil, fmt.Errorf("unsupported network: %s", network)
}
return c.defaultConn
}

func (c *connManager) getConnForAddr(network, listenHost string, remoteHostIP net.IP,
conns map[net.PacketConn]addrType) (conn net.PacketConn, err error) {

// check if there exists a connection of expected type
var pickedConn net.PacketConn
c.mu.Lock()
conn = c.pickDialConn(conns, remoteHostIP)
remoteAddrType := typeOfIP(remoteAddr.IP)
for conn, connAddrType := range conns {
if defaultConn == nil && connAddrType == addrTypeUnspecified {
setDefaultConn(conn)
}
if connAddrType == remoteAddrType {
pickedConn = conn
break
}
}
c.mu.Unlock()
if conn != nil {
return conn, nil
if pickedConn != nil {
return pickedConn, nil
}

conn, err = c.createConn(network, listenHost)
if defaultConn != nil {
return defaultConn, nil
}
// could not reuse an exist connection, create a new one
conn, err := c.createConn(network, listenHost)
if err != nil {
return nil, err
}
connAddrType := typeOfIP(remoteHostIP)
connAddrType := typeOfIP(remoteAddr.IP)
c.mu.Lock()
conns[conn] = connAddrType
c.mu.Unlock()
return conn, nil
}

func (c *connManager) GetConnForAddr(network, remoteHost string) (net.PacketConn, error) {
remoteAddr, err := net.ResolveUDPAddr(network, remoteHost)
if err != nil {
return nil, err
}

switch network {
case "udp4":
return c.getConnForAddr(network, "0.0.0.0:0", remoteAddr.IP, c.ipv4Conns)
case "udp6":
return c.getConnForAddr(network, ":0", remoteAddr.IP, c.ipv6Conns)
default:
return nil, fmt.Errorf("unsupported network: %s", network)
}
}

func (c *connManager) createConn(network, host string) (net.PacketConn, error) {
addr, err := net.ResolveUDPAddr(network, host)
if err != nil {
Expand All @@ -132,17 +145,18 @@ func (c *connManager) listenUDP(addr ma.Multiaddr) (net.PacketConn, error) {
return nil, err
}
conn, err := c.createConn(network, host)
if err == nil {
c.mu.Lock()
switch network {
case "udp4":
c.ipv4Conns[conn] = unspecified
case "udp6":
c.ipv6Conns[conn] = unspecified
}
c.mu.Unlock()
if err != nil {
return conn, err
}
return conn, err
c.mu.Lock()
switch network {
case "udp4":
c.ipv4Conns[conn] = addrTypeUnspecified
case "udp6":
c.ipv6Conns[conn] = addrTypeUnspecified
}
c.mu.Unlock()
return conn, nil
}

// The Transport implements the tpt.Transport interface for QUIC connections.
Expand Down
16 changes: 8 additions & 8 deletions transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,27 +37,27 @@ var _ = Describe("Port reuse", func() {
c = newConnManager()
})
It("reuse IPv4 port", func() {
addr, _ := ma.NewMultiaddr("/ip4/0.0.0.0/udp/40002/quic")
addr, _ := ma.NewMultiaddr("/ip4/0.0.0.0/udp/40000/quic")
listenConn, err := c.listenUDP(addr)
Expect(err).ToNot(HaveOccurred())

dialConn, err := c.GetConnForAddr("udp4", "127.0.0.1:40003")
dialConn, err := c.GetConnForAddr("udp4", "127.0.0.1:40004")
Expect(err).ToNot(HaveOccurred())
Expect(dialConn).To(Equal(listenConn))
listenConn.Close()
})
It("reuse IPv6 port", func() {
addr, _ := ma.NewMultiaddr("/ip6/::/udp/40002/quic")
addr, _ := ma.NewMultiaddr("/ip6/::/udp/40001/quic")
listenConn, err := c.listenUDP(addr)
Expect(err).ToNot(HaveOccurred())

dialConn, err := c.GetConnForAddr("udp6", "[::1]:40003")
dialConn, err := c.GetConnForAddr("udp6", "[::1]:40004")
Expect(err).ToNot(HaveOccurred())
Expect(dialConn).To(Equal(listenConn))
listenConn.Close()
})
It("listen after dial won't reuse conn", func() {
dialConn, err := c.GetConnForAddr("udp4", "127.0.0.1:40003")
dialConn, err := c.GetConnForAddr("udp4", "127.0.0.1:40004")
Expect(err).ToNot(HaveOccurred())

addr, _ := ma.NewMultiaddr("/ip4/0.0.0.0/udp/40002/quic")
Expand All @@ -69,14 +69,14 @@ var _ = Describe("Port reuse", func() {
listenConn.Close()
})
It("use listen conn by default", func() {
dialConn, err := c.GetConnForAddr("udp4", "127.0.0.1:40003")
dialConn, err := c.GetConnForAddr("udp4", "127.0.0.1:40004")
Expect(err).ToNot(HaveOccurred())

addr, _ := ma.NewMultiaddr("/ip4/0.0.0.0/udp/40002/quic")
addr, _ := ma.NewMultiaddr("/ip4/0.0.0.0/udp/40003/quic")
listenConn, err := c.listenUDP(addr)
Expect(err).ToNot(HaveOccurred())

dialConn2, err := c.GetConnForAddr("udp4", "1.2.3.4:40004")
dialConn2, err := c.GetConnForAddr("udp4", "1.2.3.4:40005")
Expect(err).ToNot(HaveOccurred())
Expect(dialConn2).To(Equal(listenConn))

Expand Down

0 comments on commit e0d65c5

Please sign in to comment.