Skip to content

Commit

Permalink
refactored the networkutils and windows watcher package
Browse files Browse the repository at this point in the history
Presently, we have setter methods present in the networkutils and watcher package instead of having mock methods for the same. This is inconsistent with rest of the agent package and goes against Go guidelines.
  • Loading branch information
Harsh Rawat authored and singholt committed Oct 18, 2022
1 parent 605a6d7 commit add9b0b
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 108 deletions.
96 changes: 96 additions & 0 deletions agent/eni/networkutils/mocks/utils_windows.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

26 changes: 5 additions & 21 deletions agent/eni/networkutils/utils_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,22 +25,23 @@ import (
"time"
"unsafe"

"golang.org/x/sys/windows"

apierrors "github.com/aws/amazon-ecs-agent/agent/api/errors"
"github.com/aws/amazon-ecs-agent/agent/eni/netwrapper"
"github.com/aws/amazon-ecs-agent/agent/utils/retry"

"github.com/cihub/seelog"
"github.com/pkg/errors"
"golang.org/x/sys/windows"
)

//go:generate mockgen -destination=mocks/$GOFILE -copyright_file=../../../scripts/copyright_file github.com/aws/amazon-ecs-agent/agent/eni/networkutils NetworkUtils

// NetworkUtils is the interface used for accessing network related functionality on Windows.
// The methods declared in this package may or may not add any additional logic over the actual networking api calls.
type NetworkUtils interface {
GetInterfaceMACByIndex(int, context.Context, time.Duration) (string, error)
GetAllNetworkInterfaces() ([]net.Interface, error)
GetDNSServerAddressList(macAddress string) ([]string, error)
SetNetWrapper(netWrapper netwrapper.NetWrapper)
}

