github.com/vishvananda/netlink@v1.3.1/rule_linux.go (about)

     1  package netlink
     2  
     3  import (
     4  	"bytes"
     5  	"errors"
     6  	"fmt"
     7  	"net"
     8  
     9  	"github.com/vishvananda/netlink/nl"
    10  	"golang.org/x/sys/unix"
    11  )
    12  
    13  const FibRuleInvert = 0x2
    14  
    15  // RuleAdd adds a rule to the system.
    16  // Equivalent to: ip rule add
    17  func RuleAdd(rule *Rule) error {
    18  	return pkgHandle.RuleAdd(rule)
    19  }
    20  
    21  // RuleAdd adds a rule to the system.
    22  // Equivalent to: ip rule add
    23  func (h *Handle) RuleAdd(rule *Rule) error {
    24  	req := h.newNetlinkRequest(unix.RTM_NEWRULE, unix.NLM_F_CREATE|unix.NLM_F_EXCL|unix.NLM_F_ACK)
    25  	return ruleHandle(rule, req)
    26  }
    27  
    28  // RuleDel deletes a rule from the system.
    29  // Equivalent to: ip rule del
    30  func RuleDel(rule *Rule) error {
    31  	return pkgHandle.RuleDel(rule)
    32  }
    33  
    34  // RuleDel deletes a rule from the system.
    35  // Equivalent to: ip rule del
    36  func (h *Handle) RuleDel(rule *Rule) error {
    37  	req := h.newNetlinkRequest(unix.RTM_DELRULE, unix.NLM_F_ACK)
    38  	return ruleHandle(rule, req)
    39  }
    40  
    41  func ruleHandle(rule *Rule, req *nl.NetlinkRequest) error {
    42  	msg := nl.NewRtMsg()
    43  	msg.Family = unix.AF_INET
    44  	msg.Protocol = unix.RTPROT_BOOT
    45  	msg.Scope = unix.RT_SCOPE_UNIVERSE
    46  	msg.Table = unix.RT_TABLE_UNSPEC
    47  	msg.Type = rule.Type // usually 0, same as unix.RTN_UNSPEC
    48  	if msg.Type == 0 && req.NlMsghdr.Flags&unix.NLM_F_CREATE > 0 {
    49  		msg.Type = unix.RTN_UNICAST
    50  	}
    51  	if rule.Invert {
    52  		msg.Flags |= FibRuleInvert
    53  	}
    54  	if rule.Family != 0 {
    55  		msg.Family = uint8(rule.Family)
    56  	}
    57  	if rule.Table >= 0 && rule.Table < 256 {
    58  		msg.Table = uint8(rule.Table)
    59  	}
    60  	if rule.Tos != 0 {
    61  		msg.Tos = uint8(rule.Tos)
    62  	}
    63  
    64  	var dstFamily uint8
    65  	var rtAttrs []*nl.RtAttr
    66  	if rule.Dst != nil && rule.Dst.IP != nil {
    67  		dstLen, _ := rule.Dst.Mask.Size()
    68  		msg.Dst_len = uint8(dstLen)
    69  		msg.Family = uint8(nl.GetIPFamily(rule.Dst.IP))
    70  		dstFamily = msg.Family
    71  		var dstData []byte
    72  		if msg.Family == unix.AF_INET {
    73  			dstData = rule.Dst.IP.To4()
    74  		} else {
    75  			dstData = rule.Dst.IP.To16()
    76  		}
    77  		rtAttrs = append(rtAttrs, nl.NewRtAttr(unix.RTA_DST, dstData))
    78  	}
    79  
    80  	if rule.Src != nil && rule.Src.IP != nil {
    81  		msg.Family = uint8(nl.GetIPFamily(rule.Src.IP))
    82  		if dstFamily != 0 && dstFamily != msg.Family {
    83  			return fmt.Errorf("source and destination ip are not the same IP family")
    84  		}
    85  		srcLen, _ := rule.Src.Mask.Size()
    86  		msg.Src_len = uint8(srcLen)
    87  		var srcData []byte
    88  		if msg.Family == unix.AF_INET {
    89  			srcData = rule.Src.IP.To4()
    90  		} else {
    91  			srcData = rule.Src.IP.To16()
    92  		}
    93  		rtAttrs = append(rtAttrs, nl.NewRtAttr(unix.RTA_SRC, srcData))
    94  	}
    95  
    96  	req.AddData(msg)
    97  	for i := range rtAttrs {
    98  		req.AddData(rtAttrs[i])
    99  	}
   100  
   101  	if rule.Priority >= 0 {
   102  		b := make([]byte, 4)
   103  		native.PutUint32(b, uint32(rule.Priority))
   104  		req.AddData(nl.NewRtAttr(nl.FRA_PRIORITY, b))
   105  	}
   106  	if rule.Mark != 0 || rule.Mask != nil {
   107  		b := make([]byte, 4)
   108  		native.PutUint32(b, rule.Mark)
   109  		req.AddData(nl.NewRtAttr(nl.FRA_FWMARK, b))
   110  	}
   111  	if rule.Mask != nil {
   112  		b := make([]byte, 4)
   113  		native.PutUint32(b, *rule.Mask)
   114  		req.AddData(nl.NewRtAttr(nl.FRA_FWMASK, b))
   115  	}
   116  	if rule.Flow >= 0 {
   117  		b := make([]byte, 4)
   118  		native.PutUint32(b, uint32(rule.Flow))
   119  		req.AddData(nl.NewRtAttr(nl.FRA_FLOW, b))
   120  	}
   121  	if rule.TunID > 0 {
   122  		b := make([]byte, 4)
   123  		native.PutUint32(b, uint32(rule.TunID))
   124  		req.AddData(nl.NewRtAttr(nl.FRA_TUN_ID, b))
   125  	}
   126  	if rule.Table >= 256 {
   127  		b := make([]byte, 4)
   128  		native.PutUint32(b, uint32(rule.Table))
   129  		req.AddData(nl.NewRtAttr(nl.FRA_TABLE, b))
   130  	}
   131  	if msg.Table > 0 {
   132  		if rule.SuppressPrefixlen >= 0 {
   133  			b := make([]byte, 4)
   134  			native.PutUint32(b, uint32(rule.SuppressPrefixlen))
   135  			req.AddData(nl.NewRtAttr(nl.FRA_SUPPRESS_PREFIXLEN, b))
   136  		}
   137  		if rule.SuppressIfgroup >= 0 {
   138  			b := make([]byte, 4)
   139  			native.PutUint32(b, uint32(rule.SuppressIfgroup))
   140  			req.AddData(nl.NewRtAttr(nl.FRA_SUPPRESS_IFGROUP, b))
   141  		}
   142  	}
   143  	if rule.IifName != "" {
   144  		req.AddData(nl.NewRtAttr(nl.FRA_IIFNAME, []byte(rule.IifName+"\x00")))
   145  	}
   146  	if rule.OifName != "" {
   147  		req.AddData(nl.NewRtAttr(nl.FRA_OIFNAME, []byte(rule.OifName+"\x00")))
   148  	}
   149  	if rule.Goto >= 0 {
   150  		msg.Type = nl.FR_ACT_GOTO
   151  		b := make([]byte, 4)
   152  		native.PutUint32(b, uint32(rule.Goto))
   153  		req.AddData(nl.NewRtAttr(nl.FRA_GOTO, b))
   154  	}
   155  
   156  	if rule.IPProto > 0 {
   157  		b := make([]byte, 4)
   158  		native.PutUint32(b, uint32(rule.IPProto))
   159  		req.AddData(nl.NewRtAttr(nl.FRA_IP_PROTO, b))
   160  	}
   161  
   162  	if rule.Dport != nil {
   163  		b := rule.Dport.toRtAttrData()
   164  		req.AddData(nl.NewRtAttr(nl.FRA_DPORT_RANGE, b))
   165  	}
   166  
   167  	if rule.Sport != nil {
   168  		b := rule.Sport.toRtAttrData()
   169  		req.AddData(nl.NewRtAttr(nl.FRA_SPORT_RANGE, b))
   170  	}
   171  
   172  	if rule.UIDRange != nil {
   173  		b := rule.UIDRange.toRtAttrData()
   174  		req.AddData(nl.NewRtAttr(nl.FRA_UID_RANGE, b))
   175  	}
   176  
   177  	if rule.Protocol > 0 {
   178  		req.AddData(nl.NewRtAttr(nl.FRA_PROTOCOL, nl.Uint8Attr(rule.Protocol)))
   179  	}
   180  
   181  	_, err := req.Execute(unix.NETLINK_ROUTE, 0)
   182  	return err
   183  }
   184  
   185  // RuleList lists rules in the system.
   186  // Equivalent to: ip rule list
   187  //
   188  // If the returned error is [ErrDumpInterrupted], results may be inconsistent
   189  // or incomplete.
   190  func RuleList(family int) ([]Rule, error) {
   191  	return pkgHandle.RuleList(family)
   192  }
   193  
   194  // RuleList lists rules in the system.
   195  // Equivalent to: ip rule list
   196  //
   197  // If the returned error is [ErrDumpInterrupted], results may be inconsistent
   198  // or incomplete.
   199  func (h *Handle) RuleList(family int) ([]Rule, error) {
   200  	return h.RuleListFiltered(family, nil, 0)
   201  }
   202  
   203  // RuleListFiltered gets a list of rules in the system filtered by the
   204  // specified rule template `filter`.
   205  // Equivalent to: ip rule list
   206  //
   207  // If the returned error is [ErrDumpInterrupted], results may be inconsistent
   208  // or incomplete.
   209  func RuleListFiltered(family int, filter *Rule, filterMask uint64) ([]Rule, error) {
   210  	return pkgHandle.RuleListFiltered(family, filter, filterMask)
   211  }
   212  
   213  // RuleListFiltered lists rules in the system.
   214  // Equivalent to: ip rule list
   215  //
   216  // If the returned error is [ErrDumpInterrupted], results may be inconsistent
   217  // or incomplete.
   218  func (h *Handle) RuleListFiltered(family int, filter *Rule, filterMask uint64) ([]Rule, error) {
   219  	req := h.newNetlinkRequest(unix.RTM_GETRULE, unix.NLM_F_DUMP|unix.NLM_F_REQUEST)
   220  	msg := nl.NewIfInfomsg(family)
   221  	req.AddData(msg)
   222  
   223  	msgs, executeErr := req.Execute(unix.NETLINK_ROUTE, unix.RTM_NEWRULE)
   224  	if executeErr != nil && !errors.Is(executeErr, ErrDumpInterrupted) {
   225  		return nil, executeErr
   226  	}
   227  
   228  	var res = make([]Rule, 0)
   229  	for i := range msgs {
   230  		msg := nl.DeserializeRtMsg(msgs[i])
   231  		attrs, err := nl.ParseRouteAttr(msgs[i][msg.Len():])
   232  		if err != nil {
   233  			return nil, err
   234  		}
   235  
   236  		rule := NewRule()
   237  		rule.Priority = 0 // The default priority from kernel
   238  
   239  		rule.Invert = msg.Flags&FibRuleInvert > 0
   240  		rule.Family = int(msg.Family)
   241  		rule.Tos = uint(msg.Tos)
   242  
   243  		for j := range attrs {
   244  			switch attrs[j].Attr.Type {
   245  			case unix.RTA_TABLE:
   246  				rule.Table = int(native.Uint32(attrs[j].Value[0:4]))
   247  			case nl.FRA_SRC:
   248  				rule.Src = &net.IPNet{
   249  					IP:   attrs[j].Value,
   250  					Mask: net.CIDRMask(int(msg.Src_len), 8*len(attrs[j].Value)),
   251  				}
   252  			case nl.FRA_DST:
   253  				rule.Dst = &net.IPNet{
   254  					IP:   attrs[j].Value,
   255  					Mask: net.CIDRMask(int(msg.Dst_len), 8*len(attrs[j].Value)),
   256  				}
   257  			case nl.FRA_FWMARK:
   258  				rule.Mark = native.Uint32(attrs[j].Value[0:4])
   259  			case nl.FRA_FWMASK:
   260  				mask := native.Uint32(attrs[j].Value[0:4])
   261  				rule.Mask = &mask
   262  			case nl.FRA_TUN_ID:
   263  				rule.TunID = uint(native.Uint64(attrs[j].Value[0:8]))
   264  			case nl.FRA_IIFNAME:
   265  				rule.IifName = string(attrs[j].Value[:len(attrs[j].Value)-1])
   266  			case nl.FRA_OIFNAME:
   267  				rule.OifName = string(attrs[j].Value[:len(attrs[j].Value)-1])
   268  			case nl.FRA_SUPPRESS_PREFIXLEN:
   269  				i := native.Uint32(attrs[j].Value[0:4])
   270  				if i != 0xffffffff {
   271  					rule.SuppressPrefixlen = int(i)
   272  				}
   273  			case nl.FRA_SUPPRESS_IFGROUP:
   274  				i := native.Uint32(attrs[j].Value[0:4])
   275  				if i != 0xffffffff {
   276  					rule.SuppressIfgroup = int(i)
   277  				}
   278  			case nl.FRA_FLOW:
   279  				rule.Flow = int(native.Uint32(attrs[j].Value[0:4]))
   280  			case nl.FRA_GOTO:
   281  				rule.Goto = int(native.Uint32(attrs[j].Value[0:4]))
   282  			case nl.FRA_PRIORITY:
   283  				rule.Priority = int(native.Uint32(attrs[j].Value[0:4]))
   284  			case nl.FRA_IP_PROTO:
   285  				rule.IPProto = int(native.Uint32(attrs[j].Value[0:4]))
   286  			case nl.FRA_DPORT_RANGE:
   287  				rule.Dport = NewRulePortRange(native.Uint16(attrs[j].Value[0:2]), native.Uint16(attrs[j].Value[2:4]))
   288  			case nl.FRA_SPORT_RANGE:
   289  				rule.Sport = NewRulePortRange(native.Uint16(attrs[j].Value[0:2]), native.Uint16(attrs[j].Value[2:4]))
   290  			case nl.FRA_UID_RANGE:
   291  				rule.UIDRange = NewRuleUIDRange(native.Uint32(attrs[j].Value[0:4]), native.Uint32(attrs[j].Value[4:8]))
   292  			case nl.FRA_PROTOCOL:
   293  				rule.Protocol = uint8(attrs[j].Value[0])
   294  			}
   295  		}
   296  
   297  		if filter != nil {
   298  			switch {
   299  			case filterMask&RT_FILTER_SRC != 0 &&
   300  				(rule.Src == nil || rule.Src.String() != filter.Src.String()):
   301  				continue
   302  			case filterMask&RT_FILTER_DST != 0 &&
   303  				(rule.Dst == nil || rule.Dst.String() != filter.Dst.String()):
   304  				continue
   305  			case filterMask&RT_FILTER_TABLE != 0 &&
   306  				filter.Table != unix.RT_TABLE_UNSPEC && rule.Table != filter.Table:
   307  				continue
   308  			case filterMask&RT_FILTER_TOS != 0 && rule.Tos != filter.Tos:
   309  				continue
   310  			case filterMask&RT_FILTER_PRIORITY != 0 && rule.Priority != filter.Priority:
   311  				continue
   312  			case filterMask&RT_FILTER_MARK != 0 && rule.Mark != filter.Mark:
   313  				continue
   314  			case filterMask&RT_FILTER_MASK != 0 && !ptrEqual(rule.Mask, filter.Mask):
   315  				continue
   316  			}
   317  		}
   318  
   319  		res = append(res, *rule)
   320  	}
   321  
   322  	return res, executeErr
   323  }
   324  
   325  func (pr *RulePortRange) toRtAttrData() []byte {
   326  	b := [][]byte{make([]byte, 2), make([]byte, 2)}
   327  	native.PutUint16(b[0], pr.Start)
   328  	native.PutUint16(b[1], pr.End)
   329  	return bytes.Join(b, []byte{})
   330  }
   331  
   332  func (pr *RuleUIDRange) toRtAttrData() []byte {
   333  	b := [][]byte{make([]byte, 4), make([]byte, 4)}
   334  	native.PutUint32(b[0], pr.Start)
   335  	native.PutUint32(b[1], pr.End)
   336  	return bytes.Join(b, []byte{})
   337  }
   338  
   339  func ptrEqual(a, b *uint32) bool {
   340  	if a == b {
   341  		return true
   342  	}
   343  	if (a == nil) || (b == nil) {
   344  		return false
   345  	}
   346  	return *a == *b
   347  }
   348  
   349  func (r Rule) typeString() string {
   350  	switch r.Type {
   351  	case unix.RTN_UNSPEC: // zero
   352  		return ""
   353  	case unix.RTN_UNICAST:
   354  		return ""
   355  	case unix.RTN_LOCAL:
   356  		return "local"
   357  	case unix.RTN_BROADCAST:
   358  		return "broadcast"
   359  	case unix.RTN_ANYCAST:
   360  		return "anycast"
   361  	case unix.RTN_MULTICAST:
   362  		return "multicast"
   363  	case unix.RTN_BLACKHOLE:
   364  		return "blackhole"
   365  	case unix.RTN_UNREACHABLE:
   366  		return "unreachable"
   367  	case unix.RTN_PROHIBIT:
   368  		return "prohibit"
   369  	case unix.RTN_THROW:
   370  		return "throw"
   371  	case unix.RTN_NAT:
   372  		return "nat"
   373  	case unix.RTN_XRESOLVE:
   374  		return "xresolve"
   375  	default:
   376  		return fmt.Sprintf("type(0x%x)", r.Type)
   377  	}
   378  }