Skip to content

Commit

Permalink
Rewrite ipset-test
Browse files Browse the repository at this point in the history
  • Loading branch information
corny committed Sep 30, 2018
1 parent 04cb360 commit f56c1f0
Showing 1 changed file with 97 additions and 70 deletions.
167 changes: 97 additions & 70 deletions cmd/ipset-test/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,95 +2,126 @@ package main

import (
"flag"
"fmt"
"log"
"net"
"os"
"sort"

"github.com/vishvananda/netlink"
)

type command struct {
Function func([]string)
Description string
ArgCount int
}

var (
commands = map[string]command{
"protocol": {cmdProtocol, "prints the protocol version", 0},
"create": {cmdCreate, "creates a new ipset", 2},
"destroy": {cmdCreate, "creates a new ipset", 1},
"list": {cmdCreate, "list specific ipset", 1},
"listall": {cmdCreate, "list all ipsets", 0},
"add": {cmdAddDel(netlink.IpsetAdd), "add entry", 1},
"del": {cmdAddDel(netlink.IpsetDel), "delete entry", 1},
}

timeoutVal *uint32
timeout = flag.Int("timeout", -1, "timeout, negative means omit the argument")
comment = flag.String("comment", "", "comment")
withComments = flag.Bool("with-comments", false, "create set with comment support")
withCounters = flag.Bool("with-counters", false, "create set with counters support")
withSkbinfo = flag.Bool("with-skbinfo", false, "create set with skbinfo support")
replace = flag.Bool("replace", false, "replace existing set/entry")
)

func main() {
timeout := flag.Int("timeout", -1, "timeout, negative means omit the argument")
comment := flag.String("comment", "", "comment")
withComments := flag.Bool("with-comments", false, "create set with comment support")
withCounters := flag.Bool("with-counters", false, "create set with counters support")
withSkbinfo := flag.Bool("with-skbinfo", false, "create set with skbinfo support")
replace := flag.Bool("replace", false, "replace existing set/entry")
flag.Parse()
args := flag.Args()

if len(args) < 1 {
panic("invalid arguments")
printUsage()
os.Exit(1)
}

var timeoutVal *uint32
if *timeout >= 0 {
v := uint32(*timeout)
timeoutVal = &v
}

log.SetFlags(log.Lshortfile)

cmd := args[0]
cmdName := args[0]
args = args[1:]

switch cmd {
case "protocol":
protocol, err := netlink.IpsetProtocol()
if err != nil {
panic(err)
}
log.Println("Protocol:", protocol)
cmd, exist := commands[cmdName]
if !exist {
fmt.Println("unknown command")
os.Exit(1)
}

case "create":
if len(args) != 2 {
panic("invalid arguments")
}
if cmd.ArgCount != len(args) {
fmt.Printf("invalid number of arguments. expected=%d given=%d\n", cmd.ArgCount, len(args))
os.Exit(1)
}

err := netlink.IpsetCreate(args[0], args[1], netlink.IpsetCreateOptions{
Replace: *replace,
Timeout: timeoutVal,
Comments: *withComments,
Counters: *withCounters,
Skbinfo: *withSkbinfo,
})
if err != nil {
panic(err)
}
cmd.Function(args)
}

case "destroy":
if len(args) != 1 {
panic("invalid arguments")
}
err := netlink.IpsetDestroy(args[0])
if err != nil {
panic(err)
}
func printUsage() {
fmt.Printf("Usage: %s COMMAND [args] [-flags]\n\n", os.Args[0])
names := make([]string, 0, len(commands))
for name := range commands {
names = append(names, name)
}
sort.Strings(names)
fmt.Println("Available commands:")
for _, name := range names {
fmt.Printf(" %-15v %s\n", name, commands[name].Description)
}
fmt.Println("\nAvailable flags:")
flag.PrintDefaults()
}

case "list":
if len(args) != 1 {
panic("invalid arguments")
}
func cmdProtocol(_ []string) {
protocol, err := netlink.IpsetProtocol()
check(err)
log.Println("Protocol:", protocol)
}

result, err := netlink.IpsetList(args[0])
if err != nil {
panic(err)
}
log.Printf("%+v", result)
func cmdCreate(args []string) {
err := netlink.IpsetCreate(args[0], args[1], netlink.IpsetCreateOptions{
Replace: *replace,
Timeout: timeoutVal,
Comments: *withComments,
Counters: *withCounters,
Skbinfo: *withSkbinfo,
})
check(err)
}

case "listall":
result, err := netlink.IpsetListAll()
if err != nil {
panic(err)
}
for _, ipset := range result {
log.Printf("%+v", ipset)
}
func cmdDestroy(args []string) {
check(netlink.IpsetDestroy(args[0]))
}

case "add", "del":
if len(args) != 2 {
panic("invalid arguments")
}
func cmdList(args []string) {
result, err := netlink.IpsetList(args[0])
check(err)
log.Printf("%+v", result)
}

func cmdListAll(args []string) {
result, err := netlink.IpsetListAll()
check(err)
for _, ipset := range result {
log.Printf("%+v", ipset)
}
}

func cmdAddDel(f func(string, *netlink.IPSetEntry) error) func([]string) {
return func(args []string) {
setName := args[0]
element := args[1]

Expand All @@ -102,17 +133,13 @@ func main() {
Replace: *replace,
}

var err error
if cmd == "add" {
err = netlink.IpsetAdd(setName, &entry)
} else {
err = netlink.IpsetDel(setName, &entry)
}
check(f(setName, &entry))
}
}

if err != nil {
panic(err)
}
default:
panic("invalid command")
// panic on error
func check(err error) {
if err != nil {
panic(err)
}
}

0 comments on commit f56c1f0

Please sign in to comment.