Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

service discovery query filtering #1337

Merged
merged 5 commits into from
Aug 16, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 3 additions & 29 deletions cmd/skywire-cli/commands/vpn/vvpn.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ import (
clirpc "github.com/skycoin/skywire/cmd/skywire-cli/commands/rpc"
"github.com/skycoin/skywire/cmd/skywire-cli/internal"
"github.com/skycoin/skywire/pkg/app/appserver"
"github.com/skycoin/skywire/pkg/servicedisc"
"github.com/skycoin/skywire/pkg/visor/visorconfig"
)

Expand Down Expand Up @@ -104,35 +103,10 @@ var vpnListCmd = &cobra.Command{
ver = ""
country = ""
}
// servers, err := client.VPNServers(ver, country) //query filtering
servers, err := client.VPNServers()
servers, err := client.VPNServers(ver, country)
if err != nil {
logger.Fatal("Failed to connect; is skywire running?\n", err)
logger.Fatal(err)
}

/*vv remove when query filtering is implemented vv*/
var a []servicedisc.Service
for _, i := range servers {
if (ver == "") || (ver == "unknown") || (strings.Replace(i.Version, "v", "", 1) == ver) {
a = append(a, i)
}
}
if len(a) > 0 {
servers = a
a = []servicedisc.Service{}
}
if country != "" {
for _, i := range servers {
if i.Geo != nil {
if i.Geo.Country == country {
a = append(a, i)
}
}
}
servers = a
}
/*^^ remove when query filtering is implemented ^^*/

