github.com/aporeto-inc/trireme-lib@v10.358.0+incompatible/monitor/internal/k8s/on_startup_test.go (about)

     1  package k8smonitor
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"reflect"
     7  	"sync"
     8  	"testing"
     9  
    10  	"github.com/golang/mock/gomock"
    11  
    12  	"go.aporeto.io/enforcerd/internal/extractors/containermetadata"
    13  	"go.aporeto.io/enforcerd/internal/extractors/containermetadata/mockcontainermetadata"
    14  
    15  	runtimeapi "k8s.io/cri-api/pkg/apis/runtime/v1alpha2"
    16  
    17  	"go.aporeto.io/enforcerd/trireme-lib/monitor/config"
    18  	"go.aporeto.io/enforcerd/trireme-lib/utils/cri/mockcri"
    19  )
    20  
    21  func Test_extractKmdFromCRISandbox(t *testing.T) {
    22  
    23  	type args struct {
    24  		sandboxID string
    25  	}
    26  	tests := []struct {
    27  		name    string
    28  		args    args
    29  		want    containermetadata.CommonKubernetesContainerMetadata
    30  		wantErr bool
    31  		prepare func(t *testing.T, extractor *mockcontainermetadata.MockCommonContainerMetadataExtractor)
    32  	}{
    33  		{
    34  			name: "sandbox ID empty",
    35  			args: args{
    36  				sandboxID: "",
    37  			},
    38  			want:    nil,
    39  			wantErr: true,
    40  			prepare: func(t *testing.T, extractor *mockcontainermetadata.MockCommonContainerMetadataExtractor) {
    41  				//nothing to be done here
    42  			},
    43  		},
    44  		{
    45  			name: "container not found with the extractor",
    46  			args: args{
    47  				sandboxID: "not-found",
    48  			},
    49  			want:    nil,
    50  			wantErr: true,
    51  			prepare: func(t *testing.T, extractor *mockcontainermetadata.MockCommonContainerMetadataExtractor) {
    52  				extractor.EXPECT().Has(
    53  					gomock.Eq(containermetadata.NewRuncArguments(containermetadata.StartAction, "not-found")),
    54  				).Return(false).Times(1)
    55  			},
    56  		},
    57  		{
    58  			name: "container extractor failed",
    59  			args: args{
    60  				sandboxID: "sandbox-id",
    61  			},
    62  			want:    nil,
    63  			wantErr: true,
    64  			prepare: func(t *testing.T, extractor *mockcontainermetadata.MockCommonContainerMetadataExtractor) {
    65  				ContainerArgs := containermetadata.NewRuncArguments(containermetadata.StartAction, "sandbox-id")
    66  				extractor.EXPECT().Has(gomock.Eq(ContainerArgs)).Return(true).Times(1)
    67  				extractor.EXPECT().Extract(gomock.Eq(ContainerArgs)).Return(nil, nil, fmt.Errorf("failed to extrat")).Times(1)
    68  			},
    69  		},
    70  		{
    71  			name: "container extractor succeeded but is not Kubernetes container",
    72  			args: args{
    73  				sandboxID: "sandbox-id",
    74  			},
    75  			want:    nil,
    76  			wantErr: true,
    77  			prepare: func(t *testing.T, extractor *mockcontainermetadata.MockCommonContainerMetadataExtractor) {
    78  				ContainerArgs := containermetadata.NewRuncArguments(containermetadata.StartAction, "sandbox-id")
    79  				extractor.EXPECT().Has(gomock.Eq(ContainerArgs)).Return(true).Times(1)
    80  				// technically the first result would need to be populated, but that doesn't matter for the test
    81  				extractor.EXPECT().Extract(gomock.Eq(ContainerArgs)).Return(nil, nil, nil).Times(1)
    82  			},
    83  		},
    84  		{
    85  			name: "container extractor succeeded",
    86  			args: args{
    87  				sandboxID: "sandbox-id",
    88  			},
    89  			want:    mockcontainermetadata.NewMockCommonKubernetesContainerMetadata(nil),
    90  			wantErr: false,
    91  			prepare: func(t *testing.T, extractor *mockcontainermetadata.MockCommonContainerMetadataExtractor) {
    92  				ContainerArgs := containermetadata.NewRuncArguments(containermetadata.StartAction, "sandbox-id")
    93  				extractor.EXPECT().Has(gomock.Eq(ContainerArgs)).Return(true).Times(1)
    94  				// technically the first result would need to be populated, but that doesn't matter for the test
    95  				extractor.EXPECT().Extract(gomock.Eq(ContainerArgs)).Return(
    96  					nil,
    97  					mockcontainermetadata.NewMockCommonKubernetesContainerMetadata(nil),
    98  					nil,
    99  				).Times(1)
   100  			},
   101  		},
   102  	}
   103  	for _, tt := range tests {
   104  		t.Run(tt.name, func(t *testing.T) {
   105  			ctrl := gomock.NewController(t)
   106  			mockExtractor := mockcontainermetadata.NewMockCommonContainerMetadataExtractor(ctrl)
   107  			extractor = mockExtractor
   108  			tt.prepare(t, mockExtractor)
   109  			got, err := extractKmdFromCRISandbox(tt.args.sandboxID)
   110  			if (err != nil) != tt.wantErr {
   111  				t.Errorf("extractKmdFromCRISandbox() error = %v, wantErr %v", err, tt.wantErr)
   112  				return
   113  			}
   114  			if !reflect.DeepEqual(got, tt.want) {
   115  				t.Errorf("extractKmdFromCRISandbox() = %v, want %v", got, tt.want)
   116  			}
   117  			ctrl.Finish()
   118  		})
   119  	}
   120  }
   121  
   122  type unitTestStartEvent interface {
   123  	f() startEventFunc
   124  	wait()
   125  	called() bool
   126  }
   127  type unitTestStartEventHandler struct {
   128  	sync.RWMutex
   129  	wg        sync.WaitGroup
   130  	wgCounter int
   131  	wasCalled bool
   132  	err       error
   133  }
   134  
   135  func (h *unitTestStartEventHandler) startEvent(ctx context.Context, kmd containermetadata.CommonKubernetesContainerMetadata, retry uint) error {
   136  	h.Lock()
   137  	defer h.Unlock()
   138  	h.wasCalled = true
   139  	if h.wgCounter > 0 {
   140  		h.wgCounter--
   141  	}
   142  	if h.wgCounter >= 0 {
   143  		h.wg.Done()
   144  	}
   145  	return h.err
   146  }
   147  
   148  func (h *unitTestStartEventHandler) f() startEventFunc {
   149  	return h.startEvent
   150  }
   151  
   152  func (h *unitTestStartEventHandler) wait() {
   153  	h.wg.Wait()
   154  }
   155  
   156  func (h *unitTestStartEventHandler) called() bool {
   157  	h.RLock()
   158  	defer h.RUnlock()
   159  	return h.wasCalled
   160  }
   161  
   162  func newUnitTestStartEventHandler(n int, err error) unitTestStartEvent {
   163  	h := &unitTestStartEventHandler{
   164  		err:       err,
   165  		wgCounter: n,
   166  	}
   167  	h.wg.Add(n)
   168  	return h
   169  }
   170  
   171  func TestK8sMonitor_onStartup(t *testing.T) {
   172  
   173  	listSandboxFilter := &runtimeapi.PodSandboxFilter{
   174  		State: &runtimeapi.PodSandboxStateValue{
   175  			State: runtimeapi.PodSandboxState_SANDBOX_READY,
   176  		},
   177  	}
   178  
   179  	tests := []struct {
   180  		name               string
   181  		startEventHandler  unitTestStartEvent
   182  		wantErr            bool
   183  		prepare            func(t *testing.T, extractor *mockcontainermetadata.MockCommonContainerMetadataExtractor, cri *mockcri.MockExtendedRuntimeService)
   184  		expectedStartEvent bool
   185  	}{
   186  		{
   187  			name:               "listing sandboxes fails",
   188  			startEventHandler:  newUnitTestStartEventHandler(0, nil),
   189  			wantErr:            true,
   190  			expectedStartEvent: false,
   191  			prepare: func(t *testing.T, extractor *mockcontainermetadata.MockCommonContainerMetadataExtractor, cri *mockcri.MockExtendedRuntimeService) {
   192  				cri.EXPECT().ListPodSandbox(gomock.Eq(listSandboxFilter)).Return(nil, fmt.Errorf("failed")).Times(1)
   193  			},
   194  		},
   195  		{
   196  			name:               "listing sandboxes succeeds, but extracting metadata fails",
   197  			startEventHandler:  newUnitTestStartEventHandler(0, nil),
   198  			wantErr:            false,
   199  			expectedStartEvent: false,
   200  			prepare: func(t *testing.T, extractor *mockcontainermetadata.MockCommonContainerMetadataExtractor, cri *mockcri.MockExtendedRuntimeService) {
   201  				cri.EXPECT().ListPodSandbox(gomock.Eq(listSandboxFilter)).Return(
   202  					[]*runtimeapi.PodSandbox{
   203  						{
   204  							Id: "sandbox-id",
   205  						},
   206  					},
   207  					nil,
   208  				).Times(1)
   209  
   210  				ContainerArgs := containermetadata.NewRuncArguments(containermetadata.StartAction, "sandbox-id")
   211  				extractor.EXPECT().Has(gomock.Eq(ContainerArgs)).Return(false).Times(1)
   212  			},
   213  		},
   214  		{
   215  			name:               "listing sandboxes succeeds, sending an event that fails",
   216  			startEventHandler:  newUnitTestStartEventHandler(1, fmt.Errorf("error")),
   217  			wantErr:            false,
   218  			expectedStartEvent: true,
   219  			prepare: func(t *testing.T, extractor *mockcontainermetadata.MockCommonContainerMetadataExtractor, cri *mockcri.MockExtendedRuntimeService) {
   220  				cri.EXPECT().ListPodSandbox(gomock.Eq(listSandboxFilter)).Return(
   221  					[]*runtimeapi.PodSandbox{
   222  						{
   223  							Id: "sandbox-id",
   224  						},
   225  					},
   226  					nil,
   227  				).Times(1)
   228  
   229  				ContainerArgs := containermetadata.NewRuncArguments(containermetadata.StartAction, "sandbox-id")
   230  				extractor.EXPECT().Has(gomock.Eq(ContainerArgs)).Return(true).Times(1)
   231  				// technically the first result would need to be populated, but that doesn't matter for the test
   232  				extractor.EXPECT().Extract(gomock.Eq(ContainerArgs)).Return(
   233  					nil,
   234  					mockcontainermetadata.NewMockCommonKubernetesContainerMetadata(nil),
   235  					nil,
   236  				).Times(1)
   237  			},
   238  		},
   239  		{
   240  			name:               "listing 2 sandboxes succeeds, sending 2 start events",
   241  			startEventHandler:  newUnitTestStartEventHandler(2, nil),
   242  			wantErr:            false,
   243  			expectedStartEvent: true,
   244  			prepare: func(t *testing.T, extractor *mockcontainermetadata.MockCommonContainerMetadataExtractor, cri *mockcri.MockExtendedRuntimeService) {
   245  				cri.EXPECT().ListPodSandbox(gomock.Eq(listSandboxFilter)).Return(
   246  					[]*runtimeapi.PodSandbox{
   247  						{
   248  							Id: "sandbox-id-1",
   249  						},
   250  						{
   251  							Id: "sandbox-id-2",
   252  						},
   253  					},
   254  					nil,
   255  				).Times(1)
   256  
   257  				ContainerArgs1 := containermetadata.NewRuncArguments(containermetadata.StartAction, "sandbox-id-1")
   258  				extractor.EXPECT().Has(gomock.Eq(ContainerArgs1)).Return(true).Times(1)
   259  				// technically the first result would need to be populated, but that doesn't matter for the test
   260  				extractor.EXPECT().Extract(gomock.Eq(ContainerArgs1)).Return(
   261  					nil,
   262  					mockcontainermetadata.NewMockCommonKubernetesContainerMetadata(nil),
   263  					nil,
   264  				).Times(1)
   265  				ContainerArgs2 := containermetadata.NewRuncArguments(containermetadata.StartAction, "sandbox-id-2")
   266  				extractor.EXPECT().Has(gomock.Eq(ContainerArgs2)).Return(true).Times(1)
   267  				// technically the first result would need to be populated, but that doesn't matter for the test
   268  				extractor.EXPECT().Extract(gomock.Eq(ContainerArgs2)).Return(
   269  					nil,
   270  					mockcontainermetadata.NewMockCommonKubernetesContainerMetadata(nil),
   271  					nil,
   272  				).Times(1)
   273  			},
   274  		},
   275  	}
   276  	for _, tt := range tests {
   277  		t.Run(tt.name, func(t *testing.T) {
   278  			ctrl := gomock.NewController(t)
   279  			mockExtractor := mockcontainermetadata.NewMockCommonContainerMetadataExtractor(ctrl)
   280  			extractor = mockExtractor
   281  			ctx, cancel := context.WithCancel(context.Background())
   282  			m := New(ctx)
   283  			m.SetupHandlers(&config.ProcessorConfig{
   284  				ResyncLock: &sync.RWMutex{},
   285  			})
   286  			m.SenderReady()
   287  			mockcri := mockcri.NewMockExtendedRuntimeService(ctrl)
   288  			m.criRuntimeService = mockcri
   289  			tt.prepare(t, mockExtractor, mockcri)
   290  			if err := m.onStartup(ctx, tt.startEventHandler.f()); (err != nil) != tt.wantErr {
   291  				t.Errorf("K8sMonitor.onStartup() error = %v, wantErr %v", err, tt.wantErr)
   292  			}
   293  			tt.startEventHandler.wait()
   294  			if tt.expectedStartEvent != tt.startEventHandler.called() {
   295  				t.Errorf("startEventHandler.called() = %v, want %v", tt.startEventHandler.called(), tt.expectedStartEvent)
   296  			}
   297  			cancel()
   298  			ctrl.Finish()
   299  		})
   300  	}
   301  }