github.com/gravitational/teleport/api@v0.0.0-20240507183017-3110591cbafc/client/joinservice_test.go (about) 1 /* 2 Copyright 2021 Gravitational, Inc. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package client 18 19 import ( 20 "context" 21 "errors" 22 "net" 23 "testing" 24 25 "github.com/google/go-cmp/cmp" 26 "github.com/stretchr/testify/assert" 27 "github.com/stretchr/testify/require" 28 "google.golang.org/grpc" 29 "google.golang.org/grpc/credentials/insecure" 30 "google.golang.org/grpc/test/bufconn" 31 32 "github.com/gravitational/teleport/api/client/proto" 33 "github.com/gravitational/teleport/api/types" 34 ) 35 36 type mockJoinServiceServer struct { 37 *proto.UnimplementedJoinServiceServer 38 registerUsingTPMMethod func(srv proto.JoinService_RegisterUsingTPMMethodServer) error 39 } 40 41 func (m *mockJoinServiceServer) RegisterUsingTPMMethod(srv proto.JoinService_RegisterUsingTPMMethodServer) error { 42 return m.registerUsingTPMMethod(srv) 43 } 44 45 func TestJoinServiceClient_RegisterUsingTPMMethod(t *testing.T) { 46 t.Parallel() 47 48 ctx, cancel := context.WithCancel(context.Background()) 49 t.Cleanup(cancel) 50 51 lis := bufconn.Listen(100) 52 t.Cleanup(func() { 53 assert.NoError(t, lis.Close()) 54 }) 55 56 mockInitReq := &proto.RegisterUsingTPMMethodInitialRequest{ 57 JoinRequest: &types.RegisterUsingTokenRequest{ 58 Token: "token", 59 }, 60 } 61 mockChallenge := &proto.TPMEncryptedCredential{ 62 CredentialBlob: []byte("cred-blob"), 63 Secret: []byte("secret"), 64 } 65 mockChallengeResp := &proto.RegisterUsingTPMMethodChallengeResponse{ 66 Solution: []byte("solution"), 67 } 68 mockCerts := &proto.Certs{ 69 TLS: []byte("cert"), 70 } 71 mockService := &mockJoinServiceServer{ 72 registerUsingTPMMethod: func(srv proto.JoinService_RegisterUsingTPMMethodServer) error { 73 req, err := srv.Recv() 74 if !assert.NoError(t, err) { 75 return err 76 } 77 assert.Empty(t, cmp.Diff(req.GetInit(), mockInitReq)) 78 79 err = srv.Send(&proto.RegisterUsingTPMMethodResponse{ 80 Payload: &proto.RegisterUsingTPMMethodResponse_ChallengeRequest{ 81 ChallengeRequest: mockChallenge, 82 }, 83 }) 84 if !assert.NoError(t, err) { 85 return err 86 } 87 88 req, err = srv.Recv() 89 if !assert.NoError(t, err) { 90 return err 91 } 92 assert.Empty(t, cmp.Diff(req.GetChallengeResponse(), mockChallengeResp)) 93 94 err = srv.Send(&proto.RegisterUsingTPMMethodResponse{ 95 Payload: &proto.RegisterUsingTPMMethodResponse_Certs{ 96 Certs: mockCerts, 97 }, 98 }) 99 if !assert.NoError(t, err) { 100 return err 101 } 102 return nil 103 }, 104 } 105 srv := grpc.NewServer() 106 t.Cleanup(func() { 107 srv.Stop() 108 }) 109 proto.RegisterJoinServiceServer(srv, mockService) 110 111 go func() { 112 err := srv.Serve(lis) 113 if err != nil && !errors.Is(err, grpc.ErrServerStopped) { 114 assert.NoError(t, err) 115 } 116 cancel() 117 }() 118 119 c, err := grpc.NewClient("unused.com", grpc.WithContextDialer(func(ctx context.Context, _ string) (net.Conn, error) { 120 return lis.DialContext(ctx) 121 }), grpc.WithTransportCredentials(insecure.NewCredentials())) 122 require.NoError(t, err) 123 124 joinClient := NewJoinServiceClient(proto.NewJoinServiceClient(c)) 125 126 certs, err := joinClient.RegisterUsingTPMMethod( 127 ctx, 128 mockInitReq, 129 func(challenge *proto.TPMEncryptedCredential) (*proto.RegisterUsingTPMMethodChallengeResponse, error) { 130 assert.Empty(t, cmp.Diff(mockChallenge, challenge)) 131 return mockChallengeResp, nil 132 }, 133 ) 134 if assert.NoError(t, err) { 135 assert.Empty(t, cmp.Diff(mockCerts, certs)) 136 } 137 }