if len(servers) == 0 {
fmt.Printf("No VPN Servers found\n")
os.Exit(0)
Expand Down Expand Up @@ -168,7 +142,7 @@ var vpnStartCmd = &cobra.Command{
Short: "start the vpn for <public-key>",
Args: cobra.MinimumNArgs(1),
Run: func(_ *cobra.Command, args []string) {
fmt.Println("%s", args[0])
fmt.Println(args[0])
internal.Catch(clirpc.Client().StartVPNClient(args[0]))
fmt.Println("OK")
},
Expand Down
2 changes: 1 addition & 1 deletion internal/gui/gui.go
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ func getAvailPublicVPNServers(conf *visorconfig.V1, httpC *http.Client, logger *
DiscAddr: conf.Launcher.ServiceDisc,
}
sdClient := servicedisc.NewClient(log, log, svrConfig, httpC, "")
vpnServers, err := sdClient.Services(context.Background(), 0)
vpnServers, err := sdClient.Services(context.Background(), 0, "", "")
if err != nil {
logger.Error("Error getting vpn servers: ", err)
return nil
Expand Down
3 changes: 1 addition & 2 deletions pkg/servicedisc/autoconnect.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,7 @@ func (a *autoconnector) fetchPubAddresses(ctx context.Context) ([]cipher.PubKey,
var services []Service
fetch := func() (err error) {
// "return" services up from the closure
//services, err = a.client.Services(ctx, a.maxConns, "", "") //query filtering
services, err = a.client.Services(ctx, a.maxConns)
services, err = a.client.Services(ctx, a.maxConns, "", "")
if err != nil {
return err
}
Expand Down
41 changes: 16 additions & 25 deletions pkg/servicedisc/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,11 @@ import (
var ErrVisorUnreachable = errors.New("visor is unreachable")

const (
updateRetryDelay = 5 * time.Second
discServiceTypeParam = "type"
discServiceQtyParam = "quantity"

// discServiceCountryParam = "country" //query filtering
// discServiceVersionParam = "version" //query filtering
updateRetryDelay = 5 * time.Second
discServiceTypeParam = "type"
discServiceQtyParam = "quantity"
discServiceCountryParam = "country"
discServiceVersionParam = "version"
)

// Config configures the HTTPClient.
Expand Down Expand Up @@ -70,8 +69,7 @@ func NewClient(log logrus.FieldLogger, mLog *logging.MasterLogger, conf Config,
}
}

func (c *HTTPClient) addr(path, serviceType string, quantity int) (string, error) {
//func (c *HTTPClient) addr(path, serviceType, version, country string, quantity int) (string, error) { //query filtering
func (c *HTTPClient) addr(path, serviceType, version, country string, quantity int) (string, error) {
addr := c.conf.DiscAddr
url, err := url.Parse(addr)
if err != nil {
Expand All @@ -85,15 +83,12 @@ func (c *HTTPClient) addr(path, serviceType string, quantity int) (string, error
if quantity > 1 {
q.Set(discServiceQtyParam, strconv.Itoa(quantity))
}
//query filtering
/*
if version != "" {
q.Set(discServiceVersionParam, version)
}
if country != "" {
q.Set(discServiceCountryParam, country)
}
*/
if version != "" {
q.Set(discServiceVersionParam, version)
}
if country != "" {
q.Set(discServiceCountryParam, country)
}
url.RawQuery = q.Encode()
return url.String(), nil
}
Expand Down Expand Up @@ -124,10 +119,8 @@ func (c *HTTPClient) Auth(ctx context.Context) (*httpauth.Client, error) {
}

// Services calls 'GET /api/services'.
func (c *HTTPClient) Services(ctx context.Context, quantity int) (out []Service, err error) {
//func (c *HTTPClient) Services(ctx context.Context, quantity int, version, country string) (out []Service, err error) { //query filtering
//url, err := c.addr("/api/services", c.entry.Type, version, country, quantity)
url, err := c.addr("/api/services", c.entry.Type, quantity)
func (c *HTTPClient) Services(ctx context.Context, quantity int, version, country string) (out []Service, err error) {
url, err := c.addr("/api/services", c.entry.Type, version, country, quantity)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -202,8 +195,7 @@ func (c *HTTPClient) postEntry(ctx context.Context) (Service, error) {
return Service{}, err
}

// url, err := c.addr("/api/services", "", "", "", 1) //query filtering
url, err := c.addr("/api/services", "", 1)
url, err := c.addr("/api/services", "", "", "", 1)
if err != nil {
return Service{}, nil
}
Expand Down Expand Up @@ -260,8 +252,7 @@ func (c *HTTPClient) DeleteEntry(ctx context.Context) (err error) {
return err
}

// url, err := c.addr("/api/services/"+c.entry.Addr.String(), c.entry.Type, "", "", 1) //query filtering
url, err := c.addr("/api/services/"+c.entry.Addr.String(), c.entry.Type, 1)
url, err := c.addr("/api/services/"+c.entry.Addr.String(), c.entry.Type, "", "", 1)
if err != nil {
return err
}
Expand Down
9 changes: 3 additions & 6 deletions pkg/visor/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ type API interface {
GetAppStats(appName string) (appserver.AppStats, error)
GetAppError(appName string) (string, error)
GetAppConnectionsSummary(appName string) ([]appserver.ConnectionSummary, error)
// VPNServers(version, country string) ([]servicedisc.Service, error) //query filtering
VPNServers() ([]servicedisc.Service, error)
VPNServers(version, country string) ([]servicedisc.Service, error)
RemoteVisors() ([]string, error)

TransportTypes() ([]string, error)
Expand Down Expand Up @@ -624,8 +623,7 @@ func (v *Visor) GetAppConnectionsSummary(appName string) ([]appserver.Connection
}

// VPNServers gets available public VPN server from service discovery URL
func (v *Visor) VPNServers() ([]servicedisc.Service, error) {
//func (v *Visor) VPNServers(version, country string) ([]servicedisc.Service, error) { //query filtering
func (v *Visor) VPNServers(version, country string) ([]servicedisc.Service, error) {
log := logging.MustGetLogger("vpnservers")
vlog := logging.NewMasterLogger()
vlog.SetLevel(logrus.InfoLevel)
Expand All @@ -636,8 +634,7 @@ func (v *Visor) VPNServers() ([]servicedisc.Service, error) {
SK: v.conf.SK,
DiscAddr: v.conf.Launcher.ServiceDisc,
}, &http.Client{Timeout: time.Duration(1) * time.Second}, "")
// vpnServers, err := sdClient.Services(context.Background(), 0, version, country) //query filtering
vpnServers, err := sdClient.Services(context.Background(), 0)
vpnServers, err := sdClient.Services(context.Background(), 0, version, country)
if err != nil {
v.log.Error("Error getting public vpn servers: ", err)
return nil, err
Expand Down
8 changes: 2 additions & 6 deletions pkg/visor/rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -556,20 +556,16 @@ func (r *RPC) SetPublicAutoconnect(pAc *bool, _ *struct{}) (err error) {
return err
}

/* //query filtering
// FilterVPNServersIn is input for VPNServers
type FilterVPNServersIn struct {
Version string
Country string
}
*/

// VPNServers gets available public VPN server from service discovery URL
func (r *RPC) VPNServers(_ *struct{}, out *[]servicedisc.Service) (err error) {
//func (r *RPC) VPNServers(vc *FilterVPNServersIn, _ *struct{}, out *[]servicedisc.Service) (err error) { //query filtering
func (r *RPC) VPNServers(vc *FilterVPNServersIn, out *[]servicedisc.Service) (err error) {
defer rpcutil.LogCall(r.log, "VPNServers", nil)(out, &err)
// vpnServers, err := r.visor.VPNServers(vc.Version, vc.Country) //query filtering
vpnServers, err := r.visor.VPNServers()
vpnServers, err := r.visor.VPNServers(vc.Version, vc.Country)
if vpnServers != nil {
*out = vpnServers
}
Expand Down
15 changes: 5 additions & 10 deletions pkg/visor/rpc_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -409,17 +409,13 @@ type StatusMessage struct {
}

// VPNServers calls VPNServers.
func (rc *rpcClient) VPNServers() ([]servicedisc.Service, error) {
//func (rc *rpcClient) VPNServers(version, country string) ([]servicedisc.Service, error) { //query filtering
func (rc *rpcClient) VPNServers(version, country string) ([]servicedisc.Service, error) {
output := []servicedisc.Service{}
/* //query filtering
rc.Call("VPNServers", &FilterVPNServersIn{
err := rc.Call("VPNServers", &FilterVPNServersIn{ // nolint
Version: version,
Country: country,
}, &output) // nolint
*/
rc.Call("VPNServers", &struct{}{}, &output) // nolint
return output, nil
}, &output)
return output, err
}

// RemoteVisors calls RemoteVisors.
Expand Down Expand Up @@ -982,8 +978,7 @@ func (mc *mockRPCClient) GetPersistentTransports() ([]transport.PersistentTransp
}

// VPNServers implements API
func (mc *mockRPCClient) VPNServers() ([]servicedisc.Service, error) {
//func (mc *mockRPCClient) VPNServers(_, _ string) ([]servicedisc.Service, error) { //query filtering
func (mc *mockRPCClient) VPNServers(_, _ string) ([]servicedisc.Service, error) {
return []servicedisc.Service{}, nil
}

Expand Down