diff --git a/route_linux.go b/route_linux.go index c69c595e..6b3f8666 100644 --- a/route_linux.go +++ b/route_linux.go @@ -33,6 +33,9 @@ const ( RT_FILTER_GW RT_FILTER_TABLE RT_FILTER_HOPLIMIT + RT_FILTER_PRIORITY + RT_FILTER_MARK + RT_FILTER_MASK ) const ( diff --git a/rule_linux.go b/rule_linux.go index 7e07d30e..71d62c6e 100644 --- a/rule_linux.go +++ b/rule_linux.go @@ -177,6 +177,19 @@ func RuleList(family int) ([]Rule, error) { // RuleList lists rules in the system. // Equivalent to: ip rule list func (h *Handle) RuleList(family int) ([]Rule, error) { + return h.RuleListFiltered(family, nil, 0) +} + +// RuleListFiltered gets a list of rules in the system filtered by the +// specified rule template `filter`. +// Equivalent to: ip rule list +func RuleListFiltered(family int, filter *Rule, filterMask uint64) ([]Rule, error) { + return pkgHandle.RuleListFiltered(family, filter, filterMask) +} + +// RuleListFiltered lists rules in the system. +// Equivalent to: ip rule list +func (h *Handle) RuleListFiltered(family int, filter *Rule, filterMask uint64) ([]Rule, error) { req := h.newNetlinkRequest(unix.RTM_GETRULE, unix.NLM_F_DUMP|unix.NLM_F_REQUEST) msg := nl.NewIfInfomsg(family) req.AddData(msg) @@ -246,6 +259,29 @@ func (h *Handle) RuleList(family int) ([]Rule, error) { rule.Sport = NewRulePortRange(native.Uint16(attrs[j].Value[0:2]), native.Uint16(attrs[j].Value[2:4])) } } + + if filter != nil { + switch { + case filterMask&RT_FILTER_SRC != 0 && + (rule.Src == nil || rule.Src.String() != filter.Src.String()): + continue + case filterMask&RT_FILTER_DST != 0 && + (rule.Dst == nil || rule.Dst.String() != filter.Dst.String()): + continue + case filterMask&RT_FILTER_TABLE != 0 && + filter.Table != unix.RT_TABLE_UNSPEC && rule.Table != filter.Table: + continue + case filterMask&RT_FILTER_TOS != 0 && rule.Tos != filter.Tos: + continue + case filterMask&RT_FILTER_PRIORITY != 0 && rule.Priority != filter.Priority: + continue + case filterMask&RT_FILTER_MARK != 0 && rule.Mark != filter.Mark: + continue + case filterMask&RT_FILTER_MASK != 0 && rule.Mask != filter.Mask: + continue + } + } + res = append(res, *rule) } diff --git a/rule_test.go b/rule_test.go index 630f10fb..2458a502 100644 --- a/rule_test.go +++ b/rule_test.go @@ -16,7 +16,7 @@ func TestRuleAddDel(t *testing.T) { srcNet := &net.IPNet{IP: net.IPv4(172, 16, 0, 1), Mask: net.CIDRMask(16, 32)} dstNet := &net.IPNet{IP: net.IPv4(172, 16, 1, 1), Mask: net.CIDRMask(24, 32)} - rulesBegin, err := RuleList(unix.AF_INET) + rulesBegin, err := RuleList(FAMILY_V4) if err != nil { t.Fatal(err) } @@ -36,7 +36,7 @@ func TestRuleAddDel(t *testing.T) { t.Fatal(err) } - rules, err := RuleList(unix.AF_INET) + rules, err := RuleList(FAMILY_V4) if err != nil { t.Fatal(err) } @@ -46,20 +46,7 @@ func TestRuleAddDel(t *testing.T) { } // find this rule - var found bool - for i := range rules { - if rules[i].Table == rule.Table && - rules[i].Src != nil && rules[i].Src.String() == srcNet.String() && - rules[i].Dst != nil && rules[i].Dst.String() == dstNet.String() && - rules[i].OifName == rule.OifName && - rules[i].Priority == rule.Priority && - rules[i].IifName == rule.IifName && - rules[i].Invert == rule.Invert && - rules[i].Tos == rule.Tos { - found = true - break - } - } + found := ruleExists(rules, *rule) if !found { t.Fatal("Rule has diffrent options than one added") } @@ -68,7 +55,7 @@ func TestRuleAddDel(t *testing.T) { t.Fatal(err) } - rulesEnd, err := RuleList(unix.AF_INET) + rulesEnd, err := RuleList(FAMILY_V4) if err != nil { t.Fatal(err) } @@ -77,3 +64,313 @@ func TestRuleAddDel(t *testing.T) { t.Fatal("Rule not removed properly") } } + +func TestRuleListFiltered(t *testing.T) { + skipUnlessRoot(t) + defer setUpNetlinkTest(t)() + + t.Run("IPv4", testRuleListFilteredIPv4) + t.Run("IPv6", testRuleListFilteredIPv6) +} + +func testRuleListFilteredIPv4(t *testing.T) { + srcNet := &net.IPNet{IP: net.IPv4(172, 16, 0, 1), Mask: net.CIDRMask(16, 32)} + dstNet := &net.IPNet{IP: net.IPv4(172, 16, 1, 1), Mask: net.CIDRMask(24, 32)} + runRuleListFiltered(t, FAMILY_V4, srcNet, dstNet) +} + +func testRuleListFilteredIPv6(t *testing.T) { + ip1 := net.ParseIP("fd56:6b58:db28:2913::") + ip2 := net.ParseIP("fde9:379f:3b35:6635::") + + srcNet := &net.IPNet{IP: ip1, Mask: net.CIDRMask(64, 128)} + dstNet := &net.IPNet{IP: ip2, Mask: net.CIDRMask(96, 128)} + runRuleListFiltered(t, FAMILY_V6, srcNet, dstNet) +} + +func runRuleListFiltered(t *testing.T, family int, srcNet, dstNet *net.IPNet) { + defaultRules, _ := RuleList(family) + + tests := []struct { + name string + ruleFilter *Rule + filterMask uint64 + preRun func() *Rule // Creates sample rule harness + postRun func(*Rule) // Deletes sample rule harness + setupWant func(*Rule) ([]Rule, bool) + }{ + { + name: "returns all rules", + ruleFilter: nil, + filterMask: 0, + preRun: func() *Rule { return nil }, + postRun: func(r *Rule) {}, + setupWant: func(_ *Rule) ([]Rule, bool) { + return defaultRules, false + }, + }, + { + name: "returns one rule filtered by Src", + ruleFilter: &Rule{Src: srcNet}, + filterMask: RT_FILTER_SRC, + preRun: func() *Rule { + r := NewRule() + r.Src = srcNet + r.Priority = 1 // Must add priority and table otherwise it's auto-assigned + r.Table = 1 + RuleAdd(r) + return r + }, + postRun: func(r *Rule) { RuleDel(r) }, + setupWant: func(r *Rule) ([]Rule, bool) { + return []Rule{*r}, false + }, + }, + { + name: "returns one rule filtered by Dst", + ruleFilter: &Rule{Dst: dstNet}, + filterMask: RT_FILTER_DST, + preRun: func() *Rule { + r := NewRule() + r.Dst = dstNet + r.Priority = 1 // Must add priority and table otherwise it's auto-assigned + r.Table = 1 + RuleAdd(r) + return r + }, + postRun: func(r *Rule) { RuleDel(r) }, + setupWant: func(r *Rule) ([]Rule, bool) { + return []Rule{*r}, false + }, + }, + { + name: "returns two rules filtered by Dst", + ruleFilter: &Rule{Dst: dstNet}, + filterMask: RT_FILTER_DST, + preRun: func() *Rule { + r := NewRule() + r.Dst = dstNet + r.Priority = 1 // Must add priority and table otherwise it's auto-assigned + r.Table = 1 + RuleAdd(r) + + rc := *r // Create almost identical copy + rc.Src = srcNet + RuleAdd(&rc) + + return r + }, + postRun: func(r *Rule) { + RuleDel(r) + + rc := *r // Delete the almost identical copy + rc.Src = srcNet + RuleDel(&rc) + }, + setupWant: func(r *Rule) ([]Rule, bool) { + rs := []Rule{} + rs = append(rs, *r) + + rc := *r // Append the almost identical copy + rc.Src = srcNet + rs = append(rs, rc) + + return rs, false + }, + }, + { + name: "returns one rule filtered by Src when two rules exist", + ruleFilter: &Rule{Src: srcNet}, + filterMask: RT_FILTER_SRC, + preRun: func() *Rule { + r := NewRule() + r.Dst = dstNet + r.Priority = 1 // Must add priority and table otherwise it's auto-assigned + r.Table = 1 + RuleAdd(r) + + rc := *r // Create almost identical copy + rc.Src = srcNet + RuleAdd(&rc) + + return r + }, + postRun: func(r *Rule) { + RuleDel(r) + + rc := *r // Delete the almost identical copy + rc.Src = srcNet + RuleDel(&rc) + }, + setupWant: func(r *Rule) ([]Rule, bool) { + rs := []Rule{} + // Do not append `r` + + rc := *r // Append the almost identical copy + rc.Src = srcNet + rs = append(rs, rc) + + return rs, false + }, + }, + { + name: "returns rules with specific priority", + ruleFilter: &Rule{Priority: 5}, + filterMask: RT_FILTER_PRIORITY, + preRun: func() *Rule { + r := NewRule() + r.Src = srcNet + r.Priority = 5 + r.Table = 1 + RuleAdd(r) + + for i := 2; i < 5; i++ { + rc := *r // Create almost identical copy + rc.Table = i + RuleAdd(&rc) + } + + return r + }, + postRun: func(r *Rule) { + RuleDel(r) + + for i := 2; i < 5; i++ { + rc := *r // Delete the almost identical copy + rc.Table = -1 + RuleDel(&rc) + } + }, + setupWant: func(r *Rule) ([]Rule, bool) { + rs := []Rule{} + rs = append(rs, *r) + + for i := 2; i < 5; i++ { + rc := *r // Append the almost identical copy + rc.Table = i + rs = append(rs, rc) + } + + return rs, false + }, + }, + { + name: "returns rules filtered by Table", + ruleFilter: &Rule{Table: 199}, + filterMask: RT_FILTER_TABLE, + preRun: func() *Rule { + r := NewRule() + r.Src = srcNet + r.Priority = 1 // Must add priority otherwise it's auto-assigned + r.Table = 199 + RuleAdd(r) + return r + }, + postRun: func(r *Rule) { RuleDel(r) }, + setupWant: func(r *Rule) ([]Rule, bool) { + return []Rule{*r}, false + }, + }, + { + name: "returns rules filtered by Mask", + ruleFilter: &Rule{Mask: 0x5}, + filterMask: RT_FILTER_MASK, + preRun: func() *Rule { + r := NewRule() + r.Src = srcNet + r.Priority = 1 // Must add priority and table otherwise it's auto-assigned + r.Table = 1 + r.Mask = 0x5 + RuleAdd(r) + return r + }, + postRun: func(r *Rule) { RuleDel(r) }, + setupWant: func(r *Rule) ([]Rule, bool) { + return []Rule{*r}, false + }, + }, + { + name: "returns rules filtered by Mark", + ruleFilter: &Rule{Mark: 0xbb}, + filterMask: RT_FILTER_MARK, + preRun: func() *Rule { + r := NewRule() + r.Src = srcNet + r.Priority = 1 // Must add priority, table, mask otherwise it's auto-assigned + r.Table = 1 + r.Mask = 0xff + r.Mark = 0xbb + RuleAdd(r) + return r + }, + postRun: func(r *Rule) { RuleDel(r) }, + setupWant: func(r *Rule) ([]Rule, bool) { + return []Rule{*r}, false + }, + }, + { + name: "returns rules filtered by Tos", + ruleFilter: &Rule{Tos: 12}, + filterMask: RT_FILTER_TOS, + preRun: func() *Rule { + r := NewRule() + r.Src = srcNet + r.Priority = 1 // Must add priority, table, mask otherwise it's auto-assigned + r.Table = 12 + r.Tos = 12 // Tos must equal table + RuleAdd(r) + return r + }, + postRun: func(r *Rule) { RuleDel(r) }, + setupWant: func(r *Rule) ([]Rule, bool) { + return []Rule{*r}, false + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rule := tt.preRun() + rules, err := RuleListFiltered(family, tt.ruleFilter, tt.filterMask) + tt.postRun(rule) + + wantRules, wantErr := tt.setupWant(rule) + + if len(wantRules) != len(rules) { + t.Errorf("Expected len: %d, got: %d", len(wantRules), len(rules)) + } else { + for i := range wantRules { + if !ruleEquals(wantRules[i], rules[i]) { + t.Errorf("Rules mismatch, want %v, got %v", wantRules[i], rules[i]) + } + } + } + + if (err != nil) != wantErr { + t.Errorf("Error expectation not met, want %v, got %v", (err != nil), wantErr) + } + }) + } +} + +func ruleExists(rules []Rule, rule Rule) bool { + for i := range rules { + if ruleEquals(rules[i], rule) { + return true + } + } + + return false +} + +func ruleEquals(a, b Rule) bool { + return a.Table == b.Table && + ((a.Src == nil && b.Src == nil) || + (a.Src != nil && b.Src != nil && a.Src.String() == b.Src.String())) && + ((a.Dst == nil && b.Dst == nil) || + (a.Dst != nil && b.Dst != nil && a.Dst.String() == b.Dst.String())) && + a.OifName == b.OifName && + a.Priority == b.Priority && + a.IifName == b.IifName && + a.Invert == b.Invert && + a.Tos == b.Tos +}