github.com/sagernet/netlink@v0.0.0-20240612041022-b9a21c07ac6a/xfrm_state_test.go (about) 1 //go:build linux 2 // +build linux 3 4 package netlink 5 6 import ( 7 "bytes" 8 "encoding/hex" 9 "net" 10 "testing" 11 "time" 12 ) 13 14 func TestXfrmStateAddGetDel(t *testing.T) { 15 for _, s := range []*XfrmState{getBaseState(), getAeadState()} { 16 testXfrmStateAddGetDel(t, s) 17 } 18 } 19 20 func testXfrmStateAddGetDel(t *testing.T, state *XfrmState) { 21 tearDown := setUpNetlinkTest(t) 22 defer tearDown() 23 if err := XfrmStateAdd(state); err != nil { 24 t.Fatal(err) 25 } 26 states, err := XfrmStateList(FAMILY_ALL) 27 if err != nil { 28 t.Fatal(err) 29 } 30 31 if len(states) != 1 { 32 t.Fatal("State not added properly") 33 } 34 35 if !compareStates(state, &states[0]) { 36 t.Fatalf("unexpected states returned") 37 } 38 39 // Get specific state 40 sa, err := XfrmStateGet(state) 41 if err != nil { 42 t.Fatal(err) 43 } 44 45 if !compareStates(state, sa) { 46 t.Fatalf("unexpected state returned") 47 } 48 49 if err = XfrmStateDel(state); err != nil { 50 t.Fatal(err) 51 } 52 53 states, err = XfrmStateList(FAMILY_ALL) 54 if err != nil { 55 t.Fatal(err) 56 } 57 if len(states) != 0 { 58 t.Fatal("State not removed properly") 59 } 60 61 if _, err := XfrmStateGet(state); err == nil { 62 t.Fatalf("Unexpected success") 63 } 64 } 65 66 func TestXfrmStateAllocSpi(t *testing.T) { 67 defer setUpNetlinkTest(t)() 68 69 state := getBaseState() 70 state.Spi = 0 71 state.Auth = nil 72 state.Crypt = nil 73 rstate, err := XfrmStateAllocSpi(state) 74 if err != nil { 75 t.Fatal(err) 76 } 77 if rstate.Spi == 0 { 78 t.Fatalf("SPI is not allocated") 79 } 80 rstate.Spi = 0 81 if !compareStates(state, rstate) { 82 t.Fatalf("State not properly allocated") 83 } 84 } 85 86 func TestXfrmStateFlush(t *testing.T) { 87 defer setUpNetlinkTest(t)() 88 89 state1 := getBaseState() 90 state2 := getBaseState() 91 state2.Src = net.ParseIP("127.1.0.1") 92 state2.Dst = net.ParseIP("127.1.0.2") 93 state2.Proto = XFRM_PROTO_AH 94 state2.Mode = XFRM_MODE_TUNNEL 95 state2.Spi = 20 96 state2.Mark = nil 97 state2.Crypt = nil 98 99 if err := XfrmStateAdd(state1); err != nil { 100 t.Fatal(err) 101 } 102 if err := XfrmStateAdd(state2); err != nil { 103 t.Fatal(err) 104 } 105 106 // flushing proto for which no state is present should return silently 107 if err := XfrmStateFlush(XFRM_PROTO_COMP); err != nil { 108 t.Fatal(err) 109 } 110 111 if err := XfrmStateFlush(XFRM_PROTO_AH); err != nil { 112 t.Fatal(err) 113 } 114 115 if _, err := XfrmStateGet(state2); err == nil { 116 t.Fatalf("Unexpected success") 117 } 118 119 if err := XfrmStateAdd(state2); err != nil { 120 t.Fatal(err) 121 } 122 123 if err := XfrmStateFlush(0); err != nil { 124 t.Fatal(err) 125 } 126 127 states, err := XfrmStateList(FAMILY_ALL) 128 if err != nil { 129 t.Fatal(err) 130 } 131 if len(states) != 0 { 132 t.Fatal("State not flushed properly") 133 } 134 135 } 136 137 func TestXfrmStateUpdateLimits(t *testing.T) { 138 defer setUpNetlinkTest(t)() 139 140 // Program state with limits 141 state := getBaseState() 142 state.Limits.TimeHard = 3600 143 state.Limits.TimeSoft = 60 144 state.Limits.PacketHard = 1000 145 state.Limits.PacketSoft = 50 146 state.Limits.ByteHard = 1000000 147 state.Limits.ByteSoft = 50000 148 state.Limits.TimeUseHard = 3000 149 state.Limits.TimeUseSoft = 1500 150 if err := XfrmStateAdd(state); err != nil { 151 t.Fatal(err) 152 } 153 // Verify limits 154 s, err := XfrmStateGet(state) 155 if err != nil { 156 t.Fatal(err) 157 } 158 if !compareLimits(state, s) { 159 t.Fatalf("Incorrect time hard/soft retrieved: %s", s.Print(true)) 160 } 161 162 // Update limits 163 state.Limits.TimeHard = 1800 164 state.Limits.TimeSoft = 30 165 state.Limits.PacketHard = 500 166 state.Limits.PacketSoft = 25 167 state.Limits.ByteHard = 500000 168 state.Limits.ByteSoft = 25000 169 state.Limits.TimeUseHard = 2000 170 state.Limits.TimeUseSoft = 1000 171 if err := XfrmStateUpdate(state); err != nil { 172 t.Fatal(err) 173 } 174 175 // Verify new limits 176 s, err = XfrmStateGet(state) 177 if err != nil { 178 t.Fatal(err) 179 } 180 if s.Limits.TimeHard != 1800 || s.Limits.TimeSoft != 30 { 181 t.Fatalf("Incorrect time hard retrieved: (%d, %d)", s.Limits.TimeHard, s.Limits.TimeSoft) 182 } 183 } 184 185 func TestXfrmStateStats(t *testing.T) { 186 defer setUpNetlinkTest(t)() 187 188 // Program state and record time 189 state := getBaseState() 190 now := time.Now() 191 if err := XfrmStateAdd(state); err != nil { 192 t.Fatal(err) 193 } 194 // Retrieve state 195 s, err := XfrmStateGet(state) 196 if err != nil { 197 t.Fatal(err) 198 } 199 // Verify stats: We expect zero counters, same second add time and unset use time 200 if s.Statistics.Bytes != 0 || s.Statistics.Packets != 0 || s.Statistics.AddTime != uint64(now.Unix()) || s.Statistics.UseTime != 0 { 201 t.Fatalf("Unexpected statistics (addTime: %s) for state:\n%s", now.Format(time.UnixDate), s.Print(true)) 202 } 203 } 204 205 func TestXfrmStateWithIfid(t *testing.T) { 206 minKernelRequired(t, 4, 19) 207 defer setUpNetlinkTest(t)() 208 209 state := getBaseState() 210 state.Ifid = 54321 211 if err := XfrmStateAdd(state); err != nil { 212 t.Fatal(err) 213 } 214 s, err := XfrmStateGet(state) 215 if err != nil { 216 t.Fatal(err) 217 } 218 if !compareStates(state, s) { 219 t.Fatalf("unexpected state returned.\nExpected: %v.\nGot %v", state, s) 220 } 221 if err = XfrmStateDel(s); err != nil { 222 t.Fatal(err) 223 } 224 } 225 226 func TestXfrmStateWithOutputMark(t *testing.T) { 227 minKernelRequired(t, 4, 14) 228 defer setUpNetlinkTest(t)() 229 230 state := getBaseState() 231 state.OutputMark = &XfrmMark{ 232 Value: 0x0000000a, 233 } 234 if err := XfrmStateAdd(state); err != nil { 235 t.Fatal(err) 236 } 237 s, err := XfrmStateGet(state) 238 if err != nil { 239 t.Fatal(err) 240 } 241 if !compareStates(state, s) { 242 t.Fatalf("unexpected state returned.\nExpected: %v.\nGot %v", state, s) 243 } 244 if err = XfrmStateDel(s); err != nil { 245 t.Fatal(err) 246 } 247 } 248 249 func TestXfrmStateWithOutputMarkAndMask(t *testing.T) { 250 minKernelRequired(t, 4, 19) 251 defer setUpNetlinkTest(t)() 252 253 state := getBaseState() 254 state.OutputMark = &XfrmMark{ 255 Value: 0x0000000a, 256 Mask: 0x0000000f, 257 } 258 if err := XfrmStateAdd(state); err != nil { 259 t.Fatal(err) 260 } 261 s, err := XfrmStateGet(state) 262 if err != nil { 263 t.Fatal(err) 264 } 265 if !compareStates(state, s) { 266 t.Fatalf("unexpected state returned.\nExpected: %v.\nGot %v", state, s) 267 } 268 if err = XfrmStateDel(s); err != nil { 269 t.Fatal(err) 270 } 271 } 272 273 func getBaseState() *XfrmState { 274 return &XfrmState{ 275 // Force 4 byte notation for the IPv4 addresses 276 Src: net.ParseIP("127.0.0.1").To4(), 277 Dst: net.ParseIP("127.0.0.2").To4(), 278 Proto: XFRM_PROTO_ESP, 279 Mode: XFRM_MODE_TUNNEL, 280 Spi: 1, 281 Auth: &XfrmStateAlgo{ 282 Name: "hmac(sha256)", 283 Key: []byte("abcdefghijklmnopqrstuvwzyzABCDEF"), 284 }, 285 Crypt: &XfrmStateAlgo{ 286 Name: "cbc(aes)", 287 Key: []byte("abcdefghijklmnopqrstuvwzyzABCDEF"), 288 }, 289 Mark: &XfrmMark{ 290 Value: 0x12340000, 291 Mask: 0xffff0000, 292 }, 293 } 294 } 295 296 func getAeadState() *XfrmState { 297 // 128 key bits + 32 salt bits 298 k, _ := hex.DecodeString("d0562776bf0e75830ba3f7f8eb6c09b555aa1177") 299 return &XfrmState{ 300 // Leave IPv4 addresses in Ipv4 in IPv6 notation 301 Src: net.ParseIP("192.168.1.1"), 302 Dst: net.ParseIP("192.168.2.2"), 303 Proto: XFRM_PROTO_ESP, 304 Mode: XFRM_MODE_TUNNEL, 305 Spi: 2, 306 Aead: &XfrmStateAlgo{ 307 Name: "rfc4106(gcm(aes))", 308 Key: k, 309 ICVLen: 64, 310 }, 311 } 312 } 313 314 func compareStates(a, b *XfrmState) bool { 315 if a == b { 316 return true 317 } 318 if a == nil || b == nil { 319 return false 320 } 321 return a.Src.Equal(b.Src) && a.Dst.Equal(b.Dst) && 322 a.Mode == b.Mode && a.Spi == b.Spi && a.Proto == b.Proto && 323 a.Ifid == b.Ifid && 324 compareAlgo(a.Auth, b.Auth) && 325 compareAlgo(a.Crypt, b.Crypt) && 326 compareAlgo(a.Aead, b.Aead) && 327 compareMarks(a.Mark, b.Mark) && 328 compareMarks(a.OutputMark, b.OutputMark) 329 } 330 331 func compareLimits(a, b *XfrmState) bool { 332 return a.Limits.TimeHard == b.Limits.TimeHard && 333 a.Limits.TimeSoft == b.Limits.TimeSoft && 334 a.Limits.PacketHard == b.Limits.PacketHard && 335 a.Limits.PacketSoft == b.Limits.PacketSoft && 336 a.Limits.ByteHard == b.Limits.ByteHard && 337 a.Limits.ByteSoft == b.Limits.ByteSoft && 338 a.Limits.TimeUseHard == b.Limits.TimeUseHard && 339 a.Limits.TimeUseSoft == b.Limits.TimeUseSoft 340 } 341 342 func compareAlgo(a, b *XfrmStateAlgo) bool { 343 if a == b { 344 return true 345 } 346 if a == nil || b == nil { 347 return false 348 } 349 return a.Name == b.Name && bytes.Equal(a.Key, b.Key) && 350 (a.TruncateLen == 0 || a.TruncateLen == b.TruncateLen) && 351 (a.ICVLen == 0 || a.ICVLen == b.ICVLen) 352 } 353 354 func compareMarks(a, b *XfrmMark) bool { 355 if a == b { 356 return true 357 } 358 if a == nil || b == nil { 359 return false 360 } 361 return a.Value == b.Value && a.Mask == b.Mask 362 }