github.com/grafana/pyroscope@v1.18.0/pkg/test/integration/cluster/cluster.go (about)

     1  package cluster
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"log"
     7  	"math/rand"
     8  	"net"
     9  	"net/http"
    10  	"os"
    11  	"path/filepath"
    12  	"sync"
    13  	"time"
    14  
    15  	"golang.org/x/sync/errgroup"
    16  
    17  	"github.com/grafana/pyroscope/api/gen/proto/go/push/v1/pushv1connect"
    18  	"github.com/grafana/pyroscope/api/gen/proto/go/querier/v1/querierv1connect"
    19  	connectapi "github.com/grafana/pyroscope/pkg/api/connect"
    20  	"github.com/grafana/pyroscope/pkg/tenant"
    21  )
    22  
    23  const listenAddr = "127.0.0.1"
    24  
    25  func getFreeTCPPorts(address string, count int) ([]int, error) {
    26  	ports := make([]int, count)
    27  	for i := 0; i < count; i++ {
    28  		addr, err := net.ResolveTCPAddr("tcp", address+":0")
    29  		if err != nil {
    30  			return nil, err
    31  		}
    32  
    33  		l, err := net.ListenTCP("tcp", addr)
    34  		if err != nil {
    35  			return nil, err
    36  		}
    37  		defer l.Close()
    38  
    39  		if tcpAddr, ok := l.Addr().(*net.TCPAddr); ok {
    40  			ports[i] = tcpAddr.Port
    41  		} else {
    42  			return nil, fmt.Errorf("unable to retrieve tcp port from %v", l)
    43  		}
    44  	}
    45  
    46  	return ports, nil
    47  }
    48  
    49  func newComponent(target string) *Component {
    50  	return &Component{
    51  		Target: target,
    52  	}
    53  }
    54  
    55  type testTransport struct {
    56  	defaultDialContext func(ctx context.Context, network, addr string) (net.Conn, error)
    57  	next               http.RoundTripper
    58  	c                  *Cluster
    59  }
    60  
    61  // use custom http transport to resolve dynamically to healthy components
    62  func newTestTransport(c *Cluster) http.RoundTripper {
    63  	defaultTransport := http.DefaultTransport.(*http.Transport)
    64  	t := &testTransport{
    65  		defaultDialContext: defaultTransport.DialContext,
    66  		c:                  c,
    67  	}
    68  	t.next = &http.Transport{
    69  		Proxy:                 defaultTransport.Proxy,
    70  		TLSClientConfig:       defaultTransport.TLSClientConfig,
    71  		TLSHandshakeTimeout:   defaultTransport.TLSHandshakeTimeout,
    72  		ExpectContinueTimeout: defaultTransport.ExpectContinueTimeout,
    73  		MaxIdleConns:          defaultTransport.MaxIdleConns,
    74  		IdleConnTimeout:       defaultTransport.IdleConnTimeout,
    75  		ForceAttemptHTTP2:     defaultTransport.ForceAttemptHTTP2,
    76  		DialContext:           t.DialContext,
    77  	}
    78  	return t
    79  }
    80  
    81  func (t *testTransport) RoundTrip(req *http.Request) (*http.Response, error) {
    82  	tenantID, err := tenant.ExtractTenantIDFromContext(req.Context())
    83  	if err == nil {
    84  		req.Header.Set("X-Scope-OrgID", tenantID)
    85  	}
    86  	return t.next.RoundTrip(req)
    87  }
    88  
    89  func (t *testTransport) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
    90  	var err error
    91  	switch addr {
    92  	case "push:80":
    93  		addr, err = t.c.pickHealthyComponent("distributor")
    94  		if err != nil {
    95  			return nil, err
    96  		}
    97  	case "querier:80":
    98  		addr, err = t.c.pickHealthyComponent("query-frontend", "querier")
    99  		if err != nil {
   100  			return nil, err
   101  		}
   102  	default:
   103  		return nil, fmt.Errorf("unknown addr %s", addr)
   104  	}
   105  
   106  	return t.defaultDialContext(ctx, network, addr)
   107  }
   108  
   109  type ClusterOption func(c *Cluster)
   110  
   111  func NewMicroServiceCluster(opts ...ClusterOption) *Cluster {
   112  	c := &Cluster{}
   113  	WithV1()(c)
   114  
   115  	// apply options
   116  	for _, opt := range opts {
   117  		opt(c)
   118  	}
   119  
   120  	c.httpClient = &http.Client{Transport: newTestTransport(c)}
   121  	c.Components = make([]*Component, len(c.expectedComponents))
   122  	for idx := range c.expectedComponents {
   123  		c.Components[idx] = newComponent(c.expectedComponents[idx])
   124  	}
   125  
   126  	return c
   127  }
   128  
   129  type Cluster struct {
   130  	Components []*Component
   131  	perTarget  map[string][]int // indexes replicas per target into Components slice
   132  
   133  	wg sync.WaitGroup // components wait group
   134  
   135  	v2                 bool     // is this a v2 cluster
   136  	debuginfodURL      string   // debuginfod URL for symbolization
   137  	expectedComponents []string // number of expected components
   138  
   139  	tmpDir     string
   140  	httpClient *http.Client
   141  }
   142  
   143  func (c *Cluster) commonFlags(comp *Component) []string {
   144  	nodeName := comp.nodeName()
   145  	return []string{
   146  		"-auth.multitenancy-enabled=true",
   147  		"-tracing.enabled=false", // data race
   148  		"-self-profiling.disable-push=true",
   149  		fmt.Sprintf("-pyroscopedb.data-path=%s", c.dataDir(comp)),
   150  		"-storage.backend=filesystem",
   151  		fmt.Sprintf("-storage.filesystem.dir=%s", c.dataSharedDir()),
   152  		fmt.Sprintf("-target=%s", comp.Target),
   153  		fmt.Sprintf("-memberlist.advertise-port=%d", comp.memberlistPort),
   154  		fmt.Sprintf("-memberlist.bind-port=%d", comp.memberlistPort),
   155  		fmt.Sprintf("-memberlist.bind-addr=%s", listenAddr),
   156  		"-memberlist.leave-timeout=1s",
   157  		"-memberlist.advertise-addr=" + listenAddr,
   158  		"-memberlist.nodename=" + nodeName,
   159  		fmt.Sprintf("-server.http-listen-port=%d", comp.httpPort),
   160  		fmt.Sprintf("-server.http-listen-address=%s", listenAddr),
   161  		fmt.Sprintf("-server.grpc-listen-port=%d", comp.grpcPort),
   162  		fmt.Sprintf("-server.grpc-listen-address=%s", listenAddr),
   163  		"-distributor.ring.instance-addr=" + listenAddr,
   164  		"-distributor.ring.instance-id=" + nodeName,
   165  		"-distributor.ring.heartbeat-period=1s",
   166  		"-overrides-exporter.ring.instance-addr=" + listenAddr,
   167  		"-overrides-exporter.ring.instance-id=" + nodeName,
   168  		"-overrides-exporter.ring.heartbeat-period=1s",
   169  		"-query-frontend.instance-addr=" + listenAddr,
   170  	}
   171  }
   172  
   173  func (c *Cluster) pickHealthyComponent(targets ...string) (addr string, err error) {
   174  	results := make([][]string, len(targets))
   175  
   176  	for _, comp := range c.Components {
   177  		for i, target := range targets {
   178  			if comp.Target == target {
   179  				results[i] = append(results[i], fmt.Sprintf("%s:%d", listenAddr, comp.httpPort))
   180  			}
   181  		}
   182  	}
   183  
   184  	for _, result := range results {
   185  		if len(result) > 0 {
   186  			// pick random element of list
   187  			return result[rand.Intn(len(result))], nil
   188  		}
   189  	}
   190  
   191  	return "", fmt.Errorf("no healthy component found for targets %v", targets)
   192  }
   193  func (c *Cluster) dataSharedDir() string {
   194  	return filepath.Join(c.tmpDir, "data-shared")
   195  }
   196  
   197  func (c *Cluster) dataDir(comp *Component) string {
   198  	return filepath.Join(c.tmpDir, comp.nodeName(), "data")
   199  }
   200  
   201  func (c *Cluster) Prepare(ctx context.Context) (err error) {
   202  	// tmp dir
   203  	c.tmpDir, err = os.MkdirTemp("", "pyroscope-test")
   204  	if err != nil {
   205  		return err
   206  	}
   207  	if err := os.Mkdir(c.dataSharedDir(), 0o755); err != nil {
   208  		return err
   209  	}
   210  
   211  	// allocate two tcp ports per component
   212  	portsPerComponent := 3
   213  	if c.v2 {
   214  		portsPerComponent = 4
   215  	}
   216  	ports, err := getFreeTCPPorts(listenAddr, len(c.Components)*portsPerComponent)
   217  	if err != nil {
   218  		return err
   219  	}
   220  
   221  	// flags with all components that participate in memberlist
   222  	memberlistJoin := []string{}
   223  	c.perTarget = map[string][]int{}
   224  	for compidx, comp := range c.Components {
   225  		c.perTarget[comp.Target] = append(c.perTarget[comp.Target], compidx)
   226  		comp.replica = len(c.perTarget[comp.Target]) - 1
   227  
   228  		// allocate ports
   229  		comp.addPorts(ports[0:portsPerComponent])
   230  		ports = ports[portsPerComponent:]
   231  
   232  		// add to memberlist join list
   233  		memberlistJoin = append(memberlistJoin, fmt.Sprintf("%s:%d", listenAddr, comp.memberlistPort))
   234  
   235  		if err := os.MkdirAll(c.dataDir(comp), 0o755); err != nil {
   236  			return err
   237  		}
   238  	}
   239  
   240  	if c.v2 {
   241  		return c.v2Prepare(ctx, memberlistJoin)
   242  	}
   243  
   244  	return c.v1Prepare(ctx, memberlistJoin)
   245  }
   246  
   247  func (c *Cluster) Stop() func(context.Context) error {
   248  	funcWaiters := make([]func(context.Context) error, 0, len(c.Components)+1)
   249  	for _, comp := range c.Components {
   250  		funcWaiters = append(funcWaiters, comp.Stop())
   251  	}
   252  
   253  	return func(ctx context.Context) error {
   254  		g, ctx := errgroup.WithContext(ctx)
   255  		for _, f := range funcWaiters {
   256  			f := f
   257  			g.Go(func() error {
   258  				return f(ctx)
   259  			})
   260  		}
   261  		return g.Wait()
   262  	}
   263  }
   264  
   265  func (c *Cluster) Start(ctx context.Context) (err error) {
   266  	notReady := make(map[*Component]error)
   267  
   268  	for _, comp := range c.Components {
   269  		p, err := comp.start(ctx)
   270  		if err != nil {
   271  			return err
   272  		}
   273  		comp.p = p
   274  
   275  		notReady[comp] = nil
   276  
   277  		c.wg.Add(1)
   278  		go func() {
   279  			defer c.wg.Done()
   280  			err := p.Run()
   281  			if err != nil {
   282  				log.Println(err)
   283  			}
   284  		}()
   285  
   286  	}
   287  
   288  	readyCh := make(chan struct{})
   289  	go func() {
   290  		rate := 200 * time.Millisecond
   291  		ticker := time.NewTicker(rate)
   292  		defer ticker.Stop()
   293  		for {
   294  			for t := range notReady {
   295  				if err := func() error {
   296  					ctx, cancel := context.WithTimeout(context.Background(), rate)
   297  					defer cancel()
   298  
   299  					var found bool
   300  					var err error
   301  
   302  					if c.v2 {
   303  						found, err = c.v2ReadyCheckComponent(ctx, t)
   304  					} else {
   305  						found, err = c.v1ReadyCheckComponent(ctx, t)
   306  					}
   307  					if found {
   308  						if err != nil {
   309  							return err
   310  						}
   311  						return nil
   312  					}
   313  
   314  					// fallback to http ready check
   315  					return t.httpReadyCheck(ctx)
   316  				}(); err != nil {
   317  					notReady[t] = err
   318  				} else {
   319  					delete(notReady, t)
   320  				}
   321  
   322  			}
   323  
   324  			if len(notReady) == 0 {
   325  				close(readyCh)
   326  				break
   327  			}
   328  
   329  			<-ticker.C
   330  		}
   331  	}()
   332  
   333  	<-readyCh
   334  
   335  	return nil
   336  }
   337  
   338  func (c *Cluster) Wait() {
   339  	c.wg.Wait()
   340  }
   341  
   342  func (c *Cluster) QueryClient() querierv1connect.QuerierServiceClient {
   343  	return querierv1connect.NewQuerierServiceClient(
   344  		c.httpClient,
   345  		"http://querier",
   346  		connectapi.DefaultClientOptions()...,
   347  	)
   348  }
   349  
   350  func (c *Cluster) PushClient() pushv1connect.PusherServiceClient {
   351  	return pushv1connect.NewPusherServiceClient(
   352  		c.httpClient,
   353  		"http://push",
   354  		connectapi.DefaultClientOptions()...,
   355  	)
   356  }