google.golang.org/grpc@v1.62.1/credentials/alts/internal/handshaker/handshaker.go (about)

     1  /*
     2   *
     3   * Copyright 2018 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  // Package handshaker provides ALTS handshaking functionality for GCP.
    20  package handshaker
    21  
    22  import (
    23  	"context"
    24  	"errors"
    25  	"fmt"
    26  	"io"
    27  	"net"
    28  	"time"
    29  
    30  	"golang.org/x/sync/semaphore"
    31  	grpc "google.golang.org/grpc"
    32  	"google.golang.org/grpc/codes"
    33  	"google.golang.org/grpc/credentials"
    34  	core "google.golang.org/grpc/credentials/alts/internal"
    35  	"google.golang.org/grpc/credentials/alts/internal/authinfo"
    36  	"google.golang.org/grpc/credentials/alts/internal/conn"
    37  	altsgrpc "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
    38  	altspb "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
    39  	"google.golang.org/grpc/internal/envconfig"
    40  )
    41  
    42  const (
    43  	// The maximum byte size of receive frames.
    44  	frameLimit              = 64 * 1024 // 64 KB
    45  	rekeyRecordProtocolName = "ALTSRP_GCM_AES128_REKEY"
    46  )
    47  
    48  var (
    49  	hsProtocol      = altspb.HandshakeProtocol_ALTS
    50  	appProtocols    = []string{"grpc"}
    51  	recordProtocols = []string{rekeyRecordProtocolName}
    52  	keyLength       = map[string]int{
    53  		rekeyRecordProtocolName: 44,
    54  	}
    55  	altsRecordFuncs = map[string]conn.ALTSRecordFunc{
    56  		// ALTS handshaker protocols.
    57  		rekeyRecordProtocolName: func(s core.Side, keyData []byte) (conn.ALTSRecordCrypto, error) {
    58  			return conn.NewAES128GCMRekey(s, keyData)
    59  		},
    60  	}
    61  	// control number of concurrent created (but not closed) handshakes.
    62  	clientHandshakes = semaphore.NewWeighted(int64(envconfig.ALTSMaxConcurrentHandshakes))
    63  	serverHandshakes = semaphore.NewWeighted(int64(envconfig.ALTSMaxConcurrentHandshakes))
    64  	// errOutOfBound occurs when the handshake service returns a consumed
    65  	// bytes value larger than the buffer that was passed to it originally.
    66  	errOutOfBound = errors.New("handshaker service consumed bytes value is out-of-bound")
    67  )
    68  
    69  func init() {
    70  	for protocol, f := range altsRecordFuncs {
    71  		if err := conn.RegisterProtocol(protocol, f); err != nil {
    72  			panic(err)
    73  		}
    74  	}
    75  }
    76  
    77  // ClientHandshakerOptions contains the client handshaker options that can
    78  // provided by the caller.
    79  type ClientHandshakerOptions struct {
    80  	// ClientIdentity is the handshaker client local identity.
    81  	ClientIdentity *altspb.Identity
    82  	// TargetName is the server service account name for secure name
    83  	// checking.
    84  	TargetName string
    85  	// TargetServiceAccounts contains a list of expected target service
    86  	// accounts. One of these accounts should match one of the accounts in
    87  	// the handshaker results. Otherwise, the handshake fails.
    88  	TargetServiceAccounts []string
    89  	// RPCVersions specifies the gRPC versions accepted by the client.
    90  	RPCVersions *altspb.RpcProtocolVersions
    91  }
    92  
    93  // ServerHandshakerOptions contains the server handshaker options that can
    94  // provided by the caller.
    95  type ServerHandshakerOptions struct {
    96  	// RPCVersions specifies the gRPC versions accepted by the server.
    97  	RPCVersions *altspb.RpcProtocolVersions
    98  }
    99  
   100  // DefaultClientHandshakerOptions returns the default client handshaker options.
   101  func DefaultClientHandshakerOptions() *ClientHandshakerOptions {
   102  	return &ClientHandshakerOptions{}
   103  }
   104  
   105  // DefaultServerHandshakerOptions returns the default client handshaker options.
   106  func DefaultServerHandshakerOptions() *ServerHandshakerOptions {
   107  	return &ServerHandshakerOptions{}
   108  }
   109  
   110  // altsHandshaker is used to complete an ALTS handshake between client and
   111  // server. This handshaker talks to the ALTS handshaker service in the metadata
   112  // server.
   113  type altsHandshaker struct {
   114  	// RPC stream used to access the ALTS Handshaker service.
   115  	stream altsgrpc.HandshakerService_DoHandshakeClient
   116  	// the connection to the peer.
   117  	conn net.Conn
   118  	// a virtual connection to the ALTS handshaker service.
   119  	clientConn *grpc.ClientConn
   120  	// client handshake options.
   121  	clientOpts *ClientHandshakerOptions
   122  	// server handshake options.
   123  	serverOpts *ServerHandshakerOptions
   124  	// defines the side doing the handshake, client or server.
   125  	side core.Side
   126  }
   127  
   128  // NewClientHandshaker creates a core.Handshaker that performs a client-side
   129  // ALTS handshake by acting as a proxy between the peer and the ALTS handshaker
   130  // service in the metadata server.
   131  func NewClientHandshaker(ctx context.Context, conn *grpc.ClientConn, c net.Conn, opts *ClientHandshakerOptions) (core.Handshaker, error) {
   132  	return &altsHandshaker{
   133  		stream:     nil,
   134  		conn:       c,
   135  		clientConn: conn,
   136  		clientOpts: opts,
   137  		side:       core.ClientSide,
   138  	}, nil
   139  }
   140  
   141  // NewServerHandshaker creates a core.Handshaker that performs a server-side
   142  // ALTS handshake by acting as a proxy between the peer and the ALTS handshaker
   143  // service in the metadata server.
   144  func NewServerHandshaker(ctx context.Context, conn *grpc.ClientConn, c net.Conn, opts *ServerHandshakerOptions) (core.Handshaker, error) {
   145  	return &altsHandshaker{
   146  		stream:     nil,
   147  		conn:       c,
   148  		clientConn: conn,
   149  		serverOpts: opts,
   150  		side:       core.ServerSide,
   151  	}, nil
   152  }
   153  
   154  // ClientHandshake starts and completes a client ALTS handshake for GCP. Once
   155  // done, ClientHandshake returns a secure connection.
   156  func (h *altsHandshaker) ClientHandshake(ctx context.Context) (net.Conn, credentials.AuthInfo, error) {
   157  	if err := clientHandshakes.Acquire(ctx, 1); err != nil {
   158  		return nil, nil, err
   159  	}
   160  	defer clientHandshakes.Release(1)
   161  
   162  	if h.side != core.ClientSide {
   163  		return nil, nil, errors.New("only handshakers created using NewClientHandshaker can perform a client handshaker")
   164  	}
   165  
   166  	// TODO(matthewstevenson88): Change unit tests to use public APIs so
   167  	// that h.stream can unconditionally be set based on h.clientConn.
   168  	if h.stream == nil {
   169  		stream, err := altsgrpc.NewHandshakerServiceClient(h.clientConn).DoHandshake(ctx)
   170  		if err != nil {
   171  			return nil, nil, fmt.Errorf("failed to establish stream to ALTS handshaker service: %v", err)
   172  		}
   173  		h.stream = stream
   174  	}
   175  
   176  	// Create target identities from service account list.
   177  	targetIdentities := make([]*altspb.Identity, 0, len(h.clientOpts.TargetServiceAccounts))
   178  	for _, account := range h.clientOpts.TargetServiceAccounts {
   179  		targetIdentities = append(targetIdentities, &altspb.Identity{
   180  			IdentityOneof: &altspb.Identity_ServiceAccount{
   181  				ServiceAccount: account,
   182  			},
   183  		})
   184  	}
   185  	req := &altspb.HandshakerReq{
   186  		ReqOneof: &altspb.HandshakerReq_ClientStart{
   187  			ClientStart: &altspb.StartClientHandshakeReq{
   188  				HandshakeSecurityProtocol: hsProtocol,
   189  				ApplicationProtocols:      appProtocols,
   190  				RecordProtocols:           recordProtocols,
   191  				TargetIdentities:          targetIdentities,
   192  				LocalIdentity:             h.clientOpts.ClientIdentity,
   193  				TargetName:                h.clientOpts.TargetName,
   194  				RpcVersions:               h.clientOpts.RPCVersions,
   195  			},
   196  		},
   197  	}
   198  
   199  	conn, result, err := h.doHandshake(req)
   200  	if err != nil {
   201  		return nil, nil, err
   202  	}
   203  	authInfo := authinfo.New(result)
   204  	return conn, authInfo, nil
   205  }
   206  
   207  // ServerHandshake starts and completes a server ALTS handshake for GCP. Once
   208  // done, ServerHandshake returns a secure connection.
   209  func (h *altsHandshaker) ServerHandshake(ctx context.Context) (net.Conn, credentials.AuthInfo, error) {
   210  	if err := serverHandshakes.Acquire(ctx, 1); err != nil {
   211  		return nil, nil, err
   212  	}
   213  	defer serverHandshakes.Release(1)
   214  
   215  	if h.side != core.ServerSide {
   216  		return nil, nil, errors.New("only handshakers created using NewServerHandshaker can perform a server handshaker")
   217  	}
   218  
   219  	// TODO(matthewstevenson88): Change unit tests to use public APIs so
   220  	// that h.stream can unconditionally be set based on h.clientConn.
   221  	if h.stream == nil {
   222  		stream, err := altsgrpc.NewHandshakerServiceClient(h.clientConn).DoHandshake(ctx)
   223  		if err != nil {
   224  			return nil, nil, fmt.Errorf("failed to establish stream to ALTS handshaker service: %v", err)
   225  		}
   226  		h.stream = stream
   227  	}
   228  
   229  	p := make([]byte, frameLimit)
   230  	n, err := h.conn.Read(p)
   231  	if err != nil {
   232  		return nil, nil, err
   233  	}
   234  
   235  	// Prepare server parameters.
   236  	params := make(map[int32]*altspb.ServerHandshakeParameters)
   237  	params[int32(altspb.HandshakeProtocol_ALTS)] = &altspb.ServerHandshakeParameters{
   238  		RecordProtocols: recordProtocols,
   239  	}
   240  	req := &altspb.HandshakerReq{
   241  		ReqOneof: &altspb.HandshakerReq_ServerStart{
   242  			ServerStart: &altspb.StartServerHandshakeReq{
   243  				ApplicationProtocols: appProtocols,
   244  				HandshakeParameters:  params,
   245  				InBytes:              p[:n],
   246  				RpcVersions:          h.serverOpts.RPCVersions,
   247  			},
   248  		},
   249  	}
   250  
   251  	conn, result, err := h.doHandshake(req)
   252  	if err != nil {
   253  		return nil, nil, err
   254  	}
   255  	authInfo := authinfo.New(result)
   256  	return conn, authInfo, nil
   257  }
   258  
   259  func (h *altsHandshaker) doHandshake(req *altspb.HandshakerReq) (net.Conn, *altspb.HandshakerResult, error) {
   260  	resp, err := h.accessHandshakerService(req)
   261  	if err != nil {
   262  		return nil, nil, err
   263  	}
   264  	// Check of the returned status is an error.
   265  	if resp.GetStatus() != nil {
   266  		if got, want := resp.GetStatus().Code, uint32(codes.OK); got != want {
   267  			return nil, nil, fmt.Errorf("%v", resp.GetStatus().Details)
   268  		}
   269  	}
   270  
   271  	var extra []byte
   272  	if req.GetServerStart() != nil {
   273  		if resp.GetBytesConsumed() > uint32(len(req.GetServerStart().GetInBytes())) {
   274  			return nil, nil, errOutOfBound
   275  		}
   276  		extra = req.GetServerStart().GetInBytes()[resp.GetBytesConsumed():]
   277  	}
   278  	result, extra, err := h.processUntilDone(resp, extra)
   279  	if err != nil {
   280  		return nil, nil, err
   281  	}
   282  	// The handshaker returns a 128 bytes key. It should be truncated based
   283  	// on the returned record protocol.
   284  	keyLen, ok := keyLength[result.RecordProtocol]
   285  	if !ok {
   286  		return nil, nil, fmt.Errorf("unknown resulted record protocol %v", result.RecordProtocol)
   287  	}
   288  	sc, err := conn.NewConn(h.conn, h.side, result.GetRecordProtocol(), result.KeyData[:keyLen], extra)
   289  	if err != nil {
   290  		return nil, nil, err
   291  	}
   292  	return sc, result, nil
   293  }
   294  
   295  func (h *altsHandshaker) accessHandshakerService(req *altspb.HandshakerReq) (*altspb.HandshakerResp, error) {
   296  	if err := h.stream.Send(req); err != nil {
   297  		return nil, err
   298  	}
   299  	resp, err := h.stream.Recv()
   300  	if err != nil {
   301  		return nil, err
   302  	}
   303  	return resp, nil
   304  }
   305  
   306  // processUntilDone processes the handshake until the handshaker service returns
   307  // the results. Handshaker service takes care of frame parsing, so we read
   308  // whatever received from the network and send it to the handshaker service.
   309  func (h *altsHandshaker) processUntilDone(resp *altspb.HandshakerResp, extra []byte) (*altspb.HandshakerResult, []byte, error) {
   310  	var lastWriteTime time.Time
   311  	for {
   312  		if len(resp.OutFrames) > 0 {
   313  			lastWriteTime = time.Now()
   314  			if _, err := h.conn.Write(resp.OutFrames); err != nil {
   315  				return nil, nil, err
   316  			}
   317  		}
   318  		if resp.Result != nil {
   319  			return resp.Result, extra, nil
   320  		}
   321  		buf := make([]byte, frameLimit)
   322  		n, err := h.conn.Read(buf)
   323  		if err != nil && err != io.EOF {
   324  			return nil, nil, err
   325  		}
   326  		// If there is nothing to send to the handshaker service, and
   327  		// nothing is received from the peer, then we are stuck.
   328  		// This covers the case when the peer is not responding. Note
   329  		// that handshaker service connection issues are caught in
   330  		// accessHandshakerService before we even get here.
   331  		if len(resp.OutFrames) == 0 && n == 0 {
   332  			return nil, nil, core.PeerNotRespondingError
   333  		}
   334  		// Append extra bytes from the previous interaction with the
   335  		// handshaker service with the current buffer read from conn.
   336  		p := append(extra, buf[:n]...)
   337  		// Compute the time elapsed since the last write to the peer.
   338  		timeElapsed := time.Since(lastWriteTime)
   339  		timeElapsedMs := uint32(timeElapsed.Milliseconds())
   340  		// From here on, p and extra point to the same slice.
   341  		resp, err = h.accessHandshakerService(&altspb.HandshakerReq{
   342  			ReqOneof: &altspb.HandshakerReq_Next{
   343  				Next: &altspb.NextHandshakeMessageReq{
   344  					InBytes:          p,
   345  					NetworkLatencyMs: timeElapsedMs,
   346  				},
   347  			},
   348  		})
   349  		if err != nil {
   350  			return nil, nil, err
   351  		}
   352  		// Set extra based on handshaker service response.
   353  		if resp.GetBytesConsumed() > uint32(len(p)) {
   354  			return nil, nil, errOutOfBound
   355  		}
   356  		extra = p[resp.GetBytesConsumed():]
   357  	}
   358  }
   359  
   360  // Close terminates the Handshaker. It should be called when the caller obtains
   361  // the secure connection.
   362  func (h *altsHandshaker) Close() {
   363  	if h.stream != nil {
   364  		h.stream.CloseSend()
   365  	}
   366  }
   367  
   368  // ResetConcurrentHandshakeSemaphoreForTesting resets the handshake semaphores
   369  // to allow numberOfAllowedHandshakes concurrent handshakes each.
   370  func ResetConcurrentHandshakeSemaphoreForTesting(numberOfAllowedHandshakes int64) {
   371  	clientHandshakes = semaphore.NewWeighted(numberOfAllowedHandshakes)
   372  	serverHandshakes = semaphore.NewWeighted(numberOfAllowedHandshakes)
   373  }