github.com/aporeto-inc/trireme-lib@v10.358.0+incompatible/utils/cri/cri_client_setup_linux_test.go (about)

     1  // +build linux
     2  
     3  package cri
     4  
     5  import (
     6  	"context"
     7  	"net"
     8  	"os"
     9  	"path/filepath"
    10  	"strings"
    11  	"testing"
    12  	"time"
    13  
    14  	"google.golang.org/grpc"
    15  )
    16  
    17  func Test_DetectCRIRuntimeEndpoint(t *testing.T) {
    18  	wd, err := os.Getwd()
    19  	if err != nil {
    20  		panic(err)
    21  	}
    22  	path := filepath.Join(wd, "testdata", "var", "run", "crio", "crio.sock")
    23  
    24  	if err := os.RemoveAll(path); err != nil {
    25  		panic(err)
    26  	}
    27  	if err := os.MkdirAll(filepath.Dir(path), 0777); err != nil {
    28  		panic(err)
    29  	}
    30  	l, err := net.Listen("unix", path)
    31  	if err != nil {
    32  		panic(err)
    33  	}
    34  	defer l.Close() // nolint
    35  
    36  	oldGetHostPath := getHostPath
    37  	defer func() {
    38  		getHostPath = oldGetHostPath
    39  	}()
    40  	tests := []struct {
    41  		name        string
    42  		getHostPath func(string) string
    43  		want        string
    44  		runType     Type
    45  		wantErr     bool
    46  	}{
    47  		{
    48  			name: "failed to detect a runtime",
    49  			getHostPath: func(path string) string {
    50  				return filepath.Join(wd, "does-not-exist", path)
    51  			},
    52  			want:    "",
    53  			runType: TypeNone,
    54  			wantErr: true,
    55  		},
    56  		{
    57  			name: "detected a runtime",
    58  			getHostPath: func(path string) string {
    59  				return filepath.Join(wd, "testdata", path)
    60  			},
    61  			want:    "unix://" + filepath.Join(wd, "testdata", "var", "run", "crio", "crio.sock"),
    62  			runType: TypeCRIO,
    63  			wantErr: false,
    64  		},
    65  	}
    66  	for _, tt := range tests {
    67  		t.Run(tt.name, func(t *testing.T) {
    68  			getHostPath = tt.getHostPath
    69  			got, rtype, err := DetectCRIRuntimeEndpoint()
    70  			if (err != nil) != tt.wantErr {
    71  				t.Errorf("DetectCRIRuntimeEndpoint() error = %v, wantErr %v", err, tt.wantErr)
    72  				return
    73  			}
    74  			if got != tt.want {
    75  				t.Errorf("DetectCRIRuntimeEndpoint() = %v, want %v", got, tt.want)
    76  			}
    77  			if rtype != tt.runType {
    78  				t.Errorf("DetectCRIRuntimeEndpoint() = %v, want %v", rtype, tt.runType)
    79  			}
    80  		})
    81  	}
    82  }
    83  
    84  func Test_getCRISocketAddr(t *testing.T) {
    85  	wd, err := os.Getwd()
    86  	if err != nil {
    87  		panic(err)
    88  	}
    89  
    90  	path := filepath.Join(wd, "testdata", "var", "run", "crio", "crio.sock")
    91  
    92  	if err := os.RemoveAll(path); err != nil {
    93  		panic(err)
    94  	}
    95  	if err := os.MkdirAll(filepath.Dir(path), 0777); err != nil {
    96  		panic(err)
    97  	}
    98  	l, err := net.Listen("unix", path)
    99  	if err != nil {
   100  		panic(err)
   101  	}
   102  	defer l.Close() // nolint
   103  
   104  	oldGetHostPath := getHostPath
   105  	defer func() {
   106  		getHostPath = oldGetHostPath
   107  	}()
   108  	type args struct {
   109  		criRuntimeEndpoint string
   110  	}
   111  	tests := []struct {
   112  		name        string
   113  		getHostPath func(string) string
   114  		args        args
   115  		want        string
   116  		wantErr     bool
   117  	}{
   118  		{
   119  			name: "auto-detected runtime should return without any error if it succeeds",
   120  			args: args{
   121  				criRuntimeEndpoint: "", // empty string enables auto-detection
   122  			},
   123  			getHostPath: func(path string) string {
   124  				return filepath.Join(wd, "testdata", path)
   125  			},
   126  			want:    filepath.Join(wd, "testdata", "var", "run", "crio", "crio.sock"),
   127  			wantErr: false,
   128  		},
   129  		{
   130  			name: "if auto-detection is enabled and fails, we must fail",
   131  			args: args{
   132  				criRuntimeEndpoint: "", // empty string enables auto-detection
   133  			},
   134  			getHostPath: func(path string) string {
   135  				return filepath.Join(wd, "does-not-exist", path)
   136  			},
   137  			want:    "",
   138  			wantErr: true,
   139  		},
   140  		{
   141  			name: "we fail on tcp endpoints",
   142  			args: args{
   143  				criRuntimeEndpoint: "tcp://127.0.0.1:1234",
   144  			},
   145  			want:    "",
   146  			wantErr: true,
   147  		},
   148  		{
   149  			name: "correct file paths to a unix socket should work",
   150  			args: args{
   151  				criRuntimeEndpoint: filepath.Join(wd, "testdata", "var", "run", "crio", "crio.sock"),
   152  			},
   153  			want:    filepath.Join(wd, "testdata", "var", "run", "crio", "crio.sock"),
   154  			wantErr: false,
   155  		},
   156  		{
   157  			name: "frakti is not supported",
   158  			args: args{
   159  				criRuntimeEndpoint: "/var/run/frakti.sock",
   160  			},
   161  			want:    "",
   162  			wantErr: true,
   163  		},
   164  		{
   165  			name: "frakti is not supported",
   166  			args: args{
   167  				criRuntimeEndpoint: "/var/run/frakti.sock",
   168  			},
   169  			want:    "",
   170  			wantErr: true,
   171  		},
   172  		{
   173  			name: "URL parsing of endpoint fails",
   174  			args: args{
   175  				criRuntimeEndpoint: string([]byte{0x7f}),
   176  			},
   177  			want:    "",
   178  			wantErr: true,
   179  		},
   180  	}
   181  	for _, tt := range tests {
   182  		t.Run(tt.name, func(t *testing.T) {
   183  			getHostPath = tt.getHostPath
   184  			got, err := getCRISocketAddr(tt.args.criRuntimeEndpoint)
   185  			if (err != nil) != tt.wantErr {
   186  				t.Errorf("getCRISocketAddr() error = %v, wantErr %v", err, tt.wantErr)
   187  				return
   188  			}
   189  			if got != tt.want {
   190  				t.Errorf("getCRISocketAddr() = %v, want %v", got, tt.want)
   191  			}
   192  		})
   193  	}
   194  }
   195  
   196  func Test_connectCRISocket(t *testing.T) {
   197  	oldConnectTimeout := connectTimeout
   198  	defer func() {
   199  		connectTimeout = oldConnectTimeout
   200  	}()
   201  	type args struct {
   202  		ctx  context.Context
   203  		addr string
   204  	}
   205  	tests := []struct {
   206  		name           string
   207  		args           args
   208  		connectTimeout time.Duration
   209  		runServer      bool
   210  		wantErr        bool
   211  	}{
   212  		{
   213  			name: "no timeout produces a canceled context which must always error",
   214  			args: args{
   215  				ctx:  context.Background(),
   216  				addr: "",
   217  			},
   218  			connectTimeout: 0,
   219  			wantErr:        true,
   220  		},
   221  		{
   222  			name: "successful connection to a unix server listening",
   223  			args: args{
   224  				ctx:  context.Background(),
   225  				addr: "@aporeto_cri_grpc_connect_test",
   226  			},
   227  			runServer:      true,
   228  			connectTimeout: time.Second * 10,
   229  			wantErr:        false,
   230  		},
   231  	}
   232  	for _, tt := range tests {
   233  		t.Run(tt.name, func(t *testing.T) {
   234  			connectTimeout = tt.connectTimeout
   235  			ctx, cancel := context.WithCancel(tt.args.ctx)
   236  			defer cancel()
   237  			if tt.runServer {
   238  				s := grpc.NewServer()
   239  				defer s.Stop()
   240  				go func() {
   241  					l, err := (&net.ListenConfig{}).Listen(ctx, "unix", tt.args.addr)
   242  					if err != nil {
   243  						panic(err)
   244  					}
   245  					s.Serve(l) // nolint: errcheck
   246  				}()
   247  			}
   248  			_, err := connectCRISocket(ctx, tt.args.addr)
   249  			if (err != nil) != tt.wantErr {
   250  				t.Errorf("connectCRISocket() error = %v, wantErr %v", err, tt.wantErr)
   251  				return
   252  			}
   253  		})
   254  	}
   255  }
   256  
   257  func TestNewCRIRuntimeServiceClient(t *testing.T) {
   258  	oldConnectTimeout := connectTimeout
   259  	oldCallTimeout := callTimeout
   260  	defer func() {
   261  		connectTimeout = oldConnectTimeout
   262  		callTimeout = oldCallTimeout
   263  	}()
   264  	type args struct {
   265  		ctx                context.Context
   266  		criRuntimeEndpoint string
   267  	}
   268  	tests := []struct {
   269  		name           string
   270  		args           args
   271  		connectTimeout time.Duration
   272  		callTimeout    time.Duration
   273  		runServer      bool
   274  		wantErr        bool
   275  	}{
   276  		{
   277  			name: "fails on getting socket path",
   278  			args: args{
   279  				ctx:                context.Background(),
   280  				criRuntimeEndpoint: string([]byte{0x7f}),
   281  			},
   282  			runServer: false,
   283  			wantErr:   true,
   284  		},
   285  		{
   286  			name: "success",
   287  			args: args{
   288  				ctx:                context.Background(),
   289  				criRuntimeEndpoint: "unix:@aporeto_cri_grpc_connect_test1",
   290  			},
   291  			connectTimeout: time.Second * 10,
   292  			callTimeout:    time.Second * 5,
   293  			runServer:      true,
   294  			wantErr:        false,
   295  		},
   296  		{
   297  			name: "fails creating the ExtendedRuntimeService",
   298  			args: args{
   299  				ctx:                context.Background(),
   300  				criRuntimeEndpoint: "unix:@aporeto_cri_grpc_connect_test2",
   301  			},
   302  			connectTimeout: time.Second * 10,
   303  			callTimeout:    0, // call timeout must not be 0
   304  			runServer:      true,
   305  			wantErr:        true,
   306  		},
   307  		{
   308  			name: "fails connecting to the grpc socket",
   309  			args: args{
   310  				ctx:                context.Background(),
   311  				criRuntimeEndpoint: "unix:@aporeto_cri_grpc_connect_test3",
   312  			},
   313  			connectTimeout: 0,
   314  			runServer:      true,
   315  			wantErr:        true,
   316  		},
   317  	}
   318  	for _, tt := range tests {
   319  		t.Run(tt.name, func(t *testing.T) {
   320  			connectTimeout = tt.connectTimeout
   321  			callTimeout = tt.callTimeout
   322  			ctx, cancel := context.WithCancel(tt.args.ctx)
   323  			defer cancel()
   324  			if tt.runServer {
   325  				s := grpc.NewServer()
   326  				defer s.Stop()
   327  				go func() {
   328  					l, err := (&net.ListenConfig{}).Listen(ctx, "unix", strings.TrimPrefix(strings.TrimPrefix(tt.args.criRuntimeEndpoint, "unix:"), "//"))
   329  					if err != nil {
   330  						panic(err)
   331  					}
   332  					s.Serve(l) // nolint: errcheck
   333  				}()
   334  			}
   335  			_, err := NewCRIRuntimeServiceClient(ctx, tt.args.criRuntimeEndpoint)
   336  			if (err != nil) != tt.wantErr {
   337  				t.Errorf("NewCRIRuntimeServiceClient() error = %v, wantErr %v", err, tt.wantErr)
   338  				return
   339  			}
   340  		})
   341  	}
   342  }