github.com/vishvananda/netlink@v1.3.0/xfrm_state_linux_test.go (about)

     1  package netlink
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/hex"
     6  	"net"
     7  	"testing"
     8  	"time"
     9  )
    10  
    11  func TestXfrmStateAddGetDel(t *testing.T) {
    12  	for _, s := range []*XfrmState{
    13  		getBaseState(),
    14  		getAeadState(),
    15  		getBaseStateV6oV4(),
    16  		getBaseStateV4oV6(),
    17  	} {
    18  		testXfrmStateAddGetDel(t, s)
    19  	}
    20  }
    21  
    22  func testXfrmStateAddGetDel(t *testing.T, state *XfrmState) {
    23  	tearDown := setUpNetlinkTest(t)
    24  	defer tearDown()
    25  	if err := XfrmStateAdd(state); err != nil {
    26  		t.Fatal(err)
    27  	}
    28  	states, err := XfrmStateList(FAMILY_ALL)
    29  	if err != nil {
    30  		t.Fatal(err)
    31  	}
    32  	if len(states) != 1 {
    33  		t.Fatal("State not added properly")
    34  	}
    35  
    36  	if !compareStates(state, &states[0]) {
    37  		t.Fatalf("unexpected states returned")
    38  	}
    39  
    40  	// Get specific state
    41  	sa, err := XfrmStateGet(state)
    42  	if err != nil {
    43  		t.Fatal(err)
    44  	}
    45  
    46  	if !compareStates(state, sa) {
    47  		t.Fatalf("unexpected state returned")
    48  	}
    49  
    50  	if err = XfrmStateDel(state); err != nil {
    51  		t.Fatal(err)
    52  	}
    53  
    54  	states, err = XfrmStateList(FAMILY_ALL)
    55  	if err != nil {
    56  		t.Fatal(err)
    57  	}
    58  	if len(states) != 0 {
    59  		t.Fatal("State not removed properly")
    60  	}
    61  
    62  	if _, err := XfrmStateGet(state); err == nil {
    63  		t.Fatalf("Unexpected success")
    64  	}
    65  }
    66  
    67  func TestXfrmStateAllocSpi(t *testing.T) {
    68  	defer setUpNetlinkTest(t)()
    69  
    70  	state := getBaseState()
    71  	state.Spi = 0
    72  	state.Auth = nil
    73  	state.Crypt = nil
    74  	rstate, err := XfrmStateAllocSpi(state)
    75  	if err != nil {
    76  		t.Fatal(err)
    77  	}
    78  	if rstate.Spi == 0 {
    79  		t.Fatalf("SPI is not allocated")
    80  	}
    81  	rstate.Spi = 0
    82  
    83  	if !compareStates(state, rstate) {
    84  		t.Fatalf("State not properly allocated")
    85  	}
    86  }
    87  
    88  func TestXfrmStateFlush(t *testing.T) {
    89  	defer setUpNetlinkTest(t)()
    90  
    91  	state1 := getBaseState()
    92  	state2 := getBaseState()
    93  	state2.Src = net.ParseIP("127.1.0.1")
    94  	state2.Dst = net.ParseIP("127.1.0.2")
    95  	state2.Proto = XFRM_PROTO_AH
    96  	state2.Mode = XFRM_MODE_TUNNEL
    97  	state2.Spi = 20
    98  	state2.Mark = nil
    99  	state2.Crypt = nil
   100  
   101  	if err := XfrmStateAdd(state1); err != nil {
   102  		t.Fatal(err)
   103  	}
   104  	if err := XfrmStateAdd(state2); err != nil {
   105  		t.Fatal(err)
   106  	}
   107  
   108  	// flushing proto for which no state is present should return silently
   109  	if err := XfrmStateFlush(XFRM_PROTO_COMP); err != nil {
   110  		t.Fatal(err)
   111  	}
   112  
   113  	if err := XfrmStateFlush(XFRM_PROTO_AH); err != nil {
   114  		t.Fatal(err)
   115  	}
   116  
   117  	if _, err := XfrmStateGet(state2); err == nil {
   118  		t.Fatalf("Unexpected success")
   119  	}
   120  
   121  	if err := XfrmStateAdd(state2); err != nil {
   122  		t.Fatal(err)
   123  	}
   124  
   125  	if err := XfrmStateFlush(0); err != nil {
   126  		t.Fatal(err)
   127  	}
   128  
   129  	states, err := XfrmStateList(FAMILY_ALL)
   130  	if err != nil {
   131  		t.Fatal(err)
   132  	}
   133  	if len(states) != 0 {
   134  		t.Fatal("State not flushed properly")
   135  	}
   136  
   137  }
   138  
   139  func TestXfrmStateUpdateLimits(t *testing.T) {
   140  	defer setUpNetlinkTest(t)()
   141  
   142  	// Program state with limits
   143  	state := getBaseState()
   144  	state.Limits.TimeHard = 3600
   145  	state.Limits.TimeSoft = 60
   146  	state.Limits.PacketHard = 1000
   147  	state.Limits.PacketSoft = 50
   148  	state.Limits.ByteHard = 1000000
   149  	state.Limits.ByteSoft = 50000
   150  	state.Limits.TimeUseHard = 3000
   151  	state.Limits.TimeUseSoft = 1500
   152  	if err := XfrmStateAdd(state); err != nil {
   153  		t.Fatal(err)
   154  	}
   155  	// Verify limits
   156  	s, err := XfrmStateGet(state)
   157  	if err != nil {
   158  		t.Fatal(err)
   159  	}
   160  	if !compareLimits(state, s) {
   161  		t.Fatalf("Incorrect time hard/soft retrieved: %s", s.Print(true))
   162  	}
   163  
   164  	// Update limits
   165  	state.Limits.TimeHard = 1800
   166  	state.Limits.TimeSoft = 30
   167  	state.Limits.PacketHard = 500
   168  	state.Limits.PacketSoft = 25
   169  	state.Limits.ByteHard = 500000
   170  	state.Limits.ByteSoft = 25000
   171  	state.Limits.TimeUseHard = 2000
   172  	state.Limits.TimeUseSoft = 1000
   173  	if err := XfrmStateUpdate(state); err != nil {
   174  		t.Fatal(err)
   175  	}
   176  
   177  	// Verify new limits
   178  	s, err = XfrmStateGet(state)
   179  	if err != nil {
   180  		t.Fatal(err)
   181  	}
   182  	if s.Limits.TimeHard != 1800 || s.Limits.TimeSoft != 30 {
   183  		t.Fatalf("Incorrect time hard retrieved: (%d, %d)", s.Limits.TimeHard, s.Limits.TimeSoft)
   184  	}
   185  }
   186  
   187  func TestXfrmStateStats(t *testing.T) {
   188  	defer setUpNetlinkTest(t)()
   189  
   190  	// Program state and record time
   191  	state := getBaseState()
   192  	now := time.Now()
   193  	if err := XfrmStateAdd(state); err != nil {
   194  		t.Fatal(err)
   195  	}
   196  	// Retrieve state
   197  	s, err := XfrmStateGet(state)
   198  	if err != nil {
   199  		t.Fatal(err)
   200  	}
   201  	// Verify stats: We expect zero counters, same second add time and unset use time
   202  	if s.Statistics.Bytes != 0 || s.Statistics.Packets != 0 || s.Statistics.AddTime != uint64(now.Unix()) || s.Statistics.UseTime != 0 {
   203  		t.Fatalf("Unexpected statistics (addTime: %s) for state:\n%s", now.Format(time.UnixDate), s.Print(true))
   204  	}
   205  }
   206  
   207  func TestXfrmStateWithIfid(t *testing.T) {
   208  	minKernelRequired(t, 4, 19)
   209  	defer setUpNetlinkTest(t)()
   210  
   211  	state := getBaseState()
   212  	state.Ifid = 54321
   213  	if err := XfrmStateAdd(state); err != nil {
   214  		t.Fatal(err)
   215  	}
   216  	s, err := XfrmStateGet(state)
   217  	if err != nil {
   218  		t.Fatal(err)
   219  	}
   220  	if !compareStates(state, s) {
   221  		t.Fatalf("unexpected state returned.\nExpected: %v.\nGot %v", state, s)
   222  	}
   223  	if err = XfrmStateDel(s); err != nil {
   224  		t.Fatal(err)
   225  	}
   226  }
   227  
   228  func TestXfrmStateWithOutputMark(t *testing.T) {
   229  	minKernelRequired(t, 4, 14)
   230  	defer setUpNetlinkTest(t)()
   231  
   232  	state := getBaseState()
   233  	state.OutputMark = &XfrmMark{
   234  		Value: 0x0000000a,
   235  	}
   236  	if err := XfrmStateAdd(state); err != nil {
   237  		t.Fatal(err)
   238  	}
   239  	s, err := XfrmStateGet(state)
   240  	if err != nil {
   241  		t.Fatal(err)
   242  	}
   243  	if !compareStates(state, s) {
   244  		t.Fatalf("unexpected state returned.\nExpected: %v.\nGot %v", state, s)
   245  	}
   246  	if err = XfrmStateDel(s); err != nil {
   247  		t.Fatal(err)
   248  	}
   249  }
   250  
   251  func TestXfrmStateWithOutputMarkAndMask(t *testing.T) {
   252  	minKernelRequired(t, 4, 19)
   253  	defer setUpNetlinkTest(t)()
   254  
   255  	state := getBaseState()
   256  	state.OutputMark = &XfrmMark{
   257  		Value: 0x0000000a,
   258  		Mask:  0x0000000f,
   259  	}
   260  	if err := XfrmStateAdd(state); err != nil {
   261  		t.Fatal(err)
   262  	}
   263  	s, err := XfrmStateGet(state)
   264  	if err != nil {
   265  		t.Fatal(err)
   266  	}
   267  	if !compareStates(state, s) {
   268  		t.Fatalf("unexpected state returned.\nExpected: %v.\nGot %v", state, s)
   269  	}
   270  	if err = XfrmStateDel(s); err != nil {
   271  		t.Fatal(err)
   272  	}
   273  }
   274  func genStateSelectorForV6Payload() *XfrmPolicy {
   275  	_, wildcardV6Net, _ := net.ParseCIDR("::/0")
   276  	return &XfrmPolicy{
   277  		Src: wildcardV6Net,
   278  		Dst: wildcardV6Net,
   279  	}
   280  }
   281  
   282  func genStateSelectorForV4Payload() *XfrmPolicy {
   283  	_, wildcardV4Net, _ := net.ParseCIDR("0.0.0.0/0")
   284  	return &XfrmPolicy{
   285  		Src: wildcardV4Net,
   286  		Dst: wildcardV4Net,
   287  	}
   288  }
   289  
   290  func getBaseState() *XfrmState {
   291  	return &XfrmState{
   292  		// Force 4 byte notation for the IPv4 addresses
   293  		Src:   net.ParseIP("127.0.0.1").To4(),
   294  		Dst:   net.ParseIP("127.0.0.2").To4(),
   295  		Proto: XFRM_PROTO_ESP,
   296  		Mode:  XFRM_MODE_TUNNEL,
   297  		Spi:   1,
   298  		Auth: &XfrmStateAlgo{
   299  			Name: "hmac(sha256)",
   300  			Key:  []byte("abcdefghijklmnopqrstuvwzyzABCDEF"),
   301  		},
   302  		Crypt: &XfrmStateAlgo{
   303  			Name: "cbc(aes)",
   304  			Key:  []byte("abcdefghijklmnopqrstuvwzyzABCDEF"),
   305  		},
   306  		Mark: &XfrmMark{
   307  			Value: 0x12340000,
   308  			Mask:  0xffff0000,
   309  		},
   310  	}
   311  }
   312  
   313  func getBaseStateV4oV6() *XfrmState {
   314  	return &XfrmState{
   315  		// Force 4 byte notation for the IPv4 addressesd
   316  		Src:   net.ParseIP("2001:dead::1").To16(),
   317  		Dst:   net.ParseIP("2001:beef::1").To16(),
   318  		Proto: XFRM_PROTO_ESP,
   319  		Mode:  XFRM_MODE_TUNNEL,
   320  		Spi:   1,
   321  		Auth: &XfrmStateAlgo{
   322  			Name: "hmac(sha256)",
   323  			Key:  []byte("abcdefghijklmnopqrstuvwzyzABCDEF"),
   324  		},
   325  		Crypt: &XfrmStateAlgo{
   326  			Name: "cbc(aes)",
   327  			Key:  []byte("abcdefghijklmnopqrstuvwzyzABCDEF"),
   328  		},
   329  		Mark: &XfrmMark{
   330  			Value: 0x12340000,
   331  			Mask:  0xffff0000,
   332  		},
   333  		Selector: genStateSelectorForV4Payload(),
   334  	}
   335  }
   336  
   337  func getBaseStateV6oV4() *XfrmState {
   338  	return &XfrmState{
   339  		// Force 4 byte notation for the IPv4 addressesd
   340  		Src:   net.ParseIP("192.168.1.1").To4(),
   341  		Dst:   net.ParseIP("192.168.2.2").To4(),
   342  		Proto: XFRM_PROTO_ESP,
   343  		Mode:  XFRM_MODE_TUNNEL,
   344  		Spi:   1,
   345  		Auth: &XfrmStateAlgo{
   346  			Name: "hmac(sha256)",
   347  			Key:  []byte("abcdefghijklmnopqrstuvwzyzABCDEF"),
   348  		},
   349  		Crypt: &XfrmStateAlgo{
   350  			Name: "cbc(aes)",
   351  			Key:  []byte("abcdefghijklmnopqrstuvwzyzABCDEF"),
   352  		},
   353  		Mark: &XfrmMark{
   354  			Value: 0x12340000,
   355  			Mask:  0xffff0000,
   356  		},
   357  		Selector: genStateSelectorForV6Payload(),
   358  	}
   359  }
   360  
   361  func getAeadState() *XfrmState {
   362  	// 128 key bits + 32 salt bits
   363  	k, _ := hex.DecodeString("d0562776bf0e75830ba3f7f8eb6c09b555aa1177")
   364  	return &XfrmState{
   365  		// Leave IPv4 addresses in Ipv4 in IPv6 notation
   366  		Src:   net.ParseIP("192.168.1.1"),
   367  		Dst:   net.ParseIP("192.168.2.2"),
   368  		Proto: XFRM_PROTO_ESP,
   369  		Mode:  XFRM_MODE_TUNNEL,
   370  		Spi:   2,
   371  		Aead: &XfrmStateAlgo{
   372  			Name:   "rfc4106(gcm(aes))",
   373  			Key:    k,
   374  			ICVLen: 64,
   375  		},
   376  	}
   377  }
   378  func compareSelector(a, b *XfrmPolicy) bool {
   379  	return a.Src.String() == b.Src.String() &&
   380  		a.Dst.String() == b.Dst.String() &&
   381  		a.Proto == b.Proto &&
   382  		a.DstPort == b.DstPort &&
   383  		a.SrcPort == b.SrcPort &&
   384  		a.Ifindex == b.Ifindex
   385  }
   386  
   387  func compareStates(a, b *XfrmState) bool {
   388  	if a == b {
   389  		return true
   390  	}
   391  	if a == nil || b == nil {
   392  		return false
   393  	}
   394  	if a.Selector != nil && b.Selector != nil {
   395  		if !compareSelector(a.Selector, b.Selector) {
   396  			return false
   397  		}
   398  	}
   399  
   400  	return a.Src.Equal(b.Src) && a.Dst.Equal(b.Dst) &&
   401  		a.Mode == b.Mode && a.Spi == b.Spi && a.Proto == b.Proto &&
   402  		a.Ifid == b.Ifid &&
   403  		compareAlgo(a.Auth, b.Auth) &&
   404  		compareAlgo(a.Crypt, b.Crypt) &&
   405  		compareAlgo(a.Aead, b.Aead) &&
   406  		compareMarks(a.Mark, b.Mark) &&
   407  		compareMarks(a.OutputMark, b.OutputMark)
   408  
   409  }
   410  
   411  func compareLimits(a, b *XfrmState) bool {
   412  	return a.Limits.TimeHard == b.Limits.TimeHard &&
   413  		a.Limits.TimeSoft == b.Limits.TimeSoft &&
   414  		a.Limits.PacketHard == b.Limits.PacketHard &&
   415  		a.Limits.PacketSoft == b.Limits.PacketSoft &&
   416  		a.Limits.ByteHard == b.Limits.ByteHard &&
   417  		a.Limits.ByteSoft == b.Limits.ByteSoft &&
   418  		a.Limits.TimeUseHard == b.Limits.TimeUseHard &&
   419  		a.Limits.TimeUseSoft == b.Limits.TimeUseSoft
   420  }
   421  
   422  func compareAlgo(a, b *XfrmStateAlgo) bool {
   423  	if a == b {
   424  		return true
   425  	}
   426  	if a == nil || b == nil {
   427  		return false
   428  	}
   429  	return a.Name == b.Name && bytes.Equal(a.Key, b.Key) &&
   430  		(a.TruncateLen == 0 || a.TruncateLen == b.TruncateLen) &&
   431  		(a.ICVLen == 0 || a.ICVLen == b.ICVLen)
   432  }
   433  
   434  func compareMarks(a, b *XfrmMark) bool {
   435  	if a == b {
   436  		return true
   437  	}
   438  	if a == nil || b == nil {
   439  		return false
   440  	}
   441  	return a.Value == b.Value && a.Mask == b.Mask
   442  }