github.com/aporeto-inc/trireme-lib@v10.358.0+incompatible/monitor/internal/k8s/runtime_cache_test.go (about)

     1  // +build linux
     2  
     3  package k8smonitor
     4  
     5  // TODO: make compatible with Windows
     6  
     7  import (
     8  	"context"
     9  	"fmt"
    10  	"reflect"
    11  	"sync"
    12  	"syscall"
    13  	"testing"
    14  	"time"
    15  
    16  	"go.aporeto.io/enforcerd/trireme-lib/common"
    17  	"go.aporeto.io/enforcerd/trireme-lib/policy"
    18  )
    19  
    20  func Test_runtimeCache_Delete(t *testing.T) {
    21  	stopEvent := func(context.Context, string) error {
    22  		return nil
    23  	}
    24  
    25  	// override globals for unit tests
    26  	oldDefaultLoopWait := defaultLoopWait
    27  	oldSyscallKill := syscallKill
    28  	defer func() {
    29  		defaultLoopWait = oldDefaultLoopWait
    30  		syscallKill = oldSyscallKill
    31  	}()
    32  	defaultLoopWait = time.Duration(0)
    33  	syscallKill = func(int, syscall.Signal) (err error) {
    34  		return nil
    35  	}
    36  
    37  	tests := []struct {
    38  		name      string
    39  		c         *runtimeCache
    40  		sandboxID string
    41  	}{
    42  		{
    43  			name:      "cache uninitialized",
    44  			c:         nil,
    45  			sandboxID: "does-not-matter",
    46  		},
    47  		{
    48  			name:      "cache initialized",
    49  			c:         newRuntimeCache(context.TODO(), stopEvent),
    50  			sandboxID: "does-not-mater",
    51  		},
    52  	}
    53  	for _, tt := range tests {
    54  		t.Run(tt.name, func(t *testing.T) {
    55  			tt.c.Delete(tt.sandboxID)
    56  		})
    57  	}
    58  }
    59  
    60  func Test_runtimeCache_Set(t *testing.T) {
    61  	stopEvent := func(context.Context, string) error {
    62  		return nil
    63  	}
    64  
    65  	// override globals for unit tests
    66  	oldDefaultLoopWait := defaultLoopWait
    67  	oldSyscallKill := syscallKill
    68  	defer func() {
    69  		defaultLoopWait = oldDefaultLoopWait
    70  		syscallKill = oldSyscallKill
    71  	}()
    72  	defaultLoopWait = time.Duration(0)
    73  	syscallKill = func(int, syscall.Signal) (err error) {
    74  		return nil
    75  	}
    76  
    77  	type args struct {
    78  		sandboxID string
    79  		runtime   policy.RuntimeReader
    80  	}
    81  	tests := []struct {
    82  		name         string
    83  		c            *runtimeCache
    84  		args         args
    85  		wantErr      bool
    86  		wantErrError error
    87  	}{
    88  		{
    89  			name:         "cache uninitialized",
    90  			c:            nil,
    91  			wantErr:      true,
    92  			wantErrError: errCacheUninitialized,
    93  		},
    94  		{
    95  			name:         "cache has unintialized map",
    96  			c:            &runtimeCache{},
    97  			wantErr:      true,
    98  			wantErrError: errCacheUninitialized,
    99  			args: args{
   100  				sandboxID: "does-not-matter",
   101  				runtime:   policy.NewPURuntimeWithDefaults(),
   102  			},
   103  		},
   104  		{
   105  			name:         "no sandboxID",
   106  			c:            newRuntimeCache(context.TODO(), stopEvent),
   107  			wantErr:      true,
   108  			wantErrError: errSandboxEmpty,
   109  			args: args{
   110  				sandboxID: "",
   111  				runtime:   policy.NewPURuntimeWithDefaults(),
   112  			},
   113  		},
   114  		{
   115  			name:         "runtime is nil",
   116  			c:            newRuntimeCache(context.TODO(), stopEvent),
   117  			wantErr:      true,
   118  			wantErrError: errRuntimeNil,
   119  			args: args{
   120  				sandboxID: "does-not-matter",
   121  				runtime:   nil,
   122  			},
   123  		},
   124  		{
   125  			name:    "successful update entry",
   126  			c:       newRuntimeCache(context.TODO(), stopEvent),
   127  			wantErr: false,
   128  			args: args{
   129  				sandboxID: "does-not-matter",
   130  				runtime:   policy.NewPURuntimeWithDefaults(),
   131  			},
   132  		},
   133  	}
   134  	for _, tt := range tests {
   135  		t.Run(tt.name, func(t *testing.T) {
   136  			err := tt.c.Set(tt.args.sandboxID, tt.args.runtime)
   137  			if (err != nil) != tt.wantErr {
   138  				t.Errorf("runtimeCache.Set() error = %v, wantErr %v", err, tt.wantErr)
   139  			}
   140  			if tt.wantErr {
   141  				if err != tt.wantErrError {
   142  					t.Errorf("runtimeCache.Set() error = %v, wantErrError %v", err, tt.wantErrError)
   143  				}
   144  			}
   145  		})
   146  	}
   147  }
   148  
   149  func Test_runtimeCache_Get(t *testing.T) {
   150  	stopEvent := func(context.Context, string) error {
   151  		return nil
   152  	}
   153  
   154  	// override globals for unit tests
   155  	oldDefaultLoopWait := defaultLoopWait
   156  	oldSyscallKill := syscallKill
   157  	defer func() {
   158  		defaultLoopWait = oldDefaultLoopWait
   159  		syscallKill = oldSyscallKill
   160  	}()
   161  	defaultLoopWait = time.Duration(0)
   162  	syscallKill = func(int, syscall.Signal) (err error) {
   163  		return nil
   164  	}
   165  
   166  	cacheWithEntry := newRuntimeCache(context.TODO(), stopEvent)
   167  	if err := cacheWithEntry.Set("entry", policy.NewPURuntimeWithDefaults()); err != nil {
   168  		panic(err)
   169  	}
   170  	tests := []struct {
   171  		name      string
   172  		sandboxID string
   173  		c         *runtimeCache
   174  		want      policy.RuntimeReader
   175  	}{
   176  		{
   177  			name: "uninitialized runtimeCache",
   178  			c:    nil,
   179  			want: nil,
   180  		},
   181  		{
   182  			name: "uninitialized map in runtimeCache",
   183  			c:    &runtimeCache{},
   184  			want: nil,
   185  		},
   186  		{
   187  			name:      "entry does not exist",
   188  			c:         newRuntimeCache(context.TODO(), stopEvent),
   189  			sandboxID: "does-not-exist",
   190  			want:      nil,
   191  		},
   192  		{
   193  			name:      "entry exists",
   194  			c:         cacheWithEntry,
   195  			sandboxID: "entry",
   196  			want:      policy.NewPURuntimeWithDefaults(),
   197  		},
   198  	}
   199  	for _, tt := range tests {
   200  		t.Run(tt.name, func(t *testing.T) {
   201  			if got := tt.c.Get(tt.sandboxID); !reflect.DeepEqual(got, tt.want) {
   202  				t.Errorf("runtimeCache.Get() = %v, want %v", got, tt.want)
   203  			}
   204  		})
   205  	}
   206  }
   207  
   208  func Test_makeSnapshot(t *testing.T) {
   209  	type args struct {
   210  		m map[string]runtimeCacheEntry
   211  	}
   212  	tests := []struct {
   213  		name string
   214  		args args
   215  		want map[string]policy.RuntimeReader
   216  	}{
   217  		{
   218  			name: "empty",
   219  			args: args{
   220  				m: map[string]runtimeCacheEntry{},
   221  			},
   222  			want: map[string]policy.RuntimeReader{},
   223  		},
   224  		{
   225  			name: "not-running entry",
   226  			args: args{
   227  				m: map[string]runtimeCacheEntry{
   228  					"entry": {
   229  						runtime: policy.NewPURuntimeWithDefaults(),
   230  						running: false,
   231  					},
   232  				},
   233  			},
   234  			want: map[string]policy.RuntimeReader{},
   235  		},
   236  		{
   237  			name: "running entry",
   238  			args: args{
   239  				m: map[string]runtimeCacheEntry{
   240  					"entry": {
   241  						runtime: policy.NewPURuntimeWithDefaults(),
   242  						running: true,
   243  					},
   244  				},
   245  			},
   246  			want: map[string]policy.RuntimeReader{
   247  				"entry": policy.NewPURuntimeWithDefaults(),
   248  			},
   249  		},
   250  	}
   251  	for _, tt := range tests {
   252  		t.Run(tt.name, func(t *testing.T) {
   253  			if got := makeSnapshot(tt.args.m); !reflect.DeepEqual(got, tt.want) {
   254  				t.Errorf("makeSnapshot() = %v, want %v", got, tt.want)
   255  			}
   256  		})
   257  	}
   258  }
   259  
   260  type unitTestStopEvent interface {
   261  	f() stopEventFunc
   262  	wait()
   263  	called() bool
   264  }
   265  type unitTestStopEventHandler struct {
   266  	sync.RWMutex
   267  	wg        sync.WaitGroup
   268  	wgCounter int
   269  	wasCalled bool
   270  	err       error
   271  }
   272  
   273  func (h *unitTestStopEventHandler) stopEvent(context.Context, string) error {
   274  	h.Lock()
   275  	defer h.Unlock()
   276  	h.wasCalled = true
   277  	if h.wgCounter > 0 {
   278  		h.wgCounter--
   279  	}
   280  	if h.wgCounter >= 0 {
   281  		h.wg.Done()
   282  	}
   283  	return h.err
   284  }
   285  
   286  func (h *unitTestStopEventHandler) f() stopEventFunc {
   287  	return h.stopEvent
   288  }
   289  
   290  func (h *unitTestStopEventHandler) wait() {
   291  	h.wg.Wait()
   292  }
   293  
   294  func (h *unitTestStopEventHandler) called() bool {
   295  	h.RLock()
   296  	defer h.RUnlock()
   297  	return h.wasCalled
   298  }
   299  
   300  func newUnitTestStopEventHandler(n int, err error) unitTestStopEvent {
   301  	h := &unitTestStopEventHandler{
   302  		err:       err,
   303  		wgCounter: n,
   304  	}
   305  	h.wg.Add(n)
   306  	return h
   307  }
   308  
   309  func Test_runtimeCache_processRuntimes(t *testing.T) {
   310  	// override globals for unit tests
   311  	oldDefaultLoopWait := defaultLoopWait
   312  	oldSyscallKill := syscallKill
   313  	defer func() {
   314  		defaultLoopWait = oldDefaultLoopWait
   315  		syscallKill = oldSyscallKill
   316  	}()
   317  	defaultLoopWait = time.Duration(0)
   318  	syscallKill = func(int, syscall.Signal) (err error) {
   319  		return nil
   320  	}
   321  
   322  	type fields struct {
   323  		runtimes map[string]runtimeCacheEntry
   324  	}
   325  	type args struct {
   326  		ctx  context.Context
   327  		snap map[string]policy.RuntimeReader
   328  	}
   329  
   330  	runtime := policy.NewPURuntime("entry", 42, "", nil, nil, common.ContainerPU, policy.None, nil)
   331  	tests := []struct {
   332  		name              string
   333  		syscallKill       func(int, syscall.Signal) error
   334  		stopEventHandler  unitTestStopEvent
   335  		fields            fields
   336  		args              args
   337  		expectedStopEvent bool
   338  		expectedRuntimes  map[string]runtimeCacheEntry
   339  	}{
   340  		{
   341  			name: "process still running",
   342  			syscallKill: func(int, syscall.Signal) error {
   343  				return nil
   344  			},
   345  			stopEventHandler: newUnitTestStopEventHandler(0, nil),
   346  			fields: fields{
   347  				runtimes: map[string]runtimeCacheEntry{
   348  					"entry": {
   349  						runtime: runtime,
   350  						running: true,
   351  					},
   352  				},
   353  			},
   354  			args: args{
   355  				ctx: context.Background(),
   356  				snap: map[string]policy.RuntimeReader{
   357  					"entry": runtime,
   358  				},
   359  			},
   360  			expectedStopEvent: false,
   361  			expectedRuntimes: map[string]runtimeCacheEntry{
   362  				"entry": {
   363  					runtime: runtime,
   364  					running: true,
   365  				},
   366  			},
   367  		},
   368  		{
   369  			name: "syscall returns unexpected error",
   370  			syscallKill: func(int, syscall.Signal) error {
   371  				return fmt.Errorf("unexpected error")
   372  			},
   373  			stopEventHandler: newUnitTestStopEventHandler(0, nil),
   374  			fields: fields{
   375  				runtimes: map[string]runtimeCacheEntry{
   376  					"entry": {
   377  						runtime: runtime,
   378  						running: true,
   379  					},
   380  				},
   381  			},
   382  			args: args{
   383  				ctx: context.Background(),
   384  				snap: map[string]policy.RuntimeReader{
   385  					"entry": runtime,
   386  				},
   387  			},
   388  			expectedStopEvent: false,
   389  			expectedRuntimes: map[string]runtimeCacheEntry{
   390  				"entry": {
   391  					runtime: runtime,
   392  					running: true,
   393  				},
   394  			},
   395  		},
   396  		{
   397  			name: "process not running anymore",
   398  			syscallKill: func(int, syscall.Signal) error {
   399  				return syscall.ESRCH
   400  			},
   401  			stopEventHandler: newUnitTestStopEventHandler(1, fmt.Errorf("more test coverage")),
   402  			fields: fields{
   403  				runtimes: map[string]runtimeCacheEntry{
   404  					"entry": {
   405  						runtime: runtime,
   406  						running: true,
   407  					},
   408  				},
   409  			},
   410  			args: args{
   411  				ctx: context.Background(),
   412  				snap: map[string]policy.RuntimeReader{
   413  					"entry": runtime,
   414  				},
   415  			},
   416  			expectedStopEvent: true,
   417  			expectedRuntimes: map[string]runtimeCacheEntry{
   418  				"entry": {
   419  					runtime: runtime,
   420  					running: false,
   421  				},
   422  			},
   423  		},
   424  	}
   425  	for _, tt := range tests {
   426  		t.Run(tt.name, func(t *testing.T) {
   427  			syscallKill = tt.syscallKill
   428  			c := &runtimeCache{
   429  				runtimes:  tt.fields.runtimes,
   430  				stopEvent: tt.stopEventHandler.f(),
   431  			}
   432  			ctx, cancel := context.WithCancel(tt.args.ctx)
   433  			defer cancel()
   434  			c.processRuntimes(ctx, tt.args.snap)
   435  			tt.stopEventHandler.wait()
   436  			if !reflect.DeepEqual(c.runtimes, tt.expectedRuntimes) {
   437  				t.Errorf("c.runtimes = %v, want %v", c.runtimes, tt.expectedRuntimes)
   438  			}
   439  			if tt.expectedStopEvent != tt.stopEventHandler.called() {
   440  				t.Errorf("stopEventHandler.called() = %v, want %v", tt.stopEventHandler.called(), tt.expectedStopEvent)
   441  			}
   442  		})
   443  	}
   444  }
   445  
   446  func Test_runtimeCache_loop(t *testing.T) {
   447  	stopEventHandler := newUnitTestStopEventHandler(1, nil)
   448  
   449  	// override globals for unit tests
   450  	oldDefaultLoopWait := defaultLoopWait
   451  	oldSyscallKill := syscallKill
   452  	defer func() {
   453  		defaultLoopWait = oldDefaultLoopWait
   454  		syscallKill = oldSyscallKill
   455  	}()
   456  	defaultLoopWait = time.Duration(1)
   457  	syscallKill = func(int, syscall.Signal) error {
   458  		return syscall.ESRCH
   459  	}
   460  
   461  	tests := []struct {
   462  		name             string
   463  		stopEventHandler unitTestStopEvent
   464  	}{
   465  		{
   466  			name:             "successful loop",
   467  			stopEventHandler: stopEventHandler,
   468  		},
   469  	}
   470  	for _, tt := range tests {
   471  		t.Run(tt.name, func(t *testing.T) {
   472  			// this starts the loop already
   473  			ctx, cancel := context.WithCancel(context.Background())
   474  			c := newRuntimeCache(ctx, tt.stopEventHandler.f())
   475  
   476  			// TODO: to get that last inch of coverage :) not sure how else to get that
   477  			time.Sleep(time.Millisecond * 10)
   478  
   479  			// add a runtime
   480  			runtime := policy.NewPURuntime("entry", 42, "", nil, nil, common.ContainerPU, policy.None, nil)
   481  			c.Set("entry", runtime) // nolint: errcheck
   482  			c.RLock()
   483  			if c.runtimes["entry"].running != true { // nolint
   484  				t.Errorf("entry is not marked as running")
   485  			}
   486  			c.RUnlock()
   487  
   488  			// wait until the stop event was called
   489  			tt.stopEventHandler.wait()
   490  			cancel()
   491  
   492  			c.RLock()
   493  			if c.runtimes["entry"].running != false { // nolint
   494  				t.Errorf("entry is still marked as running")
   495  			}
   496  			c.RUnlock()
   497  		})
   498  	}
   499  }