github.com/aporeto-inc/trireme-lib@v10.358.0+incompatible/controller/internal/enforcer/nfqdatapath/datapath_test.go (about) 1 // +build linux 2 3 package nfqdatapath 4 5 import ( 6 "context" 7 "crypto/ecdsa" 8 "encoding/binary" 9 "errors" 10 "fmt" 11 "math/rand" 12 "net" 13 "reflect" 14 "strconv" 15 "testing" 16 "time" 17 18 "github.com/golang/mock/gomock" 19 . "github.com/smartystreets/goconvey/convey" 20 "go.aporeto.io/enforcerd/trireme-lib/collector" 21 "go.aporeto.io/enforcerd/trireme-lib/common" 22 "go.aporeto.io/enforcerd/trireme-lib/controller/constants" 23 enforcerconstants "go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/constants" 24 "go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/utils/packetgen" 25 "go.aporeto.io/enforcerd/trireme-lib/controller/pkg/connection" 26 "go.aporeto.io/enforcerd/trireme-lib/controller/pkg/counters" 27 "go.aporeto.io/enforcerd/trireme-lib/controller/pkg/flowtracking/mockflowclient" 28 "go.aporeto.io/enforcerd/trireme-lib/controller/pkg/packet" 29 "go.aporeto.io/enforcerd/trireme-lib/controller/pkg/packettracing" 30 "go.aporeto.io/enforcerd/trireme-lib/controller/pkg/pucontext" 31 "go.aporeto.io/enforcerd/trireme-lib/policy" 32 "go.aporeto.io/enforcerd/trireme-lib/utils/portspec" 33 "gotest.tools/assert" 34 ) 35 36 func TestEnforcerExternalNetworks(t *testing.T) { 37 38 ctrl := gomock.NewController(t) 39 defer ctrl.Finish() 40 41 testThePackets := func(enforcer *Datapath) { 42 43 PacketFlow := packetgen.NewTemplateFlow() 44 45 _, err := PacketFlow.GenerateTCPFlow(packetgen.PacketFlowTypeGoodFlowTemplate) 46 So(err, ShouldBeNil) 47 48 synackPacket, err := PacketFlow.GetFirstSynAckPacket().ToBytes() 49 So(err, ShouldBeNil) 50 51 tcpPacket, _ := packet.New(0, synackPacket, "0", true) 52 _, err1 := enforcer.processApplicationTCPPackets(tcpPacket) 53 So(err1, ShouldBeNil) 54 55 } 56 57 Convey("When the mode is RemoteConainter", t, func() { 58 59 enforcer, secrets, mockTokenAccessor, _, _ := NewWithMocks(ctrl, "serverID1", constants.RemoteContainer, []string{"0.0.0.0/0"}, true) 60 61 secrets.EXPECT().TransmittedKey().Return([]byte("dummy")).AnyTimes() 62 secrets.EXPECT().EncodingKey().Return(&ecdsa.PrivateKey{}).AnyTimes() 63 mockTokenAccessor.EXPECT().Sign(gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil).AnyTimes() 64 mockTokenAccessor.EXPECT().CreateSynPacketToken(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil) 65 66 iprules := policy.IPRuleList{policy.IPRule{ 67 Addresses: []string{"10.1.10.76/32"}, 68 Ports: []string{"80"}, 69 Protocols: []string{constants.TCPProtoNum}, 70 Policy: &policy.FlowPolicy{ 71 Action: policy.Accept, 72 PolicyID: "tcp172/8"}, 73 }} 74 75 contextID := "123456" 76 puInfo := policy.NewPUInfo(contextID, "/ns1", common.LinuxProcessPU) 77 78 context, err := pucontext.NewPU(contextID, puInfo, mockTokenAccessor, 10*time.Second) 79 So(err, ShouldBeNil) 80 enforcer.puFromContextID.AddOrUpdate(contextID, context) 81 s, _ := portspec.NewPortSpec(80, 80, contextID) 82 enforcer.contextIDFromTCPPort.AddPortSpec(s) 83 84 err = context.UpdateNetworkACLs(iprules) 85 So(err, ShouldBeNil) 86 87 testThePackets(enforcer) 88 }) 89 } 90 91 func TestInvalidContext(t *testing.T) { 92 93 ctrl := gomock.NewController(t) 94 defer ctrl.Finish() 95 96 defer MockGetUDPRawSocket()() 97 98 Convey("Given I create a new enforcer instance", t, func() { 99 100 enforcer, _, _, _, _ := NewWithMocks(ctrl, "serverID", constants.LocalServer, []string{"0.0.0.0/0"}, true) 101 102 PacketFlow := packetgen.NewTemplateFlow() 103 _, err := PacketFlow.GenerateTCPFlow(packetgen.PacketFlowTypeGoodFlowTemplate) 104 So(err, ShouldBeNil) 105 synPacket, err := PacketFlow.GetFirstSynPacket().ToBytes() 106 So(err, ShouldBeNil) 107 tcpPacket, err := packet.New(0, synPacket, "0", true) 108 Convey("When I run a TCP Syn packet through a non existing context", func() { 109 110 _, err1 := enforcer.processApplicationTCPPackets(tcpPacket) 111 _, _, err2 := enforcer.processNetworkTCPPackets(tcpPacket) 112 113 Convey("Then I should see an error for non existing context", func() { 114 115 So(err, ShouldBeNil) 116 So(err1, ShouldNotBeNil) 117 So(err2, ShouldNotBeNil) 118 }) 119 }) 120 }) 121 } 122 123 func TestPacketHandlingFirstThreePacketsHavePayload(t *testing.T) { 124 125 ctrl := gomock.NewController(t) 126 defer ctrl.Finish() 127 128 testThePackets := func(enforcer *Datapath) { 129 SIP := net.IPv4zero 130 firstSynAckProcessed := false 131 PacketFlow := packetgen.NewTemplateFlow() 132 _, err := PacketFlow.GenerateTCPFlow(packetgen.PacketFlowTypeGoodFlowTemplate) 133 So(err, ShouldBeNil) 134 for i := 0; i < PacketFlow.GetNumPackets(); i++ { 135 oldPacketFromFlow, err := PacketFlow.GetNthPacket(i).ToBytes() 136 So(err, ShouldBeNil) 137 oldPacket, err := packet.New(0, oldPacketFromFlow, "0", true) 138 if err == nil && oldPacket != nil { 139 oldPacket.UpdateIPv4Checksum() 140 oldPacket.UpdateTCPChecksum() 141 } 142 tcpPacketFromFlow, err := PacketFlow.GetNthPacket(i).ToBytes() 143 So(err, ShouldBeNil) 144 tcpPacket, err := packet.New(0, tcpPacketFromFlow, "0", true) 145 if err == nil && tcpPacket != nil { 146 tcpPacket.UpdateIPv4Checksum() 147 tcpPacket.UpdateTCPChecksum() 148 } 149 if debug { 150 fmt.Println("Input packet", i) 151 tcpPacket.Print(0, false) 152 } 153 154 So(err, ShouldBeNil) 155 So(tcpPacket, ShouldNotBeNil) 156 157 if reflect.DeepEqual(SIP, net.IPv4zero) { 158 SIP = tcpPacket.SourceAddress() 159 } 160 if !reflect.DeepEqual(SIP, tcpPacket.DestinationAddress()) && 161 !reflect.DeepEqual(SIP, tcpPacket.SourceAddress()) { 162 t.Error("Invalid Test Packet") 163 } 164 165 _, err = enforcer.processApplicationTCPPackets(tcpPacket) 166 So(err, ShouldBeNil) 167 168 if debug { 169 fmt.Println("Intermediate packet", i) 170 tcpPacket.Print(0, false) 171 } 172 173 if tcpPacket.GetTCPFlags()&packet.TCPSynMask != 0 { 174 Convey("When I pass a packet with SYN or SYN/ACK flags for packet "+strconv.Itoa(i), func() { 175 Convey("Then I expect some data payload to exist on the packet "+strconv.Itoa(i), func() { 176 // In our 3 way security handshake syn and syn-ack packet should grow in length 177 So(tcpPacket.IPTotalLen(), ShouldBeGreaterThan, oldPacket.IPTotalLen()) 178 }) 179 }) 180 } 181 182 if !firstSynAckProcessed && tcpPacket.GetTCPFlags()&packet.TCPSynAckMask == packet.TCPAckMask { 183 firstSynAckProcessed = true 184 Convey("When I pass the first packet with ACK flag for packet "+strconv.Itoa(i), func() { 185 Convey("Then I expect some data payload to exist on the packet "+strconv.Itoa(i), func() { 186 // In our 3 way security handshake first ack packet should grow in length 187 So(tcpPacket.IPTotalLen(), ShouldBeGreaterThan, oldPacket.IPTotalLen()) 188 }) 189 }) 190 } 191 192 output := make([]byte, len(tcpPacket.GetTCPBytes())) 193 copy(output, tcpPacket.GetTCPBytes()) 194 195 outPacket, errp := packet.New(0, output, "0", true) 196 So(len(tcpPacket.GetTCPBytes()), ShouldBeLessThanOrEqualTo, len(outPacket.GetTCPBytes())) 197 So(errp, ShouldBeNil) 198 199 _, f, err := enforcer.processNetworkTCPPackets(outPacket) 200 if f != nil { 201 f() 202 } 203 204 So(err, ShouldBeNil) 205 206 if debug { 207 fmt.Println("Output packet", i) 208 outPacket.Print(0, false) 209 } 210 } 211 } 212 213 flowRecord := CreateFlowRecord(1, testSrcIP, testDstIP, 666, 80, policy.Accept, "") 214 215 Convey("When the mode is RemoteConainter", t, func() { 216 217 enforcer, mockCollector := createEnforcerWithPolicy(ctrl, constants.RemoteContainer) 218 mockCollector.EXPECT().CollectFlowEvent(MyMatcher(&flowRecord)).Times(1) 219 testThePackets(enforcer) 220 221 }) 222 223 Convey("When the mode is LocalServer", t, func() { 224 225 enforcer, mockCollector := createEnforcerWithPolicy(ctrl, constants.LocalServer) 226 mockCollector.EXPECT().CollectFlowEvent(MyMatcher(&flowRecord)).Times(1) 227 testThePackets(enforcer) 228 229 }) 230 } 231 232 func TestInvalidIPContext(t *testing.T) { 233 234 ctrl := gomock.NewController(t) 235 defer ctrl.Finish() 236 237 defer MockGetUDPRawSocket()() 238 239 Convey("Given I create a new enforcer instance", t, func() { 240 241 enforcer, secrets, mockTokenAccessor, mockCollector, mockDNS := NewWithMocks(ctrl, "serverID", constants.LocalServer, []string{"0.0.0.0/0"}, true) 242 243 secrets.EXPECT().TransmittedKey().Return([]byte("dummy")).AnyTimes() 244 secrets.EXPECT().EncodingKey().Return(&ecdsa.PrivateKey{}).AnyTimes() 245 mockTokenAccessor.EXPECT().Sign(gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil).AnyTimes() 246 mockTokenAccessor.EXPECT().CreateSynPacketToken(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil) 247 mockDNS.EXPECT().StartDNSServer(gomock.Any(), gomock.Any(), gomock.Any()).Times(1) 248 mockDNS.EXPECT().Enforce(gomock.Any(), gomock.Any(), gomock.Any()).Times(1) 249 mockDNS.EXPECT().Unenforce(gomock.Any(), gomock.Any()).Times(1) 250 mockDNS.EXPECT().SyncWithPlatformCache(gomock.Any(), gomock.Any()).Times(1) 251 252 puInfo := policy.NewPUInfo("SomeProcessingUnitId", "/ns2", common.LinuxProcessPU) 253 254 CounterReport := &collector.CounterReport{ 255 PUID: puInfo.Policy.ManagementID(), 256 Namespace: puInfo.Policy.ManagementNamespace(), 257 } 258 mockCollector.EXPECT().CollectCounterEvent(MyCounterMatcher(CounterReport)).MinTimes(1) 259 260 enforcer.Enforce(context.Background(), "serverID", puInfo) // nolint 261 defer func() { 262 if err := enforcer.Unenforce(context.Background(), "serverID"); err != nil { 263 fmt.Println("Error", err.Error()) 264 } 265 }() 266 267 PacketFlow := packetgen.NewTemplateFlow() 268 _, err := PacketFlow.GenerateTCPFlow(packetgen.PacketFlowTypeMultipleGoodFlow) 269 So(err, ShouldBeNil) 270 synPacket, err := PacketFlow.GetFirstSynPacket().ToBytes() 271 So(err, ShouldBeNil) 272 tcpPacket, err := packet.New(0, synPacket, "0", true) 273 274 Convey("When I run a TCP Syn packet through an invalid existing context (missing IP)", func() { 275 276 _, err1 := enforcer.processApplicationTCPPackets(tcpPacket) 277 _, _, err2 := enforcer.processNetworkTCPPackets(tcpPacket) 278 279 Convey("Then I should see an error for missing IP", func() { 280 281 So(err, ShouldBeNil) 282 So(err1, ShouldNotBeNil) 283 So(err2, ShouldNotBeNil) 284 }) 285 }) 286 }) 287 } 288 289 // TestEnforcerConnUnknownState test ensures that enforcer closes the 290 // connection by converting packets to rst when it finds connection 291 // to be in unknown state. This happens when enforcer has not seen the 292 // 3way handshake for a connection. 293 func TestEnforcerConnUnknownState(t *testing.T) { 294 295 ctrl := gomock.NewController(t) 296 defer ctrl.Finish() 297 298 testThePackets := func(enforcer *Datapath) { 299 Convey("If I send an ack packet from either PU to the other, it is converted into a Fin/Ack", func() { 300 PacketFlow := packetgen.NewTemplateFlow() 301 _, err := PacketFlow.GenerateTCPFlow(packetgen.PacketFlowTypeGoodFlowTemplate) 302 So(err, ShouldBeNil) 303 304 input, err := PacketFlow.GetFirstAckPacket().ToBytes() 305 So(err, ShouldBeNil) 306 307 tcpPacket, err := packet.New(0, input, "0", true) 308 // create a copy of the ack packet 309 tcpPacketCopy := *tcpPacket 310 311 if err == nil && tcpPacket != nil { 312 tcpPacket.UpdateIPv4Checksum() 313 tcpPacket.UpdateTCPChecksum() 314 } 315 316 _, err1 := enforcer.processApplicationTCPPackets(tcpPacket) 317 318 // Test whether the packet is modified with Fin/Ack 319 if tcpPacket.GetTCPFlags() != 0x04 { 320 t.Fail() 321 } 322 323 _, _, err2 := enforcer.processNetworkTCPPackets(&tcpPacketCopy) 324 325 if tcpPacket.GetTCPFlags() != 0x04 { 326 t.Fail() 327 } 328 329 So(err1, ShouldBeNil) 330 So(err2, ShouldBeNil) 331 }) 332 } 333 334 Convey("When the mode is RemoteConainter", t, func() { 335 336 enforcer, _ := createEnforcerWithPolicy(ctrl, constants.RemoteContainer) 337 testThePackets(enforcer) 338 339 }) 340 341 Convey("When the mode is LocalServer", t, func() { 342 343 enforcer, _ := createEnforcerWithPolicy(ctrl, constants.LocalServer) 344 testThePackets(enforcer) 345 346 }) 347 } 348 349 func TestInvalidTokenContext(t *testing.T) { 350 351 ctrl := gomock.NewController(t) 352 defer ctrl.Finish() 353 354 defer MockGetUDPRawSocket()() 355 356 testThePackets := func(enforcer *Datapath) { 357 358 PacketFlow := packetgen.NewTemplateFlow() 359 _, err := PacketFlow.GenerateTCPFlow(packetgen.PacketFlowTypeGoodFlowTemplate) 360 So(err, ShouldBeNil) 361 synPacket, err := PacketFlow.GetFirstSynPacket().ToBytes() 362 So(err, ShouldBeNil) 363 tcpPacket, err := packet.New(0, synPacket, "0", true) 364 365 Convey("When I run a TCP Syn packet through an invalid existing context (missing IP)", func() { 366 367 _, err1 := enforcer.processApplicationTCPPackets(tcpPacket) 368 _, _, err2 := enforcer.processNetworkTCPPackets(tcpPacket) 369 370 Convey("Then I should see an error for missing Token", func() { 371 372 So(err, ShouldBeNil) 373 So(err1, ShouldNotBeNil) 374 So(err2, ShouldNotBeNil) 375 }) 376 }) 377 } 378 379 Convey("Given I create a new enforcer instance", t, func() { 380 381 puInfo := policy.NewPUInfo("SomeProcessingUnitId", "/ns2", common.LinuxProcessPU) 382 383 ip := policy.ExtendedMap{ 384 "brige": testDstIP, 385 } 386 puInfo.Runtime.SetIPAddresses(ip) 387 388 enforcer, secrets, mockTokenAccessor, _, mockDNS := NewWithMocks(ctrl, "serverID", constants.LocalServer, []string{"0.0.0.0/0"}, true) 389 390 secrets.EXPECT().TransmittedKey().Return([]byte("dummy")).AnyTimes() 391 secrets.EXPECT().EncodingKey().Return(&ecdsa.PrivateKey{}).AnyTimes() 392 mockTokenAccessor.EXPECT().Sign(gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil).AnyTimes() 393 mockTokenAccessor.EXPECT().CreateSynPacketToken(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil) 394 mockDNS.EXPECT().StartDNSServer(gomock.Any(), gomock.Any(), gomock.Any()).Times(1) 395 mockDNS.EXPECT().Enforce(gomock.Any(), gomock.Any(), gomock.Any()).Times(1) 396 mockDNS.EXPECT().SyncWithPlatformCache(gomock.Any(), gomock.Any()).Times(1) 397 398 enforcer.Enforce(context.Background(), "serverID", puInfo) // nolint 399 400 testThePackets(enforcer) 401 }) 402 } 403 404 func TestPacketHandlingDstPortCacheBehavior(t *testing.T) { 405 406 ctrl := gomock.NewController(t) 407 defer ctrl.Finish() 408 409 testThePackets := func(enforcer *Datapath) { 410 411 SIP := net.IPv4zero 412 413 Convey("When I pass multiple packets through the enforcer", func() { 414 415 PacketFlow := packetgen.NewTemplateFlow() 416 _, err := PacketFlow.GenerateTCPFlow(packetgen.PacketFlowTypeGoodFlowTemplate) 417 So(err, ShouldBeNil) 418 for i := 0; i < PacketFlow.GetNumPackets(); i++ { 419 oldPacketFromFlow, err := PacketFlow.GetNthPacket(i).ToBytes() 420 So(err, ShouldBeNil) 421 oldPacket, err := packet.New(0, oldPacketFromFlow, "0", true) 422 if err == nil && oldPacket != nil { 423 oldPacket.UpdateIPv4Checksum() 424 oldPacket.UpdateTCPChecksum() 425 } 426 tcpPacketFromFlow, err := PacketFlow.GetNthPacket(i).ToBytes() 427 So(err, ShouldBeNil) 428 tcpPacket, err := packet.New(0, tcpPacketFromFlow, "0", true) 429 if err == nil && tcpPacket != nil { 430 tcpPacket.UpdateIPv4Checksum() 431 tcpPacket.UpdateTCPChecksum() 432 } 433 434 if debug { 435 fmt.Println("Input packet", i) 436 tcpPacket.Print(0, false) 437 } 438 439 So(err, ShouldBeNil) 440 So(tcpPacket, ShouldNotBeNil) 441 442 if reflect.DeepEqual(SIP, net.IPv4zero) { 443 SIP = tcpPacket.SourceAddress() 444 } 445 if !reflect.DeepEqual(SIP, tcpPacket.DestinationAddress()) && 446 !reflect.DeepEqual(SIP, tcpPacket.SourceAddress()) { 447 t.Error("Invalid Test Packet") 448 } 449 450 _, err = enforcer.processApplicationTCPPackets(tcpPacket) 451 So(err, ShouldBeNil) 452 453 if debug { 454 fmt.Println("Intermediate packet", i) 455 tcpPacket.Print(0, false) 456 } 457 458 output := make([]byte, len(tcpPacket.GetTCPBytes())) 459 copy(output, tcpPacket.GetTCPBytes()) 460 461 outPacket, errp := packet.New(0, output, "0", true) 462 So(len(tcpPacket.GetTCPBytes()), ShouldBeLessThanOrEqualTo, len(outPacket.GetTCPBytes())) 463 So(errp, ShouldBeNil) 464 _, f, err := enforcer.processNetworkTCPPackets(outPacket) 465 if f != nil { 466 f() 467 } 468 469 So(err, ShouldBeNil) 470 471 if debug { 472 fmt.Println("Output packet", i) 473 outPacket.Print(0, false) 474 } 475 } 476 }) 477 } 478 479 flowRecord := CreateFlowRecord(1, testSrcIP, testDstIP, 666, 80, policy.Accept, "") 480 481 Convey("When the mode is RemoteConainter", t, func() { 482 483 enforcer, mockCollector := createEnforcerWithPolicy(ctrl, constants.RemoteContainer) 484 mockCollector.EXPECT().CollectFlowEvent(MyMatcher(&flowRecord)).Times(1) 485 testThePackets(enforcer) 486 487 }) 488 489 Convey("When the mode is LocalServer", t, func() { 490 491 enforcer, mockCollector := createEnforcerWithPolicy(ctrl, constants.LocalServer) 492 mockCollector.EXPECT().CollectFlowEvent(MyMatcher(&flowRecord)).Times(1) 493 testThePackets(enforcer) 494 495 }) 496 } 497 498 func TestAckLost(t *testing.T) { 499 500 ctrl := gomock.NewController(t) 501 defer ctrl.Finish() 502 503 testThePackets := func(enforcer *Datapath) { 504 PacketFlow := packetgen.NewTemplateFlow() 505 _, err := PacketFlow.GenerateTCPFlow(packetgen.PacketFlowTypeGoodFlowTemplate) 506 So(err, ShouldBeNil) 507 508 synPacket, err := PacketFlow.GetFirstSynPacket().ToBytes() 509 So(err, ShouldBeNil) 510 tcpPacket, err := packet.New(0, synPacket, "0", true) 511 if err == nil && tcpPacket != nil { 512 tcpPacket.UpdateIPv4Checksum() 513 tcpPacket.UpdateTCPChecksum() 514 } 515 516 _, err = enforcer.processApplicationTCPPackets(tcpPacket) 517 So(err, ShouldBeNil) 518 519 output := make([]byte, len(tcpPacket.GetTCPBytes())) 520 copy(output, tcpPacket.GetTCPBytes()) 521 522 outPacket, errp := packet.New(0, output, "0", true) 523 So(len(tcpPacket.GetTCPBytes()), ShouldBeLessThanOrEqualTo, len(outPacket.GetTCPBytes())) 524 So(errp, ShouldBeNil) 525 526 _, f, err := enforcer.processNetworkTCPPackets(outPacket) 527 if f != nil { 528 f() 529 } 530 531 So(err, ShouldBeNil) 532 533 input, _ := PacketFlow.GetFirstSynAckPacket().ToBytes() 534 535 tcpPacket, _ = packet.New(0, input, "0", true) 536 if tcpPacket != nil { 537 tcpPacket.UpdateIPv4Checksum() 538 tcpPacket.UpdateTCPChecksum() 539 } 540 541 _, err = enforcer.processApplicationTCPPackets(tcpPacket) 542 So(err, ShouldBeNil) 543 544 output = make([]byte, len(tcpPacket.GetTCPBytes())) 545 copy(output, tcpPacket.GetTCPBytes()) 546 547 outPacket, _ = packet.New(0, output, "0", true) 548 _, f, err = enforcer.processNetworkTCPPackets(outPacket) 549 if f != nil { 550 f() 551 } 552 So(err, ShouldBeNil) 553 554 input, _ = PacketFlow.GetFirstAckPacket().ToBytes() 555 tcpPacket, _ = packet.New(0, input, "0", true) 556 if tcpPacket != nil { 557 tcpPacket.UpdateIPv4Checksum() 558 tcpPacket.UpdateTCPChecksum() 559 } 560 561 _, err = enforcer.processApplicationTCPPackets(tcpPacket) 562 So(err, ShouldBeNil) 563 //simulate drop, and re-transmit packets. 564 565 input, _ = PacketFlow.GetFirstSynAckPacket().ToBytes() 566 567 tcpPacket, _ = packet.New(0, input, "0", true) 568 if tcpPacket != nil { 569 tcpPacket.UpdateIPv4Checksum() 570 tcpPacket.UpdateTCPChecksum() 571 } 572 573 _, err = enforcer.processApplicationTCPPackets(tcpPacket) 574 assert.Equal(t, err, nil, "error should be nil") 575 576 output = make([]byte, len(tcpPacket.GetTCPBytes())) 577 copy(output, tcpPacket.GetTCPBytes()) 578 579 outPacket, _ = packet.New(0, output, "0", true) 580 _, f, err = enforcer.processNetworkTCPPackets(outPacket) 581 if f != nil { 582 f() 583 } 584 assert.Equal(t, err, nil, "error should be nil") 585 586 input, _ = PacketFlow.GetFirstAckPacket().ToBytes() 587 588 tcpPacket, _ = packet.New(0, input, "0", true) 589 if tcpPacket != nil { 590 tcpPacket.UpdateIPv4Checksum() 591 tcpPacket.UpdateTCPChecksum() 592 } 593 594 _, err = enforcer.processApplicationTCPPackets(tcpPacket) 595 assert.Equal(t, err, nil, "error should be nil") 596 597 output = make([]byte, len(tcpPacket.GetTCPBytes())) 598 copy(output, tcpPacket.GetTCPBytes()) 599 600 outPacket, _ = packet.New(0, output, "0", true) 601 602 _, f, err = enforcer.processNetworkTCPPackets(outPacket) 603 if f != nil { 604 f() 605 } 606 607 assert.Equal(t, err, nil, "error should be nil") 608 609 } 610 611 flowRecord := CreateFlowRecord(1, testSrcIP, testDstIP, 666, 80, policy.Accept, "") 612 613 Convey("When the mode is RemoteConainter", t, func() { 614 615 enforcer, mockCollector := createEnforcerWithPolicy(ctrl, constants.RemoteContainer) 616 flowclient := mockflowclient.NewMockFlowClient(ctrl) 617 flowclient.EXPECT().UpdateApplicationFlowMark(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(nil).AnyTimes() 618 enforcer.conntrack = flowclient 619 mockCollector.EXPECT().CollectFlowEvent(MyMatcher(&flowRecord)).Times(1) 620 621 testThePackets(enforcer) 622 623 }) 624 } 625 626 func TestConnectionTrackerStateLocalContainer(t *testing.T) { 627 628 ctrl := gomock.NewController(t) 629 defer ctrl.Finish() 630 631 testThePackets := func(enforcer *Datapath) { 632 633 PacketFlow := packetgen.NewTemplateFlow() 634 _, err := PacketFlow.GenerateTCPFlow(packetgen.PacketFlowTypeGoodFlowTemplate) 635 So(err, ShouldBeNil) 636 /*first packet in TCPFLOW slice is a syn packet*/ 637 Convey("When i pass a syn packet through the enforcer", func() { 638 639 input, err := PacketFlow.GetFirstSynPacket().ToBytes() 640 So(err, ShouldBeNil) 641 642 tcpPacket, err := packet.New(0, input, "0", true) 643 if err == nil && tcpPacket != nil { 644 tcpPacket.UpdateIPv4Checksum() 645 tcpPacket.UpdateTCPChecksum() 646 } 647 648 _, err = enforcer.processApplicationTCPPackets(tcpPacket) 649 //After sending syn packet 650 CheckAfterAppSynPacket(enforcer, tcpPacket) 651 So(err, ShouldBeNil) 652 output := make([]byte, len(tcpPacket.GetTCPBytes())) 653 copy(output, tcpPacket.GetTCPBytes()) 654 655 outPacket, err := packet.New(0, output, "0", true) 656 So(err, ShouldBeNil) 657 _, f, err := enforcer.processNetworkTCPPackets(outPacket) 658 if f != nil { 659 f() 660 } 661 So(err, ShouldBeNil) 662 //Check after processing networksyn packet 663 CheckAfterNetSynPacket(enforcer, tcpPacket, outPacket) 664 665 }) 666 Convey("When i pass a SYN and SYN ACK packet through the enforcer", func() { 667 668 input, err := PacketFlow.GetFirstSynPacket().ToBytes() 669 So(err, ShouldBeNil) 670 671 tcpPacket, err := packet.New(0, input, "0", true) 672 if err == nil && tcpPacket != nil { 673 tcpPacket.UpdateIPv4Checksum() 674 tcpPacket.UpdateTCPChecksum() 675 } 676 _, err = enforcer.processApplicationTCPPackets(tcpPacket) 677 So(err, ShouldBeNil) 678 679 output := make([]byte, len(tcpPacket.GetTCPBytes())) 680 copy(output, tcpPacket.GetTCPBytes()) 681 682 outPacket, err := packet.New(0, output, "0", true) 683 So(err, ShouldBeNil) 684 outPacket.Print(0, false) 685 _, f, err := enforcer.processNetworkTCPPackets(outPacket) 686 if f != nil { 687 f() 688 } 689 So(err, ShouldBeNil) 690 691 //Now lets send the synack packet from the server in response 692 input, err = PacketFlow.GetFirstSynAckPacket().ToBytes() 693 So(err, ShouldBeNil) 694 695 tcpPacket, err = packet.New(0, input, "0", true) 696 if err == nil && tcpPacket != nil { 697 tcpPacket.UpdateIPv4Checksum() 698 tcpPacket.UpdateTCPChecksum() 699 } 700 _, err = enforcer.processApplicationTCPPackets(tcpPacket) 701 So(err, ShouldBeNil) 702 703 output = make([]byte, len(tcpPacket.GetTCPBytes())) 704 copy(output, tcpPacket.GetTCPBytes()) 705 706 outPacket, err = packet.New(0, output, "0", true) 707 So(err, ShouldBeNil) 708 outPacketcopy, _ := packet.New(0, output, "0", true) 709 _, f, err = enforcer.processNetworkTCPPackets(outPacket) 710 if f != nil { 711 f() 712 } 713 So(err, ShouldBeNil) 714 715 CheckAfterNetSynAckPacket(t, enforcer, outPacketcopy, outPacket) 716 }) 717 718 Convey("When i pass a SYN and SYNACK and another ACK packet through the enforcer", func() { 719 720 input, err := PacketFlow.GetFirstSynPacket().ToBytes() 721 So(err, ShouldBeNil) 722 tcpPacket, err := packet.New(0, input, "0", true) 723 if err == nil && tcpPacket != nil { 724 tcpPacket.UpdateIPv4Checksum() 725 tcpPacket.UpdateTCPChecksum() 726 } 727 _, err = enforcer.processApplicationTCPPackets(tcpPacket) 728 So(err, ShouldBeNil) 729 730 output := make([]byte, len(tcpPacket.GetTCPBytes())) 731 copy(output, tcpPacket.GetTCPBytes()) 732 733 outPacket, err := packet.New(0, output, "0", true) 734 So(err, ShouldBeNil) 735 _, f, err := enforcer.processNetworkTCPPackets(outPacket) 736 if f != nil { 737 f() 738 } 739 So(err, ShouldBeNil) 740 741 //Now lets send the synack packet from the server in response 742 input, err = PacketFlow.GetFirstSynAckPacket().ToBytes() 743 So(err, ShouldBeNil) 744 745 tcpPacket, err = packet.New(0, input, "0", true) 746 if err == nil && tcpPacket != nil { 747 tcpPacket.UpdateIPv4Checksum() 748 tcpPacket.UpdateTCPChecksum() 749 } 750 _, err = enforcer.processApplicationTCPPackets(tcpPacket) 751 So(err, ShouldBeNil) 752 753 output = make([]byte, len(tcpPacket.GetTCPBytes())) 754 copy(output, tcpPacket.GetTCPBytes()) 755 756 outPacket, err = packet.New(0, output, "0", true) 757 So(err, ShouldBeNil) 758 _, f, err = enforcer.processNetworkTCPPackets(outPacket) 759 if f != nil { 760 f() 761 } 762 So(err, ShouldBeNil) 763 764 input, err = PacketFlow.GetFirstAckPacket().ToBytes() 765 So(err, ShouldBeNil) 766 767 tcpPacket, err = packet.New(0, input, "0", true) 768 if err == nil && tcpPacket != nil { 769 tcpPacket.UpdateIPv4Checksum() 770 tcpPacket.UpdateTCPChecksum() 771 } 772 _, err = enforcer.processApplicationTCPPackets(tcpPacket) 773 CheckAfterAppAckPacket(enforcer, tcpPacket) 774 So(err, ShouldBeNil) 775 776 output = make([]byte, len(tcpPacket.GetTCPBytes())) 777 copy(output, tcpPacket.GetTCPBytes()) 778 779 outPacket, err = packet.New(0, output, "0", true) 780 So(err, ShouldBeNil) 781 CheckBeforeNetAckPacket(enforcer, tcpPacket, outPacket, false) 782 _, f, err = enforcer.processNetworkTCPPackets(outPacket) 783 if f != nil { 784 f() 785 } 786 So(err, ShouldBeNil) 787 }) 788 } 789 790 flowRecord := CreateFlowRecord(1, testSrcIP, testDstIP, 666, 80, policy.Accept, "") 791 792 Convey("When the mode is RemoteConainter", t, func() { 793 794 enforcer, mockCollector := createEnforcerWithPolicy(ctrl, constants.RemoteContainer) 795 mockCollector.EXPECT().CollectFlowEvent(MyMatcher(&flowRecord)).AnyTimes() 796 testThePackets(enforcer) 797 798 }) 799 800 Convey("When the mode is LocalServer", t, func() { 801 802 enforcer, mockCollector := createEnforcerWithPolicy(ctrl, constants.LocalServer) 803 mockCollector.EXPECT().CollectFlowEvent(MyMatcher(&flowRecord)).AnyTimes() 804 testThePackets(enforcer) 805 806 }) 807 } 808 809 func CheckAfterAppSynPacket(enforcer *Datapath, tcpPacket *packet.Packet) { 810 811 appConn, _ := enforcer.tcpClient.Get(tcpPacket.L4FlowHash()) 812 So(appConn.GetState(), ShouldEqual, connection.TCPSynSend) 813 } 814 815 func CheckAfterNetSynPacket(enforcer *Datapath, tcpPacket, outPacket *packet.Packet) { 816 817 appConn, _ := enforcer.tcpServer.Get(tcpPacket.L4FlowHash()) 818 So(appConn.GetState(), ShouldEqual, connection.TCPSynReceived) 819 } 820 821 func CheckAfterNetSynAckPacket(t *testing.T, enforcer *Datapath, tcpPacket, outPacket *packet.Packet) { 822 823 netconn, _ := enforcer.tcpClient.Get(outPacket.L4ReverseFlowHash()) 824 So(netconn.GetState(), ShouldEqual, connection.TCPSynAckReceived) 825 } 826 827 func CheckAfterAppAckPacket(enforcer *Datapath, tcpPacket *packet.Packet) { 828 829 appConn, _ := enforcer.tcpClient.Get(tcpPacket.L4FlowHash()) 830 So(appConn.GetState(), ShouldEqual, connection.TCPAckSend) 831 } 832 833 func CheckBeforeNetAckPacket(enforcer *Datapath, tcpPacket, outPacket *packet.Packet, isReplay bool) { 834 835 appConn, _ := enforcer.tcpServer.Get(tcpPacket.L4FlowHash()) 836 if !isReplay { 837 So(appConn.GetState(), ShouldEqual, connection.TCPSynAckSend) 838 } else { 839 So(appConn.GetState(), ShouldBeGreaterThan, connection.TCPSynAckSend) 840 } 841 } 842 843 func TestCacheState(t *testing.T) { 844 845 ctrl := gomock.NewController(t) 846 defer ctrl.Finish() 847 848 defer MockGetUDPRawSocket()() 849 850 Convey("Given I create a new enforcer instance", t, func() { 851 852 enforcer, secrets, mockTokenAccessor, mockCollector, mockDNS := NewWithMocks(ctrl, "serverID", constants.LocalServer, []string{"0.0.0.0/0"}, true) 853 854 secrets.EXPECT().EncodingKey().Return(&ecdsa.PrivateKey{}).AnyTimes() 855 mockTokenAccessor.EXPECT().Sign(gomock.Any(), gomock.Any()).Times(2).Return([]byte("token"), nil).AnyTimes() 856 mockTokenAccessor.EXPECT().CreateSynPacketToken(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil).Times(2) 857 mockDNS.EXPECT().StartDNSServer(gomock.Any(), gomock.Any(), gomock.Any()).Times(1) 858 mockDNS.EXPECT().Enforce(gomock.Any(), gomock.Any(), gomock.Any()).Times(2) 859 mockDNS.EXPECT().Unenforce(gomock.Any(), gomock.Any()).Times(1) 860 mockDNS.EXPECT().SyncWithPlatformCache(gomock.Any(), gomock.Any()).Times(2) 861 862 contextID := "123" 863 864 puInfo := policy.NewPUInfo(contextID, "/ns1", common.ContainerPU) 865 866 CounterReport := &collector.CounterReport{ 867 PUID: puInfo.Policy.ManagementID(), 868 Namespace: puInfo.Policy.ManagementNamespace(), 869 } 870 mockCollector.EXPECT().CollectCounterEvent(MyCounterMatcher(CounterReport)).Times(2) 871 872 // Should fail: Not in cache 873 err := enforcer.Unenforce(context.Background(), contextID) 874 if err == nil { 875 t.Errorf("Expected failure, no contextID in cache") 876 } 877 878 ip := policy.ExtendedMap{"bridge": "127.0.0.1"} 879 puInfo.Runtime.SetIPAddresses(ip) 880 ipl := policy.ExtendedMap{"bridge": "127.0.0.1"} 881 puInfo.Policy.SetIPAddresses(ipl) 882 883 ip = policy.ExtendedMap{"bridge": "127.0.0.1"} 884 puInfo.Runtime.SetIPAddresses(ip) 885 886 ipl = policy.ExtendedMap{"bridge": "127.0.0.1"} 887 puInfo.Policy.SetIPAddresses(ipl) 888 889 // Should not fail: IP is valid 890 err = enforcer.Enforce(context.Background(), contextID, puInfo) 891 if err != nil { 892 t.Errorf("Expected no failure %s", err) 893 } 894 895 // Should not fail: Update 896 err = enforcer.Enforce(context.Background(), contextID, puInfo) 897 if err != nil { 898 t.Errorf("Expected no failure %s", err) 899 } 900 901 // Should not fail: IP is valid 902 err = enforcer.Unenforce(context.Background(), contextID) 903 if err != nil { 904 t.Errorf("Expected failure, no IP but passed %s", err) 905 } 906 }) 907 } 908 909 func TestDoCreatePU(t *testing.T) { 910 911 ctrl := gomock.NewController(t) 912 defer ctrl.Finish() 913 914 defer MockGetUDPRawSocket()() 915 916 Convey("Given an initialized enforcer for Linux Processes", t, func() { 917 918 defer MockGetUDPRawSocket()() 919 920 enforcer, secrets, mockTokenAccessor, _, mockDNS := NewWithMocks(ctrl, "serverID", constants.LocalServer, []string{"0.0.0.0/0"}, true) 921 922 secrets.EXPECT().EncodingKey().Return(&ecdsa.PrivateKey{}).AnyTimes() 923 mockTokenAccessor.EXPECT().Sign(gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil).AnyTimes() 924 mockTokenAccessor.EXPECT().CreateSynPacketToken(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil) 925 mockDNS.EXPECT().StartDNSServer(gomock.Any(), gomock.Any(), gomock.Any()).Times(1) 926 mockDNS.EXPECT().Enforce(gomock.Any(), gomock.Any(), gomock.Any()).Times(1) 927 mockDNS.EXPECT().SyncWithPlatformCache(gomock.Any(), gomock.Any()).Times(1) 928 929 contextID := "124" 930 puInfo := policy.NewPUInfo(contextID, "/ns1", common.LinuxProcessPU) 931 932 spec, _ := portspec.NewPortSpecFromString("80", nil) 933 puInfo.Runtime.SetOptions(policy.OptionsType{ 934 CgroupMark: "100", 935 Services: []common.Service{ 936 { 937 Protocol: uint8(6), 938 Ports: spec, 939 }, 940 }, 941 }) 942 943 Convey("When I create a new PU", func() { 944 err := enforcer.Enforce(context.Background(), contextID, puInfo) 945 946 Convey("It should succeed", func() { 947 So(err, ShouldBeNil) 948 _, err := enforcer.puFromContextID.Get(contextID) 949 So(err, ShouldBeNil) 950 _, err1 := enforcer.puFromMark.Get("100") 951 So(err1, ShouldBeNil) 952 _, err2 := enforcer.contextIDFromTCPPort.GetSpecValueFromPort(80) 953 So(err2, ShouldBeNil) 954 So(enforcer.puFromIP, ShouldBeNil) 955 }) 956 }) 957 }) 958 959 Convey("Given an initialized enforcer for Linux Processes", t, func() { 960 961 enforcer, secrets, mockTokenAccessor, _, mockDNS := NewWithMocks(ctrl, "serverID", constants.LocalServer, []string{"0.0.0.0/0"}, true) 962 963 secrets.EXPECT().EncodingKey().Return(&ecdsa.PrivateKey{}).AnyTimes() 964 mockTokenAccessor.EXPECT().Sign(gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil).AnyTimes() 965 mockTokenAccessor.EXPECT().CreateSynPacketToken(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil) 966 mockDNS.EXPECT().StartDNSServer(gomock.Any(), gomock.Any(), gomock.Any()).Times(1) 967 mockDNS.EXPECT().Enforce(gomock.Any(), gomock.Any(), gomock.Any()).Times(1) 968 mockDNS.EXPECT().SyncWithPlatformCache(gomock.Any(), gomock.Any()).Times(1) 969 970 contextID := "125" 971 puInfo := policy.NewPUInfo(contextID, "/ns1", common.LinuxProcessPU) 972 973 Convey("When I create a new PU without ports or mark", func() { 974 err := enforcer.Enforce(context.Background(), contextID, puInfo) 975 976 Convey("It should succeed", func() { 977 So(err, ShouldBeNil) 978 _, err := enforcer.puFromContextID.Get(contextID) 979 So(err, ShouldBeNil) 980 So(enforcer.puFromIP, ShouldBeNil) 981 }) 982 }) 983 }) 984 985 Convey("Given an initialized enforcer for remote Linux Containers", t, func() { 986 987 enforcer, secrets, mockTokenAccessor, _, mockDNS := NewWithMocks(ctrl, "serverID", constants.RemoteContainer, []string{"0.0.0.0/0"}, true) 988 989 secrets.EXPECT().EncodingKey().Return(&ecdsa.PrivateKey{}).AnyTimes() 990 mockTokenAccessor.EXPECT().Sign(gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil).AnyTimes() 991 mockTokenAccessor.EXPECT().CreateSynPacketToken(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil) 992 mockDNS.EXPECT().StartDNSServer(gomock.Any(), gomock.Any(), gomock.Any()).Times(1) 993 mockDNS.EXPECT().Enforce(gomock.Any(), gomock.Any(), gomock.Any()).Times(1) 994 mockDNS.EXPECT().SyncWithPlatformCache(gomock.Any(), gomock.Any()).Times(1) 995 996 contextID := "126" 997 puInfo := policy.NewPUInfo(contextID, "/ns1", common.ContainerPU) 998 999 Convey("When I create a new PU without an IP", func() { 1000 err := enforcer.Enforce(context.Background(), contextID, puInfo) 1001 1002 Convey("It should succeed ", func() { 1003 So(err, ShouldBeNil) 1004 So(enforcer.puFromIP, ShouldNotBeNil) 1005 }) 1006 }) 1007 }) 1008 } 1009 1010 func TestContextFromIP(t *testing.T) { 1011 1012 ctrl := gomock.NewController(t) 1013 defer ctrl.Finish() 1014 1015 Convey("Given an initialized enforcer for Linux Processes", t, func() { 1016 1017 enforcer, _, _, _, _ := NewWithMocks(ctrl, "serverID", constants.RemoteContainer, []string{"0.0.0.0/0"}, true) 1018 1019 puInfo := policy.NewPUInfo("SomePU", "/ns", common.ContainerPU) 1020 1021 context, err := pucontext.NewPU("SomePU", puInfo, nil, 10*time.Second) 1022 contextID := "AporetoContext" 1023 So(err, ShouldBeNil) 1024 1025 Convey("If I try to get context based on IP and its not there and its a local container it should fail ", func() { 1026 _, err := enforcer.contextFromIP(true, "", 0, packet.IPProtocolTCP) 1027 So(err, ShouldNotBeNil) 1028 }) 1029 1030 Convey("If there is no IP match, it should try the mark for app packets ", func() { 1031 enforcer.puFromMark.AddOrUpdate("100", context) 1032 enforcer.mode = constants.LocalServer 1033 Convey("If the mark exists", func() { 1034 markVal := strconv.Itoa(100) 1035 ctx, err := enforcer.contextFromIP(true, markVal, 0, packet.IPProtocolTCP) 1036 So(err, ShouldBeNil) 1037 So(ctx, ShouldNotBeNil) 1038 So(ctx, ShouldEqual, context) 1039 }) 1040 1041 Convey("If the mark doesn't exist", func() { 1042 _, err := enforcer.contextFromIP(true, "2000", 0, packet.IPProtocolTCP) 1043 So(err, ShouldNotBeNil) 1044 }) 1045 }) 1046 1047 Convey("If there is no IP match, it should try the port for net packets ", func() { 1048 s, _ := portspec.NewPortSpec(8000, 8000, contextID) 1049 enforcer.contextIDFromTCPPort.AddPortSpec(s) 1050 enforcer.puFromContextID.AddOrUpdate(contextID, context) 1051 enforcer.mode = constants.LocalServer 1052 1053 Convey("If the port exists", func() { 1054 ctx, err := enforcer.contextFromIP(false, "", 8000, packet.IPProtocolTCP) 1055 So(err, ShouldBeNil) 1056 So(ctx, ShouldNotBeNil) 1057 So(ctx, ShouldEqual, context) 1058 }) 1059 1060 Convey("If the port doesn't exist", func() { 1061 _, err := enforcer.contextFromIP(false, "", 9000, packet.IPProtocolTCP) 1062 So(err, ShouldNotBeNil) 1063 }) 1064 }) 1065 1066 }) 1067 1068 Convey("Given an initialized enforcer for HostPU", t, func() { 1069 1070 enforcer, _, _, _, _ := NewWithMocks(ctrl, "serverID", constants.RemoteContainer, []string{"0.0.0.0/0"}, true) 1071 1072 puInfo := policy.NewPUInfo("SomeHostPU", "/ns", common.HostPU) 1073 1074 context, err := pucontext.NewPU("SomeHostPU", puInfo, nil, 10*time.Second) 1075 So(err, ShouldBeNil) 1076 1077 enforcer.hostPU = context 1078 1079 Convey("If I try to get context for app ICMP for HostPU it should succeed ", func() { 1080 ctx, err := enforcer.contextFromIP(true, "", 0, packet.IPProtocolICMP) 1081 So(err, ShouldBeNil) 1082 So(ctx, ShouldNotBeNil) 1083 So(ctx, ShouldEqual, context) 1084 }) 1085 Convey("If I try to get context for net ICMP for HostPU it should succeed ", func() { 1086 ctx, err := enforcer.contextFromIP(false, "", 0, packet.IPProtocolICMP) 1087 So(err, ShouldBeNil) 1088 So(ctx, ShouldNotBeNil) 1089 So(ctx, ShouldEqual, context) 1090 }) 1091 Convey("If I try to get context for another protocol it should not return host context ", func() { 1092 _, err := enforcer.contextFromIP(true, "", 0, packet.IPProtocolTCP) 1093 So(err, ShouldNotBeNil) 1094 }) 1095 1096 }) 1097 } 1098 1099 func TestInvalidPacket(t *testing.T) { 1100 1101 ctrl := gomock.NewController(t) 1102 defer ctrl.Finish() 1103 1104 testThePackets := func(enforcer *Datapath) { 1105 1106 InvalidTCPFlow := [][]byte{ 1107 { /*0x4a, 0x1d, 0x70, 0xcf, 0xa6, 0xe5, 0xb8, 0xe8, 0x56, 0x32, 0x0b, 0xde, 0x08, 0x00,*/ 0x45, 0x00, 0x00, 0x40, 0xf4, 0x1f, 0x44, 0x00, 0x40, 0x06, 0xa9, 0x6f, 0x0a, 0x01, 0x0a, 0x4c, 0xa4, 0x43, 0xe4, 0x98, 0xe1, 0xa1, 0x00, 0x50, 0x4d, 0xa6, 0xac, 0x48, 0x00, 0x00, 0x00, 0x00, 0xb0, 0x02, 0xff, 0xff, 0x6b, 0x6c, 0x00, 0x00, 0x02, 0x04, 0x05, 0xb4, 0x01, 0x03, 0x03, 0x05, 0x01, 0x01, 0x08, 0x0a, 0x1b, 0x4f, 0x37, 0x38, 0x00, 0x00, 0x00, 0x00, 0x04, 0x02, 0x00, 0x00, 0x4a, 0x1d, 0x70, 0xcf}, 1108 } 1109 1110 for _, p := range InvalidTCPFlow { 1111 tcpPacket, err := packet.New(0, p, "0", true) 1112 So(err, ShouldBeNil) 1113 _, err = enforcer.processApplicationTCPPackets(tcpPacket) 1114 So(err, ShouldBeNil) 1115 output := make([]byte, len(tcpPacket.GetTCPBytes())) 1116 copy(output, tcpPacket.GetTCPBytes()) 1117 outpacket, err := packet.New(0, output, "0", true) 1118 So(err, ShouldBeNil) 1119 //Detach the data and parse token should fail 1120 outpacket.TCPDataDetach(binary.BigEndian.Uint16([]byte{0x0, p[32]})/4 - 20) 1121 So(err, ShouldBeNil) 1122 _, _, err = enforcer.processNetworkTCPPackets(outpacket) 1123 So(err, ShouldNotBeNil) 1124 } 1125 1126 } 1127 1128 flowRecord := CreateFlowRecord(1, testSrcIP, testDstIP, 666, 80, policy.Reject|policy.Log, collector.MissingToken) 1129 1130 Convey("When the mode is RemoteConainter", t, func() { 1131 1132 enforcer, mockCollector := createEnforcerWithPolicy(ctrl, constants.RemoteContainer) 1133 mockCollector.EXPECT().CollectFlowEvent(MyMatcher(&flowRecord)).Times(1) 1134 testThePackets(enforcer) 1135 1136 }) 1137 1138 Convey("When the mode is LocalServer", t, func() { 1139 1140 enforcer, mockCollector := createEnforcerWithPolicy(ctrl, constants.LocalServer) 1141 mockCollector.EXPECT().CollectFlowEvent(MyMatcher(&flowRecord)).Times(1) 1142 testThePackets(enforcer) 1143 1144 }) 1145 } 1146 1147 func TestFlowReportingInvalidSyn(t *testing.T) { 1148 1149 ctrl := gomock.NewController(t) 1150 defer ctrl.Finish() 1151 1152 testThePackets := func(enforcer *Datapath) { 1153 1154 SIP := net.IPv4zero 1155 packetDiffers := false 1156 1157 PacketFlow := packetgen.NewTemplateFlow() 1158 _, err := PacketFlow.GenerateTCPFlow(packetgen.PacketFlowTypeGoodFlowTemplate) 1159 So(err, ShouldBeNil) 1160 for i := 0; i < PacketFlow.GetSynPackets().GetNumPackets(); i++ { 1161 1162 start, err := PacketFlow.GetSynPackets().GetNthPacket(i).ToBytes() 1163 So(err, ShouldBeNil) 1164 oldPacket, err := packet.New(0, start, "0", true) 1165 if err == nil && oldPacket != nil { 1166 oldPacket.UpdateIPv4Checksum() 1167 oldPacket.UpdateTCPChecksum() 1168 } 1169 1170 input, err := PacketFlow.GetSynPackets().GetNthPacket(i).ToBytes() 1171 So(err, ShouldBeNil) 1172 tcpPacket, err := packet.New(0, input, "0", true) 1173 if err == nil && tcpPacket != nil { 1174 tcpPacket.UpdateIPv4Checksum() 1175 tcpPacket.UpdateTCPChecksum() 1176 } 1177 1178 if debug { 1179 fmt.Println("Input packet", i) 1180 tcpPacket.Print(0, false) 1181 } 1182 1183 So(err, ShouldBeNil) 1184 So(tcpPacket, ShouldNotBeNil) 1185 1186 if reflect.DeepEqual(SIP, net.IPv4zero) { 1187 SIP = tcpPacket.SourceAddress() 1188 } 1189 if !reflect.DeepEqual(SIP, tcpPacket.DestinationAddress()) && 1190 !reflect.DeepEqual(SIP, tcpPacket.SourceAddress()) { 1191 t.Error("Invalid Test Packet") 1192 } 1193 1194 if debug { 1195 fmt.Println("Intermediate packet", i) 1196 tcpPacket.Print(0, false) 1197 } 1198 1199 output := make([]byte, len(tcpPacket.GetTCPBytes())) 1200 copy(output, tcpPacket.GetTCPBytes()) 1201 1202 outPacket, errp := packet.New(0, output, "0", true) 1203 So(len(tcpPacket.GetTCPBytes()), ShouldBeLessThanOrEqualTo, len(outPacket.GetTCPBytes())) 1204 So(errp, ShouldBeNil) 1205 _, _, err = enforcer.processNetworkTCPPackets(outPacket) 1206 So(err, ShouldNotBeNil) 1207 1208 if debug { 1209 fmt.Println("Output packet", i) 1210 outPacket.Print(0, false) 1211 } 1212 1213 if !reflect.DeepEqual(oldPacket.GetTCPBytes(), outPacket.GetTCPBytes()) { 1214 packetDiffers = true 1215 fmt.Println("Error: packets dont match") 1216 fmt.Println("Input Packet") 1217 oldPacket.Print(0, false) 1218 fmt.Println("Output Packet") 1219 outPacket.Print(0, false) 1220 t.Errorf("Packet %d Input and output packet do not match", i) 1221 t.FailNow() 1222 } 1223 } 1224 1225 Convey("Then I expect all the input and output packets (after encoding and decoding) to be same", func() { 1226 1227 So(packetDiffers, ShouldEqual, false) 1228 }) 1229 } 1230 1231 flowRecord := CreateFlowRecord(1, testSrcIP, testDstIP, 666, 80, policy.Reject|policy.Log, collector.MissingToken) 1232 1233 Convey("When the mode is RemoteConainter", t, func() { 1234 1235 enforcer, mockCollector := createEnforcerWithPolicy(ctrl, constants.RemoteContainer) 1236 mockCollector.EXPECT().CollectFlowEvent(MyMatcher(&flowRecord)).Times(1) 1237 testThePackets(enforcer) 1238 1239 }) 1240 1241 Convey("When the mode is LocalServer", t, func() { 1242 1243 enforcer, mockCollector := createEnforcerWithPolicy(ctrl, constants.LocalServer) 1244 mockCollector.EXPECT().CollectFlowEvent(MyMatcher(&flowRecord)).Times(1) 1245 testThePackets(enforcer) 1246 1247 }) 1248 } 1249 1250 func TestFlowReportingUptoInvalidSynAck(t *testing.T) { 1251 1252 ctrl := gomock.NewController(t) 1253 defer ctrl.Finish() 1254 1255 testThePackets := func(enforcer *Datapath) { 1256 1257 SIP := net.IPv4zero 1258 packetDiffers := false 1259 1260 PacketFlow := packetgen.NewTemplateFlow() 1261 _, err := PacketFlow.GenerateTCPFlow(packetgen.PacketFlowTypeGoodFlowTemplate) 1262 So(err, ShouldBeNil) 1263 for i := 0; i < PacketFlow.GetUptoFirstSynAckPacket().GetNumPackets(); i++ { 1264 start, err := PacketFlow.GetUptoFirstSynAckPacket().GetNthPacket(i).ToBytes() 1265 So(err, ShouldBeNil) 1266 1267 oldPacket, err := packet.New(0, start, "0", true) 1268 if err == nil && oldPacket != nil { 1269 oldPacket.UpdateIPv4Checksum() 1270 oldPacket.UpdateTCPChecksum() 1271 } 1272 input, err := PacketFlow.GetUptoFirstSynAckPacket().GetNthPacket(i).ToBytes() 1273 So(err, ShouldBeNil) 1274 tcpPacket, err := packet.New(0, input, "0", true) 1275 if err == nil && tcpPacket != nil { 1276 tcpPacket.UpdateIPv4Checksum() 1277 tcpPacket.UpdateTCPChecksum() 1278 } 1279 1280 if debug { 1281 fmt.Println("Input packet", i) 1282 tcpPacket.Print(0, false) 1283 } 1284 1285 So(err, ShouldBeNil) 1286 So(tcpPacket, ShouldNotBeNil) 1287 1288 if reflect.DeepEqual(SIP, net.IPv4zero) { 1289 SIP = tcpPacket.SourceAddress() 1290 } 1291 1292 if !reflect.DeepEqual(SIP, tcpPacket.DestinationAddress()) && 1293 !reflect.DeepEqual(SIP, tcpPacket.SourceAddress()) { 1294 t.Error("Invalid Test Packet") 1295 } 1296 if PacketFlow.GetNthPacket(i).GetTCPSyn() && !PacketFlow.GetNthPacket(i).GetTCPAck() { 1297 _, err = enforcer.processApplicationTCPPackets(tcpPacket) 1298 1299 So(err, ShouldBeNil) 1300 } 1301 1302 if debug { 1303 fmt.Println("Intermediate packet", i) 1304 tcpPacket.Print(0, false) 1305 } 1306 1307 output := make([]byte, len(tcpPacket.GetTCPBytes())) 1308 copy(output, tcpPacket.GetTCPBytes()) 1309 1310 outPacket, errp := packet.New(0, output, "0", true) 1311 So(len(tcpPacket.GetTCPBytes()), ShouldBeLessThanOrEqualTo, len(outPacket.GetTCPBytes())) 1312 So(errp, ShouldBeNil) 1313 1314 if PacketFlow.GetNthPacket(i).GetTCPSyn() && !PacketFlow.GetNthPacket(i).GetTCPAck() { 1315 _, _, err = enforcer.processNetworkTCPPackets(outPacket) 1316 So(err, ShouldBeNil) 1317 } 1318 if PacketFlow.GetNthPacket(i).GetTCPSyn() && PacketFlow.GetNthPacket(i).GetTCPAck() { 1319 _, _, err = enforcer.processNetworkTCPPackets(outPacket) 1320 So(err, ShouldNotBeNil) 1321 } 1322 1323 if debug { 1324 fmt.Println("Output packet", i) 1325 outPacket.Print(0, false) 1326 } 1327 1328 if !reflect.DeepEqual(oldPacket.GetTCPBytes(), outPacket.GetTCPBytes()) { 1329 packetDiffers = true 1330 fmt.Println("Error: packets dont match") 1331 fmt.Println("Input Packet") 1332 oldPacket.Print(0, false) 1333 fmt.Println("Output Packet") 1334 outPacket.Print(0, false) 1335 t.Errorf("Packet %d Input and output packet do not match", i) 1336 t.FailNow() 1337 } 1338 } 1339 1340 Convey("Then I expect all the input and output packets (after encoding and decoding) to be same", func() { 1341 1342 So(packetDiffers, ShouldEqual, false) 1343 }) 1344 } 1345 1346 flowRecord := CreateFlowRecord(1, testSrcIP, testDstIP, 666, 80, policy.Reject|policy.Log, "policy") 1347 1348 Convey("When the mode is RemoteConainter", t, func() { 1349 1350 enforcer, mockCollector := createEnforcerWithPolicy(ctrl, constants.RemoteContainer) 1351 mockCollector.EXPECT().CollectFlowEvent(MyMatcher(&flowRecord)).Times(1) 1352 testThePackets(enforcer) 1353 1354 }) 1355 1356 Convey("When the mode is LocalServer", t, func() { 1357 1358 enforcer, mockCollector := createEnforcerWithPolicy(ctrl, constants.LocalServer) 1359 mockCollector.EXPECT().CollectFlowEvent(MyMatcher(&flowRecord)).Times(1) 1360 testThePackets(enforcer) 1361 1362 }) 1363 } 1364 1365 func TestForPacketsWithRandomFlags(t *testing.T) { 1366 1367 ctrl := gomock.NewController(t) 1368 defer ctrl.Finish() 1369 1370 debug = true 1371 1372 defer MockGetUDPRawSocket()() 1373 1374 testThePackets := func(enforcer *Datapath) { 1375 1376 PacketFlow := packetgen.NewPacketFlow("aa:ff:aa:ff:aa:ff", "ff:aa:ff:aa:ff:aa", testSrcIP, testDstIP, 666, 80) 1377 _, err := PacketFlow.GenerateTCPFlow(packetgen.PacketFlowTypeGenerateGoodFlow) 1378 So(err, ShouldBeNil) 1379 1380 count := PacketFlow.GetNumPackets() 1381 for i := 0; i < count; i++ { 1382 //Setting random TCP flags for all the packets 1383 PacketFlow.GetNthPacket(i).SetTCPCwr() 1384 PacketFlow.GetNthPacket(i).SetTCPPsh() 1385 PacketFlow.GetNthPacket(i).SetTCPEce() 1386 input, err := PacketFlow.GetNthPacket(i).ToBytes() 1387 So(err, ShouldBeNil) 1388 tcpPacket, err := packet.New(0, input, "0", true) 1389 if err == nil && tcpPacket != nil { 1390 tcpPacket.UpdateIPv4Checksum() 1391 tcpPacket.UpdateTCPChecksum() 1392 } 1393 1394 if debug { 1395 fmt.Println("Input packet", i) 1396 tcpPacket.Print(0, false) 1397 } 1398 1399 So(err, ShouldBeNil) 1400 So(tcpPacket, ShouldNotBeNil) 1401 1402 SIP := tcpPacket.SourceAddress() 1403 1404 if !reflect.DeepEqual(SIP, tcpPacket.DestinationAddress()) && 1405 !reflect.DeepEqual(SIP, tcpPacket.SourceAddress()) { 1406 t.Error("Invalid Test Packet") 1407 } 1408 1409 _, err = enforcer.processApplicationTCPPackets(tcpPacket) 1410 So(err, ShouldBeNil) 1411 1412 if debug { 1413 fmt.Println("Intermediate packet", i) 1414 tcpPacket.Print(0, false) 1415 } 1416 1417 output := make([]byte, len(tcpPacket.GetTCPBytes())) 1418 copy(output, tcpPacket.GetTCPBytes()) 1419 1420 outPacket, errp := packet.New(0, output, "0", true) 1421 So(len(tcpPacket.GetTCPBytes()), ShouldBeLessThanOrEqualTo, len(outPacket.GetTCPBytes())) 1422 So(errp, ShouldBeNil) 1423 1424 _, f, err := enforcer.processNetworkTCPPackets(outPacket) 1425 if f != nil { 1426 f() 1427 } 1428 1429 So(err, ShouldBeNil) 1430 1431 if debug { 1432 fmt.Println("Output packet ", i) 1433 outPacket.Print(0, false) 1434 } 1435 } 1436 } 1437 1438 flowRecord := CreateFlowRecord(1, testSrcIP, testDstIP, 666, 80, policy.Accept, "") 1439 1440 Convey("When the mode is RemoteConainter", t, func() { 1441 1442 enforcer, mockCollector := createEnforcerWithPolicy(ctrl, constants.RemoteContainer) 1443 mockCollector.EXPECT().CollectFlowEvent(MyMatcher(&flowRecord)).Times(1) 1444 testThePackets(enforcer) 1445 1446 }) 1447 1448 Convey("When the mode is LocalServer", t, func() { 1449 1450 enforcer, mockCollector := createEnforcerWithPolicy(ctrl, constants.LocalServer) 1451 mockCollector.EXPECT().CollectFlowEvent(MyMatcher(&flowRecord)).Times(1) 1452 testThePackets(enforcer) 1453 }) 1454 } 1455 1456 func TestPUPortCreation(t *testing.T) { 1457 1458 ctrl := gomock.NewController(t) 1459 defer ctrl.Finish() 1460 1461 Convey("Given I setup an enforcer", t, func() { 1462 1463 defer MockGetUDPRawSocket()() 1464 1465 enforcer, secrets, mockTokenAccessor, _, mockDNS := NewWithMocks(ctrl, "serverID1", constants.LocalServer, []string{"0.0.0.0/0"}, true) 1466 if enforcer == nil { // This avoids lint error SA5011: possible nil pointer dereference (staticcheck) 1467 So(enforcer != nil, ShouldBeTrue) 1468 return 1469 } 1470 1471 enforcer.packetLogs = true 1472 1473 secrets.EXPECT().EncodingKey().Return(&ecdsa.PrivateKey{}).AnyTimes() 1474 mockTokenAccessor.EXPECT().Sign(gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil).AnyTimes() 1475 mockTokenAccessor.EXPECT().CreateSynPacketToken(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil) 1476 1477 contextID := "1001" 1478 puInfo := policy.NewPUInfo(contextID, "/ns1", common.LinuxProcessPU) 1479 puInfo.Runtime.SetOptions(policy.OptionsType{ 1480 CgroupMark: "100", 1481 }) 1482 1483 mockDNS.EXPECT().StartDNSServer(gomock.Any(), contextID, gomock.Any()).Times(1) 1484 mockDNS.EXPECT().Enforce(gomock.Any(), contextID, puInfo) 1485 mockDNS.EXPECT().SyncWithPlatformCache(gomock.Any(), gomock.Any()).Times(1) 1486 1487 enforcer.Enforce(context.Background(), contextID, puInfo) // nolint 1488 }) 1489 } 1490 1491 func TestCollectTCPPacket(t *testing.T) { 1492 1493 ctrl := gomock.NewController(t) 1494 defer ctrl.Finish() 1495 1496 Convey("Given I setup an enforcer", t, func() { 1497 1498 enforcer, secrets, mockTokenAccessor, mockCollector, _ := NewWithMocks(ctrl, "serverID1", constants.RemoteContainer, []string{"0.0.0.0/0"}, true) 1499 So(enforcer != nil, ShouldBeTrue) 1500 1501 secrets.EXPECT().EncodingKey().Return(&ecdsa.PrivateKey{}).AnyTimes() 1502 mockTokenAccessor.EXPECT().Sign(gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil).AnyTimes() 1503 mockTokenAccessor.EXPECT().CreateSynPacketToken(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil) 1504 1505 contextID := "dummy" 1506 _, err := CreatePUContext(enforcer, contextID, "/ns1", common.ContainerPU, mockTokenAccessor) 1507 So(err, ShouldBeNil) 1508 1509 tcpPacket, err := newPacket(1, packet.TCPSynMask, testSrcIP, testDstIP, srcPort, dstPort, true, false) 1510 So(err, ShouldBeNil) 1511 1512 Convey("We setup tcp network packet tracing for this pu with incomplete state", func() { 1513 interval := 10 * time.Second 1514 err := enforcer.EnableDatapathPacketTracing(context.TODO(), contextID, packettracing.NetworkOnly, interval) 1515 So(err, ShouldBeNil) 1516 packetreport := collector.PacketReport{ 1517 DestinationIP: tcpPacket.DestinationAddress().String(), 1518 SourceIP: tcpPacket.SourceAddress().String(), 1519 } 1520 mockCollector.EXPECT().CollectPacketEvent(PacketEventMatcher(&packetreport)).Times(0) 1521 enforcer.collectTCPPacket(&debugpacketmessage{ 1522 Mark: 10, 1523 p: tcpPacket, 1524 tcpConn: nil, 1525 udpConn: nil, 1526 err: nil, 1527 network: true, 1528 }) 1529 }) 1530 Convey("We setup tcp network packet tracing for this pu with tcpConn != nil state", func() { 1531 interval := 10 * time.Second 1532 err := enforcer.EnableDatapathPacketTracing(context.TODO(), contextID, packettracing.NetworkOnly, interval) 1533 So(err, ShouldBeNil) 1534 packetreport := collector.PacketReport{ 1535 DestinationIP: tcpPacket.DestinationAddress().String(), 1536 SourceIP: tcpPacket.SourceAddress().String(), 1537 } 1538 context, _ := enforcer.puFromContextID.Get(contextID) 1539 tcpConn := connection.NewTCPConnection(context.(*pucontext.PUContext), nil) 1540 1541 mockCollector.EXPECT().CollectPacketEvent(PacketEventMatcher(&packetreport)).Times(1) 1542 enforcer.collectTCPPacket(&debugpacketmessage{ 1543 Mark: 10, 1544 p: tcpPacket, 1545 tcpConn: tcpConn, 1546 udpConn: nil, 1547 err: nil, 1548 network: true, 1549 }) 1550 }) 1551 Convey("We setup tcp network packet tracing for this pu with tcpConn != nil and inject application packet", func() { 1552 interval := 10 * time.Second 1553 err := enforcer.EnableDatapathPacketTracing(context.TODO(), contextID, packettracing.NetworkOnly, interval) 1554 So(err, ShouldBeNil) 1555 packetreport := collector.PacketReport{ 1556 DestinationIP: tcpPacket.DestinationAddress().String(), 1557 SourceIP: tcpPacket.SourceAddress().String(), 1558 } 1559 context, _ := enforcer.puFromContextID.Get(contextID) 1560 tcpConn := connection.NewTCPConnection(context.(*pucontext.PUContext), nil) 1561 mockCollector.EXPECT().CollectPacketEvent(PacketEventMatcher(&packetreport)).Times(0) 1562 enforcer.collectTCPPacket(&debugpacketmessage{ 1563 Mark: 10, 1564 p: tcpPacket, 1565 tcpConn: tcpConn, 1566 udpConn: nil, 1567 err: nil, 1568 network: false, 1569 }) 1570 }) 1571 1572 }) 1573 } 1574 1575 func TestEnableDatapathPacketTracing(t *testing.T) { 1576 1577 ctrl := gomock.NewController(t) 1578 defer ctrl.Finish() 1579 1580 Convey("Given I setup an enforcer", t, func() { 1581 1582 enforcer, secrets, mockTokenAccessor, _, _ := NewWithMocks(ctrl, "serverID1", constants.RemoteContainer, []string{"0.0.0.0/0"}, true) 1583 if enforcer == nil { // This avoids lint error SA5011: possible nil pointer dereference (staticcheck) 1584 So(enforcer != nil, ShouldBeTrue) 1585 return 1586 } 1587 1588 secrets.EXPECT().EncodingKey().Return(&ecdsa.PrivateKey{}).AnyTimes() 1589 mockTokenAccessor.EXPECT().Sign(gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil).AnyTimes() 1590 mockTokenAccessor.EXPECT().CreateSynPacketToken(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil) 1591 1592 contextID := "dummy" 1593 _, err := CreatePUContext(enforcer, contextID, "/ns1", common.ContainerPU, mockTokenAccessor) 1594 So(err, ShouldBeNil) 1595 1596 err = enforcer.EnableDatapathPacketTracing(context.TODO(), contextID, packettracing.ApplicationOnly, 10*time.Second) 1597 So(err, ShouldBeNil) 1598 _, err = enforcer.packetTracingCache.Get(contextID) 1599 So(err, ShouldBeNil) 1600 }) 1601 } 1602 1603 func Test_CheckCounterCollection(t *testing.T) { 1604 ctrl := gomock.NewController(t) 1605 defer ctrl.Finish() 1606 collectCounterInterval = 1 * time.Second 1607 Convey("Given I setup an enforcer", t, func() { 1608 1609 Convey("So When enforcer exits", func() { 1610 1611 enforcer, secrets, mockTokenAccessor, mockCollector, _ := NewWithMocks(ctrl, "serverID1", constants.RemoteContainer, []string{"0.0.0.0/0"}, true) 1612 So(enforcer != nil, ShouldBeTrue) 1613 1614 secrets.EXPECT().EncodingKey().Return(&ecdsa.PrivateKey{}).AnyTimes() 1615 mockTokenAccessor.EXPECT().Sign(gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil).AnyTimes() 1616 mockTokenAccessor.EXPECT().CreateSynPacketToken(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil) 1617 1618 puContext, err := CreatePUContext(enforcer, "dummy", "/ns1", common.ContainerPU, mockTokenAccessor) 1619 So(err, ShouldBeNil) 1620 1621 CounterReport := &collector.CounterReport{ 1622 PUID: puContext.ManagementID(), 1623 Namespace: puContext.ManagementNamespace(), 1624 } 1625 mockCollector.EXPECT().CollectCounterEvent(MyCounterMatcher(CounterReport)).MinTimes(1) 1626 1627 ctx, cancel := context.WithCancel(context.Background()) 1628 go enforcer.counterCollector(ctx) 1629 1630 puErr := puContext.Counters().CounterError((counters.ErrNonPUTraffic), fmt.Errorf("error")) 1631 1632 So(puErr, ShouldNotBeNil) 1633 cancel() 1634 }) 1635 1636 Convey("So When enforer exits and waits for stuff to exit", func() { 1637 enforcer, secrets, mockTokenAccessor, mockCollector, _ := NewWithMocks(ctrl, "serverID1", constants.RemoteContainer, []string{"0.0.0.0/0"}, true) 1638 So(enforcer != nil, ShouldBeTrue) 1639 1640 secrets.EXPECT().EncodingKey().Return(&ecdsa.PrivateKey{}).AnyTimes() 1641 mockTokenAccessor.EXPECT().Sign(gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil).AnyTimes() 1642 mockTokenAccessor.EXPECT().CreateSynPacketToken(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil) 1643 1644 puContext, err := CreatePUContext(enforcer, "dummy", "/ns1", common.ContainerPU, mockTokenAccessor) 1645 So(err, ShouldBeNil) 1646 1647 c := &collector.CounterReport{ 1648 PUID: puContext.ManagementID(), 1649 Namespace: puContext.ManagementNamespace(), 1650 } 1651 1652 mockCollector.EXPECT().CollectCounterEvent(MyCounterMatcher(c)).MinTimes(1) 1653 1654 ctx, cancel := context.WithCancel(context.Background()) 1655 go enforcer.counterCollector(ctx) 1656 1657 puErr := puContext.Counters().CounterError(counters.ErrNonPUTraffic, fmt.Errorf("error")) 1658 1659 So(puErr, ShouldNotBeNil) 1660 cancel() 1661 <-time.After(5 * time.Second) 1662 1663 }) 1664 Convey("So When an error is reported and the enforcer waits for collection interval", func() { 1665 enforcer, secrets, mockTokenAccessor, mockCollector, _ := NewWithMocks(ctrl, "serverID1", constants.RemoteContainer, []string{"0.0.0.0/0"}, true) 1666 So(enforcer != nil, ShouldBeTrue) 1667 1668 secrets.EXPECT().EncodingKey().Return(&ecdsa.PrivateKey{}).AnyTimes() 1669 mockTokenAccessor.EXPECT().Sign(gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil).AnyTimes() 1670 mockTokenAccessor.EXPECT().CreateSynPacketToken(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil) 1671 1672 puContext, err := CreatePUContext(enforcer, "dummy", "/ns1", common.ContainerPU, mockTokenAccessor) 1673 So(err, ShouldBeNil) 1674 1675 c := &collector.CounterReport{ 1676 PUID: puContext.ManagementID(), 1677 Namespace: puContext.ManagementNamespace(), 1678 } 1679 1680 mockCollector.EXPECT().CollectCounterEvent(MyCounterMatcher(c)).MinTimes(1) 1681 1682 ctx, cancel := context.WithCancel(context.Background()) 1683 go enforcer.counterCollector(ctx) 1684 puErr := puContext.Counters().CounterError(counters.ErrNonPUTraffic, fmt.Errorf("error")) 1685 So(puErr, ShouldNotBeNil) 1686 <-time.After(5 * collectCounterInterval) 1687 cancel() 1688 1689 }) 1690 1691 }) 1692 } 1693 1694 func Test_CounterReportedOnAuthSetAppSyn(t *testing.T) { 1695 ctrl := gomock.NewController(t) 1696 defer ctrl.Finish() 1697 1698 Convey("Given I setup an enforcer", t, func() { 1699 1700 enforcer, secrets, mockTokenAccessor, _, _ := NewWithMocks(ctrl, "serverID1", constants.RemoteContainer, []string{"0.0.0.0/0"}, true) 1701 So(enforcer != nil, ShouldBeTrue) 1702 1703 secrets.EXPECT().EncodingKey().Return(&ecdsa.PrivateKey{}).AnyTimes() 1704 mockTokenAccessor.EXPECT().Sign(gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil).AnyTimes() 1705 mockTokenAccessor.EXPECT().CreateSynPacketToken(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil) 1706 mockTokenAccessor.EXPECT().Randomize(gomock.Any(), gomock.Any()).Times(2) 1707 1708 context, err := CreatePUContext(enforcer, "dummy", "/ns1", common.ContainerPU, mockTokenAccessor) 1709 So(err, ShouldBeNil) 1710 1711 p, err := newPacket(packet.PacketTypeApplication, packet.TCPSynMask, "1.1.1.1", "2.2.2.2", srcPort, dstPort, false, false) 1712 So(err, ShouldBeNil) 1713 conn := connection.NewTCPConnection(context, p) 1714 err = enforcer.processApplicationSynPacket(p, context, conn) 1715 So(err, ShouldBeNil) 1716 1717 c := conn.Context.Counters().GetErrorCounters() 1718 So(c[counters.ErrAppSynAuthOptionSet], ShouldBeZeroValue) 1719 1720 p, err = newPacket(packet.PacketTypeApplication, packet.TCPSynMask, "1.1.1.1", "2.2.2.2", srcPort, dstPort, true, false) 1721 So(err, ShouldBeNil) 1722 conn = connection.NewTCPConnection(context, p) 1723 err = enforcer.processApplicationSynPacket(p, context, conn) 1724 So(err, ShouldBeNil) 1725 1726 c = conn.Context.Counters().GetErrorCounters() 1727 So(c[counters.ErrAppSynAuthOptionSet], ShouldEqual, 1) 1728 }) 1729 } 1730 1731 func Test_CounterOnSynCacheTimeout(t *testing.T) { 1732 1733 ctrl := gomock.NewController(t) 1734 defer ctrl.Finish() 1735 1736 Convey("Given I setup an enforcer", t, func() { 1737 1738 enforcer, secrets, mockTokenAccessor, _, _ := NewWithMocks(ctrl, "serverID1", constants.RemoteContainer, []string{"0.0.0.0/0"}, true) 1739 if enforcer == nil { // This avoids lint error SA5011: possible nil pointer dereference (staticcheck) 1740 So(enforcer != nil, ShouldBeTrue) 1741 return 1742 } 1743 1744 secrets.EXPECT().EncodingKey().Return(&ecdsa.PrivateKey{}).AnyTimes() 1745 mockTokenAccessor.EXPECT().Sign(gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil).AnyTimes() 1746 mockTokenAccessor.EXPECT().CreateSynPacketToken(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil) 1747 mockTokenAccessor.EXPECT().Randomize(gomock.Any(), gomock.Any()).Times(1) 1748 1749 context, err := CreatePUContext(enforcer, "dummy", "/ns1", common.ContainerPU, mockTokenAccessor) 1750 So(err, ShouldBeNil) 1751 1752 p, err := newPacket(packet.PacketTypeApplication, packet.TCPSynMask, "1.1.1.1", "2.2.2.2", srcPort, dstPort, false, false) 1753 So(err, ShouldBeNil) 1754 1755 // Update the connection timer for testing. 1756 conn := connection.NewTCPConnection(context, p) 1757 conn.ChangeConnectionTimeout(2 * time.Second) 1758 1759 err = enforcer.processApplicationSynPacket(p, context, conn) 1760 So(err, ShouldBeNil) 1761 1762 c := conn.Context.Counters().GetErrorCounters() 1763 So(c[counters.ErrTCPConnectionsExpired], ShouldBeZeroValue) 1764 1765 // Wait for the connection to expire. 1766 time.Sleep(3 * time.Second) 1767 _, exists := enforcer.tcpClient.Get(p.L4FlowHash()) 1768 if exists { 1769 t.Fail() 1770 } 1771 1772 c = conn.Context.Counters().GetErrorCounters() 1773 So(c[counters.ErrTCPConnectionsExpired], ShouldEqual, 1) 1774 }) 1775 } 1776 1777 func Test_NOClaims(t *testing.T) { 1778 ctrl := gomock.NewController(t) 1779 defer ctrl.Finish() 1780 1781 Convey("Given I setup an enforcer", t, func() { 1782 1783 enforcer, _, _, mockCollector, _ := NewWithMocks(ctrl, "serverID1", constants.RemoteContainer, []string{"0.0.0.0/0"}, true) 1784 So(enforcer != nil, ShouldBeTrue) 1785 1786 flowRecord := CreateFlowRecord(1, "1.1.1.1", "2.2.2.2", 2000, 80, policy.Reject|policy.Log, collector.PolicyDrop) 1787 mockCollector.EXPECT().CollectFlowEvent(MyMatcher(&flowRecord)).Times(1) 1788 1789 context, err := CreatePUContext(enforcer, "dummy", "/ns1", common.ContainerPU, nil) 1790 So(err, ShouldBeNil) 1791 1792 p, err := newPacket(packet.PacketTypeNetwork, packet.TCPSynAckMask, "2.2.2.2", "1.1.1.1", dstPort, srcPort, true, false) 1793 So(err, ShouldBeNil) 1794 1795 conn := connection.NewTCPConnection(context, p) 1796 1797 _, err = enforcer.processNetworkSynAckPacket(context, conn, p) 1798 So(err, ShouldNotBeNil) 1799 }) 1800 } 1801 1802 func newPacket(context uint64, tcpFlags uint8, src, dst string, srcPort, desPort uint16, addOptions bool, addPayload bool) (*packet.Packet, error) { //nolint 1803 1804 p, err := packet.NewIpv4TCPPacket(context, tcpFlags, src, dst, srcPort, dstPort) 1805 if err != nil { 1806 return nil, err 1807 } 1808 1809 p.SetTCPSeq(rand.Uint32()) 1810 1811 if addOptions { 1812 options := []byte{2 /*Maximum Segment Size*/, 4, 0x05, 0x8C, 34, enforcerconstants.TCPAuthenticationOptionBaseLen, 0, 0} 1813 buffer := append(p.GetBuffer(0), options...) 1814 err = p.UpdatePacketBuffer(buffer, uint16(len(options))) 1815 } 1816 1817 if addPayload { 1818 buffer := append(p.GetBuffer(0), []byte("dummy payload")...) 1819 err = p.UpdatePacketBuffer(buffer, 0) 1820 } 1821 1822 return p, err 1823 } 1824 1825 func TestCheckConnectionDeletion(t *testing.T) { 1826 1827 ctrl := gomock.NewController(t) 1828 defer ctrl.Finish() 1829 1830 Convey("Given I setup an enforcer", t, func() { 1831 1832 enforcer, secrets, mockTokenAccessor, _, _ := NewWithMocks(ctrl, "serverID1", constants.RemoteContainer, []string{"0.0.0.0/0"}, true) 1833 1834 secrets.EXPECT().TransmittedKey().Return([]byte("dummy")).AnyTimes() 1835 secrets.EXPECT().EncodingKey().Return(&ecdsa.PrivateKey{}).AnyTimes() 1836 mockTokenAccessor.EXPECT().Sign(gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil).AnyTimes() 1837 mockTokenAccessor.EXPECT().CreateSynPacketToken(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil) 1838 1839 err := CreatePortPolicy(enforcer, "dummy", "/ns1", common.ContainerPU, mockTokenAccessor, "2", dstPort, dstPort) 1840 So(err, ShouldBeNil) 1841 1842 tcpPacket, err := newPacket(1, packet.TCPSynMask, testSrcIP, testDstIP, srcPort, dstPort, true, false) 1843 So(err, ShouldBeNil) 1844 1845 conn := &connection.TCPConnection{ 1846 ServiceConnection: true, 1847 MarkForDeletion: true, 1848 } 1849 1850 hash := tcpPacket.L4FlowHash() 1851 enforcer.tcpClient.Put(hash, conn) 1852 1853 tcpPacket.Mark = "2" 1854 1855 conn1, err := enforcer.appSynRetrieveState(tcpPacket) 1856 So(err, ShouldBeNil) 1857 So(conn1.MarkForDeletion, ShouldBeFalse) 1858 1859 enforcer.tcpServer.Put(hash, conn) 1860 _, err = enforcer.netSynRetrieveState(tcpPacket) 1861 So(err, ShouldBeNil) 1862 1863 tcpSynAckPacket, err := newPacket(1, packet.TCPSynAckMask, testDstIP, testSrcIP, dstPort, srcPort, true, false) 1864 So(err, ShouldBeNil) 1865 1866 _, err = enforcer.netSynAckRetrieveState(tcpSynAckPacket) 1867 So(err, ShouldNotBeNil) 1868 ShouldEqual(err, errNonPUTraffic) 1869 }) 1870 } 1871 1872 func TestNetSynRetrieveState(t *testing.T) { 1873 1874 ctrl := gomock.NewController(t) 1875 defer ctrl.Finish() 1876 1877 // Testing datapath.netSynRetrieveState 1878 // There are 4 different code branches in this functions 1879 1880 Convey("Given I setup an enforcer", t, func() { 1881 1882 defer MockGetUDPRawSocket()() 1883 1884 enforcer, secrets, mockTokenAccessor, _, _ := NewWithMocks(ctrl, "serverID1", constants.LocalServer, []string{"0.0.0.0/0"}, true) 1885 1886 secrets.EXPECT().TransmittedKey().Return([]byte("dummy")).AnyTimes() 1887 secrets.EXPECT().EncodingKey().Return(&ecdsa.PrivateKey{}).AnyTimes() 1888 mockTokenAccessor.EXPECT().Sign(gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil).AnyTimes() 1889 mockTokenAccessor.EXPECT().CreateSynPacketToken(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil) 1890 1891 err := CreatePortPolicy(enforcer, "123456", "/ns1", common.LinuxProcessPU, mockTokenAccessor, "2", 9000, 9000) 1892 So(err, ShouldBeNil) 1893 1894 // Test the error case 1895 p, err := packet.NewIpv4TCPPacket(1, 0x2, "127.0.0.1", "127.0.0.1", 43758, 8000) 1896 So(err, ShouldBeNil) 1897 _, err = enforcer.netSynRetrieveState(p) 1898 So(err, ShouldNotBeNil) 1899 1900 p, err = packet.NewIpv4TCPPacket(1, 0x2, "127.0.0.1", "127.0.0.1", 43758, 9000) 1901 So(err, ShouldBeNil) 1902 1903 conn, err := enforcer.netSynRetrieveState(p) 1904 So(err, ShouldBeNil) 1905 1906 enforcer.tcpServer.Put(p.L4FlowHash(), conn) 1907 1908 So(conn.GetInitialSequenceNumber(), ShouldEqual, p.TCPSequenceNumber()) 1909 Convey("I retry the same packet", func() { 1910 retryconn, err := enforcer.netSynRetrieveState(p) 1911 assert.Equal(t, err, nil, "error should be nil") 1912 assert.Equal(t, retryconn, conn, "connection should be same") 1913 }) 1914 Convey("Then i modify the sequence number and retry the packet", func() { 1915 p.IncreaseTCPSeq(10) 1916 conn1, err := enforcer.netSynRetrieveState(p) 1917 So(err, ShouldBeNil) 1918 So(conn1.GetInitialSequenceNumber(), ShouldNotEqual, conn.GetInitialSequenceNumber()) 1919 _, exists := enforcer.tcpServer.Get(p.L4FlowHash()) 1920 if exists { 1921 t.Fail() 1922 } 1923 }) 1924 1925 }) 1926 } 1927 1928 func TestAppSynRetrieveState(t *testing.T) { 1929 1930 ctrl := gomock.NewController(t) 1931 defer ctrl.Finish() 1932 1933 // Testing datapath.appSynRetrieveState 1934 // There are 4 different code branches in the function 1935 1936 Convey("Given I setup an enforcer", t, func() { 1937 1938 defer MockGetUDPRawSocket()() 1939 1940 enforcer, _, _, _, _ := NewWithMocks(ctrl, "serverID1", constants.LocalServer, []string{"0.0.0.0/0"}, true) 1941 1942 err := CreatePortPolicy(enforcer, "testContextID", "/ns1", common.LinuxProcessPU, nil, "2", 9000, 9000) 1943 So(err, ShouldBeNil) 1944 1945 // Create a Syn packet 1946 p, err := packet.NewIpv4TCPPacket(1, 0x2, "127.0.0.1", "127.0.0.1", 43758, 9000) 1947 So(err, ShouldBeNil) 1948 1949 // The error case "PU context doesn't exist for this syn, return error" 1950 _, err = enforcer.appSynRetrieveState(p) 1951 So(err, ShouldNotBeNil) 1952 1953 p.Mark = "2" 1954 1955 conn, err := enforcer.appSynRetrieveState(p) 1956 So(err, ShouldBeNil) 1957 1958 enforcer.tcpClient.Put(p.L4FlowHash(), conn) 1959 1960 Convey("I replay the same packet", func() { 1961 retryconn, err := enforcer.appSynRetrieveState(p) 1962 So(err, ShouldBeNil) 1963 So(retryconn, ShouldNotBeNil) 1964 1965 }) 1966 Convey("I modify the sequence number and retransmit the packet", func() { 1967 p.IncreaseTCPSeq(10) 1968 retryconn, err := enforcer.appSynRetrieveState(p) 1969 So(retryconn, ShouldNotBeNil) 1970 So(err, ShouldBeNil) 1971 _, exists := enforcer.tcpClient.Get(p.L4FlowHash()) 1972 if exists { 1973 t.Fail() 1974 } 1975 }) 1976 }) 1977 } 1978 1979 func TestAppSynAckRetrieveState(t *testing.T) { 1980 1981 ctrl := gomock.NewController(t) 1982 defer ctrl.Finish() 1983 1984 // Testing datapath.appSynAckRetrieveState 1985 // There are 2 different code branches in this functions 1986 1987 Convey("Given I setup an enforcer", t, func() { 1988 1989 defer MockGetUDPRawSocket()() 1990 1991 enforcer, _, _, _, _ := NewWithMocks(ctrl, "serverID1", constants.LocalServer, []string{"0.0.0.0/0"}, true) 1992 1993 // Create a SynAck packet 1994 p, err := packet.NewIpv4TCPPacket(1, packet.TCPSynAckMask, "127.0.0.1", "127.0.0.1", 43758, 9000) 1995 So(err, ShouldBeNil) 1996 1997 // The error case when nothing is in the cache 1998 _, err = enforcer.appSynAckRetrieveState(p) 1999 So(err, ShouldNotBeNil) 2000 2001 // add connection to the cache 2002 enforcer.tcpServer.Put(p.L4ReverseFlowHash(), &connection.TCPConnection{}) 2003 2004 // Should be in the cache 2005 conn, err := enforcer.appSynAckRetrieveState(p) 2006 So(err, ShouldBeNil) 2007 So(conn, ShouldNotBeNil) 2008 }) 2009 } 2010 2011 func TestNetSynAckRetrieveState(t *testing.T) { 2012 2013 ctrl := gomock.NewController(t) 2014 defer ctrl.Finish() 2015 2016 // Testing datapath.netSynAckRetrieveState 2017 // There are 3 different code branches in this functions 2018 2019 Convey("Given I setup an enforcer", t, func() { 2020 2021 defer MockGetUDPRawSocket()() 2022 2023 enforcer, _, _, _, _ := NewWithMocks(ctrl, "serverID1", constants.LocalServer, []string{"0.0.0.0/0"}, true) 2024 2025 // Create a SynAck packet 2026 p, err := packet.NewIpv4TCPPacket(1, packet.TCPSynAckMask, "127.0.0.1", "127.0.0.1", 43758, 9000) 2027 So(err, ShouldBeNil) 2028 2029 // The error case when nothing is in the cache 2030 _, err = enforcer.netSynAckRetrieveState(p) 2031 ShouldEqual(err, errNonPUTraffic) 2032 2033 // add connection to the cache 2034 enforcer.tcpClient.Put(p.L4ReverseFlowHash(), &connection.TCPConnection{}) 2035 2036 // Should be in the cache 2037 conn, err := enforcer.netSynAckRetrieveState(p) 2038 So(err, ShouldBeNil) 2039 So(conn, ShouldNotBeNil) 2040 2041 // Mark the connection as deleted 2042 conn.MarkForDeletion = true 2043 2044 // We should get an error 2045 _, err = enforcer.netSynAckRetrieveState(p) 2046 ShouldEqual(err, errOutOfOrderSynAck) 2047 }) 2048 } 2049 2050 func TestAppRetrieveState(t *testing.T) { 2051 2052 ctrl := gomock.NewController(t) 2053 defer ctrl.Finish() 2054 2055 // Testing datapath.appRetrieveState 2056 // There are 6 branch conditions in this function. 2057 2058 Convey("Given I setup an enforcer", t, func() { 2059 2060 defer MockGetUDPRawSocket()() 2061 2062 enforcer, _, _, _, _ := NewWithMocks(ctrl, "serverID1", constants.LocalServer, []string{"0.0.0.0/0"}, true) 2063 2064 // Create a Rst packet 2065 p, err := packet.NewIpv4TCPPacket(1, packet.TCPRstMask, "127.0.0.1", "127.0.0.1", 43758, 9000) 2066 So(err, ShouldBeNil) 2067 2068 // 1. We should get the errRstPacket error 2069 _, err = enforcer.appRetrieveState(p) 2070 ShouldEqual(err, errRstPacket) 2071 2072 // Create a Syn packet 2073 p, err = packet.NewIpv4TCPPacket(1, packet.TCPSynMask, "127.0.0.1", "127.0.0.1", 43758, 9000) 2074 So(err, ShouldBeNil) 2075 2076 // 2. We should get errNoConnection error 2077 _, err = enforcer.appRetrieveState(p) 2078 ShouldEqual(err, errNoConnection) 2079 2080 // Create a Ack packet 2081 p, err = packet.NewIpv4TCPPacket(1, packet.TCPAckMask, "127.0.0.1", "127.0.0.1", 43758, 9000) 2082 So(err, ShouldBeNil) 2083 2084 // 3. We should get error "No context in app processing" 2085 _, err = enforcer.appRetrieveState(p) 2086 ShouldResemble(err, errors.New("No context in app processing")) 2087 2088 // Create port policy 2089 err = CreatePortPolicy(enforcer, "testContextID", "/ns1", common.LinuxProcessPU, nil, "2", 43758, 43758) 2090 So(err, ShouldBeNil) 2091 2092 p.Mark = "2" 2093 2094 // 4. We should get a connection object with UnknownState 2095 conn, err := enforcer.appRetrieveState(p) 2096 So(err, ShouldBeNil) 2097 So(conn, ShouldNotBeNil) 2098 ShouldEqual(conn.GetState(), connection.UnknownState) 2099 2100 // add connection to the server cache 2101 connServer := &connection.TCPConnection{} 2102 enforcer.tcpServer.Put(p.L4ReverseFlowHash(), connServer) 2103 2104 // 5. Should be in the cache 2105 conn, err = enforcer.appRetrieveState(p) 2106 So(err, ShouldBeNil) 2107 ShouldEqual(conn, connServer) 2108 2109 // add connection to the client cache 2110 connClient := &connection.TCPConnection{} 2111 enforcer.tcpClient.Put(p.L4FlowHash(), connClient) 2112 2113 // 6. Should be in the cache 2114 conn, err = enforcer.appRetrieveState(p) 2115 So(err, ShouldBeNil) 2116 ShouldEqual(conn, connClient) 2117 }) 2118 } 2119 2120 func TestNetRetrieveState(t *testing.T) { 2121 2122 ctrl := gomock.NewController(t) 2123 defer ctrl.Finish() 2124 2125 // Testing datapath.netRetrieveState 2126 // There are 7 branch conditions in this function. 2127 2128 Convey("Given I setup an enforcer", t, func() { 2129 2130 defer MockGetUDPRawSocket()() 2131 2132 enforcer, _, _, _, _ := NewWithMocks(ctrl, "serverID1", constants.LocalServer, []string{"0.0.0.0/0"}, true) 2133 2134 // Create a Rst packet 2135 p, err := packet.NewIpv4TCPPacket(1, packet.TCPRstMask, "127.0.0.1", "127.0.0.1", 43758, 9000) 2136 So(err, ShouldBeNil) 2137 2138 // 1. We should get the errRstPacket error 2139 _, err = enforcer.netRetrieveState(p) 2140 ShouldEqual(err, errRstPacket) 2141 2142 // Create a Syn packet 2143 p, err = packet.NewIpv4TCPPacket(1, packet.TCPSynMask, "127.0.0.1", "127.0.0.1", 43758, 9000) 2144 So(err, ShouldBeNil) 2145 2146 // 2. We should get errNoConnection error 2147 _, err = enforcer.netRetrieveState(p) 2148 ShouldEqual(err, errNoConnection) 2149 2150 // Create a Ack packet 2151 p, err = packet.NewIpv4TCPPacket(1, packet.TCPAckMask, "127.0.0.1", "127.0.0.1", 43758, 9000) 2152 So(err, ShouldBeNil) 2153 2154 // 3. We should get error " TCP Port Not Found 9000" 2155 _, err = enforcer.netRetrieveState(p) 2156 ShouldResemble(err, errors.New(" TCP Port Not Found 9000")) 2157 2158 // Create port policy 2159 err = CreatePortPolicy(enforcer, "testContextID", "/ns1", common.LinuxProcessPU, nil, "2", 9000, 9000) 2160 So(err, ShouldBeNil) 2161 2162 p.Mark = "2" 2163 2164 // 4. We should get a connection object with UnknownState 2165 conn, err := enforcer.netRetrieveState(p) 2166 So(err, ShouldBeNil) 2167 So(conn, ShouldNotBeNil) 2168 ShouldEqual(conn.GetState(), connection.UnknownState) 2169 2170 // add connection to the server cache 2171 connServer := &connection.TCPConnection{} 2172 enforcer.tcpServer.Put(p.L4FlowHash(), connServer) 2173 2174 // 5. Should be in the cache 2175 conn, err = enforcer.netRetrieveState(p) 2176 So(err, ShouldBeNil) 2177 ShouldEqual(conn, connServer) 2178 2179 // add connection to the client cache 2180 connClient := &connection.TCPConnection{} 2181 enforcer.tcpClient.Put(p.L4ReverseFlowHash(), connClient) 2182 2183 // 6. Should be in the cache 2184 conn, err = enforcer.netRetrieveState(p) 2185 So(err, ShouldBeNil) 2186 ShouldEqual(conn, connClient) 2187 2188 // Change to a Rst packet 2189 p.SetTCPFlags(packet.TCPRstMask) 2190 2191 // 7. Should be in the cache, but should get error errRstPacket 2192 _, err = enforcer.netRetrieveState(p) 2193 So(err, ShouldNotBeNil) 2194 ShouldEqual(err, errRstPacket) 2195 }) 2196 } 2197 2198 // This is to ensure that if we get tcp fo packet with no identity payload that we drop the packet 2199 func TestProcessNetworkSynPacket(t *testing.T) { 2200 2201 ctrl := gomock.NewController(t) 2202 defer ctrl.Finish() 2203 2204 Convey("When I setup an enforcer", t, func() { 2205 2206 defer MockGetUDPRawSocket()() 2207 2208 enforcer, _, _, mockCollector, _ := NewWithMocks(ctrl, "serverID1", constants.LocalServer, []string{"0.0.0.0/0"}, true) 2209 2210 flowRecord := CreateFlowRecord(1, testSrcIP, testDstIP, 43758, 80, policy.Reject|policy.Log, collector.MissingToken) 2211 mockCollector.EXPECT().CollectFlowEvent(MyMatcher(&flowRecord)).Times(1) 2212 2213 Convey("So I received a packet with tcp fast open option set but no payload", func() { 2214 2215 p, err := packet.NewIpv4TCPPacket(1, 0x2, testSrcIP, testDstIP, 43758, 80) 2216 So(err, ShouldBeNil) 2217 So(p, ShouldNotBeNil) 2218 2219 // Add the fast open option 2220 buffer := append(p.GetBuffer(0), []byte{packet.TCPAuthenticationOption, enforcerconstants.TCPAuthenticationOptionBaseLen, 0, 0}...) 2221 err = p.UpdatePacketBuffer(buffer, 4) 2222 So(err, ShouldBeNil) 2223 2224 err = p.CheckTCPAuthenticationOption(enforcerconstants.TCPAuthenticationOptionBaseLen) 2225 So(err, ShouldBeNil) 2226 So(p.IsEmptyTCPPayload(), ShouldBeTrue) 2227 2228 context, err := CreatePUContext(enforcer, "dummyContext", "/ns1", common.LinuxProcessPU, nil) 2229 So(err, ShouldBeNil) 2230 So(context, ShouldNotBeNil) 2231 2232 _, err = enforcer.processNetworkSynPacket(context, connection.NewTCPConnection(context, p), p) 2233 So(err, ShouldNotBeNil) 2234 }) 2235 }) 2236 } 2237 2238 func TestProcessNetworkSynAckPacket(t *testing.T) { 2239 2240 ctrl := gomock.NewController(t) 2241 defer ctrl.Finish() 2242 2243 Convey("When I setup an enforcer", t, func() { 2244 2245 defer MockGetUDPRawSocket()() 2246 2247 enforcer, _, _, mockCollector, _ := NewWithMocks(ctrl, "serverID1", constants.LocalServer, []string{"0.0.0.0/0"}, true) 2248 2249 flowRecord1 := CreateFlowRecord(1, testDstIP, testSrcIP, 80, 43758, policy.Reject|policy.Log, collector.PolicyDrop) 2250 mockCollector.EXPECT().CollectFlowEvent(MyMatcher(&flowRecord1)).Times(1) 2251 2252 flowRecord2 := CreateFlowRecord(1, testDstIP, testSrcIP, 80, 43758, policy.Accept, "") 2253 mockCollector.EXPECT().CollectFlowEvent(MyMatcher(&flowRecord2)).Times(1) 2254 2255 Convey("So I received a packet with tcp fast open option set but no payload", func() { 2256 2257 p, err := packet.NewIpv4TCPPacket(1, 0x2, testSrcIP, testDstIP, 43758, 80) 2258 So(err, ShouldBeNil) 2259 So(p, ShouldNotBeNil) 2260 2261 // Add the fast open option 2262 buffer := append(p.GetBuffer(0), []byte{packet.TCPAuthenticationOption, enforcerconstants.TCPAuthenticationOptionBaseLen, 0, 0}...) 2263 2264 err = p.UpdatePacketBuffer(buffer, 4) 2265 So(err, ShouldBeNil) 2266 2267 err = p.CheckTCPAuthenticationOption(enforcerconstants.TCPAuthenticationOptionBaseLen) 2268 So(err, ShouldBeNil) 2269 So(p.IsEmptyTCPPayload(), ShouldBeTrue) 2270 2271 context, err := CreatePUContext(enforcer, "dummyContext", "/ns1", common.LinuxProcessPU, nil) 2272 So(err, ShouldBeNil) 2273 So(context, ShouldNotBeNil) 2274 2275 _, err = enforcer.processNetworkSynAckPacket(context, connection.NewTCPConnection(context, p), p) 2276 So(err, ShouldNotBeNil) 2277 2278 Convey("Then i add ip acl rule.", func() { 2279 iprules := policy.IPRuleList{policy.IPRule{ 2280 Addresses: []string{"10.1.10.76/32"}, 2281 Ports: []string{"43758"}, 2282 Protocols: []string{constants.TCPProtoNum}, 2283 Policy: &policy.FlowPolicy{ 2284 Action: policy.Accept, 2285 PolicyID: "tcp172/8"}, 2286 }} 2287 err = context.UpdateApplicationACLs(iprules) 2288 So(err, ShouldBeNil) 2289 2290 _, err = enforcer.processNetworkSynAckPacket(context, connection.NewTCPConnection(context, p), p) 2291 So(err, ShouldBeNil) 2292 }) 2293 }) 2294 2295 }) 2296 }