github.com/aporeto-inc/trireme-lib@v10.358.0+incompatible/controller/internal/enforcer/nfqdatapath/datapath_common_test.go (about)

     1  package nfqdatapath
     2  
     3  import (
     4  	"context"
     5  	"crypto/ecdsa"
     6  	"time"
     7  
     8  	"github.com/blang/semver"
     9  	"github.com/golang/mock/gomock"
    10  	. "github.com/smartystreets/goconvey/convey"
    11  	"go.aporeto.io/enforcerd/trireme-lib/collector"
    12  	"go.aporeto.io/enforcerd/trireme-lib/collector/mockcollector"
    13  	"go.aporeto.io/enforcerd/trireme-lib/common"
    14  	"go.aporeto.io/enforcerd/trireme-lib/controller/constants"
    15  	enforcerconstants "go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/constants"
    16  	"go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/dnsproxy/mockdnsproxy"
    17  	"go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/nfqdatapath/afinetrawsocket"
    18  	"go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/nfqdatapath/tokenaccessor"
    19  	"go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/nfqdatapath/tokenaccessor/mocktokenaccessor"
    20  	"go.aporeto.io/enforcerd/trireme-lib/controller/pkg/claimsheader"
    21  	"go.aporeto.io/enforcerd/trireme-lib/controller/pkg/connection"
    22  	"go.aporeto.io/enforcerd/trireme-lib/controller/pkg/flowtracking/mockflowclient"
    23  	"go.aporeto.io/enforcerd/trireme-lib/controller/pkg/pkiverifier"
    24  	"go.aporeto.io/enforcerd/trireme-lib/controller/pkg/pucontext"
    25  	"go.aporeto.io/enforcerd/trireme-lib/controller/pkg/secrets"
    26  	"go.aporeto.io/enforcerd/trireme-lib/controller/pkg/secrets/mocksecrets"
    27  	"go.aporeto.io/enforcerd/trireme-lib/controller/pkg/tokens"
    28  	"go.aporeto.io/enforcerd/trireme-lib/controller/runtime"
    29  	"go.aporeto.io/enforcerd/trireme-lib/policy"
    30  	"go.aporeto.io/enforcerd/trireme-lib/utils/cache"
    31  	"go.aporeto.io/enforcerd/trireme-lib/utils/portspec"
    32  )
    33  
    34  const (
    35  	testSrcIP = "10.1.10.76"
    36  	testDstIP = "164.67.228.152"
    37  )
    38  
    39  var (
    40  	debug bool
    41  )
    42  
    43  func procSetValueMock(procName string, value int) error {
    44  	return nil
    45  }
    46  
    47  // NewWithDefaults create a new data path with most things used by default
    48  func newWithDefaults(
    49  	ctrl *gomock.Controller,
    50  	serverID string,
    51  	collector collector.EventCollector,
    52  	secrets secrets.Secrets,
    53  	mode constants.ModeType,
    54  	targetNetworks []string,
    55  	testExpirationNotifier bool,
    56  ) *Datapath {
    57  
    58  	// Override so that you don't have to run as root
    59  	procSetValuePtr = procSetValueMock
    60  
    61  	mockTokenAccessor := mocktokenaccessor.NewMockTokenAccessor(ctrl)
    62  	flowclient := mockflowclient.NewMockFlowClient(ctrl)
    63  	puFromContextID := cache.NewCache("puFromContextID")
    64  	mockDNS := mockdnsproxy.NewMockDNSProxy(ctrl)
    65  
    66  	mockTokenAccessor.EXPECT().Sign(gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil).AnyTimes()
    67  	mockTokenAccessor.EXPECT().CreateSynPacketToken(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil).AnyTimes()
    68  	mockTokenAccessor.EXPECT().CreateSynAckPacketToken(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil).AnyTimes()
    69  	mockTokenAccessor.EXPECT().Randomize(gomock.Any(), gomock.Any()).AnyTimes()
    70  	mockTokenAccessor.EXPECT().ParseAckToken(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
    71  	mockTokenAccessor.EXPECT().ParsePacketToken(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(
    72  		func(privateKey, data, secrets, c, b interface{}) interface{} {
    73  
    74  			claims := c.(*tokens.ConnectionClaims)
    75  			claims.T = policy.NewTagStore()
    76  			claims.T.AppendKeyValue(enforcerconstants.TransmitterLabel, "value")
    77  			return nil
    78  		},
    79  	).Return(nil, &claimsheader.ClaimsHeader{}, &pkiverifier.PKIControllerInfo{}, []byte("remoteNonce"), "", false, nil).AnyTimes()
    80  
    81  	mockDNS.EXPECT().StartDNSServer(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
    82  	mockDNS.EXPECT().Enforce(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
    83  	mockDNS.EXPECT().SyncWithPlatformCache(gomock.Any(), gomock.Any()).AnyTimes()
    84  
    85  	e := New(
    86  		false,
    87  		nil,
    88  		collector,
    89  		serverID,
    90  		10*time.Minute,
    91  		secrets,
    92  		mode,
    93  		"/proc",
    94  		500*time.Millisecond,
    95  		false,
    96  		mockTokenAccessor,
    97  		puFromContextID,
    98  		&runtime.Configuration{TCPTargetNetworks: targetNetworks},
    99  		false,
   100  		semver.Version{},
   101  		policy.None,
   102  	)
   103  
   104  	e.conntrack = flowclient
   105  	e.dnsProxy = mockDNS
   106  
   107  	if testExpirationNotifier {
   108  		e.tcpConnectionExpirationNotifier = testConnectionExpirationNotifier
   109  	}
   110  
   111  	return e
   112  }
   113  
   114  // NewWithMocks create a new data path using mock objects
   115  func NewWithMocks(
   116  	ctrl *gomock.Controller,
   117  	serverID string,
   118  	mode constants.ModeType,
   119  	targetNetworks []string,
   120  	testExpirationNotifier bool,
   121  ) (*Datapath, *mocksecrets.MockSecrets, *mocktokenaccessor.MockTokenAccessor,
   122  	*mockcollector.MockEventCollector, *mockdnsproxy.MockDNSProxy) {
   123  
   124  	// Override so that you don't have to run as root
   125  	procSetValuePtr = procSetValueMock
   126  
   127  	secrets := mocksecrets.NewMockSecrets(ctrl)
   128  	tokenAccessor := mocktokenaccessor.NewMockTokenAccessor(ctrl)
   129  	collector := mockcollector.NewMockEventCollector(ctrl)
   130  	flowclient := mockflowclient.NewMockFlowClient(ctrl)
   131  	puFromContextID := cache.NewCache("puFromContextID")
   132  	dnsproxy := mockdnsproxy.NewMockDNSProxy(ctrl)
   133  
   134  	secrets.EXPECT().AckSize().Return(uint32(300)).Times(1)
   135  
   136  	e := New(
   137  		false,
   138  		nil,
   139  		collector,
   140  		serverID,
   141  		10*time.Minute,
   142  		secrets,
   143  		mode,
   144  		"/proc",
   145  		500*time.Millisecond,
   146  		false,
   147  		tokenAccessor,
   148  		puFromContextID,
   149  		&runtime.Configuration{TCPTargetNetworks: targetNetworks},
   150  		false,
   151  		semver.Version{},
   152  		policy.None,
   153  	)
   154  
   155  	e.conntrack = flowclient
   156  	e.dnsProxy = dnsproxy
   157  
   158  	if testExpirationNotifier {
   159  		e.tcpConnectionExpirationNotifier = testConnectionExpirationNotifier
   160  	}
   161  
   162  	return e, secrets, tokenAccessor, collector, dnsproxy
   163  }
   164  
   165  func testConnectionExpirationNotifier(conn *connection.TCPConnection) {
   166  
   167  	conn.Cleanup()
   168  }
   169  
   170  // MockGetUDPRawSocket mocks the GetUDPRawSocket function. Usage "defer MockGetUDPRawSocket()()"
   171  func MockGetUDPRawSocket() func() {
   172  	prevRawSocket := GetUDPRawSocket
   173  	GetUDPRawSocket = func(mark int, device string) (afinetrawsocket.SocketWriter, error) {
   174  		return nil, nil
   175  	}
   176  	return func() {
   177  		GetUDPRawSocket = prevRawSocket
   178  	}
   179  }
   180  
   181  // CreatePUContext creates a policy
   182  func CreatePUContext(enforcer *Datapath, contextID, namespace string, puType common.PUType, tokenAccessor tokenaccessor.TokenAccessor) (*pucontext.PUContext, error) {
   183  	puInfo := policy.NewPUInfo(contextID, namespace, puType)
   184  	context, err := pucontext.NewPU(contextID, puInfo, tokenAccessor, 10*time.Second)
   185  	if err != nil {
   186  		return nil, err
   187  	}
   188  	enforcer.puFromContextID.AddOrUpdate(contextID, context) // nolint
   189  	return context, nil
   190  }
   191  
   192  // CreatePortPolicy creates a port range policy
   193  func CreatePortPolicy(enforcer *Datapath, contextID, namespace string, puType common.PUType, tokenAccessor tokenaccessor.TokenAccessor, mark string, portMin, portMax uint16) error {
   194  
   195  	context, err := CreatePUContext(enforcer, contextID, namespace, puType, tokenAccessor)
   196  	if err != nil {
   197  		return err
   198  	}
   199  
   200  	err = enforcer.puFromMark.Add(mark, context)
   201  	if err != nil {
   202  		return err
   203  	}
   204  
   205  	portspec, err := portspec.NewPortSpec(portMin, portMax, contextID)
   206  	if err != nil {
   207  		return err
   208  	}
   209  	enforcer.contextIDFromTCPPort.AddPortSpec(portspec)
   210  	return nil
   211  }
   212  
   213  // CreateFlowRecord creates a basic flow report
   214  func CreateFlowRecord(count int, srcIP, destIP string, srcPort, destPort uint16, action policy.ActionType, dropReason string) collector.FlowRecord {
   215  	var flowRecord collector.FlowRecord
   216  	var srcEndPoint collector.EndPoint
   217  	var dstEndPoint collector.EndPoint
   218  
   219  	srcEndPoint.IP = srcIP
   220  	srcEndPoint.Port = srcPort
   221  
   222  	dstEndPoint.IP = destIP
   223  	dstEndPoint.Port = destPort
   224  
   225  	flowRecord.Count = count
   226  	flowRecord.Source = srcEndPoint
   227  	flowRecord.Destination = dstEndPoint
   228  	flowRecord.Action = action
   229  	flowRecord.DropReason = dropReason
   230  	return flowRecord
   231  }
   232  
   233  func createEnforcerWithPolicy(ctrl *gomock.Controller, mode constants.ModeType) (*Datapath, *mockcollector.MockEventCollector) {
   234  
   235  	puInfo1, puInfo2 := createPolicies(testSrcIP, testDstIP)
   236  	So(puInfo1, ShouldNotBeNil)
   237  	So(puInfo2, ShouldNotBeNil)
   238  
   239  	enforcer, mockTokenAccessor := createEnforcer(ctrl, mode)
   240  
   241  	err := enforcer.Enforce(context.Background(), puInfo1.ContextID, puInfo1)
   242  	So(err, ShouldBeNil)
   243  
   244  	err = enforcer.Enforce(context.Background(), puInfo2.ContextID, puInfo2)
   245  	So(err, ShouldBeNil)
   246  
   247  	return enforcer, mockTokenAccessor
   248  }
   249  
   250  func createEnforcer(ctrl *gomock.Controller, mode constants.ModeType) (*Datapath, *mockcollector.MockEventCollector) {
   251  
   252  	enforcer, secrets, mockTokenAccessor, mockCollector, mockDNS := NewWithMocks(ctrl, "serverID", mode, []string{"0.0.0.0/0"}, true)
   253  	So(enforcer != nil, ShouldBeTrue)
   254  
   255  	secrets.EXPECT().EncodingKey().Return(&ecdsa.PrivateKey{}).AnyTimes()
   256  	mockTokenAccessor.EXPECT().Sign(gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil).AnyTimes()
   257  	mockTokenAccessor.EXPECT().CreateSynPacketToken(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil).AnyTimes()
   258  	mockTokenAccessor.EXPECT().CreateSynAckPacketToken(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return([]byte("token"), nil).AnyTimes()
   259  	mockTokenAccessor.EXPECT().Randomize(gomock.Any(), gomock.Any()).AnyTimes()
   260  	mockTokenAccessor.EXPECT().ParseAckToken(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
   261  	mockTokenAccessor.EXPECT().ParsePacketToken(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(
   262  		func(privateKey, data, secrets, c, b interface{}) interface{} {
   263  
   264  			claims := c.(*tokens.ConnectionClaims)
   265  			claims.T = policy.NewTagStore()
   266  			claims.T.AppendKeyValue(enforcerconstants.TransmitterLabel, "value")
   267  			return nil
   268  		},
   269  	).Return(nil, &claimsheader.ClaimsHeader{}, &pkiverifier.PKIControllerInfo{}, []byte("remoteNonce"), "", false, nil).AnyTimes()
   270  
   271  	mockDNS.EXPECT().StartDNSServer(gomock.Any(), gomock.Any(), gomock.Any()).Times(2)
   272  	mockDNS.EXPECT().Enforce(gomock.Any(), gomock.Any(), gomock.Any()).Times(2)
   273  	mockDNS.EXPECT().SyncWithPlatformCache(gomock.Any(), gomock.Any()).Times(2)
   274  	return enforcer, mockCollector
   275  }
   276  
   277  func createPolicies(srcIP, dstIP string) (*policy.PUInfo, *policy.PUInfo) {
   278  	tagSelector := policy.TagSelector{
   279  		Clause: []policy.KeyValueOperator{
   280  			{
   281  				Key:      enforcerconstants.TransmitterLabel,
   282  				Value:    []string{"value"},
   283  				Operator: policy.Equal,
   284  			},
   285  		},
   286  		Policy: &policy.FlowPolicy{Action: policy.Accept},
   287  	}
   288  
   289  	puID1 := "SomeProcessingUnitId1"
   290  	puID2 := "SomeProcessingUnitId2"
   291  
   292  	puIP1 := dstIP
   293  	puIP2 := srcIP
   294  
   295  	// Create ProcessingUnit 1
   296  	puInfo1 := policy.NewPUInfo(puID1, "/ns1", common.ContainerPU)
   297  
   298  	ip1 := policy.ExtendedMap{}
   299  	ip1["bridge"] = puIP1
   300  	puInfo1.Runtime.SetIPAddresses(ip1)
   301  	ipl1 := policy.ExtendedMap{policy.DefaultNamespace: puIP1}
   302  	puInfo1.Policy.SetIPAddresses(ipl1)
   303  	puInfo1.Policy.AddIdentityTag(enforcerconstants.TransmitterLabel, "value")
   304  	puInfo1.Policy.AddReceiverRules(tagSelector)
   305  
   306  	// Create processing unit 2
   307  	puInfo2 := policy.NewPUInfo(puID2, "/ns2", common.ContainerPU)
   308  	ip2 := policy.ExtendedMap{"bridge": puIP2}
   309  	puInfo2.Runtime.SetIPAddresses(ip2)
   310  	ipl2 := policy.ExtendedMap{policy.DefaultNamespace: puIP2}
   311  	puInfo2.Policy.SetIPAddresses(ipl2)
   312  	puInfo2.Policy.AddIdentityTag(enforcerconstants.TransmitterLabel, "value")
   313  	puInfo2.Policy.AddReceiverRules(tagSelector)
   314  
   315  	return puInfo1, puInfo2
   316  }