github.com/hxx258456/ccgo@v0.0.5-0.20230213014102-48b35f46f66f/grpc/credentials/alts/internal/handshaker/handshaker.go (about)

     1  /*
     2   *
     3   * Copyright 2018 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  // Package handshaker provides ALTS handshaking functionality for GCP.
    20  package handshaker
    21  
    22  import (
    23  	"context"
    24  	"errors"
    25  	"fmt"
    26  	"io"
    27  	"net"
    28  	"sync"
    29  
    30  	grpc "github.com/hxx258456/ccgo/grpc"
    31  	"github.com/hxx258456/ccgo/grpc/codes"
    32  	"github.com/hxx258456/ccgo/grpc/credentials"
    33  	core "github.com/hxx258456/ccgo/grpc/credentials/alts/internal"
    34  	"github.com/hxx258456/ccgo/grpc/credentials/alts/internal/authinfo"
    35  	"github.com/hxx258456/ccgo/grpc/credentials/alts/internal/conn"
    36  	altsgrpc "github.com/hxx258456/ccgo/grpc/credentials/alts/internal/proto/grpc_gcp"
    37  	altspb "github.com/hxx258456/ccgo/grpc/credentials/alts/internal/proto/grpc_gcp"
    38  )
    39  
    40  const (
    41  	// The maximum byte size of receive frames.
    42  	frameLimit              = 64 * 1024 // 64 KB
    43  	rekeyRecordProtocolName = "ALTSRP_GCM_AES128_REKEY"
    44  	// maxPendingHandshakes represents the maximum number of concurrent
    45  	// handshakes.
    46  	maxPendingHandshakes = 100
    47  )
    48  
    49  var (
    50  	hsProtocol      = altspb.HandshakeProtocol_ALTS
    51  	appProtocols    = []string{"grpc"}
    52  	recordProtocols = []string{rekeyRecordProtocolName}
    53  	keyLength       = map[string]int{
    54  		rekeyRecordProtocolName: 44,
    55  	}
    56  	altsRecordFuncs = map[string]conn.ALTSRecordFunc{
    57  		// ALTS handshaker protocols.
    58  		rekeyRecordProtocolName: func(s core.Side, keyData []byte) (conn.ALTSRecordCrypto, error) {
    59  			return conn.NewAES128GCMRekey(s, keyData)
    60  		},
    61  	}
    62  	// control number of concurrent created (but not closed) handshakers.
    63  	mu                   sync.Mutex
    64  	concurrentHandshakes = int64(0)
    65  	// errDropped occurs when maxPendingHandshakes is reached.
    66  	errDropped = errors.New("maximum number of concurrent ALTS handshakes is reached")
    67  	// errOutOfBound occurs when the handshake service returns a consumed
    68  	// bytes value larger than the buffer that was passed to it originally.
    69  	errOutOfBound = errors.New("handshaker service consumed bytes value is out-of-bound")
    70  )
    71  
    72  func init() {
    73  	for protocol, f := range altsRecordFuncs {
    74  		if err := conn.RegisterProtocol(protocol, f); err != nil {
    75  			panic(err)
    76  		}
    77  	}
    78  }
    79  
    80  func acquire() bool {
    81  	mu.Lock()
    82  	// If we need n to be configurable, we can pass it as an argument.
    83  	n := int64(1)
    84  	success := maxPendingHandshakes-concurrentHandshakes >= n
    85  	if success {
    86  		concurrentHandshakes += n
    87  	}
    88  	mu.Unlock()
    89  	return success
    90  }
    91  
    92  func release() {
    93  	mu.Lock()
    94  	// If we need n to be configurable, we can pass it as an argument.
    95  	n := int64(1)
    96  	concurrentHandshakes -= n
    97  	if concurrentHandshakes < 0 {
    98  		mu.Unlock()
    99  		panic("bad release")
   100  	}
   101  	mu.Unlock()
   102  }
   103  
   104  // ClientHandshakerOptions contains the client handshaker options that can
   105  // provided by the caller.
   106  type ClientHandshakerOptions struct {
   107  	// ClientIdentity is the handshaker client local identity.
   108  	ClientIdentity *altspb.Identity
   109  	// TargetName is the server service account name for secure name
   110  	// checking.
   111  	TargetName string
   112  	// TargetServiceAccounts contains a list of expected target service
   113  	// accounts. One of these accounts should match one of the accounts in
   114  	// the handshaker results. Otherwise, the handshake fails.
   115  	TargetServiceAccounts []string
   116  	// RPCVersions specifies the gRPC versions accepted by the client.
   117  	RPCVersions *altspb.RpcProtocolVersions
   118  }
   119  
   120  // ServerHandshakerOptions contains the server handshaker options that can
   121  // provided by the caller.
   122  type ServerHandshakerOptions struct {
   123  	// RPCVersions specifies the gRPC versions accepted by the server.
   124  	RPCVersions *altspb.RpcProtocolVersions
   125  }
   126  
   127  // DefaultClientHandshakerOptions returns the default client handshaker options.
   128  func DefaultClientHandshakerOptions() *ClientHandshakerOptions {
   129  	return &ClientHandshakerOptions{}
   130  }
   131  
   132  // DefaultServerHandshakerOptions returns the default client handshaker options.
   133  func DefaultServerHandshakerOptions() *ServerHandshakerOptions {
   134  	return &ServerHandshakerOptions{}
   135  }
   136  
   137  // TODO: add support for future local and remote endpoint in both client options
   138  //       and server options (server options struct does not exist now. When
   139  //       caller can provide endpoints, it should be created.
   140  
   141  // altsHandshaker is used to complete a ALTS handshaking between client and
   142  // server. This handshaker talks to the ALTS handshaker service in the metadata
   143  // server.
   144  type altsHandshaker struct {
   145  	// RPC stream used to access the ALTS Handshaker service.
   146  	stream altsgrpc.HandshakerService_DoHandshakeClient
   147  	// the connection to the peer.
   148  	conn net.Conn
   149  	// client handshake options.
   150  	clientOpts *ClientHandshakerOptions
   151  	// server handshake options.
   152  	serverOpts *ServerHandshakerOptions
   153  	// defines the side doing the handshake, client or server.
   154  	side core.Side
   155  }
   156  
   157  // NewClientHandshaker creates a ALTS handshaker for GCP which contains an RPC
   158  // stub created using the passed conn and used to talk to the ALTS Handshaker
   159  // service in the metadata server.
   160  func NewClientHandshaker(ctx context.Context, conn *grpc.ClientConn, c net.Conn, opts *ClientHandshakerOptions) (core.Handshaker, error) {
   161  	stream, err := altsgrpc.NewHandshakerServiceClient(conn).DoHandshake(ctx, grpc.WaitForReady(true))
   162  	if err != nil {
   163  		return nil, err
   164  	}
   165  	return &altsHandshaker{
   166  		stream:     stream,
   167  		conn:       c,
   168  		clientOpts: opts,
   169  		side:       core.ClientSide,
   170  	}, nil
   171  }
   172  
   173  // NewServerHandshaker creates a ALTS handshaker for GCP which contains an RPC
   174  // stub created using the passed conn and used to talk to the ALTS Handshaker
   175  // service in the metadata server.
   176  func NewServerHandshaker(ctx context.Context, conn *grpc.ClientConn, c net.Conn, opts *ServerHandshakerOptions) (core.Handshaker, error) {
   177  	stream, err := altsgrpc.NewHandshakerServiceClient(conn).DoHandshake(ctx, grpc.WaitForReady(true))
   178  	if err != nil {
   179  		return nil, err
   180  	}
   181  	return &altsHandshaker{
   182  		stream:     stream,
   183  		conn:       c,
   184  		serverOpts: opts,
   185  		side:       core.ServerSide,
   186  	}, nil
   187  }
   188  
   189  // ClientHandshake starts and completes a client ALTS handshaking for GCP. Once
   190  // done, ClientHandshake returns a secure connection.
   191  func (h *altsHandshaker) ClientHandshake(ctx context.Context) (net.Conn, credentials.AuthInfo, error) {
   192  	if !acquire() {
   193  		return nil, nil, errDropped
   194  	}
   195  	defer release()
   196  
   197  	if h.side != core.ClientSide {
   198  		return nil, nil, errors.New("only handshakers created using NewClientHandshaker can perform a client handshaker")
   199  	}
   200  
   201  	// Create target identities from service account list.
   202  	targetIdentities := make([]*altspb.Identity, 0, len(h.clientOpts.TargetServiceAccounts))
   203  	for _, account := range h.clientOpts.TargetServiceAccounts {
   204  		targetIdentities = append(targetIdentities, &altspb.Identity{
   205  			IdentityOneof: &altspb.Identity_ServiceAccount{
   206  				ServiceAccount: account,
   207  			},
   208  		})
   209  	}
   210  	req := &altspb.HandshakerReq{
   211  		ReqOneof: &altspb.HandshakerReq_ClientStart{
   212  			ClientStart: &altspb.StartClientHandshakeReq{
   213  				HandshakeSecurityProtocol: hsProtocol,
   214  				ApplicationProtocols:      appProtocols,
   215  				RecordProtocols:           recordProtocols,
   216  				TargetIdentities:          targetIdentities,
   217  				LocalIdentity:             h.clientOpts.ClientIdentity,
   218  				TargetName:                h.clientOpts.TargetName,
   219  				RpcVersions:               h.clientOpts.RPCVersions,
   220  			},
   221  		},
   222  	}
   223  
   224  	conn, result, err := h.doHandshake(req)
   225  	if err != nil {
   226  		return nil, nil, err
   227  	}
   228  	authInfo := authinfo.New(result)
   229  	return conn, authInfo, nil
   230  }
   231  
   232  // ServerHandshake starts and completes a server ALTS handshaking for GCP. Once
   233  // done, ServerHandshake returns a secure connection.
   234  func (h *altsHandshaker) ServerHandshake(ctx context.Context) (net.Conn, credentials.AuthInfo, error) {
   235  	if !acquire() {
   236  		return nil, nil, errDropped
   237  	}
   238  	defer release()
   239  
   240  	if h.side != core.ServerSide {
   241  		return nil, nil, errors.New("only handshakers created using NewServerHandshaker can perform a server handshaker")
   242  	}
   243  
   244  	p := make([]byte, frameLimit)
   245  	n, err := h.conn.Read(p)
   246  	if err != nil {
   247  		return nil, nil, err
   248  	}
   249  
   250  	// Prepare server parameters.
   251  	// TODO: currently only ALTS parameters are provided. Might need to use
   252  	//       more options in the future.
   253  	params := make(map[int32]*altspb.ServerHandshakeParameters)
   254  	params[int32(altspb.HandshakeProtocol_ALTS)] = &altspb.ServerHandshakeParameters{
   255  		RecordProtocols: recordProtocols,
   256  	}
   257  	req := &altspb.HandshakerReq{
   258  		ReqOneof: &altspb.HandshakerReq_ServerStart{
   259  			ServerStart: &altspb.StartServerHandshakeReq{
   260  				ApplicationProtocols: appProtocols,
   261  				HandshakeParameters:  params,
   262  				InBytes:              p[:n],
   263  				RpcVersions:          h.serverOpts.RPCVersions,
   264  			},
   265  		},
   266  	}
   267  
   268  	conn, result, err := h.doHandshake(req)
   269  	if err != nil {
   270  		return nil, nil, err
   271  	}
   272  	authInfo := authinfo.New(result)
   273  	return conn, authInfo, nil
   274  }
   275  
   276  func (h *altsHandshaker) doHandshake(req *altspb.HandshakerReq) (net.Conn, *altspb.HandshakerResult, error) {
   277  	resp, err := h.accessHandshakerService(req)
   278  	if err != nil {
   279  		return nil, nil, err
   280  	}
   281  	// Check of the returned status is an error.
   282  	if resp.GetStatus() != nil {
   283  		if got, want := resp.GetStatus().Code, uint32(codes.OK); got != want {
   284  			return nil, nil, fmt.Errorf("%v", resp.GetStatus().Details)
   285  		}
   286  	}
   287  
   288  	var extra []byte
   289  	if req.GetServerStart() != nil {
   290  		if resp.GetBytesConsumed() > uint32(len(req.GetServerStart().GetInBytes())) {
   291  			return nil, nil, errOutOfBound
   292  		}
   293  		extra = req.GetServerStart().GetInBytes()[resp.GetBytesConsumed():]
   294  	}
   295  	result, extra, err := h.processUntilDone(resp, extra)
   296  	if err != nil {
   297  		return nil, nil, err
   298  	}
   299  	// The handshaker returns a 128 bytes key. It should be truncated based
   300  	// on the returned record protocol.
   301  	keyLen, ok := keyLength[result.RecordProtocol]
   302  	if !ok {
   303  		return nil, nil, fmt.Errorf("unknown resulted record protocol %v", result.RecordProtocol)
   304  	}
   305  	sc, err := conn.NewConn(h.conn, h.side, result.GetRecordProtocol(), result.KeyData[:keyLen], extra)
   306  	if err != nil {
   307  		return nil, nil, err
   308  	}
   309  	return sc, result, nil
   310  }
   311  
   312  func (h *altsHandshaker) accessHandshakerService(req *altspb.HandshakerReq) (*altspb.HandshakerResp, error) {
   313  	if err := h.stream.Send(req); err != nil {
   314  		return nil, err
   315  	}
   316  	resp, err := h.stream.Recv()
   317  	if err != nil {
   318  		return nil, err
   319  	}
   320  	return resp, nil
   321  }
   322  
   323  // processUntilDone processes the handshake until the handshaker service returns
   324  // the results. Handshaker service takes care of frame parsing, so we read
   325  // whatever received from the network and send it to the handshaker service.
   326  func (h *altsHandshaker) processUntilDone(resp *altspb.HandshakerResp, extra []byte) (*altspb.HandshakerResult, []byte, error) {
   327  	for {
   328  		if len(resp.OutFrames) > 0 {
   329  			if _, err := h.conn.Write(resp.OutFrames); err != nil {
   330  				return nil, nil, err
   331  			}
   332  		}
   333  		if resp.Result != nil {
   334  			return resp.Result, extra, nil
   335  		}
   336  		buf := make([]byte, frameLimit)
   337  		n, err := h.conn.Read(buf)
   338  		if err != nil && err != io.EOF {
   339  			return nil, nil, err
   340  		}
   341  		// If there is nothing to send to the handshaker service, and
   342  		// nothing is received from the peer, then we are stuck.
   343  		// This covers the case when the peer is not responding. Note
   344  		// that handshaker service connection issues are caught in
   345  		// accessHandshakerService before we even get here.
   346  		if len(resp.OutFrames) == 0 && n == 0 {
   347  			return nil, nil, core.PeerNotRespondingError
   348  		}
   349  		// Append extra bytes from the previous interaction with the
   350  		// handshaker service with the current buffer read from conn.
   351  		p := append(extra, buf[:n]...)
   352  		// From here on, p and extra point to the same slice.
   353  		resp, err = h.accessHandshakerService(&altspb.HandshakerReq{
   354  			ReqOneof: &altspb.HandshakerReq_Next{
   355  				Next: &altspb.NextHandshakeMessageReq{
   356  					InBytes: p,
   357  				},
   358  			},
   359  		})
   360  		if err != nil {
   361  			return nil, nil, err
   362  		}
   363  		// Set extra based on handshaker service response.
   364  		if resp.GetBytesConsumed() > uint32(len(p)) {
   365  			return nil, nil, errOutOfBound
   366  		}
   367  		extra = p[resp.GetBytesConsumed():]
   368  	}
   369  }
   370  
   371  // Close terminates the Handshaker. It should be called when the caller obtains
   372  // the secure connection.
   373  func (h *altsHandshaker) Close() {
   374  	h.stream.CloseSend()
   375  }