github.com/gravitational/teleport/api@v0.0.0-20240507183017-3110591cbafc/client/joinservice.go (about)

     1  /*
     2  Copyright 2022 Gravitational, Inc.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package client
    18  
    19  import (
    20  	"context"
    21  
    22  	"github.com/gravitational/trace"
    23  
    24  	"github.com/gravitational/teleport/api/client/proto"
    25  )
    26  
    27  // JoinServiceClient is a client for the JoinService, which runs on both the
    28  // auth and proxy.
    29  type JoinServiceClient struct {
    30  	grpcClient proto.JoinServiceClient
    31  }
    32  
    33  // NewJoinServiceClient returns a new JoinServiceClient wrapping the given grpc
    34  // client.
    35  func NewJoinServiceClient(grpcClient proto.JoinServiceClient) *JoinServiceClient {
    36  	return &JoinServiceClient{
    37  		grpcClient: grpcClient,
    38  	}
    39  }
    40  
    41  // RegisterIAMChallengeResponseFunc is a function type meant to be passed to
    42  // RegisterUsingIAMMethod. It must return a *proto.RegisterUsingIAMMethodRequest
    43  // for a given challenge, or an error.
    44  type RegisterIAMChallengeResponseFunc func(challenge string) (*proto.RegisterUsingIAMMethodRequest, error)
    45  
    46  // RegisterAzureChallengeResponseFunc is a function type meant to be passed to
    47  // RegisterUsingAzureMethod. It must return a
    48  // *proto.RegisterUsingAzureMethodRequest for a given challenge, or an error.
    49  type RegisterAzureChallengeResponseFunc func(challenge string) (*proto.RegisterUsingAzureMethodRequest, error)
    50  
    51  // RegisterTPMChallengeResponseFunc is a function type meant to be passed to
    52  // RegisterUsingTPMMethod. It must return a
    53  // *proto.RegisterUsingTPMMethodChallengeResponse for a given challenge, or an
    54  // error.
    55  type RegisterTPMChallengeResponseFunc func(challenge *proto.TPMEncryptedCredential) (*proto.RegisterUsingTPMMethodChallengeResponse, error)
    56  
    57  // RegisterUsingIAMMethod registers the caller using the IAM join method and
    58  // returns signed certs to join the cluster.
    59  //
    60  // The caller must provide a ChallengeResponseFunc which returns a
    61  // *types.RegisterUsingTokenRequest with a signed sts:GetCallerIdentity request
    62  // including the challenge as a signed header.
    63  func (c *JoinServiceClient) RegisterUsingIAMMethod(ctx context.Context, challengeResponse RegisterIAMChallengeResponseFunc) (*proto.Certs, error) {
    64  	// Make sure the gRPC stream is closed when this returns
    65  	ctx, cancel := context.WithCancel(ctx)
    66  	defer cancel()
    67  
    68  	// initiate the streaming rpc
    69  	iamJoinClient, err := c.grpcClient.RegisterUsingIAMMethod(ctx)
    70  	if err != nil {
    71  		return nil, trace.Wrap(err)
    72  	}
    73  
    74  	// wait for the challenge string from auth
    75  	challenge, err := iamJoinClient.Recv()
    76  	if err != nil {
    77  		return nil, trace.Wrap(err)
    78  	}
    79  
    80  	// get challenge response from the caller
    81  	req, err := challengeResponse(challenge.Challenge)
    82  	if err != nil {
    83  		return nil, trace.Wrap(err)
    84  	}
    85  
    86  	// forward the challenge response from the caller to auth
    87  	if err := iamJoinClient.Send(req); err != nil {
    88  		return nil, trace.Wrap(err)
    89  	}
    90  
    91  	// wait for the certs from auth and return to the caller
    92  	certsResp, err := iamJoinClient.Recv()
    93  	if err != nil {
    94  		return nil, trace.Wrap(err)
    95  	}
    96  	return certsResp.Certs, nil
    97  }
    98  
    99  // RegisterUsingAzureMethod registers the caller using the Azure join method and
   100  // returns signed certs to join the cluster.
   101  //
   102  // The caller must provide a ChallengeResponseFunc which returns a
   103  // *proto.RegisterUsingAzureMethodRequest with a signed attested data document
   104  // including the challenge as a nonce.
   105  func (c *JoinServiceClient) RegisterUsingAzureMethod(ctx context.Context, challengeResponse RegisterAzureChallengeResponseFunc) (*proto.Certs, error) {
   106  	ctx, cancel := context.WithCancel(ctx)
   107  	defer cancel()
   108  
   109  	azureJoinClient, err := c.grpcClient.RegisterUsingAzureMethod(ctx)
   110  	if err != nil {
   111  		return nil, trace.Wrap(err)
   112  	}
   113  
   114  	challenge, err := azureJoinClient.Recv()
   115  	if err != nil {
   116  		return nil, trace.Wrap(err)
   117  	}
   118  
   119  	req, err := challengeResponse(challenge.Challenge)
   120  	if err != nil {
   121  		return nil, trace.Wrap(err)
   122  	}
   123  
   124  	if err := azureJoinClient.Send(req); err != nil {
   125  		return nil, trace.Wrap(err)
   126  	}
   127  
   128  	certsResp, err := azureJoinClient.Recv()
   129  	if err != nil {
   130  		return nil, trace.Wrap(err)
   131  	}
   132  	return certsResp.Certs, nil
   133  }
   134  
   135  // RegisterUsingTPMMethod registers the caller using the TPM join method and
   136  // returns signed certs to join the cluster. The caller must provide a
   137  // ChallengeResponseFunc which returns a *proto.RegisterUsingTPMMethodRequest
   138  // for a given challenge, or an error.
   139  func (c *JoinServiceClient) RegisterUsingTPMMethod(
   140  	ctx context.Context,
   141  	initReq *proto.RegisterUsingTPMMethodInitialRequest,
   142  	solveChallenge RegisterTPMChallengeResponseFunc,
   143  ) (*proto.Certs, error) {
   144  	ctx, cancel := context.WithCancel(ctx)
   145  	defer cancel()
   146  
   147  	stream, err := c.grpcClient.RegisterUsingTPMMethod(ctx)
   148  	if err != nil {
   149  		return nil, trace.Wrap(err)
   150  	}
   151  	defer stream.CloseSend()
   152  
   153  	err = stream.Send(&proto.RegisterUsingTPMMethodRequest{
   154  		Payload: &proto.RegisterUsingTPMMethodRequest_Init{
   155  			Init: initReq,
   156  		},
   157  	})
   158  	if err != nil {
   159  		return nil, trace.Wrap(err, "sending initial request")
   160  	}
   161  
   162  	res, err := stream.Recv()
   163  	if err != nil {
   164  		return nil, trace.Wrap(err, "receiving challenge")
   165  	}
   166  
   167  	challenge := res.GetChallengeRequest()
   168  	if challenge == nil {
   169  		return nil, trace.BadParameter(
   170  			"expected ChallengeRequest payload, got %T",
   171  			res.Payload,
   172  		)
   173  	}
   174  
   175  	solution, err := solveChallenge(challenge)
   176  	if err != nil {
   177  		return nil, trace.Wrap(err, "solving challenge")
   178  	}
   179  
   180  	err = stream.Send(&proto.RegisterUsingTPMMethodRequest{
   181  		Payload: &proto.RegisterUsingTPMMethodRequest_ChallengeResponse{
   182  			ChallengeResponse: solution,
   183  		},
   184  	})
   185  	if err != nil {
   186  		return nil, trace.Wrap(err, "sending solution")
   187  	}
   188  
   189  	res, err = stream.Recv()
   190  	if err != nil {
   191  		return nil, trace.Wrap(err, "receiving certs")
   192  	}
   193  	certs := res.GetCerts()
   194  	if certs == nil {
   195  		return nil, trace.BadParameter(
   196  			"expected Certs payload, got %T",
   197  			res.Payload,
   198  		)
   199  	}
   200  
   201  	return certs, nil
   202  }