github.com/sagernet/netlink@v0.0.0-20240612041022-b9a21c07ac6a/rule_linux.go (about)

     1  package netlink
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"net/netip"
     7  
     8  	"github.com/sagernet/netlink/nl"
     9  	"golang.org/x/sys/unix"
    10  )
    11  
    12  const FibRuleInvert = 0x2
    13  
    14  // RuleAdd adds a rule to the system.
    15  // Equivalent to: ip rule add
    16  func RuleAdd(rule *Rule) error {
    17  	return pkgHandle.RuleAdd(rule)
    18  }
    19  
    20  // RuleAdd adds a rule to the system.
    21  // Equivalent to: ip rule add
    22  func (h *Handle) RuleAdd(rule *Rule) error {
    23  	req := h.newNetlinkRequest(unix.RTM_NEWRULE, unix.NLM_F_CREATE|unix.NLM_F_EXCL|unix.NLM_F_ACK)
    24  	return ruleHandle(rule, req)
    25  }
    26  
    27  // RuleDel deletes a rule from the system.
    28  // Equivalent to: ip rule del
    29  func RuleDel(rule *Rule) error {
    30  	return pkgHandle.RuleDel(rule)
    31  }
    32  
    33  // RuleDel deletes a rule from the system.
    34  // Equivalent to: ip rule del
    35  func (h *Handle) RuleDel(rule *Rule) error {
    36  	req := h.newNetlinkRequest(unix.RTM_DELRULE, unix.NLM_F_ACK)
    37  	return ruleHandle(rule, req)
    38  }
    39  
    40  func ruleHandle(rule *Rule, req *nl.NetlinkRequest) error {
    41  	msg := nl.NewRtMsg()
    42  	msg.Family = unix.AF_UNSPEC
    43  	msg.Protocol = unix.RTPROT_BOOT
    44  	msg.Scope = unix.RT_SCOPE_UNIVERSE
    45  	msg.Table = unix.RT_TABLE_UNSPEC
    46  	msg.Type = unix.RTN_UNSPEC
    47  	if rule.Type > 0 {
    48  		msg.Type = rule.Type
    49  	} else if rule.Table >= 0 {
    50  		msg.Type = unix.FR_ACT_TO_TBL
    51  	} else if rule.Goto >= 0 {
    52  		msg.Type = unix.FR_ACT_GOTO
    53  	} else if req.NlMsghdr.Flags&unix.NLM_F_CREATE > 0 {
    54  		msg.Type = unix.FR_ACT_NOP
    55  	}
    56  	if rule.Invert {
    57  		msg.Flags |= FibRuleInvert
    58  	}
    59  	if rule.Family != 0 {
    60  		msg.Family = uint8(rule.Family)
    61  	}
    62  	if rule.Table >= 0 && rule.Table < 256 {
    63  		msg.Table = uint8(rule.Table)
    64  	}
    65  	if rule.Tos != 0 {
    66  		msg.Tos = uint8(rule.Tos)
    67  	}
    68  
    69  	var dstFamily uint8
    70  	var rtAttrs []*nl.RtAttr
    71  
    72  	if rule.Dst.IsValid() {
    73  		msg.Dst_len = uint8(rule.Dst.Bits())
    74  		msg.Family = uint8(nl.GetIPFamily(rule.Dst.Addr().AsSlice()))
    75  		dstFamily = msg.Family
    76  		var dstData []byte
    77  		if msg.Family == unix.AF_INET {
    78  			dstData = netip.AddrFrom4(rule.Dst.Addr().As4()).AsSlice()
    79  		} else {
    80  			dstData = rule.Dst.Addr().AsSlice()
    81  		}
    82  		rtAttrs = append(rtAttrs, nl.NewRtAttr(unix.RTA_DST, dstData))
    83  	}
    84  
    85  	if rule.Src.IsValid() {
    86  		msg.Src_len = uint8(rule.Src.Bits())
    87  		msg.Family = uint8(nl.GetIPFamily(rule.Src.Addr().AsSlice()))
    88  		if dstFamily != 0 && dstFamily != msg.Family {
    89  			return fmt.Errorf("source and destination ip are not the same IP family")
    90  		}
    91  		var srcData []byte
    92  		if msg.Family == unix.AF_INET {
    93  			srcData = netip.AddrFrom4(rule.Src.Addr().As4()).AsSlice()
    94  		} else {
    95  			srcData = rule.Src.Addr().AsSlice()
    96  		}
    97  		rtAttrs = append(rtAttrs, nl.NewRtAttr(unix.RTA_SRC, srcData))
    98  	}
    99  
   100  	req.AddData(msg)
   101  	for i := range rtAttrs {
   102  		req.AddData(rtAttrs[i])
   103  	}
   104  
   105  	if rule.Priority >= 0 {
   106  		b := make([]byte, 4)
   107  		native.PutUint32(b, uint32(rule.Priority))
   108  		req.AddData(nl.NewRtAttr(nl.FRA_PRIORITY, b))
   109  	}
   110  	if rule.MarkSet {
   111  		b := make([]byte, 4)
   112  		native.PutUint32(b, rule.Mark)
   113  		req.AddData(nl.NewRtAttr(nl.FRA_FWMARK, b))
   114  	}
   115  	if rule.Mask >= 0 {
   116  		b := make([]byte, 4)
   117  		native.PutUint32(b, uint32(rule.Mask))
   118  		req.AddData(nl.NewRtAttr(nl.FRA_FWMASK, b))
   119  	}
   120  	if rule.Flow >= 0 {
   121  		b := make([]byte, 4)
   122  		native.PutUint32(b, uint32(rule.Flow))
   123  		req.AddData(nl.NewRtAttr(nl.FRA_FLOW, b))
   124  	}
   125  	if rule.TunID > 0 {
   126  		b := make([]byte, 4)
   127  		native.PutUint32(b, uint32(rule.TunID))
   128  		req.AddData(nl.NewRtAttr(nl.FRA_TUN_ID, b))
   129  	}
   130  	if rule.Table >= 0 {
   131  		b := make([]byte, 4)
   132  		native.PutUint32(b, uint32(rule.Table))
   133  		req.AddData(nl.NewRtAttr(nl.FRA_TABLE, b))
   134  		if rule.SuppressPrefixlen >= 0 {
   135  			b := make([]byte, 4)
   136  			native.PutUint32(b, uint32(rule.SuppressPrefixlen))
   137  			req.AddData(nl.NewRtAttr(nl.FRA_SUPPRESS_PREFIXLEN, b))
   138  		}
   139  		if rule.SuppressIfgroup >= 0 {
   140  			b := make([]byte, 4)
   141  			native.PutUint32(b, uint32(rule.SuppressIfgroup))
   142  			req.AddData(nl.NewRtAttr(nl.FRA_SUPPRESS_IFGROUP, b))
   143  		}
   144  	}
   145  	if rule.IifName != "" {
   146  		req.AddData(nl.NewRtAttr(nl.FRA_IIFNAME, []byte(rule.IifName+"\x00")))
   147  	}
   148  	if rule.OifName != "" {
   149  		req.AddData(nl.NewRtAttr(nl.FRA_OIFNAME, []byte(rule.OifName+"\x00")))
   150  	}
   151  	if rule.Goto >= 0 {
   152  		b := make([]byte, 4)
   153  		native.PutUint32(b, uint32(rule.Goto))
   154  		req.AddData(nl.NewRtAttr(nl.FRA_GOTO, b))
   155  	}
   156  
   157  	if rule.IPProto > 0 {
   158  		b := make([]byte, 4)
   159  		native.PutUint32(b, uint32(rule.IPProto))
   160  		req.AddData(nl.NewRtAttr(nl.FRA_IP_PROTO, b))
   161  	}
   162  
   163  	if rule.Dport != nil {
   164  		b := rule.Dport.toRtAttrData()
   165  		req.AddData(nl.NewRtAttr(nl.FRA_DPORT_RANGE, b))
   166  	}
   167  
   168  	if rule.Sport != nil {
   169  		b := rule.Sport.toRtAttrData()
   170  		req.AddData(nl.NewRtAttr(nl.FRA_SPORT_RANGE, b))
   171  	}
   172  
   173  	if rule.UIDRange != nil {
   174  		b := rule.UIDRange.toRtAttrData()
   175  		req.AddData(nl.NewRtAttr(nl.FRA_UID_RANGE, b))
   176  	}
   177  
   178  	_, err := req.Execute(unix.NETLINK_ROUTE, 0)
   179  	return err
   180  }
   181  
   182  // RuleList lists rules in the system.
   183  // Equivalent to: ip rule list
   184  func RuleList(family int) ([]Rule, error) {
   185  	return pkgHandle.RuleList(family)
   186  }
   187  
   188  // RuleList lists rules in the system.
   189  // Equivalent to: ip rule list
   190  func (h *Handle) RuleList(family int) ([]Rule, error) {
   191  	return h.RuleListFiltered(family, nil, 0)
   192  }
   193  
   194  // RuleListFiltered gets a list of rules in the system filtered by the
   195  // specified rule template `filter`.
   196  // Equivalent to: ip rule list
   197  func RuleListFiltered(family int, filter *Rule, filterMask uint64) ([]Rule, error) {
   198  	return pkgHandle.RuleListFiltered(family, filter, filterMask)
   199  }
   200  
   201  // RuleListFiltered lists rules in the system.
   202  // Equivalent to: ip rule list
   203  func (h *Handle) RuleListFiltered(family int, filter *Rule, filterMask uint64) ([]Rule, error) {
   204  	req := h.newNetlinkRequest(unix.RTM_GETRULE, unix.NLM_F_DUMP|unix.NLM_F_REQUEST)
   205  	msg := nl.NewIfInfomsg(family)
   206  	req.AddData(msg)
   207  
   208  	msgs, err := req.Execute(unix.NETLINK_ROUTE, unix.RTM_NEWRULE)
   209  	if err != nil {
   210  		return nil, err
   211  	}
   212  
   213  	var res = make([]Rule, 0)
   214  	for i := range msgs {
   215  		msg := nl.DeserializeRtMsg(msgs[i])
   216  		attrs, err := nl.ParseRouteAttr(msgs[i][msg.Len():])
   217  		if err != nil {
   218  			return nil, err
   219  		}
   220  
   221  		rule := NewRule()
   222  		rule.Family = int(msg.Family)
   223  		rule.Invert = msg.Flags&FibRuleInvert > 0
   224  		rule.Tos = uint(msg.Tos)
   225  
   226  		for j := range attrs {
   227  			switch attrs[j].Attr.Type {
   228  			case unix.RTA_TABLE:
   229  				rule.Table = int(native.Uint32(attrs[j].Value[0:4]))
   230  			case nl.FRA_SRC:
   231  				addr, _ := netip.AddrFromSlice(attrs[j].Value)
   232  				if addr.Is4In6() {
   233  					addr = netip.AddrFrom4(addr.As4())
   234  				}
   235  				rule.Src = netip.PrefixFrom(addr, int(msg.Src_len))
   236  			case nl.FRA_DST:
   237  				addr, _ := netip.AddrFromSlice(attrs[j].Value)
   238  				if addr.Is4In6() {
   239  					addr = netip.AddrFrom4(addr.As4())
   240  				}
   241  				rule.Dst = netip.PrefixFrom(addr, int(msg.Dst_len))
   242  			case nl.FRA_FWMARK:
   243  				rule.Mark = native.Uint32(attrs[j].Value[0:4])
   244  			case nl.FRA_FWMASK:
   245  				rule.Mask = int(native.Uint32(attrs[j].Value[0:4]))
   246  			case nl.FRA_TUN_ID:
   247  				rule.TunID = uint(native.Uint64(attrs[j].Value[0:8]))
   248  			case nl.FRA_IIFNAME:
   249  				rule.IifName = string(attrs[j].Value[:len(attrs[j].Value)-1])
   250  			case nl.FRA_OIFNAME:
   251  				rule.OifName = string(attrs[j].Value[:len(attrs[j].Value)-1])
   252  			case nl.FRA_SUPPRESS_PREFIXLEN:
   253  				i := native.Uint32(attrs[j].Value[0:4])
   254  				if i != 0xffffffff {
   255  					rule.SuppressPrefixlen = int(i)
   256  				}
   257  			case nl.FRA_SUPPRESS_IFGROUP:
   258  				i := native.Uint32(attrs[j].Value[0:4])
   259  				if i != 0xffffffff {
   260  					rule.SuppressIfgroup = int(i)
   261  				}
   262  			case nl.FRA_FLOW:
   263  				rule.Flow = int(native.Uint32(attrs[j].Value[0:4]))
   264  			case nl.FRA_GOTO:
   265  				rule.Goto = int(native.Uint32(attrs[j].Value[0:4]))
   266  			case nl.FRA_PRIORITY:
   267  				rule.Priority = int(native.Uint32(attrs[j].Value[0:4]))
   268  			case nl.FRA_IP_PROTO:
   269  				rule.IPProto = int(native.Uint32(attrs[j].Value[0:4]))
   270  			case nl.FRA_DPORT_RANGE:
   271  				rule.Dport = NewRulePortRange(native.Uint16(attrs[j].Value[0:2]), native.Uint16(attrs[j].Value[2:4]))
   272  			case nl.FRA_SPORT_RANGE:
   273  				rule.Sport = NewRulePortRange(native.Uint16(attrs[j].Value[0:2]), native.Uint16(attrs[j].Value[2:4]))
   274  			case nl.FRA_UID_RANGE:
   275  				rule.UIDRange = NewRuleUIDRange(native.Uint32(attrs[j].Value[0:4]), native.Uint32(attrs[j].Value[4:8]))
   276  			}
   277  		}
   278  
   279  		if filter != nil {
   280  			switch {
   281  			case filterMask&RT_FILTER_SRC != 0 &&
   282  				(rule.Src.IsValid() || rule.Src.String() != filter.Src.String()):
   283  				continue
   284  			case filterMask&RT_FILTER_DST != 0 &&
   285  				(rule.Dst.IsValid() || rule.Dst.String() != filter.Dst.String()):
   286  				continue
   287  			case filterMask&RT_FILTER_TABLE != 0 &&
   288  				filter.Table != unix.RT_TABLE_UNSPEC && rule.Table != filter.Table:
   289  				continue
   290  			case filterMask&RT_FILTER_TOS != 0 && rule.Tos != filter.Tos:
   291  				continue
   292  			case filterMask&RT_FILTER_PRIORITY != 0 && rule.Priority != filter.Priority:
   293  				continue
   294  			case filterMask&RT_FILTER_MARK != 0 && rule.Mark != filter.Mark:
   295  				continue
   296  			case filterMask&RT_FILTER_MASK != 0 && rule.Mask != filter.Mask:
   297  				continue
   298  			}
   299  		}
   300  
   301  		res = append(res, *rule)
   302  	}
   303  
   304  	return res, nil
   305  }
   306  
   307  func (pr *RulePortRange) toRtAttrData() []byte {
   308  	b := [][]byte{make([]byte, 2), make([]byte, 2)}
   309  	native.PutUint16(b[0], pr.Start)
   310  	native.PutUint16(b[1], pr.End)
   311  	return bytes.Join(b, []byte{})
   312  }
   313  
   314  func (pr *RuleUIDRange) toRtAttrData() []byte {
   315  	b := [][]byte{make([]byte, 4), make([]byte, 4)}
   316  	native.PutUint32(b[0], pr.Start)
   317  	native.PutUint32(b[1], pr.End)
   318  	return bytes.Join(b, []byte{})
   319  }