github.com/slackhq/nebula@v1.9.0/firewall_test.go (about) 1 package nebula 2 3 import ( 4 "bytes" 5 "errors" 6 "math" 7 "net" 8 "testing" 9 "time" 10 11 "github.com/slackhq/nebula/cert" 12 "github.com/slackhq/nebula/config" 13 "github.com/slackhq/nebula/firewall" 14 "github.com/slackhq/nebula/iputil" 15 "github.com/slackhq/nebula/test" 16 "github.com/stretchr/testify/assert" 17 ) 18 19 func TestNewFirewall(t *testing.T) { 20 l := test.NewLogger() 21 c := &cert.NebulaCertificate{} 22 fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) 23 conntrack := fw.Conntrack 24 assert.NotNil(t, conntrack) 25 assert.NotNil(t, conntrack.Conns) 26 assert.NotNil(t, conntrack.TimerWheel) 27 assert.NotNil(t, fw.InRules) 28 assert.NotNil(t, fw.OutRules) 29 assert.Equal(t, time.Second, fw.TCPTimeout) 30 assert.Equal(t, time.Minute, fw.UDPTimeout) 31 assert.Equal(t, time.Hour, fw.DefaultTimeout) 32 33 assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) 34 assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) 35 assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen) 36 37 fw = NewFirewall(l, time.Second, time.Hour, time.Minute, c) 38 assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) 39 assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen) 40 41 fw = NewFirewall(l, time.Hour, time.Second, time.Minute, c) 42 assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) 43 assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen) 44 45 fw = NewFirewall(l, time.Hour, time.Minute, time.Second, c) 46 assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) 47 assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen) 48 49 fw = NewFirewall(l, time.Minute, time.Hour, time.Second, c) 50 assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) 51 assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen) 52 53 fw = NewFirewall(l, time.Minute, time.Second, time.Hour, c) 54 assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) 55 assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen) 56 } 57 58 func TestFirewall_AddRule(t *testing.T) { 59 l := test.NewLogger() 60 ob := &bytes.Buffer{} 61 l.SetOutput(ob) 62 63 c := &cert.NebulaCertificate{} 64 fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) 65 assert.NotNil(t, fw.InRules) 66 assert.NotNil(t, fw.OutRules) 67 68 _, ti, _ := net.ParseCIDR("1.2.3.4/32") 69 70 assert.Nil(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", nil, nil, "", "")) 71 // An empty rule is any 72 assert.True(t, fw.InRules.TCP[1].Any.Any.Any) 73 assert.Empty(t, fw.InRules.TCP[1].Any.Groups) 74 assert.Empty(t, fw.InRules.TCP[1].Any.Hosts) 75 76 fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) 77 assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "", "")) 78 assert.Nil(t, fw.InRules.UDP[1].Any.Any) 79 assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1") 80 assert.Empty(t, fw.InRules.UDP[1].Any.Hosts) 81 82 fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) 83 assert.Nil(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", nil, nil, "", "")) 84 assert.Nil(t, fw.InRules.ICMP[1].Any.Any) 85 assert.Empty(t, fw.InRules.ICMP[1].Any.Groups) 86 assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1") 87 88 fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) 89 assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, nil, "", "")) 90 assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any) 91 ok, _ := fw.OutRules.AnyProto[1].Any.CIDR.GetCIDR(ti) 92 assert.True(t, ok) 93 94 fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) 95 assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", nil, ti, "", "")) 96 assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any) 97 ok, _ = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.GetCIDR(ti) 98 assert.True(t, ok) 99 100 fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) 101 assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "ca-name", "")) 102 assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name") 103 104 fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) 105 assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "", "ca-sha")) 106 assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha") 107 108 fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) 109 assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, nil, "", "")) 110 assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) 111 112 fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) 113 _, anyIp, _ := net.ParseCIDR("0.0.0.0/0") 114 assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, nil, "", "")) 115 assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) 116 117 // Test error conditions 118 fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) 119 assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", nil, nil, "", "")) 120 assert.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", nil, nil, "", "")) 121 } 122 123 func TestFirewall_Drop(t *testing.T) { 124 l := test.NewLogger() 125 ob := &bytes.Buffer{} 126 l.SetOutput(ob) 127 128 p := firewall.Packet{ 129 LocalIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), 130 RemoteIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), 131 LocalPort: 10, 132 RemotePort: 90, 133 Protocol: firewall.ProtoUDP, 134 Fragment: false, 135 } 136 137 ipNet := net.IPNet{ 138 IP: net.IPv4(1, 2, 3, 4), 139 Mask: net.IPMask{255, 255, 255, 0}, 140 } 141 142 c := cert.NebulaCertificate{ 143 Details: cert.NebulaCertificateDetails{ 144 Name: "host1", 145 Ips: []*net.IPNet{&ipNet}, 146 Groups: []string{"default-group"}, 147 InvertedGroups: map[string]struct{}{"default-group": {}}, 148 Issuer: "signer-shasum", 149 }, 150 } 151 h := HostInfo{ 152 ConnectionState: &ConnectionState{ 153 peerCert: &c, 154 }, 155 vpnIp: iputil.Ip2VpnIp(ipNet.IP), 156 } 157 h.CreateRemoteCIDR(&c) 158 159 fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) 160 assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, nil, "", "")) 161 cp := cert.NewCAPool() 162 163 // Drop outbound 164 assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule) 165 // Allow inbound 166 resetConntrack(fw) 167 assert.NoError(t, fw.Drop(p, true, &h, cp, nil)) 168 // Allow outbound because conntrack 169 assert.NoError(t, fw.Drop(p, false, &h, cp, nil)) 170 171 // test remote mismatch 172 oldRemote := p.RemoteIP 173 p.RemoteIP = iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 10)) 174 assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP) 175 p.RemoteIP = oldRemote 176 177 // ensure signer doesn't get in the way of group checks 178 fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) 179 assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "", "signer-shasum")) 180 assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "", "signer-shasum-bad")) 181 assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) 182 183 // test caSha doesn't drop on match 184 fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) 185 assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "", "signer-shasum-bad")) 186 assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "", "signer-shasum")) 187 assert.NoError(t, fw.Drop(p, true, &h, cp, nil)) 188 189 // ensure ca name doesn't get in the way of group checks 190 cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}} 191 fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) 192 assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "ca-good", "")) 193 assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "ca-good-bad", "")) 194 assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) 195 196 // test caName doesn't drop on match 197 cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}} 198 fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) 199 assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "ca-good-bad", "")) 200 assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "ca-good", "")) 201 assert.NoError(t, fw.Drop(p, true, &h, cp, nil)) 202 } 203 204 func BenchmarkFirewallTable_match(b *testing.B) { 205 f := &Firewall{} 206 ft := FirewallTable{ 207 TCP: firewallPort{}, 208 } 209 210 _, n, _ := net.ParseCIDR("172.1.1.1/32") 211 goodLocalCIDRIP := iputil.Ip2VpnIp(n.IP) 212 _ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", n, nil, "", "") 213 _ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", nil, n, "", "") 214 cp := cert.NewCAPool() 215 216 b.Run("fail on proto", func(b *testing.B) { 217 // This benchmark is showing us the cost of failing to match the protocol 218 c := &cert.NebulaCertificate{} 219 for n := 0; n < b.N; n++ { 220 assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoUDP}, true, c, cp)) 221 } 222 }) 223 224 b.Run("pass proto, fail on port", func(b *testing.B) { 225 // This benchmark is showing us the cost of matching a specific protocol but failing to match the port 226 c := &cert.NebulaCertificate{} 227 for n := 0; n < b.N; n++ { 228 assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 1}, true, c, cp)) 229 } 230 }) 231 232 b.Run("pass proto, port, fail on local CIDR", func(b *testing.B) { 233 c := &cert.NebulaCertificate{} 234 ip, _, _ := net.ParseCIDR("9.254.254.254/32") 235 lip := iputil.Ip2VpnIp(ip) 236 for n := 0; n < b.N; n++ { 237 assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: lip}, true, c, cp)) 238 } 239 }) 240 241 b.Run("pass proto, port, any local CIDR, fail all group, name, and cidr", func(b *testing.B) { 242 _, ip, _ := net.ParseCIDR("9.254.254.254/32") 243 c := &cert.NebulaCertificate{ 244 Details: cert.NebulaCertificateDetails{ 245 InvertedGroups: map[string]struct{}{"nope": {}}, 246 Name: "nope", 247 Ips: []*net.IPNet{ip}, 248 }, 249 } 250 for n := 0; n < b.N; n++ { 251 assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)) 252 } 253 }) 254 255 b.Run("pass proto, port, specific local CIDR, fail all group, name, and cidr", func(b *testing.B) { 256 _, ip, _ := net.ParseCIDR("9.254.254.254/32") 257 c := &cert.NebulaCertificate{ 258 Details: cert.NebulaCertificateDetails{ 259 InvertedGroups: map[string]struct{}{"nope": {}}, 260 Name: "nope", 261 Ips: []*net.IPNet{ip}, 262 }, 263 } 264 for n := 0; n < b.N; n++ { 265 assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: goodLocalCIDRIP}, true, c, cp)) 266 } 267 }) 268 269 b.Run("pass on group on any local cidr", func(b *testing.B) { 270 c := &cert.NebulaCertificate{ 271 Details: cert.NebulaCertificateDetails{ 272 InvertedGroups: map[string]struct{}{"good-group": {}}, 273 Name: "nope", 274 }, 275 } 276 for n := 0; n < b.N; n++ { 277 assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)) 278 } 279 }) 280 281 b.Run("pass on group on specific local cidr", func(b *testing.B) { 282 c := &cert.NebulaCertificate{ 283 Details: cert.NebulaCertificateDetails{ 284 InvertedGroups: map[string]struct{}{"good-group": {}}, 285 Name: "nope", 286 }, 287 } 288 for n := 0; n < b.N; n++ { 289 assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: goodLocalCIDRIP}, true, c, cp)) 290 } 291 }) 292 293 b.Run("pass on name", func(b *testing.B) { 294 c := &cert.NebulaCertificate{ 295 Details: cert.NebulaCertificateDetails{ 296 InvertedGroups: map[string]struct{}{"nope": {}}, 297 Name: "good-host", 298 }, 299 } 300 for n := 0; n < b.N; n++ { 301 ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp) 302 } 303 }) 304 // 305 //b.Run("pass on ip", func(b *testing.B) { 306 // ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1)) 307 // c := &cert.NebulaCertificate{ 308 // Details: cert.NebulaCertificateDetails{ 309 // InvertedGroups: map[string]struct{}{"nope": {}}, 310 // Name: "good-host", 311 // }, 312 // } 313 // for n := 0; n < b.N; n++ { 314 // ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10, RemoteIP: ip}, true, c, cp) 315 // } 316 //}) 317 // 318 //b.Run("pass on local ip", func(b *testing.B) { 319 // ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1)) 320 // c := &cert.NebulaCertificate{ 321 // Details: cert.NebulaCertificateDetails{ 322 // InvertedGroups: map[string]struct{}{"nope": {}}, 323 // Name: "good-host", 324 // }, 325 // } 326 // for n := 0; n < b.N; n++ { 327 // ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10, LocalIP: ip}, true, c, cp) 328 // } 329 //}) 330 // 331 //_ = ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, n, "", "") 332 // 333 //b.Run("pass on ip with any port", func(b *testing.B) { 334 // ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1)) 335 // c := &cert.NebulaCertificate{ 336 // Details: cert.NebulaCertificateDetails{ 337 // InvertedGroups: map[string]struct{}{"nope": {}}, 338 // Name: "good-host", 339 // }, 340 // } 341 // for n := 0; n < b.N; n++ { 342 // ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, RemoteIP: ip}, true, c, cp) 343 // } 344 //}) 345 // 346 //b.Run("pass on local ip with any port", func(b *testing.B) { 347 // ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1)) 348 // c := &cert.NebulaCertificate{ 349 // Details: cert.NebulaCertificateDetails{ 350 // InvertedGroups: map[string]struct{}{"nope": {}}, 351 // Name: "good-host", 352 // }, 353 // } 354 // for n := 0; n < b.N; n++ { 355 // ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: ip}, true, c, cp) 356 // } 357 //}) 358 } 359 360 func TestFirewall_Drop2(t *testing.T) { 361 l := test.NewLogger() 362 ob := &bytes.Buffer{} 363 l.SetOutput(ob) 364 365 p := firewall.Packet{ 366 LocalIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), 367 RemoteIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), 368 LocalPort: 10, 369 RemotePort: 90, 370 Protocol: firewall.ProtoUDP, 371 Fragment: false, 372 } 373 374 ipNet := net.IPNet{ 375 IP: net.IPv4(1, 2, 3, 4), 376 Mask: net.IPMask{255, 255, 255, 0}, 377 } 378 379 c := cert.NebulaCertificate{ 380 Details: cert.NebulaCertificateDetails{ 381 Name: "host1", 382 Ips: []*net.IPNet{&ipNet}, 383 InvertedGroups: map[string]struct{}{"default-group": {}, "test-group": {}}, 384 }, 385 } 386 h := HostInfo{ 387 ConnectionState: &ConnectionState{ 388 peerCert: &c, 389 }, 390 vpnIp: iputil.Ip2VpnIp(ipNet.IP), 391 } 392 h.CreateRemoteCIDR(&c) 393 394 c1 := cert.NebulaCertificate{ 395 Details: cert.NebulaCertificateDetails{ 396 Name: "host1", 397 Ips: []*net.IPNet{&ipNet}, 398 InvertedGroups: map[string]struct{}{"default-group": {}, "test-group-not": {}}, 399 }, 400 } 401 h1 := HostInfo{ 402 ConnectionState: &ConnectionState{ 403 peerCert: &c1, 404 }, 405 } 406 h1.CreateRemoteCIDR(&c1) 407 408 fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) 409 assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", nil, nil, "", "")) 410 cp := cert.NewCAPool() 411 412 // h1/c1 lacks the proper groups 413 assert.Error(t, fw.Drop(p, true, &h1, cp, nil), ErrNoMatchingRule) 414 // c has the proper groups 415 resetConntrack(fw) 416 assert.NoError(t, fw.Drop(p, true, &h, cp, nil)) 417 } 418 419 func TestFirewall_Drop3(t *testing.T) { 420 l := test.NewLogger() 421 ob := &bytes.Buffer{} 422 l.SetOutput(ob) 423 424 p := firewall.Packet{ 425 LocalIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), 426 RemoteIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), 427 LocalPort: 1, 428 RemotePort: 1, 429 Protocol: firewall.ProtoUDP, 430 Fragment: false, 431 } 432 433 ipNet := net.IPNet{ 434 IP: net.IPv4(1, 2, 3, 4), 435 Mask: net.IPMask{255, 255, 255, 0}, 436 } 437 438 c := cert.NebulaCertificate{ 439 Details: cert.NebulaCertificateDetails{ 440 Name: "host-owner", 441 Ips: []*net.IPNet{&ipNet}, 442 }, 443 } 444 445 c1 := cert.NebulaCertificate{ 446 Details: cert.NebulaCertificateDetails{ 447 Name: "host1", 448 Ips: []*net.IPNet{&ipNet}, 449 Issuer: "signer-sha-bad", 450 }, 451 } 452 h1 := HostInfo{ 453 ConnectionState: &ConnectionState{ 454 peerCert: &c1, 455 }, 456 vpnIp: iputil.Ip2VpnIp(ipNet.IP), 457 } 458 h1.CreateRemoteCIDR(&c1) 459 460 c2 := cert.NebulaCertificate{ 461 Details: cert.NebulaCertificateDetails{ 462 Name: "host2", 463 Ips: []*net.IPNet{&ipNet}, 464 Issuer: "signer-sha", 465 }, 466 } 467 h2 := HostInfo{ 468 ConnectionState: &ConnectionState{ 469 peerCert: &c2, 470 }, 471 vpnIp: iputil.Ip2VpnIp(ipNet.IP), 472 } 473 h2.CreateRemoteCIDR(&c2) 474 475 c3 := cert.NebulaCertificate{ 476 Details: cert.NebulaCertificateDetails{ 477 Name: "host3", 478 Ips: []*net.IPNet{&ipNet}, 479 Issuer: "signer-sha-bad", 480 }, 481 } 482 h3 := HostInfo{ 483 ConnectionState: &ConnectionState{ 484 peerCert: &c3, 485 }, 486 vpnIp: iputil.Ip2VpnIp(ipNet.IP), 487 } 488 h3.CreateRemoteCIDR(&c3) 489 490 fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) 491 assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", nil, nil, "", "")) 492 assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", nil, nil, "", "signer-sha")) 493 cp := cert.NewCAPool() 494 495 // c1 should pass because host match 496 assert.NoError(t, fw.Drop(p, true, &h1, cp, nil)) 497 // c2 should pass because ca sha match 498 resetConntrack(fw) 499 assert.NoError(t, fw.Drop(p, true, &h2, cp, nil)) 500 // c3 should fail because no match 501 resetConntrack(fw) 502 assert.Equal(t, fw.Drop(p, true, &h3, cp, nil), ErrNoMatchingRule) 503 } 504 505 func TestFirewall_DropConntrackReload(t *testing.T) { 506 l := test.NewLogger() 507 ob := &bytes.Buffer{} 508 l.SetOutput(ob) 509 510 p := firewall.Packet{ 511 LocalIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), 512 RemoteIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), 513 LocalPort: 10, 514 RemotePort: 90, 515 Protocol: firewall.ProtoUDP, 516 Fragment: false, 517 } 518 519 ipNet := net.IPNet{ 520 IP: net.IPv4(1, 2, 3, 4), 521 Mask: net.IPMask{255, 255, 255, 0}, 522 } 523 524 c := cert.NebulaCertificate{ 525 Details: cert.NebulaCertificateDetails{ 526 Name: "host1", 527 Ips: []*net.IPNet{&ipNet}, 528 Groups: []string{"default-group"}, 529 InvertedGroups: map[string]struct{}{"default-group": {}}, 530 Issuer: "signer-shasum", 531 }, 532 } 533 h := HostInfo{ 534 ConnectionState: &ConnectionState{ 535 peerCert: &c, 536 }, 537 vpnIp: iputil.Ip2VpnIp(ipNet.IP), 538 } 539 h.CreateRemoteCIDR(&c) 540 541 fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) 542 assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, nil, "", "")) 543 cp := cert.NewCAPool() 544 545 // Drop outbound 546 assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule) 547 // Allow inbound 548 resetConntrack(fw) 549 assert.NoError(t, fw.Drop(p, true, &h, cp, nil)) 550 // Allow outbound because conntrack 551 assert.NoError(t, fw.Drop(p, false, &h, cp, nil)) 552 553 oldFw := fw 554 fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) 555 assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", nil, nil, "", "")) 556 fw.Conntrack = oldFw.Conntrack 557 fw.rulesVersion = oldFw.rulesVersion + 1 558 559 // Allow outbound because conntrack and new rules allow port 10 560 assert.NoError(t, fw.Drop(p, false, &h, cp, nil)) 561 562 oldFw = fw 563 fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) 564 assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", nil, nil, "", "")) 565 fw.Conntrack = oldFw.Conntrack 566 fw.rulesVersion = oldFw.rulesVersion + 1 567 568 // Drop outbound because conntrack doesn't match new ruleset 569 assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule) 570 } 571 572 func BenchmarkLookup(b *testing.B) { 573 ml := func(m map[string]struct{}, a [][]string) { 574 for n := 0; n < b.N; n++ { 575 for _, sg := range a { 576 found := false 577 578 for _, g := range sg { 579 if _, ok := m[g]; !ok { 580 found = false 581 break 582 } 583 584 found = true 585 } 586 587 if found { 588 return 589 } 590 } 591 } 592 } 593 594 b.Run("array to map best", func(b *testing.B) { 595 m := map[string]struct{}{ 596 "1ne": {}, 597 "2wo": {}, 598 "3hr": {}, 599 "4ou": {}, 600 "5iv": {}, 601 "6ix": {}, 602 } 603 604 a := [][]string{ 605 {"1ne", "2wo", "3hr", "4ou", "5iv", "6ix"}, 606 {"one", "2wo", "3hr", "4ou", "5iv", "6ix"}, 607 {"one", "two", "3hr", "4ou", "5iv", "6ix"}, 608 {"one", "two", "thr", "4ou", "5iv", "6ix"}, 609 {"one", "two", "thr", "fou", "5iv", "6ix"}, 610 {"one", "two", "thr", "fou", "fiv", "6ix"}, 611 {"one", "two", "thr", "fou", "fiv", "six"}, 612 } 613 614 for n := 0; n < b.N; n++ { 615 ml(m, a) 616 } 617 }) 618 619 b.Run("array to map worst", func(b *testing.B) { 620 m := map[string]struct{}{ 621 "one": {}, 622 "two": {}, 623 "thr": {}, 624 "fou": {}, 625 "fiv": {}, 626 "six": {}, 627 } 628 629 a := [][]string{ 630 {"1ne", "2wo", "3hr", "4ou", "5iv", "6ix"}, 631 {"one", "2wo", "3hr", "4ou", "5iv", "6ix"}, 632 {"one", "two", "3hr", "4ou", "5iv", "6ix"}, 633 {"one", "two", "thr", "4ou", "5iv", "6ix"}, 634 {"one", "two", "thr", "fou", "5iv", "6ix"}, 635 {"one", "two", "thr", "fou", "fiv", "6ix"}, 636 {"one", "two", "thr", "fou", "fiv", "six"}, 637 } 638 639 for n := 0; n < b.N; n++ { 640 ml(m, a) 641 } 642 }) 643 644 //TODO: only way array lookup in array will help is if both are sorted, then maybe it's faster 645 } 646 647 func Test_parsePort(t *testing.T) { 648 _, _, err := parsePort("") 649 assert.EqualError(t, err, "was not a number; ``") 650 651 _, _, err = parsePort(" ") 652 assert.EqualError(t, err, "was not a number; ` `") 653 654 _, _, err = parsePort("-") 655 assert.EqualError(t, err, "appears to be a range but could not be parsed; `-`") 656 657 _, _, err = parsePort(" - ") 658 assert.EqualError(t, err, "appears to be a range but could not be parsed; ` - `") 659 660 _, _, err = parsePort("a-b") 661 assert.EqualError(t, err, "beginning range was not a number; `a`") 662 663 _, _, err = parsePort("1-b") 664 assert.EqualError(t, err, "ending range was not a number; `b`") 665 666 s, e, err := parsePort(" 1 - 2 ") 667 assert.Equal(t, int32(1), s) 668 assert.Equal(t, int32(2), e) 669 assert.Nil(t, err) 670 671 s, e, err = parsePort("0-1") 672 assert.Equal(t, int32(0), s) 673 assert.Equal(t, int32(0), e) 674 assert.Nil(t, err) 675 676 s, e, err = parsePort("9919") 677 assert.Equal(t, int32(9919), s) 678 assert.Equal(t, int32(9919), e) 679 assert.Nil(t, err) 680 681 s, e, err = parsePort("any") 682 assert.Equal(t, int32(0), s) 683 assert.Equal(t, int32(0), e) 684 assert.Nil(t, err) 685 } 686 687 func TestNewFirewallFromConfig(t *testing.T) { 688 l := test.NewLogger() 689 // Test a bad rule definition 690 c := &cert.NebulaCertificate{} 691 conf := config.NewC(l) 692 conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": "asdf"} 693 _, err := NewFirewallFromConfig(l, c, conf) 694 assert.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules") 695 696 // Test both port and code 697 conf = config.NewC(l) 698 conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "code": "2"}}} 699 _, err = NewFirewallFromConfig(l, c, conf) 700 assert.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided") 701 702 // Test missing host, group, cidr, ca_name and ca_sha 703 conf = config.NewC(l) 704 conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}} 705 _, err = NewFirewallFromConfig(l, c, conf) 706 assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided") 707 708 // Test code/port error 709 conf = config.NewC(l) 710 conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "a", "host": "testh"}}} 711 _, err = NewFirewallFromConfig(l, c, conf) 712 assert.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`") 713 714 conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "a", "host": "testh"}}} 715 _, err = NewFirewallFromConfig(l, c, conf) 716 assert.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`") 717 718 // Test proto error 719 conf = config.NewC(l) 720 conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "host": "testh"}}} 721 _, err = NewFirewallFromConfig(l, c, conf) 722 assert.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``") 723 724 // Test cidr parse error 725 conf = config.NewC(l) 726 conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}} 727 _, err = NewFirewallFromConfig(l, c, conf) 728 assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; invalid CIDR address: testh") 729 730 // Test local_cidr parse error 731 conf = config.NewC(l) 732 conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "local_cidr": "testh", "proto": "any"}}} 733 _, err = NewFirewallFromConfig(l, c, conf) 734 assert.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; invalid CIDR address: testh") 735 736 // Test both group and groups 737 conf = config.NewC(l) 738 conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}} 739 _, err = NewFirewallFromConfig(l, c, conf) 740 assert.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided") 741 } 742 743 func TestAddFirewallRulesFromConfig(t *testing.T) { 744 l := test.NewLogger() 745 // Test adding tcp rule 746 conf := config.NewC(l) 747 mf := &mockFirewall{} 748 conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}} 749 assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf)) 750 assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall) 751 752 // Test adding udp rule 753 conf = config.NewC(l) 754 mf = &mockFirewall{} 755 conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}} 756 assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf)) 757 assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall) 758 759 // Test adding icmp rule 760 conf = config.NewC(l) 761 mf = &mockFirewall{} 762 conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}} 763 assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf)) 764 assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall) 765 766 // Test adding any rule 767 conf = config.NewC(l) 768 mf = &mockFirewall{} 769 conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}} 770 assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) 771 assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall) 772 773 // Test adding rule with cidr 774 cidr := &net.IPNet{IP: net.ParseIP("10.0.0.0").To4(), Mask: net.IPv4Mask(255, 0, 0, 0)} 775 conf = config.NewC(l) 776 mf = &mockFirewall{} 777 conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "cidr": cidr.String()}}} 778 assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) 779 assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: nil}, mf.lastCall) 780 781 // Test adding rule with local_cidr 782 conf = config.NewC(l) 783 mf = &mockFirewall{} 784 conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "local_cidr": cidr.String()}}} 785 assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) 786 assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, localIp: cidr}, mf.lastCall) 787 788 // Test adding rule with ca_sha 789 conf = config.NewC(l) 790 mf = &mockFirewall{} 791 conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}} 792 assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) 793 assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, localIp: nil, caSha: "12312313123"}, mf.lastCall) 794 795 // Test adding rule with ca_name 796 conf = config.NewC(l) 797 mf = &mockFirewall{} 798 conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}} 799 assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) 800 assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, localIp: nil, caName: "root01"}, mf.lastCall) 801 802 // Test single group 803 conf = config.NewC(l) 804 mf = &mockFirewall{} 805 conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}} 806 assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) 807 assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil, localIp: nil}, mf.lastCall) 808 809 // Test single groups 810 conf = config.NewC(l) 811 mf = &mockFirewall{} 812 conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}} 813 assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) 814 assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil, localIp: nil}, mf.lastCall) 815 816 // Test multiple AND groups 817 conf = config.NewC(l) 818 mf = &mockFirewall{} 819 conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}} 820 assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) 821 assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: nil, localIp: nil}, mf.lastCall) 822 823 // Test Add error 824 conf = config.NewC(l) 825 mf = &mockFirewall{} 826 mf.nextCallReturn = errors.New("test error") 827 conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}} 828 assert.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; `test error`") 829 } 830 831 func TestFirewall_convertRule(t *testing.T) { 832 l := test.NewLogger() 833 ob := &bytes.Buffer{} 834 l.SetOutput(ob) 835 836 // Ensure group array of 1 is converted and a warning is printed 837 c := map[interface{}]interface{}{ 838 "group": []interface{}{"group1"}, 839 } 840 841 r, err := convertRule(l, c, "test", 1) 842 assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value") 843 assert.Nil(t, err) 844 assert.Equal(t, "group1", r.Group) 845 846 // Ensure group array of > 1 is errord 847 ob.Reset() 848 c = map[interface{}]interface{}{ 849 "group": []interface{}{"group1", "group2"}, 850 } 851 852 r, err = convertRule(l, c, "test", 1) 853 assert.Equal(t, "", ob.String()) 854 assert.Error(t, err, "group should contain a single value, an array with more than one entry was provided") 855 856 // Make sure a well formed group is alright 857 ob.Reset() 858 c = map[interface{}]interface{}{ 859 "group": "group1", 860 } 861 862 r, err = convertRule(l, c, "test", 1) 863 assert.Nil(t, err) 864 assert.Equal(t, "group1", r.Group) 865 } 866 867 type addRuleCall struct { 868 incoming bool 869 proto uint8 870 startPort int32 871 endPort int32 872 groups []string 873 host string 874 ip *net.IPNet 875 localIp *net.IPNet 876 caName string 877 caSha string 878 } 879 880 type mockFirewall struct { 881 lastCall addRuleCall 882 nextCallReturn error 883 } 884 885 func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error { 886 mf.lastCall = addRuleCall{ 887 incoming: incoming, 888 proto: proto, 889 startPort: startPort, 890 endPort: endPort, 891 groups: groups, 892 host: host, 893 ip: ip, 894 localIp: localIp, 895 caName: caName, 896 caSha: caSha, 897 } 898 899 err := mf.nextCallReturn 900 mf.nextCallReturn = nil 901 return err 902 } 903 904 func resetConntrack(fw *Firewall) { 905 fw.Conntrack.Lock() 906 fw.Conntrack.Conns = map[firewall.Packet]*conn{} 907 fw.Conntrack.Unlock() 908 }