type networkUtils struct {
Expand Down Expand Up @@ -118,11 +119,6 @@ func (utils *networkUtils) GetAllNetworkInterfaces() ([]net.Interface, error) {
return utils.netWrapper.GetAllNetworkInterfaces()
}

// SetNetWrapper is used to inject netWrapper instance. This will be handy while testing to inject mocks.
func (utils *networkUtils) SetNetWrapper(netWrapper netwrapper.NetWrapper) {
utils.netWrapper = netWrapper
}

// GetDNSServerAddressList returns the DNS server addresses of the queried interface.
func (utils *networkUtils) GetDNSServerAddressList(macAddress string) ([]string, error) {
addresses, err := funcGetAdapterAddresses()
Expand All @@ -141,7 +137,7 @@ func (utils *networkUtils) GetDNSServerAddressList(macAddress string) ([]string,

dnsServerAddressList := make([]string, 0)
for firstDnsNode != nil {
dnsServerAddressList = append(dnsServerAddressList, utils.parseSocketAddress(firstDnsNode.Address))
dnsServerAddressList = append(dnsServerAddressList, firstDnsNode.Address.IP().String())
firstDnsNode = firstDnsNode.Next
}

Expand All @@ -158,18 +154,6 @@ func (utils *networkUtils) parseMACAddress(adapterAddress *windows.IpAdapterAddr
return hardwareAddr
}

// parseSocketAddress parses the SocketAddress into its string representation.
// This method needs to be deprecated in favour of IP() method of SocketAdress introduced in Go 1.13+.
// The method details have been taken from https://github.com/golang/sys/blob/release-branch.go1.13/windows/types_windows.go
func (utils *networkUtils) parseSocketAddress(addr windows.SocketAddress) string {
var ipAddr string
if uintptr(addr.SockaddrLength) >= unsafe.Sizeof(syscall.RawSockaddrInet4{}) && addr.Sockaddr.Addr.Family == syscall.AF_INET {
ip := net.IP((*syscall.RawSockaddrInet4)(unsafe.Pointer(addr.Sockaddr)).Addr[:])
ipAddr = ip.String()
}
return ipAddr
}

// getAdapterAddresses returns a list of IP adapter and address
// structures. The structure contains an IP adapter and flattened
// multiple IP addresses including unicast, anycast and multicast
Expand Down
21 changes: 7 additions & 14 deletions agent/eni/networkutils/utils_windows_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,7 @@ func TestGetInterfaceMACByIndex(t *testing.T) {

ctx := context.TODO()
mocknetwrapper := mock_netwrapper.NewMockNetWrapper(mockCtrl)
netUtils := New()
netUtils.SetNetWrapper(mocknetwrapper)
netUtils := &networkUtils{netWrapper: mocknetwrapper}
hardwareAddr, err := net.ParseMAC(macAddress)

mocknetwrapper.EXPECT().FindInterfaceByIndex(interfaceIndex).Return(
Expand All @@ -68,8 +67,7 @@ func TestGetInterfaceMACByIndexEmptyAddress(t *testing.T) {

ctx := context.TODO()
mocknetwrapper := mock_netwrapper.NewMockNetWrapper(mockCtrl)
netUtils := New()
netUtils.SetNetWrapper(mocknetwrapper)
netUtils := &networkUtils{netWrapper: mocknetwrapper}

mocknetwrapper.EXPECT().FindInterfaceByIndex(interfaceIndex).Return(
&net.Interface{
Expand All @@ -91,8 +89,7 @@ func TestGetInterfaceMACByIndexRetries(t *testing.T) {

ctx := context.TODO()
mocknetwrapper := mock_netwrapper.NewMockNetWrapper(mockCtrl)
netUtils := New()
netUtils.SetNetWrapper(mocknetwrapper)
netUtils := &networkUtils{netWrapper: mocknetwrapper}
hardwareAddr, err := net.ParseMAC(macAddress)
emptyaddr := make([]byte, 0)

Expand Down Expand Up @@ -123,8 +120,7 @@ func TestGetInterfaceMACByIndexContextTimeout(t *testing.T) {

ctx := context.TODO()
mocknetwrapper := mock_netwrapper.NewMockNetWrapper(mockCtrl)
netUtils := New()
netUtils.SetNetWrapper(mocknetwrapper)
netUtils := &networkUtils{netWrapper: mocknetwrapper}

mocknetwrapper.EXPECT().FindInterfaceByIndex(interfaceIndex).Return(
&net.Interface{
Expand All @@ -146,8 +142,7 @@ func TestGetInterfaceMACByIndexWithGolangNetError(t *testing.T) {

ctx := context.TODO()
mocknetwrapper := mock_netwrapper.NewMockNetWrapper(mockCtrl)
netUtils := New()
netUtils.SetNetWrapper(mocknetwrapper)
netUtils := &networkUtils{netWrapper: mocknetwrapper}

mocknetwrapper.EXPECT().FindInterfaceByIndex(interfaceIndex).Return(
nil, errors.New("unable to retrieve interface"))
Expand All @@ -164,8 +159,7 @@ func TestGetAllNetworkInterfaces(t *testing.T) {
defer mockCtrl.Finish()

mocknetwrapper := mock_netwrapper.NewMockNetWrapper(mockCtrl)
netUtils := New()
netUtils.SetNetWrapper(mocknetwrapper)
netUtils := &networkUtils{netWrapper: mocknetwrapper}

expectedIface := make([]net.Interface, 1)

Expand All @@ -191,8 +185,7 @@ func TestGetAllNetworkInterfacesError(t *testing.T) {
defer mockCtrl.Finish()

mocknetwrapper := mock_netwrapper.NewMockNetWrapper(mockCtrl)
netUtils := New()
netUtils.SetNetWrapper(mocknetwrapper)
netUtils := &networkUtils{netWrapper: mocknetwrapper}

mocknetwrapper.EXPECT().GetAllNetworkInterfaces().Return(
nil, errors.New("error occurred while fetching interfaces"),
Expand Down
9 changes: 1 addition & 8 deletions agent/eni/watcher/watcher_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,6 @@ func newWatcher(ctx context.Context,
state dockerstate.TaskEngineState,
stateChangeEvents chan<- statechange.Event) (*ENIWatcher, error) {

derivedContext, cancel := context.WithCancel(ctx)

eniMonitor := iphelperwrapper.NewMonitor()
notificationChannel := make(chan int)
err := eniMonitor.Start(notificationChannel)
Expand All @@ -69,6 +67,7 @@ func newWatcher(ctx context.Context,
}
log.Info("windows eni watcher has been initialized")

derivedContext, cancel := context.WithCancel(ctx)
return &ENIWatcher{
ctx: derivedContext,
cancel: cancel,
Expand Down Expand Up @@ -172,9 +171,3 @@ func (eniWatcher *ENIWatcher) getAllInterfaces() (state map[string]int, err erro
}
return state, nil
}

// SetNetworkUtils is used for injecting NetworkUtils instance in eniWatcher
// This will be handy while testing to inject mock objects
func (eniWatcher *ENIWatcher) SetNetworkUtils(utils networkutils.NetworkUtils) {
eniWatcher.netutils = utils
}
Loading

0 comments on commit add9b0b

Please sign in to comment.