Skip to content

Commit

Permalink
Add support for TCP_USER_TIMEOUT setting
Browse files Browse the repository at this point in the history
See https://blog.cloudflare.com/when-tcp-sockets-refuse-to-die/ for
technical details.

Signed-off-by: Andrey Smirnov <[email protected]>
  • Loading branch information
smira committed May 22, 2023
1 parent 91f8614 commit 8bea9a4
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 0 deletions.
74 changes: 74 additions & 0 deletions tcp_linux.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// +build linux !appengine

/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

package tcpproxy

import (
"fmt"
"net"
"syscall"
"time"

"golang.org/x/sys/unix"
)

// SetTCPUserTimeout sets the TCP user timeout on a connection's socket
func SetTCPUserTimeout(conn net.Conn, timeout time.Duration) error {
tcpconn, ok := conn.(*net.TCPConn)
if !ok {
// not a TCP connection. exit early
return nil
}
rawConn, err := tcpconn.SyscallConn()
if err != nil {
return fmt.Errorf("error getting raw connection: %v", err)
}
err = rawConn.Control(func(fd uintptr) {
err = syscall.SetsockoptInt(int(fd), syscall.IPPROTO_TCP, unix.TCP_USER_TIMEOUT, int(timeout/time.Millisecond))
})
if err != nil {
return fmt.Errorf("error setting option on socket: %v", err)
}

return nil
}

// GetTCPUserTimeout gets the TCP user timeout on a connection's socket
func GetTCPUserTimeout(conn net.Conn) (opt int, err error) {
tcpconn, ok := conn.(*net.TCPConn)
if !ok {
err = fmt.Errorf("conn is not *net.TCPConn. got %T", conn)
return
}
rawConn, err := tcpconn.SyscallConn()
if err != nil {
err = fmt.Errorf("error getting raw connection: %v", err)
return
}
err = rawConn.Control(func(fd uintptr) {
opt, err = syscall.GetsockoptInt(int(fd), syscall.IPPROTO_TCP, unix.TCP_USER_TIMEOUT)
})
if err != nil {
err = fmt.Errorf("error getting option on socket: %v", err)
return
}

return
}
37 changes: 37 additions & 0 deletions tcp_other.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// +build !linux appengine

/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

package tcpproxy

import (
"net"
"time"
)

// SetTCPUserTimeout is a no-op function under non-linux or appengine environments
func SetTCPUserTimeout(conn net.Conn, timeout time.Duration) error {
return nil
}

// GetTCPUserTimeout is a no-op function under non-linux or appengine environments
// a negative return value indicates the operation is not supported
func GetTCPUserTimeout(conn net.Conn) (int, error) {
return -1, nil
}
9 changes: 9 additions & 0 deletions tcpproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,10 @@ type DialProxy struct {
// If negative, the timeout is disabled.
DialTimeout time.Duration

// TCPUserTimeout optionally specifies a TCP_USER_TIMEOUT (only on Linux).
// If zero, TCP_USER_TIMEOUT is not set.
TCPUserTimeout time.Duration

// DialContext optionally specifies an alternate dial function
// for TCP targets. If nil, the standard
// net.Dialer.DialContext method is used.
Expand Down Expand Up @@ -381,6 +385,11 @@ func (dp *DialProxy) HandleConn(src net.Conn) {
}
}

if dp.TCPUserTimeout > 0 {
SetTCPUserTimeout(src, dp.TCPUserTimeout)
SetTCPUserTimeout(dst, dp.TCPUserTimeout)
}

errc := make(chan error, 1)
go proxyCopy(errc, src, dst)
go proxyCopy(errc, dst, src)
Expand Down
1 change: 1 addition & 0 deletions tcpproxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ func TestProxyPROXYOut(t *testing.T) {
p.AddRoute(testFrontAddr, &DialProxy{
Addr: back.Addr().String(),
ProxyProtocolVersion: 1,
TCPUserTimeout: time.Second,
})
if err := p.Start(); err != nil {
t.Fatal(err)
Expand Down

0 comments on commit 8bea9a4

Please sign in to comment.