github.com/grafana/pyroscope@v1.18.0/pkg/segmentwriter/client/client_test.go (about)

     1  package segmentwriterclient
     2  
     3  import (
     4  	"context"
     5  	"flag"
     6  	"fmt"
     7  	"io"
     8  	"net"
     9  	"os"
    10  	"testing"
    11  	"time"
    12  
    13  	"github.com/go-kit/log"
    14  	"github.com/grafana/dskit/grpcclient"
    15  	"github.com/grafana/dskit/ring"
    16  	"github.com/grafana/dskit/services"
    17  	"github.com/stretchr/testify/mock"
    18  	"github.com/stretchr/testify/suite"
    19  	"google.golang.org/grpc"
    20  	"google.golang.org/grpc/codes"
    21  	"google.golang.org/grpc/credentials/insecure"
    22  	"google.golang.org/grpc/status"
    23  	"google.golang.org/grpc/test/bufconn"
    24  
    25  	segmentwriterv1 "github.com/grafana/pyroscope/api/gen/proto/go/segmentwriter/v1"
    26  	"github.com/grafana/pyroscope/pkg/segmentwriter/client/distributor/placement"
    27  	"github.com/grafana/pyroscope/pkg/testhelper"
    28  )
    29  
    30  type segwriterServerMock struct {
    31  	segmentwriterv1.UnimplementedSegmentWriterServiceServer
    32  	mock.Mock
    33  }
    34  
    35  func (m *segwriterServerMock) Push(
    36  	ctx context.Context,
    37  	req *segmentwriterv1.PushRequest,
    38  ) (*segmentwriterv1.PushResponse, error) {
    39  	args := m.Called(ctx, req)
    40  	return args.Get(0).(*segmentwriterv1.PushResponse), args.Error(1)
    41  }
    42  
    43  type testPlacement struct{}
    44  
    45  func (testPlacement) Policy(k placement.Key) placement.Policy {
    46  	return placement.Policy{
    47  		TenantShards:  0, // Unlimited.
    48  		DatasetShards: 1,
    49  		PickShard: func(n int) int {
    50  			return int(k.Fingerprint % uint64(n))
    51  		},
    52  	}
    53  }
    54  
    55  type segwriterClientSuite struct {
    56  	suite.Suite
    57  
    58  	listener *bufconn.Listener
    59  	dialer   func(context.Context, string) (net.Conn, error)
    60  	server   *grpc.Server
    61  	service  *segwriterServerMock
    62  	done     chan struct{}
    63  
    64  	logger log.Logger
    65  	config grpcclient.Config
    66  	ring   testhelper.MockRing
    67  	client *Client
    68  }
    69  
    70  func (s *segwriterClientSuite) SetupTest() {
    71  	listener := bufconn.Listen(256 << 10)
    72  	s.listener = listener
    73  	s.dialer = func(context.Context, string) (net.Conn, error) { return listener.Dial() }
    74  	s.server = grpc.NewServer()
    75  	s.service = new(segwriterServerMock)
    76  	segmentwriterv1.RegisterSegmentWriterServiceServer(s.server, s.service)
    77  
    78  	s.logger = log.NewLogfmtLogger(os.Stdout)
    79  	s.config = grpcclient.Config{}
    80  	s.config.RegisterFlags(flag.NewFlagSet("", flag.PanicOnError))
    81  	instances := []ring.InstanceDesc{
    82  		{Id: "a", Addr: "localhost", Tokens: make([]uint32, 1)},
    83  		{Id: "b", Addr: "localhost", Tokens: make([]uint32, 1)},
    84  		{Id: "c", Addr: "localhost", Tokens: make([]uint32, 1)},
    85  	}
    86  	s.ring = testhelper.NewMockRing(instances, 1)
    87  
    88  	var err error
    89  	s.client, err = NewSegmentWriterClient(
    90  		s.config, s.logger, nil, s.ring,
    91  		testPlacement{},
    92  		grpc.WithContextDialer(s.dialer))
    93  	s.Require().NoError(err)
    94  
    95  	s.done = make(chan struct{})
    96  	go func() {
    97  		defer close(s.done)
    98  		s.Require().NoError(s.server.Serve(listener))
    99  	}()
   100  
   101  	// Wait for the server
   102  	conn, err := grpc.NewClient("",
   103  		grpc.WithContextDialer(s.dialer),
   104  		grpc.WithTransportCredentials(insecure.NewCredentials()),
   105  	)
   106  
   107  	s.Require().NoError(err)
   108  	s.Require().NoError(conn.Close())
   109  }
   110  
   111  func (s *segwriterClientSuite) BeforeTest(_, _ string) {
   112  	svc := s.client.Service()
   113  	s.Require().NoError(svc.StartAsync(context.Background()))
   114  	s.Require().NoError(svc.AwaitRunning(context.Background()))
   115  	s.Require().Equal(services.Running, svc.State())
   116  }
   117  
   118  func (s *segwriterClientSuite) AfterTest(_, _ string) {
   119  	svc := s.client.Service()
   120  	svc.StopAsync()
   121  	s.Require().NoError(svc.AwaitTerminated(context.Background()))
   122  	s.Require().Equal(services.Terminated, svc.State())
   123  
   124  	s.service.AssertExpectations(s.T())
   125  }
   126  
   127  func (s *segwriterClientSuite) TearDownTest() {
   128  	s.server.GracefulStop()
   129  	<-s.done
   130  }
   131  
   132  func TestSegmentWriterClientSuite(t *testing.T) { suite.Run(t, new(segwriterClientSuite)) }
   133  
   134  func (s *segwriterClientSuite) Test_Push_HappyPath() {
   135  	s.service.On("Push", mock.Anything, mock.Anything).
   136  		Return(&segmentwriterv1.PushResponse{}, nil).
   137  		Once()
   138  
   139  	_, err := s.client.Push(context.Background(), &segmentwriterv1.PushRequest{})
   140  	s.Assert().NoError(err)
   141  }
   142  
   143  func (s *segwriterClientSuite) Test_Push_EmptyRing() {
   144  	emptyRing := testhelper.NewMockRing(nil, 1)
   145  	var err error
   146  	s.client, err = NewSegmentWriterClient(
   147  		s.config, s.logger, nil, emptyRing,
   148  		testPlacement{},
   149  		grpc.WithContextDialer(s.dialer))
   150  	s.Require().NoError(err)
   151  
   152  	_, err = s.client.Push(context.Background(), &segmentwriterv1.PushRequest{})
   153  	s.Assert().Equal(codes.Unavailable.String(), status.Code(err).String())
   154  }
   155  
   156  func (s *segwriterClientSuite) Test_Push_ClientError_Cancellation() {
   157  	s.service.On("Push", mock.Anything, mock.Anything).
   158  		Return(new(segmentwriterv1.PushResponse), context.Canceled).
   159  		Once()
   160  
   161  	_, err := s.client.Push(context.Background(), &segmentwriterv1.PushRequest{})
   162  	s.Assert().Equal(codes.Canceled.String(), status.Code(err).String())
   163  }
   164  
   165  func (s *segwriterClientSuite) Test_Push_Client_Deadline() {
   166  	ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(-time.Second))
   167  	defer cancel()
   168  	_, err := s.client.Push(ctx, &segmentwriterv1.PushRequest{})
   169  	s.Assert().ErrorIs(err, context.DeadlineExceeded)
   170  }
   171  
   172  func (s *segwriterClientSuite) Test_Push_NonClient_Deadline() {
   173  	s.service.On("Push", mock.Anything, mock.Anything).
   174  		Return(new(segmentwriterv1.PushResponse), context.DeadlineExceeded).
   175  		Once()
   176  
   177  	s.service.On("Push", mock.Anything, mock.Anything).
   178  		Return(new(segmentwriterv1.PushResponse), nil).
   179  		Once()
   180  
   181  	_, err := s.client.Push(context.Background(), &segmentwriterv1.PushRequest{})
   182  	s.Assert().NoError(err)
   183  }
   184  
   185  func (s *segwriterClientSuite) Test_Push_ClientError_InvalidArgument() {
   186  	s.service.On("Push", mock.Anything, mock.Anything).
   187  		Return(new(segmentwriterv1.PushResponse), status.Error(codes.InvalidArgument, errServiceUnavailableMsg)).
   188  		Once()
   189  
   190  	_, err := s.client.Push(context.Background(), &segmentwriterv1.PushRequest{})
   191  	s.Assert().Equal(codes.InvalidArgument.String(), status.Code(err).String())
   192  }
   193  
   194  func (s *segwriterClientSuite) Test_Push_ServerError_NonRetryable() {
   195  	s.service.On("Push", mock.Anything, mock.Anything).
   196  		Return(new(segmentwriterv1.PushResponse), io.EOF).
   197  		Once()
   198  
   199  	_, err := s.client.Push(context.Background(), &segmentwriterv1.PushRequest{})
   200  	s.Assert().Equal(codes.Unavailable.String(), status.Code(err).String())
   201  }
   202  
   203  func (s *segwriterClientSuite) Test_Push_ServerError_Retry_Unavailable() {
   204  	s.service.On("Push", mock.Anything, mock.Anything).
   205  		Return(new(segmentwriterv1.PushResponse), status.Error(codes.Unavailable, errServiceUnavailableMsg)).
   206  		Once()
   207  
   208  	s.service.On("Push", mock.Anything, mock.Anything).
   209  		Return(new(segmentwriterv1.PushResponse), nil).
   210  		Once()
   211  
   212  	_, err := s.client.Push(context.Background(), &segmentwriterv1.PushRequest{})
   213  	s.Assert().NoError(err)
   214  }
   215  
   216  func (s *segwriterClientSuite) Test_Push_ServerError_Retry_ResourceExhausted() {
   217  	s.service.On("Push", mock.Anything, mock.Anything).
   218  		Return(new(segmentwriterv1.PushResponse), status.Error(codes.ResourceExhausted, errServiceUnavailableMsg)).
   219  		Once()
   220  
   221  	s.service.On("Push", mock.Anything, mock.Anything).
   222  		Return(new(segmentwriterv1.PushResponse), nil).
   223  		Once()
   224  
   225  	_, err := s.client.Push(context.Background(), &segmentwriterv1.PushRequest{})
   226  	s.Assert().NoError(err)
   227  }
   228  
   229  func (s *segwriterClientSuite) Test_Push_DialError() {
   230  	dialer := func(ctx context.Context, s string) (net.Conn, error) {
   231  		return nil, io.EOF
   232  	}
   233  	var err error
   234  	s.client, err = NewSegmentWriterClient(
   235  		s.config, s.logger, nil, s.ring,
   236  		testPlacement{},
   237  		grpc.WithContextDialer(dialer))
   238  	s.Require().NoError(err)
   239  
   240  	_, err = s.client.Push(context.Background(), &segmentwriterv1.PushRequest{})
   241  	s.Assert().Equal(codes.Unavailable.String(), status.Code(err).String())
   242  }
   243  
   244  func (s *segwriterClientSuite) Test_Push_DialError_Retry() {
   245  	var failed bool
   246  	dialer := func(context.Context, string) (net.Conn, error) {
   247  		if failed {
   248  			return nil, net.UnknownNetworkError("network issue")
   249  		}
   250  		failed = true
   251  		return s.listener.Dial()
   252  	}
   253  	var err error
   254  	s.client, err = NewSegmentWriterClient(
   255  		s.config, s.logger, nil, s.ring,
   256  		testPlacement{},
   257  		grpc.WithContextDialer(dialer))
   258  	s.Require().NoError(err)
   259  
   260  	s.service.On("Push", mock.Anything, mock.Anything).
   261  		Return(new(segmentwriterv1.PushResponse), nil).
   262  		Once()
   263  
   264  	_, err = s.client.Push(context.Background(), &segmentwriterv1.PushRequest{})
   265  	s.Assert().NoError(err)
   266  }
   267  
   268  func (s *segwriterClientSuite) Test_Push_AllInstancesUnavailable() {
   269  	s.service.On("Push", mock.Anything, mock.Anything).
   270  		Return(new(segmentwriterv1.PushResponse), status.Error(codes.Unavailable, errServiceUnavailableMsg))
   271  
   272  	_, err := s.client.Push(context.Background(), &segmentwriterv1.PushRequest{})
   273  	s.Assert().Equal(codes.Unavailable.String(), status.Code(err).String())
   274  }
   275  
   276  func (s *segwriterClientSuite) Test_Push_ConnTimeout() {
   277  	dialer := func(ctx context.Context, _ string) (net.Conn, error) {
   278  		<-ctx.Done()
   279  		return nil, fmt.Errorf("dial error")
   280  	}
   281  
   282  	// Unfortunately, we can't set arbitrary timeout
   283  	// here: the minimal allowed value is 1s.
   284  	s.config.ConnectTimeout = time.Second
   285  	var err error
   286  	s.client, err = NewSegmentWriterClient(
   287  		s.config, s.logger, nil, s.ring,
   288  		testPlacement{},
   289  		grpc.WithContextDialer(dialer))
   290  	s.Require().NoError(err)
   291  
   292  	// Note that we use the background context: we do not
   293  	// want to wait for the context to expire, but fail
   294  	// fast, once the connection timeout expires.
   295  	_, err = s.client.Push(context.Background(), &segmentwriterv1.PushRequest{})
   296  	// The client, however, won't see the underlying error.
   297  	s.Require().NotNil(err)
   298  	s.Assert().Contains(err.Error(), errServiceUnavailableMsg)
   299  }