github.com/grafana/pyroscope@v1.18.0/pkg/segmentwriter/client/client_test.go (about) 1 package segmentwriterclient 2 3 import ( 4 "context" 5 "flag" 6 "fmt" 7 "io" 8 "net" 9 "os" 10 "testing" 11 "time" 12 13 "github.com/go-kit/log" 14 "github.com/grafana/dskit/grpcclient" 15 "github.com/grafana/dskit/ring" 16 "github.com/grafana/dskit/services" 17 "github.com/stretchr/testify/mock" 18 "github.com/stretchr/testify/suite" 19 "google.golang.org/grpc" 20 "google.golang.org/grpc/codes" 21 "google.golang.org/grpc/credentials/insecure" 22 "google.golang.org/grpc/status" 23 "google.golang.org/grpc/test/bufconn" 24 25 segmentwriterv1 "github.com/grafana/pyroscope/api/gen/proto/go/segmentwriter/v1" 26 "github.com/grafana/pyroscope/pkg/segmentwriter/client/distributor/placement" 27 "github.com/grafana/pyroscope/pkg/testhelper" 28 ) 29 30 type segwriterServerMock struct { 31 segmentwriterv1.UnimplementedSegmentWriterServiceServer 32 mock.Mock 33 } 34 35 func (m *segwriterServerMock) Push( 36 ctx context.Context, 37 req *segmentwriterv1.PushRequest, 38 ) (*segmentwriterv1.PushResponse, error) { 39 args := m.Called(ctx, req) 40 return args.Get(0).(*segmentwriterv1.PushResponse), args.Error(1) 41 } 42 43 type testPlacement struct{} 44 45 func (testPlacement) Policy(k placement.Key) placement.Policy { 46 return placement.Policy{ 47 TenantShards: 0, // Unlimited. 48 DatasetShards: 1, 49 PickShard: func(n int) int { 50 return int(k.Fingerprint % uint64(n)) 51 }, 52 } 53 } 54 55 type segwriterClientSuite struct { 56 suite.Suite 57 58 listener *bufconn.Listener 59 dialer func(context.Context, string) (net.Conn, error) 60 server *grpc.Server 61 service *segwriterServerMock 62 done chan struct{} 63 64 logger log.Logger 65 config grpcclient.Config 66 ring testhelper.MockRing 67 client *Client 68 } 69 70 func (s *segwriterClientSuite) SetupTest() { 71 listener := bufconn.Listen(256 << 10) 72 s.listener = listener 73 s.dialer = func(context.Context, string) (net.Conn, error) { return listener.Dial() } 74 s.server = grpc.NewServer() 75 s.service = new(segwriterServerMock) 76 segmentwriterv1.RegisterSegmentWriterServiceServer(s.server, s.service) 77 78 s.logger = log.NewLogfmtLogger(os.Stdout) 79 s.config = grpcclient.Config{} 80 s.config.RegisterFlags(flag.NewFlagSet("", flag.PanicOnError)) 81 instances := []ring.InstanceDesc{ 82 {Id: "a", Addr: "localhost", Tokens: make([]uint32, 1)}, 83 {Id: "b", Addr: "localhost", Tokens: make([]uint32, 1)}, 84 {Id: "c", Addr: "localhost", Tokens: make([]uint32, 1)}, 85 } 86 s.ring = testhelper.NewMockRing(instances, 1) 87 88 var err error 89 s.client, err = NewSegmentWriterClient( 90 s.config, s.logger, nil, s.ring, 91 testPlacement{}, 92 grpc.WithContextDialer(s.dialer)) 93 s.Require().NoError(err) 94 95 s.done = make(chan struct{}) 96 go func() { 97 defer close(s.done) 98 s.Require().NoError(s.server.Serve(listener)) 99 }() 100 101 // Wait for the server 102 conn, err := grpc.NewClient("", 103 grpc.WithContextDialer(s.dialer), 104 grpc.WithTransportCredentials(insecure.NewCredentials()), 105 ) 106 107 s.Require().NoError(err) 108 s.Require().NoError(conn.Close()) 109 } 110 111 func (s *segwriterClientSuite) BeforeTest(_, _ string) { 112 svc := s.client.Service() 113 s.Require().NoError(svc.StartAsync(context.Background())) 114 s.Require().NoError(svc.AwaitRunning(context.Background())) 115 s.Require().Equal(services.Running, svc.State()) 116 } 117 118 func (s *segwriterClientSuite) AfterTest(_, _ string) { 119 svc := s.client.Service() 120 svc.StopAsync() 121 s.Require().NoError(svc.AwaitTerminated(context.Background())) 122 s.Require().Equal(services.Terminated, svc.State()) 123 124 s.service.AssertExpectations(s.T()) 125 } 126 127 func (s *segwriterClientSuite) TearDownTest() { 128 s.server.GracefulStop() 129 <-s.done 130 } 131 132 func TestSegmentWriterClientSuite(t *testing.T) { suite.Run(t, new(segwriterClientSuite)) } 133 134 func (s *segwriterClientSuite) Test_Push_HappyPath() { 135 s.service.On("Push", mock.Anything, mock.Anything). 136 Return(&segmentwriterv1.PushResponse{}, nil). 137 Once() 138 139 _, err := s.client.Push(context.Background(), &segmentwriterv1.PushRequest{}) 140 s.Assert().NoError(err) 141 } 142 143 func (s *segwriterClientSuite) Test_Push_EmptyRing() { 144 emptyRing := testhelper.NewMockRing(nil, 1) 145 var err error 146 s.client, err = NewSegmentWriterClient( 147 s.config, s.logger, nil, emptyRing, 148 testPlacement{}, 149 grpc.WithContextDialer(s.dialer)) 150 s.Require().NoError(err) 151 152 _, err = s.client.Push(context.Background(), &segmentwriterv1.PushRequest{}) 153 s.Assert().Equal(codes.Unavailable.String(), status.Code(err).String()) 154 } 155 156 func (s *segwriterClientSuite) Test_Push_ClientError_Cancellation() { 157 s.service.On("Push", mock.Anything, mock.Anything). 158 Return(new(segmentwriterv1.PushResponse), context.Canceled). 159 Once() 160 161 _, err := s.client.Push(context.Background(), &segmentwriterv1.PushRequest{}) 162 s.Assert().Equal(codes.Canceled.String(), status.Code(err).String()) 163 } 164 165 func (s *segwriterClientSuite) Test_Push_Client_Deadline() { 166 ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(-time.Second)) 167 defer cancel() 168 _, err := s.client.Push(ctx, &segmentwriterv1.PushRequest{}) 169 s.Assert().ErrorIs(err, context.DeadlineExceeded) 170 } 171 172 func (s *segwriterClientSuite) Test_Push_NonClient_Deadline() { 173 s.service.On("Push", mock.Anything, mock.Anything). 174 Return(new(segmentwriterv1.PushResponse), context.DeadlineExceeded). 175 Once() 176 177 s.service.On("Push", mock.Anything, mock.Anything). 178 Return(new(segmentwriterv1.PushResponse), nil). 179 Once() 180 181 _, err := s.client.Push(context.Background(), &segmentwriterv1.PushRequest{}) 182 s.Assert().NoError(err) 183 } 184 185 func (s *segwriterClientSuite) Test_Push_ClientError_InvalidArgument() { 186 s.service.On("Push", mock.Anything, mock.Anything). 187 Return(new(segmentwriterv1.PushResponse), status.Error(codes.InvalidArgument, errServiceUnavailableMsg)). 188 Once() 189 190 _, err := s.client.Push(context.Background(), &segmentwriterv1.PushRequest{}) 191 s.Assert().Equal(codes.InvalidArgument.String(), status.Code(err).String()) 192 } 193 194 func (s *segwriterClientSuite) Test_Push_ServerError_NonRetryable() { 195 s.service.On("Push", mock.Anything, mock.Anything). 196 Return(new(segmentwriterv1.PushResponse), io.EOF). 197 Once() 198 199 _, err := s.client.Push(context.Background(), &segmentwriterv1.PushRequest{}) 200 s.Assert().Equal(codes.Unavailable.String(), status.Code(err).String()) 201 } 202 203 func (s *segwriterClientSuite) Test_Push_ServerError_Retry_Unavailable() { 204 s.service.On("Push", mock.Anything, mock.Anything). 205 Return(new(segmentwriterv1.PushResponse), status.Error(codes.Unavailable, errServiceUnavailableMsg)). 206 Once() 207 208 s.service.On("Push", mock.Anything, mock.Anything). 209 Return(new(segmentwriterv1.PushResponse), nil). 210 Once() 211 212 _, err := s.client.Push(context.Background(), &segmentwriterv1.PushRequest{}) 213 s.Assert().NoError(err) 214 } 215 216 func (s *segwriterClientSuite) Test_Push_ServerError_Retry_ResourceExhausted() { 217 s.service.On("Push", mock.Anything, mock.Anything). 218 Return(new(segmentwriterv1.PushResponse), status.Error(codes.ResourceExhausted, errServiceUnavailableMsg)). 219 Once() 220 221 s.service.On("Push", mock.Anything, mock.Anything). 222 Return(new(segmentwriterv1.PushResponse), nil). 223 Once() 224 225 _, err := s.client.Push(context.Background(), &segmentwriterv1.PushRequest{}) 226 s.Assert().NoError(err) 227 } 228 229 func (s *segwriterClientSuite) Test_Push_DialError() { 230 dialer := func(ctx context.Context, s string) (net.Conn, error) { 231 return nil, io.EOF 232 } 233 var err error 234 s.client, err = NewSegmentWriterClient( 235 s.config, s.logger, nil, s.ring, 236 testPlacement{}, 237 grpc.WithContextDialer(dialer)) 238 s.Require().NoError(err) 239 240 _, err = s.client.Push(context.Background(), &segmentwriterv1.PushRequest{}) 241 s.Assert().Equal(codes.Unavailable.String(), status.Code(err).String()) 242 } 243 244 func (s *segwriterClientSuite) Test_Push_DialError_Retry() { 245 var failed bool 246 dialer := func(context.Context, string) (net.Conn, error) { 247 if failed { 248 return nil, net.UnknownNetworkError("network issue") 249 } 250 failed = true 251 return s.listener.Dial() 252 } 253 var err error 254 s.client, err = NewSegmentWriterClient( 255 s.config, s.logger, nil, s.ring, 256 testPlacement{}, 257 grpc.WithContextDialer(dialer)) 258 s.Require().NoError(err) 259 260 s.service.On("Push", mock.Anything, mock.Anything). 261 Return(new(segmentwriterv1.PushResponse), nil). 262 Once() 263 264 _, err = s.client.Push(context.Background(), &segmentwriterv1.PushRequest{}) 265 s.Assert().NoError(err) 266 } 267 268 func (s *segwriterClientSuite) Test_Push_AllInstancesUnavailable() { 269 s.service.On("Push", mock.Anything, mock.Anything). 270 Return(new(segmentwriterv1.PushResponse), status.Error(codes.Unavailable, errServiceUnavailableMsg)) 271 272 _, err := s.client.Push(context.Background(), &segmentwriterv1.PushRequest{}) 273 s.Assert().Equal(codes.Unavailable.String(), status.Code(err).String()) 274 } 275 276 func (s *segwriterClientSuite) Test_Push_ConnTimeout() { 277 dialer := func(ctx context.Context, _ string) (net.Conn, error) { 278 <-ctx.Done() 279 return nil, fmt.Errorf("dial error") 280 } 281 282 // Unfortunately, we can't set arbitrary timeout 283 // here: the minimal allowed value is 1s. 284 s.config.ConnectTimeout = time.Second 285 var err error 286 s.client, err = NewSegmentWriterClient( 287 s.config, s.logger, nil, s.ring, 288 testPlacement{}, 289 grpc.WithContextDialer(dialer)) 290 s.Require().NoError(err) 291 292 // Note that we use the background context: we do not 293 // want to wait for the context to expire, but fail 294 // fast, once the connection timeout expires. 295 _, err = s.client.Push(context.Background(), &segmentwriterv1.PushRequest{}) 296 // The client, however, won't see the underlying error. 297 s.Require().NotNil(err) 298 s.Assert().Contains(err.Error(), errServiceUnavailableMsg) 299 }