Skip to content

Commit

Permalink
Merge pull request #1016 from xchacha20-poly1305/dev-android-protect
Browse files Browse the repository at this point in the history
feat: support Android protect path
  • Loading branch information
tobyxdd authored Apr 13, 2024
2 parents 6e00aa3 + 6b5486f commit 234dc45
Show file tree
Hide file tree
Showing 9 changed files with 346 additions and 16 deletions.
41 changes: 32 additions & 9 deletions app/cmd/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"github.com/apernet/hysteria/app/internal/http"
"github.com/apernet/hysteria/app/internal/proxymux"
"github.com/apernet/hysteria/app/internal/redirect"
"github.com/apernet/hysteria/app/internal/sockopts"
"github.com/apernet/hysteria/app/internal/socks5"
"github.com/apernet/hysteria/app/internal/tproxy"
"github.com/apernet/hysteria/app/internal/tun"
Expand Down Expand Up @@ -100,13 +101,20 @@ type clientConfigTLS struct {
}

type clientConfigQUIC struct {
InitStreamReceiveWindow uint64 `mapstructure:"initStreamReceiveWindow"`
MaxStreamReceiveWindow uint64 `mapstructure:"maxStreamReceiveWindow"`
InitConnectionReceiveWindow uint64 `mapstructure:"initConnReceiveWindow"`
MaxConnectionReceiveWindow uint64 `mapstructure:"maxConnReceiveWindow"`
MaxIdleTimeout time.Duration `mapstructure:"maxIdleTimeout"`
KeepAlivePeriod time.Duration `mapstructure:"keepAlivePeriod"`
DisablePathMTUDiscovery bool `mapstructure:"disablePathMTUDiscovery"`
InitStreamReceiveWindow uint64 `mapstructure:"initStreamReceiveWindow"`
MaxStreamReceiveWindow uint64 `mapstructure:"maxStreamReceiveWindow"`
InitConnectionReceiveWindow uint64 `mapstructure:"initConnReceiveWindow"`
MaxConnectionReceiveWindow uint64 `mapstructure:"maxConnReceiveWindow"`
MaxIdleTimeout time.Duration `mapstructure:"maxIdleTimeout"`
KeepAlivePeriod time.Duration `mapstructure:"keepAlivePeriod"`
DisablePathMTUDiscovery bool `mapstructure:"disablePathMTUDiscovery"`
Sockopts clientConfigQUICSockopts `mapstructure:"sockopts"`
}

type clientConfigQUICSockopts struct {
BindInterface *string `mapstructure:"bindInterface"`
FirewallMark *uint32 `mapstructure:"fwmark"`
FdControlUnixSocket *string `mapstructure:"fdControlUnixSocket"`
}

