github.com/gravitational/teleport/api@v0.0.0-20240507183017-3110591cbafc/utils/grpc/interceptors/mfa_test.go (about) 1 // Copyright 2023 Gravitational, Inc 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package interceptors_test 16 17 import ( 18 "context" 19 "net" 20 "testing" 21 22 "github.com/gravitational/trace" 23 "github.com/stretchr/testify/assert" 24 "github.com/stretchr/testify/require" 25 "google.golang.org/grpc" 26 "google.golang.org/grpc/credentials" 27 28 "github.com/gravitational/teleport/api/client/proto" 29 "github.com/gravitational/teleport/api/mfa" 30 "github.com/gravitational/teleport/api/testhelpers/mtls" 31 "github.com/gravitational/teleport/api/utils/grpc/interceptors" 32 ) 33 34 const ( 35 otpTestCode = "otp-test-code" 36 otpTestCodeReusable = "otp-test-code-reusable" 37 ) 38 39 type mfaService struct { 40 allowReuse bool 41 proto.UnimplementedAuthServiceServer 42 } 43 44 func (s *mfaService) Ping(ctx context.Context, req *proto.PingRequest) (*proto.PingResponse, error) { 45 if err := s.verifyMFAFromContext(ctx); err != nil { 46 return nil, trace.Wrap(err) 47 } 48 return &proto.PingResponse{}, nil 49 } 50 51 func (s *mfaService) verifyMFAFromContext(ctx context.Context) error { 52 mfaResp, err := mfa.CredentialsFromContext(ctx) 53 if err != nil { 54 // (In production consider logging err, so we don't swallow it silently.) 55 return trace.Wrap(&mfa.ErrAdminActionMFARequired) 56 } 57 58 switch r := mfaResp.Response.(type) { 59 case *proto.MFAAuthenticateResponse_TOTP: 60 switch r.TOTP.Code { 61 case otpTestCode: 62 return nil 63 case otpTestCodeReusable: 64 if s.allowReuse { 65 return nil 66 } 67 fallthrough 68 default: 69 return trace.Wrap(&mfa.ErrAdminActionMFARequired) 70 } 71 default: 72 return trace.BadParameter("unexpected mfa response type %T", r) 73 } 74 } 75 76 // TestGRPCErrorWrapping tests the error wrapping capability of the client 77 // and server unary and stream interceptors 78 func TestRetryWithMFA(t *testing.T) { 79 t.Parallel() 80 ctx := context.Background() 81 82 mtlsConfig := mtls.NewConfig(t) 83 listener, err := net.Listen("tcp", "localhost:0") 84 require.NoError(t, err) 85 86 server := grpc.NewServer( 87 grpc.Creds(credentials.NewTLS(mtlsConfig.ServerTLS)), 88 grpc.ChainUnaryInterceptor(interceptors.GRPCServerUnaryErrorInterceptor), 89 ) 90 proto.RegisterAuthServiceServer(server, &mfaService{}) 91 go func() { 92 server.Serve(listener) 93 }() 94 defer server.Stop() 95 96 t.Run("without interceptor", func(t *testing.T) { 97 conn, err := grpc.Dial( 98 listener.Addr().String(), 99 grpc.WithTransportCredentials(credentials.NewTLS(mtlsConfig.ClientTLS)), 100 grpc.WithUnaryInterceptor(interceptors.GRPCClientUnaryErrorInterceptor), 101 ) 102 require.NoError(t, err) 103 defer conn.Close() 104 105 client := proto.NewAuthServiceClient(conn) 106 _, err = client.Ping(context.Background(), &proto.PingRequest{}) 107 assert.ErrorIs(t, err, &mfa.ErrAdminActionMFARequired, "Ping error mismatch") 108 }) 109 110 okMFACeremony := func(ctx context.Context, challengeRequest *proto.CreateAuthenticateChallengeRequest, promptOpts ...mfa.PromptOpt) (*proto.MFAAuthenticateResponse, error) { 111 return &proto.MFAAuthenticateResponse{ 112 Response: &proto.MFAAuthenticateResponse_TOTP{ 113 TOTP: &proto.TOTPResponse{ 114 Code: otpTestCode, 115 }, 116 }, 117 }, nil 118 } 119 120 mfaCeremonyErr := trace.BadParameter("client does not support mfa") 121 nokMFACeremony := func(ctx context.Context, challengeRequest *proto.CreateAuthenticateChallengeRequest, promptOpts ...mfa.PromptOpt) (*proto.MFAAuthenticateResponse, error) { 122 return nil, mfaCeremonyErr 123 } 124 125 t.Run("with interceptor", func(t *testing.T) { 126 t.Run("ok mfa ceremony", func(t *testing.T) { 127 conn, err := grpc.Dial( 128 listener.Addr().String(), 129 grpc.WithTransportCredentials(credentials.NewTLS(mtlsConfig.ClientTLS)), 130 grpc.WithChainUnaryInterceptor( 131 interceptors.WithMFAUnaryInterceptor(okMFACeremony), 132 interceptors.GRPCClientUnaryErrorInterceptor, 133 ), 134 ) 135 require.NoError(t, err) 136 defer conn.Close() 137 138 client := proto.NewAuthServiceClient(conn) 139 _, err = client.Ping(ctx, &proto.PingRequest{}) 140 assert.NoError(t, err) 141 }) 142 143 t.Run("nok mfa ceremony", func(t *testing.T) { 144 conn, err := grpc.Dial( 145 listener.Addr().String(), 146 grpc.WithTransportCredentials(credentials.NewTLS(mtlsConfig.ClientTLS)), 147 grpc.WithChainUnaryInterceptor( 148 interceptors.WithMFAUnaryInterceptor(nokMFACeremony), 149 interceptors.GRPCClientUnaryErrorInterceptor, 150 ), 151 ) 152 require.NoError(t, err) 153 defer conn.Close() 154 155 client := proto.NewAuthServiceClient(conn) 156 _, err = client.Ping(ctx, &proto.PingRequest{}) 157 assert.ErrorIs(t, err, &mfa.ErrAdminActionMFARequired, "Ping error mismatch") 158 assert.ErrorIs(t, err, mfaCeremonyErr, "Ping error mismatch") 159 }) 160 161 t.Run("ok mfa in context", func(t *testing.T) { 162 conn, err := grpc.Dial( 163 listener.Addr().String(), 164 grpc.WithTransportCredentials(credentials.NewTLS(mtlsConfig.ClientTLS)), 165 grpc.WithChainUnaryInterceptor( 166 interceptors.WithMFAUnaryInterceptor(nokMFACeremony), 167 interceptors.GRPCClientUnaryErrorInterceptor, 168 ), 169 ) 170 require.NoError(t, err) 171 defer conn.Close() 172 173 mfaResp, _ := okMFACeremony(ctx, nil) 174 ctx := mfa.ContextWithMFAResponse(ctx, mfaResp) 175 176 client := proto.NewAuthServiceClient(conn) 177 _, err = client.Ping(ctx, &proto.PingRequest{}) 178 assert.NoError(t, err) 179 }) 180 }) 181 } 182 183 func TestRetryWithMFA_Reuse(t *testing.T) { 184 t.Parallel() 185 ctx := context.Background() 186 187 mtlsConfig := mtls.NewConfig(t) 188 listener, err := net.Listen("tcp", "localhost:0") 189 require.NoError(t, err) 190 191 mfaService := &mfaService{} 192 server := grpc.NewServer( 193 grpc.Creds(credentials.NewTLS(mtlsConfig.ServerTLS)), 194 grpc.ChainUnaryInterceptor(interceptors.GRPCServerUnaryErrorInterceptor), 195 ) 196 proto.RegisterAuthServiceServer(server, mfaService) 197 go func() { 198 server.Serve(listener) 199 }() 200 defer server.Stop() 201 202 okMFACeremony := func(ctx context.Context, challengeRequest *proto.CreateAuthenticateChallengeRequest, promptOpts ...mfa.PromptOpt) (*proto.MFAAuthenticateResponse, error) { 203 return &proto.MFAAuthenticateResponse{ 204 Response: &proto.MFAAuthenticateResponse_TOTP{ 205 TOTP: &proto.TOTPResponse{ 206 Code: otpTestCode, 207 }, 208 }, 209 }, nil 210 } 211 212 okMFACeremonyAllowReuse := func(ctx context.Context, challengeRequest *proto.CreateAuthenticateChallengeRequest, promptOpts ...mfa.PromptOpt) (*proto.MFAAuthenticateResponse, error) { 213 return &proto.MFAAuthenticateResponse{ 214 Response: &proto.MFAAuthenticateResponse_TOTP{ 215 TOTP: &proto.TOTPResponse{ 216 Code: otpTestCodeReusable, 217 }, 218 }, 219 }, nil 220 } 221 222 t.Run("ok allow reuse", func(t *testing.T) { 223 mfaService.allowReuse = true 224 conn, err := grpc.Dial( 225 listener.Addr().String(), 226 grpc.WithTransportCredentials(credentials.NewTLS(mtlsConfig.ClientTLS)), 227 grpc.WithChainUnaryInterceptor( 228 interceptors.WithMFAUnaryInterceptor(okMFACeremonyAllowReuse), 229 interceptors.GRPCClientUnaryErrorInterceptor, 230 ), 231 ) 232 require.NoError(t, err) 233 defer conn.Close() 234 235 client := proto.NewAuthServiceClient(conn) 236 _, err = client.Ping(ctx, &proto.PingRequest{}) 237 assert.NoError(t, err) 238 }) 239 240 t.Run("nok disallow reuse", func(t *testing.T) { 241 mfaService.allowReuse = false 242 conn, err := grpc.Dial( 243 listener.Addr().String(), 244 grpc.WithTransportCredentials(credentials.NewTLS(mtlsConfig.ClientTLS)), 245 grpc.WithChainUnaryInterceptor( 246 interceptors.WithMFAUnaryInterceptor(okMFACeremonyAllowReuse), 247 interceptors.GRPCClientUnaryErrorInterceptor, 248 ), 249 ) 250 require.NoError(t, err) 251 defer conn.Close() 252 253 client := proto.NewAuthServiceClient(conn) 254 _, err = client.Ping(ctx, &proto.PingRequest{}) 255 assert.ErrorIs(t, err, &mfa.ErrAdminActionMFARequired, "Ping error mismatch") 256 }) 257 258 t.Run("ok disallow reuse, retry with one-shot mfa", func(t *testing.T) { 259 mfaService.allowReuse = false 260 conn, err := grpc.Dial( 261 listener.Addr().String(), 262 grpc.WithTransportCredentials(credentials.NewTLS(mtlsConfig.ClientTLS)), 263 grpc.WithChainUnaryInterceptor( 264 interceptors.WithMFAUnaryInterceptor(okMFACeremony), 265 interceptors.GRPCClientUnaryErrorInterceptor, 266 ), 267 ) 268 require.NoError(t, err) 269 defer conn.Close() 270 271 // Pass reusable MFA through the context. The interceptor should 272 // catch the resulting ErrAdminActionMFARequired and retry with 273 // a one-shot mfa challenge. 274 mfaResp, _ := okMFACeremony(ctx, nil) 275 ctx := mfa.ContextWithMFAResponse(ctx, mfaResp) 276 277 client := proto.NewAuthServiceClient(conn) 278 _, err = client.Ping(ctx, &proto.PingRequest{}) 279 assert.NoError(t, err) 280 }) 281 }