github.com/sagernet/netlink@v0.0.0-20240612041022-b9a21c07ac6a/xfrm_state_linux.go (about)

     1  package netlink
     2  
     3  import (
     4  	"fmt"
     5  	"unsafe"
     6  
     7  	"github.com/sagernet/netlink/nl"
     8  	"golang.org/x/sys/unix"
     9  )
    10  
    11  func writeStateAlgo(a *XfrmStateAlgo) []byte {
    12  	algo := nl.XfrmAlgo{
    13  		AlgKeyLen: uint32(len(a.Key) * 8),
    14  		AlgKey:    a.Key,
    15  	}
    16  	end := len(a.Name)
    17  	if end > 64 {
    18  		end = 64
    19  	}
    20  	copy(algo.AlgName[:end], a.Name)
    21  	return algo.Serialize()
    22  }
    23  
    24  func writeStateAlgoAuth(a *XfrmStateAlgo) []byte {
    25  	algo := nl.XfrmAlgoAuth{
    26  		AlgKeyLen:   uint32(len(a.Key) * 8),
    27  		AlgTruncLen: uint32(a.TruncateLen),
    28  		AlgKey:      a.Key,
    29  	}
    30  	end := len(a.Name)
    31  	if end > 64 {
    32  		end = 64
    33  	}
    34  	copy(algo.AlgName[:end], a.Name)
    35  	return algo.Serialize()
    36  }
    37  
    38  func writeStateAlgoAead(a *XfrmStateAlgo) []byte {
    39  	algo := nl.XfrmAlgoAEAD{
    40  		AlgKeyLen: uint32(len(a.Key) * 8),
    41  		AlgICVLen: uint32(a.ICVLen),
    42  		AlgKey:    a.Key,
    43  	}
    44  	end := len(a.Name)
    45  	if end > 64 {
    46  		end = 64
    47  	}
    48  	copy(algo.AlgName[:end], a.Name)
    49  	return algo.Serialize()
    50  }
    51  
    52  func writeMark(m *XfrmMark) []byte {
    53  	mark := &nl.XfrmMark{
    54  		Value: m.Value,
    55  		Mask:  m.Mask,
    56  	}
    57  	if mark.Mask == 0 {
    58  		mark.Mask = ^uint32(0)
    59  	}
    60  	return mark.Serialize()
    61  }
    62  
    63  func writeReplayEsn(replayWindow int) []byte {
    64  	replayEsn := &nl.XfrmReplayStateEsn{
    65  		OSeq:         0,
    66  		Seq:          0,
    67  		OSeqHi:       0,
    68  		SeqHi:        0,
    69  		ReplayWindow: uint32(replayWindow),
    70  	}
    71  
    72  	// Linux stores the bitmap to identify the already received sequence packets in blocks of uint32 elements.
    73  	// Therefore bitmap length is the minimum number of uint32 elements needed. The following is a ceiling operation.
    74  	bytesPerElem := int(unsafe.Sizeof(replayEsn.BmpLen)) // Any uint32 variable is good for this
    75  	replayEsn.BmpLen = uint32((replayWindow + (bytesPerElem * 8) - 1) / (bytesPerElem * 8))
    76  
    77  	return replayEsn.Serialize()
    78  }
    79  
    80  func writeReplay(r *XfrmReplayState) []byte {
    81  	return (&nl.XfrmReplayState{
    82  		OSeq:   r.OSeq,
    83  		Seq:    r.Seq,
    84  		BitMap: r.BitMap,
    85  	}).Serialize()
    86  }
    87  
    88  // XfrmStateAdd will add an xfrm state to the system.
    89  // Equivalent to: `ip xfrm state add $state`
    90  func XfrmStateAdd(state *XfrmState) error {
    91  	return pkgHandle.XfrmStateAdd(state)
    92  }
    93  
    94  // XfrmStateAdd will add an xfrm state to the system.
    95  // Equivalent to: `ip xfrm state add $state`
    96  func (h *Handle) XfrmStateAdd(state *XfrmState) error {
    97  	return h.xfrmStateAddOrUpdate(state, nl.XFRM_MSG_NEWSA)
    98  }
    99  
   100  // XfrmStateAllocSpi will allocate an xfrm state in the system.
   101  // Equivalent to: `ip xfrm state allocspi`
   102  func XfrmStateAllocSpi(state *XfrmState) (*XfrmState, error) {
   103  	return pkgHandle.xfrmStateAllocSpi(state)
   104  }
   105  
   106  // XfrmStateUpdate will update an xfrm state to the system.
   107  // Equivalent to: `ip xfrm state update $state`
   108  func XfrmStateUpdate(state *XfrmState) error {
   109  	return pkgHandle.XfrmStateUpdate(state)
   110  }
   111  
   112  // XfrmStateUpdate will update an xfrm state to the system.
   113  // Equivalent to: `ip xfrm state update $state`
   114  func (h *Handle) XfrmStateUpdate(state *XfrmState) error {
   115  	return h.xfrmStateAddOrUpdate(state, nl.XFRM_MSG_UPDSA)
   116  }
   117  
   118  func (h *Handle) xfrmStateAddOrUpdate(state *XfrmState, nlProto int) error {
   119  
   120  	// A state with spi 0 can't be deleted so don't allow it to be set
   121  	if state.Spi == 0 {
   122  		return fmt.Errorf("Spi must be set when adding xfrm state")
   123  	}
   124  	req := h.newNetlinkRequest(nlProto, unix.NLM_F_CREATE|unix.NLM_F_EXCL|unix.NLM_F_ACK)
   125  
   126  	msg := xfrmUsersaInfoFromXfrmState(state)
   127  
   128  	if state.ESN {
   129  		if state.ReplayWindow == 0 {
   130  			return fmt.Errorf("ESN flag set without ReplayWindow")
   131  		}
   132  		msg.Flags |= nl.XFRM_STATE_ESN
   133  		msg.ReplayWindow = 0
   134  	}
   135  
   136  	limitsToLft(state.Limits, &msg.Lft)
   137  	req.AddData(msg)
   138  
   139  	if state.Auth != nil {
   140  		out := nl.NewRtAttr(nl.XFRMA_ALG_AUTH_TRUNC, writeStateAlgoAuth(state.Auth))
   141  		req.AddData(out)
   142  	}
   143  	if state.Crypt != nil {
   144  		out := nl.NewRtAttr(nl.XFRMA_ALG_CRYPT, writeStateAlgo(state.Crypt))
   145  		req.AddData(out)
   146  	}
   147  	if state.Aead != nil {
   148  		out := nl.NewRtAttr(nl.XFRMA_ALG_AEAD, writeStateAlgoAead(state.Aead))
   149  		req.AddData(out)
   150  	}
   151  	if state.Encap != nil {
   152  		encapData := make([]byte, nl.SizeofXfrmEncapTmpl)
   153  		encap := nl.DeserializeXfrmEncapTmpl(encapData)
   154  		encap.EncapType = uint16(state.Encap.Type)
   155  		encap.EncapSport = nl.Swap16(uint16(state.Encap.SrcPort))
   156  		encap.EncapDport = nl.Swap16(uint16(state.Encap.DstPort))
   157  		encap.EncapOa.FromIP(state.Encap.OriginalAddress)
   158  		out := nl.NewRtAttr(nl.XFRMA_ENCAP, encapData)
   159  		req.AddData(out)
   160  	}
   161  	if state.Mark != nil {
   162  		out := nl.NewRtAttr(nl.XFRMA_MARK, writeMark(state.Mark))
   163  		req.AddData(out)
   164  	}
   165  	if state.ESN {
   166  		out := nl.NewRtAttr(nl.XFRMA_REPLAY_ESN_VAL, writeReplayEsn(state.ReplayWindow))
   167  		req.AddData(out)
   168  	}
   169  	if state.OutputMark != nil {
   170  		out := nl.NewRtAttr(nl.XFRMA_SET_MARK, nl.Uint32Attr(state.OutputMark.Value))
   171  		req.AddData(out)
   172  		if state.OutputMark.Mask != 0 {
   173  			out = nl.NewRtAttr(nl.XFRMA_SET_MARK_MASK, nl.Uint32Attr(state.OutputMark.Mask))
   174  			req.AddData(out)
   175  		}
   176  	}
   177  	if state.OSeqMayWrap || state.DontEncapDSCP {
   178  		var flags uint32
   179  		if state.DontEncapDSCP {
   180  			flags |= nl.XFRM_SA_XFLAG_DONT_ENCAP_DSCP
   181  		}
   182  		if state.OSeqMayWrap {
   183  			flags |= nl.XFRM_SA_XFLAG_OSEQ_MAY_WRAP
   184  		}
   185  		out := nl.NewRtAttr(nl.XFRMA_SA_EXTRA_FLAGS, nl.Uint32Attr(flags))
   186  		req.AddData(out)
   187  	}
   188  	if state.Replay != nil {
   189  		out := nl.NewRtAttr(nl.XFRMA_REPLAY_VAL, writeReplay(state.Replay))
   190  		req.AddData(out)
   191  	}
   192  
   193  	if state.Ifid != 0 {
   194  		ifId := nl.NewRtAttr(nl.XFRMA_IF_ID, nl.Uint32Attr(uint32(state.Ifid)))
   195  		req.AddData(ifId)
   196  	}
   197  
   198  	_, err := req.Execute(unix.NETLINK_XFRM, 0)
   199  	return err
   200  }
   201  
   202  func (h *Handle) xfrmStateAllocSpi(state *XfrmState) (*XfrmState, error) {
   203  	req := h.newNetlinkRequest(nl.XFRM_MSG_ALLOCSPI,
   204  		unix.NLM_F_CREATE|unix.NLM_F_EXCL|unix.NLM_F_ACK)
   205  
   206  	msg := &nl.XfrmUserSpiInfo{}
   207  	msg.XfrmUsersaInfo = *(xfrmUsersaInfoFromXfrmState(state))
   208  	// 1-255 is reserved by IANA for future use
   209  	msg.Min = 0x100
   210  	msg.Max = 0xffffffff
   211  	req.AddData(msg)
   212  
   213  	if state.Mark != nil {
   214  		out := nl.NewRtAttr(nl.XFRMA_MARK, writeMark(state.Mark))
   215  		req.AddData(out)
   216  	}
   217  
   218  	msgs, err := req.Execute(unix.NETLINK_XFRM, 0)
   219  	if err != nil {
   220  		return nil, err
   221  	}
   222  
   223  	return parseXfrmState(msgs[0], FAMILY_ALL)
   224  }
   225  
   226  // XfrmStateDel will delete an xfrm state from the system. Note that
   227  // the Algos are ignored when matching the state to delete.
   228  // Equivalent to: `ip xfrm state del $state`
   229  func XfrmStateDel(state *XfrmState) error {
   230  	return pkgHandle.XfrmStateDel(state)
   231  }
   232  
   233  // XfrmStateDel will delete an xfrm state from the system. Note that
   234  // the Algos are ignored when matching the state to delete.
   235  // Equivalent to: `ip xfrm state del $state`
   236  func (h *Handle) XfrmStateDel(state *XfrmState) error {
   237  	_, err := h.xfrmStateGetOrDelete(state, nl.XFRM_MSG_DELSA)
   238  	return err
   239  }
   240  
   241  // XfrmStateList gets a list of xfrm states in the system.
   242  // Equivalent to: `ip [-4|-6] xfrm state show`.
   243  // The list can be filtered by ip family.
   244  func XfrmStateList(family int) ([]XfrmState, error) {
   245  	return pkgHandle.XfrmStateList(family)
   246  }
   247  
   248  // XfrmStateList gets a list of xfrm states in the system.
   249  // Equivalent to: `ip xfrm state show`.
   250  // The list can be filtered by ip family.
   251  func (h *Handle) XfrmStateList(family int) ([]XfrmState, error) {
   252  	req := h.newNetlinkRequest(nl.XFRM_MSG_GETSA, unix.NLM_F_DUMP)
   253  
   254  	msgs, err := req.Execute(unix.NETLINK_XFRM, nl.XFRM_MSG_NEWSA)
   255  	if err != nil {
   256  		return nil, err
   257  	}
   258  
   259  	var res []XfrmState
   260  	for _, m := range msgs {
   261  		if state, err := parseXfrmState(m, family); err == nil {
   262  			res = append(res, *state)
   263  		} else if err == familyError {
   264  			continue
   265  		} else {
   266  			return nil, err
   267  		}
   268  	}
   269  	return res, nil
   270  }
   271  
   272  // XfrmStateGet gets the xfrm state described by the ID, if found.
   273  // Equivalent to: `ip xfrm state get ID [ mark MARK [ mask MASK ] ]`.
   274  // Only the fields which constitue the SA ID must be filled in:
   275  // ID := [ src ADDR ] [ dst ADDR ] [ proto XFRM-PROTO ] [ spi SPI ]
   276  // mark is optional
   277  func XfrmStateGet(state *XfrmState) (*XfrmState, error) {
   278  	return pkgHandle.XfrmStateGet(state)
   279  }
   280  
   281  // XfrmStateGet gets the xfrm state described by the ID, if found.
   282  // Equivalent to: `ip xfrm state get ID [ mark MARK [ mask MASK ] ]`.
   283  // Only the fields which constitue the SA ID must be filled in:
   284  // ID := [ src ADDR ] [ dst ADDR ] [ proto XFRM-PROTO ] [ spi SPI ]
   285  // mark is optional
   286  func (h *Handle) XfrmStateGet(state *XfrmState) (*XfrmState, error) {
   287  	return h.xfrmStateGetOrDelete(state, nl.XFRM_MSG_GETSA)
   288  }
   289  
   290  func (h *Handle) xfrmStateGetOrDelete(state *XfrmState, nlProto int) (*XfrmState, error) {
   291  	req := h.newNetlinkRequest(nlProto, unix.NLM_F_ACK)
   292  
   293  	msg := &nl.XfrmUsersaId{}
   294  	msg.Family = uint16(nl.GetIPFamily(state.Dst))
   295  	msg.Daddr.FromIP(state.Dst)
   296  	msg.Proto = uint8(state.Proto)
   297  	msg.Spi = nl.Swap32(uint32(state.Spi))
   298  	req.AddData(msg)
   299  
   300  	if state.Mark != nil {
   301  		out := nl.NewRtAttr(nl.XFRMA_MARK, writeMark(state.Mark))
   302  		req.AddData(out)
   303  	}
   304  	if state.Src != nil {
   305  		out := nl.NewRtAttr(nl.XFRMA_SRCADDR, state.Src.To16())
   306  		req.AddData(out)
   307  	}
   308  
   309  	if state.Ifid != 0 {
   310  		ifId := nl.NewRtAttr(nl.XFRMA_IF_ID, nl.Uint32Attr(uint32(state.Ifid)))
   311  		req.AddData(ifId)
   312  	}
   313  
   314  	resType := nl.XFRM_MSG_NEWSA
   315  	if nlProto == nl.XFRM_MSG_DELSA {
   316  		resType = 0
   317  	}
   318  
   319  	msgs, err := req.Execute(unix.NETLINK_XFRM, uint16(resType))
   320  	if err != nil {
   321  		return nil, err
   322  	}
   323  
   324  	if nlProto == nl.XFRM_MSG_DELSA {
   325  		return nil, nil
   326  	}
   327  
   328  	s, err := parseXfrmState(msgs[0], FAMILY_ALL)
   329  	if err != nil {
   330  		return nil, err
   331  	}
   332  
   333  	return s, nil
   334  }
   335  
   336  var familyError = fmt.Errorf("family error")
   337  
   338  func xfrmStateFromXfrmUsersaInfo(msg *nl.XfrmUsersaInfo) *XfrmState {
   339  	var state XfrmState
   340  
   341  	state.Dst = msg.Id.Daddr.ToIP()
   342  	state.Src = msg.Saddr.ToIP()
   343  	state.Proto = Proto(msg.Id.Proto)
   344  	state.Mode = Mode(msg.Mode)
   345  	state.Spi = int(nl.Swap32(msg.Id.Spi))
   346  	state.Reqid = int(msg.Reqid)
   347  	state.ReplayWindow = int(msg.ReplayWindow)
   348  	lftToLimits(&msg.Lft, &state.Limits)
   349  	curToStats(&msg.Curlft, &msg.Stats, &state.Statistics)
   350  
   351  	return &state
   352  }
   353  
   354  func parseXfrmState(m []byte, family int) (*XfrmState, error) {
   355  	msg := nl.DeserializeXfrmUsersaInfo(m)
   356  
   357  	// This is mainly for the state dump
   358  	if family != FAMILY_ALL && family != int(msg.Family) {
   359  		return nil, familyError
   360  	}
   361  
   362  	state := xfrmStateFromXfrmUsersaInfo(msg)
   363  
   364  	attrs, err := nl.ParseRouteAttr(m[nl.SizeofXfrmUsersaInfo:])
   365  	if err != nil {
   366  		return nil, err
   367  	}
   368  
   369  	for _, attr := range attrs {
   370  		switch attr.Attr.Type {
   371  		case nl.XFRMA_ALG_AUTH, nl.XFRMA_ALG_CRYPT:
   372  			var resAlgo *XfrmStateAlgo
   373  			if attr.Attr.Type == nl.XFRMA_ALG_AUTH {
   374  				if state.Auth == nil {
   375  					state.Auth = new(XfrmStateAlgo)
   376  				}
   377  				resAlgo = state.Auth
   378  			} else {
   379  				state.Crypt = new(XfrmStateAlgo)
   380  				resAlgo = state.Crypt
   381  			}
   382  			algo := nl.DeserializeXfrmAlgo(attr.Value[:])
   383  			(*resAlgo).Name = nl.BytesToString(algo.AlgName[:])
   384  			(*resAlgo).Key = algo.AlgKey
   385  		case nl.XFRMA_ALG_AUTH_TRUNC:
   386  			if state.Auth == nil {
   387  				state.Auth = new(XfrmStateAlgo)
   388  			}
   389  			algo := nl.DeserializeXfrmAlgoAuth(attr.Value[:])
   390  			state.Auth.Name = nl.BytesToString(algo.AlgName[:])
   391  			state.Auth.Key = algo.AlgKey
   392  			state.Auth.TruncateLen = int(algo.AlgTruncLen)
   393  		case nl.XFRMA_ALG_AEAD:
   394  			state.Aead = new(XfrmStateAlgo)
   395  			algo := nl.DeserializeXfrmAlgoAEAD(attr.Value[:])
   396  			state.Aead.Name = nl.BytesToString(algo.AlgName[:])
   397  			state.Aead.Key = algo.AlgKey
   398  			state.Aead.ICVLen = int(algo.AlgICVLen)
   399  		case nl.XFRMA_ENCAP:
   400  			encap := nl.DeserializeXfrmEncapTmpl(attr.Value[:])
   401  			state.Encap = new(XfrmStateEncap)
   402  			state.Encap.Type = EncapType(encap.EncapType)
   403  			state.Encap.SrcPort = int(nl.Swap16(encap.EncapSport))
   404  			state.Encap.DstPort = int(nl.Swap16(encap.EncapDport))
   405  			state.Encap.OriginalAddress = encap.EncapOa.ToIP()
   406  		case nl.XFRMA_MARK:
   407  			mark := nl.DeserializeXfrmMark(attr.Value[:])
   408  			state.Mark = new(XfrmMark)
   409  			state.Mark.Value = mark.Value
   410  			state.Mark.Mask = mark.Mask
   411  		case nl.XFRMA_SA_EXTRA_FLAGS:
   412  			flags := native.Uint32(attr.Value)
   413  			if (flags & nl.XFRM_SA_XFLAG_DONT_ENCAP_DSCP) != 0 {
   414  				state.DontEncapDSCP = true
   415  			}
   416  			if (flags & nl.XFRM_SA_XFLAG_OSEQ_MAY_WRAP) != 0 {
   417  				state.OSeqMayWrap = true
   418  			}
   419  		case nl.XFRMA_SET_MARK:
   420  			if state.OutputMark == nil {
   421  				state.OutputMark = new(XfrmMark)
   422  			}
   423  			state.OutputMark.Value = native.Uint32(attr.Value)
   424  		case nl.XFRMA_SET_MARK_MASK:
   425  			if state.OutputMark == nil {
   426  				state.OutputMark = new(XfrmMark)
   427  			}
   428  			state.OutputMark.Mask = native.Uint32(attr.Value)
   429  			if state.OutputMark.Mask == 0xffffffff {
   430  				state.OutputMark.Mask = 0
   431  			}
   432  		case nl.XFRMA_IF_ID:
   433  			state.Ifid = int(native.Uint32(attr.Value))
   434  		case nl.XFRMA_REPLAY_VAL:
   435  			if state.Replay == nil {
   436  				state.Replay = new(XfrmReplayState)
   437  			}
   438  			replay := nl.DeserializeXfrmReplayState(attr.Value[:])
   439  			state.Replay.OSeq = replay.OSeq
   440  			state.Replay.Seq = replay.Seq
   441  			state.Replay.BitMap = replay.BitMap
   442  		}
   443  	}
   444  
   445  	return state, nil
   446  }
   447  
   448  // XfrmStateFlush will flush the xfrm state on the system.
   449  // proto = 0 means any transformation protocols
   450  // Equivalent to: `ip xfrm state flush [ proto XFRM-PROTO ]`
   451  func XfrmStateFlush(proto Proto) error {
   452  	return pkgHandle.XfrmStateFlush(proto)
   453  }
   454  
   455  // XfrmStateFlush will flush the xfrm state on the system.
   456  // proto = 0 means any transformation protocols
   457  // Equivalent to: `ip xfrm state flush [ proto XFRM-PROTO ]`
   458  func (h *Handle) XfrmStateFlush(proto Proto) error {
   459  	req := h.newNetlinkRequest(nl.XFRM_MSG_FLUSHSA, unix.NLM_F_ACK)
   460  
   461  	req.AddData(&nl.XfrmUsersaFlush{Proto: uint8(proto)})
   462  
   463  	_, err := req.Execute(unix.NETLINK_XFRM, 0)
   464  	return err
   465  }
   466  
   467  func limitsToLft(lmts XfrmStateLimits, lft *nl.XfrmLifetimeCfg) {
   468  	if lmts.ByteSoft != 0 {
   469  		lft.SoftByteLimit = lmts.ByteSoft
   470  	} else {
   471  		lft.SoftByteLimit = nl.XFRM_INF
   472  	}
   473  	if lmts.ByteHard != 0 {
   474  		lft.HardByteLimit = lmts.ByteHard
   475  	} else {
   476  		lft.HardByteLimit = nl.XFRM_INF
   477  	}
   478  	if lmts.PacketSoft != 0 {
   479  		lft.SoftPacketLimit = lmts.PacketSoft
   480  	} else {
   481  		lft.SoftPacketLimit = nl.XFRM_INF
   482  	}
   483  	if lmts.PacketHard != 0 {
   484  		lft.HardPacketLimit = lmts.PacketHard
   485  	} else {
   486  		lft.HardPacketLimit = nl.XFRM_INF
   487  	}
   488  	lft.SoftAddExpiresSeconds = lmts.TimeSoft
   489  	lft.HardAddExpiresSeconds = lmts.TimeHard
   490  	lft.SoftUseExpiresSeconds = lmts.TimeUseSoft
   491  	lft.HardUseExpiresSeconds = lmts.TimeUseHard
   492  }
   493  
   494  func lftToLimits(lft *nl.XfrmLifetimeCfg, lmts *XfrmStateLimits) {
   495  	*lmts = *(*XfrmStateLimits)(unsafe.Pointer(lft))
   496  }
   497  
   498  func curToStats(cur *nl.XfrmLifetimeCur, wstats *nl.XfrmStats, stats *XfrmStateStats) {
   499  	stats.Bytes = cur.Bytes
   500  	stats.Packets = cur.Packets
   501  	stats.AddTime = cur.AddTime
   502  	stats.UseTime = cur.UseTime
   503  	stats.ReplayWindow = wstats.ReplayWindow
   504  	stats.Replay = wstats.Replay
   505  	stats.Failed = wstats.IntegrityFailed
   506  }
   507  
   508  func xfrmUsersaInfoFromXfrmState(state *XfrmState) *nl.XfrmUsersaInfo {
   509  	msg := &nl.XfrmUsersaInfo{}
   510  	msg.Family = uint16(nl.GetIPFamily(state.Dst))
   511  	msg.Id.Daddr.FromIP(state.Dst)
   512  	msg.Saddr.FromIP(state.Src)
   513  	msg.Id.Proto = uint8(state.Proto)
   514  	msg.Mode = uint8(state.Mode)
   515  	msg.Id.Spi = nl.Swap32(uint32(state.Spi))
   516  	msg.Reqid = uint32(state.Reqid)
   517  	msg.ReplayWindow = uint8(state.ReplayWindow)
   518  
   519  	return msg
   520  }