github.com/gravitational/teleport/api@v0.0.0-20240507183017-3110591cbafc/client/joinservice.go (about) 1 /* 2 Copyright 2022 Gravitational, Inc. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package client 18 19 import ( 20 "context" 21 22 "github.com/gravitational/trace" 23 24 "github.com/gravitational/teleport/api/client/proto" 25 ) 26 27 // JoinServiceClient is a client for the JoinService, which runs on both the 28 // auth and proxy. 29 type JoinServiceClient struct { 30 grpcClient proto.JoinServiceClient 31 } 32 33 // NewJoinServiceClient returns a new JoinServiceClient wrapping the given grpc 34 // client. 35 func NewJoinServiceClient(grpcClient proto.JoinServiceClient) *JoinServiceClient { 36 return &JoinServiceClient{ 37 grpcClient: grpcClient, 38 } 39 } 40 41 // RegisterIAMChallengeResponseFunc is a function type meant to be passed to 42 // RegisterUsingIAMMethod. It must return a *proto.RegisterUsingIAMMethodRequest 43 // for a given challenge, or an error. 44 type RegisterIAMChallengeResponseFunc func(challenge string) (*proto.RegisterUsingIAMMethodRequest, error) 45 46 // RegisterAzureChallengeResponseFunc is a function type meant to be passed to 47 // RegisterUsingAzureMethod. It must return a 48 // *proto.RegisterUsingAzureMethodRequest for a given challenge, or an error. 49 type RegisterAzureChallengeResponseFunc func(challenge string) (*proto.RegisterUsingAzureMethodRequest, error) 50 51 // RegisterTPMChallengeResponseFunc is a function type meant to be passed to 52 // RegisterUsingTPMMethod. It must return a 53 // *proto.RegisterUsingTPMMethodChallengeResponse for a given challenge, or an 54 // error. 55 type RegisterTPMChallengeResponseFunc func(challenge *proto.TPMEncryptedCredential) (*proto.RegisterUsingTPMMethodChallengeResponse, error) 56 57 // RegisterUsingIAMMethod registers the caller using the IAM join method and 58 // returns signed certs to join the cluster. 59 // 60 // The caller must provide a ChallengeResponseFunc which returns a 61 // *types.RegisterUsingTokenRequest with a signed sts:GetCallerIdentity request 62 // including the challenge as a signed header. 63 func (c *JoinServiceClient) RegisterUsingIAMMethod(ctx context.Context, challengeResponse RegisterIAMChallengeResponseFunc) (*proto.Certs, error) { 64 // Make sure the gRPC stream is closed when this returns 65 ctx, cancel := context.WithCancel(ctx) 66 defer cancel() 67 68 // initiate the streaming rpc 69 iamJoinClient, err := c.grpcClient.RegisterUsingIAMMethod(ctx) 70 if err != nil { 71 return nil, trace.Wrap(err) 72 } 73 74 // wait for the challenge string from auth 75 challenge, err := iamJoinClient.Recv() 76 if err != nil { 77 return nil, trace.Wrap(err) 78 } 79 80 // get challenge response from the caller 81 req, err := challengeResponse(challenge.Challenge) 82 if err != nil { 83 return nil, trace.Wrap(err) 84 } 85 86 // forward the challenge response from the caller to auth 87 if err := iamJoinClient.Send(req); err != nil { 88 return nil, trace.Wrap(err) 89 } 90 91 // wait for the certs from auth and return to the caller 92 certsResp, err := iamJoinClient.Recv() 93 if err != nil { 94 return nil, trace.Wrap(err) 95 } 96 return certsResp.Certs, nil 97 } 98 99 // RegisterUsingAzureMethod registers the caller using the Azure join method and 100 // returns signed certs to join the cluster. 101 // 102 // The caller must provide a ChallengeResponseFunc which returns a 103 // *proto.RegisterUsingAzureMethodRequest with a signed attested data document 104 // including the challenge as a nonce. 105 func (c *JoinServiceClient) RegisterUsingAzureMethod(ctx context.Context, challengeResponse RegisterAzureChallengeResponseFunc) (*proto.Certs, error) { 106 ctx, cancel := context.WithCancel(ctx) 107 defer cancel() 108 109 azureJoinClient, err := c.grpcClient.RegisterUsingAzureMethod(ctx) 110 if err != nil { 111 return nil, trace.Wrap(err) 112 } 113 114 challenge, err := azureJoinClient.Recv() 115 if err != nil { 116 return nil, trace.Wrap(err) 117 } 118 119 req, err := challengeResponse(challenge.Challenge) 120 if err != nil { 121 return nil, trace.Wrap(err) 122 } 123 124 if err := azureJoinClient.Send(req); err != nil { 125 return nil, trace.Wrap(err) 126 } 127 128 certsResp, err := azureJoinClient.Recv() 129 if err != nil { 130 return nil, trace.Wrap(err) 131 } 132 return certsResp.Certs, nil 133 } 134 135 // RegisterUsingTPMMethod registers the caller using the TPM join method and 136 // returns signed certs to join the cluster. The caller must provide a 137 // ChallengeResponseFunc which returns a *proto.RegisterUsingTPMMethodRequest 138 // for a given challenge, or an error. 139 func (c *JoinServiceClient) RegisterUsingTPMMethod( 140 ctx context.Context, 141 initReq *proto.RegisterUsingTPMMethodInitialRequest, 142 solveChallenge RegisterTPMChallengeResponseFunc, 143 ) (*proto.Certs, error) { 144 ctx, cancel := context.WithCancel(ctx) 145 defer cancel() 146 147 stream, err := c.grpcClient.RegisterUsingTPMMethod(ctx) 148 if err != nil { 149 return nil, trace.Wrap(err) 150 } 151 defer stream.CloseSend() 152 153 err = stream.Send(&proto.RegisterUsingTPMMethodRequest{ 154 Payload: &proto.RegisterUsingTPMMethodRequest_Init{ 155 Init: initReq, 156 }, 157 }) 158 if err != nil { 159 return nil, trace.Wrap(err, "sending initial request") 160 } 161 162 res, err := stream.Recv() 163 if err != nil { 164 return nil, trace.Wrap(err, "receiving challenge") 165 } 166 167 challenge := res.GetChallengeRequest() 168 if challenge == nil { 169 return nil, trace.BadParameter( 170 "expected ChallengeRequest payload, got %T", 171 res.Payload, 172 ) 173 } 174 175 solution, err := solveChallenge(challenge) 176 if err != nil { 177 return nil, trace.Wrap(err, "solving challenge") 178 } 179 180 err = stream.Send(&proto.RegisterUsingTPMMethodRequest{ 181 Payload: &proto.RegisterUsingTPMMethodRequest_ChallengeResponse{ 182 ChallengeResponse: solution, 183 }, 184 }) 185 if err != nil { 186 return nil, trace.Wrap(err, "sending solution") 187 } 188 189 res, err = stream.Recv() 190 if err != nil { 191 return nil, trace.Wrap(err, "receiving certs") 192 } 193 certs := res.GetCerts() 194 if certs == nil { 195 return nil, trace.BadParameter( 196 "expected Certs payload, got %T", 197 res.Payload, 198 ) 199 } 200 201 return certs, nil 202 }