google.golang.org/grpc@v1.74.2/credentials/alts/internal/handshaker/handshaker.go (about)

     1  /*
     2   *
     3   * Copyright 2018 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  // Package handshaker provides ALTS handshaking functionality for GCP.
    20  package handshaker
    21  
    22  import (
    23  	"context"
    24  	"errors"
    25  	"fmt"
    26  	"io"
    27  	"net"
    28  	"time"
    29  
    30  	"golang.org/x/sync/semaphore"
    31  	grpc "google.golang.org/grpc"
    32  	"google.golang.org/grpc/codes"
    33  	"google.golang.org/grpc/credentials"
    34  	core "google.golang.org/grpc/credentials/alts/internal"
    35  	"google.golang.org/grpc/credentials/alts/internal/authinfo"
    36  	"google.golang.org/grpc/credentials/alts/internal/conn"
    37  	altsgrpc "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
    38  	altspb "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
    39  	"google.golang.org/grpc/internal/envconfig"
    40  )
    41  
    42  const (
    43  	// The maximum byte size of receive frames.
    44  	frameLimit              = 64 * 1024 // 64 KB
    45  	rekeyRecordProtocolName = "ALTSRP_GCM_AES128_REKEY"
    46  )
    47  
    48  var (
    49  	hsProtocol      = altspb.HandshakeProtocol_ALTS
    50  	appProtocols    = []string{"grpc"}
    51  	recordProtocols = []string{rekeyRecordProtocolName}
    52  	keyLength       = map[string]int{
    53  		rekeyRecordProtocolName: 44,
    54  	}
    55  	altsRecordFuncs = map[string]conn.ALTSRecordFunc{
    56  		// ALTS handshaker protocols.
    57  		rekeyRecordProtocolName: func(s core.Side, keyData []byte) (conn.ALTSRecordCrypto, error) {
    58  			return conn.NewAES128GCMRekey(s, keyData)
    59  		},
    60  	}
    61  	// control number of concurrent created (but not closed) handshakes.
    62  	clientHandshakes = semaphore.NewWeighted(int64(envconfig.ALTSMaxConcurrentHandshakes))
    63  	serverHandshakes = semaphore.NewWeighted(int64(envconfig.ALTSMaxConcurrentHandshakes))
    64  	// errOutOfBound occurs when the handshake service returns a consumed
    65  	// bytes value larger than the buffer that was passed to it originally.
    66  	errOutOfBound = errors.New("handshaker service consumed bytes value is out-of-bound")
    67  )
    68  
    69  func init() {
    70  	for protocol, f := range altsRecordFuncs {
    71  		if err := conn.RegisterProtocol(protocol, f); err != nil {
    72  			panic(err)
    73  		}
    74  	}
    75  }
    76  
    77  // ClientHandshakerOptions contains the client handshaker options that can
    78  // provided by the caller.
    79  type ClientHandshakerOptions struct {
    80  	// ClientIdentity is the handshaker client local identity.
    81  	ClientIdentity *altspb.Identity
    82  	// TargetName is the server service account name for secure name
    83  	// checking.
    84  	TargetName string
    85  	// TargetServiceAccounts contains a list of expected target service
    86  	// accounts. One of these accounts should match one of the accounts in
    87  	// the handshaker results. Otherwise, the handshake fails.
    88  	TargetServiceAccounts []string
    89  	// RPCVersions specifies the gRPC versions accepted by the client.
    90  	RPCVersions *altspb.RpcProtocolVersions
    91  	// BoundAccessToken is a bound access token to be sent to the server for authentication.
    92  	BoundAccessToken string
    93  }
    94  
    95  // ServerHandshakerOptions contains the server handshaker options that can
    96  // provided by the caller.
    97  type ServerHandshakerOptions struct {
    98  	// RPCVersions specifies the gRPC versions accepted by the server.
    99  	RPCVersions *altspb.RpcProtocolVersions
   100  }
   101  
   102  // DefaultClientHandshakerOptions returns the default client handshaker options.
   103  func DefaultClientHandshakerOptions() *ClientHandshakerOptions {
   104  	return &ClientHandshakerOptions{}
   105  }
   106  
   107  // DefaultServerHandshakerOptions returns the default client handshaker options.
   108  func DefaultServerHandshakerOptions() *ServerHandshakerOptions {
   109  	return &ServerHandshakerOptions{}
   110  }
   111  
   112  // altsHandshaker is used to complete an ALTS handshake between client and
   113  // server. This handshaker talks to the ALTS handshaker service in the metadata
   114  // server.
   115  type altsHandshaker struct {
   116  	// RPC stream used to access the ALTS Handshaker service.
   117  	stream altsgrpc.HandshakerService_DoHandshakeClient
   118  	// the connection to the peer.
   119  	conn net.Conn
   120  	// a virtual connection to the ALTS handshaker service.
   121  	clientConn *grpc.ClientConn
   122  	// client handshake options.
   123  	clientOpts *ClientHandshakerOptions
   124  	// server handshake options.
   125  	serverOpts *ServerHandshakerOptions
   126  	// defines the side doing the handshake, client or server.
   127  	side core.Side
   128  }
   129  
   130  // NewClientHandshaker creates a core.Handshaker that performs a client-side
   131  // ALTS handshake by acting as a proxy between the peer and the ALTS handshaker
   132  // service in the metadata server.
   133  func NewClientHandshaker(_ context.Context, conn *grpc.ClientConn, c net.Conn, opts *ClientHandshakerOptions) (core.Handshaker, error) {
   134  	return &altsHandshaker{
   135  		stream:     nil,
   136  		conn:       c,
   137  		clientConn: conn,
   138  		clientOpts: opts,
   139  		side:       core.ClientSide,
   140  	}, nil
   141  }
   142  
   143  // NewServerHandshaker creates a core.Handshaker that performs a server-side
   144  // ALTS handshake by acting as a proxy between the peer and the ALTS handshaker
   145  // service in the metadata server.
   146  func NewServerHandshaker(_ context.Context, conn *grpc.ClientConn, c net.Conn, opts *ServerHandshakerOptions) (core.Handshaker, error) {
   147  	return &altsHandshaker{
   148  		stream:     nil,
   149  		conn:       c,
   150  		clientConn: conn,
   151  		serverOpts: opts,
   152  		side:       core.ServerSide,
   153  	}, nil
   154  }
   155  
   156  // ClientHandshake starts and completes a client ALTS handshake for GCP. Once
   157  // done, ClientHandshake returns a secure connection.
   158  func (h *altsHandshaker) ClientHandshake(ctx context.Context) (net.Conn, credentials.AuthInfo, error) {
   159  	if err := clientHandshakes.Acquire(ctx, 1); err != nil {
   160  		return nil, nil, err
   161  	}
   162  	defer clientHandshakes.Release(1)
   163  
   164  	if h.side != core.ClientSide {
   165  		return nil, nil, errors.New("only handshakers created using NewClientHandshaker can perform a client handshaker")
   166  	}
   167  
   168  	// TODO(matthewstevenson88): Change unit tests to use public APIs so
   169  	// that h.stream can unconditionally be set based on h.clientConn.
   170  	if h.stream == nil {
   171  		stream, err := altsgrpc.NewHandshakerServiceClient(h.clientConn).DoHandshake(ctx)
   172  		if err != nil {
   173  			return nil, nil, fmt.Errorf("failed to establish stream to ALTS handshaker service: %v", err)
   174  		}
   175  		h.stream = stream
   176  	}
   177  
   178  	// Create target identities from service account list.
   179  	targetIdentities := make([]*altspb.Identity, 0, len(h.clientOpts.TargetServiceAccounts))
   180  	for _, account := range h.clientOpts.TargetServiceAccounts {
   181  		targetIdentities = append(targetIdentities, &altspb.Identity{
   182  			IdentityOneof: &altspb.Identity_ServiceAccount{
   183  				ServiceAccount: account,
   184  			},
   185  		})
   186  	}
   187  	req := &altspb.HandshakerReq{
   188  		ReqOneof: &altspb.HandshakerReq_ClientStart{
   189  			ClientStart: &altspb.StartClientHandshakeReq{
   190  				HandshakeSecurityProtocol: hsProtocol,
   191  				ApplicationProtocols:      appProtocols,
   192  				RecordProtocols:           recordProtocols,
   193  				TargetIdentities:          targetIdentities,
   194  				LocalIdentity:             h.clientOpts.ClientIdentity,
   195  				TargetName:                h.clientOpts.TargetName,
   196  				RpcVersions:               h.clientOpts.RPCVersions,
   197  			},
   198  		},
   199  	}
   200  	if h.clientOpts.BoundAccessToken != "" {
   201  		req.GetClientStart().AccessToken = h.clientOpts.BoundAccessToken
   202  	}
   203  	conn, result, err := h.doHandshake(req)
   204  	if err != nil {
   205  		return nil, nil, err
   206  	}
   207  	authInfo := authinfo.New(result)
   208  	return conn, authInfo, nil
   209  }
   210  
   211  // ServerHandshake starts and completes a server ALTS handshake for GCP. Once
   212  // done, ServerHandshake returns a secure connection.
   213  func (h *altsHandshaker) ServerHandshake(ctx context.Context) (net.Conn, credentials.AuthInfo, error) {
   214  	if err := serverHandshakes.Acquire(ctx, 1); err != nil {
   215  		return nil, nil, err
   216  	}
   217  	defer serverHandshakes.Release(1)
   218  
   219  	if h.side != core.ServerSide {
   220  		return nil, nil, errors.New("only handshakers created using NewServerHandshaker can perform a server handshaker")
   221  	}
   222  
   223  	// TODO(matthewstevenson88): Change unit tests to use public APIs so
   224  	// that h.stream can unconditionally be set based on h.clientConn.
   225  	if h.stream == nil {
   226  		stream, err := altsgrpc.NewHandshakerServiceClient(h.clientConn).DoHandshake(ctx)
   227  		if err != nil {
   228  			return nil, nil, fmt.Errorf("failed to establish stream to ALTS handshaker service: %v", err)
   229  		}
   230  		h.stream = stream
   231  	}
   232  
   233  	p := make([]byte, frameLimit)
   234  	n, err := h.conn.Read(p)
   235  	if err != nil {
   236  		return nil, nil, err
   237  	}
   238  
   239  	// Prepare server parameters.
   240  	params := make(map[int32]*altspb.ServerHandshakeParameters)
   241  	params[int32(altspb.HandshakeProtocol_ALTS)] = &altspb.ServerHandshakeParameters{
   242  		RecordProtocols: recordProtocols,
   243  	}
   244  	req := &altspb.HandshakerReq{
   245  		ReqOneof: &altspb.HandshakerReq_ServerStart{
   246  			ServerStart: &altspb.StartServerHandshakeReq{
   247  				ApplicationProtocols: appProtocols,
   248  				HandshakeParameters:  params,
   249  				InBytes:              p[:n],
   250  				RpcVersions:          h.serverOpts.RPCVersions,
   251  			},
   252  		},
   253  	}
   254  
   255  	conn, result, err := h.doHandshake(req)
   256  	if err != nil {
   257  		return nil, nil, err
   258  	}
   259  	authInfo := authinfo.New(result)
   260  	return conn, authInfo, nil
   261  }
   262  
   263  func (h *altsHandshaker) doHandshake(req *altspb.HandshakerReq) (net.Conn, *altspb.HandshakerResult, error) {
   264  	resp, err := h.accessHandshakerService(req)
   265  	if err != nil {
   266  		return nil, nil, err
   267  	}
   268  	// Check of the returned status is an error.
   269  	if resp.GetStatus() != nil {
   270  		if got, want := resp.GetStatus().Code, uint32(codes.OK); got != want {
   271  			return nil, nil, fmt.Errorf("%v", resp.GetStatus().Details)
   272  		}
   273  	}
   274  
   275  	var extra []byte
   276  	if req.GetServerStart() != nil {
   277  		if resp.GetBytesConsumed() > uint32(len(req.GetServerStart().GetInBytes())) {
   278  			return nil, nil, errOutOfBound
   279  		}
   280  		extra = req.GetServerStart().GetInBytes()[resp.GetBytesConsumed():]
   281  	}
   282  	result, extra, err := h.processUntilDone(resp, extra)
   283  	if err != nil {
   284  		return nil, nil, err
   285  	}
   286  	// The handshaker returns a 128 bytes key. It should be truncated based
   287  	// on the returned record protocol.
   288  	keyLen, ok := keyLength[result.RecordProtocol]
   289  	if !ok {
   290  		return nil, nil, fmt.Errorf("unknown resulted record protocol %v", result.RecordProtocol)
   291  	}
   292  	sc, err := conn.NewConn(h.conn, h.side, result.GetRecordProtocol(), result.KeyData[:keyLen], extra)
   293  	if err != nil {
   294  		return nil, nil, err
   295  	}
   296  	return sc, result, nil
   297  }
   298  
   299  func (h *altsHandshaker) accessHandshakerService(req *altspb.HandshakerReq) (*altspb.HandshakerResp, error) {
   300  	if err := h.stream.Send(req); err != nil {
   301  		return nil, err
   302  	}
   303  	resp, err := h.stream.Recv()
   304  	if err != nil {
   305  		return nil, err
   306  	}
   307  	return resp, nil
   308  }
   309  
   310  // processUntilDone processes the handshake until the handshaker service returns
   311  // the results. Handshaker service takes care of frame parsing, so we read
   312  // whatever received from the network and send it to the handshaker service.
   313  func (h *altsHandshaker) processUntilDone(resp *altspb.HandshakerResp, extra []byte) (*altspb.HandshakerResult, []byte, error) {
   314  	var lastWriteTime time.Time
   315  	buf := make([]byte, frameLimit)
   316  	for {
   317  		if len(resp.OutFrames) > 0 {
   318  			lastWriteTime = time.Now()
   319  			if _, err := h.conn.Write(resp.OutFrames); err != nil {
   320  				return nil, nil, err
   321  			}
   322  		}
   323  		if resp.Result != nil {
   324  			return resp.Result, extra, nil
   325  		}
   326  		n, err := h.conn.Read(buf)
   327  		if err != nil && err != io.EOF {
   328  			return nil, nil, err
   329  		}
   330  		// If there is nothing to send to the handshaker service, and
   331  		// nothing is received from the peer, then we are stuck.
   332  		// This covers the case when the peer is not responding. Note
   333  		// that handshaker service connection issues are caught in
   334  		// accessHandshakerService before we even get here.
   335  		if len(resp.OutFrames) == 0 && n == 0 {
   336  			return nil, nil, core.PeerNotRespondingError
   337  		}
   338  		// Append extra bytes from the previous interaction with the
   339  		// handshaker service with the current buffer read from conn.
   340  		p := append(extra, buf[:n]...)
   341  		// Compute the time elapsed since the last write to the peer.
   342  		timeElapsed := time.Since(lastWriteTime)
   343  		timeElapsedMs := uint32(timeElapsed.Milliseconds())
   344  		// From here on, p and extra point to the same slice.
   345  		resp, err = h.accessHandshakerService(&altspb.HandshakerReq{
   346  			ReqOneof: &altspb.HandshakerReq_Next{
   347  				Next: &altspb.NextHandshakeMessageReq{
   348  					InBytes:          p,
   349  					NetworkLatencyMs: timeElapsedMs,
   350  				},
   351  			},
   352  		})
   353  		if err != nil {
   354  			return nil, nil, err
   355  		}
   356  		// Set extra based on handshaker service response.
   357  		if resp.GetBytesConsumed() > uint32(len(p)) {
   358  			return nil, nil, errOutOfBound
   359  		}
   360  		extra = p[resp.GetBytesConsumed():]
   361  	}
   362  }
   363  
   364  // Close terminates the Handshaker. It should be called when the caller obtains
   365  // the secure connection.
   366  func (h *altsHandshaker) Close() {
   367  	if h.stream != nil {
   368  		h.stream.CloseSend()
   369  	}
   370  }
   371  
   372  // ResetConcurrentHandshakeSemaphoreForTesting resets the handshake semaphores
   373  // to allow numberOfAllowedHandshakes concurrent handshakes each.
   374  func ResetConcurrentHandshakeSemaphoreForTesting(numberOfAllowedHandshakes int64) {
   375  	clientHandshakes = semaphore.NewWeighted(numberOfAllowedHandshakes)
   376  	serverHandshakes = semaphore.NewWeighted(numberOfAllowedHandshakes)
   377  }