Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[client] Use the prerouting chain to mark for masquerading to support older systems #2808

Merged
merged 3 commits into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions client/firewall/iptables/acl_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -352,14 +352,14 @@ func (m *aclManager) seedInitialEntries() {
func (m *aclManager) seedInitialOptionalEntries() {
m.optionalEntries["FORWARD"] = []entry{
{
spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmark), "-j", chainNameInputRules},
spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected), "-j", chainNameInputRules},
position: 2,
},
}

m.optionalEntries["PREROUTING"] = []entry{
{
spec: []string{"-t", "mangle", "-i", m.wgIface.Name(), "-m", "addrtype", "--dst-type", "LOCAL", "-j", "MARK", "--set-mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmark)},
spec: []string{"-t", "mangle", "-i", m.wgIface.Name(), "-m", "addrtype", "--dst-type", "LOCAL", "-j", "MARK", "--set-mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected)},
position: 1,
},
}
Expand Down
169 changes: 116 additions & 53 deletions client/firewall/iptables/router_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,24 @@ import (
"github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/statemanager"
)

const (
ipv4Nat = "netbird-rt-nat"
nbnet "github.com/netbirdio/netbird/util/net"
)

// constants needed to manage and create iptable rules
const (
tableFilter = "filter"
tableNat = "nat"
tableMangle = "mangle"
chainPOSTROUTING = "POSTROUTING"
chainPREROUTING = "PREROUTING"
chainRTNAT = "NETBIRD-RT-NAT"
chainRTFWD = "NETBIRD-RT-FWD"
chainRTPRE = "NETBIRD-RT-PRE"
routingFinalForwardJump = "ACCEPT"
routingFinalNatJump = "MASQUERADE"

jumpPre = "jump-pre"
jumpNat = "jump-nat"
matchSet = "--match-set"
)

Expand Down Expand Up @@ -323,24 +325,25 @@ func (r *router) Reset() error {
}

func (r *router) cleanUpDefaultForwardRules() error {
err := r.cleanJumpRules()
if err != nil {
return err
if err := r.cleanJumpRules(); err != nil {
return fmt.Errorf("clean jump rules: %w", err)
}

log.Debug("flushing routing related tables")
for _, chain := range []string{chainRTFWD, chainRTNAT} {
table := r.getTableForChain(chain)

ok, err := r.iptablesClient.ChainExists(table, chain)
for _, chainInfo := range []struct {
chain string
table string
}{
{chainRTFWD, tableFilter},
{chainRTNAT, tableNat},
{chainRTPRE, tableMangle},
} {
ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain)
if err != nil {
log.Errorf("failed check chain %s, error: %v", chain, err)
return err
return fmt.Errorf("check chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
} else if ok {
err = r.iptablesClient.ClearAndDeleteChain(table, chain)
if err != nil {
log.Errorf("failed cleaning chain %s, error: %v", chain, err)
return err
if err = r.iptablesClient.ClearAndDeleteChain(chainInfo.table, chainInfo.chain); err != nil {
return fmt.Errorf("clear and delete chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
}
}
}
Expand All @@ -349,23 +352,60 @@ func (r *router) cleanUpDefaultForwardRules() error {
}

func (r *router) createContainers() error {
for _, chain := range []string{chainRTFWD, chainRTNAT} {
if err := r.createAndSetupChain(chain); err != nil {
return fmt.Errorf("create chain %s: %w", chain, err)
for _, chainInfo := range []struct {
chain string
table string
}{
{chainRTFWD, tableFilter},
{chainRTPRE, tableMangle},
{chainRTNAT, tableNat},
} {
if err := r.createAndSetupChain(chainInfo.chain); err != nil {
return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
}
}

if err := r.insertEstablishedRule(chainRTFWD); err != nil {
return fmt.Errorf("insert established rule: %w", err)
}

if err := r.addPostroutingRules(); err != nil {
return fmt.Errorf("add static nat rules: %w", err)
}

if err := r.addJumpRules(); err != nil {
return fmt.Errorf("add jump rules: %w", err)
}

return nil
}

func (r *router) addPostroutingRules() error {
// First rule for outbound masquerade
rule1 := []string{
"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
"!", "-o", "lo",
"-j", routingFinalNatJump,
}
if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule1...); err != nil {
return fmt.Errorf("add outbound masquerade rule: %v", err)
}
r.rules["static-nat-outbound"] = rule1

// Second rule for return traffic masquerade
rule2 := []string{
"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
"-o", r.wgIface.Name(),
"-j", routingFinalNatJump,
}
if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule2...); err != nil {
return fmt.Errorf("add return masquerade rule: %v", err)
}
r.rules["static-nat-return"] = rule2

return nil
}

func (r *router) createAndSetupChain(chain string) error {
table := r.getTableForChain(chain)

Expand All @@ -377,10 +417,14 @@ func (r *router) createAndSetupChain(chain string) error {
}

func (r *router) getTableForChain(chain string) string {
if chain == chainRTNAT {
switch chain {
case chainRTNAT:
return tableNat
case chainRTPRE:
return tableMangle
default:
return tableFilter
}
return tableFilter
}

func (r *router) insertEstablishedRule(chain string) error {
Expand All @@ -398,59 +442,88 @@ func (r *router) insertEstablishedRule(chain string) error {
}

func (r *router) addJumpRules() error {
rule := []string{"-j", chainRTNAT}
err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, rule...)
if err != nil {
return err
// Jump to NAT chain
natRule := []string{"-j", chainRTNAT}
if err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, natRule...); err != nil {
return fmt.Errorf("add nat jump rule: %v", err)
}
r.rules[ipv4Nat] = rule
r.rules[jumpNat] = natRule

// Jump to prerouting chain
preRule := []string{"-j", chainRTPRE}
if err := r.iptablesClient.Insert(tableMangle, chainPREROUTING, 1, preRule...); err != nil {
return fmt.Errorf("add prerouting jump rule: %v", err)
}
r.rules[jumpPre] = preRule

return nil
}

func (r *router) cleanJumpRules() error {
rule, found := r.rules[ipv4Nat]
if found {
err := r.iptablesClient.DeleteIfExists(tableNat, chainPOSTROUTING, rule...)
if err != nil {
return fmt.Errorf("failed cleaning rule from chain %s, err: %v", chainPOSTROUTING, err)
for _, ruleKey := range []string{jumpNat, jumpPre} {
if rule, exists := r.rules[ruleKey]; exists {
table := tableNat
chain := chainPOSTROUTING
if ruleKey == jumpPre {
table = tableMangle
chain = chainPREROUTING
}

if err := r.iptablesClient.DeleteIfExists(table, chain, rule...); err != nil {
return fmt.Errorf("delete rule from chain %s in table %s, err: %v", chain, table, err)
}
delete(r.rules, ruleKey)
}
}

return nil
}

func (r *router) addNatRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.NatFormat, pair)

if rule, exists := r.rules[ruleKey]; exists {
if err := r.iptablesClient.DeleteIfExists(tableNat, chainRTNAT, rule...); err != nil {
return fmt.Errorf("error while removing existing NAT rule for %s: %v", pair.Destination, err)
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainRTPRE, rule...); err != nil {
return fmt.Errorf("error while removing existing marking rule for %s: %v", pair.Destination, err)
}
delete(r.rules, ruleKey)
}

rule := genRuleSpec(routingFinalNatJump, pair.Source, pair.Destination, r.wgIface.Name(), pair.Inverse)
if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule...); err != nil {
return fmt.Errorf("error while appending new NAT rule for %s: %v", pair.Destination, err)
markValue := nbnet.PreroutingFwmarkMasquerade
if pair.Inverse {
markValue = nbnet.PreroutingFwmarkMasqueradeReturn
}

r.rules[ruleKey] = rule
rule := []string{"-i", r.wgIface.Name()}
if pair.Inverse {
rule = []string{"!", "-i", r.wgIface.Name()}
}

rule = append(rule,
"-m", "conntrack",
"--ctstate", "NEW",
"-s", pair.Source.String(),
"-d", pair.Destination.String(),
"-j", "MARK", "--set-mark", fmt.Sprintf("%#x", markValue),
)

if err := r.iptablesClient.Append(tableMangle, chainRTPRE, rule...); err != nil {
return fmt.Errorf("error while adding marking rule for %s: %v", pair.Destination, err)
}

r.rules[ruleKey] = rule
return nil
}

func (r *router) removeNatRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.NatFormat, pair)

if rule, exists := r.rules[ruleKey]; exists {
if err := r.iptablesClient.DeleteIfExists(tableNat, chainRTNAT, rule...); err != nil {
return fmt.Errorf("error while removing existing nat rule for %s: %v", pair.Destination, err)
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainRTPRE, rule...); err != nil {
return fmt.Errorf("error while removing marking rule for %s: %v", pair.Destination, err)
}

delete(r.rules, ruleKey)
} else {
log.Debugf("nat rule %s not found", ruleKey)
log.Debugf("marking rule %s not found", ruleKey)
}

return nil
Expand Down Expand Up @@ -482,16 +555,6 @@ func (r *router) updateState() {
}
}

func genRuleSpec(jump string, source, destination netip.Prefix, intf string, inverse bool) []string {
intdir := "-i"
lointdir := "-o"
if inverse {
intdir = "-o"
lointdir = "-i"
}
return []string{intdir, intf, "!", lointdir, "lo", "-s", source.String(), "-d", destination.String(), "-j", jump}
}

func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {
var rule []string

Expand Down
Loading
Loading