Skip to content

Commit

Permalink
Update network handler usages
Browse files Browse the repository at this point in the history
  • Loading branch information
nekohasekai committed Oct 20, 2024
1 parent 77e0b10 commit 68f92d1
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 43 deletions.
6 changes: 3 additions & 3 deletions brutal.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/common/varbin"
)

const (
Expand All @@ -32,7 +32,7 @@ func WriteBrutalResponse(writer io.Writer, receiveBPS uint64, ok bool, message s
if ok {
common.Must(binary.Write(buffer, binary.BigEndian, receiveBPS))
} else {
err := rw.WriteVString(buffer, message)
err := varbin.Write(buffer, binary.BigEndian, message)
if err != nil {
return err
}
Expand All @@ -52,7 +52,7 @@ func ReadBrutalResponse(reader io.Reader) (uint64, error) {
return receiveBPS, err
} else {
var message string
message, err = rw.ReadVString(reader)
message, err = varbin.ReadValue[string](reader, binary.BigEndian)
if err != nil {
return 0, err
}
Expand Down
12 changes: 6 additions & 6 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
module github.com/sagernet/sing-mux

go 1.18
go 1.20

require (
github.com/hashicorp/yamux v0.1.1
github.com/sagernet/sing v0.3.0
github.com/hashicorp/yamux v0.1.2
github.com/sagernet/sing v0.5.0-rc.4.0.20241020060022-1270938dd44a
github.com/sagernet/smux v0.0.0-20231208180855-7041f6ea79e7
golang.org/x/net v0.19.0
golang.org/x/sys v0.16.0
golang.org/x/net v0.30.0
golang.org/x/sys v0.26.0
)

require golang.org/x/text v0.14.0 // indirect
require golang.org/x/text v0.19.0 // indirect
22 changes: 11 additions & 11 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/hashicorp/yamux v0.1.1 h1:yrQxtgseBDrq9Y652vSRDvsKCJKOUD+GzTS4Y0Y8pvE=
github.com/hashicorp/yamux v0.1.1/go.mod h1:CtWFDAQgb7dxtzFs4tWbplKIe2jSi3+5vKbgIO0SLnQ=
github.com/hashicorp/yamux v0.1.2 h1:XtB8kyFOyHXYVFnwT5C3+Bdo8gArse7j2AQ0DA0Uey8=
github.com/hashicorp/yamux v0.1.2/go.mod h1:C+zze2n6e/7wshOZep2A70/aQU6QBRWJO/G6FT1wIns=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/sagernet/sing v0.2.18/go.mod h1:OL6k2F0vHmEzXz2KW19qQzu172FDgSbUSODylighuVo=
github.com/sagernet/sing v0.3.0 h1:PIDVFZHnQAAYRL1UYqNM+0k5s8f/tb1lUW6UDcQiOc8=
github.com/sagernet/sing v0.3.0/go.mod h1:9pfuAH6mZfgnz/YjP6xu5sxx882rfyjpcrTdUpd6w3g=
github.com/sagernet/sing v0.5.0-rc.4.0.20241020060022-1270938dd44a h1:6qlFfBvLZT/MhDpUr4cKY6RxYTnaCcFgOrJEnf/0+io=
github.com/sagernet/sing v0.5.0-rc.4.0.20241020060022-1270938dd44a/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak=
github.com/sagernet/smux v0.0.0-20231208180855-7041f6ea79e7 h1:DImB4lELfQhplLTxeq2z31Fpv8CQqqrUwTbrIRumZqQ=
github.com/sagernet/smux v0.0.0-20231208180855-7041f6ea79e7/go.mod h1:FP9X2xjT/Az1EsG/orYYoC+5MojWnuI7hrffz8fGwwo=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c=
golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4=
golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU=
golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.16.0 h1:xWw16ngr6ZMtmxDyKyIgsE93KNKz5HKmMa3b8ALHidU=
golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo=
golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM=
golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
16 changes: 10 additions & 6 deletions protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/common/varbin"
)

const (
Expand Down Expand Up @@ -41,14 +42,18 @@ type Request struct {
}

func ReadRequest(reader io.Reader) (*Request, error) {
version, err := rw.ReadByte(reader)
var (
version byte
protocol byte
)
err := binary.Read(reader, binary.BigEndian, &version)
if err != nil {
return nil, err
}
if version < Version0 || version > Version1 {
return nil, E.New("unsupported version: ", version)
}
protocol, err := rw.ReadByte(reader)
err = binary.Read(reader, binary.BigEndian, &protocol)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -166,13 +171,12 @@ type StreamResponse struct {

func ReadStreamResponse(reader io.Reader) (*StreamResponse, error) {
var response StreamResponse
status, err := rw.ReadByte(reader)
err := binary.Read(reader, binary.BigEndian, &response.Status)
if err != nil {
return nil, err
}
response.Status = status
if status == statusError {
response.Message, err = rw.ReadVString(reader)
if response.Status == statusError {
response.Message, err = varbin.ReadValue[string](reader, binary.BigEndian)
if err != nil {
return nil, err
}
Expand Down
55 changes: 45 additions & 10 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,24 @@ import (
"github.com/sagernet/sing/common/task"
)

// Deprecated: Use ServiceHandlerEx instead.
//
//nolint:staticcheck
type ServiceHandler interface {
N.TCPConnectionHandler
N.UDPConnectionHandler
}

type ServiceHandlerEx interface {
N.TCPConnectionHandlerEx
N.UDPConnectionHandlerEx
}

type Service struct {
newStreamContext func(context.Context, net.Conn) context.Context
logger logger.ContextLogger
handler ServiceHandler
handlerEx ServiceHandlerEx
padding bool
brutal BrutalOptions
}
Expand All @@ -30,6 +39,7 @@ type ServiceOptions struct {
NewStreamContext func(context.Context, net.Conn) context.Context
Logger logger.ContextLogger
Handler ServiceHandler
HandlerEx ServiceHandlerEx
Padding bool
Brutal BrutalOptions
}
Expand All @@ -42,12 +52,28 @@ func NewService(options ServiceOptions) (*Service, error) {
newStreamContext: options.NewStreamContext,
logger: options.Logger,
handler: options.Handler,
handlerEx: options.HandlerEx,
padding: options.Padding,
brutal: options.Brutal,
}, nil
}

// Deprecated: Use NewConnectionEx instead.
func (s *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
return s.newConnection(ctx, conn, metadata.Source)
}

func (s *Service) NewConnectionEx(ctx context.Context, conn net.Conn, source M.Socksaddr, destination M.Socksaddr, onClose N.CloseHandler) {
err := s.newConnection(ctx, conn, source)
if err != nil {
N.ReportHandshakeFailure(conn, err)
conn.Close()
onClose(err)
s.logger.ErrorContext(ctx, E.Cause(err, "process multiplex connection from ", source))
}
}

func (s *Service) newConnection(ctx context.Context, conn net.Conn, source M.Socksaddr) error {
request, err := ReadRequest(conn)
if err != nil {
return err
Expand All @@ -71,9 +97,10 @@ func (s *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.M
}
streamCtx := s.newStreamContext(ctx, stream)
go func() {
hErr := s.newConnection(streamCtx, conn, stream, metadata)
hErr := s.newSession(streamCtx, conn, stream, source)
if hErr != nil {
s.logger.ErrorContext(streamCtx, E.Cause(hErr, "handle connection"))
stream.Close()
s.logger.ErrorContext(streamCtx, E.Cause(hErr, "process multiplex stream"))
}
}()
}
Expand All @@ -84,13 +111,13 @@ func (s *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.M
return group.Run(ctx)
}

func (s *Service) newConnection(ctx context.Context, sessionConn net.Conn, stream net.Conn, metadata M.Metadata) error {
func (s *Service) newSession(ctx context.Context, sessionConn net.Conn, stream net.Conn, source M.Socksaddr) error {
stream = &wrapStream{stream}
request, err := ReadStreamRequest(stream)
if err != nil {
return E.Cause(err, "read multiplex stream request")
}
metadata.Destination = request.Destination
destination := request.Destination
if request.Network == N.NetworkTCP {
conn := &serverConn{ExtendedConn: bufio.NewExtendedConn(stream)}
if request.Destination.Fqdn == BrutalExchangeDomain {
Expand Down Expand Up @@ -128,20 +155,28 @@ func (s *Service) newConnection(ctx context.Context, sessionConn net.Conn, strea
}
return nil
}
s.logger.InfoContext(ctx, "inbound multiplex connection to ", metadata.Destination)
s.handler.NewConnection(ctx, conn, metadata)
stream.Close()
s.logger.InfoContext(ctx, "inbound multiplex connection to ", destination)
if s.handler != nil {
//nolint:staticcheck
s.handler.NewConnection(ctx, conn, M.Metadata{Source: source, Destination: destination})
} else {
s.handlerEx.NewConnectionEx(ctx, conn, source, destination, nil)
}
} else {
var packetConn N.PacketConn
if !request.PacketAddr {
s.logger.InfoContext(ctx, "inbound multiplex packet connection to ", metadata.Destination)
s.logger.InfoContext(ctx, "inbound multiplex packet connection to ", destination)
packetConn = &serverPacketConn{ExtendedConn: bufio.NewExtendedConn(stream), destination: request.Destination}
} else {
s.logger.InfoContext(ctx, "inbound multiplex packet connection")
packetConn = &serverPacketAddrConn{ExtendedConn: bufio.NewExtendedConn(stream)}
}
s.handler.NewPacketConnection(ctx, packetConn, metadata)
stream.Close()
if s.handler != nil {
//nolint:staticcheck
s.handler.NewPacketConnection(ctx, packetConn, M.Metadata{Source: source, Destination: destination})
} else {
s.handlerEx.NewPacketConnectionEx(ctx, packetConn, source, destination, nil)
}
}
return nil
}
14 changes: 7 additions & 7 deletions server_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/common/varbin"
)

type serverConn struct {
Expand All @@ -24,11 +24,11 @@ func (c *serverConn) NeedHandshake() bool {

func (c *serverConn) HandshakeFailure(err error) error {
errMessage := err.Error()
buffer := buf.NewSize(1 + rw.UVariantLen(uint64(len(errMessage))) + len(errMessage))
buffer := buf.NewSize(1 + varbin.UvarintLen(uint64(len(errMessage))) + len(errMessage))
defer buffer.Release()
common.Must(
buffer.WriteByte(statusError),
rw.WriteVString(buffer, errMessage),
varbin.Write(buffer, binary.BigEndian, errMessage),
)
return common.Error(c.ExtendedConn.Write(buffer.Bytes()))
}
Expand Down Expand Up @@ -88,11 +88,11 @@ func (c *serverPacketConn) NeedHandshake() bool {

func (c *serverPacketConn) HandshakeFailure(err error) error {
errMessage := err.Error()
buffer := buf.NewSize(1 + rw.UVariantLen(uint64(len(errMessage))) + len(errMessage))
buffer := buf.NewSize(1 + varbin.UvarintLen(uint64(len(errMessage))) + len(errMessage))
defer buffer.Release()
common.Must(
buffer.WriteByte(statusError),
rw.WriteVString(buffer, errMessage),
varbin.Write(buffer, binary.BigEndian, errMessage),
)
return common.Error(c.ExtendedConn.Write(buffer.Bytes()))
}
Expand Down Expand Up @@ -188,11 +188,11 @@ func (c *serverPacketAddrConn) NeedHandshake() bool {

func (c *serverPacketAddrConn) HandshakeFailure(err error) error {
errMessage := err.Error()
buffer := buf.NewSize(1 + rw.UVariantLen(uint64(len(errMessage))) + len(errMessage))
buffer := buf.NewSize(1 + varbin.UvarintLen(uint64(len(errMessage))) + len(errMessage))
defer buffer.Release()
common.Must(
buffer.WriteByte(statusError),
rw.WriteVString(buffer, errMessage),
varbin.Write(buffer, binary.BigEndian, errMessage),
)
return common.Error(c.ExtendedConn.Write(buffer.Bytes()))
}
Expand Down

0 comments on commit 68f92d1

Please sign in to comment.