diff --git a/cmd/apps/vpn-server/vpn-server.go b/cmd/apps/vpn-server/vpn-server.go index b934addee..db1931556 100644 --- a/cmd/apps/vpn-server/vpn-server.go +++ b/cmd/apps/vpn-server/vpn-server.go @@ -30,6 +30,7 @@ var ( localPKStr = flag.String("pk", "", "Local PubKey") localSKStr = flag.String("sk", "", "Local SecKey") passcode = flag.String("passcode", "", "Passcode to authenticate connecting users") + secure = flag.Bool("secure", true, "Forbid connections from clients to server local network") ) func main() { @@ -73,6 +74,7 @@ func main() { srvCfg := vpn.ServerConfig{ Passcode: *passcode, + Secure: *secure, } srv, err := vpn.NewServer(srvCfg, log) if err != nil { diff --git a/internal/vpn/ip_generator.go b/internal/vpn/ip_generator.go index 6ec61b7e3..10795c22b 100644 --- a/internal/vpn/ip_generator.go +++ b/internal/vpn/ip_generator.go @@ -17,7 +17,8 @@ type IPGenerator struct { func NewIPGenerator() *IPGenerator { return &IPGenerator{ ranges: []*subnetIPIncrementer{ - newSubnetIPIncrementer([4]uint8{192, 168, 0, 0}, [4]uint8{192, 168, 255, 255}, 8), + // exclude some most commonly used addresses in local networks + newSubnetIPIncrementer([4]uint8{192, 168, 2, 0}, [4]uint8{192, 168, 255, 255}, 8), newSubnetIPIncrementer([4]uint8{172, 16, 0, 0}, [4]uint8{172, 31, 255, 255}, 8), newSubnetIPIncrementer([4]uint8{10, 0, 0, 0}, [4]uint8{10, 255, 255, 255}, 8), }, diff --git a/internal/vpn/os.go b/internal/vpn/os.go index 58a9caecf..d70364592 100644 --- a/internal/vpn/os.go +++ b/internal/vpn/os.go @@ -7,6 +7,71 @@ import ( "os/exec" ) +// LocalNetworkInterfaceIPs gets IPs of all local interfaces. +func LocalNetworkInterfaceIPs() ([]net.IP, error) { + ips, _, err := localNetworkInterfaceIPs("") + return ips, err +} + +// NetworkInterfaceIPs gets IPs of network interface with name `name`. +func NetworkInterfaceIPs(name string) ([]net.IP, error) { + _, ifcIPs, err := localNetworkInterfaceIPs(name) + return ifcIPs, err +} + +// localNetworkInterfaceIPs gets IPs of all local interfaces. Separately returns list of IPs +// of interface `ifcName`. +func localNetworkInterfaceIPs(ifcName string) ([]net.IP, []net.IP, error) { + var ifcIPs []net.IP + + ifaces, err := net.Interfaces() + if err != nil { + return nil, nil, fmt.Errorf("error getting network interfaces: %w", err) + } + + var ips []net.IP + for _, iface := range ifaces { + if iface.Flags&net.FlagUp == 0 { + continue // interface down + } + if iface.Flags&net.FlagLoopback != 0 { + continue // loopback interface + } + + addrs, err := iface.Addrs() + if err != nil { + return nil, nil, fmt.Errorf("error getting addresses for interface %s: %w", iface.Name, err) + } + + for _, addr := range addrs { + var ip net.IP + switch v := addr.(type) { + case *net.IPNet: + ip = v.IP + case *net.IPAddr: + ip = v.IP + } + + if ip == nil || ip.IsLoopback() { + continue + } + + ip = ip.To4() + if ip == nil { + continue // not an ipv4 address + } + + ips = append(ips, ip) + + if ifcName != "" && iface.Name == ifcName { + ifcIPs = append(ifcIPs, ip) + } + } + } + + return ips, ifcIPs, nil +} + func parseCIDR(ipCIDR string) (ipStr, netmask string, err error) { ip, net, err := net.ParseCIDR(ipCIDR) if err != nil { diff --git a/internal/vpn/os_client.go b/internal/vpn/os_client.go deleted file mode 100644 index 4e1aab4f0..000000000 --- a/internal/vpn/os_client.go +++ /dev/null @@ -1,52 +0,0 @@ -package vpn - -import ( - "fmt" - "net" -) - -// LocalNetworkInterfaceIPs gets IPs of all local interfaces. -func LocalNetworkInterfaceIPs() ([]net.IP, error) { - ifaces, err := net.Interfaces() - if err != nil { - return nil, fmt.Errorf("error getting network interfaces: %w", err) - } - - var ips []net.IP - for _, iface := range ifaces { - if iface.Flags&net.FlagUp == 0 { - continue // interface down - } - if iface.Flags&net.FlagLoopback != 0 { - continue // loopback interface - } - - addrs, err := iface.Addrs() - if err != nil { - return nil, fmt.Errorf("error getting addresses for interface %s: %w", iface.Name, err) - } - - for _, addr := range addrs { - var ip net.IP - switch v := addr.(type) { - case *net.IPNet: - ip = v.IP - case *net.IPAddr: - ip = v.IP - } - - if ip == nil || ip.IsLoopback() { - continue - } - - ip = ip.To4() - if ip == nil { - continue // not an ipv4 address - } - - ips = append(ips, ip) - } - } - - return ips, nil -} diff --git a/internal/vpn/os_server.go b/internal/vpn/os_server.go index 0449a28b6..8daa14e44 100644 --- a/internal/vpn/os_server.go +++ b/internal/vpn/os_server.go @@ -4,12 +4,35 @@ package vpn import ( "errors" + "net" ) var ( errServerMethodsNotSupported = errors.New("server related methods are not supported for this OS") ) +// AllowSSH allows all SSH traffic (via default 22 port) between `src` and `dst`. +func AllowSSH(_, _ net.IP, _ []net.IP) error { + return errServerMethodsNotSupported +} + +// BlockSSH blocks all SSH traffic (via default 22 port) between `src` and `dst`. +func BlockSSH(_, _ net.IP, _ []net.IP) error { + return errServerMethodsNotSupported +} + +// AllowIPToLocalNetwork allows all the packets coming from `source` +// to private IP ranges. +func AllowIPToLocalNetwork(_, _ net.IP) error { + return errServerMethodsNotSupported +} + +// BlockIPToLocalNetwork blocks all the packets coming from `source` +// to private IP ranges. +func BlockIPToLocalNetwork(_, _ net.IP) error { + return errServerMethodsNotSupported +} + // DefaultNetworkInterface fetches default network interface name. func DefaultNetworkInterface() (string, error) { return "", errServerMethodsNotSupported diff --git a/internal/vpn/os_server_linux.go b/internal/vpn/os_server_linux.go index 7444fae67..073187357 100644 --- a/internal/vpn/os_server_linux.go +++ b/internal/vpn/os_server_linux.go @@ -5,6 +5,7 @@ package vpn import ( "bytes" "fmt" + "net" "os/exec" ) @@ -16,8 +17,32 @@ const ( setIPv6ForwardingCMDFmt = "sysctl -w net.ipv6.conf.all.forwarding=%s" enableIPMasqueradingCMDFmt = "iptables -t nat -A POSTROUTING -o %s -j MASQUERADE" disableIPMasqueradingCMDFmt = "iptables -t nat -D POSTROUTING -o %s -j MASQUERADE" + blockIPToLocalNetCMDFmt = "iptables -I FORWARD -d 192.168.0.0/16,172.16.0.0/12,10.0.0.0/8 -s %s -j DROP && iptables -I INPUT -d 192.168.0.0/16,172.16.0.0/12,10.0.0.0/8 -s %s -j DROP" + allowIPToLocalNetCMDFmt = "iptables -D FORWARD -d 192.168.0.0/16,172.16.0.0/12,10.0.0.0/8 -s %s -j DROP && iptables -D INPUT -d 192.168.0.0/16,172.16.0.0/12,10.0.0.0/8 -s %s -j DROP" ) +// AllowIPToLocalNetwork allows all the packets coming from `source` +// to private IP ranges. +func AllowIPToLocalNetwork(src, dst net.IP) error { + cmd := fmt.Sprintf(allowIPToLocalNetCMDFmt, src, src) + if err := exec.Command("sh", "-c", cmd).Run(); err != nil { //nolint:gosec + return fmt.Errorf("error running command %s: %w", cmd, err) + } + + return nil +} + +// BlockIPToLocalNetwork blocks all the packets coming from `source` +// to private IP ranges. +func BlockIPToLocalNetwork(src, dst net.IP) error { + cmd := fmt.Sprintf(blockIPToLocalNetCMDFmt, src, src) + if err := exec.Command("sh", "-c", cmd).Run(); err != nil { //nolint:gosec + return fmt.Errorf("error running command %s: %w", cmd, err) + } + + return nil +} + // DefaultNetworkInterface fetches default network interface name. func DefaultNetworkInterface() (string, error) { outputBytes, err := exec.Command("sh", "-c", defaultNetworkInterfaceCMD).Output() diff --git a/internal/vpn/server.go b/internal/vpn/server.go index 01f7633a5..3038b242e 100644 --- a/internal/vpn/server.go +++ b/internal/vpn/server.go @@ -12,15 +12,16 @@ import ( // Server is a VPN server. type Server struct { - cfg ServerConfig - lisMx sync.Mutex - lis net.Listener - log logrus.FieldLogger - serveOnce sync.Once - ipGen *IPGenerator - defaultNetworkInterface string - ipv4ForwardingVal string - ipv6ForwardingVal string + cfg ServerConfig + lisMx sync.Mutex + lis net.Listener + log logrus.FieldLogger + serveOnce sync.Once + ipGen *IPGenerator + defaultNetworkInterface string + defaultNetworkInterfaceIPs []net.IP + ipv4ForwardingVal string + ipv6ForwardingVal string } // NewServer creates VPN server instance. @@ -38,6 +39,13 @@ func NewServer(cfg ServerConfig, l logrus.FieldLogger) (*Server, error) { l.Infof("Got default network interface: %s", defaultNetworkIfc) + defaultNetworkIfcIPs, err := NetworkInterfaceIPs(defaultNetworkIfc) + if err != nil { + return nil, fmt.Errorf("error getting IPs of interface %s: %w", defaultNetworkIfc, err) + } + + l.Infof("Got IPs of interface %s: %v", defaultNetworkIfc, defaultNetworkIfcIPs) + ipv4ForwardingVal, err := GetIPv4ForwardingValue() if err != nil { return nil, fmt.Errorf("error getting IPv4 forwarding value: %w", err) @@ -51,6 +59,7 @@ func NewServer(cfg ServerConfig, l logrus.FieldLogger) (*Server, error) { l.Infof("IPv4: %s, IPv6: %s", ipv4ForwardingVal, ipv6ForwardingVal) s.defaultNetworkInterface = defaultNetworkIfc + s.defaultNetworkInterfaceIPs = defaultNetworkIfcIPs s.ipv4ForwardingVal = ipv4ForwardingVal s.ipv6ForwardingVal = ipv6ForwardingVal @@ -144,11 +153,12 @@ func (s *Server) closeConn(conn net.Conn) { func (s *Server) serveConn(conn net.Conn) { defer s.closeConn(conn) - tunIP, tunGateway, err := s.shakeHands(conn) + tunIP, tunGateway, allowTrafficToLocalNet, err := s.shakeHands(conn) if err != nil { s.log.WithError(err).Errorf("Error negotiating with client %s", conn.RemoteAddr()) return } + defer allowTrafficToLocalNet() tun, err := newTUNDevice() if err != nil { @@ -193,55 +203,40 @@ func (s *Server) serveConn(conn net.Conn) { } } -func (s *Server) shakeHands(conn net.Conn) (tunIP, tunGateway net.IP, err error) { +func (s *Server) shakeHands(conn net.Conn) (tunIP, tunGateway net.IP, unsecureVPN func(), err error) { var cHello ClientHello if err := ReadJSON(conn, &cHello); err != nil { - return nil, nil, fmt.Errorf("error reading client hello: %w", err) + return nil, nil, nil, fmt.Errorf("error reading client hello: %w", err) } - s.log.Debugf("Got client hello: %v", cHello) + // default value + unsecureVPN = func() {} - var sHello ServerHello + s.log.Debugf("Got client hello: %v", cHello) if s.cfg.Passcode != "" && cHello.Passcode != s.cfg.Passcode { - sHello.Status = HandshakeStatusForbidden - if err := WriteJSON(conn, &sHello); err != nil { - s.log.WithError(err).Errorln("Error sending server hello") - } - - return nil, nil, errors.New("got wrong passcode from client") + s.sendServerErrHello(conn, HandshakeStatusForbidden) + return nil, nil, nil, errors.New("got wrong passcode from client") } for _, ip := range cHello.UnavailablePrivateIPs { if err := s.ipGen.Reserve(ip); err != nil { // this happens only on malformed IP - sHello.Status = HandshakeStatusBadRequest - if err := WriteJSON(conn, &sHello); err != nil { - s.log.WithError(err).Errorln("Error sending server hello") - } - - return nil, nil, fmt.Errorf("error reserving IP %s: %w", ip.String(), err) + s.sendServerErrHello(conn, HandshakeStatusBadRequest) + return nil, nil, nil, fmt.Errorf("error reserving IP %s: %w", ip.String(), err) } } subnet, err := s.ipGen.Next() if err != nil { - sHello.Status = HandshakeNoFreeIPs - if err := WriteJSON(conn, &sHello); err != nil { - s.log.WithError(err).Errorln("Error sending server hello") - } - - return nil, nil, fmt.Errorf("error getting free subnet IP: %w", err) + s.sendServerErrHello(conn, HandshakeNoFreeIPs) + return nil, nil, nil, fmt.Errorf("error getting free subnet IP: %w", err) } subnetOctets, err := fetchIPv4Octets(subnet) if err != nil { - sHello.Status = HandshakeStatusInternalError - if err := WriteJSON(conn, &sHello); err != nil { - s.log.WithError(err).Errorln("Error sending server hello") - } - - return nil, nil, fmt.Errorf("error breaking IP into octets: %w", err) + s.sendServerErrHello(conn, HandshakeStatusInternalError) + return nil, nil, nil, fmt.Errorf("error breaking IP into octets: %w", err) } // basically IP address comprised of `subnetOctets` items is the IP address of the subnet, @@ -258,12 +253,40 @@ func (s *Server) shakeHands(conn net.Conn) (tunIP, tunGateway net.IP, err error) cTUNIP := net.IPv4(subnetOctets[0], subnetOctets[1], subnetOctets[2], subnetOctets[3]+4) cTUNGateway := net.IPv4(subnetOctets[0], subnetOctets[1], subnetOctets[2], subnetOctets[3]+3) - sHello.TUNIP = cTUNIP - sHello.TUNGateway = cTUNGateway + if s.cfg.Secure { + if err := BlockIPToLocalNetwork(cTUNIP, sTUNIP); err != nil { + s.sendServerErrHello(conn, HandshakeStatusInternalError) + return nil, nil, nil, + fmt.Errorf("error securing local network for IP %s: %w", cTUNIP, err) + } + + unsecureVPN = func() { + if err := AllowIPToLocalNetwork(cTUNIP, sTUNIP); err != nil { + s.log.WithError(err).Errorln("Error allowing traffic to local network") + } + } + } + + sHello := ServerHello{ + Status: HandshakeStatusOK, + TUNIP: cTUNIP, + TUNGateway: cTUNGateway, + } if err := WriteJSON(conn, &sHello); err != nil { - return nil, nil, fmt.Errorf("error finishing hadnshake: error sending server hello: %w", err) + unsecureVPN() + return nil, nil, nil, fmt.Errorf("error finishing hadnshake: error sending server hello: %w", err) } - return sTUNIP, sTUNGateway, nil + return sTUNIP, sTUNGateway, unsecureVPN, nil +} + +func (s *Server) sendServerErrHello(conn net.Conn, status HandshakeStatus) { + sHello := ServerHello{ + Status: status, + } + + if err := WriteJSON(conn, &sHello); err != nil { + s.log.WithError(err).Errorln("Error sending server hello") + } } diff --git a/internal/vpn/server_config.go b/internal/vpn/server_config.go index e1e6a5177..b95476598 100644 --- a/internal/vpn/server_config.go +++ b/internal/vpn/server_config.go @@ -3,4 +3,5 @@ package vpn // ServerConfig is a configuration for VPN server. type ServerConfig struct { Passcode string + Secure bool }