github.com/aporeto-inc/trireme-lib@v10.358.0+incompatible/utils/cri/cri_client_setup_linux_test.go (about) 1 // +build linux 2 3 package cri 4 5 import ( 6 "context" 7 "net" 8 "os" 9 "path/filepath" 10 "strings" 11 "testing" 12 "time" 13 14 "google.golang.org/grpc" 15 ) 16 17 func Test_DetectCRIRuntimeEndpoint(t *testing.T) { 18 wd, err := os.Getwd() 19 if err != nil { 20 panic(err) 21 } 22 path := filepath.Join(wd, "testdata", "var", "run", "crio", "crio.sock") 23 24 if err := os.RemoveAll(path); err != nil { 25 panic(err) 26 } 27 if err := os.MkdirAll(filepath.Dir(path), 0777); err != nil { 28 panic(err) 29 } 30 l, err := net.Listen("unix", path) 31 if err != nil { 32 panic(err) 33 } 34 defer l.Close() // nolint 35 36 oldGetHostPath := getHostPath 37 defer func() { 38 getHostPath = oldGetHostPath 39 }() 40 tests := []struct { 41 name string 42 getHostPath func(string) string 43 want string 44 runType Type 45 wantErr bool 46 }{ 47 { 48 name: "failed to detect a runtime", 49 getHostPath: func(path string) string { 50 return filepath.Join(wd, "does-not-exist", path) 51 }, 52 want: "", 53 runType: TypeNone, 54 wantErr: true, 55 }, 56 { 57 name: "detected a runtime", 58 getHostPath: func(path string) string { 59 return filepath.Join(wd, "testdata", path) 60 }, 61 want: "unix://" + filepath.Join(wd, "testdata", "var", "run", "crio", "crio.sock"), 62 runType: TypeCRIO, 63 wantErr: false, 64 }, 65 } 66 for _, tt := range tests { 67 t.Run(tt.name, func(t *testing.T) { 68 getHostPath = tt.getHostPath 69 got, rtype, err := DetectCRIRuntimeEndpoint() 70 if (err != nil) != tt.wantErr { 71 t.Errorf("DetectCRIRuntimeEndpoint() error = %v, wantErr %v", err, tt.wantErr) 72 return 73 } 74 if got != tt.want { 75 t.Errorf("DetectCRIRuntimeEndpoint() = %v, want %v", got, tt.want) 76 } 77 if rtype != tt.runType { 78 t.Errorf("DetectCRIRuntimeEndpoint() = %v, want %v", rtype, tt.runType) 79 } 80 }) 81 } 82 } 83 84 func Test_getCRISocketAddr(t *testing.T) { 85 wd, err := os.Getwd() 86 if err != nil { 87 panic(err) 88 } 89 90 path := filepath.Join(wd, "testdata", "var", "run", "crio", "crio.sock") 91 92 if err := os.RemoveAll(path); err != nil { 93 panic(err) 94 } 95 if err := os.MkdirAll(filepath.Dir(path), 0777); err != nil { 96 panic(err) 97 } 98 l, err := net.Listen("unix", path) 99 if err != nil { 100 panic(err) 101 } 102 defer l.Close() // nolint 103 104 oldGetHostPath := getHostPath 105 defer func() { 106 getHostPath = oldGetHostPath 107 }() 108 type args struct { 109 criRuntimeEndpoint string 110 } 111 tests := []struct { 112 name string 113 getHostPath func(string) string 114 args args 115 want string 116 wantErr bool 117 }{ 118 { 119 name: "auto-detected runtime should return without any error if it succeeds", 120 args: args{ 121 criRuntimeEndpoint: "", // empty string enables auto-detection 122 }, 123 getHostPath: func(path string) string { 124 return filepath.Join(wd, "testdata", path) 125 }, 126 want: filepath.Join(wd, "testdata", "var", "run", "crio", "crio.sock"), 127 wantErr: false, 128 }, 129 { 130 name: "if auto-detection is enabled and fails, we must fail", 131 args: args{ 132 criRuntimeEndpoint: "", // empty string enables auto-detection 133 }, 134 getHostPath: func(path string) string { 135 return filepath.Join(wd, "does-not-exist", path) 136 }, 137 want: "", 138 wantErr: true, 139 }, 140 { 141 name: "we fail on tcp endpoints", 142 args: args{ 143 criRuntimeEndpoint: "tcp://127.0.0.1:1234", 144 }, 145 want: "", 146 wantErr: true, 147 }, 148 { 149 name: "correct file paths to a unix socket should work", 150 args: args{ 151 criRuntimeEndpoint: filepath.Join(wd, "testdata", "var", "run", "crio", "crio.sock"), 152 }, 153 want: filepath.Join(wd, "testdata", "var", "run", "crio", "crio.sock"), 154 wantErr: false, 155 }, 156 { 157 name: "frakti is not supported", 158 args: args{ 159 criRuntimeEndpoint: "/var/run/frakti.sock", 160 }, 161 want: "", 162 wantErr: true, 163 }, 164 { 165 name: "frakti is not supported", 166 args: args{ 167 criRuntimeEndpoint: "/var/run/frakti.sock", 168 }, 169 want: "", 170 wantErr: true, 171 }, 172 { 173 name: "URL parsing of endpoint fails", 174 args: args{ 175 criRuntimeEndpoint: string([]byte{0x7f}), 176 }, 177 want: "", 178 wantErr: true, 179 }, 180 } 181 for _, tt := range tests { 182 t.Run(tt.name, func(t *testing.T) { 183 getHostPath = tt.getHostPath 184 got, err := getCRISocketAddr(tt.args.criRuntimeEndpoint) 185 if (err != nil) != tt.wantErr { 186 t.Errorf("getCRISocketAddr() error = %v, wantErr %v", err, tt.wantErr) 187 return 188 } 189 if got != tt.want { 190 t.Errorf("getCRISocketAddr() = %v, want %v", got, tt.want) 191 } 192 }) 193 } 194 } 195 196 func Test_connectCRISocket(t *testing.T) { 197 oldConnectTimeout := connectTimeout 198 defer func() { 199 connectTimeout = oldConnectTimeout 200 }() 201 type args struct { 202 ctx context.Context 203 addr string 204 } 205 tests := []struct { 206 name string 207 args args 208 connectTimeout time.Duration 209 runServer bool 210 wantErr bool 211 }{ 212 { 213 name: "no timeout produces a canceled context which must always error", 214 args: args{ 215 ctx: context.Background(), 216 addr: "", 217 }, 218 connectTimeout: 0, 219 wantErr: true, 220 }, 221 { 222 name: "successful connection to a unix server listening", 223 args: args{ 224 ctx: context.Background(), 225 addr: "@aporeto_cri_grpc_connect_test", 226 }, 227 runServer: true, 228 connectTimeout: time.Second * 10, 229 wantErr: false, 230 }, 231 } 232 for _, tt := range tests { 233 t.Run(tt.name, func(t *testing.T) { 234 connectTimeout = tt.connectTimeout 235 ctx, cancel := context.WithCancel(tt.args.ctx) 236 defer cancel() 237 if tt.runServer { 238 s := grpc.NewServer() 239 defer s.Stop() 240 go func() { 241 l, err := (&net.ListenConfig{}).Listen(ctx, "unix", tt.args.addr) 242 if err != nil { 243 panic(err) 244 } 245 s.Serve(l) // nolint: errcheck 246 }() 247 } 248 _, err := connectCRISocket(ctx, tt.args.addr) 249 if (err != nil) != tt.wantErr { 250 t.Errorf("connectCRISocket() error = %v, wantErr %v", err, tt.wantErr) 251 return 252 } 253 }) 254 } 255 } 256 257 func TestNewCRIRuntimeServiceClient(t *testing.T) { 258 oldConnectTimeout := connectTimeout 259 oldCallTimeout := callTimeout 260 defer func() { 261 connectTimeout = oldConnectTimeout 262 callTimeout = oldCallTimeout 263 }() 264 type args struct { 265 ctx context.Context 266 criRuntimeEndpoint string 267 } 268 tests := []struct { 269 name string 270 args args 271 connectTimeout time.Duration 272 callTimeout time.Duration 273 runServer bool 274 wantErr bool 275 }{ 276 { 277 name: "fails on getting socket path", 278 args: args{ 279 ctx: context.Background(), 280 criRuntimeEndpoint: string([]byte{0x7f}), 281 }, 282 runServer: false, 283 wantErr: true, 284 }, 285 { 286 name: "success", 287 args: args{ 288 ctx: context.Background(), 289 criRuntimeEndpoint: "unix:@aporeto_cri_grpc_connect_test1", 290 }, 291 connectTimeout: time.Second * 10, 292 callTimeout: time.Second * 5, 293 runServer: true, 294 wantErr: false, 295 }, 296 { 297 name: "fails creating the ExtendedRuntimeService", 298 args: args{ 299 ctx: context.Background(), 300 criRuntimeEndpoint: "unix:@aporeto_cri_grpc_connect_test2", 301 }, 302 connectTimeout: time.Second * 10, 303 callTimeout: 0, // call timeout must not be 0 304 runServer: true, 305 wantErr: true, 306 }, 307 { 308 name: "fails connecting to the grpc socket", 309 args: args{ 310 ctx: context.Background(), 311 criRuntimeEndpoint: "unix:@aporeto_cri_grpc_connect_test3", 312 }, 313 connectTimeout: 0, 314 runServer: true, 315 wantErr: true, 316 }, 317 } 318 for _, tt := range tests { 319 t.Run(tt.name, func(t *testing.T) { 320 connectTimeout = tt.connectTimeout 321 callTimeout = tt.callTimeout 322 ctx, cancel := context.WithCancel(tt.args.ctx) 323 defer cancel() 324 if tt.runServer { 325 s := grpc.NewServer() 326 defer s.Stop() 327 go func() { 328 l, err := (&net.ListenConfig{}).Listen(ctx, "unix", strings.TrimPrefix(strings.TrimPrefix(tt.args.criRuntimeEndpoint, "unix:"), "//")) 329 if err != nil { 330 panic(err) 331 } 332 s.Serve(l) // nolint: errcheck 333 }() 334 } 335 _, err := NewCRIRuntimeServiceClient(ctx, tt.args.criRuntimeEndpoint) 336 if (err != nil) != tt.wantErr { 337 t.Errorf("NewCRIRuntimeServiceClient() error = %v, wantErr %v", err, tt.wantErr) 338 return 339 } 340 }) 341 } 342 }