github.com/gravitational/teleport/api@v0.0.0-20240507183017-3110591cbafc/utils/grpc/interceptors/errors_test.go (about)

     1  // Copyright 2023 Gravitational, Inc
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package interceptors_test
    16  
    17  import (
    18  	"context"
    19  	"errors"
    20  	"io"
    21  	"net"
    22  	"testing"
    23  	"time"
    24  
    25  	"github.com/gravitational/trace"
    26  	"github.com/stretchr/testify/assert"
    27  	"github.com/stretchr/testify/require"
    28  	"google.golang.org/grpc"
    29  	"google.golang.org/grpc/credentials/insecure"
    30  
    31  	"github.com/gravitational/teleport/api/client/proto"
    32  	"github.com/gravitational/teleport/api/utils/grpc/interceptors"
    33  )
    34  
    35  type errService struct {
    36  	proto.UnimplementedAuthServiceServer
    37  }
    38  
    39  func (s *errService) Ping(ctx context.Context, req *proto.PingRequest) (*proto.PingResponse, error) {
    40  	return nil, trace.NotFound("not found")
    41  }
    42  
    43  func (s *errService) AddMFADevice(stream proto.AuthService_AddMFADeviceServer) error {
    44  	return trace.AlreadyExists("already exists")
    45  }
    46  
    47  // TestGRPCErrorWrapping tests the error wrapping capability of the client
    48  // and server unary and stream interceptors
    49  func TestGRPCErrorWrapping(t *testing.T) {
    50  	t.Parallel()
    51  
    52  	listener, err := net.Listen("tcp", "localhost:0")
    53  	require.NoError(t, err)
    54  
    55  	server := grpc.NewServer(
    56  		grpc.ChainUnaryInterceptor(interceptors.GRPCServerUnaryErrorInterceptor),
    57  		grpc.ChainStreamInterceptor(interceptors.GRPCServerStreamErrorInterceptor),
    58  	)
    59  	proto.RegisterAuthServiceServer(server, &errService{})
    60  	go func() {
    61  		server.Serve(listener)
    62  	}()
    63  	defer server.Stop()
    64  
    65  	conn, err := grpc.Dial(
    66  		listener.Addr().String(),
    67  		grpc.WithTransportCredentials(insecure.NewCredentials()),
    68  		grpc.WithChainUnaryInterceptor(interceptors.GRPCClientUnaryErrorInterceptor),
    69  		grpc.WithChainStreamInterceptor(interceptors.GRPCClientStreamErrorInterceptor),
    70  	)
    71  	require.NoError(t, err)
    72  	defer conn.Close()
    73  
    74  	client := proto.NewAuthServiceClient(conn)
    75  
    76  	t.Run("unary interceptor", func(t *testing.T) {
    77  		resp, err := client.Ping(context.Background(), &proto.PingRequest{})
    78  		assert.Nil(t, resp, "resp is non-nil")
    79  		assert.True(t, trace.IsNotFound(err), "trace.IsNotFound failed: err=%v (%T)", err, trace.Unwrap(err))
    80  		assert.Equal(t, "not found", err.Error())
    81  
    82  		var traceErr *trace.TraceErr
    83  		assert.False(t, errors.As(err, &traceErr), "client error should not include traces originating in the middleware")
    84  		var remoteErr *interceptors.RemoteError
    85  		assert.ErrorAs(t, err, &remoteErr, "Remote error is not marked as an interceptors.RemoteError")
    86  	})
    87  
    88  	t.Run("stream interceptor", func(t *testing.T) {
    89  		//nolint:staticcheck // SA1019. The specific stream used here doesn't matter.
    90  		stream, err := client.AddMFADevice(context.Background())
    91  		require.NoError(t, err)
    92  
    93  		// Give the server time to close the stream. This allows us to more
    94  		// consistently hit the io.EOF error.
    95  		time.Sleep(100 * time.Millisecond)
    96  
    97  		//nolint:staticcheck // SA1019. The specific stream used here doesn't matter.
    98  		sendErr := stream.Send(&proto.AddMFADeviceRequest{})
    99  
   100  		// Expect either a success (unlikely because of the Sleep) or an unwrapped
   101  		// io.EOF error (meaning the server errored and closed the stream).
   102  		// In either case, it is still safe to recv from the stream and check for
   103  		// the already exists error.
   104  		//nolint:errorlint //comparison != error comparison on purpose!
   105  		if sendErr != nil && sendErr != io.EOF {
   106  			t.Fatalf("Unexpected error: %q (%T)", sendErr, sendErr)
   107  		}
   108  
   109  		_, err = stream.Recv()
   110  		assert.True(t, trace.IsAlreadyExists(err), "trace.IsAlreadyExists failed: err=%v (%T)", err, trace.Unwrap(err))
   111  		assert.Equal(t, "already exists", err.Error())
   112  		var traceErr *trace.TraceErr
   113  		assert.False(t, errors.As(err, &traceErr), "client error should not include traces originating in the middleware")
   114  		assert.True(t, trace.IsAlreadyExists(err), "trace.IsAlreadyExists failed: err=%v (%T)", err, trace.Unwrap(err))
   115  		var remoteErr *interceptors.RemoteError
   116  		assert.ErrorAs(t, err, &remoteErr, "Remote error is not marked as an interceptors.RemoteError")
   117  	})
   118  }