github.com/sagernet/netlink@v0.0.0-20240612041022-b9a21c07ac6a/xfrm_policy_linux.go (about)

     1  package netlink
     2  
     3  import (
     4  	"github.com/sagernet/netlink/nl"
     5  	"golang.org/x/sys/unix"
     6  )
     7  
     8  func selFromPolicy(sel *nl.XfrmSelector, policy *XfrmPolicy) {
     9  	sel.Family = uint16(nl.FAMILY_V4)
    10  	if policy.Dst != nil {
    11  		sel.Family = uint16(nl.GetIPFamily(policy.Dst.IP))
    12  		sel.Daddr.FromIP(policy.Dst.IP)
    13  		prefixlenD, _ := policy.Dst.Mask.Size()
    14  		sel.PrefixlenD = uint8(prefixlenD)
    15  	}
    16  	if policy.Src != nil {
    17  		sel.Saddr.FromIP(policy.Src.IP)
    18  		prefixlenS, _ := policy.Src.Mask.Size()
    19  		sel.PrefixlenS = uint8(prefixlenS)
    20  	}
    21  	sel.Proto = uint8(policy.Proto)
    22  	sel.Dport = nl.Swap16(uint16(policy.DstPort))
    23  	sel.Sport = nl.Swap16(uint16(policy.SrcPort))
    24  	if sel.Dport != 0 {
    25  		sel.DportMask = ^uint16(0)
    26  	}
    27  	if sel.Sport != 0 {
    28  		sel.SportMask = ^uint16(0)
    29  	}
    30  	sel.Ifindex = int32(policy.Ifindex)
    31  }
    32  
    33  // XfrmPolicyAdd will add an xfrm policy to the system.
    34  // Equivalent to: `ip xfrm policy add $policy`
    35  func XfrmPolicyAdd(policy *XfrmPolicy) error {
    36  	return pkgHandle.XfrmPolicyAdd(policy)
    37  }
    38  
    39  // XfrmPolicyAdd will add an xfrm policy to the system.
    40  // Equivalent to: `ip xfrm policy add $policy`
    41  func (h *Handle) XfrmPolicyAdd(policy *XfrmPolicy) error {
    42  	return h.xfrmPolicyAddOrUpdate(policy, nl.XFRM_MSG_NEWPOLICY)
    43  }
    44  
    45  // XfrmPolicyUpdate will update an xfrm policy to the system.
    46  // Equivalent to: `ip xfrm policy update $policy`
    47  func XfrmPolicyUpdate(policy *XfrmPolicy) error {
    48  	return pkgHandle.XfrmPolicyUpdate(policy)
    49  }
    50  
    51  // XfrmPolicyUpdate will update an xfrm policy to the system.
    52  // Equivalent to: `ip xfrm policy update $policy`
    53  func (h *Handle) XfrmPolicyUpdate(policy *XfrmPolicy) error {
    54  	return h.xfrmPolicyAddOrUpdate(policy, nl.XFRM_MSG_UPDPOLICY)
    55  }
    56  
    57  func (h *Handle) xfrmPolicyAddOrUpdate(policy *XfrmPolicy, nlProto int) error {
    58  	req := h.newNetlinkRequest(nlProto, unix.NLM_F_CREATE|unix.NLM_F_EXCL|unix.NLM_F_ACK)
    59  
    60  	msg := &nl.XfrmUserpolicyInfo{}
    61  	selFromPolicy(&msg.Sel, policy)
    62  	msg.Priority = uint32(policy.Priority)
    63  	msg.Index = uint32(policy.Index)
    64  	msg.Dir = uint8(policy.Dir)
    65  	msg.Action = uint8(policy.Action)
    66  	msg.Lft.SoftByteLimit = nl.XFRM_INF
    67  	msg.Lft.HardByteLimit = nl.XFRM_INF
    68  	msg.Lft.SoftPacketLimit = nl.XFRM_INF
    69  	msg.Lft.HardPacketLimit = nl.XFRM_INF
    70  	req.AddData(msg)
    71  
    72  	tmplData := make([]byte, nl.SizeofXfrmUserTmpl*len(policy.Tmpls))
    73  	for i, tmpl := range policy.Tmpls {
    74  		start := i * nl.SizeofXfrmUserTmpl
    75  		userTmpl := nl.DeserializeXfrmUserTmpl(tmplData[start : start+nl.SizeofXfrmUserTmpl])
    76  		userTmpl.XfrmId.Daddr.FromIP(tmpl.Dst)
    77  		userTmpl.Saddr.FromIP(tmpl.Src)
    78  		userTmpl.XfrmId.Proto = uint8(tmpl.Proto)
    79  		userTmpl.XfrmId.Spi = nl.Swap32(uint32(tmpl.Spi))
    80  		userTmpl.Mode = uint8(tmpl.Mode)
    81  		userTmpl.Reqid = uint32(tmpl.Reqid)
    82  		userTmpl.Optional = uint8(tmpl.Optional)
    83  		userTmpl.Aalgos = ^uint32(0)
    84  		userTmpl.Ealgos = ^uint32(0)
    85  		userTmpl.Calgos = ^uint32(0)
    86  	}
    87  	if len(tmplData) > 0 {
    88  		tmpls := nl.NewRtAttr(nl.XFRMA_TMPL, tmplData)
    89  		req.AddData(tmpls)
    90  	}
    91  	if policy.Mark != nil {
    92  		out := nl.NewRtAttr(nl.XFRMA_MARK, writeMark(policy.Mark))
    93  		req.AddData(out)
    94  	}
    95  
    96  	if policy.Ifid != 0 {
    97  		ifId := nl.NewRtAttr(nl.XFRMA_IF_ID, nl.Uint32Attr(uint32(policy.Ifid)))
    98  		req.AddData(ifId)
    99  	}
   100  
   101  	_, err := req.Execute(unix.NETLINK_XFRM, 0)
   102  	return err
   103  }
   104  
   105  // XfrmPolicyDel will delete an xfrm policy from the system. Note that
   106  // the Tmpls are ignored when matching the policy to delete.
   107  // Equivalent to: `ip xfrm policy del $policy`
   108  func XfrmPolicyDel(policy *XfrmPolicy) error {
   109  	return pkgHandle.XfrmPolicyDel(policy)
   110  }
   111  
   112  // XfrmPolicyDel will delete an xfrm policy from the system. Note that
   113  // the Tmpls are ignored when matching the policy to delete.
   114  // Equivalent to: `ip xfrm policy del $policy`
   115  func (h *Handle) XfrmPolicyDel(policy *XfrmPolicy) error {
   116  	_, err := h.xfrmPolicyGetOrDelete(policy, nl.XFRM_MSG_DELPOLICY)
   117  	return err
   118  }
   119  
   120  // XfrmPolicyList gets a list of xfrm policies in the system.
   121  // Equivalent to: `ip xfrm policy show`.
   122  // The list can be filtered by ip family.
   123  func XfrmPolicyList(family int) ([]XfrmPolicy, error) {
   124  	return pkgHandle.XfrmPolicyList(family)
   125  }
   126  
   127  // XfrmPolicyList gets a list of xfrm policies in the system.
   128  // Equivalent to: `ip xfrm policy show`.
   129  // The list can be filtered by ip family.
   130  func (h *Handle) XfrmPolicyList(family int) ([]XfrmPolicy, error) {
   131  	req := h.newNetlinkRequest(nl.XFRM_MSG_GETPOLICY, unix.NLM_F_DUMP)
   132  
   133  	msg := nl.NewIfInfomsg(family)
   134  	req.AddData(msg)
   135  
   136  	msgs, err := req.Execute(unix.NETLINK_XFRM, nl.XFRM_MSG_NEWPOLICY)
   137  	if err != nil {
   138  		return nil, err
   139  	}
   140  
   141  	var res []XfrmPolicy
   142  	for _, m := range msgs {
   143  		if policy, err := parseXfrmPolicy(m, family); err == nil {
   144  			res = append(res, *policy)
   145  		} else if err == familyError {
   146  			continue
   147  		} else {
   148  			return nil, err
   149  		}
   150  	}
   151  	return res, nil
   152  }
   153  
   154  // XfrmPolicyGet gets a the policy described by the index or selector, if found.
   155  // Equivalent to: `ip xfrm policy get { SELECTOR | index INDEX } dir DIR [ctx CTX ] [ mark MARK [ mask MASK ] ] [ ptype PTYPE ]`.
   156  func XfrmPolicyGet(policy *XfrmPolicy) (*XfrmPolicy, error) {
   157  	return pkgHandle.XfrmPolicyGet(policy)
   158  }
   159  
   160  // XfrmPolicyGet gets a the policy described by the index or selector, if found.
   161  // Equivalent to: `ip xfrm policy get { SELECTOR | index INDEX } dir DIR [ctx CTX ] [ mark MARK [ mask MASK ] ] [ ptype PTYPE ]`.
   162  func (h *Handle) XfrmPolicyGet(policy *XfrmPolicy) (*XfrmPolicy, error) {
   163  	return h.xfrmPolicyGetOrDelete(policy, nl.XFRM_MSG_GETPOLICY)
   164  }
   165  
   166  // XfrmPolicyFlush will flush the policies on the system.
   167  // Equivalent to: `ip xfrm policy flush`
   168  func XfrmPolicyFlush() error {
   169  	return pkgHandle.XfrmPolicyFlush()
   170  }
   171  
   172  // XfrmPolicyFlush will flush the policies on the system.
   173  // Equivalent to: `ip xfrm policy flush`
   174  func (h *Handle) XfrmPolicyFlush() error {
   175  	req := h.newNetlinkRequest(nl.XFRM_MSG_FLUSHPOLICY, unix.NLM_F_ACK)
   176  	_, err := req.Execute(unix.NETLINK_XFRM, 0)
   177  	return err
   178  }
   179  
   180  func (h *Handle) xfrmPolicyGetOrDelete(policy *XfrmPolicy, nlProto int) (*XfrmPolicy, error) {
   181  	req := h.newNetlinkRequest(nlProto, unix.NLM_F_ACK)
   182  
   183  	msg := &nl.XfrmUserpolicyId{}
   184  	selFromPolicy(&msg.Sel, policy)
   185  	msg.Index = uint32(policy.Index)
   186  	msg.Dir = uint8(policy.Dir)
   187  	req.AddData(msg)
   188  
   189  	if policy.Mark != nil {
   190  		out := nl.NewRtAttr(nl.XFRMA_MARK, writeMark(policy.Mark))
   191  		req.AddData(out)
   192  	}
   193  
   194  	if policy.Ifid != 0 {
   195  		ifId := nl.NewRtAttr(nl.XFRMA_IF_ID, nl.Uint32Attr(uint32(policy.Ifid)))
   196  		req.AddData(ifId)
   197  	}
   198  
   199  	resType := nl.XFRM_MSG_NEWPOLICY
   200  	if nlProto == nl.XFRM_MSG_DELPOLICY {
   201  		resType = 0
   202  	}
   203  
   204  	msgs, err := req.Execute(unix.NETLINK_XFRM, uint16(resType))
   205  	if err != nil {
   206  		return nil, err
   207  	}
   208  
   209  	if nlProto == nl.XFRM_MSG_DELPOLICY {
   210  		return nil, err
   211  	}
   212  
   213  	return parseXfrmPolicy(msgs[0], FAMILY_ALL)
   214  }
   215  
   216  func parseXfrmPolicy(m []byte, family int) (*XfrmPolicy, error) {
   217  	msg := nl.DeserializeXfrmUserpolicyInfo(m)
   218  
   219  	// This is mainly for the policy dump
   220  	if family != FAMILY_ALL && family != int(msg.Sel.Family) {
   221  		return nil, familyError
   222  	}
   223  
   224  	var policy XfrmPolicy
   225  
   226  	policy.Dst = msg.Sel.Daddr.ToIPNet(msg.Sel.PrefixlenD)
   227  	policy.Src = msg.Sel.Saddr.ToIPNet(msg.Sel.PrefixlenS)
   228  	policy.Proto = Proto(msg.Sel.Proto)
   229  	policy.DstPort = int(nl.Swap16(msg.Sel.Dport))
   230  	policy.SrcPort = int(nl.Swap16(msg.Sel.Sport))
   231  	policy.Ifindex = int(msg.Sel.Ifindex)
   232  	policy.Priority = int(msg.Priority)
   233  	policy.Index = int(msg.Index)
   234  	policy.Dir = Dir(msg.Dir)
   235  	policy.Action = PolicyAction(msg.Action)
   236  
   237  	attrs, err := nl.ParseRouteAttr(m[msg.Len():])
   238  	if err != nil {
   239  		return nil, err
   240  	}
   241  
   242  	for _, attr := range attrs {
   243  		switch attr.Attr.Type {
   244  		case nl.XFRMA_TMPL:
   245  			max := len(attr.Value)
   246  			for i := 0; i < max; i += nl.SizeofXfrmUserTmpl {
   247  				var resTmpl XfrmPolicyTmpl
   248  				tmpl := nl.DeserializeXfrmUserTmpl(attr.Value[i : i+nl.SizeofXfrmUserTmpl])
   249  				resTmpl.Dst = tmpl.XfrmId.Daddr.ToIP()
   250  				resTmpl.Src = tmpl.Saddr.ToIP()
   251  				resTmpl.Proto = Proto(tmpl.XfrmId.Proto)
   252  				resTmpl.Mode = Mode(tmpl.Mode)
   253  				resTmpl.Spi = int(nl.Swap32(tmpl.XfrmId.Spi))
   254  				resTmpl.Reqid = int(tmpl.Reqid)
   255  				resTmpl.Optional = int(tmpl.Optional)
   256  				policy.Tmpls = append(policy.Tmpls, resTmpl)
   257  			}
   258  		case nl.XFRMA_MARK:
   259  			mark := nl.DeserializeXfrmMark(attr.Value[:])
   260  			policy.Mark = new(XfrmMark)
   261  			policy.Mark.Value = mark.Value
   262  			policy.Mark.Mask = mark.Mask
   263  		case nl.XFRMA_IF_ID:
   264  			policy.Ifid = int(native.Uint32(attr.Value))
   265  		}
   266  	}
   267  
   268  	return &policy, nil
   269  }