Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

move netdev interface to use net/netip for IP addr/ports #622

Merged
merged 2 commits into from
Dec 4, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 23 additions & 19 deletions espat/espat.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"fmt"
"machine"
"net"
"net/netip"
"strconv"
"strings"
"sync"
Expand All @@ -44,8 +45,8 @@ type Config struct {
type socket struct {
inUse bool
protocol int
lip net.IP
lport int
lip netip.Addr
lport uint16
soypat marked this conversation as resolved.
Show resolved Hide resolved
}

type Device struct {
Expand Down Expand Up @@ -106,7 +107,7 @@ func (d *Device) NetConnect(params *netlink.ConnectParams) error {

fmt.Printf("CONNECTED\r\n")

ip, err := d.GetIPAddr()
ip, err := d.Addr()
if err != nil {
return err
}
Expand All @@ -125,28 +126,31 @@ func (d *Device) NetNotify(cb func(netlink.Event)) {
// Not supported
}

func (d *Device) GetHostByName(name string) (net.IP, error) {
func (d *Device) GetHostByName(name string) (netip.Addr, error) {
ip, err := d.GetDNS(name)
return net.ParseIP(ip), err
if err != nil {
return netip.Addr{}, err
}
return netip.ParseAddr(ip)
}

func (d *Device) GetHardwareAddr() (net.HardwareAddr, error) {
return net.HardwareAddr{}, netlink.ErrNotSupported
}

func (d *Device) GetIPAddr() (net.IP, error) {
func (d *Device) Addr() (netip.Addr, error) {
resp, err := d.GetClientIP()
if err != nil {
return net.IP{}, err
return netip.Addr{}, err
}
prefix := "+CIPSTA:ip:"
for _, line := range strings.Split(resp, "\n") {
if ok := strings.HasPrefix(line, prefix); ok {
ip := line[len(prefix)+1 : len(line)-2]
return net.ParseIP(ip), nil
return netip.ParseAddr(ip)
}
}
return net.IP{}, fmt.Errorf("Error getting IP address")
return netip.Addr{}, fmt.Errorf("Error getting IP address")
}

func (d *Device) Socket(domain int, stype int, protocol int) (int, error) {
Expand Down Expand Up @@ -175,17 +179,17 @@ func (d *Device) Socket(domain int, stype int, protocol int) (int, error) {
return 0, nil
}

func (d *Device) Bind(sockfd int, ip net.IP, port int) error {
d.socket.lip = ip
d.socket.lport = port
func (d *Device) Bind(sockfd int, ip netip.AddrPort) error {
d.socket.lip = ip.Addr()
d.socket.lport = ip.Port()
return nil
}

func (d *Device) Connect(sockfd int, host string, ip net.IP, port int) error {
func (d *Device) Connect(sockfd int, host string, ip netip.AddrPort) error {
var err error
var addr = ip.String()
var rport = strconv.Itoa(port)
var lport = strconv.Itoa(d.socket.lport)
var addr = ip.Addr().String()
var rport = strconv.Itoa(int(ip.Port()))
var lport = strconv.Itoa(int(d.socket.lport))

switch d.socket.protocol {
case netdev.IPPROTO_TCP:
Expand All @@ -198,9 +202,9 @@ func (d *Device) Connect(sockfd int, host string, ip net.IP, port int) error {

if err != nil {
if host == "" {
return fmt.Errorf("Connect to %s:%d timed out", ip, port)
return fmt.Errorf("Connect to %s timed out", ip)
} else {
return fmt.Errorf("Connect to %s:%d timed out", host, port)
return fmt.Errorf("Connect to %s:%d timed out", host, ip.Port())
}
}

Expand All @@ -216,7 +220,7 @@ func (d *Device) Listen(sockfd int, backlog int) error {
return nil
}

func (d *Device) Accept(sockfd int, ip net.IP, port int) (int, error) {
func (d *Device) Accept(sockfd int, ip netip.AddrPort) (int, error) {
return -1, netdev.ErrNotSupported
}

Expand Down
15 changes: 8 additions & 7 deletions netdev/netdev.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ package netdev

import (
"errors"
"net"
"net/netip"
"time"
_ "unsafe" // to use go:linkname
)
Expand All @@ -28,6 +28,7 @@ const (
// GethostByName() errors
var (
ErrHostUnknown = errors.New("Host unknown")
ErrMalAddr = errors.New("Malformed address")
)

// Socket errors
Expand Down Expand Up @@ -70,18 +71,18 @@ type Netdever interface {

// GetHostByName returns the IP address of either a hostname or IPv4
// address in standard dot notation
GetHostByName(name string) (net.IP, error)
GetHostByName(name string) (netip.Addr, error)

// GetIPAddr returns IP address assigned to the interface, either by
// Addr returns IP address assigned to the interface, either by
// DHCP or statically
GetIPAddr() (net.IP, error)
Addr() (netip.Addr, error)

// Berkely Sockets-like interface, Go-ified. See man page for socket(2), etc.
Socket(domain int, stype int, protocol int) (int, error)
Bind(sockfd int, ip net.IP, port int) error
Connect(sockfd int, host string, ip net.IP, port int) error
Bind(sockfd int, ip netip.AddrPort) error
Connect(sockfd int, host string, ip netip.AddrPort) error
Listen(sockfd int, backlog int) error
Accept(sockfd int, ip net.IP, port int) (int, error)
Accept(sockfd int, ip netip.AddrPort) (int, error)
Send(sockfd int, buf []byte, flags int, deadline time.Time) (int, error)
Recv(sockfd int, buf []byte, flags int, deadline time.Time) (int, error)
Close(sockfd int) error
Expand Down
77 changes: 42 additions & 35 deletions rtl8720dn/rtl8720dn.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"io"
"machine"
"net"
"net/netip"
"strings"
"sync"
"time"
Expand Down Expand Up @@ -195,9 +196,9 @@ func (r *rtl8720dn) showIP() {
if debugging(debugBasic) {
ip, subnet, gateway, _ := r.getIP()
fmt.Printf("\r\n")
fmt.Printf("DHCP-assigned IP : %s\r\n", ip.String())
fmt.Printf("DHCP-assigned subnet : %s\r\n", subnet.String())
fmt.Printf("DHCP-assigned gateway : %s\r\n", gateway.String())
fmt.Printf("DHCP-assigned IP : %s\r\n", ip)
fmt.Printf("DHCP-assigned subnet : %s\r\n", subnet)
fmt.Printf("DHCP-assigned gateway : %s\r\n", gateway)
fmt.Printf("\r\n")
}
}
Expand Down Expand Up @@ -315,7 +316,7 @@ func (r *rtl8720dn) NetNotify(cb func(netlink.Event)) {
r.notifyCb = cb
}

func (r *rtl8720dn) GetHostByName(name string) (net.IP, error) {
func (r *rtl8720dn) GetHostByName(name string) (netip.Addr, error) {

if debugging(debugNetdev) {
fmt.Printf("[GetHostByName] name: %s\r\n", name)
Expand All @@ -327,10 +328,15 @@ func (r *rtl8720dn) GetHostByName(name string) (net.IP, error) {
var ip [4]byte
result := r.rpc_netconn_gethostbyname(name, ip[:])
if result == -1 {
return net.IP{}, netdev.ErrHostUnknown
return netip.Addr{}, netdev.ErrHostUnknown
}

return net.IP(ip[:]), nil
addr, ok := netip.AddrFromSlice(ip[:])
if !ok {
return netip.Addr{}, netdev.ErrMalAddr
}

return addr, nil
}

func (r *rtl8720dn) GetHardwareAddr() (net.HardwareAddr, error) {
Expand All @@ -348,7 +354,7 @@ func (r *rtl8720dn) GetHardwareAddr() (net.HardwareAddr, error) {
return net.HardwareAddr(addr), err
}

func (r *rtl8720dn) GetIPAddr() (net.IP, error) {
func (r *rtl8720dn) Addr() (netip.Addr, error) {

if debugging(debugNetdev) {
fmt.Printf("[GetIPAddr]\r\n")
Expand All @@ -359,7 +365,7 @@ func (r *rtl8720dn) GetIPAddr() (net.IP, error) {

ip, _, _, err := r.getIP()

return net.IP(ip), err
return ip, err
}

func (r *rtl8720dn) clientTLS() uint32 {
Expand Down Expand Up @@ -415,40 +421,40 @@ func (r *rtl8720dn) Socket(domain int, stype int, protocol int) (int, error) {
return int(newSock), nil
}

func addrToName(ip net.IP, port int) []byte {
func ipToName(ip netip.AddrPort) []byte {
name := make([]byte, 16)
name[0] = 0x00
name[1] = netdev.AF_INET
name[2] = byte(port >> 8)
name[3] = byte(port)
if len(ip) == 4 {
name[4] = byte(ip[0])
name[5] = byte(ip[1])
name[6] = byte(ip[2])
name[7] = byte(ip[3])
name[2] = byte(ip.Port() >> 8)
name[3] = byte(ip.Port())
if ip.Addr().Is4() {
addr := ip.Addr().As4()
name[4] = byte(addr[0])
name[5] = byte(addr[1])
name[6] = byte(addr[2])
name[7] = byte(addr[3])
}

return name
}

func (r *rtl8720dn) Bind(sockfd int, ip net.IP, port int) error {
func (r *rtl8720dn) Bind(sockfd int, ip netip.AddrPort) error {

if debugging(debugNetdev) {
fmt.Printf("[Bind] sockfd: %d, addr: %s:%d\r\n", sockfd, ip, port)
fmt.Printf("[Bind] sockfd: %d, addr: %s\r\n", sockfd, ip)
}

r.mu.Lock()
defer r.mu.Unlock()

var sock = sock(sockfd)
var socket = r.sockets[sock]
var name = addrToName(ip, port)
var name = ipToName(ip)

switch socket.protocol {
case netdev.IPPROTO_TCP, netdev.IPPROTO_UDP:
result := r.rpc_lwip_bind(int32(sock), name, uint32(len(name)))
if result == -1 {
return fmt.Errorf("Bind to %s:%d failed", ip, port)
return fmt.Errorf("Bind to %s failed", ip)
}
default:
return netdev.ErrProtocolNotSupported
Expand All @@ -457,11 +463,13 @@ func (r *rtl8720dn) Bind(sockfd int, ip net.IP, port int) error {
return nil
}

func (r *rtl8720dn) Connect(sockfd int, host string, ip net.IP, port int) error {
func (r *rtl8720dn) Connect(sockfd int, host string, ip netip.AddrPort) error {

port := ip.Port()

if debugging(debugNetdev) {
if host == "" {
fmt.Printf("[Connect] sockfd: %d, addr: %s:%d\r\n", sockfd, ip, port)
fmt.Printf("[Connect] sockfd: %d, addr: %s\r\n", sockfd, ip)
} else {
fmt.Printf("[Connect] sockfd: %d, host: %s:%d\r\n", sockfd, host, port)
}
Expand All @@ -472,14 +480,14 @@ func (r *rtl8720dn) Connect(sockfd int, host string, ip net.IP, port int) error

var sock = sock(sockfd)
var socket = r.sockets[sock]
var name = addrToName(ip, port)
var name = ipToName(ip)

// Start the connection
switch socket.protocol {
case netdev.IPPROTO_TCP, netdev.IPPROTO_UDP:
result := r.rpc_lwip_connect(int32(sock), name, uint32(len(name)))
if result == -1 {
return fmt.Errorf("Connect to %s:%d failed", ip, port)
return fmt.Errorf("Connect to %s failed", ip)
}
case netdev.IPPROTO_TLS:
result := r.rpc_wifi_start_ssl_client(uint32(sock),
Expand Down Expand Up @@ -526,10 +534,10 @@ func (r *rtl8720dn) Listen(sockfd int, backlog int) error {
return nil
}

func (r *rtl8720dn) Accept(sockfd int, ip net.IP, port int) (int, error) {
func (r *rtl8720dn) Accept(sockfd int, ip netip.AddrPort) (int, error) {

if debugging(debugNetdev) {
fmt.Printf("[Accept] sockfd: %d, peer: %s:%d\r\n", sockfd, ip, port)
fmt.Printf("[Accept] sockfd: %d, peer: %s\r\n", sockfd, ip)
}

r.mu.Lock()
Expand All @@ -538,7 +546,7 @@ func (r *rtl8720dn) Accept(sockfd int, ip net.IP, port int) (int, error) {
var newSock int32
var lsock = sock(sockfd)
var socket = r.sockets[lsock]
var addr = addrToName(ip, port)
var name = ipToName(ip)

switch socket.protocol {
case netdev.IPPROTO_TCP:
Expand All @@ -554,8 +562,8 @@ func (r *rtl8720dn) Accept(sockfd int, ip net.IP, port int) (int, error) {
r.mu.Lock()

// Check if a client connected. O_NONBLOCK is set on lsock.
addrlen := uint32(len(addr))
newSock = r.rpc_lwip_accept(int32(lsock), addr, &addrlen)
namelen := uint32(len(name))
newSock = r.rpc_lwip_accept(int32(lsock), name, &namelen)
if newSock == -1 {
// No new client
time.Sleep(100 * time.Millisecond)
Expand Down Expand Up @@ -761,16 +769,15 @@ func (r *rtl8720dn) getMACAddr() string {
return string(mac[:])
}

func (r *rtl8720dn) getIP() (ip, subnet, gateway net.IP, err error) {
func (r *rtl8720dn) getIP() (ip, subnet, gateway netip.Addr, err error) {
var ip_info [12]byte
result := r.rpc_tcpip_adapter_get_ip_info(0, ip_info[:])
if result == -1 {
err = fmt.Errorf("Get IP info failed")
return
}
ip, subnet, gateway = make([]byte, 4), make([]byte, 4), make([]byte, 4)
copy(ip[:], ip_info[0:4])
copy(subnet[:], ip_info[4:8])
copy(gateway[:], ip_info[8:12])
ip, _ = netip.AddrFromSlice(ip_info[0:4])
subnet, _ = netip.AddrFromSlice(ip_info[4:8])
gateway, _ = netip.AddrFromSlice(ip_info[8:12])
return
}
Loading
Loading