diff --git a/pkg/core/client.go b/pkg/core/client.go index 8f9fbbd1dc..822d3f20ea 100644 --- a/pkg/core/client.go +++ b/pkg/core/client.go @@ -61,20 +61,25 @@ func (c *Client) connectToServer() error { if err != nil { return err } - packetConn, err := net.ListenPacket("udp", "") + udpConn, err := net.ListenUDP("udp", nil) if err != nil { return err } + var qs quic.Session if c.obfuscator != nil { // Wrap PacketConn with obfuscator - packetConn = &obfsPacketConn{ - Orig: packetConn, + qs, err = quic.Dial(&obfsUDPConn{ + Orig: udpConn, Obfuscator: c.obfuscator, + }, serverUDPAddr, c.serverAddr, c.tlsConfig, c.quicConfig) + if err != nil { + return err + } + } else { + qs, err = quic.Dial(udpConn, serverUDPAddr, c.serverAddr, c.tlsConfig, c.quicConfig) + if err != nil { + return err } - } - qs, err := quic.Dial(packetConn, serverUDPAddr, c.serverAddr, c.tlsConfig, c.quicConfig) - if err != nil { - return err } // Control stream ctx, ctxCancel := context.WithTimeout(context.Background(), protocolTimeout) diff --git a/pkg/core/obfs.go b/pkg/core/obfs.go index 8454dcf6e2..68ff95c593 100644 --- a/pkg/core/obfs.go +++ b/pkg/core/obfs.go @@ -2,6 +2,8 @@ package core import ( "net" + "os" + "syscall" "time" ) @@ -10,12 +12,16 @@ type Obfuscator interface { Obfuscate(p []byte) []byte } -type obfsPacketConn struct { - Orig net.PacketConn +type obfsUDPConn struct { + Orig *net.UDPConn Obfuscator Obfuscator } -func (c *obfsPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { +func (c *obfsUDPConn) SyscallConn() (syscall.RawConn, error) { + return c.Orig.SyscallConn() +} + +func (c *obfsUDPConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { oldN, addr, err := c.Orig.ReadFrom(p) if oldN > 0 { newN := c.Obfuscator.Deobfuscate(p, oldN) @@ -25,7 +31,7 @@ func (c *obfsPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { } } -func (c *obfsPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { +func (c *obfsUDPConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { np := c.Obfuscator.Obfuscate(p) _, err = c.Orig.WriteTo(np, addr) if err != nil { @@ -35,22 +41,34 @@ func (c *obfsPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { } } -func (c *obfsPacketConn) Close() error { +func (c *obfsUDPConn) Close() error { return c.Orig.Close() } -func (c *obfsPacketConn) LocalAddr() net.Addr { +func (c *obfsUDPConn) LocalAddr() net.Addr { return c.Orig.LocalAddr() } -func (c *obfsPacketConn) SetDeadline(t time.Time) error { +func (c *obfsUDPConn) SetDeadline(t time.Time) error { return c.Orig.SetDeadline(t) } -func (c *obfsPacketConn) SetReadDeadline(t time.Time) error { +func (c *obfsUDPConn) SetReadDeadline(t time.Time) error { return c.Orig.SetReadDeadline(t) } -func (c *obfsPacketConn) SetWriteDeadline(t time.Time) error { +func (c *obfsUDPConn) SetWriteDeadline(t time.Time) error { return c.Orig.SetWriteDeadline(t) } + +func (c *obfsUDPConn) SetReadBuffer(bytes int) error { + return c.Orig.SetReadBuffer(bytes) +} + +func (c *obfsUDPConn) SetWriteBuffer(bytes int) error { + return c.Orig.SetWriteBuffer(bytes) +} + +func (c *obfsUDPConn) File() (f *os.File, err error) { + return c.Orig.File() +} diff --git a/pkg/core/server.go b/pkg/core/server.go index 8651d1b61d..f9cb85f0d2 100644 --- a/pkg/core/server.go +++ b/pkg/core/server.go @@ -38,20 +38,29 @@ func NewServer(addr string, tlsConfig *tls.Config, quicConfig *quic.Config, sendBPS uint64, recvBPS uint64, congestionFactory CongestionFactory, disableUDP bool, aclEngine *acl.Engine, obfuscator Obfuscator, authFunc AuthFunc, tcpRequestFunc TCPRequestFunc, tcpErrorFunc TCPErrorFunc, udpRequestFunc UDPRequestFunc, udpErrorFunc UDPErrorFunc) (*Server, error) { - packetConn, err := net.ListenPacket("udp", addr) + udpAddr, err := net.ResolveUDPAddr("udp", addr) if err != nil { return nil, err } + udpConn, err := net.ListenUDP("udp", udpAddr) + if err != nil { + return nil, err + } + var listener quic.Listener if obfuscator != nil { // Wrap PacketConn with obfuscator - packetConn = &obfsPacketConn{ - Orig: packetConn, + listener, err = quic.Listen(&obfsUDPConn{ + Orig: udpConn, Obfuscator: obfuscator, + }, tlsConfig, quicConfig) + if err != nil { + return nil, err + } + } else { + listener, err = quic.Listen(udpConn, tlsConfig, quicConfig) + if err != nil { + return nil, err } - } - listener, err := quic.Listen(packetConn, tlsConfig, quicConfig) - if err != nil { - return nil, err } s := &Server{ listener: listener,