diff --git a/NOTICE b/NOTICE index c4351b8..d10de77 100644 --- a/NOTICE +++ b/NOTICE @@ -45,3 +45,24 @@ This project uses font from https://www.onlygfx.com/newspaper-cutout-font-white- This project uses software from https://github.com/oschwald/maxminddb-golang (ISC) * Copyright (c) 2015, Gregory J. Oschwald + +This project uses software from https://git.zx2c4.com/wireguard-go/about/ (MIT) +* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +of the Software, and to permit persons to whom the Software is furnished to do +so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/README.md b/README.md index 3f63a3c..5098649 100644 --- a/README.md +++ b/README.md @@ -129,6 +129,27 @@ Fabric that holds application components together. * `state` - where data persists on disk through restarts of the application. Default is `.slrp/data` of your home directory. * `sync` - how often data is synchronised to disk, pending availability of any updates of component state. Default is every minute. +## dialer + +[WireGuard](https://www.wireguard.com/) userspace VPN dialer configuration. Embeds the official [Go implementation](https://git.zx2c4.com/wireguard-go). Disabled by default. + +* `wireguard_config_file` - [configuration file](https://www.wireguard.com/#cryptokey-routing) from WireGuard. IPv6 address parsing is ignored at the moment. +* `wireguard_verbose` - verbose logging mode for WireGuard tunnel. + +Sample WireGuard configuration file: + +```ini +[Interface] +PrivateKey = gI6EdUSYvn8ugXOt8QQD6Yc+JyiZxIhp3GInSWRfWGE= +Address = 1.2.3.4/24 +DNS = 1.2.3.4 + +[Peer] +PublicKey = HIgo9xNzJMWLKASShiTqIybxZ0U3wGLiUeJ1PKf8ykw= +Endpoint = 1.2.3.4:51820 +AllowedIPs = 0.0.0.0/0 +``` + ## log Structured logging meta-components. diff --git a/checker/checker.go b/checker/checker.go index 9b5e90c..7b5d005 100644 --- a/checker/checker.go +++ b/checker/checker.go @@ -6,6 +6,7 @@ import ( "fmt" "io/ioutil" "math/rand" + "net" "net/http" "regexp" "strings" @@ -13,6 +14,7 @@ import ( "github.com/nfx/slrp/app" "github.com/nfx/slrp/pmux" + "github.com/rs/zerolog/log" "github.com/corpix/uarand" "github.com/microcosm-cc/bluemonday" @@ -22,6 +24,14 @@ type Checker interface { Check(ctx context.Context, proxy pmux.Proxy) (time.Duration, error) } +type dialer interface { + DialContext(ctx context.Context, network, address string) (net.Conn, error) +} + +type httpClient interface { + Do(req *http.Request) (*http.Response, error) +} + var ( firstPass = []string{ // these check for ext ip, but don't show headers @@ -44,57 +54,81 @@ var ( ErrNotAnonymous = fmt.Errorf("this IP address found") ) -var defaultClient httpClient = pmux.DefaultHttpClient - -func init() { - defaultClient = &http.Client{ - Transport: pmux.ContextualHttpTransport(), - Timeout: 5 * time.Second, - } -} - -func NewChecker() Checker { - ip, err := thisIP() - if err != nil { - panic(fmt.Errorf("cannot get this IP: %w", err)) - } +func NewChecker(dialer dialer) Checker { return &configurableChecker{ - ip: ip, - client: defaultClient, - strategies: map[string]Checker{ - "twopass": newTwoPass(ip, defaultClient), - "simple": newFederated(firstPass, defaultClient, ip), - "headers": newFederated([]string{ - "https://ifconfig.me/all", - "https://ifconfig.io/all.json", - }, defaultClient, ip), + client: &http.Client{ + Timeout: 5 * time.Second, + Transport: &http.Transport{ + DialContext: dialer.DialContext, + TLSClientConfig: pmux.DefaultTlsConfig, + Proxy: pmux.ProxyFromContext, + }, }, - strategy: "simple", } } type configurableChecker struct { - ip string - client httpClient - strategies map[string]Checker - strategy string + ip string + client httpClient + strategy Checker } func (cc *configurableChecker) Configure(conf app.Config) error { - cc.strategy = conf.StrOr("strategy", "simple") - _, invalidStrategy := cc.strategies[cc.strategy] - if !invalidStrategy { - return fmt.Errorf("invalid strategy: %s", cc.strategy) + ip, err := cc.thisIP() + if ip == "" { + return fmt.Errorf("IP is empty") + } + if err != nil { + return fmt.Errorf("cannot get this IP: %w", err) + } + cc.ip = ip + strategies := map[string]Checker{ + "twopass": newTwoPass(ip, cc.client), + "simple": newFederated(firstPass, cc.client, ip), + "headers": newFederated([]string{ + "https://ifconfig.me/all", + "https://ifconfig.io/all.json", + }, cc.client, ip), + } + strategyName := conf.StrOr("strategy", "simple") + strategy, ok := strategies[strategyName] + if !ok { + return fmt.Errorf("invalid strategy: %s", strategyName) } + cc.strategy = strategy + timeout := conf.DurOr("timeout", 5*time.Second) original, ok := cc.client.(*http.Client) if ok { - original.Timeout = conf.DurOr("timeout", 5*time.Second) + original.Timeout = timeout } + log.Info(). + Str("ip", ip). + Str("strategy", strategyName). + Dur("timeout", timeout). + Msg("configured proxy checker") return nil } +func (cc *configurableChecker) thisIP() (string, error) { + req, err := http.NewRequest("GET", "https://ifconfig.me/ip", nil) + if err != nil { + return "", err + } + r, err := cc.client.Do(req) + if err != nil { + return "", err + } + defer r.Body.Close() + s := bufio.NewScanner(r.Body) + s.Scan() + return s.Text(), nil +} + func (cc *configurableChecker) Check(ctx context.Context, proxy pmux.Proxy) (time.Duration, error) { - return cc.strategies[cc.strategy].Check(ctx, proxy) + if cc.strategy == nil { + return 0, fmt.Errorf("no strategy") + } + return cc.strategy.Check(ctx, proxy) } func newTwoPass(ip string, client httpClient) twoPass { @@ -158,10 +192,6 @@ func (f federated) Check(ctx context.Context, proxy pmux.Proxy) (time.Duration, return f[choice].Check(ctx, proxy) } -type httpClient interface { - Do(req *http.Request) (*http.Response, error) -} - type simple struct { client httpClient page string @@ -233,17 +263,6 @@ func truncatedBody(body string) string { return body } -func thisIP() (string, error) { - r, err := http.Get("https://ifconfig.me/ip") - if err != nil { - return "", err - } - defer r.Body.Close() - s := bufio.NewScanner(r.Body) - s.Scan() - return s.Text(), nil -} - type temporary string func (t temporary) Temporary() bool { diff --git a/checker/checker_test.go b/checker/checker_test.go index a75f192..f60d738 100644 --- a/checker/checker_test.go +++ b/checker/checker_test.go @@ -5,6 +5,7 @@ import ( "context" "fmt" "io" + "net" "net/http" "strings" "testing" @@ -16,39 +17,37 @@ import ( ) func TestFailure(t *testing.T) { - defaultClient = &staticResponseClient{ + c := NewChecker(&checkerShim{ err: fmt.Errorf("fails"), - } - c := NewChecker() - + }) ctx := context.Background() _, err := c.Check(ctx, pmux.HttpProxy("127.0.0.1:1")) - assert.EqualError(t, err, "fails") + assert.EqualError(t, err, "no strategy") } func TestConfigurableChecker(t *testing.T) { client := http.DefaultClient - c := configurableChecker{ + c := &configurableChecker{ client: client, - strategies: map[string]Checker{ - "simple": &simple{}, // just for tests - }, } err := c.Configure(app.Config{}) assert.NoError(t, err) - assert.Equal(t, "simple", c.strategy) assert.Equal(t, time.Second*5, client.Timeout) } -type staticResponseClient struct { +type checkerShim struct { http.Response err error } -func (r staticResponseClient) Do(req *http.Request) (*http.Response, error) { +func (r checkerShim) Do(req *http.Request) (*http.Response, error) { return &r.Response, r.err } +func (r checkerShim) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + return nil, r.err +} + func body(x string) io.ReadCloser { return io.NopCloser(bytes.NewBufferString(x)) } @@ -90,7 +89,7 @@ func TestTwoPassCheck(t *testing.T) { &simple{ ip: "XYZ", valid: "..", - client: staticResponseClient{ + client: checkerShim{ Response: http.Response{ Body: body(tt.firstBody), StatusCode: 200, @@ -103,7 +102,7 @@ func TestTwoPassCheck(t *testing.T) { &simple{ ip: "XYZ", valid: "..", - client: staticResponseClient{ + client: checkerShim{ Response: http.Response{ Body: body(tt.secondBody), StatusCode: 200, @@ -197,7 +196,7 @@ func TestSimpleCheck(t *testing.T) { ip: "255.0.0.1", valid: tt.valid, page: tt.page, - client: staticResponseClient{ + client: checkerShim{ Response: http.Response{ Body: tt.body, StatusCode: 200, diff --git a/dialer/dialer.go b/dialer/dialer.go new file mode 100644 index 0000000..d63ae9b --- /dev/null +++ b/dialer/dialer.go @@ -0,0 +1,187 @@ +package dialer + +import ( + "bytes" + "context" + "encoding/base64" + "fmt" + "net" + "net/netip" + "strings" + + "github.com/nfx/slrp/app" + "github.com/nfx/slrp/dialer/ini" + "github.com/rs/zerolog/log" + "golang.zx2c4.com/wireguard/conn" + "golang.zx2c4.com/wireguard/device" + "golang.zx2c4.com/wireguard/tun" + "golang.zx2c4.com/wireguard/tun/netstack" +) + +// wireGuardDialer implements the app.Service interface and represents the WireGuard dialer. +type wireGuardDialer struct { + standard net.Dialer + tunnel *netstack.Net + conf ini.Config + verbose bool +} + +// NewDialer creates a new instance of the WireGuard dialer. +func NewDialer() *wireGuardDialer { + return &wireGuardDialer{} +} + +// Configure initializes the WireGuard dialer with the provided configuration. +func (d *wireGuardDialer) Configure(c app.Config) error { + configFile := c.StrOr("wireguard_config_file", "") + if configFile == "" { + // If no WireGuard config file is specified, use the standard net.Dialer. + log.Warn().Msg("using clear dialer") + return nil + } + d.verbose = c.BoolOr("wireguard_verbose", false) + conf, err := ini.ParseINI(configFile) + if err != nil { + return fmt.Errorf("parse %s: %w", configFile, err) + } + d.conf = conf + + log.Info(). + Str("config", configFile). + Str("endpoint", d.conf["Peer"]["Endpoint"]). + Msg("configured WireGuard dialer") + + // https://www.wireguard.com/xplatform/ + // Create the WireGuard tunnel and device based on the configuration. + tun, tnet, err := d.createNetTUN() + if err != nil { + return fmt.Errorf("create net tun: %w", err) + } + bind := conn.NewDefaultBind() + verboseF := func(format string, args ...any) {} + if d.verbose { + verboseF = func(format string, args ...any) { + log.Debug().Str("service", "wireguard").Msgf(format, args...) + } + } + dev := device.NewDevice(tun, bind, &device.Logger{ + // Define custom logger functions for WireGuard device logging. + Errorf: func(format string, args ...any) { + log.Error().Str("service", "wireguard").Msgf(format, args...) + }, + Verbosef: verboseF, + }) + ipc, err := d.getIpsSetShim() + if err != nil { + return fmt.Errorf("ipc shim: %w", err) + } + err = dev.IpcSetOperation(ipc) + if err != nil { + return fmt.Errorf("ipc set: %w", err) + } + err = dev.Up() + if err != nil { + return fmt.Errorf("up: %w", err) + } + d.tunnel = tnet + return nil +} + +// addrsFromConfig parses a comma-separated list of IP addresses from the configuration section and key. +func (d *wireGuardDialer) addrsFromConfig(section, key string) (addrs []netip.Addr, err error) { + // Fetch the comma-separated value from the configuration. + value := d.conf[section][key] + for _, v := range strings.Split(value, ",") { + v = strings.TrimSpace(v) + if strings.Contains(v, ":") { + // Skip IPv6 addresses for now (not supported). + continue + } + if strings.Contains(v, "/") { + // Parse the IP address with subnet prefix if present. + addr, err := netip.ParsePrefix(v) + if err != nil { + return nil, err + } + addrs = append(addrs, addr.Addr()) + continue + } + // Parse the IP address without subnet prefix. + addr, err := netip.ParseAddr(v) + if err != nil { + return nil, err + } + addrs = append(addrs, addr) + } + return addrs, nil +} + +// createNetTUN creates a network TUN interface with the specified IP addresses and DNS servers. +func (d *wireGuardDialer) createNetTUN() (tun.Device, *netstack.Net, error) { + addrs, err := d.addrsFromConfig("Interface", "Address") + if err != nil { + return nil, nil, err + } + dns, err := d.addrsFromConfig("Interface", "DNS") + if err != nil { + return nil, nil, err + } + // Create the network TUN interface using the netstack package with the obtained IP addresses and DNS servers. + return netstack.CreateNetTUN(addrs, dns, 1420) +} + +// writeHexKeyAs writes a base64-encoded key to the buffer with the specified label. +func (d *wireGuardDialer) writeHexKeyAs(b *bytes.Buffer, section, key, as string) error { + // Decode the base64-encoded key from the configuration. + raw, err := base64.StdEncoding.DecodeString(d.conf[section][key]) + if err != nil { + return err + } + // Write the key to the buffer as a hexadecimal value with the given label. + _, err = b.WriteString(fmt.Sprintf("%s=%x\n", as, raw)) + return err +} + +func (d *wireGuardDialer) getIpsSetShim() (*bytes.Buffer, error) { + b := &bytes.Buffer{} + // Write private, public, and preshared keys to the buffer as hexadecimal values. + err := d.writeHexKeyAs(b, "Interface", "PrivateKey", "private_key") + if err != nil { + return nil, err + } + err = d.writeHexKeyAs(b, "Peer", "PublicKey", "public_key") + if err != nil { + return nil, err + } + err = d.writeHexKeyAs(b, "Peer", "PresharedKey", "preshared_key") + if err != nil { + return nil, err + } + // Add the allowed IP and endpoint information to the buffer. + _, err = b.WriteString("allowed_ip=0.0.0.0/0\n") + //_, err = b.WriteString(fmt.Sprintf("allowed_ip=%s\n", d.conf["Peer"]["AllowedIPs"])) + if err != nil { + return nil, err + } + _, err = b.WriteString(fmt.Sprintf("endpoint=%s\n", d.conf["Peer"]["Endpoint"])) + if err != nil { + return nil, err + } + return b, nil +} + +// DialContext establishes a network connection using the WireGuard tunnel if available, +// otherwise, it uses the standard net.Dialer. +func (d *wireGuardDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + if d.tunnel != nil { + // If the WireGuard tunnel is available, use it to establish the connection. + return d.tunnel.DialContext(ctx, network, address) + } + // If there is no WireGuard tunnel, fall back to the standard net.Dialer. + return d.standard.DialContext(ctx, network, address) +} + +// Dial is a convenience function that calls DialContext with a background context. +func (d *wireGuardDialer) Dial(network, addr string) (net.Conn, error) { + return d.DialContext(context.Background(), network, addr) +} diff --git a/dialer/ini/ini.go b/dialer/ini/ini.go new file mode 100644 index 0000000..565f15f --- /dev/null +++ b/dialer/ini/ini.go @@ -0,0 +1,48 @@ +package ini + +import ( + "bufio" + "fmt" + "os" + "strings" +) + +type Config map[string]map[string]string + +func ParseINI(filename string) (Config, error) { + file, err := os.Open(filename) + if err != nil { + return nil, err + } + defer file.Close() + + cfg := make(Config) + var currentSection string + + scanner := bufio.NewScanner(file) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if len(line) == 0 || line[0] == ';' || line[0] == '#' { + continue // Skip empty lines and comments + } + + if line[0] == '[' && line[len(line)-1] == ']' { + currentSection = line[1 : len(line)-1] + cfg[currentSection] = make(map[string]string) + } else if idx := strings.Index(line, "="); idx > 0 { + key := strings.TrimSpace(line[:idx]) + value := strings.TrimSpace(line[idx+1:]) + if currentSection != "" { + cfg[currentSection][key] = value + } + } else { + return nil, fmt.Errorf("invalid line: %s", line) + } + } + + if err := scanner.Err(); err != nil { + return nil, err + } + + return cfg, nil +} diff --git a/dialer/ini/ini_test.go b/dialer/ini/ini_test.go new file mode 100644 index 0000000..d40b7e4 --- /dev/null +++ b/dialer/ini/ini_test.go @@ -0,0 +1,99 @@ +package ini + +import ( + "os" + "testing" + + "github.com/stretchr/testify/assert" +) + +// TestParseINI tests the ParseINI function. +func TestParseINI(t *testing.T) { + tests := []struct { + name string + input string + expectedCfg Config + expectedErr bool + expectedKeys int + }{ + { + name: "ValidINIFile", + input: `[Section1] +Key1=Value1 +Key2=Value2 + +[Section2] +KeyA=ValueA +KeyB=ValueB +`, + expectedCfg: Config{ + "Section1": { + "Key1": "Value1", + "Key2": "Value2", + }, + "Section2": { + "KeyA": "ValueA", + "KeyB": "ValueB", + }, + }, + expectedErr: false, + expectedKeys: 4, + }, + { + name: "EmptyINIFile", + input: "", + expectedCfg: Config{}, + expectedErr: false, + expectedKeys: 0, + }, + { + name: "InvalidINIFile", + input: `[Section1] +Key1=Value1 +MissingKey +`, + expectedErr: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + file, err := createTempFile(test.input) + assert.NoError(t, err) + defer os.Remove(file) + + cfg, err := ParseINI(file) + + if test.expectedErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + + assert.Equal(t, test.expectedCfg, cfg) + assert.Equal(t, test.expectedKeys, countKeys(cfg)) + }) + } +} + +// Helper function to create a temporary file with the given content. +func createTempFile(content string) (string, error) { + file, err := os.CreateTemp("", "testfile*.ini") + if err != nil { + return "", err + } + _, err = file.WriteString(content) + if err != nil { + return "", err + } + return file.Name(), nil +} + +// Helper function to count the total number of keys in a Config object. +func countKeys(cfg Config) int { + count := 0 + for _, section := range cfg { + count += len(section) + } + return count +} diff --git a/go.mod b/go.mod index da9bbf7..0e3df61 100644 --- a/go.mod +++ b/go.mod @@ -18,6 +18,7 @@ require ( github.com/yosssi/gohtml v0.0.0-20201013000340-ee4748c638f4 golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 golang.org/x/net v0.12.0 + golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 gopkg.in/natefinch/lumberjack.v2 v2.2.1 ) @@ -30,6 +31,7 @@ require ( github.com/dlclark/regexp2 v1.10.0 // indirect github.com/dsnet/compress v0.0.1 // indirect github.com/go-sourcemap/sourcemap v2.1.3+incompatible // indirect + github.com/google/btree v1.0.1 // indirect github.com/google/go-github v17.0.0+incompatible // indirect github.com/google/go-querystring v1.1.0 // indirect github.com/google/pprof v0.0.0-20230207041349-798e818bf904 // indirect @@ -44,8 +46,12 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect github.com/ulikunitz/xz v0.5.11 // indirect go4.org/netipx v0.0.0-20230303233057-f1b76eb4bb35 // indirect + golang.org/x/crypto v0.11.0 // indirect golang.org/x/sys v0.10.0 // indirect golang.org/x/text v0.11.0 // indirect + golang.org/x/time v0.0.0-20191024005414-555d28b269f0 // indirect + golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect + gvisor.dev/gvisor v0.0.0-20221203005347-703fd9b7fbc0 // indirect ) diff --git a/go.sum b/go.sum index 1fcc83a..fe1b907 100644 --- a/go.sum +++ b/go.sum @@ -47,6 +47,8 @@ github.com/go-sourcemap/sourcemap v2.1.3+incompatible/go.mod h1:F8jJfvm2KbVjc5Nq github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4= +github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA= github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= github.com/google/go-github v17.0.0+incompatible h1:N0LgJ1j65A7kfXrZnUDaYCs/Sf4rEjNlfyDHW9dolSY= @@ -150,6 +152,8 @@ go4.org/netipx v0.0.0-20230303233057-f1b76eb4bb35/go.mod h1:TQvodOM+hJTioNQJilmL golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190426145343-a29dc8fdc734/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.11.0 h1:6Ewdq3tDic1mg5xRO4milcWCfMVQhI4NkqWWvqejpuA= +golang.org/x/crypto v0.11.0/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio= golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 h1:k/i9J1pBpvlfR+9QsetwPyERsqu1GIbi967PQMq3Ivc= golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= @@ -190,11 +194,17 @@ golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= golang.org/x/text v0.11.0 h1:LAntKIrcmeSKERyiOh0XMV39LXS8IE9UL2yP7+f5ij4= golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs= +golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg= +golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= +golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 h1:EY138uSo1JYlDq+97u1FtcOUwPpIU6WL1Lkt7WpYjPA= +golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1/go.mod h1:tqur9LnfstdR9ep2LaJT4lFUl0EjlHtge+gAjmsHUG4= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= @@ -213,3 +223,5 @@ gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C gopkg.in/yaml.v3 v3.0.0-20200605160147-a5ece683394c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gvisor.dev/gvisor v0.0.0-20221203005347-703fd9b7fbc0 h1:Wobr37noukisGxpKo5jAsLREcpj61RxrWYzD8uwveOY= +gvisor.dev/gvisor v0.0.0-20221203005347-703fd9b7fbc0/go.mod h1:Dn5idtptoW1dIos9U6A2rpebLs/MtTwFacjKb8jLdQA= diff --git a/main.go b/main.go index 99ce4d7..1bfd395 100644 --- a/main.go +++ b/main.go @@ -8,6 +8,7 @@ import ( "github.com/nfx/slrp/app" "github.com/nfx/slrp/checker" + "github.com/nfx/slrp/dialer" "github.com/nfx/slrp/history" "github.com/nfx/slrp/internal/updater" "github.com/nfx/slrp/ipinfo" @@ -20,7 +21,7 @@ import ( var version = "devel" -//go:embed ui/build +//go:embed ui/build/* var embedFrontend embed.FS func main() { @@ -35,6 +36,7 @@ func main() { "blacklist": probe.NewBlacklistApi, "checker": checker.NewChecker, "dashboard": serve.NewDashboard, + "dialer": dialer.NewDialer, "history": history.NewHistory, "ipinfo": ipinfo.NewLookup, "mitm": serve.NewMitmProxyServer, diff --git a/main_test.go b/main_test.go index dc5262e..c57658d 100644 --- a/main_test.go +++ b/main_test.go @@ -9,12 +9,14 @@ import ( func TestMain(t *testing.T) { qa.RunOnlyInDebug(t) - if false { - os.Setenv("SLRP_PPROF_ENABLE", "true") - os.Setenv("SLRP_HISTORY_LIMIT", "10000") - os.Setenv("SLRP_LOG_LEVEL", "debug") - os.Setenv("SLRP_LOG_FORMAT", "file") // TODO: eek, make it better - os.Setenv("SLRP_LOG_FILE", "/tmp/$APP.log") // TODO: eek, make it better + if true { + // os.Setenv("SLRP_PPROF_ENABLE", "true") + os.Setenv("SLRP_DIALER_WIREGUARD_CONFIG_FILE", "$HOME/.$APP/wireguard.conf") + os.Setenv("SLRP_HISTORY_LIMIT", "100000") + //os.Setenv("SLRP_REFRESHER_ENABLED", "false") + // os.Setenv("SLRP_LOG_LEVEL", "debug") + // os.Setenv("SLRP_LOG_FORMAT", "file") // TODO: eek, make it better + // os.Setenv("SLRP_LOG_FILE", "/tmp/$APP.log") // TODO: eek, make it better } main() diff --git a/pmux/proxy.go b/pmux/proxy.go index eea3f17..b28bdb7 100644 --- a/pmux/proxy.go +++ b/pmux/proxy.go @@ -196,31 +196,42 @@ func dialProxiedConnection(ctx context.Context, network, addr string) (net.Conn, } } -func pickHttpProxyFromContext(r *http.Request) (*url.URL, error) { +func ProxyFromContext(r *http.Request) (*url.URL, error) { p := GetProxyFromContext(r.Context()) if p == 0 { return nil, nil } - if p.IsTunnel() { - // handled in DialContext - return nil, nil - } + // if p.IsTunnel() { + // // handled in DialContext + // return nil, nil + // } + // TODO: free-proxy.cz is not liking HTTPS dialer, so it needs only HTTP forwarder return p.URL(), nil } +var contextualTransport = &http.Transport{ + // If DialTLSContext is set, the Dial and DialContext hooks are not used for HTTPS + // requests and the TLSClientConfig and TLSHandshakeTimeout + // are ignored. The returned net.Conn is assumed to already be + // past the TLS handshake. + // DialTLSContext: dialProxiedConnection, + TLSClientConfig: DefaultTlsConfig, + // TLSHandshakeTimeout: DefaultDialer.Timeout, + Proxy: ProxyFromContext, + // DisableKeepAlives: true, + // MaxIdleConns: 0, +} + func ContextualHttpTransport() *http.Transport { - return &http.Transport{ - // If DialTLSContext is set, the Dial and DialContext hooks are not used for HTTPS - // requests and the TLSClientConfig and TLSHandshakeTimeout - // are ignored. The returned net.Conn is assumed to already be - // past the TLS handshake. - DialTLSContext: dialProxiedConnection, - TLSClientConfig: DefaultTlsConfig, - TLSHandshakeTimeout: DefaultDialer.Timeout, - Proxy: pickHttpProxyFromContext, - DisableKeepAlives: true, - MaxIdleConns: 0, + return contextualTransport +} + +func NewProxyFromURL(url string) Proxy { + split := strings.Split(url, "://") + if len(split) != 2 { + return 0 } + return NewProxy(split[1], split[0]) } func NewProxy(addr string, t string) Proxy { diff --git a/pmux/proxy_test.go b/pmux/proxy_test.go index 5bb7291..6c04ee2 100644 --- a/pmux/proxy_test.go +++ b/pmux/proxy_test.go @@ -83,20 +83,20 @@ func TestDialProxiedConnection_SOCKS(t *testing.T) { func TestPickProxyFromContext(t *testing.T) { p := HttpProxy("127.0.0.1:0") r := p.MustNewGetRequest("https://ifconfig.me") - u, _ := pickHttpProxyFromContext(r) + u, _ := ProxyFromContext(r) assert.Equal(t, u.String(), p.String()) } func TestPickProxyFromContext_Tunnel(t *testing.T) { p := Socks5Proxy("127.0.0.1:0") r := p.MustNewGetRequest("https://ifconfig.me") - u, err := pickHttpProxyFromContext(r) + u, err := ProxyFromContext(r) assert.Nil(t, u) assert.NoError(t, err) } func TestPickProxyFromContext_NoProxy(t *testing.T) { - u, err := pickHttpProxyFromContext(&http.Request{}) + u, err := ProxyFromContext(&http.Request{}) assert.Nil(t, u) assert.NoError(t, err) } diff --git a/pool/pool.go b/pool/pool.go index a3621ac..e2860e5 100644 --- a/pool/pool.go +++ b/pool/pool.go @@ -6,6 +6,7 @@ import ( "encoding/gob" "fmt" "math/rand" + "net" "net/http" "sync" "time" @@ -37,7 +38,11 @@ type httpClient interface { var poolWorkSize = 128 var poolShards = 32 -func NewPool(history *history.History, ipLookup ipinfo.IpInfoGetter) *Pool { +type dialer interface { + DialContext(ctx context.Context, network, address string) (net.Conn, error) +} + +func NewPool(history *history.History, ipLookup ipinfo.IpInfoGetter, dialer dialer) *Pool { return &Pool{ ipLookup: ipLookup, serial: make(chan int), @@ -46,8 +51,11 @@ func NewPool(history *history.History, ipLookup ipinfo.IpInfoGetter) *Pool { halt: make(chan time.Duration), shards: make([]shard, poolShards), client: &http.Client{ - Transport: history.Wrap(pmux.ContextualHttpTransport()), - Timeout: 10 * time.Second, // TODO: make timeouts configurable + Transport: history.Wrap(&http.Transport{ + DialContext: dialer.DialContext, + Proxy: pmux.ProxyFromContext, + TLSClientConfig: pmux.DefaultTlsConfig, + }), }, } } diff --git a/pool/pool_test.go b/pool/pool_test.go index 63a0195..873ed0c 100644 --- a/pool/pool_test.go +++ b/pool/pool_test.go @@ -4,6 +4,7 @@ import ( "context" "encoding/gob" "fmt" + "net" "net/http" "net/url" "os" @@ -20,15 +21,10 @@ import ( "github.com/stretchr/testify/assert" ) -func init() { - poolShards = 1 - poolWorkSize = 1 -} - func TestSimpleAddAndRemove(t *testing.T) { pool, runtime := app.MockStartSpin(NewPool(history.NewHistory(), ipinfo.NoopIpInfo{ Country: "Zimbabwe", - })) + }, &net.Dialer{})) defer runtime.Stop() ctx := context.Background() @@ -44,7 +40,7 @@ func TestMarshallAndUnmarshall(t *testing.T) { history := history.NewHistory() pool, first := app.MockStartSpin(NewPool(history, ipinfo.NoopIpInfo{ Country: "Zimbabwe", - })) + }, &net.Dialer{})) defer first.Stop() ctx := context.Background() @@ -57,7 +53,7 @@ func TestMarshallAndUnmarshall(t *testing.T) { loaded := NewPool(history, ipinfo.NoopIpInfo{ Country: "Zimbabwe", - }) + }, &net.Dialer{}) err = loaded.UnmarshalBinary(raw) assert.NoError(t, err) @@ -80,7 +76,7 @@ func (r staticResponseClient) Do(req *http.Request) (*http.Response, error) { func TestRoundTrip(t *testing.T) { pool, runtime := app.MockStartSpin(NewPool(history.NewHistory(), ipinfo.NoopIpInfo{ Country: "Zimbabwe", - })) + }, &net.Dialer{})) defer runtime.Stop() pool.client = staticResponseClient{ @@ -110,7 +106,7 @@ func TestSession(t *testing.T) { hist := history.NewHistory() pool, runtime := app.MockStartSpin(NewPool(hist, ipinfo.NoopIpInfo{ Country: "Zimbabwe", - }), hist) + }, &net.Dialer{}), hist) defer runtime.Stop() ctx := context.Background() @@ -139,7 +135,7 @@ func TestSession(t *testing.T) { func TestHttpGet(t *testing.T) { pool, runtime := app.MockStartSpin(NewPool(history.NewHistory(), ipinfo.NoopIpInfo{ Country: "Zimbabwe", - })) + }, &net.Dialer{})) defer runtime.Stop() ctx := context.Background() @@ -167,7 +163,7 @@ func load(t *testing.T) *Pool { dec := gob.NewDecoder(f) pool := NewPool(history.NewHistory(), ipinfo.NoopIpInfo{ Country: "Zimbabwe", - }) + }, &net.Dialer{}) dec.Decode(pool) return pool } @@ -219,7 +215,7 @@ func TestSelection(t *testing.T) { func TestReceiveHalt(t *testing.T) { pool, runtime := app.MockStartSpin(NewPool(history.NewHistory(), ipinfo.NoopIpInfo{ Country: "Zimbabwe", - })) + }, &net.Dialer{})) defer runtime.Stop() for i := 0; i < 33; i++ { @@ -233,7 +229,7 @@ func TestReceiveHalt(t *testing.T) { func TestCounterOnHalt(t *testing.T) { pool, runtime := app.MockStartSpin(NewPool(history.NewHistory(), ipinfo.NoopIpInfo{ Country: "Zimbabwe", - })) + }, &net.Dialer{})) defer runtime.Stop() serial := <-pool.serial @@ -257,7 +253,7 @@ func TestCounterOnHalt(t *testing.T) { func TestRandomFast(t *testing.T) { pool, runtime := app.MockStartSpin(NewPool(history.NewHistory(), ipinfo.NoopIpInfo{ Country: "Zimbabwe", - })) + }, &net.Dialer{})) defer runtime.Stop() x := pmux.HttpProxy("127.0.0.1:1024") @@ -275,7 +271,7 @@ func TestRandomFast(t *testing.T) { func TestRoundTripCtxErr(t *testing.T) { pool, runtime := app.MockStartSpin(NewPool(history.NewHistory(), ipinfo.NoopIpInfo{ Country: "Zimbabwe", - })) + }, &net.Dialer{})) defer runtime.Stop() ctx, cancel := context.WithCancel(context.Background()) @@ -291,7 +287,7 @@ func TestRoundTripCtxErr(t *testing.T) { func TestRoundTripNilResponseFromOut(t *testing.T) { pool, runtime := app.MockStartSpin(NewPool(history.NewHistory(), ipinfo.NoopIpInfo{ Country: "Zimbabwe", - })) + }, &net.Dialer{})) defer runtime.Stop() ctx, cancel := context.WithCancel(context.Background()) diff --git a/probe/blacklist_test.go b/probe/blacklist_test.go index 6d17b6a..c051497 100644 --- a/probe/blacklist_test.go +++ b/probe/blacklist_test.go @@ -2,6 +2,7 @@ package probe import ( "fmt" + "net" "net/http" "testing" @@ -26,7 +27,7 @@ func TestBlacklist(t *testing.T) { history := history.NewHistory() pool := pool.NewPool(history, ipinfo.NoopIpInfo{ Country: "Zimbabwe", - }) + }, &net.Dialer{}) probe := NewProbe(stats, pool, checker) runtime := app.Singletons{ diff --git a/probe/probe_test.go b/probe/probe_test.go index bcfde42..d9c9e71 100644 --- a/probe/probe_test.go +++ b/probe/probe_test.go @@ -3,6 +3,7 @@ package probe import ( "context" "fmt" + "net" "net/http" "net/url" "testing" @@ -42,7 +43,7 @@ func TestBasicProbe(t *testing.T) { history := history.NewHistory() pool := pool.NewPool(history, ipinfo.NoopIpInfo{ Country: "Zimbabwe", - }) + }, &net.Dialer{}) probe := NewProbe(stats, pool, checker) runtime := app.Singletons{ @@ -86,7 +87,7 @@ func TestProbeMarshaling(t *testing.T) { history := history.NewHistory() pool := pool.NewPool(history, ipinfo.NoopIpInfo{ Country: "Zimbabwe", - }) + }, &net.Dialer{}) probe := NewProbe(stats, pool, checker) runtime := app.Singletons{ @@ -125,7 +126,7 @@ func TestProbeDeleting(t *testing.T) { history := history.NewHistory() pool := pool.NewPool(history, ipinfo.NoopIpInfo{ Country: "Zimbabwe", - }) + }, &net.Dialer{}) probe := NewProbe(stats, pool, checker) runtime := app.Singletons{ diff --git a/serve/dashboard_test.go b/serve/dashboard_test.go index 247dfc3..c4f3a0b 100644 --- a/serve/dashboard_test.go +++ b/serve/dashboard_test.go @@ -3,6 +3,7 @@ package serve import ( "context" "fmt" + "net" "testing" "time" @@ -46,7 +47,7 @@ func TestDashboardRenders(t *testing.T) { history := history.NewHistory() pool := pool.NewPool(history, ipinfo.NoopIpInfo{ Country: "Zimbabwe", - }) + }, &net.Dialer{}) probe := probe.NewProbe(stats, pool, checker) refresher := refresher.NewRefresher(stats, pool, probe) dashboard := NewDashboard(refresher, probe, stats) diff --git a/serve/mitm_proxy_test.go b/serve/mitm_proxy_test.go index 8098717..45b7d2b 100644 --- a/serve/mitm_proxy_test.go +++ b/serve/mitm_proxy_test.go @@ -1,6 +1,7 @@ package serve import ( + "net" "net/http" "net/http/httptest" "testing" @@ -63,7 +64,7 @@ func TestFlows(t *testing.T) { history := history.NewHistory() pool := pool.NewPool(history, ipinfo.NoopIpInfo{ Country: "Zimbabwe", - }) + }, &net.Dialer{}) mitm, runtime := app.MockStartSpin( NewMitmProxyServer(pool, *defaultCA), history, pool, tt.Via) @@ -97,7 +98,7 @@ func TestMitm_HTTP_viaHTTP_toHTTP(t *testing.T) { // TODO: rename history := history.NewHistory() pool := pool.NewPool(history, ipinfo.NoopIpInfo{ Country: "Zimbabwe", - }) + }, &net.Dialer{}) mitm, runtime := app.MockStartSpin( NewMitmProxyServer(pool, *defaultCA), history, pool, transparentHttp)