github.com/kubeshop/testkube@v1.17.23/pkg/agent/agent.go (about)

     1  package agent
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"crypto/x509"
     7  	"fmt"
     8  	"math"
     9  	"os"
    10  	"time"
    11  
    12  	"google.golang.org/grpc/keepalive"
    13  
    14  	"github.com/kubeshop/testkube/pkg/executor/output"
    15  	"github.com/kubeshop/testkube/pkg/version"
    16  
    17  	"google.golang.org/grpc/credentials"
    18  	"google.golang.org/grpc/credentials/insecure"
    19  	"google.golang.org/grpc/encoding/gzip"
    20  
    21  	"github.com/pkg/errors"
    22  	"github.com/valyala/fasthttp"
    23  	"go.uber.org/zap"
    24  	"golang.org/x/sync/errgroup"
    25  	"google.golang.org/grpc"
    26  	"google.golang.org/grpc/metadata"
    27  
    28  	"github.com/kubeshop/testkube/internal/config"
    29  	"github.com/kubeshop/testkube/pkg/api/v1/testkube"
    30  	"github.com/kubeshop/testkube/pkg/cloud"
    31  	"github.com/kubeshop/testkube/pkg/featureflags"
    32  )
    33  
    34  const (
    35  	timeout            = 10 * time.Second
    36  	apiKeyMeta         = "api-key"
    37  	clusterIDMeta      = "cluster-id"
    38  	cloudMigrateMeta   = "migrate"
    39  	orgIdMeta          = "environment-id"
    40  	envIdMeta          = "organization-id"
    41  	healthcheckCommand = "healthcheck"
    42  )
    43  
    44  // buffer up to five messages per worker
    45  const bufferSizePerWorker = 5
    46  
    47  func NewGRPCConnection(
    48  	ctx context.Context,
    49  	isInsecure bool,
    50  	skipVerify bool,
    51  	server string,
    52  	certFile, keyFile, caFile string,
    53  	logger *zap.SugaredLogger,
    54  ) (*grpc.ClientConn, error) {
    55  	ctx, cancel := context.WithTimeout(ctx, timeout)
    56  	defer cancel()
    57  	tlsConfig := &tls.Config{MinVersion: tls.VersionTLS12}
    58  	if skipVerify {
    59  		tlsConfig = &tls.Config{InsecureSkipVerify: true}
    60  	} else {
    61  		if certFile != "" && keyFile != "" {
    62  			if err := clientCert(tlsConfig, certFile, keyFile); err != nil {
    63  				return nil, err
    64  			}
    65  		}
    66  		if caFile != "" {
    67  			if err := rootCAs(tlsConfig, caFile); err != nil {
    68  				return nil, err
    69  			}
    70  		}
    71  	}
    72  
    73  	creds := credentials.NewTLS(tlsConfig)
    74  	if isInsecure {
    75  		creds = insecure.NewCredentials()
    76  	}
    77  
    78  	kacp := keepalive.ClientParameters{
    79  		Time:                10 * time.Second,
    80  		Timeout:             5 * time.Second,
    81  		PermitWithoutStream: true,
    82  	}
    83  
    84  	userAgent := version.Version + "/" + version.Commit
    85  	logger.Infow("initiating connection with agent api", "userAgent", userAgent, "server", server, "insecure", isInsecure, "skipVerify", skipVerify, "certFile", certFile, "keyFile", keyFile, "caFile", caFile)
    86  	// WithBlock, WithReturnConnectionError and FailOnNonTempDialError are recommended not to be used by gRPC go docs
    87  	// but given that Agent will not work if gRPC connection cannot be established, it is ok to use them and assert issues at dial time
    88  	return grpc.DialContext(
    89  		ctx,
    90  		server,
    91  		grpc.WithBlock(),
    92  		grpc.WithReturnConnectionError(),
    93  		grpc.FailOnNonTempDialError(true),
    94  		grpc.WithUserAgent(userAgent),
    95  		grpc.WithTransportCredentials(creds),
    96  		grpc.WithKeepaliveParams(kacp),
    97  	)
    98  }
    99  
   100  func rootCAs(tlsConfig *tls.Config, file ...string) error {
   101  	pool := x509.NewCertPool()
   102  	for _, f := range file {
   103  		rootPEM, err := os.ReadFile(f)
   104  		if err != nil || rootPEM == nil {
   105  			return fmt.Errorf("agent: error loading or parsing rootCA file: %v", err)
   106  		}
   107  		ok := pool.AppendCertsFromPEM(rootPEM)
   108  		if !ok {
   109  			return fmt.Errorf("agent: failed to parse root certificate from %q", f)
   110  		}
   111  	}
   112  	tlsConfig.RootCAs = pool
   113  	return nil
   114  }
   115  
   116  func clientCert(tlsConfig *tls.Config, certFile, keyFile string) error {
   117  	cert, err := tls.LoadX509KeyPair(certFile, keyFile)
   118  	if err != nil {
   119  		return fmt.Errorf("agent: error loading client certificate: %v", err)
   120  	}
   121  	cert.Leaf, err = x509.ParseCertificate(cert.Certificate[0])
   122  	if err != nil {
   123  		return fmt.Errorf("agent: error parsing client certificate: %v", err)
   124  	}
   125  	tlsConfig.Certificates = []tls.Certificate{cert}
   126  	return nil
   127  }
   128  
   129  type Agent struct {
   130  	client  cloud.TestKubeCloudAPIClient
   131  	handler fasthttp.RequestHandler
   132  	logger  *zap.SugaredLogger
   133  	apiKey  string
   134  
   135  	workerCount    int
   136  	requestBuffer  chan *cloud.ExecuteRequest
   137  	responseBuffer chan *cloud.ExecuteResponse
   138  
   139  	logStreamWorkerCount    int
   140  	logStreamRequestBuffer  chan *cloud.LogsStreamRequest
   141  	logStreamResponseBuffer chan *cloud.LogsStreamResponse
   142  	logStreamFunc           func(ctx context.Context, executionID string) (chan output.Output, error)
   143  
   144  	testWorkflowNotificationsWorkerCount    int
   145  	testWorkflowNotificationsRequestBuffer  chan *cloud.TestWorkflowNotificationsRequest
   146  	testWorkflowNotificationsResponseBuffer chan *cloud.TestWorkflowNotificationsResponse
   147  	testWorkflowNotificationsFunc           func(ctx context.Context, executionID string) (chan testkube.TestWorkflowExecutionNotification, error)
   148  
   149  	events              chan testkube.Event
   150  	sendTimeout         time.Duration
   151  	receiveTimeout      time.Duration
   152  	healthcheckInterval time.Duration
   153  
   154  	clusterID   string
   155  	clusterName string
   156  	envs        map[string]string
   157  	features    featureflags.FeatureFlags
   158  
   159  	proContext config.ProContext
   160  }
   161  
   162  func NewAgent(logger *zap.SugaredLogger,
   163  	handler fasthttp.RequestHandler,
   164  	client cloud.TestKubeCloudAPIClient,
   165  	logStreamFunc func(ctx context.Context, executionID string) (chan output.Output, error),
   166  	workflowNotificationsFunc func(ctx context.Context, executionID string) (chan testkube.TestWorkflowExecutionNotification, error),
   167  	clusterID string,
   168  	clusterName string,
   169  	envs map[string]string,
   170  	features featureflags.FeatureFlags,
   171  	proContext config.ProContext,
   172  ) (*Agent, error) {
   173  	return &Agent{
   174  		handler:                                 handler,
   175  		logger:                                  logger.With("service", "Agent", "environmentId", proContext.EnvID),
   176  		apiKey:                                  proContext.APIKey,
   177  		client:                                  client,
   178  		events:                                  make(chan testkube.Event),
   179  		workerCount:                             proContext.WorkerCount,
   180  		requestBuffer:                           make(chan *cloud.ExecuteRequest, bufferSizePerWorker*proContext.WorkerCount),
   181  		responseBuffer:                          make(chan *cloud.ExecuteResponse, bufferSizePerWorker*proContext.WorkerCount),
   182  		receiveTimeout:                          5 * time.Minute,
   183  		sendTimeout:                             30 * time.Second,
   184  		healthcheckInterval:                     30 * time.Second,
   185  		logStreamWorkerCount:                    proContext.LogStreamWorkerCount,
   186  		logStreamRequestBuffer:                  make(chan *cloud.LogsStreamRequest, bufferSizePerWorker*proContext.LogStreamWorkerCount),
   187  		logStreamResponseBuffer:                 make(chan *cloud.LogsStreamResponse, bufferSizePerWorker*proContext.LogStreamWorkerCount),
   188  		logStreamFunc:                           logStreamFunc,
   189  		testWorkflowNotificationsWorkerCount:    proContext.WorkflowNotificationsWorkerCount,
   190  		testWorkflowNotificationsRequestBuffer:  make(chan *cloud.TestWorkflowNotificationsRequest, bufferSizePerWorker*proContext.WorkflowNotificationsWorkerCount),
   191  		testWorkflowNotificationsResponseBuffer: make(chan *cloud.TestWorkflowNotificationsResponse, bufferSizePerWorker*proContext.WorkflowNotificationsWorkerCount),
   192  		testWorkflowNotificationsFunc:           workflowNotificationsFunc,
   193  		clusterID:                               clusterID,
   194  		clusterName:                             clusterName,
   195  		envs:                                    envs,
   196  		features:                                features,
   197  		proContext:                              proContext,
   198  	}, nil
   199  }
   200  
   201  func (ag *Agent) Run(ctx context.Context) error {
   202  	for {
   203  		if ctx.Err() != nil {
   204  			return ctx.Err()
   205  		}
   206  		err := ag.run(ctx)
   207  
   208  		ag.logger.Errorw("agent connection failed, reconnecting", "error", err)
   209  
   210  		// TODO: some smart back off strategy?
   211  		time.Sleep(5 * time.Second)
   212  	}
   213  }
   214  
   215  func (ag *Agent) run(ctx context.Context) (err error) {
   216  	g, groupCtx := errgroup.WithContext(ctx)
   217  	g.Go(func() error {
   218  		return ag.runCommandLoop(groupCtx)
   219  	})
   220  
   221  	g.Go(func() error {
   222  		return ag.runWorkers(groupCtx, ag.workerCount)
   223  	})
   224  
   225  	g.Go(func() error {
   226  		return ag.runEventLoop(groupCtx)
   227  	})
   228  
   229  	if !ag.features.LogsV2 {
   230  		g.Go(func() error {
   231  			return ag.runLogStreamLoop(groupCtx)
   232  		})
   233  		g.Go(func() error {
   234  			return ag.runLogStreamWorker(groupCtx, ag.logStreamWorkerCount)
   235  		})
   236  	}
   237  
   238  	g.Go(func() error {
   239  		return ag.runTestWorkflowNotificationsLoop(groupCtx)
   240  	})
   241  	g.Go(func() error {
   242  		return ag.runTestWorkflowNotificationsWorker(groupCtx, ag.testWorkflowNotificationsWorkerCount)
   243  	})
   244  
   245  	err = g.Wait()
   246  
   247  	return err
   248  }
   249  
   250  func (ag *Agent) sendResponse(ctx context.Context, stream cloud.TestKubeCloudAPI_ExecuteClient, resp *cloud.ExecuteResponse) error {
   251  	errChan := make(chan error, 1)
   252  	go func() {
   253  		errChan <- stream.Send(resp)
   254  		close(errChan)
   255  	}()
   256  
   257  	t := time.NewTimer(ag.sendTimeout)
   258  	select {
   259  	case err := <-errChan:
   260  		if !t.Stop() {
   261  			<-t.C
   262  		}
   263  		return err
   264  	case <-ctx.Done():
   265  		if !t.Stop() {
   266  			<-t.C
   267  		}
   268  
   269  		return ctx.Err()
   270  	case <-t.C:
   271  		return errors.New("send response too slow")
   272  	}
   273  }
   274  
   275  func (ag *Agent) receiveCommand(ctx context.Context, stream cloud.TestKubeCloudAPI_ExecuteClient) (*cloud.ExecuteRequest, error) {
   276  	respChan := make(chan cloudResponse, 1)
   277  	go func() {
   278  		cmd, err := stream.Recv()
   279  		respChan <- cloudResponse{resp: cmd, err: err}
   280  	}()
   281  
   282  	t := time.NewTimer(ag.receiveTimeout)
   283  	var cmd *cloud.ExecuteRequest
   284  	select {
   285  	case resp := <-respChan:
   286  		if !t.Stop() {
   287  			<-t.C
   288  		}
   289  
   290  		cmd = resp.resp
   291  		err := resp.err
   292  
   293  		if err != nil {
   294  			ag.logger.Errorf("agent stream receive: %v", err)
   295  			return nil, err
   296  		}
   297  	case <-ctx.Done():
   298  		if !t.Stop() {
   299  			<-t.C
   300  		}
   301  
   302  		return nil, ctx.Err()
   303  	case <-t.C:
   304  		return nil, errors.New("stream receive too slow")
   305  	}
   306  
   307  	return cmd, nil
   308  }
   309  
   310  func (ag *Agent) runCommandLoop(ctx context.Context) error {
   311  	ctx = AddAPIKeyMeta(ctx, ag.proContext.APIKey)
   312  
   313  	ctx = metadata.AppendToOutgoingContext(ctx, clusterIDMeta, ag.clusterID)
   314  	ctx = metadata.AppendToOutgoingContext(ctx, cloudMigrateMeta, ag.proContext.Migrate)
   315  	ctx = metadata.AppendToOutgoingContext(ctx, envIdMeta, ag.proContext.EnvID)
   316  	ctx = metadata.AppendToOutgoingContext(ctx, orgIdMeta, ag.proContext.OrgID)
   317  
   318  	ag.logger.Infow("initiating streaming connection with Pro API")
   319  	// creates a new Stream from the client side. ctx is used for the lifetime of the stream.
   320  	opts := []grpc.CallOption{grpc.UseCompressor(gzip.Name), grpc.MaxCallRecvMsgSize(math.MaxInt32)}
   321  	stream, err := ag.client.ExecuteAsync(ctx, opts...)
   322  	if err != nil {
   323  		ag.logger.Errorf("failed to execute: %w", err)
   324  		return errors.Wrap(err, "failed to setup stream")
   325  	}
   326  
   327  	// GRPC stream have special requirements for concurrency on SendMsg, and RecvMsg calls.
   328  	// Please check https://github.com/grpc/grpc-go/blob/master/Documentation/concurrency.md
   329  	g, groupCtx := errgroup.WithContext(ctx)
   330  	g.Go(func() error {
   331  		for {
   332  			cmd, err := ag.receiveCommand(groupCtx, stream)
   333  			if err != nil {
   334  				return err
   335  			}
   336  
   337  			ag.requestBuffer <- cmd
   338  		}
   339  	})
   340  
   341  	g.Go(func() error {
   342  		for {
   343  			select {
   344  			case resp := <-ag.responseBuffer:
   345  				err := ag.sendResponse(groupCtx, stream, resp)
   346  				if err != nil {
   347  					return err
   348  				}
   349  			case <-groupCtx.Done():
   350  				return groupCtx.Err()
   351  			}
   352  		}
   353  	})
   354  
   355  	err = g.Wait()
   356  
   357  	return err
   358  }
   359  
   360  func (ag *Agent) runWorkers(ctx context.Context, numWorkers int) error {
   361  	g, groupCtx := errgroup.WithContext(ctx)
   362  	for i := 0; i < numWorkers; i++ {
   363  		g.Go(func() error {
   364  			for {
   365  				select {
   366  				case cmd := <-ag.requestBuffer:
   367  					select {
   368  					case ag.responseBuffer <- ag.executeCommand(groupCtx, cmd):
   369  					case <-groupCtx.Done():
   370  						return groupCtx.Err()
   371  					}
   372  				case <-groupCtx.Done():
   373  					return groupCtx.Err()
   374  				}
   375  			}
   376  		})
   377  	}
   378  	return g.Wait()
   379  }
   380  
   381  func (ag *Agent) executeCommand(ctx context.Context, cmd *cloud.ExecuteRequest) *cloud.ExecuteResponse {
   382  	switch {
   383  	case cmd.Url == healthcheckCommand:
   384  		return &cloud.ExecuteResponse{MessageId: cmd.MessageId, Status: 0}
   385  	default:
   386  		req := &fasthttp.RequestCtx{}
   387  		r := fasthttp.AcquireRequest()
   388  		r.Header.SetHost("localhost")
   389  		r.Header.SetMethod(cmd.Method)
   390  
   391  		for k, values := range cmd.Headers {
   392  			for _, value := range values.Header {
   393  				r.Header.Add(k, value)
   394  			}
   395  		}
   396  		r.SetBody(cmd.Body)
   397  		uri := &fasthttp.URI{}
   398  
   399  		err := uri.Parse(nil, []byte(cmd.Url))
   400  		if err != nil {
   401  			ag.logger.Errorf("agent bad command url: %w", err)
   402  			resp := &cloud.ExecuteResponse{MessageId: cmd.MessageId, Status: 400, Body: []byte(fmt.Sprintf("bad command url: %s", err))}
   403  			return resp
   404  		}
   405  		r.SetURI(uri)
   406  
   407  		req.Init(r, nil, nil)
   408  		ag.handler(req)
   409  
   410  		fasthttp.ReleaseRequest(r)
   411  
   412  		headers := make(map[string]*cloud.HeaderValue)
   413  		req.Response.Header.VisitAll(func(key, value []byte) {
   414  			_, ok := headers[string(key)]
   415  			if !ok {
   416  				headers[string(key)] = &cloud.HeaderValue{Header: []string{string(value)}}
   417  				return
   418  			}
   419  
   420  			headers[string(key)].Header = append(headers[string(key)].Header, string(value))
   421  		})
   422  
   423  		resp := &cloud.ExecuteResponse{MessageId: cmd.MessageId, Headers: headers, Status: int64(req.Response.StatusCode()), Body: req.Response.Body()}
   424  
   425  		return resp
   426  	}
   427  }
   428  
   429  func AddAPIKeyMeta(ctx context.Context, apiKey string) context.Context {
   430  	md := metadata.Pairs(apiKeyMeta, apiKey)
   431  	return metadata.NewOutgoingContext(ctx, md)
   432  }
   433  
   434  type cloudResponse struct {
   435  	resp *cloud.ExecuteRequest
   436  	err  error
   437  }