type clientConfigBandwidth struct {
Expand Down Expand Up @@ -196,18 +204,33 @@ func (c *clientConfig) fillServerAddr(hyConfig *client.Config) error {
// fillConnFactory must be called after fillServerAddr, as we have different logic
// for ConnFactory depending on whether we have a port hopping address.
func (c *clientConfig) fillConnFactory(hyConfig *client.Config) error {
so := &sockopts.SocketOptions{
BindInterface: c.QUIC.Sockopts.BindInterface,
FirewallMark: c.QUIC.Sockopts.FirewallMark,
FdControlUnixSocket: c.QUIC.Sockopts.FdControlUnixSocket,
}
if err := so.CheckSupported(); err != nil {
var unsupportedErr *sockopts.UnsupportedError
if errors.As(err, &unsupportedErr) {
return configError{
Field: "quic.sockopts." + unsupportedErr.Field,
Err: errors.New("unsupported on this platform"),
}
}
return configError{Field: "quic.sockopts", Err: err}
}
// Inner PacketConn
var newFunc func(addr net.Addr) (net.PacketConn, error)
switch strings.ToLower(c.Transport.Type) {
case "", "udp":
if hyConfig.ServerAddr.Network() == "udphop" {
hopAddr := hyConfig.ServerAddr.(*udphop.UDPHopAddr)
newFunc = func(addr net.Addr) (net.PacketConn, error) {
return udphop.NewUDPHopPacketConn(hopAddr, c.Transport.UDP.HopInterval, nil)
return udphop.NewUDPHopPacketConn(hopAddr, c.Transport.UDP.HopInterval, so.ListenUDP)
}
} else {
newFunc = func(addr net.Addr) (net.PacketConn, error) {
return net.ListenUDP("udp", nil)
return so.ListenUDP()
}
}
default:
Expand Down
13 changes: 13 additions & 0 deletions app/cmd/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ func TestClientConfig(t *testing.T) {
MaxIdleTimeout: 10 * time.Second,
KeepAlivePeriod: 4 * time.Second,
DisablePathMTUDiscovery: true,
Sockopts: clientConfigQUICSockopts{
BindInterface: stringRef("eth0"),
FirewallMark: uint32Ref(1234),
FdControlUnixSocket: stringRef("test.sock"),
},
},
Bandwidth: clientConfigBandwidth{
Up: "200 mbps",
Expand Down Expand Up @@ -189,3 +194,11 @@ func TestClientConfigURI(t *testing.T) {
})
}
}

func stringRef(s string) *string {
return &s
}

func uint32Ref(i uint32) *uint32 {
return &i
}
12 changes: 8 additions & 4 deletions app/cmd/client_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ quic:
maxIdleTimeout: 10s
keepAlivePeriod: 4s
disablePathMTUDiscovery: true
sockopts:
bindInterface: eth0
fwmark: 1234
fdControlUnixSocket: test.sock

bandwidth:
up: 200 mbps
Expand Down Expand Up @@ -75,7 +79,7 @@ tun:
ipv6: 2001::ffff:ffff:ffff:fff1/126
route:
strict: true
ipv4: [0.0.0.0/0]
ipv6: ["2000::/3"]
ipv4Exclude: [192.0.2.1/32]
ipv6Exclude: ["2001:db8::1/128"]
ipv4: [ 0.0.0.0/0 ]
ipv6: [ "2000::/3" ]
ipv4Exclude: [ 192.0.2.1/32 ]
ipv6Exclude: [ "2001:db8::1/128" ]
4 changes: 2 additions & 2 deletions app/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ require (
github.com/stretchr/testify v1.8.4
github.com/txthinking/socks5 v0.0.0-20230325130024-4230056ae301
go.uber.org/zap v1.24.0
golang.org/x/exp v0.0.0-20221205204356-47842c84f3db
golang.org/x/sys v0.17.0
)

require (
Expand Down Expand Up @@ -54,10 +56,8 @@ require (
go.uber.org/multierr v1.11.0 // indirect
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba // indirect
golang.org/x/crypto v0.19.0 // indirect
golang.org/x/exp v0.0.0-20221205204356-47842c84f3db // indirect
golang.org/x/mod v0.12.0 // indirect
golang.org/x/net v0.21.0 // indirect
golang.org/x/sys v0.17.0 // indirect
golang.org/x/text v0.14.0 // indirect
golang.org/x/tools v0.11.1 // indirect
google.golang.org/protobuf v1.33.0 // indirect
Expand Down
65 changes: 65 additions & 0 deletions app/internal/sockopts/fd_control_unix_socket_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import socket
import array
import os
import struct
import sys


def serve(path):
try:
os.unlink(path)
except OSError:
if os.path.exists(path):
raise

server = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
server.bind(path)
server.listen()
print(f"Listening on {path}")

try:
while True:
connection, client_address = server.accept()
print(f"Client connected")

try:
# Receiving fd from client
fds = array.array("i")
msg, ancdata, flags, addr = connection.recvmsg(1, socket.CMSG_LEN(struct.calcsize('i')))
for cmsg_level, cmsg_type, cmsg_data in ancdata:
if cmsg_level == socket.SOL_SOCKET and cmsg_type == socket.SCM_RIGHTS:
fds.frombytes(cmsg_data[:len(cmsg_data) - (len(cmsg_data) % fds.itemsize)])

fd = fds[0]

# We make a call to setsockopt(2) here, so client can verify we have received the fd
# In the real scenario, the server would set things like SO_MARK,
# we use SO_RCVBUF as it doesn't require any special capabilities.
nbytes = struct.pack("i", 2500)
fdsocket = fd_to_socket(fd)
fdsocket.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, nbytes)
fdsocket.close()

# The only protocol-like thing specified in the client implementation.
connection.send(b'\x01')
finally:
connection.close()
print("Connection closed")

except KeyboardInterrupt:
print("Exit")

finally:
server.close()
os.unlink(path)


def fd_to_socket(fd):
return socket.fromfd(fd, socket.AF_UNIX, socket.SOCK_STREAM)


if __name__ == "__main__":
if len(sys.argv) < 2:
raise ValueError("unix socket path is required")

serve(sys.argv[1])
76 changes: 76 additions & 0 deletions app/internal/sockopts/sockopts.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package sockopts

import (
"fmt"
"net"
)

type SocketOptions struct {
BindInterface *string
FirewallMark *uint32
FdControlUnixSocket *string
}

// implemented in platform-specific files
var (
bindInterfaceFunc func(c *net.UDPConn, device string) error
firewallMarkFunc func(c *net.UDPConn, fwmark uint32) error
fdControlUnixSocketFunc func(c *net.UDPConn, path string) error
)

func (o *SocketOptions) CheckSupported() (err error) {
if o.BindInterface != nil && bindInterfaceFunc == nil {
return &UnsupportedError{"bindInterface"}
}
if o.FirewallMark != nil && firewallMarkFunc == nil {
return &UnsupportedError{"fwmark"}
}
if o.FdControlUnixSocket != nil && fdControlUnixSocketFunc == nil {
return &UnsupportedError{"fdControlUnixSocket"}
}
return nil
}

type UnsupportedError struct {
Field string
}

func (e *UnsupportedError) Error() string {
return fmt.Sprintf("%s is not supported on this platform", e.Field)
}

func (o *SocketOptions) ListenUDP() (uconn net.PacketConn, err error) {
uconn, err = net.ListenUDP("udp", nil)
if err != nil {
return
}
err = o.applyToUDPConn(uconn.(*net.UDPConn))
if err != nil {
uconn.Close()
uconn = nil
return
}
return
}

func (o *SocketOptions) applyToUDPConn(c *net.UDPConn) error {
if o.BindInterface != nil && bindInterfaceFunc != nil {
err := bindInterfaceFunc(c, *o.BindInterface)
if err != nil {
return fmt.Errorf("failed to bind to interface: %w", err)
}
}
if o.FirewallMark != nil && firewallMarkFunc != nil {
err := firewallMarkFunc(c, *o.FirewallMark)
if err != nil {
return fmt.Errorf("failed to set fwmark: %w", err)
}
}
if o.FdControlUnixSocket != nil && fdControlUnixSocketFunc != nil {
err := fdControlUnixSocketFunc(c, *o.FdControlUnixSocket)
if err != nil {
return fmt.Errorf("failed to send fd to control unix socket: %w", err)
}
}
return nil
}
96 changes: 96 additions & 0 deletions app/internal/sockopts/sockopts_linux.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
//go:build linux

package sockopts

import (
"fmt"
"net"
"time"

"golang.org/x/exp/constraints"
"golang.org/x/sys/unix"
)

const (
fdControlUnixTimeout = 3 * time.Second
)

func init() {
bindInterfaceFunc = bindInterfaceImpl
firewallMarkFunc = firewallMarkImpl
fdControlUnixSocketFunc = fdControlUnixSocketImpl
}

func controlUDPConn(c *net.UDPConn, cb func(fd int) error) (err error) {
rconn, err := c.SyscallConn()
if err != nil {
return
}
cerr := rconn.Control(func(fd uintptr) {
err = cb(int(fd))
})
if err != nil {
return
}
if cerr != nil {
err = fmt.Errorf("failed to control fd: %w", cerr)
return
}
return
}

func bindInterfaceImpl(c *net.UDPConn, device string) error {
return controlUDPConn(c, func(fd int) error {
return unix.BindToDevice(fd, device)
})
}

func firewallMarkImpl(c *net.UDPConn, fwmark uint32) error {
return controlUDPConn(c, func(fd int) error {
return unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_MARK, int(fwmark))
})
}

func fdControlUnixSocketImpl(c *net.UDPConn, path string) error {
return controlUDPConn(c, func(fd int) error {
socketFd, err := unix.Socket(unix.AF_UNIX, unix.SOCK_STREAM, 0)
if err != nil {
return fmt.Errorf("failed to create unix socket: %w", err)
}
defer unix.Close(socketFd)

var timeout unix.Timeval
timeUsec := fdControlUnixTimeout.Microseconds()
castAssignInteger(timeUsec/1e6, &timeout.Sec)
// Specifying the type explicitly is not necessary here, but it makes GoLand happy.
castAssignInteger[int64](timeUsec%1e6, &timeout.Usec)

_ = unix.SetsockoptTimeval(socketFd, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &timeout)
_ = unix.SetsockoptTimeval(socketFd, unix.SOL_SOCKET, unix.SO_SNDTIMEO, &timeout)

err = unix.Connect(socketFd, &unix.SockaddrUnix{Name: path})
if err != nil {
return fmt.Errorf("failed to connect: %w", err)
}

err = unix.Sendmsg(socketFd, nil, unix.UnixRights(fd), nil, 0)
if err != nil {
return fmt.Errorf("failed to send: %w", err)
}

dummy := []byte{1}
n, err := unix.Read(socketFd, dummy)
if err != nil {
return fmt.Errorf("failed to receive: %w", err)
}
if n != 1 {
return fmt.Errorf("socket closed unexpectedly")
}

return nil
})
}

func castAssignInteger[F, T constraints.Integer](from F, to *T) {
*to = T(from)
}
Loading

0 comments on commit 234dc45

Please sign in to comment.