github.com/sagernet/netlink@v0.0.0-20240612041022-b9a21c07ac6a/xfrm_state_test.go (about)

     1  //go:build linux
     2  // +build linux
     3  
     4  package netlink
     5  
     6  import (
     7  	"bytes"
     8  	"encoding/hex"
     9  	"net"
    10  	"testing"
    11  	"time"
    12  )
    13  
    14  func TestXfrmStateAddGetDel(t *testing.T) {
    15  	for _, s := range []*XfrmState{getBaseState(), getAeadState()} {
    16  		testXfrmStateAddGetDel(t, s)
    17  	}
    18  }
    19  
    20  func testXfrmStateAddGetDel(t *testing.T, state *XfrmState) {
    21  	tearDown := setUpNetlinkTest(t)
    22  	defer tearDown()
    23  	if err := XfrmStateAdd(state); err != nil {
    24  		t.Fatal(err)
    25  	}
    26  	states, err := XfrmStateList(FAMILY_ALL)
    27  	if err != nil {
    28  		t.Fatal(err)
    29  	}
    30  
    31  	if len(states) != 1 {
    32  		t.Fatal("State not added properly")
    33  	}
    34  
    35  	if !compareStates(state, &states[0]) {
    36  		t.Fatalf("unexpected states returned")
    37  	}
    38  
    39  	// Get specific state
    40  	sa, err := XfrmStateGet(state)
    41  	if err != nil {
    42  		t.Fatal(err)
    43  	}
    44  
    45  	if !compareStates(state, sa) {
    46  		t.Fatalf("unexpected state returned")
    47  	}
    48  
    49  	if err = XfrmStateDel(state); err != nil {
    50  		t.Fatal(err)
    51  	}
    52  
    53  	states, err = XfrmStateList(FAMILY_ALL)
    54  	if err != nil {
    55  		t.Fatal(err)
    56  	}
    57  	if len(states) != 0 {
    58  		t.Fatal("State not removed properly")
    59  	}
    60  
    61  	if _, err := XfrmStateGet(state); err == nil {
    62  		t.Fatalf("Unexpected success")
    63  	}
    64  }
    65  
    66  func TestXfrmStateAllocSpi(t *testing.T) {
    67  	defer setUpNetlinkTest(t)()
    68  
    69  	state := getBaseState()
    70  	state.Spi = 0
    71  	state.Auth = nil
    72  	state.Crypt = nil
    73  	rstate, err := XfrmStateAllocSpi(state)
    74  	if err != nil {
    75  		t.Fatal(err)
    76  	}
    77  	if rstate.Spi == 0 {
    78  		t.Fatalf("SPI is not allocated")
    79  	}
    80  	rstate.Spi = 0
    81  	if !compareStates(state, rstate) {
    82  		t.Fatalf("State not properly allocated")
    83  	}
    84  }
    85  
    86  func TestXfrmStateFlush(t *testing.T) {
    87  	defer setUpNetlinkTest(t)()
    88  
    89  	state1 := getBaseState()
    90  	state2 := getBaseState()
    91  	state2.Src = net.ParseIP("127.1.0.1")
    92  	state2.Dst = net.ParseIP("127.1.0.2")
    93  	state2.Proto = XFRM_PROTO_AH
    94  	state2.Mode = XFRM_MODE_TUNNEL
    95  	state2.Spi = 20
    96  	state2.Mark = nil
    97  	state2.Crypt = nil
    98  
    99  	if err := XfrmStateAdd(state1); err != nil {
   100  		t.Fatal(err)
   101  	}
   102  	if err := XfrmStateAdd(state2); err != nil {
   103  		t.Fatal(err)
   104  	}
   105  
   106  	// flushing proto for which no state is present should return silently
   107  	if err := XfrmStateFlush(XFRM_PROTO_COMP); err != nil {
   108  		t.Fatal(err)
   109  	}
   110  
   111  	if err := XfrmStateFlush(XFRM_PROTO_AH); err != nil {
   112  		t.Fatal(err)
   113  	}
   114  
   115  	if _, err := XfrmStateGet(state2); err == nil {
   116  		t.Fatalf("Unexpected success")
   117  	}
   118  
   119  	if err := XfrmStateAdd(state2); err != nil {
   120  		t.Fatal(err)
   121  	}
   122  
   123  	if err := XfrmStateFlush(0); err != nil {
   124  		t.Fatal(err)
   125  	}
   126  
   127  	states, err := XfrmStateList(FAMILY_ALL)
   128  	if err != nil {
   129  		t.Fatal(err)
   130  	}
   131  	if len(states) != 0 {
   132  		t.Fatal("State not flushed properly")
   133  	}
   134  
   135  }
   136  
   137  func TestXfrmStateUpdateLimits(t *testing.T) {
   138  	defer setUpNetlinkTest(t)()
   139  
   140  	// Program state with limits
   141  	state := getBaseState()
   142  	state.Limits.TimeHard = 3600
   143  	state.Limits.TimeSoft = 60
   144  	state.Limits.PacketHard = 1000
   145  	state.Limits.PacketSoft = 50
   146  	state.Limits.ByteHard = 1000000
   147  	state.Limits.ByteSoft = 50000
   148  	state.Limits.TimeUseHard = 3000
   149  	state.Limits.TimeUseSoft = 1500
   150  	if err := XfrmStateAdd(state); err != nil {
   151  		t.Fatal(err)
   152  	}
   153  	// Verify limits
   154  	s, err := XfrmStateGet(state)
   155  	if err != nil {
   156  		t.Fatal(err)
   157  	}
   158  	if !compareLimits(state, s) {
   159  		t.Fatalf("Incorrect time hard/soft retrieved: %s", s.Print(true))
   160  	}
   161  
   162  	// Update limits
   163  	state.Limits.TimeHard = 1800
   164  	state.Limits.TimeSoft = 30
   165  	state.Limits.PacketHard = 500
   166  	state.Limits.PacketSoft = 25
   167  	state.Limits.ByteHard = 500000
   168  	state.Limits.ByteSoft = 25000
   169  	state.Limits.TimeUseHard = 2000
   170  	state.Limits.TimeUseSoft = 1000
   171  	if err := XfrmStateUpdate(state); err != nil {
   172  		t.Fatal(err)
   173  	}
   174  
   175  	// Verify new limits
   176  	s, err = XfrmStateGet(state)
   177  	if err != nil {
   178  		t.Fatal(err)
   179  	}
   180  	if s.Limits.TimeHard != 1800 || s.Limits.TimeSoft != 30 {
   181  		t.Fatalf("Incorrect time hard retrieved: (%d, %d)", s.Limits.TimeHard, s.Limits.TimeSoft)
   182  	}
   183  }
   184  
   185  func TestXfrmStateStats(t *testing.T) {
   186  	defer setUpNetlinkTest(t)()
   187  
   188  	// Program state and record time
   189  	state := getBaseState()
   190  	now := time.Now()
   191  	if err := XfrmStateAdd(state); err != nil {
   192  		t.Fatal(err)
   193  	}
   194  	// Retrieve state
   195  	s, err := XfrmStateGet(state)
   196  	if err != nil {
   197  		t.Fatal(err)
   198  	}
   199  	// Verify stats: We expect zero counters, same second add time and unset use time
   200  	if s.Statistics.Bytes != 0 || s.Statistics.Packets != 0 || s.Statistics.AddTime != uint64(now.Unix()) || s.Statistics.UseTime != 0 {
   201  		t.Fatalf("Unexpected statistics (addTime: %s) for state:\n%s", now.Format(time.UnixDate), s.Print(true))
   202  	}
   203  }
   204  
   205  func TestXfrmStateWithIfid(t *testing.T) {
   206  	minKernelRequired(t, 4, 19)
   207  	defer setUpNetlinkTest(t)()
   208  
   209  	state := getBaseState()
   210  	state.Ifid = 54321
   211  	if err := XfrmStateAdd(state); err != nil {
   212  		t.Fatal(err)
   213  	}
   214  	s, err := XfrmStateGet(state)
   215  	if err != nil {
   216  		t.Fatal(err)
   217  	}
   218  	if !compareStates(state, s) {
   219  		t.Fatalf("unexpected state returned.\nExpected: %v.\nGot %v", state, s)
   220  	}
   221  	if err = XfrmStateDel(s); err != nil {
   222  		t.Fatal(err)
   223  	}
   224  }
   225  
   226  func TestXfrmStateWithOutputMark(t *testing.T) {
   227  	minKernelRequired(t, 4, 14)
   228  	defer setUpNetlinkTest(t)()
   229  
   230  	state := getBaseState()
   231  	state.OutputMark = &XfrmMark{
   232  		Value: 0x0000000a,
   233  	}
   234  	if err := XfrmStateAdd(state); err != nil {
   235  		t.Fatal(err)
   236  	}
   237  	s, err := XfrmStateGet(state)
   238  	if err != nil {
   239  		t.Fatal(err)
   240  	}
   241  	if !compareStates(state, s) {
   242  		t.Fatalf("unexpected state returned.\nExpected: %v.\nGot %v", state, s)
   243  	}
   244  	if err = XfrmStateDel(s); err != nil {
   245  		t.Fatal(err)
   246  	}
   247  }
   248  
   249  func TestXfrmStateWithOutputMarkAndMask(t *testing.T) {
   250  	minKernelRequired(t, 4, 19)
   251  	defer setUpNetlinkTest(t)()
   252  
   253  	state := getBaseState()
   254  	state.OutputMark = &XfrmMark{
   255  		Value: 0x0000000a,
   256  		Mask:  0x0000000f,
   257  	}
   258  	if err := XfrmStateAdd(state); err != nil {
   259  		t.Fatal(err)
   260  	}
   261  	s, err := XfrmStateGet(state)
   262  	if err != nil {
   263  		t.Fatal(err)
   264  	}
   265  	if !compareStates(state, s) {
   266  		t.Fatalf("unexpected state returned.\nExpected: %v.\nGot %v", state, s)
   267  	}
   268  	if err = XfrmStateDel(s); err != nil {
   269  		t.Fatal(err)
   270  	}
   271  }
   272  
   273  func getBaseState() *XfrmState {
   274  	return &XfrmState{
   275  		// Force 4 byte notation for the IPv4 addresses
   276  		Src:   net.ParseIP("127.0.0.1").To4(),
   277  		Dst:   net.ParseIP("127.0.0.2").To4(),
   278  		Proto: XFRM_PROTO_ESP,
   279  		Mode:  XFRM_MODE_TUNNEL,
   280  		Spi:   1,
   281  		Auth: &XfrmStateAlgo{
   282  			Name: "hmac(sha256)",
   283  			Key:  []byte("abcdefghijklmnopqrstuvwzyzABCDEF"),
   284  		},
   285  		Crypt: &XfrmStateAlgo{
   286  			Name: "cbc(aes)",
   287  			Key:  []byte("abcdefghijklmnopqrstuvwzyzABCDEF"),
   288  		},
   289  		Mark: &XfrmMark{
   290  			Value: 0x12340000,
   291  			Mask:  0xffff0000,
   292  		},
   293  	}
   294  }
   295  
   296  func getAeadState() *XfrmState {
   297  	// 128 key bits + 32 salt bits
   298  	k, _ := hex.DecodeString("d0562776bf0e75830ba3f7f8eb6c09b555aa1177")
   299  	return &XfrmState{
   300  		// Leave IPv4 addresses in Ipv4 in IPv6 notation
   301  		Src:   net.ParseIP("192.168.1.1"),
   302  		Dst:   net.ParseIP("192.168.2.2"),
   303  		Proto: XFRM_PROTO_ESP,
   304  		Mode:  XFRM_MODE_TUNNEL,
   305  		Spi:   2,
   306  		Aead: &XfrmStateAlgo{
   307  			Name:   "rfc4106(gcm(aes))",
   308  			Key:    k,
   309  			ICVLen: 64,
   310  		},
   311  	}
   312  }
   313  
   314  func compareStates(a, b *XfrmState) bool {
   315  	if a == b {
   316  		return true
   317  	}
   318  	if a == nil || b == nil {
   319  		return false
   320  	}
   321  	return a.Src.Equal(b.Src) && a.Dst.Equal(b.Dst) &&
   322  		a.Mode == b.Mode && a.Spi == b.Spi && a.Proto == b.Proto &&
   323  		a.Ifid == b.Ifid &&
   324  		compareAlgo(a.Auth, b.Auth) &&
   325  		compareAlgo(a.Crypt, b.Crypt) &&
   326  		compareAlgo(a.Aead, b.Aead) &&
   327  		compareMarks(a.Mark, b.Mark) &&
   328  		compareMarks(a.OutputMark, b.OutputMark)
   329  }
   330  
   331  func compareLimits(a, b *XfrmState) bool {
   332  	return a.Limits.TimeHard == b.Limits.TimeHard &&
   333  		a.Limits.TimeSoft == b.Limits.TimeSoft &&
   334  		a.Limits.PacketHard == b.Limits.PacketHard &&
   335  		a.Limits.PacketSoft == b.Limits.PacketSoft &&
   336  		a.Limits.ByteHard == b.Limits.ByteHard &&
   337  		a.Limits.ByteSoft == b.Limits.ByteSoft &&
   338  		a.Limits.TimeUseHard == b.Limits.TimeUseHard &&
   339  		a.Limits.TimeUseSoft == b.Limits.TimeUseSoft
   340  }
   341  
   342  func compareAlgo(a, b *XfrmStateAlgo) bool {
   343  	if a == b {
   344  		return true
   345  	}
   346  	if a == nil || b == nil {
   347  		return false
   348  	}
   349  	return a.Name == b.Name && bytes.Equal(a.Key, b.Key) &&
   350  		(a.TruncateLen == 0 || a.TruncateLen == b.TruncateLen) &&
   351  		(a.ICVLen == 0 || a.ICVLen == b.ICVLen)
   352  }
   353  
   354  func compareMarks(a, b *XfrmMark) bool {
   355  	if a == b {
   356  		return true
   357  	}
   358  	if a == nil || b == nil {
   359  		return false
   360  	}
   361  	return a.Value == b.Value && a.Mask == b.Mask
   362  }