github.com/aporeto-inc/trireme-lib@v10.358.0+incompatible/controller/internal/enforcer/nfqdatapath/datapath_test.go (about)

     1  // +build linux
     2  
     3  package nfqdatapath
     4  
     5  import (
     6  	"context"
     7  	"crypto/ecdsa"
     8  	"encoding/binary"
     9  	"errors"
    10  	"fmt"
    11  	"math/rand"
    12  	"net"
    13  	"reflect"
    14  	"strconv"
    15  	"testing"
    16  	"time"
    17  
    18  	"github.com/golang/mock/gomock"
    19  	. "github.com/smartystreets/goconvey/convey"
    20  	"go.aporeto.io/enforcerd/trireme-lib/collector"
    21  	"go.aporeto.io/enforcerd/trireme-lib/common"
    22  	"go.aporeto.io/enforcerd/trireme-lib/controller/constants"
    23  	enforcerconstants "go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/constants"
    24  	"go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/utils/packetgen"
    25  	"go.aporeto.io/enforcerd/trireme-lib/controller/pkg/connection"
    26  	"go.aporeto.io/enforcerd/trireme-lib/controller/pkg/counters"
    27  	"go.aporeto.io/enforcerd/trireme-lib/controller/pkg/flowtracking/mockflowclient"
    28  	"go.aporeto.io/enforcerd/trireme-lib/controller/pkg/packet"
    29  	"go.aporeto.io/enforcerd/trireme-lib/controller/pkg/packettracing"
    30  	"go.aporeto.io/enforcerd/trireme-lib/controller/pkg/pucontext"
    31  	"go.aporeto.io/enforcerd/trireme-lib/policy"
    32  	"go.aporeto.io/enforcerd/trireme-lib/utils/portspec"
    33  	"gotest.tools/assert"
    34  )
    35  
    36  func TestEnforcerExternalNetworks(t *testing.T) {
    37  
    38  	ctrl := gomock.NewController(t)
    39  	defer ctrl.Finish()
    40  
    41  	testThePackets := func(enforcer *Datapath) {
    42  
    43  		PacketFlow := packetgen.NewTemplateFlow()
    44  
    45  		_, err := PacketFlow.GenerateTCPFlow(packetgen.PacketFlowTypeGoodFlowTemplate)
    46  		So(err, ShouldBeNil)
    47  
    48  		synackPacket, err := PacketFlow.GetFirstSynAckPacket().ToBytes()
    49  		So(err, ShouldBeNil)
    50  
    51  		tcpPacket, _ := packet.New(0, synackPacket, "0", true)
    52  		_, err1 := enforcer.processApplicationTCPPackets(tcpPacket)
    53  		So(err1, ShouldBeNil)
    54  
    55  	}
    56  
    57  	Convey("When the mode is RemoteConainter", t, func() {
    58  
    59  		enforcer, secrets, mockTokenAccessor, _, _ := NewWithMocks(ctrl, "serverID1", constants.RemoteContainer, []string{"0.0.0.0/0"}, true)
    60  
    61  		secrets.EXPECT().TransmittedKey().Return([]byte("dummy")).AnyTimes()
    62  		secrets.EXPECT().EncodingKey().Return(&ecdsa.PrivateKey{}).AnyTimes()
    63  		mockTokenAccessor.EXPECT().Sign(gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil).AnyTimes()
    64  		mockTokenAccessor.EXPECT().CreateSynPacketToken(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil)
    65  
    66  		iprules := policy.IPRuleList{policy.IPRule{
    67  			Addresses: []string{"10.1.10.76/32"},
    68  			Ports:     []string{"80"},
    69  			Protocols: []string{constants.TCPProtoNum},
    70  			Policy: &policy.FlowPolicy{
    71  				Action:   policy.Accept,
    72  				PolicyID: "tcp172/8"},
    73  		}}
    74  
    75  		contextID := "123456"
    76  		puInfo := policy.NewPUInfo(contextID, "/ns1", common.LinuxProcessPU)
    77  
    78  		context, err := pucontext.NewPU(contextID, puInfo, mockTokenAccessor, 10*time.Second)
    79  		So(err, ShouldBeNil)
    80  		enforcer.puFromContextID.AddOrUpdate(contextID, context)
    81  		s, _ := portspec.NewPortSpec(80, 80, contextID)
    82  		enforcer.contextIDFromTCPPort.AddPortSpec(s)
    83  
    84  		err = context.UpdateNetworkACLs(iprules)
    85  		So(err, ShouldBeNil)
    86  
    87  		testThePackets(enforcer)
    88  	})
    89  }
    90  
    91  func TestInvalidContext(t *testing.T) {
    92  
    93  	ctrl := gomock.NewController(t)
    94  	defer ctrl.Finish()
    95  
    96  	defer MockGetUDPRawSocket()()
    97  
    98  	Convey("Given I create a new enforcer instance", t, func() {
    99  
   100  		enforcer, _, _, _, _ := NewWithMocks(ctrl, "serverID", constants.LocalServer, []string{"0.0.0.0/0"}, true)
   101  
   102  		PacketFlow := packetgen.NewTemplateFlow()
   103  		_, err := PacketFlow.GenerateTCPFlow(packetgen.PacketFlowTypeGoodFlowTemplate)
   104  		So(err, ShouldBeNil)
   105  		synPacket, err := PacketFlow.GetFirstSynPacket().ToBytes()
   106  		So(err, ShouldBeNil)
   107  		tcpPacket, err := packet.New(0, synPacket, "0", true)
   108  		Convey("When I run a TCP Syn packet through a non existing context", func() {
   109  
   110  			_, err1 := enforcer.processApplicationTCPPackets(tcpPacket)
   111  			_, _, err2 := enforcer.processNetworkTCPPackets(tcpPacket)
   112  
   113  			Convey("Then I should see an error for non existing context", func() {
   114  
   115  				So(err, ShouldBeNil)
   116  				So(err1, ShouldNotBeNil)
   117  				So(err2, ShouldNotBeNil)
   118  			})
   119  		})
   120  	})
   121  }
   122  
   123  func TestPacketHandlingFirstThreePacketsHavePayload(t *testing.T) {
   124  
   125  	ctrl := gomock.NewController(t)
   126  	defer ctrl.Finish()
   127  
   128  	testThePackets := func(enforcer *Datapath) {
   129  		SIP := net.IPv4zero
   130  		firstSynAckProcessed := false
   131  		PacketFlow := packetgen.NewTemplateFlow()
   132  		_, err := PacketFlow.GenerateTCPFlow(packetgen.PacketFlowTypeGoodFlowTemplate)
   133  		So(err, ShouldBeNil)
   134  		for i := 0; i < PacketFlow.GetNumPackets(); i++ {
   135  			oldPacketFromFlow, err := PacketFlow.GetNthPacket(i).ToBytes()
   136  			So(err, ShouldBeNil)
   137  			oldPacket, err := packet.New(0, oldPacketFromFlow, "0", true)
   138  			if err == nil && oldPacket != nil {
   139  				oldPacket.UpdateIPv4Checksum()
   140  				oldPacket.UpdateTCPChecksum()
   141  			}
   142  			tcpPacketFromFlow, err := PacketFlow.GetNthPacket(i).ToBytes()
   143  			So(err, ShouldBeNil)
   144  			tcpPacket, err := packet.New(0, tcpPacketFromFlow, "0", true)
   145  			if err == nil && tcpPacket != nil {
   146  				tcpPacket.UpdateIPv4Checksum()
   147  				tcpPacket.UpdateTCPChecksum()
   148  			}
   149  			if debug {
   150  				fmt.Println("Input packet", i)
   151  				tcpPacket.Print(0, false)
   152  			}
   153  
   154  			So(err, ShouldBeNil)
   155  			So(tcpPacket, ShouldNotBeNil)
   156  
   157  			if reflect.DeepEqual(SIP, net.IPv4zero) {
   158  				SIP = tcpPacket.SourceAddress()
   159  			}
   160  			if !reflect.DeepEqual(SIP, tcpPacket.DestinationAddress()) &&
   161  				!reflect.DeepEqual(SIP, tcpPacket.SourceAddress()) {
   162  				t.Error("Invalid Test Packet")
   163  			}
   164  
   165  			_, err = enforcer.processApplicationTCPPackets(tcpPacket)
   166  			So(err, ShouldBeNil)
   167  
   168  			if debug {
   169  				fmt.Println("Intermediate packet", i)
   170  				tcpPacket.Print(0, false)
   171  			}
   172  
   173  			if tcpPacket.GetTCPFlags()&packet.TCPSynMask != 0 {
   174  				Convey("When I pass a packet with SYN or SYN/ACK flags for packet "+strconv.Itoa(i), func() {
   175  					Convey("Then I expect some data payload to exist on the packet "+strconv.Itoa(i), func() {
   176  						// In our 3 way security handshake syn and syn-ack packet should grow in length
   177  						So(tcpPacket.IPTotalLen(), ShouldBeGreaterThan, oldPacket.IPTotalLen())
   178  					})
   179  				})
   180  			}
   181  
   182  			if !firstSynAckProcessed && tcpPacket.GetTCPFlags()&packet.TCPSynAckMask == packet.TCPAckMask {
   183  				firstSynAckProcessed = true
   184  				Convey("When I pass the first packet with ACK flag for packet "+strconv.Itoa(i), func() {
   185  					Convey("Then I expect some data payload to exist on the packet "+strconv.Itoa(i), func() {
   186  						// In our 3 way security handshake first ack packet should grow in length
   187  						So(tcpPacket.IPTotalLen(), ShouldBeGreaterThan, oldPacket.IPTotalLen())
   188  					})
   189  				})
   190  			}
   191  
   192  			output := make([]byte, len(tcpPacket.GetTCPBytes()))
   193  			copy(output, tcpPacket.GetTCPBytes())
   194  
   195  			outPacket, errp := packet.New(0, output, "0", true)
   196  			So(len(tcpPacket.GetTCPBytes()), ShouldBeLessThanOrEqualTo, len(outPacket.GetTCPBytes()))
   197  			So(errp, ShouldBeNil)
   198  
   199  			_, f, err := enforcer.processNetworkTCPPackets(outPacket)
   200  			if f != nil {
   201  				f()
   202  			}
   203  
   204  			So(err, ShouldBeNil)
   205  
   206  			if debug {
   207  				fmt.Println("Output packet", i)
   208  				outPacket.Print(0, false)
   209  			}
   210  		}
   211  	}
   212  
   213  	flowRecord := CreateFlowRecord(1, testSrcIP, testDstIP, 666, 80, policy.Accept, "")
   214  
   215  	Convey("When the mode is RemoteConainter", t, func() {
   216  
   217  		enforcer, mockCollector := createEnforcerWithPolicy(ctrl, constants.RemoteContainer)
   218  		mockCollector.EXPECT().CollectFlowEvent(MyMatcher(&flowRecord)).Times(1)
   219  		testThePackets(enforcer)
   220  
   221  	})
   222  
   223  	Convey("When the mode is LocalServer", t, func() {
   224  
   225  		enforcer, mockCollector := createEnforcerWithPolicy(ctrl, constants.LocalServer)
   226  		mockCollector.EXPECT().CollectFlowEvent(MyMatcher(&flowRecord)).Times(1)
   227  		testThePackets(enforcer)
   228  
   229  	})
   230  }
   231  
   232  func TestInvalidIPContext(t *testing.T) {
   233  
   234  	ctrl := gomock.NewController(t)
   235  	defer ctrl.Finish()
   236  
   237  	defer MockGetUDPRawSocket()()
   238  
   239  	Convey("Given I create a new enforcer instance", t, func() {
   240  
   241  		enforcer, secrets, mockTokenAccessor, mockCollector, mockDNS := NewWithMocks(ctrl, "serverID", constants.LocalServer, []string{"0.0.0.0/0"}, true)
   242  
   243  		secrets.EXPECT().TransmittedKey().Return([]byte("dummy")).AnyTimes()
   244  		secrets.EXPECT().EncodingKey().Return(&ecdsa.PrivateKey{}).AnyTimes()
   245  		mockTokenAccessor.EXPECT().Sign(gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil).AnyTimes()
   246  		mockTokenAccessor.EXPECT().CreateSynPacketToken(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil)
   247  		mockDNS.EXPECT().StartDNSServer(gomock.Any(), gomock.Any(), gomock.Any()).Times(1)
   248  		mockDNS.EXPECT().Enforce(gomock.Any(), gomock.Any(), gomock.Any()).Times(1)
   249  		mockDNS.EXPECT().Unenforce(gomock.Any(), gomock.Any()).Times(1)
   250  		mockDNS.EXPECT().SyncWithPlatformCache(gomock.Any(), gomock.Any()).Times(1)
   251  
   252  		puInfo := policy.NewPUInfo("SomeProcessingUnitId", "/ns2", common.LinuxProcessPU)
   253  
   254  		CounterReport := &collector.CounterReport{
   255  			PUID:      puInfo.Policy.ManagementID(),
   256  			Namespace: puInfo.Policy.ManagementNamespace(),
   257  		}
   258  		mockCollector.EXPECT().CollectCounterEvent(MyCounterMatcher(CounterReport)).MinTimes(1)
   259  
   260  		enforcer.Enforce(context.Background(), "serverID", puInfo) // nolint
   261  		defer func() {
   262  			if err := enforcer.Unenforce(context.Background(), "serverID"); err != nil {
   263  				fmt.Println("Error", err.Error())
   264  			}
   265  		}()
   266  
   267  		PacketFlow := packetgen.NewTemplateFlow()
   268  		_, err := PacketFlow.GenerateTCPFlow(packetgen.PacketFlowTypeMultipleGoodFlow)
   269  		So(err, ShouldBeNil)
   270  		synPacket, err := PacketFlow.GetFirstSynPacket().ToBytes()
   271  		So(err, ShouldBeNil)
   272  		tcpPacket, err := packet.New(0, synPacket, "0", true)
   273  
   274  		Convey("When I run a TCP Syn packet through an invalid existing context (missing IP)", func() {
   275  
   276  			_, err1 := enforcer.processApplicationTCPPackets(tcpPacket)
   277  			_, _, err2 := enforcer.processNetworkTCPPackets(tcpPacket)
   278  
   279  			Convey("Then I should see an error for missing IP", func() {
   280  
   281  				So(err, ShouldBeNil)
   282  				So(err1, ShouldNotBeNil)
   283  				So(err2, ShouldNotBeNil)
   284  			})
   285  		})
   286  	})
   287  }
   288  
   289  // TestEnforcerConnUnknownState test ensures that enforcer closes the
   290  // connection by converting packets to rst when it finds connection
   291  // to be in unknown state. This happens when enforcer has not seen the
   292  // 3way handshake for a connection.
   293  func TestEnforcerConnUnknownState(t *testing.T) {
   294  
   295  	ctrl := gomock.NewController(t)
   296  	defer ctrl.Finish()
   297  
   298  	testThePackets := func(enforcer *Datapath) {
   299  		Convey("If I send an ack packet from either PU to the other, it is converted into a Fin/Ack", func() {
   300  			PacketFlow := packetgen.NewTemplateFlow()
   301  			_, err := PacketFlow.GenerateTCPFlow(packetgen.PacketFlowTypeGoodFlowTemplate)
   302  			So(err, ShouldBeNil)
   303  
   304  			input, err := PacketFlow.GetFirstAckPacket().ToBytes()
   305  			So(err, ShouldBeNil)
   306  
   307  			tcpPacket, err := packet.New(0, input, "0", true)
   308  			// create a copy of the ack packet
   309  			tcpPacketCopy := *tcpPacket
   310  
   311  			if err == nil && tcpPacket != nil {
   312  				tcpPacket.UpdateIPv4Checksum()
   313  				tcpPacket.UpdateTCPChecksum()
   314  			}
   315  
   316  			_, err1 := enforcer.processApplicationTCPPackets(tcpPacket)
   317  
   318  			// Test whether the packet is modified with Fin/Ack
   319  			if tcpPacket.GetTCPFlags() != 0x04 {
   320  				t.Fail()
   321  			}
   322  
   323  			_, _, err2 := enforcer.processNetworkTCPPackets(&tcpPacketCopy)
   324  
   325  			if tcpPacket.GetTCPFlags() != 0x04 {
   326  				t.Fail()
   327  			}
   328  
   329  			So(err1, ShouldBeNil)
   330  			So(err2, ShouldBeNil)
   331  		})
   332  	}
   333  
   334  	Convey("When the mode is RemoteConainter", t, func() {
   335  
   336  		enforcer, _ := createEnforcerWithPolicy(ctrl, constants.RemoteContainer)
   337  		testThePackets(enforcer)
   338  
   339  	})
   340  
   341  	Convey("When the mode is LocalServer", t, func() {
   342  
   343  		enforcer, _ := createEnforcerWithPolicy(ctrl, constants.LocalServer)
   344  		testThePackets(enforcer)
   345  
   346  	})
   347  }
   348  
   349  func TestInvalidTokenContext(t *testing.T) {
   350  
   351  	ctrl := gomock.NewController(t)
   352  	defer ctrl.Finish()
   353  
   354  	defer MockGetUDPRawSocket()()
   355  
   356  	testThePackets := func(enforcer *Datapath) {
   357  
   358  		PacketFlow := packetgen.NewTemplateFlow()
   359  		_, err := PacketFlow.GenerateTCPFlow(packetgen.PacketFlowTypeGoodFlowTemplate)
   360  		So(err, ShouldBeNil)
   361  		synPacket, err := PacketFlow.GetFirstSynPacket().ToBytes()
   362  		So(err, ShouldBeNil)
   363  		tcpPacket, err := packet.New(0, synPacket, "0", true)
   364  
   365  		Convey("When I run a TCP Syn packet through an invalid existing context (missing IP)", func() {
   366  
   367  			_, err1 := enforcer.processApplicationTCPPackets(tcpPacket)
   368  			_, _, err2 := enforcer.processNetworkTCPPackets(tcpPacket)
   369  
   370  			Convey("Then I should see an error for missing Token", func() {
   371  
   372  				So(err, ShouldBeNil)
   373  				So(err1, ShouldNotBeNil)
   374  				So(err2, ShouldNotBeNil)
   375  			})
   376  		})
   377  	}
   378  
   379  	Convey("Given I create a new enforcer instance", t, func() {
   380  
   381  		puInfo := policy.NewPUInfo("SomeProcessingUnitId", "/ns2", common.LinuxProcessPU)
   382  
   383  		ip := policy.ExtendedMap{
   384  			"brige": testDstIP,
   385  		}
   386  		puInfo.Runtime.SetIPAddresses(ip)
   387  
   388  		enforcer, secrets, mockTokenAccessor, _, mockDNS := NewWithMocks(ctrl, "serverID", constants.LocalServer, []string{"0.0.0.0/0"}, true)
   389  
   390  		secrets.EXPECT().TransmittedKey().Return([]byte("dummy")).AnyTimes()
   391  		secrets.EXPECT().EncodingKey().Return(&ecdsa.PrivateKey{}).AnyTimes()
   392  		mockTokenAccessor.EXPECT().Sign(gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil).AnyTimes()
   393  		mockTokenAccessor.EXPECT().CreateSynPacketToken(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil)
   394  		mockDNS.EXPECT().StartDNSServer(gomock.Any(), gomock.Any(), gomock.Any()).Times(1)
   395  		mockDNS.EXPECT().Enforce(gomock.Any(), gomock.Any(), gomock.Any()).Times(1)
   396  		mockDNS.EXPECT().SyncWithPlatformCache(gomock.Any(), gomock.Any()).Times(1)
   397  
   398  		enforcer.Enforce(context.Background(), "serverID", puInfo) // nolint
   399  
   400  		testThePackets(enforcer)
   401  	})
   402  }
   403  
   404  func TestPacketHandlingDstPortCacheBehavior(t *testing.T) {
   405  
   406  	ctrl := gomock.NewController(t)
   407  	defer ctrl.Finish()
   408  
   409  	testThePackets := func(enforcer *Datapath) {
   410  
   411  		SIP := net.IPv4zero
   412  
   413  		Convey("When I pass multiple packets through the enforcer", func() {
   414  
   415  			PacketFlow := packetgen.NewTemplateFlow()
   416  			_, err := PacketFlow.GenerateTCPFlow(packetgen.PacketFlowTypeGoodFlowTemplate)
   417  			So(err, ShouldBeNil)
   418  			for i := 0; i < PacketFlow.GetNumPackets(); i++ {
   419  				oldPacketFromFlow, err := PacketFlow.GetNthPacket(i).ToBytes()
   420  				So(err, ShouldBeNil)
   421  				oldPacket, err := packet.New(0, oldPacketFromFlow, "0", true)
   422  				if err == nil && oldPacket != nil {
   423  					oldPacket.UpdateIPv4Checksum()
   424  					oldPacket.UpdateTCPChecksum()
   425  				}
   426  				tcpPacketFromFlow, err := PacketFlow.GetNthPacket(i).ToBytes()
   427  				So(err, ShouldBeNil)
   428  				tcpPacket, err := packet.New(0, tcpPacketFromFlow, "0", true)
   429  				if err == nil && tcpPacket != nil {
   430  					tcpPacket.UpdateIPv4Checksum()
   431  					tcpPacket.UpdateTCPChecksum()
   432  				}
   433  
   434  				if debug {
   435  					fmt.Println("Input packet", i)
   436  					tcpPacket.Print(0, false)
   437  				}
   438  
   439  				So(err, ShouldBeNil)
   440  				So(tcpPacket, ShouldNotBeNil)
   441  
   442  				if reflect.DeepEqual(SIP, net.IPv4zero) {
   443  					SIP = tcpPacket.SourceAddress()
   444  				}
   445  				if !reflect.DeepEqual(SIP, tcpPacket.DestinationAddress()) &&
   446  					!reflect.DeepEqual(SIP, tcpPacket.SourceAddress()) {
   447  					t.Error("Invalid Test Packet")
   448  				}
   449  
   450  				_, err = enforcer.processApplicationTCPPackets(tcpPacket)
   451  				So(err, ShouldBeNil)
   452  
   453  				if debug {
   454  					fmt.Println("Intermediate packet", i)
   455  					tcpPacket.Print(0, false)
   456  				}
   457  
   458  				output := make([]byte, len(tcpPacket.GetTCPBytes()))
   459  				copy(output, tcpPacket.GetTCPBytes())
   460  
   461  				outPacket, errp := packet.New(0, output, "0", true)
   462  				So(len(tcpPacket.GetTCPBytes()), ShouldBeLessThanOrEqualTo, len(outPacket.GetTCPBytes()))
   463  				So(errp, ShouldBeNil)
   464  				_, f, err := enforcer.processNetworkTCPPackets(outPacket)
   465  				if f != nil {
   466  					f()
   467  				}
   468  
   469  				So(err, ShouldBeNil)
   470  
   471  				if debug {
   472  					fmt.Println("Output packet", i)
   473  					outPacket.Print(0, false)
   474  				}
   475  			}
   476  		})
   477  	}
   478  
   479  	flowRecord := CreateFlowRecord(1, testSrcIP, testDstIP, 666, 80, policy.Accept, "")
   480  
   481  	Convey("When the mode is RemoteConainter", t, func() {
   482  
   483  		enforcer, mockCollector := createEnforcerWithPolicy(ctrl, constants.RemoteContainer)
   484  		mockCollector.EXPECT().CollectFlowEvent(MyMatcher(&flowRecord)).Times(1)
   485  		testThePackets(enforcer)
   486  
   487  	})
   488  
   489  	Convey("When the mode is LocalServer", t, func() {
   490  
   491  		enforcer, mockCollector := createEnforcerWithPolicy(ctrl, constants.LocalServer)
   492  		mockCollector.EXPECT().CollectFlowEvent(MyMatcher(&flowRecord)).Times(1)
   493  		testThePackets(enforcer)
   494  
   495  	})
   496  }
   497  
   498  func TestAckLost(t *testing.T) {
   499  
   500  	ctrl := gomock.NewController(t)
   501  	defer ctrl.Finish()
   502  
   503  	testThePackets := func(enforcer *Datapath) {
   504  		PacketFlow := packetgen.NewTemplateFlow()
   505  		_, err := PacketFlow.GenerateTCPFlow(packetgen.PacketFlowTypeGoodFlowTemplate)
   506  		So(err, ShouldBeNil)
   507  
   508  		synPacket, err := PacketFlow.GetFirstSynPacket().ToBytes()
   509  		So(err, ShouldBeNil)
   510  		tcpPacket, err := packet.New(0, synPacket, "0", true)
   511  		if err == nil && tcpPacket != nil {
   512  			tcpPacket.UpdateIPv4Checksum()
   513  			tcpPacket.UpdateTCPChecksum()
   514  		}
   515  
   516  		_, err = enforcer.processApplicationTCPPackets(tcpPacket)
   517  		So(err, ShouldBeNil)
   518  
   519  		output := make([]byte, len(tcpPacket.GetTCPBytes()))
   520  		copy(output, tcpPacket.GetTCPBytes())
   521  
   522  		outPacket, errp := packet.New(0, output, "0", true)
   523  		So(len(tcpPacket.GetTCPBytes()), ShouldBeLessThanOrEqualTo, len(outPacket.GetTCPBytes()))
   524  		So(errp, ShouldBeNil)
   525  
   526  		_, f, err := enforcer.processNetworkTCPPackets(outPacket)
   527  		if f != nil {
   528  			f()
   529  		}
   530  
   531  		So(err, ShouldBeNil)
   532  
   533  		input, _ := PacketFlow.GetFirstSynAckPacket().ToBytes()
   534  
   535  		tcpPacket, _ = packet.New(0, input, "0", true)
   536  		if tcpPacket != nil {
   537  			tcpPacket.UpdateIPv4Checksum()
   538  			tcpPacket.UpdateTCPChecksum()
   539  		}
   540  
   541  		_, err = enforcer.processApplicationTCPPackets(tcpPacket)
   542  		So(err, ShouldBeNil)
   543  
   544  		output = make([]byte, len(tcpPacket.GetTCPBytes()))
   545  		copy(output, tcpPacket.GetTCPBytes())
   546  
   547  		outPacket, _ = packet.New(0, output, "0", true)
   548  		_, f, err = enforcer.processNetworkTCPPackets(outPacket)
   549  		if f != nil {
   550  			f()
   551  		}
   552  		So(err, ShouldBeNil)
   553  
   554  		input, _ = PacketFlow.GetFirstAckPacket().ToBytes()
   555  		tcpPacket, _ = packet.New(0, input, "0", true)
   556  		if tcpPacket != nil {
   557  			tcpPacket.UpdateIPv4Checksum()
   558  			tcpPacket.UpdateTCPChecksum()
   559  		}
   560  
   561  		_, err = enforcer.processApplicationTCPPackets(tcpPacket)
   562  		So(err, ShouldBeNil)
   563  		//simulate drop, and re-transmit packets.
   564  
   565  		input, _ = PacketFlow.GetFirstSynAckPacket().ToBytes()
   566  
   567  		tcpPacket, _ = packet.New(0, input, "0", true)
   568  		if tcpPacket != nil {
   569  			tcpPacket.UpdateIPv4Checksum()
   570  			tcpPacket.UpdateTCPChecksum()
   571  		}
   572  
   573  		_, err = enforcer.processApplicationTCPPackets(tcpPacket)
   574  		assert.Equal(t, err, nil, "error should be nil")
   575  
   576  		output = make([]byte, len(tcpPacket.GetTCPBytes()))
   577  		copy(output, tcpPacket.GetTCPBytes())
   578  
   579  		outPacket, _ = packet.New(0, output, "0", true)
   580  		_, f, err = enforcer.processNetworkTCPPackets(outPacket)
   581  		if f != nil {
   582  			f()
   583  		}
   584  		assert.Equal(t, err, nil, "error should be nil")
   585  
   586  		input, _ = PacketFlow.GetFirstAckPacket().ToBytes()
   587  
   588  		tcpPacket, _ = packet.New(0, input, "0", true)
   589  		if tcpPacket != nil {
   590  			tcpPacket.UpdateIPv4Checksum()
   591  			tcpPacket.UpdateTCPChecksum()
   592  		}
   593  
   594  		_, err = enforcer.processApplicationTCPPackets(tcpPacket)
   595  		assert.Equal(t, err, nil, "error should be nil")
   596  
   597  		output = make([]byte, len(tcpPacket.GetTCPBytes()))
   598  		copy(output, tcpPacket.GetTCPBytes())
   599  
   600  		outPacket, _ = packet.New(0, output, "0", true)
   601  
   602  		_, f, err = enforcer.processNetworkTCPPackets(outPacket)
   603  		if f != nil {
   604  			f()
   605  		}
   606  
   607  		assert.Equal(t, err, nil, "error should be nil")
   608  
   609  	}
   610  
   611  	flowRecord := CreateFlowRecord(1, testSrcIP, testDstIP, 666, 80, policy.Accept, "")
   612  
   613  	Convey("When the mode is RemoteConainter", t, func() {
   614  
   615  		enforcer, mockCollector := createEnforcerWithPolicy(ctrl, constants.RemoteContainer)
   616  		flowclient := mockflowclient.NewMockFlowClient(ctrl)
   617  		flowclient.EXPECT().UpdateApplicationFlowMark(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(nil).AnyTimes()
   618  		enforcer.conntrack = flowclient
   619  		mockCollector.EXPECT().CollectFlowEvent(MyMatcher(&flowRecord)).Times(1)
   620  
   621  		testThePackets(enforcer)
   622  
   623  	})
   624  }
   625  
   626  func TestConnectionTrackerStateLocalContainer(t *testing.T) {
   627  
   628  	ctrl := gomock.NewController(t)
   629  	defer ctrl.Finish()
   630  
   631  	testThePackets := func(enforcer *Datapath) {
   632  
   633  		PacketFlow := packetgen.NewTemplateFlow()
   634  		_, err := PacketFlow.GenerateTCPFlow(packetgen.PacketFlowTypeGoodFlowTemplate)
   635  		So(err, ShouldBeNil)
   636  		/*first packet in TCPFLOW slice is a syn packet*/
   637  		Convey("When i pass a syn packet through the enforcer", func() {
   638  
   639  			input, err := PacketFlow.GetFirstSynPacket().ToBytes()
   640  			So(err, ShouldBeNil)
   641  
   642  			tcpPacket, err := packet.New(0, input, "0", true)
   643  			if err == nil && tcpPacket != nil {
   644  				tcpPacket.UpdateIPv4Checksum()
   645  				tcpPacket.UpdateTCPChecksum()
   646  			}
   647  
   648  			_, err = enforcer.processApplicationTCPPackets(tcpPacket)
   649  			//After sending syn packet
   650  			CheckAfterAppSynPacket(enforcer, tcpPacket)
   651  			So(err, ShouldBeNil)
   652  			output := make([]byte, len(tcpPacket.GetTCPBytes()))
   653  			copy(output, tcpPacket.GetTCPBytes())
   654  
   655  			outPacket, err := packet.New(0, output, "0", true)
   656  			So(err, ShouldBeNil)
   657  			_, f, err := enforcer.processNetworkTCPPackets(outPacket)
   658  			if f != nil {
   659  				f()
   660  			}
   661  			So(err, ShouldBeNil)
   662  			//Check after processing networksyn packet
   663  			CheckAfterNetSynPacket(enforcer, tcpPacket, outPacket)
   664  
   665  		})
   666  		Convey("When i pass a SYN and SYN ACK packet through the enforcer", func() {
   667  
   668  			input, err := PacketFlow.GetFirstSynPacket().ToBytes()
   669  			So(err, ShouldBeNil)
   670  
   671  			tcpPacket, err := packet.New(0, input, "0", true)
   672  			if err == nil && tcpPacket != nil {
   673  				tcpPacket.UpdateIPv4Checksum()
   674  				tcpPacket.UpdateTCPChecksum()
   675  			}
   676  			_, err = enforcer.processApplicationTCPPackets(tcpPacket)
   677  			So(err, ShouldBeNil)
   678  
   679  			output := make([]byte, len(tcpPacket.GetTCPBytes()))
   680  			copy(output, tcpPacket.GetTCPBytes())
   681  
   682  			outPacket, err := packet.New(0, output, "0", true)
   683  			So(err, ShouldBeNil)
   684  			outPacket.Print(0, false)
   685  			_, f, err := enforcer.processNetworkTCPPackets(outPacket)
   686  			if f != nil {
   687  				f()
   688  			}
   689  			So(err, ShouldBeNil)
   690  
   691  			//Now lets send the synack packet from the server in response
   692  			input, err = PacketFlow.GetFirstSynAckPacket().ToBytes()
   693  			So(err, ShouldBeNil)
   694  
   695  			tcpPacket, err = packet.New(0, input, "0", true)
   696  			if err == nil && tcpPacket != nil {
   697  				tcpPacket.UpdateIPv4Checksum()
   698  				tcpPacket.UpdateTCPChecksum()
   699  			}
   700  			_, err = enforcer.processApplicationTCPPackets(tcpPacket)
   701  			So(err, ShouldBeNil)
   702  
   703  			output = make([]byte, len(tcpPacket.GetTCPBytes()))
   704  			copy(output, tcpPacket.GetTCPBytes())
   705  
   706  			outPacket, err = packet.New(0, output, "0", true)
   707  			So(err, ShouldBeNil)
   708  			outPacketcopy, _ := packet.New(0, output, "0", true)
   709  			_, f, err = enforcer.processNetworkTCPPackets(outPacket)
   710  			if f != nil {
   711  				f()
   712  			}
   713  			So(err, ShouldBeNil)
   714  
   715  			CheckAfterNetSynAckPacket(t, enforcer, outPacketcopy, outPacket)
   716  		})
   717  
   718  		Convey("When i pass a SYN and SYNACK and another ACK packet through the enforcer", func() {
   719  
   720  			input, err := PacketFlow.GetFirstSynPacket().ToBytes()
   721  			So(err, ShouldBeNil)
   722  			tcpPacket, err := packet.New(0, input, "0", true)
   723  			if err == nil && tcpPacket != nil {
   724  				tcpPacket.UpdateIPv4Checksum()
   725  				tcpPacket.UpdateTCPChecksum()
   726  			}
   727  			_, err = enforcer.processApplicationTCPPackets(tcpPacket)
   728  			So(err, ShouldBeNil)
   729  
   730  			output := make([]byte, len(tcpPacket.GetTCPBytes()))
   731  			copy(output, tcpPacket.GetTCPBytes())
   732  
   733  			outPacket, err := packet.New(0, output, "0", true)
   734  			So(err, ShouldBeNil)
   735  			_, f, err := enforcer.processNetworkTCPPackets(outPacket)
   736  			if f != nil {
   737  				f()
   738  			}
   739  			So(err, ShouldBeNil)
   740  
   741  			//Now lets send the synack packet from the server in response
   742  			input, err = PacketFlow.GetFirstSynAckPacket().ToBytes()
   743  			So(err, ShouldBeNil)
   744  
   745  			tcpPacket, err = packet.New(0, input, "0", true)
   746  			if err == nil && tcpPacket != nil {
   747  				tcpPacket.UpdateIPv4Checksum()
   748  				tcpPacket.UpdateTCPChecksum()
   749  			}
   750  			_, err = enforcer.processApplicationTCPPackets(tcpPacket)
   751  			So(err, ShouldBeNil)
   752  
   753  			output = make([]byte, len(tcpPacket.GetTCPBytes()))
   754  			copy(output, tcpPacket.GetTCPBytes())
   755  
   756  			outPacket, err = packet.New(0, output, "0", true)
   757  			So(err, ShouldBeNil)
   758  			_, f, err = enforcer.processNetworkTCPPackets(outPacket)
   759  			if f != nil {
   760  				f()
   761  			}
   762  			So(err, ShouldBeNil)
   763  
   764  			input, err = PacketFlow.GetFirstAckPacket().ToBytes()
   765  			So(err, ShouldBeNil)
   766  
   767  			tcpPacket, err = packet.New(0, input, "0", true)
   768  			if err == nil && tcpPacket != nil {
   769  				tcpPacket.UpdateIPv4Checksum()
   770  				tcpPacket.UpdateTCPChecksum()
   771  			}
   772  			_, err = enforcer.processApplicationTCPPackets(tcpPacket)
   773  			CheckAfterAppAckPacket(enforcer, tcpPacket)
   774  			So(err, ShouldBeNil)
   775  
   776  			output = make([]byte, len(tcpPacket.GetTCPBytes()))
   777  			copy(output, tcpPacket.GetTCPBytes())
   778  
   779  			outPacket, err = packet.New(0, output, "0", true)
   780  			So(err, ShouldBeNil)
   781  			CheckBeforeNetAckPacket(enforcer, tcpPacket, outPacket, false)
   782  			_, f, err = enforcer.processNetworkTCPPackets(outPacket)
   783  			if f != nil {
   784  				f()
   785  			}
   786  			So(err, ShouldBeNil)
   787  		})
   788  	}
   789  
   790  	flowRecord := CreateFlowRecord(1, testSrcIP, testDstIP, 666, 80, policy.Accept, "")
   791  
   792  	Convey("When the mode is RemoteConainter", t, func() {
   793  
   794  		enforcer, mockCollector := createEnforcerWithPolicy(ctrl, constants.RemoteContainer)
   795  		mockCollector.EXPECT().CollectFlowEvent(MyMatcher(&flowRecord)).AnyTimes()
   796  		testThePackets(enforcer)
   797  
   798  	})
   799  
   800  	Convey("When the mode is LocalServer", t, func() {
   801  
   802  		enforcer, mockCollector := createEnforcerWithPolicy(ctrl, constants.LocalServer)
   803  		mockCollector.EXPECT().CollectFlowEvent(MyMatcher(&flowRecord)).AnyTimes()
   804  		testThePackets(enforcer)
   805  
   806  	})
   807  }
   808  
   809  func CheckAfterAppSynPacket(enforcer *Datapath, tcpPacket *packet.Packet) {
   810  
   811  	appConn, _ := enforcer.tcpClient.Get(tcpPacket.L4FlowHash())
   812  	So(appConn.GetState(), ShouldEqual, connection.TCPSynSend)
   813  }
   814  
   815  func CheckAfterNetSynPacket(enforcer *Datapath, tcpPacket, outPacket *packet.Packet) {
   816  
   817  	appConn, _ := enforcer.tcpServer.Get(tcpPacket.L4FlowHash())
   818  	So(appConn.GetState(), ShouldEqual, connection.TCPSynReceived)
   819  }
   820  
   821  func CheckAfterNetSynAckPacket(t *testing.T, enforcer *Datapath, tcpPacket, outPacket *packet.Packet) {
   822  
   823  	netconn, _ := enforcer.tcpClient.Get(outPacket.L4ReverseFlowHash())
   824  	So(netconn.GetState(), ShouldEqual, connection.TCPSynAckReceived)
   825  }
   826  
   827  func CheckAfterAppAckPacket(enforcer *Datapath, tcpPacket *packet.Packet) {
   828  
   829  	appConn, _ := enforcer.tcpClient.Get(tcpPacket.L4FlowHash())
   830  	So(appConn.GetState(), ShouldEqual, connection.TCPAckSend)
   831  }
   832  
   833  func CheckBeforeNetAckPacket(enforcer *Datapath, tcpPacket, outPacket *packet.Packet, isReplay bool) {
   834  
   835  	appConn, _ := enforcer.tcpServer.Get(tcpPacket.L4FlowHash())
   836  	if !isReplay {
   837  		So(appConn.GetState(), ShouldEqual, connection.TCPSynAckSend)
   838  	} else {
   839  		So(appConn.GetState(), ShouldBeGreaterThan, connection.TCPSynAckSend)
   840  	}
   841  }
   842  
   843  func TestCacheState(t *testing.T) {
   844  
   845  	ctrl := gomock.NewController(t)
   846  	defer ctrl.Finish()
   847  
   848  	defer MockGetUDPRawSocket()()
   849  
   850  	Convey("Given I create a new enforcer instance", t, func() {
   851  
   852  		enforcer, secrets, mockTokenAccessor, mockCollector, mockDNS := NewWithMocks(ctrl, "serverID", constants.LocalServer, []string{"0.0.0.0/0"}, true)
   853  
   854  		secrets.EXPECT().EncodingKey().Return(&ecdsa.PrivateKey{}).AnyTimes()
   855  		mockTokenAccessor.EXPECT().Sign(gomock.Any(), gomock.Any()).Times(2).Return([]byte("token"), nil).AnyTimes()
   856  		mockTokenAccessor.EXPECT().CreateSynPacketToken(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil).Times(2)
   857  		mockDNS.EXPECT().StartDNSServer(gomock.Any(), gomock.Any(), gomock.Any()).Times(1)
   858  		mockDNS.EXPECT().Enforce(gomock.Any(), gomock.Any(), gomock.Any()).Times(2)
   859  		mockDNS.EXPECT().Unenforce(gomock.Any(), gomock.Any()).Times(1)
   860  		mockDNS.EXPECT().SyncWithPlatformCache(gomock.Any(), gomock.Any()).Times(2)
   861  
   862  		contextID := "123"
   863  
   864  		puInfo := policy.NewPUInfo(contextID, "/ns1", common.ContainerPU)
   865  
   866  		CounterReport := &collector.CounterReport{
   867  			PUID:      puInfo.Policy.ManagementID(),
   868  			Namespace: puInfo.Policy.ManagementNamespace(),
   869  		}
   870  		mockCollector.EXPECT().CollectCounterEvent(MyCounterMatcher(CounterReport)).Times(2)
   871  
   872  		// Should fail: Not in cache
   873  		err := enforcer.Unenforce(context.Background(), contextID)
   874  		if err == nil {
   875  			t.Errorf("Expected failure, no contextID in cache")
   876  		}
   877  
   878  		ip := policy.ExtendedMap{"bridge": "127.0.0.1"}
   879  		puInfo.Runtime.SetIPAddresses(ip)
   880  		ipl := policy.ExtendedMap{"bridge": "127.0.0.1"}
   881  		puInfo.Policy.SetIPAddresses(ipl)
   882  
   883  		ip = policy.ExtendedMap{"bridge": "127.0.0.1"}
   884  		puInfo.Runtime.SetIPAddresses(ip)
   885  
   886  		ipl = policy.ExtendedMap{"bridge": "127.0.0.1"}
   887  		puInfo.Policy.SetIPAddresses(ipl)
   888  
   889  		// Should  not fail:  IP is valid
   890  		err = enforcer.Enforce(context.Background(), contextID, puInfo)
   891  		if err != nil {
   892  			t.Errorf("Expected no failure %s", err)
   893  		}
   894  
   895  		// Should  not fail:  Update
   896  		err = enforcer.Enforce(context.Background(), contextID, puInfo)
   897  		if err != nil {
   898  			t.Errorf("Expected no failure %s", err)
   899  		}
   900  
   901  		// Should  not fail:  IP is valid
   902  		err = enforcer.Unenforce(context.Background(), contextID)
   903  		if err != nil {
   904  			t.Errorf("Expected failure, no IP but passed %s", err)
   905  		}
   906  	})
   907  }
   908  
   909  func TestDoCreatePU(t *testing.T) {
   910  
   911  	ctrl := gomock.NewController(t)
   912  	defer ctrl.Finish()
   913  
   914  	defer MockGetUDPRawSocket()()
   915  
   916  	Convey("Given an initialized enforcer for Linux Processes", t, func() {
   917  
   918  		defer MockGetUDPRawSocket()()
   919  
   920  		enforcer, secrets, mockTokenAccessor, _, mockDNS := NewWithMocks(ctrl, "serverID", constants.LocalServer, []string{"0.0.0.0/0"}, true)
   921  
   922  		secrets.EXPECT().EncodingKey().Return(&ecdsa.PrivateKey{}).AnyTimes()
   923  		mockTokenAccessor.EXPECT().Sign(gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil).AnyTimes()
   924  		mockTokenAccessor.EXPECT().CreateSynPacketToken(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil)
   925  		mockDNS.EXPECT().StartDNSServer(gomock.Any(), gomock.Any(), gomock.Any()).Times(1)
   926  		mockDNS.EXPECT().Enforce(gomock.Any(), gomock.Any(), gomock.Any()).Times(1)
   927  		mockDNS.EXPECT().SyncWithPlatformCache(gomock.Any(), gomock.Any()).Times(1)
   928  
   929  		contextID := "124"
   930  		puInfo := policy.NewPUInfo(contextID, "/ns1", common.LinuxProcessPU)
   931  
   932  		spec, _ := portspec.NewPortSpecFromString("80", nil)
   933  		puInfo.Runtime.SetOptions(policy.OptionsType{
   934  			CgroupMark: "100",
   935  			Services: []common.Service{
   936  				{
   937  					Protocol: uint8(6),
   938  					Ports:    spec,
   939  				},
   940  			},
   941  		})
   942  
   943  		Convey("When I create a new PU", func() {
   944  			err := enforcer.Enforce(context.Background(), contextID, puInfo)
   945  
   946  			Convey("It should succeed", func() {
   947  				So(err, ShouldBeNil)
   948  				_, err := enforcer.puFromContextID.Get(contextID)
   949  				So(err, ShouldBeNil)
   950  				_, err1 := enforcer.puFromMark.Get("100")
   951  				So(err1, ShouldBeNil)
   952  				_, err2 := enforcer.contextIDFromTCPPort.GetSpecValueFromPort(80)
   953  				So(err2, ShouldBeNil)
   954  				So(enforcer.puFromIP, ShouldBeNil)
   955  			})
   956  		})
   957  	})
   958  
   959  	Convey("Given an initialized enforcer for Linux Processes", t, func() {
   960  
   961  		enforcer, secrets, mockTokenAccessor, _, mockDNS := NewWithMocks(ctrl, "serverID", constants.LocalServer, []string{"0.0.0.0/0"}, true)
   962  
   963  		secrets.EXPECT().EncodingKey().Return(&ecdsa.PrivateKey{}).AnyTimes()
   964  		mockTokenAccessor.EXPECT().Sign(gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil).AnyTimes()
   965  		mockTokenAccessor.EXPECT().CreateSynPacketToken(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil)
   966  		mockDNS.EXPECT().StartDNSServer(gomock.Any(), gomock.Any(), gomock.Any()).Times(1)
   967  		mockDNS.EXPECT().Enforce(gomock.Any(), gomock.Any(), gomock.Any()).Times(1)
   968  		mockDNS.EXPECT().SyncWithPlatformCache(gomock.Any(), gomock.Any()).Times(1)
   969  
   970  		contextID := "125"
   971  		puInfo := policy.NewPUInfo(contextID, "/ns1", common.LinuxProcessPU)
   972  
   973  		Convey("When I create a new PU without ports or mark", func() {
   974  			err := enforcer.Enforce(context.Background(), contextID, puInfo)
   975  
   976  			Convey("It should succeed", func() {
   977  				So(err, ShouldBeNil)
   978  				_, err := enforcer.puFromContextID.Get(contextID)
   979  				So(err, ShouldBeNil)
   980  				So(enforcer.puFromIP, ShouldBeNil)
   981  			})
   982  		})
   983  	})
   984  
   985  	Convey("Given an initialized enforcer for remote Linux Containers", t, func() {
   986  
   987  		enforcer, secrets, mockTokenAccessor, _, mockDNS := NewWithMocks(ctrl, "serverID", constants.RemoteContainer, []string{"0.0.0.0/0"}, true)
   988  
   989  		secrets.EXPECT().EncodingKey().Return(&ecdsa.PrivateKey{}).AnyTimes()
   990  		mockTokenAccessor.EXPECT().Sign(gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil).AnyTimes()
   991  		mockTokenAccessor.EXPECT().CreateSynPacketToken(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil)
   992  		mockDNS.EXPECT().StartDNSServer(gomock.Any(), gomock.Any(), gomock.Any()).Times(1)
   993  		mockDNS.EXPECT().Enforce(gomock.Any(), gomock.Any(), gomock.Any()).Times(1)
   994  		mockDNS.EXPECT().SyncWithPlatformCache(gomock.Any(), gomock.Any()).Times(1)
   995  
   996  		contextID := "126"
   997  		puInfo := policy.NewPUInfo(contextID, "/ns1", common.ContainerPU)
   998  
   999  		Convey("When I create a new PU without an IP", func() {
  1000  			err := enforcer.Enforce(context.Background(), contextID, puInfo)
  1001  
  1002  			Convey("It should succeed ", func() {
  1003  				So(err, ShouldBeNil)
  1004  				So(enforcer.puFromIP, ShouldNotBeNil)
  1005  			})
  1006  		})
  1007  	})
  1008  }
  1009  
  1010  func TestContextFromIP(t *testing.T) {
  1011  
  1012  	ctrl := gomock.NewController(t)
  1013  	defer ctrl.Finish()
  1014  
  1015  	Convey("Given an initialized enforcer for Linux Processes", t, func() {
  1016  
  1017  		enforcer, _, _, _, _ := NewWithMocks(ctrl, "serverID", constants.RemoteContainer, []string{"0.0.0.0/0"}, true)
  1018  
  1019  		puInfo := policy.NewPUInfo("SomePU", "/ns", common.ContainerPU)
  1020  
  1021  		context, err := pucontext.NewPU("SomePU", puInfo, nil, 10*time.Second)
  1022  		contextID := "AporetoContext"
  1023  		So(err, ShouldBeNil)
  1024  
  1025  		Convey("If I try to get context based on IP and its  not there and its a local container it should fail ", func() {
  1026  			_, err := enforcer.contextFromIP(true, "", 0, packet.IPProtocolTCP)
  1027  			So(err, ShouldNotBeNil)
  1028  		})
  1029  
  1030  		Convey("If there is no IP match, it should try the mark for app packets ", func() {
  1031  			enforcer.puFromMark.AddOrUpdate("100", context)
  1032  			enforcer.mode = constants.LocalServer
  1033  			Convey("If the mark exists", func() {
  1034  				markVal := strconv.Itoa(100)
  1035  				ctx, err := enforcer.contextFromIP(true, markVal, 0, packet.IPProtocolTCP)
  1036  				So(err, ShouldBeNil)
  1037  				So(ctx, ShouldNotBeNil)
  1038  				So(ctx, ShouldEqual, context)
  1039  			})
  1040  
  1041  			Convey("If the mark doesn't exist", func() {
  1042  				_, err := enforcer.contextFromIP(true, "2000", 0, packet.IPProtocolTCP)
  1043  				So(err, ShouldNotBeNil)
  1044  			})
  1045  		})
  1046  
  1047  		Convey("If there is no IP match, it should try the port for net packets ", func() {
  1048  			s, _ := portspec.NewPortSpec(8000, 8000, contextID)
  1049  			enforcer.contextIDFromTCPPort.AddPortSpec(s)
  1050  			enforcer.puFromContextID.AddOrUpdate(contextID, context)
  1051  			enforcer.mode = constants.LocalServer
  1052  
  1053  			Convey("If the port exists", func() {
  1054  				ctx, err := enforcer.contextFromIP(false, "", 8000, packet.IPProtocolTCP)
  1055  				So(err, ShouldBeNil)
  1056  				So(ctx, ShouldNotBeNil)
  1057  				So(ctx, ShouldEqual, context)
  1058  			})
  1059  
  1060  			Convey("If the port doesn't exist", func() {
  1061  				_, err := enforcer.contextFromIP(false, "", 9000, packet.IPProtocolTCP)
  1062  				So(err, ShouldNotBeNil)
  1063  			})
  1064  		})
  1065  
  1066  	})
  1067  
  1068  	Convey("Given an initialized enforcer for HostPU", t, func() {
  1069  
  1070  		enforcer, _, _, _, _ := NewWithMocks(ctrl, "serverID", constants.RemoteContainer, []string{"0.0.0.0/0"}, true)
  1071  
  1072  		puInfo := policy.NewPUInfo("SomeHostPU", "/ns", common.HostPU)
  1073  
  1074  		context, err := pucontext.NewPU("SomeHostPU", puInfo, nil, 10*time.Second)
  1075  		So(err, ShouldBeNil)
  1076  
  1077  		enforcer.hostPU = context
  1078  
  1079  		Convey("If I try to get context for app ICMP for HostPU it should succeed ", func() {
  1080  			ctx, err := enforcer.contextFromIP(true, "", 0, packet.IPProtocolICMP)
  1081  			So(err, ShouldBeNil)
  1082  			So(ctx, ShouldNotBeNil)
  1083  			So(ctx, ShouldEqual, context)
  1084  		})
  1085  		Convey("If I try to get context for net ICMP for HostPU it should succeed ", func() {
  1086  			ctx, err := enforcer.contextFromIP(false, "", 0, packet.IPProtocolICMP)
  1087  			So(err, ShouldBeNil)
  1088  			So(ctx, ShouldNotBeNil)
  1089  			So(ctx, ShouldEqual, context)
  1090  		})
  1091  		Convey("If I try to get context for another protocol it should not return host context ", func() {
  1092  			_, err := enforcer.contextFromIP(true, "", 0, packet.IPProtocolTCP)
  1093  			So(err, ShouldNotBeNil)
  1094  		})
  1095  
  1096  	})
  1097  }
  1098  
  1099  func TestInvalidPacket(t *testing.T) {
  1100  
  1101  	ctrl := gomock.NewController(t)
  1102  	defer ctrl.Finish()
  1103  
  1104  	testThePackets := func(enforcer *Datapath) {
  1105  
  1106  		InvalidTCPFlow := [][]byte{
  1107  			{ /*0x4a, 0x1d, 0x70, 0xcf, 0xa6, 0xe5, 0xb8, 0xe8, 0x56, 0x32, 0x0b, 0xde, 0x08, 0x00,*/ 0x45, 0x00, 0x00, 0x40, 0xf4, 0x1f, 0x44, 0x00, 0x40, 0x06, 0xa9, 0x6f, 0x0a, 0x01, 0x0a, 0x4c, 0xa4, 0x43, 0xe4, 0x98, 0xe1, 0xa1, 0x00, 0x50, 0x4d, 0xa6, 0xac, 0x48, 0x00, 0x00, 0x00, 0x00, 0xb0, 0x02, 0xff, 0xff, 0x6b, 0x6c, 0x00, 0x00, 0x02, 0x04, 0x05, 0xb4, 0x01, 0x03, 0x03, 0x05, 0x01, 0x01, 0x08, 0x0a, 0x1b, 0x4f, 0x37, 0x38, 0x00, 0x00, 0x00, 0x00, 0x04, 0x02, 0x00, 0x00, 0x4a, 0x1d, 0x70, 0xcf},
  1108  		}
  1109  
  1110  		for _, p := range InvalidTCPFlow {
  1111  			tcpPacket, err := packet.New(0, p, "0", true)
  1112  			So(err, ShouldBeNil)
  1113  			_, err = enforcer.processApplicationTCPPackets(tcpPacket)
  1114  			So(err, ShouldBeNil)
  1115  			output := make([]byte, len(tcpPacket.GetTCPBytes()))
  1116  			copy(output, tcpPacket.GetTCPBytes())
  1117  			outpacket, err := packet.New(0, output, "0", true)
  1118  			So(err, ShouldBeNil)
  1119  			//Detach the data and parse token should fail
  1120  			outpacket.TCPDataDetach(binary.BigEndian.Uint16([]byte{0x0, p[32]})/4 - 20)
  1121  			So(err, ShouldBeNil)
  1122  			_, _, err = enforcer.processNetworkTCPPackets(outpacket)
  1123  			So(err, ShouldNotBeNil)
  1124  		}
  1125  
  1126  	}
  1127  
  1128  	flowRecord := CreateFlowRecord(1, testSrcIP, testDstIP, 666, 80, policy.Reject|policy.Log, collector.MissingToken)
  1129  
  1130  	Convey("When the mode is RemoteConainter", t, func() {
  1131  
  1132  		enforcer, mockCollector := createEnforcerWithPolicy(ctrl, constants.RemoteContainer)
  1133  		mockCollector.EXPECT().CollectFlowEvent(MyMatcher(&flowRecord)).Times(1)
  1134  		testThePackets(enforcer)
  1135  
  1136  	})
  1137  
  1138  	Convey("When the mode is LocalServer", t, func() {
  1139  
  1140  		enforcer, mockCollector := createEnforcerWithPolicy(ctrl, constants.LocalServer)
  1141  		mockCollector.EXPECT().CollectFlowEvent(MyMatcher(&flowRecord)).Times(1)
  1142  		testThePackets(enforcer)
  1143  
  1144  	})
  1145  }
  1146  
  1147  func TestFlowReportingInvalidSyn(t *testing.T) {
  1148  
  1149  	ctrl := gomock.NewController(t)
  1150  	defer ctrl.Finish()
  1151  
  1152  	testThePackets := func(enforcer *Datapath) {
  1153  
  1154  		SIP := net.IPv4zero
  1155  		packetDiffers := false
  1156  
  1157  		PacketFlow := packetgen.NewTemplateFlow()
  1158  		_, err := PacketFlow.GenerateTCPFlow(packetgen.PacketFlowTypeGoodFlowTemplate)
  1159  		So(err, ShouldBeNil)
  1160  		for i := 0; i < PacketFlow.GetSynPackets().GetNumPackets(); i++ {
  1161  
  1162  			start, err := PacketFlow.GetSynPackets().GetNthPacket(i).ToBytes()
  1163  			So(err, ShouldBeNil)
  1164  			oldPacket, err := packet.New(0, start, "0", true)
  1165  			if err == nil && oldPacket != nil {
  1166  				oldPacket.UpdateIPv4Checksum()
  1167  				oldPacket.UpdateTCPChecksum()
  1168  			}
  1169  
  1170  			input, err := PacketFlow.GetSynPackets().GetNthPacket(i).ToBytes()
  1171  			So(err, ShouldBeNil)
  1172  			tcpPacket, err := packet.New(0, input, "0", true)
  1173  			if err == nil && tcpPacket != nil {
  1174  				tcpPacket.UpdateIPv4Checksum()
  1175  				tcpPacket.UpdateTCPChecksum()
  1176  			}
  1177  
  1178  			if debug {
  1179  				fmt.Println("Input packet", i)
  1180  				tcpPacket.Print(0, false)
  1181  			}
  1182  
  1183  			So(err, ShouldBeNil)
  1184  			So(tcpPacket, ShouldNotBeNil)
  1185  
  1186  			if reflect.DeepEqual(SIP, net.IPv4zero) {
  1187  				SIP = tcpPacket.SourceAddress()
  1188  			}
  1189  			if !reflect.DeepEqual(SIP, tcpPacket.DestinationAddress()) &&
  1190  				!reflect.DeepEqual(SIP, tcpPacket.SourceAddress()) {
  1191  				t.Error("Invalid Test Packet")
  1192  			}
  1193  
  1194  			if debug {
  1195  				fmt.Println("Intermediate packet", i)
  1196  				tcpPacket.Print(0, false)
  1197  			}
  1198  
  1199  			output := make([]byte, len(tcpPacket.GetTCPBytes()))
  1200  			copy(output, tcpPacket.GetTCPBytes())
  1201  
  1202  			outPacket, errp := packet.New(0, output, "0", true)
  1203  			So(len(tcpPacket.GetTCPBytes()), ShouldBeLessThanOrEqualTo, len(outPacket.GetTCPBytes()))
  1204  			So(errp, ShouldBeNil)
  1205  			_, _, err = enforcer.processNetworkTCPPackets(outPacket)
  1206  			So(err, ShouldNotBeNil)
  1207  
  1208  			if debug {
  1209  				fmt.Println("Output packet", i)
  1210  				outPacket.Print(0, false)
  1211  			}
  1212  
  1213  			if !reflect.DeepEqual(oldPacket.GetTCPBytes(), outPacket.GetTCPBytes()) {
  1214  				packetDiffers = true
  1215  				fmt.Println("Error: packets dont match")
  1216  				fmt.Println("Input Packet")
  1217  				oldPacket.Print(0, false)
  1218  				fmt.Println("Output Packet")
  1219  				outPacket.Print(0, false)
  1220  				t.Errorf("Packet %d Input and output packet do not match", i)
  1221  				t.FailNow()
  1222  			}
  1223  		}
  1224  
  1225  		Convey("Then I expect all the input and output packets (after encoding and decoding) to be same", func() {
  1226  
  1227  			So(packetDiffers, ShouldEqual, false)
  1228  		})
  1229  	}
  1230  
  1231  	flowRecord := CreateFlowRecord(1, testSrcIP, testDstIP, 666, 80, policy.Reject|policy.Log, collector.MissingToken)
  1232  
  1233  	Convey("When the mode is RemoteConainter", t, func() {
  1234  
  1235  		enforcer, mockCollector := createEnforcerWithPolicy(ctrl, constants.RemoteContainer)
  1236  		mockCollector.EXPECT().CollectFlowEvent(MyMatcher(&flowRecord)).Times(1)
  1237  		testThePackets(enforcer)
  1238  
  1239  	})
  1240  
  1241  	Convey("When the mode is LocalServer", t, func() {
  1242  
  1243  		enforcer, mockCollector := createEnforcerWithPolicy(ctrl, constants.LocalServer)
  1244  		mockCollector.EXPECT().CollectFlowEvent(MyMatcher(&flowRecord)).Times(1)
  1245  		testThePackets(enforcer)
  1246  
  1247  	})
  1248  }
  1249  
  1250  func TestFlowReportingUptoInvalidSynAck(t *testing.T) {
  1251  
  1252  	ctrl := gomock.NewController(t)
  1253  	defer ctrl.Finish()
  1254  
  1255  	testThePackets := func(enforcer *Datapath) {
  1256  
  1257  		SIP := net.IPv4zero
  1258  		packetDiffers := false
  1259  
  1260  		PacketFlow := packetgen.NewTemplateFlow()
  1261  		_, err := PacketFlow.GenerateTCPFlow(packetgen.PacketFlowTypeGoodFlowTemplate)
  1262  		So(err, ShouldBeNil)
  1263  		for i := 0; i < PacketFlow.GetUptoFirstSynAckPacket().GetNumPackets(); i++ {
  1264  			start, err := PacketFlow.GetUptoFirstSynAckPacket().GetNthPacket(i).ToBytes()
  1265  			So(err, ShouldBeNil)
  1266  
  1267  			oldPacket, err := packet.New(0, start, "0", true)
  1268  			if err == nil && oldPacket != nil {
  1269  				oldPacket.UpdateIPv4Checksum()
  1270  				oldPacket.UpdateTCPChecksum()
  1271  			}
  1272  			input, err := PacketFlow.GetUptoFirstSynAckPacket().GetNthPacket(i).ToBytes()
  1273  			So(err, ShouldBeNil)
  1274  			tcpPacket, err := packet.New(0, input, "0", true)
  1275  			if err == nil && tcpPacket != nil {
  1276  				tcpPacket.UpdateIPv4Checksum()
  1277  				tcpPacket.UpdateTCPChecksum()
  1278  			}
  1279  
  1280  			if debug {
  1281  				fmt.Println("Input packet", i)
  1282  				tcpPacket.Print(0, false)
  1283  			}
  1284  
  1285  			So(err, ShouldBeNil)
  1286  			So(tcpPacket, ShouldNotBeNil)
  1287  
  1288  			if reflect.DeepEqual(SIP, net.IPv4zero) {
  1289  				SIP = tcpPacket.SourceAddress()
  1290  			}
  1291  
  1292  			if !reflect.DeepEqual(SIP, tcpPacket.DestinationAddress()) &&
  1293  				!reflect.DeepEqual(SIP, tcpPacket.SourceAddress()) {
  1294  				t.Error("Invalid Test Packet")
  1295  			}
  1296  			if PacketFlow.GetNthPacket(i).GetTCPSyn() && !PacketFlow.GetNthPacket(i).GetTCPAck() {
  1297  				_, err = enforcer.processApplicationTCPPackets(tcpPacket)
  1298  
  1299  				So(err, ShouldBeNil)
  1300  			}
  1301  
  1302  			if debug {
  1303  				fmt.Println("Intermediate packet", i)
  1304  				tcpPacket.Print(0, false)
  1305  			}
  1306  
  1307  			output := make([]byte, len(tcpPacket.GetTCPBytes()))
  1308  			copy(output, tcpPacket.GetTCPBytes())
  1309  
  1310  			outPacket, errp := packet.New(0, output, "0", true)
  1311  			So(len(tcpPacket.GetTCPBytes()), ShouldBeLessThanOrEqualTo, len(outPacket.GetTCPBytes()))
  1312  			So(errp, ShouldBeNil)
  1313  
  1314  			if PacketFlow.GetNthPacket(i).GetTCPSyn() && !PacketFlow.GetNthPacket(i).GetTCPAck() {
  1315  				_, _, err = enforcer.processNetworkTCPPackets(outPacket)
  1316  				So(err, ShouldBeNil)
  1317  			}
  1318  			if PacketFlow.GetNthPacket(i).GetTCPSyn() && PacketFlow.GetNthPacket(i).GetTCPAck() {
  1319  				_, _, err = enforcer.processNetworkTCPPackets(outPacket)
  1320  				So(err, ShouldNotBeNil)
  1321  			}
  1322  
  1323  			if debug {
  1324  				fmt.Println("Output packet", i)
  1325  				outPacket.Print(0, false)
  1326  			}
  1327  
  1328  			if !reflect.DeepEqual(oldPacket.GetTCPBytes(), outPacket.GetTCPBytes()) {
  1329  				packetDiffers = true
  1330  				fmt.Println("Error: packets dont match")
  1331  				fmt.Println("Input Packet")
  1332  				oldPacket.Print(0, false)
  1333  				fmt.Println("Output Packet")
  1334  				outPacket.Print(0, false)
  1335  				t.Errorf("Packet %d Input and output packet do not match", i)
  1336  				t.FailNow()
  1337  			}
  1338  		}
  1339  
  1340  		Convey("Then I expect all the input and output packets (after encoding and decoding) to be same", func() {
  1341  
  1342  			So(packetDiffers, ShouldEqual, false)
  1343  		})
  1344  	}
  1345  
  1346  	flowRecord := CreateFlowRecord(1, testSrcIP, testDstIP, 666, 80, policy.Reject|policy.Log, "policy")
  1347  
  1348  	Convey("When the mode is RemoteConainter", t, func() {
  1349  
  1350  		enforcer, mockCollector := createEnforcerWithPolicy(ctrl, constants.RemoteContainer)
  1351  		mockCollector.EXPECT().CollectFlowEvent(MyMatcher(&flowRecord)).Times(1)
  1352  		testThePackets(enforcer)
  1353  
  1354  	})
  1355  
  1356  	Convey("When the mode is LocalServer", t, func() {
  1357  
  1358  		enforcer, mockCollector := createEnforcerWithPolicy(ctrl, constants.LocalServer)
  1359  		mockCollector.EXPECT().CollectFlowEvent(MyMatcher(&flowRecord)).Times(1)
  1360  		testThePackets(enforcer)
  1361  
  1362  	})
  1363  }
  1364  
  1365  func TestForPacketsWithRandomFlags(t *testing.T) {
  1366  
  1367  	ctrl := gomock.NewController(t)
  1368  	defer ctrl.Finish()
  1369  
  1370  	debug = true
  1371  
  1372  	defer MockGetUDPRawSocket()()
  1373  
  1374  	testThePackets := func(enforcer *Datapath) {
  1375  
  1376  		PacketFlow := packetgen.NewPacketFlow("aa:ff:aa:ff:aa:ff", "ff:aa:ff:aa:ff:aa", testSrcIP, testDstIP, 666, 80)
  1377  		_, err := PacketFlow.GenerateTCPFlow(packetgen.PacketFlowTypeGenerateGoodFlow)
  1378  		So(err, ShouldBeNil)
  1379  
  1380  		count := PacketFlow.GetNumPackets()
  1381  		for i := 0; i < count; i++ {
  1382  			//Setting random TCP flags for all the packets
  1383  			PacketFlow.GetNthPacket(i).SetTCPCwr()
  1384  			PacketFlow.GetNthPacket(i).SetTCPPsh()
  1385  			PacketFlow.GetNthPacket(i).SetTCPEce()
  1386  			input, err := PacketFlow.GetNthPacket(i).ToBytes()
  1387  			So(err, ShouldBeNil)
  1388  			tcpPacket, err := packet.New(0, input, "0", true)
  1389  			if err == nil && tcpPacket != nil {
  1390  				tcpPacket.UpdateIPv4Checksum()
  1391  				tcpPacket.UpdateTCPChecksum()
  1392  			}
  1393  
  1394  			if debug {
  1395  				fmt.Println("Input packet", i)
  1396  				tcpPacket.Print(0, false)
  1397  			}
  1398  
  1399  			So(err, ShouldBeNil)
  1400  			So(tcpPacket, ShouldNotBeNil)
  1401  
  1402  			SIP := tcpPacket.SourceAddress()
  1403  
  1404  			if !reflect.DeepEqual(SIP, tcpPacket.DestinationAddress()) &&
  1405  				!reflect.DeepEqual(SIP, tcpPacket.SourceAddress()) {
  1406  				t.Error("Invalid Test Packet")
  1407  			}
  1408  
  1409  			_, err = enforcer.processApplicationTCPPackets(tcpPacket)
  1410  			So(err, ShouldBeNil)
  1411  
  1412  			if debug {
  1413  				fmt.Println("Intermediate packet", i)
  1414  				tcpPacket.Print(0, false)
  1415  			}
  1416  
  1417  			output := make([]byte, len(tcpPacket.GetTCPBytes()))
  1418  			copy(output, tcpPacket.GetTCPBytes())
  1419  
  1420  			outPacket, errp := packet.New(0, output, "0", true)
  1421  			So(len(tcpPacket.GetTCPBytes()), ShouldBeLessThanOrEqualTo, len(outPacket.GetTCPBytes()))
  1422  			So(errp, ShouldBeNil)
  1423  
  1424  			_, f, err := enforcer.processNetworkTCPPackets(outPacket)
  1425  			if f != nil {
  1426  				f()
  1427  			}
  1428  
  1429  			So(err, ShouldBeNil)
  1430  
  1431  			if debug {
  1432  				fmt.Println("Output packet ", i)
  1433  				outPacket.Print(0, false)
  1434  			}
  1435  		}
  1436  	}
  1437  
  1438  	flowRecord := CreateFlowRecord(1, testSrcIP, testDstIP, 666, 80, policy.Accept, "")
  1439  
  1440  	Convey("When the mode is RemoteConainter", t, func() {
  1441  
  1442  		enforcer, mockCollector := createEnforcerWithPolicy(ctrl, constants.RemoteContainer)
  1443  		mockCollector.EXPECT().CollectFlowEvent(MyMatcher(&flowRecord)).Times(1)
  1444  		testThePackets(enforcer)
  1445  
  1446  	})
  1447  
  1448  	Convey("When the mode is LocalServer", t, func() {
  1449  
  1450  		enforcer, mockCollector := createEnforcerWithPolicy(ctrl, constants.LocalServer)
  1451  		mockCollector.EXPECT().CollectFlowEvent(MyMatcher(&flowRecord)).Times(1)
  1452  		testThePackets(enforcer)
  1453  	})
  1454  }
  1455  
  1456  func TestPUPortCreation(t *testing.T) {
  1457  
  1458  	ctrl := gomock.NewController(t)
  1459  	defer ctrl.Finish()
  1460  
  1461  	Convey("Given I setup an enforcer", t, func() {
  1462  
  1463  		defer MockGetUDPRawSocket()()
  1464  
  1465  		enforcer, secrets, mockTokenAccessor, _, mockDNS := NewWithMocks(ctrl, "serverID1", constants.LocalServer, []string{"0.0.0.0/0"}, true)
  1466  		if enforcer == nil { // This avoids lint error SA5011: possible nil pointer dereference (staticcheck)
  1467  			So(enforcer != nil, ShouldBeTrue)
  1468  			return
  1469  		}
  1470  
  1471  		enforcer.packetLogs = true
  1472  
  1473  		secrets.EXPECT().EncodingKey().Return(&ecdsa.PrivateKey{}).AnyTimes()
  1474  		mockTokenAccessor.EXPECT().Sign(gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil).AnyTimes()
  1475  		mockTokenAccessor.EXPECT().CreateSynPacketToken(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil)
  1476  
  1477  		contextID := "1001"
  1478  		puInfo := policy.NewPUInfo(contextID, "/ns1", common.LinuxProcessPU)
  1479  		puInfo.Runtime.SetOptions(policy.OptionsType{
  1480  			CgroupMark: "100",
  1481  		})
  1482  
  1483  		mockDNS.EXPECT().StartDNSServer(gomock.Any(), contextID, gomock.Any()).Times(1)
  1484  		mockDNS.EXPECT().Enforce(gomock.Any(), contextID, puInfo)
  1485  		mockDNS.EXPECT().SyncWithPlatformCache(gomock.Any(), gomock.Any()).Times(1)
  1486  
  1487  		enforcer.Enforce(context.Background(), contextID, puInfo) // nolint
  1488  	})
  1489  }
  1490  
  1491  func TestCollectTCPPacket(t *testing.T) {
  1492  
  1493  	ctrl := gomock.NewController(t)
  1494  	defer ctrl.Finish()
  1495  
  1496  	Convey("Given I setup an enforcer", t, func() {
  1497  
  1498  		enforcer, secrets, mockTokenAccessor, mockCollector, _ := NewWithMocks(ctrl, "serverID1", constants.RemoteContainer, []string{"0.0.0.0/0"}, true)
  1499  		So(enforcer != nil, ShouldBeTrue)
  1500  
  1501  		secrets.EXPECT().EncodingKey().Return(&ecdsa.PrivateKey{}).AnyTimes()
  1502  		mockTokenAccessor.EXPECT().Sign(gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil).AnyTimes()
  1503  		mockTokenAccessor.EXPECT().CreateSynPacketToken(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil)
  1504  
  1505  		contextID := "dummy"
  1506  		_, err := CreatePUContext(enforcer, contextID, "/ns1", common.ContainerPU, mockTokenAccessor)
  1507  		So(err, ShouldBeNil)
  1508  
  1509  		tcpPacket, err := newPacket(1, packet.TCPSynMask, testSrcIP, testDstIP, srcPort, dstPort, true, false)
  1510  		So(err, ShouldBeNil)
  1511  
  1512  		Convey("We setup tcp network packet tracing for this pu with incomplete state", func() {
  1513  			interval := 10 * time.Second
  1514  			err := enforcer.EnableDatapathPacketTracing(context.TODO(), contextID, packettracing.NetworkOnly, interval)
  1515  			So(err, ShouldBeNil)
  1516  			packetreport := collector.PacketReport{
  1517  				DestinationIP: tcpPacket.DestinationAddress().String(),
  1518  				SourceIP:      tcpPacket.SourceAddress().String(),
  1519  			}
  1520  			mockCollector.EXPECT().CollectPacketEvent(PacketEventMatcher(&packetreport)).Times(0)
  1521  			enforcer.collectTCPPacket(&debugpacketmessage{
  1522  				Mark:    10,
  1523  				p:       tcpPacket,
  1524  				tcpConn: nil,
  1525  				udpConn: nil,
  1526  				err:     nil,
  1527  				network: true,
  1528  			})
  1529  		})
  1530  		Convey("We setup tcp network packet tracing for this pu with tcpConn != nil state", func() {
  1531  			interval := 10 * time.Second
  1532  			err := enforcer.EnableDatapathPacketTracing(context.TODO(), contextID, packettracing.NetworkOnly, interval)
  1533  			So(err, ShouldBeNil)
  1534  			packetreport := collector.PacketReport{
  1535  				DestinationIP: tcpPacket.DestinationAddress().String(),
  1536  				SourceIP:      tcpPacket.SourceAddress().String(),
  1537  			}
  1538  			context, _ := enforcer.puFromContextID.Get(contextID)
  1539  			tcpConn := connection.NewTCPConnection(context.(*pucontext.PUContext), nil)
  1540  
  1541  			mockCollector.EXPECT().CollectPacketEvent(PacketEventMatcher(&packetreport)).Times(1)
  1542  			enforcer.collectTCPPacket(&debugpacketmessage{
  1543  				Mark:    10,
  1544  				p:       tcpPacket,
  1545  				tcpConn: tcpConn,
  1546  				udpConn: nil,
  1547  				err:     nil,
  1548  				network: true,
  1549  			})
  1550  		})
  1551  		Convey("We setup tcp network packet tracing for this pu with tcpConn != nil and inject application packet", func() {
  1552  			interval := 10 * time.Second
  1553  			err := enforcer.EnableDatapathPacketTracing(context.TODO(), contextID, packettracing.NetworkOnly, interval)
  1554  			So(err, ShouldBeNil)
  1555  			packetreport := collector.PacketReport{
  1556  				DestinationIP: tcpPacket.DestinationAddress().String(),
  1557  				SourceIP:      tcpPacket.SourceAddress().String(),
  1558  			}
  1559  			context, _ := enforcer.puFromContextID.Get(contextID)
  1560  			tcpConn := connection.NewTCPConnection(context.(*pucontext.PUContext), nil)
  1561  			mockCollector.EXPECT().CollectPacketEvent(PacketEventMatcher(&packetreport)).Times(0)
  1562  			enforcer.collectTCPPacket(&debugpacketmessage{
  1563  				Mark:    10,
  1564  				p:       tcpPacket,
  1565  				tcpConn: tcpConn,
  1566  				udpConn: nil,
  1567  				err:     nil,
  1568  				network: false,
  1569  			})
  1570  		})
  1571  
  1572  	})
  1573  }
  1574  
  1575  func TestEnableDatapathPacketTracing(t *testing.T) {
  1576  
  1577  	ctrl := gomock.NewController(t)
  1578  	defer ctrl.Finish()
  1579  
  1580  	Convey("Given I setup an enforcer", t, func() {
  1581  
  1582  		enforcer, secrets, mockTokenAccessor, _, _ := NewWithMocks(ctrl, "serverID1", constants.RemoteContainer, []string{"0.0.0.0/0"}, true)
  1583  		if enforcer == nil { // This avoids lint error SA5011: possible nil pointer dereference (staticcheck)
  1584  			So(enforcer != nil, ShouldBeTrue)
  1585  			return
  1586  		}
  1587  
  1588  		secrets.EXPECT().EncodingKey().Return(&ecdsa.PrivateKey{}).AnyTimes()
  1589  		mockTokenAccessor.EXPECT().Sign(gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil).AnyTimes()
  1590  		mockTokenAccessor.EXPECT().CreateSynPacketToken(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil)
  1591  
  1592  		contextID := "dummy"
  1593  		_, err := CreatePUContext(enforcer, contextID, "/ns1", common.ContainerPU, mockTokenAccessor)
  1594  		So(err, ShouldBeNil)
  1595  
  1596  		err = enforcer.EnableDatapathPacketTracing(context.TODO(), contextID, packettracing.ApplicationOnly, 10*time.Second)
  1597  		So(err, ShouldBeNil)
  1598  		_, err = enforcer.packetTracingCache.Get(contextID)
  1599  		So(err, ShouldBeNil)
  1600  	})
  1601  }
  1602  
  1603  func Test_CheckCounterCollection(t *testing.T) {
  1604  	ctrl := gomock.NewController(t)
  1605  	defer ctrl.Finish()
  1606  	collectCounterInterval = 1 * time.Second
  1607  	Convey("Given I setup an enforcer", t, func() {
  1608  
  1609  		Convey("So When enforcer exits", func() {
  1610  
  1611  			enforcer, secrets, mockTokenAccessor, mockCollector, _ := NewWithMocks(ctrl, "serverID1", constants.RemoteContainer, []string{"0.0.0.0/0"}, true)
  1612  			So(enforcer != nil, ShouldBeTrue)
  1613  
  1614  			secrets.EXPECT().EncodingKey().Return(&ecdsa.PrivateKey{}).AnyTimes()
  1615  			mockTokenAccessor.EXPECT().Sign(gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil).AnyTimes()
  1616  			mockTokenAccessor.EXPECT().CreateSynPacketToken(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil)
  1617  
  1618  			puContext, err := CreatePUContext(enforcer, "dummy", "/ns1", common.ContainerPU, mockTokenAccessor)
  1619  			So(err, ShouldBeNil)
  1620  
  1621  			CounterReport := &collector.CounterReport{
  1622  				PUID:      puContext.ManagementID(),
  1623  				Namespace: puContext.ManagementNamespace(),
  1624  			}
  1625  			mockCollector.EXPECT().CollectCounterEvent(MyCounterMatcher(CounterReport)).MinTimes(1)
  1626  
  1627  			ctx, cancel := context.WithCancel(context.Background())
  1628  			go enforcer.counterCollector(ctx)
  1629  
  1630  			puErr := puContext.Counters().CounterError((counters.ErrNonPUTraffic), fmt.Errorf("error"))
  1631  
  1632  			So(puErr, ShouldNotBeNil)
  1633  			cancel()
  1634  		})
  1635  
  1636  		Convey("So When enforer exits and waits for stuff to exit", func() {
  1637  			enforcer, secrets, mockTokenAccessor, mockCollector, _ := NewWithMocks(ctrl, "serverID1", constants.RemoteContainer, []string{"0.0.0.0/0"}, true)
  1638  			So(enforcer != nil, ShouldBeTrue)
  1639  
  1640  			secrets.EXPECT().EncodingKey().Return(&ecdsa.PrivateKey{}).AnyTimes()
  1641  			mockTokenAccessor.EXPECT().Sign(gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil).AnyTimes()
  1642  			mockTokenAccessor.EXPECT().CreateSynPacketToken(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil)
  1643  
  1644  			puContext, err := CreatePUContext(enforcer, "dummy", "/ns1", common.ContainerPU, mockTokenAccessor)
  1645  			So(err, ShouldBeNil)
  1646  
  1647  			c := &collector.CounterReport{
  1648  				PUID:      puContext.ManagementID(),
  1649  				Namespace: puContext.ManagementNamespace(),
  1650  			}
  1651  
  1652  			mockCollector.EXPECT().CollectCounterEvent(MyCounterMatcher(c)).MinTimes(1)
  1653  
  1654  			ctx, cancel := context.WithCancel(context.Background())
  1655  			go enforcer.counterCollector(ctx)
  1656  
  1657  			puErr := puContext.Counters().CounterError(counters.ErrNonPUTraffic, fmt.Errorf("error"))
  1658  
  1659  			So(puErr, ShouldNotBeNil)
  1660  			cancel()
  1661  			<-time.After(5 * time.Second)
  1662  
  1663  		})
  1664  		Convey("So When an error is reported and the enforcer waits for collection interval", func() {
  1665  			enforcer, secrets, mockTokenAccessor, mockCollector, _ := NewWithMocks(ctrl, "serverID1", constants.RemoteContainer, []string{"0.0.0.0/0"}, true)
  1666  			So(enforcer != nil, ShouldBeTrue)
  1667  
  1668  			secrets.EXPECT().EncodingKey().Return(&ecdsa.PrivateKey{}).AnyTimes()
  1669  			mockTokenAccessor.EXPECT().Sign(gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil).AnyTimes()
  1670  			mockTokenAccessor.EXPECT().CreateSynPacketToken(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil)
  1671  
  1672  			puContext, err := CreatePUContext(enforcer, "dummy", "/ns1", common.ContainerPU, mockTokenAccessor)
  1673  			So(err, ShouldBeNil)
  1674  
  1675  			c := &collector.CounterReport{
  1676  				PUID:      puContext.ManagementID(),
  1677  				Namespace: puContext.ManagementNamespace(),
  1678  			}
  1679  
  1680  			mockCollector.EXPECT().CollectCounterEvent(MyCounterMatcher(c)).MinTimes(1)
  1681  
  1682  			ctx, cancel := context.WithCancel(context.Background())
  1683  			go enforcer.counterCollector(ctx)
  1684  			puErr := puContext.Counters().CounterError(counters.ErrNonPUTraffic, fmt.Errorf("error"))
  1685  			So(puErr, ShouldNotBeNil)
  1686  			<-time.After(5 * collectCounterInterval)
  1687  			cancel()
  1688  
  1689  		})
  1690  
  1691  	})
  1692  }
  1693  
  1694  func Test_CounterReportedOnAuthSetAppSyn(t *testing.T) {
  1695  	ctrl := gomock.NewController(t)
  1696  	defer ctrl.Finish()
  1697  
  1698  	Convey("Given I setup an enforcer", t, func() {
  1699  
  1700  		enforcer, secrets, mockTokenAccessor, _, _ := NewWithMocks(ctrl, "serverID1", constants.RemoteContainer, []string{"0.0.0.0/0"}, true)
  1701  		So(enforcer != nil, ShouldBeTrue)
  1702  
  1703  		secrets.EXPECT().EncodingKey().Return(&ecdsa.PrivateKey{}).AnyTimes()
  1704  		mockTokenAccessor.EXPECT().Sign(gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil).AnyTimes()
  1705  		mockTokenAccessor.EXPECT().CreateSynPacketToken(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil)
  1706  		mockTokenAccessor.EXPECT().Randomize(gomock.Any(), gomock.Any()).Times(2)
  1707  
  1708  		context, err := CreatePUContext(enforcer, "dummy", "/ns1", common.ContainerPU, mockTokenAccessor)
  1709  		So(err, ShouldBeNil)
  1710  
  1711  		p, err := newPacket(packet.PacketTypeApplication, packet.TCPSynMask, "1.1.1.1", "2.2.2.2", srcPort, dstPort, false, false)
  1712  		So(err, ShouldBeNil)
  1713  		conn := connection.NewTCPConnection(context, p)
  1714  		err = enforcer.processApplicationSynPacket(p, context, conn)
  1715  		So(err, ShouldBeNil)
  1716  
  1717  		c := conn.Context.Counters().GetErrorCounters()
  1718  		So(c[counters.ErrAppSynAuthOptionSet], ShouldBeZeroValue)
  1719  
  1720  		p, err = newPacket(packet.PacketTypeApplication, packet.TCPSynMask, "1.1.1.1", "2.2.2.2", srcPort, dstPort, true, false)
  1721  		So(err, ShouldBeNil)
  1722  		conn = connection.NewTCPConnection(context, p)
  1723  		err = enforcer.processApplicationSynPacket(p, context, conn)
  1724  		So(err, ShouldBeNil)
  1725  
  1726  		c = conn.Context.Counters().GetErrorCounters()
  1727  		So(c[counters.ErrAppSynAuthOptionSet], ShouldEqual, 1)
  1728  	})
  1729  }
  1730  
  1731  func Test_CounterOnSynCacheTimeout(t *testing.T) {
  1732  
  1733  	ctrl := gomock.NewController(t)
  1734  	defer ctrl.Finish()
  1735  
  1736  	Convey("Given I setup an enforcer", t, func() {
  1737  
  1738  		enforcer, secrets, mockTokenAccessor, _, _ := NewWithMocks(ctrl, "serverID1", constants.RemoteContainer, []string{"0.0.0.0/0"}, true)
  1739  		if enforcer == nil { // This avoids lint error SA5011: possible nil pointer dereference (staticcheck)
  1740  			So(enforcer != nil, ShouldBeTrue)
  1741  			return
  1742  		}
  1743  
  1744  		secrets.EXPECT().EncodingKey().Return(&ecdsa.PrivateKey{}).AnyTimes()
  1745  		mockTokenAccessor.EXPECT().Sign(gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil).AnyTimes()
  1746  		mockTokenAccessor.EXPECT().CreateSynPacketToken(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil)
  1747  		mockTokenAccessor.EXPECT().Randomize(gomock.Any(), gomock.Any()).Times(1)
  1748  
  1749  		context, err := CreatePUContext(enforcer, "dummy", "/ns1", common.ContainerPU, mockTokenAccessor)
  1750  		So(err, ShouldBeNil)
  1751  
  1752  		p, err := newPacket(packet.PacketTypeApplication, packet.TCPSynMask, "1.1.1.1", "2.2.2.2", srcPort, dstPort, false, false)
  1753  		So(err, ShouldBeNil)
  1754  
  1755  		// Update the connection timer for testing.
  1756  		conn := connection.NewTCPConnection(context, p)
  1757  		conn.ChangeConnectionTimeout(2 * time.Second)
  1758  
  1759  		err = enforcer.processApplicationSynPacket(p, context, conn)
  1760  		So(err, ShouldBeNil)
  1761  
  1762  		c := conn.Context.Counters().GetErrorCounters()
  1763  		So(c[counters.ErrTCPConnectionsExpired], ShouldBeZeroValue)
  1764  
  1765  		// Wait for the connection to expire.
  1766  		time.Sleep(3 * time.Second)
  1767  		_, exists := enforcer.tcpClient.Get(p.L4FlowHash())
  1768  		if exists {
  1769  			t.Fail()
  1770  		}
  1771  
  1772  		c = conn.Context.Counters().GetErrorCounters()
  1773  		So(c[counters.ErrTCPConnectionsExpired], ShouldEqual, 1)
  1774  	})
  1775  }
  1776  
  1777  func Test_NOClaims(t *testing.T) {
  1778  	ctrl := gomock.NewController(t)
  1779  	defer ctrl.Finish()
  1780  
  1781  	Convey("Given I setup an enforcer", t, func() {
  1782  
  1783  		enforcer, _, _, mockCollector, _ := NewWithMocks(ctrl, "serverID1", constants.RemoteContainer, []string{"0.0.0.0/0"}, true)
  1784  		So(enforcer != nil, ShouldBeTrue)
  1785  
  1786  		flowRecord := CreateFlowRecord(1, "1.1.1.1", "2.2.2.2", 2000, 80, policy.Reject|policy.Log, collector.PolicyDrop)
  1787  		mockCollector.EXPECT().CollectFlowEvent(MyMatcher(&flowRecord)).Times(1)
  1788  
  1789  		context, err := CreatePUContext(enforcer, "dummy", "/ns1", common.ContainerPU, nil)
  1790  		So(err, ShouldBeNil)
  1791  
  1792  		p, err := newPacket(packet.PacketTypeNetwork, packet.TCPSynAckMask, "2.2.2.2", "1.1.1.1", dstPort, srcPort, true, false)
  1793  		So(err, ShouldBeNil)
  1794  
  1795  		conn := connection.NewTCPConnection(context, p)
  1796  
  1797  		_, err = enforcer.processNetworkSynAckPacket(context, conn, p)
  1798  		So(err, ShouldNotBeNil)
  1799  	})
  1800  }
  1801  
  1802  func newPacket(context uint64, tcpFlags uint8, src, dst string, srcPort, desPort uint16, addOptions bool, addPayload bool) (*packet.Packet, error) { //nolint
  1803  
  1804  	p, err := packet.NewIpv4TCPPacket(context, tcpFlags, src, dst, srcPort, dstPort)
  1805  	if err != nil {
  1806  		return nil, err
  1807  	}
  1808  
  1809  	p.SetTCPSeq(rand.Uint32())
  1810  
  1811  	if addOptions {
  1812  		options := []byte{2 /*Maximum Segment Size*/, 4, 0x05, 0x8C, 34, enforcerconstants.TCPAuthenticationOptionBaseLen, 0, 0}
  1813  		buffer := append(p.GetBuffer(0), options...)
  1814  		err = p.UpdatePacketBuffer(buffer, uint16(len(options)))
  1815  	}
  1816  
  1817  	if addPayload {
  1818  		buffer := append(p.GetBuffer(0), []byte("dummy payload")...)
  1819  		err = p.UpdatePacketBuffer(buffer, 0)
  1820  	}
  1821  
  1822  	return p, err
  1823  }
  1824  
  1825  func TestCheckConnectionDeletion(t *testing.T) {
  1826  
  1827  	ctrl := gomock.NewController(t)
  1828  	defer ctrl.Finish()
  1829  
  1830  	Convey("Given I setup an enforcer", t, func() {
  1831  
  1832  		enforcer, secrets, mockTokenAccessor, _, _ := NewWithMocks(ctrl, "serverID1", constants.RemoteContainer, []string{"0.0.0.0/0"}, true)
  1833  
  1834  		secrets.EXPECT().TransmittedKey().Return([]byte("dummy")).AnyTimes()
  1835  		secrets.EXPECT().EncodingKey().Return(&ecdsa.PrivateKey{}).AnyTimes()
  1836  		mockTokenAccessor.EXPECT().Sign(gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil).AnyTimes()
  1837  		mockTokenAccessor.EXPECT().CreateSynPacketToken(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil)
  1838  
  1839  		err := CreatePortPolicy(enforcer, "dummy", "/ns1", common.ContainerPU, mockTokenAccessor, "2", dstPort, dstPort)
  1840  		So(err, ShouldBeNil)
  1841  
  1842  		tcpPacket, err := newPacket(1, packet.TCPSynMask, testSrcIP, testDstIP, srcPort, dstPort, true, false)
  1843  		So(err, ShouldBeNil)
  1844  
  1845  		conn := &connection.TCPConnection{
  1846  			ServiceConnection: true,
  1847  			MarkForDeletion:   true,
  1848  		}
  1849  
  1850  		hash := tcpPacket.L4FlowHash()
  1851  		enforcer.tcpClient.Put(hash, conn)
  1852  
  1853  		tcpPacket.Mark = "2"
  1854  
  1855  		conn1, err := enforcer.appSynRetrieveState(tcpPacket)
  1856  		So(err, ShouldBeNil)
  1857  		So(conn1.MarkForDeletion, ShouldBeFalse)
  1858  
  1859  		enforcer.tcpServer.Put(hash, conn)
  1860  		_, err = enforcer.netSynRetrieveState(tcpPacket)
  1861  		So(err, ShouldBeNil)
  1862  
  1863  		tcpSynAckPacket, err := newPacket(1, packet.TCPSynAckMask, testDstIP, testSrcIP, dstPort, srcPort, true, false)
  1864  		So(err, ShouldBeNil)
  1865  
  1866  		_, err = enforcer.netSynAckRetrieveState(tcpSynAckPacket)
  1867  		So(err, ShouldNotBeNil)
  1868  		ShouldEqual(err, errNonPUTraffic)
  1869  	})
  1870  }
  1871  
  1872  func TestNetSynRetrieveState(t *testing.T) {
  1873  
  1874  	ctrl := gomock.NewController(t)
  1875  	defer ctrl.Finish()
  1876  
  1877  	// Testing datapath.netSynRetrieveState
  1878  	// There are 4 different code branches in this functions
  1879  
  1880  	Convey("Given I setup an enforcer", t, func() {
  1881  
  1882  		defer MockGetUDPRawSocket()()
  1883  
  1884  		enforcer, secrets, mockTokenAccessor, _, _ := NewWithMocks(ctrl, "serverID1", constants.LocalServer, []string{"0.0.0.0/0"}, true)
  1885  
  1886  		secrets.EXPECT().TransmittedKey().Return([]byte("dummy")).AnyTimes()
  1887  		secrets.EXPECT().EncodingKey().Return(&ecdsa.PrivateKey{}).AnyTimes()
  1888  		mockTokenAccessor.EXPECT().Sign(gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil).AnyTimes()
  1889  		mockTokenAccessor.EXPECT().CreateSynPacketToken(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil)
  1890  
  1891  		err := CreatePortPolicy(enforcer, "123456", "/ns1", common.LinuxProcessPU, mockTokenAccessor, "2", 9000, 9000)
  1892  		So(err, ShouldBeNil)
  1893  
  1894  		// Test the error case
  1895  		p, err := packet.NewIpv4TCPPacket(1, 0x2, "127.0.0.1", "127.0.0.1", 43758, 8000)
  1896  		So(err, ShouldBeNil)
  1897  		_, err = enforcer.netSynRetrieveState(p)
  1898  		So(err, ShouldNotBeNil)
  1899  
  1900  		p, err = packet.NewIpv4TCPPacket(1, 0x2, "127.0.0.1", "127.0.0.1", 43758, 9000)
  1901  		So(err, ShouldBeNil)
  1902  
  1903  		conn, err := enforcer.netSynRetrieveState(p)
  1904  		So(err, ShouldBeNil)
  1905  
  1906  		enforcer.tcpServer.Put(p.L4FlowHash(), conn)
  1907  
  1908  		So(conn.GetInitialSequenceNumber(), ShouldEqual, p.TCPSequenceNumber())
  1909  		Convey("I retry the same packet", func() {
  1910  			retryconn, err := enforcer.netSynRetrieveState(p)
  1911  			assert.Equal(t, err, nil, "error should be nil")
  1912  			assert.Equal(t, retryconn, conn, "connection should be same")
  1913  		})
  1914  		Convey("Then i modify the sequence number and retry the packet", func() {
  1915  			p.IncreaseTCPSeq(10)
  1916  			conn1, err := enforcer.netSynRetrieveState(p)
  1917  			So(err, ShouldBeNil)
  1918  			So(conn1.GetInitialSequenceNumber(), ShouldNotEqual, conn.GetInitialSequenceNumber())
  1919  			_, exists := enforcer.tcpServer.Get(p.L4FlowHash())
  1920  			if exists {
  1921  				t.Fail()
  1922  			}
  1923  		})
  1924  
  1925  	})
  1926  }
  1927  
  1928  func TestAppSynRetrieveState(t *testing.T) {
  1929  
  1930  	ctrl := gomock.NewController(t)
  1931  	defer ctrl.Finish()
  1932  
  1933  	// Testing datapath.appSynRetrieveState
  1934  	// There are 4 different code branches in the function
  1935  
  1936  	Convey("Given I setup an enforcer", t, func() {
  1937  
  1938  		defer MockGetUDPRawSocket()()
  1939  
  1940  		enforcer, _, _, _, _ := NewWithMocks(ctrl, "serverID1", constants.LocalServer, []string{"0.0.0.0/0"}, true)
  1941  
  1942  		err := CreatePortPolicy(enforcer, "testContextID", "/ns1", common.LinuxProcessPU, nil, "2", 9000, 9000)
  1943  		So(err, ShouldBeNil)
  1944  
  1945  		// Create a Syn packet
  1946  		p, err := packet.NewIpv4TCPPacket(1, 0x2, "127.0.0.1", "127.0.0.1", 43758, 9000)
  1947  		So(err, ShouldBeNil)
  1948  
  1949  		// The error case "PU context doesn't exist for this syn, return error"
  1950  		_, err = enforcer.appSynRetrieveState(p)
  1951  		So(err, ShouldNotBeNil)
  1952  
  1953  		p.Mark = "2"
  1954  
  1955  		conn, err := enforcer.appSynRetrieveState(p)
  1956  		So(err, ShouldBeNil)
  1957  
  1958  		enforcer.tcpClient.Put(p.L4FlowHash(), conn)
  1959  
  1960  		Convey("I replay the same packet", func() {
  1961  			retryconn, err := enforcer.appSynRetrieveState(p)
  1962  			So(err, ShouldBeNil)
  1963  			So(retryconn, ShouldNotBeNil)
  1964  
  1965  		})
  1966  		Convey("I modify the sequence number and retransmit the packet", func() {
  1967  			p.IncreaseTCPSeq(10)
  1968  			retryconn, err := enforcer.appSynRetrieveState(p)
  1969  			So(retryconn, ShouldNotBeNil)
  1970  			So(err, ShouldBeNil)
  1971  			_, exists := enforcer.tcpClient.Get(p.L4FlowHash())
  1972  			if exists {
  1973  				t.Fail()
  1974  			}
  1975  		})
  1976  	})
  1977  }
  1978  
  1979  func TestAppSynAckRetrieveState(t *testing.T) {
  1980  
  1981  	ctrl := gomock.NewController(t)
  1982  	defer ctrl.Finish()
  1983  
  1984  	// Testing datapath.appSynAckRetrieveState
  1985  	// There are 2 different code branches in this functions
  1986  
  1987  	Convey("Given I setup an enforcer", t, func() {
  1988  
  1989  		defer MockGetUDPRawSocket()()
  1990  
  1991  		enforcer, _, _, _, _ := NewWithMocks(ctrl, "serverID1", constants.LocalServer, []string{"0.0.0.0/0"}, true)
  1992  
  1993  		// Create a SynAck packet
  1994  		p, err := packet.NewIpv4TCPPacket(1, packet.TCPSynAckMask, "127.0.0.1", "127.0.0.1", 43758, 9000)
  1995  		So(err, ShouldBeNil)
  1996  
  1997  		// The error case when nothing is in the cache
  1998  		_, err = enforcer.appSynAckRetrieveState(p)
  1999  		So(err, ShouldNotBeNil)
  2000  
  2001  		// add connection to the cache
  2002  		enforcer.tcpServer.Put(p.L4ReverseFlowHash(), &connection.TCPConnection{})
  2003  
  2004  		// Should be in the cache
  2005  		conn, err := enforcer.appSynAckRetrieveState(p)
  2006  		So(err, ShouldBeNil)
  2007  		So(conn, ShouldNotBeNil)
  2008  	})
  2009  }
  2010  
  2011  func TestNetSynAckRetrieveState(t *testing.T) {
  2012  
  2013  	ctrl := gomock.NewController(t)
  2014  	defer ctrl.Finish()
  2015  
  2016  	// Testing datapath.netSynAckRetrieveState
  2017  	// There are 3 different code branches in this functions
  2018  
  2019  	Convey("Given I setup an enforcer", t, func() {
  2020  
  2021  		defer MockGetUDPRawSocket()()
  2022  
  2023  		enforcer, _, _, _, _ := NewWithMocks(ctrl, "serverID1", constants.LocalServer, []string{"0.0.0.0/0"}, true)
  2024  
  2025  		// Create a SynAck packet
  2026  		p, err := packet.NewIpv4TCPPacket(1, packet.TCPSynAckMask, "127.0.0.1", "127.0.0.1", 43758, 9000)
  2027  		So(err, ShouldBeNil)
  2028  
  2029  		// The error case when nothing is in the cache
  2030  		_, err = enforcer.netSynAckRetrieveState(p)
  2031  		ShouldEqual(err, errNonPUTraffic)
  2032  
  2033  		// add connection to the cache
  2034  		enforcer.tcpClient.Put(p.L4ReverseFlowHash(), &connection.TCPConnection{})
  2035  
  2036  		// Should be in the cache
  2037  		conn, err := enforcer.netSynAckRetrieveState(p)
  2038  		So(err, ShouldBeNil)
  2039  		So(conn, ShouldNotBeNil)
  2040  
  2041  		// Mark the connection as deleted
  2042  		conn.MarkForDeletion = true
  2043  
  2044  		// We should get an error
  2045  		_, err = enforcer.netSynAckRetrieveState(p)
  2046  		ShouldEqual(err, errOutOfOrderSynAck)
  2047  	})
  2048  }
  2049  
  2050  func TestAppRetrieveState(t *testing.T) {
  2051  
  2052  	ctrl := gomock.NewController(t)
  2053  	defer ctrl.Finish()
  2054  
  2055  	// Testing datapath.appRetrieveState
  2056  	// There are 6 branch conditions in this function.
  2057  
  2058  	Convey("Given I setup an enforcer", t, func() {
  2059  
  2060  		defer MockGetUDPRawSocket()()
  2061  
  2062  		enforcer, _, _, _, _ := NewWithMocks(ctrl, "serverID1", constants.LocalServer, []string{"0.0.0.0/0"}, true)
  2063  
  2064  		// Create a Rst packet
  2065  		p, err := packet.NewIpv4TCPPacket(1, packet.TCPRstMask, "127.0.0.1", "127.0.0.1", 43758, 9000)
  2066  		So(err, ShouldBeNil)
  2067  
  2068  		// 1. We should get the errRstPacket error
  2069  		_, err = enforcer.appRetrieveState(p)
  2070  		ShouldEqual(err, errRstPacket)
  2071  
  2072  		// Create a Syn packet
  2073  		p, err = packet.NewIpv4TCPPacket(1, packet.TCPSynMask, "127.0.0.1", "127.0.0.1", 43758, 9000)
  2074  		So(err, ShouldBeNil)
  2075  
  2076  		// 2. We should get errNoConnection error
  2077  		_, err = enforcer.appRetrieveState(p)
  2078  		ShouldEqual(err, errNoConnection)
  2079  
  2080  		// Create a Ack packet
  2081  		p, err = packet.NewIpv4TCPPacket(1, packet.TCPAckMask, "127.0.0.1", "127.0.0.1", 43758, 9000)
  2082  		So(err, ShouldBeNil)
  2083  
  2084  		// 3. We should get error "No context in app processing"
  2085  		_, err = enforcer.appRetrieveState(p)
  2086  		ShouldResemble(err, errors.New("No context in app processing"))
  2087  
  2088  		// Create port policy
  2089  		err = CreatePortPolicy(enforcer, "testContextID", "/ns1", common.LinuxProcessPU, nil, "2", 43758, 43758)
  2090  		So(err, ShouldBeNil)
  2091  
  2092  		p.Mark = "2"
  2093  
  2094  		// 4. We should get a connection object with UnknownState
  2095  		conn, err := enforcer.appRetrieveState(p)
  2096  		So(err, ShouldBeNil)
  2097  		So(conn, ShouldNotBeNil)
  2098  		ShouldEqual(conn.GetState(), connection.UnknownState)
  2099  
  2100  		// add connection to the server cache
  2101  		connServer := &connection.TCPConnection{}
  2102  		enforcer.tcpServer.Put(p.L4ReverseFlowHash(), connServer)
  2103  
  2104  		// 5. Should be in the cache
  2105  		conn, err = enforcer.appRetrieveState(p)
  2106  		So(err, ShouldBeNil)
  2107  		ShouldEqual(conn, connServer)
  2108  
  2109  		// add connection to the client cache
  2110  		connClient := &connection.TCPConnection{}
  2111  		enforcer.tcpClient.Put(p.L4FlowHash(), connClient)
  2112  
  2113  		// 6. Should be in the cache
  2114  		conn, err = enforcer.appRetrieveState(p)
  2115  		So(err, ShouldBeNil)
  2116  		ShouldEqual(conn, connClient)
  2117  	})
  2118  }
  2119  
  2120  func TestNetRetrieveState(t *testing.T) {
  2121  
  2122  	ctrl := gomock.NewController(t)
  2123  	defer ctrl.Finish()
  2124  
  2125  	// Testing datapath.netRetrieveState
  2126  	// There are 7 branch conditions in this function.
  2127  
  2128  	Convey("Given I setup an enforcer", t, func() {
  2129  
  2130  		defer MockGetUDPRawSocket()()
  2131  
  2132  		enforcer, _, _, _, _ := NewWithMocks(ctrl, "serverID1", constants.LocalServer, []string{"0.0.0.0/0"}, true)
  2133  
  2134  		// Create a Rst packet
  2135  		p, err := packet.NewIpv4TCPPacket(1, packet.TCPRstMask, "127.0.0.1", "127.0.0.1", 43758, 9000)
  2136  		So(err, ShouldBeNil)
  2137  
  2138  		// 1. We should get the errRstPacket error
  2139  		_, err = enforcer.netRetrieveState(p)
  2140  		ShouldEqual(err, errRstPacket)
  2141  
  2142  		// Create a Syn packet
  2143  		p, err = packet.NewIpv4TCPPacket(1, packet.TCPSynMask, "127.0.0.1", "127.0.0.1", 43758, 9000)
  2144  		So(err, ShouldBeNil)
  2145  
  2146  		// 2. We should get errNoConnection error
  2147  		_, err = enforcer.netRetrieveState(p)
  2148  		ShouldEqual(err, errNoConnection)
  2149  
  2150  		// Create a Ack packet
  2151  		p, err = packet.NewIpv4TCPPacket(1, packet.TCPAckMask, "127.0.0.1", "127.0.0.1", 43758, 9000)
  2152  		So(err, ShouldBeNil)
  2153  
  2154  		// 3. We should get error " TCP Port Not Found 9000"
  2155  		_, err = enforcer.netRetrieveState(p)
  2156  		ShouldResemble(err, errors.New(" TCP Port Not Found 9000"))
  2157  
  2158  		// Create port policy
  2159  		err = CreatePortPolicy(enforcer, "testContextID", "/ns1", common.LinuxProcessPU, nil, "2", 9000, 9000)
  2160  		So(err, ShouldBeNil)
  2161  
  2162  		p.Mark = "2"
  2163  
  2164  		// 4. We should get a connection object with UnknownState
  2165  		conn, err := enforcer.netRetrieveState(p)
  2166  		So(err, ShouldBeNil)
  2167  		So(conn, ShouldNotBeNil)
  2168  		ShouldEqual(conn.GetState(), connection.UnknownState)
  2169  
  2170  		// add connection to the server cache
  2171  		connServer := &connection.TCPConnection{}
  2172  		enforcer.tcpServer.Put(p.L4FlowHash(), connServer)
  2173  
  2174  		// 5. Should be in the cache
  2175  		conn, err = enforcer.netRetrieveState(p)
  2176  		So(err, ShouldBeNil)
  2177  		ShouldEqual(conn, connServer)
  2178  
  2179  		// add connection to the client cache
  2180  		connClient := &connection.TCPConnection{}
  2181  		enforcer.tcpClient.Put(p.L4ReverseFlowHash(), connClient)
  2182  
  2183  		// 6. Should be in the cache
  2184  		conn, err = enforcer.netRetrieveState(p)
  2185  		So(err, ShouldBeNil)
  2186  		ShouldEqual(conn, connClient)
  2187  
  2188  		// Change to a Rst packet
  2189  		p.SetTCPFlags(packet.TCPRstMask)
  2190  
  2191  		// 7. Should be in the cache, but should get error errRstPacket
  2192  		_, err = enforcer.netRetrieveState(p)
  2193  		So(err, ShouldNotBeNil)
  2194  		ShouldEqual(err, errRstPacket)
  2195  	})
  2196  }
  2197  
  2198  // This is to ensure that if we get tcp fo packet with no identity payload that we drop the packet
  2199  func TestProcessNetworkSynPacket(t *testing.T) {
  2200  
  2201  	ctrl := gomock.NewController(t)
  2202  	defer ctrl.Finish()
  2203  
  2204  	Convey("When I setup an enforcer", t, func() {
  2205  
  2206  		defer MockGetUDPRawSocket()()
  2207  
  2208  		enforcer, _, _, mockCollector, _ := NewWithMocks(ctrl, "serverID1", constants.LocalServer, []string{"0.0.0.0/0"}, true)
  2209  
  2210  		flowRecord := CreateFlowRecord(1, testSrcIP, testDstIP, 43758, 80, policy.Reject|policy.Log, collector.MissingToken)
  2211  		mockCollector.EXPECT().CollectFlowEvent(MyMatcher(&flowRecord)).Times(1)
  2212  
  2213  		Convey("So I received a packet with tcp fast open option set but no payload", func() {
  2214  
  2215  			p, err := packet.NewIpv4TCPPacket(1, 0x2, testSrcIP, testDstIP, 43758, 80)
  2216  			So(err, ShouldBeNil)
  2217  			So(p, ShouldNotBeNil)
  2218  
  2219  			// Add the fast open option
  2220  			buffer := append(p.GetBuffer(0), []byte{packet.TCPAuthenticationOption, enforcerconstants.TCPAuthenticationOptionBaseLen, 0, 0}...)
  2221  			err = p.UpdatePacketBuffer(buffer, 4)
  2222  			So(err, ShouldBeNil)
  2223  
  2224  			err = p.CheckTCPAuthenticationOption(enforcerconstants.TCPAuthenticationOptionBaseLen)
  2225  			So(err, ShouldBeNil)
  2226  			So(p.IsEmptyTCPPayload(), ShouldBeTrue)
  2227  
  2228  			context, err := CreatePUContext(enforcer, "dummyContext", "/ns1", common.LinuxProcessPU, nil)
  2229  			So(err, ShouldBeNil)
  2230  			So(context, ShouldNotBeNil)
  2231  
  2232  			_, err = enforcer.processNetworkSynPacket(context, connection.NewTCPConnection(context, p), p)
  2233  			So(err, ShouldNotBeNil)
  2234  		})
  2235  	})
  2236  }
  2237  
  2238  func TestProcessNetworkSynAckPacket(t *testing.T) {
  2239  
  2240  	ctrl := gomock.NewController(t)
  2241  	defer ctrl.Finish()
  2242  
  2243  	Convey("When I setup an enforcer", t, func() {
  2244  
  2245  		defer MockGetUDPRawSocket()()
  2246  
  2247  		enforcer, _, _, mockCollector, _ := NewWithMocks(ctrl, "serverID1", constants.LocalServer, []string{"0.0.0.0/0"}, true)
  2248  
  2249  		flowRecord1 := CreateFlowRecord(1, testDstIP, testSrcIP, 80, 43758, policy.Reject|policy.Log, collector.PolicyDrop)
  2250  		mockCollector.EXPECT().CollectFlowEvent(MyMatcher(&flowRecord1)).Times(1)
  2251  
  2252  		flowRecord2 := CreateFlowRecord(1, testDstIP, testSrcIP, 80, 43758, policy.Accept, "")
  2253  		mockCollector.EXPECT().CollectFlowEvent(MyMatcher(&flowRecord2)).Times(1)
  2254  
  2255  		Convey("So I received a packet with tcp fast open option set but no payload", func() {
  2256  
  2257  			p, err := packet.NewIpv4TCPPacket(1, 0x2, testSrcIP, testDstIP, 43758, 80)
  2258  			So(err, ShouldBeNil)
  2259  			So(p, ShouldNotBeNil)
  2260  
  2261  			// Add the fast open option
  2262  			buffer := append(p.GetBuffer(0), []byte{packet.TCPAuthenticationOption, enforcerconstants.TCPAuthenticationOptionBaseLen, 0, 0}...)
  2263  
  2264  			err = p.UpdatePacketBuffer(buffer, 4)
  2265  			So(err, ShouldBeNil)
  2266  
  2267  			err = p.CheckTCPAuthenticationOption(enforcerconstants.TCPAuthenticationOptionBaseLen)
  2268  			So(err, ShouldBeNil)
  2269  			So(p.IsEmptyTCPPayload(), ShouldBeTrue)
  2270  
  2271  			context, err := CreatePUContext(enforcer, "dummyContext", "/ns1", common.LinuxProcessPU, nil)
  2272  			So(err, ShouldBeNil)
  2273  			So(context, ShouldNotBeNil)
  2274  
  2275  			_, err = enforcer.processNetworkSynAckPacket(context, connection.NewTCPConnection(context, p), p)
  2276  			So(err, ShouldNotBeNil)
  2277  
  2278  			Convey("Then i add ip acl rule.", func() {
  2279  				iprules := policy.IPRuleList{policy.IPRule{
  2280  					Addresses: []string{"10.1.10.76/32"},
  2281  					Ports:     []string{"43758"},
  2282  					Protocols: []string{constants.TCPProtoNum},
  2283  					Policy: &policy.FlowPolicy{
  2284  						Action:   policy.Accept,
  2285  						PolicyID: "tcp172/8"},
  2286  				}}
  2287  				err = context.UpdateApplicationACLs(iprules)
  2288  				So(err, ShouldBeNil)
  2289  
  2290  				_, err = enforcer.processNetworkSynAckPacket(context, connection.NewTCPConnection(context, p), p)
  2291  				So(err, ShouldBeNil)
  2292  			})
  2293  		})
  2294  
  2295  	})
  2296  }