diff --git a/internal/quic/conn.go b/internal/quic/conn.go index b3d6feabc..1292f2b20 100644 --- a/internal/quic/conn.go +++ b/internal/quic/conn.go @@ -86,7 +86,15 @@ type connTestHooks interface { timeNow() time.Time } -func newConn(now time.Time, side connSide, originalDstConnID, retrySrcConnID []byte, peerAddr netip.AddrPort, config *Config, l *Listener) (*Conn, error) { +// newServerConnIDs is connection IDs associated with a new server connection. +type newServerConnIDs struct { + srcConnID []byte // source from client's current Initial + dstConnID []byte // destination from client's current Initial + originalDstConnID []byte // destination from client's first Initial + retrySrcConnID []byte // source from server's Retry +} + +func newConn(now time.Time, side connSide, cids newServerConnIDs, peerAddr netip.AddrPort, config *Config, l *Listener) (*Conn, error) { c := &Conn{ side: side, listener: l, @@ -115,11 +123,11 @@ func newConn(now time.Time, side connSide, originalDstConnID, retrySrcConnID []b } initialConnID, _ = c.connIDState.dstConnID() } else { - initialConnID = originalDstConnID - if retrySrcConnID != nil { - initialConnID = retrySrcConnID + initialConnID = cids.originalDstConnID + if cids.retrySrcConnID != nil { + initialConnID = cids.retrySrcConnID } - if err := c.connIDState.initServer(c, initialConnID); err != nil { + if err := c.connIDState.initServer(c, cids); err != nil { return nil, err } } @@ -134,8 +142,8 @@ func newConn(now time.Time, side connSide, originalDstConnID, retrySrcConnID []b if err := c.startTLS(now, initialConnID, transportParameters{ initialSrcConnID: c.connIDState.srcConnID(), - originalDstConnID: originalDstConnID, - retrySrcConnID: retrySrcConnID, + originalDstConnID: cids.originalDstConnID, + retrySrcConnID: cids.retrySrcConnID, ackDelayExponent: ackDelayExponent, maxUDPPayloadSize: maxUDPPayloadSize, maxAckDelay: maxAckDelay, diff --git a/internal/quic/conn_id.go b/internal/quic/conn_id.go index 91ccaade1..b77ad8edf 100644 --- a/internal/quic/conn_id.go +++ b/internal/quic/conn_id.go @@ -96,8 +96,8 @@ func (s *connIDState) initClient(c *Conn) error { return nil } -func (s *connIDState) initServer(c *Conn, dstConnID []byte) error { - dstConnID = cloneBytes(dstConnID) +func (s *connIDState) initServer(c *Conn, cids newServerConnIDs) error { + dstConnID := cloneBytes(cids.dstConnID) // Client-chosen, transient connection ID received in the first Initial packet. // The server will not use this as the Source Connection ID of packets it sends, // but remembers it because it may receive packets sent to this destination. @@ -121,6 +121,14 @@ func (s *connIDState) initServer(c *Conn, dstConnID []byte) error { conns.addConnID(c, dstConnID) conns.addConnID(c, locid) }) + + // Client chose its own connection ID. + s.remote = append(s.remote, remoteConnID{ + connID: connID{ + seq: 0, + cid: cloneBytes(cids.srcConnID), + }, + }) return nil } diff --git a/internal/quic/conn_id_test.go b/internal/quic/conn_id_test.go index 63feec992..314a6b384 100644 --- a/internal/quic/conn_id_test.go +++ b/internal/quic/conn_id_test.go @@ -578,8 +578,11 @@ func TestConnIDPeerWithZeroLengthIDProvidesPreferredAddr(t *testing.T) { p.preferredAddrV6 = netip.MustParseAddrPort("[::0]:0") p.preferredAddrConnID = testPeerConnID(1) p.preferredAddrResetToken = make([]byte, 16) + }, func(cids *newServerConnIDs) { + cids.srcConnID = []byte{} + }, func(tc *testConn) { + tc.peerConnID = []byte{} }) - tc.peerConnID = []byte{} tc.writeFrames(packetTypeInitial, debugFrameCrypto{ diff --git a/internal/quic/conn_test.go b/internal/quic/conn_test.go index df28907f4..248be9641 100644 --- a/internal/quic/conn_test.go +++ b/internal/quic/conn_test.go @@ -193,33 +193,38 @@ func newTestConn(t *testing.T, side connSide, opts ...any) *testConn { TLSConfig: newTestTLSConfig(side), StatelessResetKey: testStatelessResetKey, } + var cids newServerConnIDs + if side == serverSide { + // The initial connection ID for the server is chosen by the client. + cids.srcConnID = testPeerConnID(0) + cids.dstConnID = testPeerConnID(-1) + } var configTransportParams []func(*transportParameters) + var configTestConn []func(*testConn) for _, o := range opts { switch o := o.(type) { case func(*Config): o(config) case func(*tls.Config): o(config.TLSConfig) + case func(cids *newServerConnIDs): + o(&cids) case func(p *transportParameters): configTransportParams = append(configTransportParams, o) + case func(p *testConn): + configTestConn = append(configTestConn, o) default: t.Fatalf("unknown newTestConn option %T", o) } } - var initialConnID []byte - if side == serverSide { - // The initial connection ID for the server is chosen by the client. - initialConnID = testPeerConnID(-1) - } - listener := newTestListener(t, config) listener.configTransportParams = configTransportParams + listener.configTestConn = configTestConn conn, err := listener.l.newConn( listener.now, side, - initialConnID, - nil, + cids, netip.MustParseAddrPort("127.0.0.1:443")) if err != nil { t.Fatal(err) @@ -244,6 +249,9 @@ func newTestConnForConn(t *testing.T, listener *testListener, conn *Conn) *testC recvDatagram: make(chan *datagram), } t.Cleanup(tc.cleanup) + for _, f := range listener.configTestConn { + f(tc) + } conn.testHooks = (*testConnHooks)(tc) if listener.peerTLSConn != nil { diff --git a/internal/quic/listener.go b/internal/quic/listener.go index 08f011092..24484eb6f 100644 --- a/internal/quic/listener.go +++ b/internal/quic/listener.go @@ -140,7 +140,7 @@ func (l *Listener) Dial(ctx context.Context, network, address string) (*Conn, er } addr := u.AddrPort() addr = netip.AddrPortFrom(addr.Addr().Unmap(), addr.Port()) - c, err := l.newConn(time.Now(), clientSide, nil, nil, addr) + c, err := l.newConn(time.Now(), clientSide, newServerConnIDs{}, addr) if err != nil { return nil, err } @@ -151,13 +151,13 @@ func (l *Listener) Dial(ctx context.Context, network, address string) (*Conn, er return c, nil } -func (l *Listener) newConn(now time.Time, side connSide, originalDstConnID, retrySrcConnID []byte, peerAddr netip.AddrPort) (*Conn, error) { +func (l *Listener) newConn(now time.Time, side connSide, cids newServerConnIDs, peerAddr netip.AddrPort) (*Conn, error) { l.connsMu.Lock() defer l.connsMu.Unlock() if l.closing { return nil, errors.New("listener closed") } - c, err := newConn(now, side, originalDstConnID, retrySrcConnID, peerAddr, l.config, l) + c, err := newConn(now, side, cids, peerAddr, l.config, l) if err != nil { return nil, err } @@ -296,19 +296,22 @@ func (l *Listener) handleUnknownDestinationDatagram(m *datagram) { } else { now = time.Now() } - var originalDstConnID, retrySrcConnID []byte + cids := newServerConnIDs{ + srcConnID: p.srcConnID, + dstConnID: p.dstConnID, + } if l.config.RequireAddressValidation { var ok bool - retrySrcConnID = p.dstConnID - originalDstConnID, ok = l.validateInitialAddress(now, p, m.addr) + cids.retrySrcConnID = p.dstConnID + cids.originalDstConnID, ok = l.validateInitialAddress(now, p, m.addr) if !ok { return } } else { - originalDstConnID = p.dstConnID + cids.originalDstConnID = p.dstConnID } var err error - c, err := l.newConn(now, serverSide, originalDstConnID, retrySrcConnID, m.addr) + c, err := l.newConn(now, serverSide, cids, m.addr) if err != nil { // The accept queue is probably full. // We could send a CONNECTION_CLOSE to the peer to reject the connection. diff --git a/internal/quic/listener_test.go b/internal/quic/listener_test.go index 21717e251..674d4e4a1 100644 --- a/internal/quic/listener_test.go +++ b/internal/quic/listener_test.go @@ -19,12 +19,12 @@ import ( ) func TestConnect(t *testing.T) { - newLocalConnPair(t, &Config{}, &Config{}) + NewLocalConnPair(t, &Config{}, &Config{}) } func TestStreamTransfer(t *testing.T) { ctx := context.Background() - cli, srv := newLocalConnPair(t, &Config{}, &Config{}) + cli, srv := NewLocalConnPair(t, &Config{}, &Config{}) data := makeTestData(1 << 20) srvdone := make(chan struct{}) @@ -61,11 +61,11 @@ func TestStreamTransfer(t *testing.T) { } } -func newLocalConnPair(t *testing.T, conf1, conf2 *Config) (clientConn, serverConn *Conn) { +func NewLocalConnPair(t *testing.T, conf1, conf2 *Config) (clientConn, serverConn *Conn) { t.Helper() ctx := context.Background() - l1 := newLocalListener(t, serverSide, conf1) - l2 := newLocalListener(t, clientSide, conf2) + l1 := NewLocalListener(t, serverSide, conf1) + l2 := NewLocalListener(t, clientSide, conf2) c2, err := l2.Dial(ctx, "udp", l1.LocalAddr().String()) if err != nil { t.Fatal(err) @@ -77,9 +77,11 @@ func newLocalConnPair(t *testing.T, conf1, conf2 *Config) (clientConn, serverCon return c2, c1 } -func newLocalListener(t *testing.T, side connSide, conf *Config) *Listener { +func NewLocalListener(t *testing.T, side connSide, conf *Config) *Listener { t.Helper() if conf.TLSConfig == nil { + newConf := *conf + conf = &newConf conf.TLSConfig = newTestTLSConfig(side) } l, err := Listen("udp", "127.0.0.1:0", conf) @@ -101,6 +103,7 @@ type testListener struct { conns map[*Conn]*testConn acceptQueue []*testConn configTransportParams []func(*transportParameters) + configTestConn []func(*testConn) sentDatagrams [][]byte peerTLSConn *tls.QUICConn lastInitialDstConnID []byte // for parsing Retry packets @@ -251,33 +254,6 @@ func (tl *testListener) wantIdle(expectation string) { } } -func (tl *testListener) newClientTLS(srcConnID, dstConnID []byte) []byte { - peerProvidedParams := defaultTransportParameters() - peerProvidedParams.initialSrcConnID = srcConnID - peerProvidedParams.originalDstConnID = dstConnID - for _, f := range tl.configTransportParams { - f(&peerProvidedParams) - } - - config := &tls.QUICConfig{TLSConfig: newTestTLSConfig(clientSide)} - tl.peerTLSConn = tls.QUICClient(config) - tl.peerTLSConn.SetTransportParameters(marshalTransportParameters(peerProvidedParams)) - tl.peerTLSConn.Start(context.Background()) - var data []byte - for { - e := tl.peerTLSConn.NextEvent() - switch e.Kind { - case tls.QUICNoEvent: - return data - case tls.QUICWriteData: - if e.Level != tls.QUICEncryptionLevelInitial { - tl.t.Fatal("initial data at unexpected level") - } - data = append(data, e.Data...) - } - } -} - // advance causes time to pass. func (tl *testListener) advance(d time.Duration) { tl.t.Helper()