From 066dc9a7e2bad5bcac56a8ccb349d35bca933265 Mon Sep 17 00:00:00 2001 From: Eugene Burkov Date: Thu, 8 Jun 2023 16:16:34 +0300 Subject: [PATCH] Pull request 259: 5874-fix-fallback Merge in GO/dnsproxy from 5874-fix-fallback to master Updates AdguardTeam/AdGuardHome#5874. Squashed commit of the following: commit 6687dac3b6c06f17e1a88b855c0b8dc595e4f2a7 Author: Eugene Burkov Date: Thu Jun 8 16:10:06 2023 +0300 upstream: close plain connections commit 4162e21386d296c88504ffda3e7a0d8616bb92ec Author: Eugene Burkov Date: Thu Jun 8 15:32:49 2023 +0300 upstream: imp fallback test commit 2614653e5e96b3f3e8317cb62de2cabc5a0250b5 Author: Eugene Burkov Date: Thu Jun 8 15:16:40 2023 +0300 upstream: imp logging, test commit 7fa6a9d2ab7f25cf6ad2cd87184ec7cf4c7cc871 Author: Eugene Burkov Date: Thu Jun 8 14:53:51 2023 +0300 upstream: fix fallback to tcp --- upstream/upstream.go | 12 +++++++----- upstream/upstream_doh.go | 9 +++++++-- upstream/upstream_dot.go | 4 ++-- upstream/upstream_plain.go | 18 +++++++++++++++--- upstream/upstream_plain_test.go | 24 +++++++++++++++++++++--- 5 files changed, 52 insertions(+), 15 deletions(-) diff --git a/upstream/upstream.go b/upstream/upstream.go index 87e230a7c..74fd6e2a1 100644 --- a/upstream/upstream.go +++ b/upstream/upstream.go @@ -278,24 +278,26 @@ func addPort(u *url.URL, port int) { } } -// Write to log DNS request information that we are going to send -func logBegin(upstreamAddress string, req *dns.Msg) { +// logBegin logs the start of DNS request resolution. It should be called right +// before dialing the connection to the upstream. n is the [network] that will +// be used to send the request. +func logBegin(upstreamAddress string, n network, req *dns.Msg) { qtype := "" target := "" if len(req.Question) != 0 { qtype = dns.Type(req.Question[0].Qtype).String() target = req.Question[0].Name } - log.Debug("%s: sending request %s %s", upstreamAddress, qtype, target) + log.Debug("%s: sending request over %s: %s %s", upstreamAddress, n, qtype, target) } // Write to log about the result of DNS request -func logFinish(upstreamAddress string, err error) { +func logFinish(upstreamAddress string, n network, err error) { status := "ok" if err != nil { status = err.Error() } - log.Debug("%s: response: %s", upstreamAddress, status) + log.Debug("%s: response received over %s: %s", upstreamAddress, n, status) } // DialerInitializer returns the handler that it creates. All the subsequent diff --git a/upstream/upstream_doh.go b/upstream/upstream_doh.go index 719835b68..3238ef726 100644 --- a/upstream/upstream_doh.go +++ b/upstream/upstream_doh.go @@ -211,9 +211,14 @@ func (p *dnsOverHTTPS) closeClient(client *http.Client) (err error) { func (p *dnsOverHTTPS) exchangeHTTPS(client *http.Client, req *dns.Msg) (resp *dns.Msg, err error) { addr := p.Address() - logBegin(addr, req) + n := networkTCP + if isHTTP3(client) { + n = networkUDP + } + + logBegin(addr, n, req) resp, err = p.exchangeHTTPSClient(client, req) - logFinish(addr, err) + logFinish(addr, n, err) return resp, err } diff --git a/upstream/upstream_dot.go b/upstream/upstream_dot.go index 6e31f60b7..bd30bebe7 100644 --- a/upstream/upstream_dot.go +++ b/upstream/upstream_dot.go @@ -198,8 +198,8 @@ func (p *dnsOverTLS) putBack(conn net.Conn) { func (p *dnsOverTLS) exchangeWithConn(conn net.Conn, m *dns.Msg) (reply *dns.Msg, err error) { addr := p.Address() - logBegin(addr, m) - defer func() { logFinish(addr, err) }() + logBegin(addr, networkTCP, m) + defer func() { logFinish(addr, networkTCP, err) }() dnsConn := dns.Conn{Conn: conn} diff --git a/upstream/upstream_plain.go b/upstream/upstream_plain.go index b15b0e0df..9a65c0727 100644 --- a/upstream/upstream_plain.go +++ b/upstream/upstream_plain.go @@ -98,14 +98,15 @@ func (p *plainDNS) dialExchange( conn.UDPSize = dns.MinMsgSize } - logBegin(addr, req) - defer func() { logFinish(addr, err) }() + logBegin(addr, network, req) + defer func() { logFinish(addr, network, err) }() ctx := context.Background() conn.Conn, err = dial(ctx, string(network), "") if err != nil { return nil, fmt.Errorf("dialing %s over %s: %w", p.addr.Host, network, err) } + defer func(c net.Conn) { err = errors.WithDeferred(err, c.Close()) }(conn.Conn) resp, _, err = client.ExchangeWithConn(req, conn) if isExpectedConnErr(err) { @@ -113,6 +114,7 @@ func (p *plainDNS) dialExchange( if err != nil { return nil, fmt.Errorf("dialing %s over %s again: %w", p.addr.Host, network, err) } + defer func(c net.Conn) { err = errors.WithDeferred(err, c.Close()) }(conn.Conn) resp, _, err = client.ExchangeWithConn(req, conn) } @@ -144,20 +146,30 @@ func (p *plainDNS) Exchange(req *dns.Msg) (resp *dns.Msg, err error) { resp, err = p.dialExchange(p.net, dial, req) if p.net != networkUDP { + // The network is already TCP. return resp, err } if resp == nil { + // There is likely an error with the upstream. return resp, err } if errors.Is(err, errQuestion) { + // The upstream responds with malformed messages, so try TCP. log.Debug("plain %s: %s, using tcp", addr, err) + + return p.dialExchange(networkTCP, dial, req) } else if resp.Truncated { + // Fallback to TCP on truncated responses. log.Debug("plain %s: resp for %s is truncated, using tcp", &req.Question[0], addr) + + return p.dialExchange(networkTCP, dial, req) } - return p.dialExchange(networkTCP, dial, req) + // There is either no error or the error isn't related to the received + // message. + return resp, err } // Close implements the [Upstream] interface for *plainDNS. diff --git a/upstream/upstream_plain_test.go b/upstream/upstream_plain_test.go index be4219df9..44f947fd9 100644 --- a/upstream/upstream_plain_test.go +++ b/upstream/upstream_plain_test.go @@ -4,6 +4,7 @@ import ( "fmt" "io" "net" + "sync/atomic" "testing" "time" @@ -62,9 +63,8 @@ func TestUpstream_plainDNS_badID(t *testing.T) { assert.Nil(t, resp) } -func TestUpstream_plainDNS_fallback(t *testing.T) { +func TestUpstream_plainDNS_fallbackToTCP(t *testing.T) { req := createTestMessage() - goodResp := respondToTestMessage(req) truncResp := goodResp.Copy() @@ -79,26 +79,41 @@ func TestUpstream_plainDNS_fallback(t *testing.T) { testCases := []struct { udpResp *dns.Msg name string + wantUDP int + wantTCP int }{{ udpResp: goodResp, name: "all_right", + wantUDP: 1, + wantTCP: 0, }, { udpResp: truncResp, name: "truncated_response", + wantUDP: 1, + wantTCP: 1, }, { udpResp: badQNameResp, name: "bad_qname", + wantUDP: 1, + wantTCP: 1, }, { udpResp: badQTypeResp, name: "bad_qtype", + wantUDP: 1, + wantTCP: 1, }} for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { + var udpReqNum, tcpReqNum atomic.Uint32 srv := startDNSServer(t, func(w dns.ResponseWriter, _ *dns.Msg) { - resp := goodResp + var resp *dns.Msg if w.RemoteAddr().Network() == string(networkUDP) { + udpReqNum.Add(1) resp = tc.udpResp + } else { + tcpReqNum.Add(1) + resp = goodResp } require.NoError(testutil.PanicT{}, w.WriteMsg(resp)) @@ -116,6 +131,9 @@ func TestUpstream_plainDNS_fallback(t *testing.T) { resp, err := u.Exchange(req) require.NoError(t, err) requireResponse(t, req, resp) + + assert.Equal(t, tc.wantUDP, int(udpReqNum.Load())) + assert.Equal(t, tc.wantTCP, int(tcpReqNum.Load())) }) } }