github.com/gravitational/teleport/api@v0.0.0-20240507183017-3110591cbafc/client/joinservice_test.go (about)

     1  /*
     2  Copyright 2021 Gravitational, Inc.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package client
    18  
    19  import (
    20  	"context"
    21  	"errors"
    22  	"net"
    23  	"testing"
    24  
    25  	"github.com/google/go-cmp/cmp"
    26  	"github.com/stretchr/testify/assert"
    27  	"github.com/stretchr/testify/require"
    28  	"google.golang.org/grpc"
    29  	"google.golang.org/grpc/credentials/insecure"
    30  	"google.golang.org/grpc/test/bufconn"
    31  
    32  	"github.com/gravitational/teleport/api/client/proto"
    33  	"github.com/gravitational/teleport/api/types"
    34  )
    35  
    36  type mockJoinServiceServer struct {
    37  	*proto.UnimplementedJoinServiceServer
    38  	registerUsingTPMMethod func(srv proto.JoinService_RegisterUsingTPMMethodServer) error
    39  }
    40  
    41  func (m *mockJoinServiceServer) RegisterUsingTPMMethod(srv proto.JoinService_RegisterUsingTPMMethodServer) error {
    42  	return m.registerUsingTPMMethod(srv)
    43  }
    44  
    45  func TestJoinServiceClient_RegisterUsingTPMMethod(t *testing.T) {
    46  	t.Parallel()
    47  
    48  	ctx, cancel := context.WithCancel(context.Background())
    49  	t.Cleanup(cancel)
    50  
    51  	lis := bufconn.Listen(100)
    52  	t.Cleanup(func() {
    53  		assert.NoError(t, lis.Close())
    54  	})
    55  
    56  	mockInitReq := &proto.RegisterUsingTPMMethodInitialRequest{
    57  		JoinRequest: &types.RegisterUsingTokenRequest{
    58  			Token: "token",
    59  		},
    60  	}
    61  	mockChallenge := &proto.TPMEncryptedCredential{
    62  		CredentialBlob: []byte("cred-blob"),
    63  		Secret:         []byte("secret"),
    64  	}
    65  	mockChallengeResp := &proto.RegisterUsingTPMMethodChallengeResponse{
    66  		Solution: []byte("solution"),
    67  	}
    68  	mockCerts := &proto.Certs{
    69  		TLS: []byte("cert"),
    70  	}
    71  	mockService := &mockJoinServiceServer{
    72  		registerUsingTPMMethod: func(srv proto.JoinService_RegisterUsingTPMMethodServer) error {
    73  			req, err := srv.Recv()
    74  			if !assert.NoError(t, err) {
    75  				return err
    76  			}
    77  			assert.Empty(t, cmp.Diff(req.GetInit(), mockInitReq))
    78  
    79  			err = srv.Send(&proto.RegisterUsingTPMMethodResponse{
    80  				Payload: &proto.RegisterUsingTPMMethodResponse_ChallengeRequest{
    81  					ChallengeRequest: mockChallenge,
    82  				},
    83  			})
    84  			if !assert.NoError(t, err) {
    85  				return err
    86  			}
    87  
    88  			req, err = srv.Recv()
    89  			if !assert.NoError(t, err) {
    90  				return err
    91  			}
    92  			assert.Empty(t, cmp.Diff(req.GetChallengeResponse(), mockChallengeResp))
    93  
    94  			err = srv.Send(&proto.RegisterUsingTPMMethodResponse{
    95  				Payload: &proto.RegisterUsingTPMMethodResponse_Certs{
    96  					Certs: mockCerts,
    97  				},
    98  			})
    99  			if !assert.NoError(t, err) {
   100  				return err
   101  			}
   102  			return nil
   103  		},
   104  	}
   105  	srv := grpc.NewServer()
   106  	t.Cleanup(func() {
   107  		srv.Stop()
   108  	})
   109  	proto.RegisterJoinServiceServer(srv, mockService)
   110  
   111  	go func() {
   112  		err := srv.Serve(lis)
   113  		if err != nil && !errors.Is(err, grpc.ErrServerStopped) {
   114  			assert.NoError(t, err)
   115  		}
   116  		cancel()
   117  	}()
   118  
   119  	c, err := grpc.NewClient("unused.com", grpc.WithContextDialer(func(ctx context.Context, _ string) (net.Conn, error) {
   120  		return lis.DialContext(ctx)
   121  	}), grpc.WithTransportCredentials(insecure.NewCredentials()))
   122  	require.NoError(t, err)
   123  
   124  	joinClient := NewJoinServiceClient(proto.NewJoinServiceClient(c))
   125  
   126  	certs, err := joinClient.RegisterUsingTPMMethod(
   127  		ctx,
   128  		mockInitReq,
   129  		func(challenge *proto.TPMEncryptedCredential) (*proto.RegisterUsingTPMMethodChallengeResponse, error) {
   130  			assert.Empty(t, cmp.Diff(mockChallenge, challenge))
   131  			return mockChallengeResp, nil
   132  		},
   133  	)
   134  	if assert.NoError(t, err) {
   135  		assert.Empty(t, cmp.Diff(mockCerts, certs))
   136  	}
   137  }