github.com/vishvananda/netlink@v1.3.0/rule_linux.go (about)

     1  package netlink
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"net"
     7  
     8  	"github.com/vishvananda/netlink/nl"
     9  	"golang.org/x/sys/unix"
    10  )
    11  
    12  const FibRuleInvert = 0x2
    13  
    14  // RuleAdd adds a rule to the system.
    15  // Equivalent to: ip rule add
    16  func RuleAdd(rule *Rule) error {
    17  	return pkgHandle.RuleAdd(rule)
    18  }
    19  
    20  // RuleAdd adds a rule to the system.
    21  // Equivalent to: ip rule add
    22  func (h *Handle) RuleAdd(rule *Rule) error {
    23  	req := h.newNetlinkRequest(unix.RTM_NEWRULE, unix.NLM_F_CREATE|unix.NLM_F_EXCL|unix.NLM_F_ACK)
    24  	return ruleHandle(rule, req)
    25  }
    26  
    27  // RuleDel deletes a rule from the system.
    28  // Equivalent to: ip rule del
    29  func RuleDel(rule *Rule) error {
    30  	return pkgHandle.RuleDel(rule)
    31  }
    32  
    33  // RuleDel deletes a rule from the system.
    34  // Equivalent to: ip rule del
    35  func (h *Handle) RuleDel(rule *Rule) error {
    36  	req := h.newNetlinkRequest(unix.RTM_DELRULE, unix.NLM_F_ACK)
    37  	return ruleHandle(rule, req)
    38  }
    39  
    40  func ruleHandle(rule *Rule, req *nl.NetlinkRequest) error {
    41  	msg := nl.NewRtMsg()
    42  	msg.Family = unix.AF_INET
    43  	msg.Protocol = unix.RTPROT_BOOT
    44  	msg.Scope = unix.RT_SCOPE_UNIVERSE
    45  	msg.Table = unix.RT_TABLE_UNSPEC
    46  	msg.Type = rule.Type // usually 0, same as unix.RTN_UNSPEC
    47  	if msg.Type == 0 && req.NlMsghdr.Flags&unix.NLM_F_CREATE > 0 {
    48  		msg.Type = unix.RTN_UNICAST
    49  	}
    50  	if rule.Invert {
    51  		msg.Flags |= FibRuleInvert
    52  	}
    53  	if rule.Family != 0 {
    54  		msg.Family = uint8(rule.Family)
    55  	}
    56  	if rule.Table >= 0 && rule.Table < 256 {
    57  		msg.Table = uint8(rule.Table)
    58  	}
    59  	if rule.Tos != 0 {
    60  		msg.Tos = uint8(rule.Tos)
    61  	}
    62  
    63  	var dstFamily uint8
    64  	var rtAttrs []*nl.RtAttr
    65  	if rule.Dst != nil && rule.Dst.IP != nil {
    66  		dstLen, _ := rule.Dst.Mask.Size()
    67  		msg.Dst_len = uint8(dstLen)
    68  		msg.Family = uint8(nl.GetIPFamily(rule.Dst.IP))
    69  		dstFamily = msg.Family
    70  		var dstData []byte
    71  		if msg.Family == unix.AF_INET {
    72  			dstData = rule.Dst.IP.To4()
    73  		} else {
    74  			dstData = rule.Dst.IP.To16()
    75  		}
    76  		rtAttrs = append(rtAttrs, nl.NewRtAttr(unix.RTA_DST, dstData))
    77  	}
    78  
    79  	if rule.Src != nil && rule.Src.IP != nil {
    80  		msg.Family = uint8(nl.GetIPFamily(rule.Src.IP))
    81  		if dstFamily != 0 && dstFamily != msg.Family {
    82  			return fmt.Errorf("source and destination ip are not the same IP family")
    83  		}
    84  		srcLen, _ := rule.Src.Mask.Size()
    85  		msg.Src_len = uint8(srcLen)
    86  		var srcData []byte
    87  		if msg.Family == unix.AF_INET {
    88  			srcData = rule.Src.IP.To4()
    89  		} else {
    90  			srcData = rule.Src.IP.To16()
    91  		}
    92  		rtAttrs = append(rtAttrs, nl.NewRtAttr(unix.RTA_SRC, srcData))
    93  	}
    94  
    95  	req.AddData(msg)
    96  	for i := range rtAttrs {
    97  		req.AddData(rtAttrs[i])
    98  	}
    99  
   100  	if rule.Priority >= 0 {
   101  		b := make([]byte, 4)
   102  		native.PutUint32(b, uint32(rule.Priority))
   103  		req.AddData(nl.NewRtAttr(nl.FRA_PRIORITY, b))
   104  	}
   105  	if rule.Mark != 0 || rule.Mask != nil {
   106  		b := make([]byte, 4)
   107  		native.PutUint32(b, rule.Mark)
   108  		req.AddData(nl.NewRtAttr(nl.FRA_FWMARK, b))
   109  	}
   110  	if rule.Mask != nil {
   111  		b := make([]byte, 4)
   112  		native.PutUint32(b, *rule.Mask)
   113  		req.AddData(nl.NewRtAttr(nl.FRA_FWMASK, b))
   114  	}
   115  	if rule.Flow >= 0 {
   116  		b := make([]byte, 4)
   117  		native.PutUint32(b, uint32(rule.Flow))
   118  		req.AddData(nl.NewRtAttr(nl.FRA_FLOW, b))
   119  	}
   120  	if rule.TunID > 0 {
   121  		b := make([]byte, 4)
   122  		native.PutUint32(b, uint32(rule.TunID))
   123  		req.AddData(nl.NewRtAttr(nl.FRA_TUN_ID, b))
   124  	}
   125  	if rule.Table >= 256 {
   126  		b := make([]byte, 4)
   127  		native.PutUint32(b, uint32(rule.Table))
   128  		req.AddData(nl.NewRtAttr(nl.FRA_TABLE, b))
   129  	}
   130  	if msg.Table > 0 {
   131  		if rule.SuppressPrefixlen >= 0 {
   132  			b := make([]byte, 4)
   133  			native.PutUint32(b, uint32(rule.SuppressPrefixlen))
   134  			req.AddData(nl.NewRtAttr(nl.FRA_SUPPRESS_PREFIXLEN, b))
   135  		}
   136  		if rule.SuppressIfgroup >= 0 {
   137  			b := make([]byte, 4)
   138  			native.PutUint32(b, uint32(rule.SuppressIfgroup))
   139  			req.AddData(nl.NewRtAttr(nl.FRA_SUPPRESS_IFGROUP, b))
   140  		}
   141  	}
   142  	if rule.IifName != "" {
   143  		req.AddData(nl.NewRtAttr(nl.FRA_IIFNAME, []byte(rule.IifName+"\x00")))
   144  	}
   145  	if rule.OifName != "" {
   146  		req.AddData(nl.NewRtAttr(nl.FRA_OIFNAME, []byte(rule.OifName+"\x00")))
   147  	}
   148  	if rule.Goto >= 0 {
   149  		msg.Type = nl.FR_ACT_GOTO
   150  		b := make([]byte, 4)
   151  		native.PutUint32(b, uint32(rule.Goto))
   152  		req.AddData(nl.NewRtAttr(nl.FRA_GOTO, b))
   153  	}
   154  
   155  	if rule.IPProto > 0 {
   156  		b := make([]byte, 4)
   157  		native.PutUint32(b, uint32(rule.IPProto))
   158  		req.AddData(nl.NewRtAttr(nl.FRA_IP_PROTO, b))
   159  	}
   160  
   161  	if rule.Dport != nil {
   162  		b := rule.Dport.toRtAttrData()
   163  		req.AddData(nl.NewRtAttr(nl.FRA_DPORT_RANGE, b))
   164  	}
   165  
   166  	if rule.Sport != nil {
   167  		b := rule.Sport.toRtAttrData()
   168  		req.AddData(nl.NewRtAttr(nl.FRA_SPORT_RANGE, b))
   169  	}
   170  
   171  	if rule.UIDRange != nil {
   172  		b := rule.UIDRange.toRtAttrData()
   173  		req.AddData(nl.NewRtAttr(nl.FRA_UID_RANGE, b))
   174  	}
   175  
   176  	if rule.Protocol > 0 {
   177  		req.AddData(nl.NewRtAttr(nl.FRA_PROTOCOL, nl.Uint8Attr(rule.Protocol)))
   178  	}
   179  
   180  	_, err := req.Execute(unix.NETLINK_ROUTE, 0)
   181  	return err
   182  }
   183  
   184  // RuleList lists rules in the system.
   185  // Equivalent to: ip rule list
   186  func RuleList(family int) ([]Rule, error) {
   187  	return pkgHandle.RuleList(family)
   188  }
   189  
   190  // RuleList lists rules in the system.
   191  // Equivalent to: ip rule list
   192  func (h *Handle) RuleList(family int) ([]Rule, error) {
   193  	return h.RuleListFiltered(family, nil, 0)
   194  }
   195  
   196  // RuleListFiltered gets a list of rules in the system filtered by the
   197  // specified rule template `filter`.
   198  // Equivalent to: ip rule list
   199  func RuleListFiltered(family int, filter *Rule, filterMask uint64) ([]Rule, error) {
   200  	return pkgHandle.RuleListFiltered(family, filter, filterMask)
   201  }
   202  
   203  // RuleListFiltered lists rules in the system.
   204  // Equivalent to: ip rule list
   205  func (h *Handle) RuleListFiltered(family int, filter *Rule, filterMask uint64) ([]Rule, error) {
   206  	req := h.newNetlinkRequest(unix.RTM_GETRULE, unix.NLM_F_DUMP|unix.NLM_F_REQUEST)
   207  	msg := nl.NewIfInfomsg(family)
   208  	req.AddData(msg)
   209  
   210  	msgs, err := req.Execute(unix.NETLINK_ROUTE, unix.RTM_NEWRULE)
   211  	if err != nil {
   212  		return nil, err
   213  	}
   214  
   215  	var res = make([]Rule, 0)
   216  	for i := range msgs {
   217  		msg := nl.DeserializeRtMsg(msgs[i])
   218  		attrs, err := nl.ParseRouteAttr(msgs[i][msg.Len():])
   219  		if err != nil {
   220  			return nil, err
   221  		}
   222  
   223  		rule := NewRule()
   224  		rule.Priority = 0 // The default priority from kernel
   225  
   226  		rule.Invert = msg.Flags&FibRuleInvert > 0
   227  		rule.Family = int(msg.Family)
   228  		rule.Tos = uint(msg.Tos)
   229  
   230  		for j := range attrs {
   231  			switch attrs[j].Attr.Type {
   232  			case unix.RTA_TABLE:
   233  				rule.Table = int(native.Uint32(attrs[j].Value[0:4]))
   234  			case nl.FRA_SRC:
   235  				rule.Src = &net.IPNet{
   236  					IP:   attrs[j].Value,
   237  					Mask: net.CIDRMask(int(msg.Src_len), 8*len(attrs[j].Value)),
   238  				}
   239  			case nl.FRA_DST:
   240  				rule.Dst = &net.IPNet{
   241  					IP:   attrs[j].Value,
   242  					Mask: net.CIDRMask(int(msg.Dst_len), 8*len(attrs[j].Value)),
   243  				}
   244  			case nl.FRA_FWMARK:
   245  				rule.Mark = native.Uint32(attrs[j].Value[0:4])
   246  			case nl.FRA_FWMASK:
   247  				mask := native.Uint32(attrs[j].Value[0:4])
   248  				rule.Mask = &mask
   249  			case nl.FRA_TUN_ID:
   250  				rule.TunID = uint(native.Uint64(attrs[j].Value[0:8]))
   251  			case nl.FRA_IIFNAME:
   252  				rule.IifName = string(attrs[j].Value[:len(attrs[j].Value)-1])
   253  			case nl.FRA_OIFNAME:
   254  				rule.OifName = string(attrs[j].Value[:len(attrs[j].Value)-1])
   255  			case nl.FRA_SUPPRESS_PREFIXLEN:
   256  				i := native.Uint32(attrs[j].Value[0:4])
   257  				if i != 0xffffffff {
   258  					rule.SuppressPrefixlen = int(i)
   259  				}
   260  			case nl.FRA_SUPPRESS_IFGROUP:
   261  				i := native.Uint32(attrs[j].Value[0:4])
   262  				if i != 0xffffffff {
   263  					rule.SuppressIfgroup = int(i)
   264  				}
   265  			case nl.FRA_FLOW:
   266  				rule.Flow = int(native.Uint32(attrs[j].Value[0:4]))
   267  			case nl.FRA_GOTO:
   268  				rule.Goto = int(native.Uint32(attrs[j].Value[0:4]))
   269  			case nl.FRA_PRIORITY:
   270  				rule.Priority = int(native.Uint32(attrs[j].Value[0:4]))
   271  			case nl.FRA_IP_PROTO:
   272  				rule.IPProto = int(native.Uint32(attrs[j].Value[0:4]))
   273  			case nl.FRA_DPORT_RANGE:
   274  				rule.Dport = NewRulePortRange(native.Uint16(attrs[j].Value[0:2]), native.Uint16(attrs[j].Value[2:4]))
   275  			case nl.FRA_SPORT_RANGE:
   276  				rule.Sport = NewRulePortRange(native.Uint16(attrs[j].Value[0:2]), native.Uint16(attrs[j].Value[2:4]))
   277  			case nl.FRA_UID_RANGE:
   278  				rule.UIDRange = NewRuleUIDRange(native.Uint32(attrs[j].Value[0:4]), native.Uint32(attrs[j].Value[4:8]))
   279  			case nl.FRA_PROTOCOL:
   280  				rule.Protocol = uint8(attrs[j].Value[0])
   281  			}
   282  		}
   283  
   284  		if filter != nil {
   285  			switch {
   286  			case filterMask&RT_FILTER_SRC != 0 &&
   287  				(rule.Src == nil || rule.Src.String() != filter.Src.String()):
   288  				continue
   289  			case filterMask&RT_FILTER_DST != 0 &&
   290  				(rule.Dst == nil || rule.Dst.String() != filter.Dst.String()):
   291  				continue
   292  			case filterMask&RT_FILTER_TABLE != 0 &&
   293  				filter.Table != unix.RT_TABLE_UNSPEC && rule.Table != filter.Table:
   294  				continue
   295  			case filterMask&RT_FILTER_TOS != 0 && rule.Tos != filter.Tos:
   296  				continue
   297  			case filterMask&RT_FILTER_PRIORITY != 0 && rule.Priority != filter.Priority:
   298  				continue
   299  			case filterMask&RT_FILTER_MARK != 0 && rule.Mark != filter.Mark:
   300  				continue
   301  			case filterMask&RT_FILTER_MASK != 0 && !ptrEqual(rule.Mask, filter.Mask):
   302  				continue
   303  			}
   304  		}
   305  
   306  		res = append(res, *rule)
   307  	}
   308  
   309  	return res, nil
   310  }
   311  
   312  func (pr *RulePortRange) toRtAttrData() []byte {
   313  	b := [][]byte{make([]byte, 2), make([]byte, 2)}
   314  	native.PutUint16(b[0], pr.Start)
   315  	native.PutUint16(b[1], pr.End)
   316  	return bytes.Join(b, []byte{})
   317  }
   318  
   319  func (pr *RuleUIDRange) toRtAttrData() []byte {
   320  	b := [][]byte{make([]byte, 4), make([]byte, 4)}
   321  	native.PutUint32(b[0], pr.Start)
   322  	native.PutUint32(b[1], pr.End)
   323  	return bytes.Join(b, []byte{})
   324  }
   325  
   326  func ptrEqual(a, b *uint32) bool {
   327  	if a == b {
   328  		return true
   329  	}
   330  	if (a == nil) || (b == nil) {
   331  		return false
   332  	}
   333  	return *a == *b
   334  }
   335  
   336  func (r Rule) typeString() string {
   337  	switch r.Type {
   338  	case unix.RTN_UNSPEC: // zero
   339  		return ""
   340  	case unix.RTN_UNICAST:
   341  		return ""
   342  	case unix.RTN_LOCAL:
   343  		return "local"
   344  	case unix.RTN_BROADCAST:
   345  		return "broadcast"
   346  	case unix.RTN_ANYCAST:
   347  		return "anycast"
   348  	case unix.RTN_MULTICAST:
   349  		return "multicast"
   350  	case unix.RTN_BLACKHOLE:
   351  		return "blackhole"
   352  	case unix.RTN_UNREACHABLE:
   353  		return "unreachable"
   354  	case unix.RTN_PROHIBIT:
   355  		return "prohibit"
   356  	case unix.RTN_THROW:
   357  		return "throw"
   358  	case unix.RTN_NAT:
   359  		return "nat"
   360  	case unix.RTN_XRESOLVE:
   361  		return "xresolve"
   362  	default:
   363  		return fmt.Sprintf("type(0x%x)", r.Type)
   364  	}
   365  }