github.com/kubeshop/testkube@v1.17.23/pkg/agent/agent.go (about) 1 package agent 2 3 import ( 4 "context" 5 "crypto/tls" 6 "crypto/x509" 7 "fmt" 8 "math" 9 "os" 10 "time" 11 12 "google.golang.org/grpc/keepalive" 13 14 "github.com/kubeshop/testkube/pkg/executor/output" 15 "github.com/kubeshop/testkube/pkg/version" 16 17 "google.golang.org/grpc/credentials" 18 "google.golang.org/grpc/credentials/insecure" 19 "google.golang.org/grpc/encoding/gzip" 20 21 "github.com/pkg/errors" 22 "github.com/valyala/fasthttp" 23 "go.uber.org/zap" 24 "golang.org/x/sync/errgroup" 25 "google.golang.org/grpc" 26 "google.golang.org/grpc/metadata" 27 28 "github.com/kubeshop/testkube/internal/config" 29 "github.com/kubeshop/testkube/pkg/api/v1/testkube" 30 "github.com/kubeshop/testkube/pkg/cloud" 31 "github.com/kubeshop/testkube/pkg/featureflags" 32 ) 33 34 const ( 35 timeout = 10 * time.Second 36 apiKeyMeta = "api-key" 37 clusterIDMeta = "cluster-id" 38 cloudMigrateMeta = "migrate" 39 orgIdMeta = "environment-id" 40 envIdMeta = "organization-id" 41 healthcheckCommand = "healthcheck" 42 ) 43 44 // buffer up to five messages per worker 45 const bufferSizePerWorker = 5 46 47 func NewGRPCConnection( 48 ctx context.Context, 49 isInsecure bool, 50 skipVerify bool, 51 server string, 52 certFile, keyFile, caFile string, 53 logger *zap.SugaredLogger, 54 ) (*grpc.ClientConn, error) { 55 ctx, cancel := context.WithTimeout(ctx, timeout) 56 defer cancel() 57 tlsConfig := &tls.Config{MinVersion: tls.VersionTLS12} 58 if skipVerify { 59 tlsConfig = &tls.Config{InsecureSkipVerify: true} 60 } else { 61 if certFile != "" && keyFile != "" { 62 if err := clientCert(tlsConfig, certFile, keyFile); err != nil { 63 return nil, err 64 } 65 } 66 if caFile != "" { 67 if err := rootCAs(tlsConfig, caFile); err != nil { 68 return nil, err 69 } 70 } 71 } 72 73 creds := credentials.NewTLS(tlsConfig) 74 if isInsecure { 75 creds = insecure.NewCredentials() 76 } 77 78 kacp := keepalive.ClientParameters{ 79 Time: 10 * time.Second, 80 Timeout: 5 * time.Second, 81 PermitWithoutStream: true, 82 } 83 84 userAgent := version.Version + "/" + version.Commit 85 logger.Infow("initiating connection with agent api", "userAgent", userAgent, "server", server, "insecure", isInsecure, "skipVerify", skipVerify, "certFile", certFile, "keyFile", keyFile, "caFile", caFile) 86 // WithBlock, WithReturnConnectionError and FailOnNonTempDialError are recommended not to be used by gRPC go docs 87 // but given that Agent will not work if gRPC connection cannot be established, it is ok to use them and assert issues at dial time 88 return grpc.DialContext( 89 ctx, 90 server, 91 grpc.WithBlock(), 92 grpc.WithReturnConnectionError(), 93 grpc.FailOnNonTempDialError(true), 94 grpc.WithUserAgent(userAgent), 95 grpc.WithTransportCredentials(creds), 96 grpc.WithKeepaliveParams(kacp), 97 ) 98 } 99 100 func rootCAs(tlsConfig *tls.Config, file ...string) error { 101 pool := x509.NewCertPool() 102 for _, f := range file { 103 rootPEM, err := os.ReadFile(f) 104 if err != nil || rootPEM == nil { 105 return fmt.Errorf("agent: error loading or parsing rootCA file: %v", err) 106 } 107 ok := pool.AppendCertsFromPEM(rootPEM) 108 if !ok { 109 return fmt.Errorf("agent: failed to parse root certificate from %q", f) 110 } 111 } 112 tlsConfig.RootCAs = pool 113 return nil 114 } 115 116 func clientCert(tlsConfig *tls.Config, certFile, keyFile string) error { 117 cert, err := tls.LoadX509KeyPair(certFile, keyFile) 118 if err != nil { 119 return fmt.Errorf("agent: error loading client certificate: %v", err) 120 } 121 cert.Leaf, err = x509.ParseCertificate(cert.Certificate[0]) 122 if err != nil { 123 return fmt.Errorf("agent: error parsing client certificate: %v", err) 124 } 125 tlsConfig.Certificates = []tls.Certificate{cert} 126 return nil 127 } 128 129 type Agent struct { 130 client cloud.TestKubeCloudAPIClient 131 handler fasthttp.RequestHandler 132 logger *zap.SugaredLogger 133 apiKey string 134 135 workerCount int 136 requestBuffer chan *cloud.ExecuteRequest 137 responseBuffer chan *cloud.ExecuteResponse 138 139 logStreamWorkerCount int 140 logStreamRequestBuffer chan *cloud.LogsStreamRequest 141 logStreamResponseBuffer chan *cloud.LogsStreamResponse 142 logStreamFunc func(ctx context.Context, executionID string) (chan output.Output, error) 143 144 testWorkflowNotificationsWorkerCount int 145 testWorkflowNotificationsRequestBuffer chan *cloud.TestWorkflowNotificationsRequest 146 testWorkflowNotificationsResponseBuffer chan *cloud.TestWorkflowNotificationsResponse 147 testWorkflowNotificationsFunc func(ctx context.Context, executionID string) (chan testkube.TestWorkflowExecutionNotification, error) 148 149 events chan testkube.Event 150 sendTimeout time.Duration 151 receiveTimeout time.Duration 152 healthcheckInterval time.Duration 153 154 clusterID string 155 clusterName string 156 envs map[string]string 157 features featureflags.FeatureFlags 158 159 proContext config.ProContext 160 } 161 162 func NewAgent(logger *zap.SugaredLogger, 163 handler fasthttp.RequestHandler, 164 client cloud.TestKubeCloudAPIClient, 165 logStreamFunc func(ctx context.Context, executionID string) (chan output.Output, error), 166 workflowNotificationsFunc func(ctx context.Context, executionID string) (chan testkube.TestWorkflowExecutionNotification, error), 167 clusterID string, 168 clusterName string, 169 envs map[string]string, 170 features featureflags.FeatureFlags, 171 proContext config.ProContext, 172 ) (*Agent, error) { 173 return &Agent{ 174 handler: handler, 175 logger: logger.With("service", "Agent", "environmentId", proContext.EnvID), 176 apiKey: proContext.APIKey, 177 client: client, 178 events: make(chan testkube.Event), 179 workerCount: proContext.WorkerCount, 180 requestBuffer: make(chan *cloud.ExecuteRequest, bufferSizePerWorker*proContext.WorkerCount), 181 responseBuffer: make(chan *cloud.ExecuteResponse, bufferSizePerWorker*proContext.WorkerCount), 182 receiveTimeout: 5 * time.Minute, 183 sendTimeout: 30 * time.Second, 184 healthcheckInterval: 30 * time.Second, 185 logStreamWorkerCount: proContext.LogStreamWorkerCount, 186 logStreamRequestBuffer: make(chan *cloud.LogsStreamRequest, bufferSizePerWorker*proContext.LogStreamWorkerCount), 187 logStreamResponseBuffer: make(chan *cloud.LogsStreamResponse, bufferSizePerWorker*proContext.LogStreamWorkerCount), 188 logStreamFunc: logStreamFunc, 189 testWorkflowNotificationsWorkerCount: proContext.WorkflowNotificationsWorkerCount, 190 testWorkflowNotificationsRequestBuffer: make(chan *cloud.TestWorkflowNotificationsRequest, bufferSizePerWorker*proContext.WorkflowNotificationsWorkerCount), 191 testWorkflowNotificationsResponseBuffer: make(chan *cloud.TestWorkflowNotificationsResponse, bufferSizePerWorker*proContext.WorkflowNotificationsWorkerCount), 192 testWorkflowNotificationsFunc: workflowNotificationsFunc, 193 clusterID: clusterID, 194 clusterName: clusterName, 195 envs: envs, 196 features: features, 197 proContext: proContext, 198 }, nil 199 } 200 201 func (ag *Agent) Run(ctx context.Context) error { 202 for { 203 if ctx.Err() != nil { 204 return ctx.Err() 205 } 206 err := ag.run(ctx) 207 208 ag.logger.Errorw("agent connection failed, reconnecting", "error", err) 209 210 // TODO: some smart back off strategy? 211 time.Sleep(5 * time.Second) 212 } 213 } 214 215 func (ag *Agent) run(ctx context.Context) (err error) { 216 g, groupCtx := errgroup.WithContext(ctx) 217 g.Go(func() error { 218 return ag.runCommandLoop(groupCtx) 219 }) 220 221 g.Go(func() error { 222 return ag.runWorkers(groupCtx, ag.workerCount) 223 }) 224 225 g.Go(func() error { 226 return ag.runEventLoop(groupCtx) 227 }) 228 229 if !ag.features.LogsV2 { 230 g.Go(func() error { 231 return ag.runLogStreamLoop(groupCtx) 232 }) 233 g.Go(func() error { 234 return ag.runLogStreamWorker(groupCtx, ag.logStreamWorkerCount) 235 }) 236 } 237 238 g.Go(func() error { 239 return ag.runTestWorkflowNotificationsLoop(groupCtx) 240 }) 241 g.Go(func() error { 242 return ag.runTestWorkflowNotificationsWorker(groupCtx, ag.testWorkflowNotificationsWorkerCount) 243 }) 244 245 err = g.Wait() 246 247 return err 248 } 249 250 func (ag *Agent) sendResponse(ctx context.Context, stream cloud.TestKubeCloudAPI_ExecuteClient, resp *cloud.ExecuteResponse) error { 251 errChan := make(chan error, 1) 252 go func() { 253 errChan <- stream.Send(resp) 254 close(errChan) 255 }() 256 257 t := time.NewTimer(ag.sendTimeout) 258 select { 259 case err := <-errChan: 260 if !t.Stop() { 261 <-t.C 262 } 263 return err 264 case <-ctx.Done(): 265 if !t.Stop() { 266 <-t.C 267 } 268 269 return ctx.Err() 270 case <-t.C: 271 return errors.New("send response too slow") 272 } 273 } 274 275 func (ag *Agent) receiveCommand(ctx context.Context, stream cloud.TestKubeCloudAPI_ExecuteClient) (*cloud.ExecuteRequest, error) { 276 respChan := make(chan cloudResponse, 1) 277 go func() { 278 cmd, err := stream.Recv() 279 respChan <- cloudResponse{resp: cmd, err: err} 280 }() 281 282 t := time.NewTimer(ag.receiveTimeout) 283 var cmd *cloud.ExecuteRequest 284 select { 285 case resp := <-respChan: 286 if !t.Stop() { 287 <-t.C 288 } 289 290 cmd = resp.resp 291 err := resp.err 292 293 if err != nil { 294 ag.logger.Errorf("agent stream receive: %v", err) 295 return nil, err 296 } 297 case <-ctx.Done(): 298 if !t.Stop() { 299 <-t.C 300 } 301 302 return nil, ctx.Err() 303 case <-t.C: 304 return nil, errors.New("stream receive too slow") 305 } 306 307 return cmd, nil 308 } 309 310 func (ag *Agent) runCommandLoop(ctx context.Context) error { 311 ctx = AddAPIKeyMeta(ctx, ag.proContext.APIKey) 312 313 ctx = metadata.AppendToOutgoingContext(ctx, clusterIDMeta, ag.clusterID) 314 ctx = metadata.AppendToOutgoingContext(ctx, cloudMigrateMeta, ag.proContext.Migrate) 315 ctx = metadata.AppendToOutgoingContext(ctx, envIdMeta, ag.proContext.EnvID) 316 ctx = metadata.AppendToOutgoingContext(ctx, orgIdMeta, ag.proContext.OrgID) 317 318 ag.logger.Infow("initiating streaming connection with Pro API") 319 // creates a new Stream from the client side. ctx is used for the lifetime of the stream. 320 opts := []grpc.CallOption{grpc.UseCompressor(gzip.Name), grpc.MaxCallRecvMsgSize(math.MaxInt32)} 321 stream, err := ag.client.ExecuteAsync(ctx, opts...) 322 if err != nil { 323 ag.logger.Errorf("failed to execute: %w", err) 324 return errors.Wrap(err, "failed to setup stream") 325 } 326 327 // GRPC stream have special requirements for concurrency on SendMsg, and RecvMsg calls. 328 // Please check https://github.com/grpc/grpc-go/blob/master/Documentation/concurrency.md 329 g, groupCtx := errgroup.WithContext(ctx) 330 g.Go(func() error { 331 for { 332 cmd, err := ag.receiveCommand(groupCtx, stream) 333 if err != nil { 334 return err 335 } 336 337 ag.requestBuffer <- cmd 338 } 339 }) 340 341 g.Go(func() error { 342 for { 343 select { 344 case resp := <-ag.responseBuffer: 345 err := ag.sendResponse(groupCtx, stream, resp) 346 if err != nil { 347 return err 348 } 349 case <-groupCtx.Done(): 350 return groupCtx.Err() 351 } 352 } 353 }) 354 355 err = g.Wait() 356 357 return err 358 } 359 360 func (ag *Agent) runWorkers(ctx context.Context, numWorkers int) error { 361 g, groupCtx := errgroup.WithContext(ctx) 362 for i := 0; i < numWorkers; i++ { 363 g.Go(func() error { 364 for { 365 select { 366 case cmd := <-ag.requestBuffer: 367 select { 368 case ag.responseBuffer <- ag.executeCommand(groupCtx, cmd): 369 case <-groupCtx.Done(): 370 return groupCtx.Err() 371 } 372 case <-groupCtx.Done(): 373 return groupCtx.Err() 374 } 375 } 376 }) 377 } 378 return g.Wait() 379 } 380 381 func (ag *Agent) executeCommand(ctx context.Context, cmd *cloud.ExecuteRequest) *cloud.ExecuteResponse { 382 switch { 383 case cmd.Url == healthcheckCommand: 384 return &cloud.ExecuteResponse{MessageId: cmd.MessageId, Status: 0} 385 default: 386 req := &fasthttp.RequestCtx{} 387 r := fasthttp.AcquireRequest() 388 r.Header.SetHost("localhost") 389 r.Header.SetMethod(cmd.Method) 390 391 for k, values := range cmd.Headers { 392 for _, value := range values.Header { 393 r.Header.Add(k, value) 394 } 395 } 396 r.SetBody(cmd.Body) 397 uri := &fasthttp.URI{} 398 399 err := uri.Parse(nil, []byte(cmd.Url)) 400 if err != nil { 401 ag.logger.Errorf("agent bad command url: %w", err) 402 resp := &cloud.ExecuteResponse{MessageId: cmd.MessageId, Status: 400, Body: []byte(fmt.Sprintf("bad command url: %s", err))} 403 return resp 404 } 405 r.SetURI(uri) 406 407 req.Init(r, nil, nil) 408 ag.handler(req) 409 410 fasthttp.ReleaseRequest(r) 411 412 headers := make(map[string]*cloud.HeaderValue) 413 req.Response.Header.VisitAll(func(key, value []byte) { 414 _, ok := headers[string(key)] 415 if !ok { 416 headers[string(key)] = &cloud.HeaderValue{Header: []string{string(value)}} 417 return 418 } 419 420 headers[string(key)].Header = append(headers[string(key)].Header, string(value)) 421 }) 422 423 resp := &cloud.ExecuteResponse{MessageId: cmd.MessageId, Headers: headers, Status: int64(req.Response.StatusCode()), Body: req.Response.Body()} 424 425 return resp 426 } 427 } 428 429 func AddAPIKeyMeta(ctx context.Context, apiKey string) context.Context { 430 md := metadata.Pairs(apiKeyMeta, apiKey) 431 return metadata.NewOutgoingContext(ctx, md) 432 } 433 434 type cloudResponse struct { 435 resp *cloud.ExecuteRequest 436 err error 437 }