github.com/gravitational/teleport/api@v0.0.0-20240507183017-3110591cbafc/utils/grpc/interceptors/mfa_test.go (about)

     1  // Copyright 2023 Gravitational, Inc
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package interceptors_test
    16  
    17  import (
    18  	"context"
    19  	"net"
    20  	"testing"
    21  
    22  	"github.com/gravitational/trace"
    23  	"github.com/stretchr/testify/assert"
    24  	"github.com/stretchr/testify/require"
    25  	"google.golang.org/grpc"
    26  	"google.golang.org/grpc/credentials"
    27  
    28  	"github.com/gravitational/teleport/api/client/proto"
    29  	"github.com/gravitational/teleport/api/mfa"
    30  	"github.com/gravitational/teleport/api/testhelpers/mtls"
    31  	"github.com/gravitational/teleport/api/utils/grpc/interceptors"
    32  )
    33  
    34  const (
    35  	otpTestCode         = "otp-test-code"
    36  	otpTestCodeReusable = "otp-test-code-reusable"
    37  )
    38  
    39  type mfaService struct {
    40  	allowReuse bool
    41  	proto.UnimplementedAuthServiceServer
    42  }
    43  
    44  func (s *mfaService) Ping(ctx context.Context, req *proto.PingRequest) (*proto.PingResponse, error) {
    45  	if err := s.verifyMFAFromContext(ctx); err != nil {
    46  		return nil, trace.Wrap(err)
    47  	}
    48  	return &proto.PingResponse{}, nil
    49  }
    50  
    51  func (s *mfaService) verifyMFAFromContext(ctx context.Context) error {
    52  	mfaResp, err := mfa.CredentialsFromContext(ctx)
    53  	if err != nil {
    54  		// (In production consider logging err, so we don't swallow it silently.)
    55  		return trace.Wrap(&mfa.ErrAdminActionMFARequired)
    56  	}
    57  
    58  	switch r := mfaResp.Response.(type) {
    59  	case *proto.MFAAuthenticateResponse_TOTP:
    60  		switch r.TOTP.Code {
    61  		case otpTestCode:
    62  			return nil
    63  		case otpTestCodeReusable:
    64  			if s.allowReuse {
    65  				return nil
    66  			}
    67  			fallthrough
    68  		default:
    69  			return trace.Wrap(&mfa.ErrAdminActionMFARequired)
    70  		}
    71  	default:
    72  		return trace.BadParameter("unexpected mfa response type %T", r)
    73  	}
    74  }
    75  
    76  // TestGRPCErrorWrapping tests the error wrapping capability of the client
    77  // and server unary and stream interceptors
    78  func TestRetryWithMFA(t *testing.T) {
    79  	t.Parallel()
    80  	ctx := context.Background()
    81  
    82  	mtlsConfig := mtls.NewConfig(t)
    83  	listener, err := net.Listen("tcp", "localhost:0")
    84  	require.NoError(t, err)
    85  
    86  	server := grpc.NewServer(
    87  		grpc.Creds(credentials.NewTLS(mtlsConfig.ServerTLS)),
    88  		grpc.ChainUnaryInterceptor(interceptors.GRPCServerUnaryErrorInterceptor),
    89  	)
    90  	proto.RegisterAuthServiceServer(server, &mfaService{})
    91  	go func() {
    92  		server.Serve(listener)
    93  	}()
    94  	defer server.Stop()
    95  
    96  	t.Run("without interceptor", func(t *testing.T) {
    97  		conn, err := grpc.Dial(
    98  			listener.Addr().String(),
    99  			grpc.WithTransportCredentials(credentials.NewTLS(mtlsConfig.ClientTLS)),
   100  			grpc.WithUnaryInterceptor(interceptors.GRPCClientUnaryErrorInterceptor),
   101  		)
   102  		require.NoError(t, err)
   103  		defer conn.Close()
   104  
   105  		client := proto.NewAuthServiceClient(conn)
   106  		_, err = client.Ping(context.Background(), &proto.PingRequest{})
   107  		assert.ErrorIs(t, err, &mfa.ErrAdminActionMFARequired, "Ping error mismatch")
   108  	})
   109  
   110  	okMFACeremony := func(ctx context.Context, challengeRequest *proto.CreateAuthenticateChallengeRequest, promptOpts ...mfa.PromptOpt) (*proto.MFAAuthenticateResponse, error) {
   111  		return &proto.MFAAuthenticateResponse{
   112  			Response: &proto.MFAAuthenticateResponse_TOTP{
   113  				TOTP: &proto.TOTPResponse{
   114  					Code: otpTestCode,
   115  				},
   116  			},
   117  		}, nil
   118  	}
   119  
   120  	mfaCeremonyErr := trace.BadParameter("client does not support mfa")
   121  	nokMFACeremony := func(ctx context.Context, challengeRequest *proto.CreateAuthenticateChallengeRequest, promptOpts ...mfa.PromptOpt) (*proto.MFAAuthenticateResponse, error) {
   122  		return nil, mfaCeremonyErr
   123  	}
   124  
   125  	t.Run("with interceptor", func(t *testing.T) {
   126  		t.Run("ok mfa ceremony", func(t *testing.T) {
   127  			conn, err := grpc.Dial(
   128  				listener.Addr().String(),
   129  				grpc.WithTransportCredentials(credentials.NewTLS(mtlsConfig.ClientTLS)),
   130  				grpc.WithChainUnaryInterceptor(
   131  					interceptors.WithMFAUnaryInterceptor(okMFACeremony),
   132  					interceptors.GRPCClientUnaryErrorInterceptor,
   133  				),
   134  			)
   135  			require.NoError(t, err)
   136  			defer conn.Close()
   137  
   138  			client := proto.NewAuthServiceClient(conn)
   139  			_, err = client.Ping(ctx, &proto.PingRequest{})
   140  			assert.NoError(t, err)
   141  		})
   142  
   143  		t.Run("nok mfa ceremony", func(t *testing.T) {
   144  			conn, err := grpc.Dial(
   145  				listener.Addr().String(),
   146  				grpc.WithTransportCredentials(credentials.NewTLS(mtlsConfig.ClientTLS)),
   147  				grpc.WithChainUnaryInterceptor(
   148  					interceptors.WithMFAUnaryInterceptor(nokMFACeremony),
   149  					interceptors.GRPCClientUnaryErrorInterceptor,
   150  				),
   151  			)
   152  			require.NoError(t, err)
   153  			defer conn.Close()
   154  
   155  			client := proto.NewAuthServiceClient(conn)
   156  			_, err = client.Ping(ctx, &proto.PingRequest{})
   157  			assert.ErrorIs(t, err, &mfa.ErrAdminActionMFARequired, "Ping error mismatch")
   158  			assert.ErrorIs(t, err, mfaCeremonyErr, "Ping error mismatch")
   159  		})
   160  
   161  		t.Run("ok mfa in context", func(t *testing.T) {
   162  			conn, err := grpc.Dial(
   163  				listener.Addr().String(),
   164  				grpc.WithTransportCredentials(credentials.NewTLS(mtlsConfig.ClientTLS)),
   165  				grpc.WithChainUnaryInterceptor(
   166  					interceptors.WithMFAUnaryInterceptor(nokMFACeremony),
   167  					interceptors.GRPCClientUnaryErrorInterceptor,
   168  				),
   169  			)
   170  			require.NoError(t, err)
   171  			defer conn.Close()
   172  
   173  			mfaResp, _ := okMFACeremony(ctx, nil)
   174  			ctx := mfa.ContextWithMFAResponse(ctx, mfaResp)
   175  
   176  			client := proto.NewAuthServiceClient(conn)
   177  			_, err = client.Ping(ctx, &proto.PingRequest{})
   178  			assert.NoError(t, err)
   179  		})
   180  	})
   181  }
   182  
   183  func TestRetryWithMFA_Reuse(t *testing.T) {
   184  	t.Parallel()
   185  	ctx := context.Background()
   186  
   187  	mtlsConfig := mtls.NewConfig(t)
   188  	listener, err := net.Listen("tcp", "localhost:0")
   189  	require.NoError(t, err)
   190  
   191  	mfaService := &mfaService{}
   192  	server := grpc.NewServer(
   193  		grpc.Creds(credentials.NewTLS(mtlsConfig.ServerTLS)),
   194  		grpc.ChainUnaryInterceptor(interceptors.GRPCServerUnaryErrorInterceptor),
   195  	)
   196  	proto.RegisterAuthServiceServer(server, mfaService)
   197  	go func() {
   198  		server.Serve(listener)
   199  	}()
   200  	defer server.Stop()
   201  
   202  	okMFACeremony := func(ctx context.Context, challengeRequest *proto.CreateAuthenticateChallengeRequest, promptOpts ...mfa.PromptOpt) (*proto.MFAAuthenticateResponse, error) {
   203  		return &proto.MFAAuthenticateResponse{
   204  			Response: &proto.MFAAuthenticateResponse_TOTP{
   205  				TOTP: &proto.TOTPResponse{
   206  					Code: otpTestCode,
   207  				},
   208  			},
   209  		}, nil
   210  	}
   211  
   212  	okMFACeremonyAllowReuse := func(ctx context.Context, challengeRequest *proto.CreateAuthenticateChallengeRequest, promptOpts ...mfa.PromptOpt) (*proto.MFAAuthenticateResponse, error) {
   213  		return &proto.MFAAuthenticateResponse{
   214  			Response: &proto.MFAAuthenticateResponse_TOTP{
   215  				TOTP: &proto.TOTPResponse{
   216  					Code: otpTestCodeReusable,
   217  				},
   218  			},
   219  		}, nil
   220  	}
   221  
   222  	t.Run("ok allow reuse", func(t *testing.T) {
   223  		mfaService.allowReuse = true
   224  		conn, err := grpc.Dial(
   225  			listener.Addr().String(),
   226  			grpc.WithTransportCredentials(credentials.NewTLS(mtlsConfig.ClientTLS)),
   227  			grpc.WithChainUnaryInterceptor(
   228  				interceptors.WithMFAUnaryInterceptor(okMFACeremonyAllowReuse),
   229  				interceptors.GRPCClientUnaryErrorInterceptor,
   230  			),
   231  		)
   232  		require.NoError(t, err)
   233  		defer conn.Close()
   234  
   235  		client := proto.NewAuthServiceClient(conn)
   236  		_, err = client.Ping(ctx, &proto.PingRequest{})
   237  		assert.NoError(t, err)
   238  	})
   239  
   240  	t.Run("nok disallow reuse", func(t *testing.T) {
   241  		mfaService.allowReuse = false
   242  		conn, err := grpc.Dial(
   243  			listener.Addr().String(),
   244  			grpc.WithTransportCredentials(credentials.NewTLS(mtlsConfig.ClientTLS)),
   245  			grpc.WithChainUnaryInterceptor(
   246  				interceptors.WithMFAUnaryInterceptor(okMFACeremonyAllowReuse),
   247  				interceptors.GRPCClientUnaryErrorInterceptor,
   248  			),
   249  		)
   250  		require.NoError(t, err)
   251  		defer conn.Close()
   252  
   253  		client := proto.NewAuthServiceClient(conn)
   254  		_, err = client.Ping(ctx, &proto.PingRequest{})
   255  		assert.ErrorIs(t, err, &mfa.ErrAdminActionMFARequired, "Ping error mismatch")
   256  	})
   257  
   258  	t.Run("ok disallow reuse, retry with one-shot mfa", func(t *testing.T) {
   259  		mfaService.allowReuse = false
   260  		conn, err := grpc.Dial(
   261  			listener.Addr().String(),
   262  			grpc.WithTransportCredentials(credentials.NewTLS(mtlsConfig.ClientTLS)),
   263  			grpc.WithChainUnaryInterceptor(
   264  				interceptors.WithMFAUnaryInterceptor(okMFACeremony),
   265  				interceptors.GRPCClientUnaryErrorInterceptor,
   266  			),
   267  		)
   268  		require.NoError(t, err)
   269  		defer conn.Close()
   270  
   271  		// Pass reusable MFA through the context. The interceptor should
   272  		// catch the resulting ErrAdminActionMFARequired and retry with
   273  		// a one-shot mfa challenge.
   274  		mfaResp, _ := okMFACeremony(ctx, nil)
   275  		ctx := mfa.ContextWithMFAResponse(ctx, mfaResp)
   276  
   277  		client := proto.NewAuthServiceClient(conn)
   278  		_, err = client.Ping(ctx, &proto.PingRequest{})
   279  		assert.NoError(t, err)
   280  	})
   281  }