github.com/gravitational/teleport/api@v0.0.0-20240507183017-3110591cbafc/utils/grpc/interceptors/errors_test.go (about) 1 // Copyright 2023 Gravitational, Inc 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package interceptors_test 16 17 import ( 18 "context" 19 "errors" 20 "io" 21 "net" 22 "testing" 23 "time" 24 25 "github.com/gravitational/trace" 26 "github.com/stretchr/testify/assert" 27 "github.com/stretchr/testify/require" 28 "google.golang.org/grpc" 29 "google.golang.org/grpc/credentials/insecure" 30 31 "github.com/gravitational/teleport/api/client/proto" 32 "github.com/gravitational/teleport/api/utils/grpc/interceptors" 33 ) 34 35 type errService struct { 36 proto.UnimplementedAuthServiceServer 37 } 38 39 func (s *errService) Ping(ctx context.Context, req *proto.PingRequest) (*proto.PingResponse, error) { 40 return nil, trace.NotFound("not found") 41 } 42 43 func (s *errService) AddMFADevice(stream proto.AuthService_AddMFADeviceServer) error { 44 return trace.AlreadyExists("already exists") 45 } 46 47 // TestGRPCErrorWrapping tests the error wrapping capability of the client 48 // and server unary and stream interceptors 49 func TestGRPCErrorWrapping(t *testing.T) { 50 t.Parallel() 51 52 listener, err := net.Listen("tcp", "localhost:0") 53 require.NoError(t, err) 54 55 server := grpc.NewServer( 56 grpc.ChainUnaryInterceptor(interceptors.GRPCServerUnaryErrorInterceptor), 57 grpc.ChainStreamInterceptor(interceptors.GRPCServerStreamErrorInterceptor), 58 ) 59 proto.RegisterAuthServiceServer(server, &errService{}) 60 go func() { 61 server.Serve(listener) 62 }() 63 defer server.Stop() 64 65 conn, err := grpc.Dial( 66 listener.Addr().String(), 67 grpc.WithTransportCredentials(insecure.NewCredentials()), 68 grpc.WithChainUnaryInterceptor(interceptors.GRPCClientUnaryErrorInterceptor), 69 grpc.WithChainStreamInterceptor(interceptors.GRPCClientStreamErrorInterceptor), 70 ) 71 require.NoError(t, err) 72 defer conn.Close() 73 74 client := proto.NewAuthServiceClient(conn) 75 76 t.Run("unary interceptor", func(t *testing.T) { 77 resp, err := client.Ping(context.Background(), &proto.PingRequest{}) 78 assert.Nil(t, resp, "resp is non-nil") 79 assert.True(t, trace.IsNotFound(err), "trace.IsNotFound failed: err=%v (%T)", err, trace.Unwrap(err)) 80 assert.Equal(t, "not found", err.Error()) 81 82 var traceErr *trace.TraceErr 83 assert.False(t, errors.As(err, &traceErr), "client error should not include traces originating in the middleware") 84 var remoteErr *interceptors.RemoteError 85 assert.ErrorAs(t, err, &remoteErr, "Remote error is not marked as an interceptors.RemoteError") 86 }) 87 88 t.Run("stream interceptor", func(t *testing.T) { 89 //nolint:staticcheck // SA1019. The specific stream used here doesn't matter. 90 stream, err := client.AddMFADevice(context.Background()) 91 require.NoError(t, err) 92 93 // Give the server time to close the stream. This allows us to more 94 // consistently hit the io.EOF error. 95 time.Sleep(100 * time.Millisecond) 96 97 //nolint:staticcheck // SA1019. The specific stream used here doesn't matter. 98 sendErr := stream.Send(&proto.AddMFADeviceRequest{}) 99 100 // Expect either a success (unlikely because of the Sleep) or an unwrapped 101 // io.EOF error (meaning the server errored and closed the stream). 102 // In either case, it is still safe to recv from the stream and check for 103 // the already exists error. 104 //nolint:errorlint //comparison != error comparison on purpose! 105 if sendErr != nil && sendErr != io.EOF { 106 t.Fatalf("Unexpected error: %q (%T)", sendErr, sendErr) 107 } 108 109 _, err = stream.Recv() 110 assert.True(t, trace.IsAlreadyExists(err), "trace.IsAlreadyExists failed: err=%v (%T)", err, trace.Unwrap(err)) 111 assert.Equal(t, "already exists", err.Error()) 112 var traceErr *trace.TraceErr 113 assert.False(t, errors.As(err, &traceErr), "client error should not include traces originating in the middleware") 114 assert.True(t, trace.IsAlreadyExists(err), "trace.IsAlreadyExists failed: err=%v (%T)", err, trace.Unwrap(err)) 115 var remoteErr *interceptors.RemoteError 116 assert.ErrorAs(t, err, &remoteErr, "Remote error is not marked as an interceptors.RemoteError") 117 }) 118 }