github.com/hxx258456/ccgo@v0.0.5-0.20230213014102-48b35f46f66f/grpc/credentials/alts/alts.go (about)

     1  /*
     2   *
     3   * Copyright 2018 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  // Package alts implements the ALTS credential support by gRPC library, which
    20  // encapsulates all the state needed by a client to authenticate with a server
    21  // using ALTS and make various assertions, e.g., about the client's identity,
    22  // role, or whether it is authorized to make a particular call.
    23  // This package is experimental.
    24  package alts
    25  
    26  import (
    27  	"context"
    28  	"errors"
    29  	"fmt"
    30  	"net"
    31  	"sync"
    32  	"time"
    33  
    34  	"github.com/hxx258456/ccgo/grpc/credentials"
    35  	core "github.com/hxx258456/ccgo/grpc/credentials/alts/internal"
    36  	"github.com/hxx258456/ccgo/grpc/credentials/alts/internal/handshaker"
    37  	"github.com/hxx258456/ccgo/grpc/credentials/alts/internal/handshaker/service"
    38  	altspb "github.com/hxx258456/ccgo/grpc/credentials/alts/internal/proto/grpc_gcp"
    39  	"github.com/hxx258456/ccgo/grpc/grpclog"
    40  	"github.com/hxx258456/ccgo/grpc/internal/googlecloud"
    41  )
    42  
    43  const (
    44  	// hypervisorHandshakerServiceAddress represents the default ALTS gRPC
    45  	// handshaker service address in the hypervisor.
    46  	hypervisorHandshakerServiceAddress = "metadata.google.internal.:8080"
    47  	// defaultTimeout specifies the server handshake timeout.
    48  	defaultTimeout = 30.0 * time.Second
    49  	// The following constants specify the minimum and maximum acceptable
    50  	// protocol versions.
    51  	protocolVersionMaxMajor = 2
    52  	protocolVersionMaxMinor = 1
    53  	protocolVersionMinMajor = 2
    54  	protocolVersionMinMinor = 1
    55  )
    56  
    57  var (
    58  	vmOnGCP       bool
    59  	once          sync.Once
    60  	maxRPCVersion = &altspb.RpcProtocolVersions_Version{
    61  		Major: protocolVersionMaxMajor,
    62  		Minor: protocolVersionMaxMinor,
    63  	}
    64  	minRPCVersion = &altspb.RpcProtocolVersions_Version{
    65  		Major: protocolVersionMinMajor,
    66  		Minor: protocolVersionMinMinor,
    67  	}
    68  	// ErrUntrustedPlatform is returned from ClientHandshake and
    69  	// ServerHandshake is running on a platform where the trustworthiness of
    70  	// the handshaker service is not guaranteed.
    71  	ErrUntrustedPlatform = errors.New("ALTS: untrusted platform. ALTS is only supported on GCP")
    72  	logger               = grpclog.Component("alts")
    73  )
    74  
    75  // AuthInfo exposes security information from the ALTS handshake to the
    76  // application. This interface is to be implemented by ALTS. Users should not
    77  // need a brand new implementation of this interface. For situations like
    78  // testing, any new implementation should embed this interface. This allows
    79  // ALTS to add new methods to this interface.
    80  type AuthInfo interface {
    81  	// ApplicationProtocol returns application protocol negotiated for the
    82  	// ALTS connection.
    83  	ApplicationProtocol() string
    84  	// RecordProtocol returns the record protocol negotiated for the ALTS
    85  	// connection.
    86  	RecordProtocol() string
    87  	// SecurityLevel returns the security level of the created ALTS secure
    88  	// channel.
    89  	SecurityLevel() altspb.SecurityLevel
    90  	// PeerServiceAccount returns the peer service account.
    91  	PeerServiceAccount() string
    92  	// LocalServiceAccount returns the local service account.
    93  	LocalServiceAccount() string
    94  	// PeerRPCVersions returns the RPC version supported by the peer.
    95  	PeerRPCVersions() *altspb.RpcProtocolVersions
    96  }
    97  
    98  // ClientOptions contains the client-side options of an ALTS channel. These
    99  // options will be passed to the underlying ALTS handshaker.
   100  type ClientOptions struct {
   101  	// TargetServiceAccounts contains a list of expected target service
   102  	// accounts.
   103  	TargetServiceAccounts []string
   104  	// HandshakerServiceAddress represents the ALTS handshaker gRPC service
   105  	// address to connect to.
   106  	HandshakerServiceAddress string
   107  }
   108  
   109  // DefaultClientOptions creates a new ClientOptions object with the default
   110  // values.
   111  func DefaultClientOptions() *ClientOptions {
   112  	return &ClientOptions{
   113  		HandshakerServiceAddress: hypervisorHandshakerServiceAddress,
   114  	}
   115  }
   116  
   117  // ServerOptions contains the server-side options of an ALTS channel. These
   118  // options will be passed to the underlying ALTS handshaker.
   119  type ServerOptions struct {
   120  	// HandshakerServiceAddress represents the ALTS handshaker gRPC service
   121  	// address to connect to.
   122  	HandshakerServiceAddress string
   123  }
   124  
   125  // DefaultServerOptions creates a new ServerOptions object with the default
   126  // values.
   127  func DefaultServerOptions() *ServerOptions {
   128  	return &ServerOptions{
   129  		HandshakerServiceAddress: hypervisorHandshakerServiceAddress,
   130  	}
   131  }
   132  
   133  // altsTC is the credentials required for authenticating a connection using ALTS.
   134  // It implements credentials.TransportCredentials interface.
   135  type altsTC struct {
   136  	info      *credentials.ProtocolInfo
   137  	side      core.Side
   138  	accounts  []string
   139  	hsAddress string
   140  }
   141  
   142  // NewClientCreds constructs a client-side ALTS TransportCredentials object.
   143  func NewClientCreds(opts *ClientOptions) credentials.TransportCredentials {
   144  	return newALTS(core.ClientSide, opts.TargetServiceAccounts, opts.HandshakerServiceAddress)
   145  }
   146  
   147  // NewServerCreds constructs a server-side ALTS TransportCredentials object.
   148  func NewServerCreds(opts *ServerOptions) credentials.TransportCredentials {
   149  	return newALTS(core.ServerSide, nil, opts.HandshakerServiceAddress)
   150  }
   151  
   152  func newALTS(side core.Side, accounts []string, hsAddress string) credentials.TransportCredentials {
   153  	once.Do(func() {
   154  		vmOnGCP = googlecloud.OnGCE()
   155  	})
   156  	if hsAddress == "" {
   157  		hsAddress = hypervisorHandshakerServiceAddress
   158  	}
   159  	return &altsTC{
   160  		info: &credentials.ProtocolInfo{
   161  			SecurityProtocol: "alts",
   162  			SecurityVersion:  "1.0",
   163  		},
   164  		side:      side,
   165  		accounts:  accounts,
   166  		hsAddress: hsAddress,
   167  	}
   168  }
   169  
   170  // ClientHandshake implements the client side handshake protocol.
   171  func (g *altsTC) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (_ net.Conn, _ credentials.AuthInfo, err error) {
   172  	if !vmOnGCP {
   173  		return nil, nil, ErrUntrustedPlatform
   174  	}
   175  
   176  	// Connecting to ALTS handshaker service.
   177  	hsConn, err := service.Dial(g.hsAddress)
   178  	if err != nil {
   179  		return nil, nil, err
   180  	}
   181  	// Do not close hsConn since it is shared with other handshakes.
   182  
   183  	// Possible context leak:
   184  	// The cancel function for the child context we create will only be
   185  	// called a non-nil error is returned.
   186  	var cancel context.CancelFunc
   187  	ctx, cancel = context.WithCancel(ctx)
   188  	defer func() {
   189  		if err != nil {
   190  			cancel()
   191  		}
   192  	}()
   193  
   194  	opts := handshaker.DefaultClientHandshakerOptions()
   195  	opts.TargetName = addr
   196  	opts.TargetServiceAccounts = g.accounts
   197  	opts.RPCVersions = &altspb.RpcProtocolVersions{
   198  		MaxRpcVersion: maxRPCVersion,
   199  		MinRpcVersion: minRPCVersion,
   200  	}
   201  	chs, err := handshaker.NewClientHandshaker(ctx, hsConn, rawConn, opts)
   202  	if err != nil {
   203  		return nil, nil, err
   204  	}
   205  	defer func() {
   206  		if err != nil {
   207  			chs.Close()
   208  		}
   209  	}()
   210  	secConn, authInfo, err := chs.ClientHandshake(ctx)
   211  	if err != nil {
   212  		return nil, nil, err
   213  	}
   214  	altsAuthInfo, ok := authInfo.(AuthInfo)
   215  	if !ok {
   216  		return nil, nil, errors.New("client-side auth info is not of type alts.AuthInfo")
   217  	}
   218  	match, _ := checkRPCVersions(opts.RPCVersions, altsAuthInfo.PeerRPCVersions())
   219  	if !match {
   220  		return nil, nil, fmt.Errorf("server-side RPC versions are not compatible with this client, local versions: %v, peer versions: %v", opts.RPCVersions, altsAuthInfo.PeerRPCVersions())
   221  	}
   222  	return secConn, authInfo, nil
   223  }
   224  
   225  // ServerHandshake implements the server side ALTS handshaker.
   226  func (g *altsTC) ServerHandshake(rawConn net.Conn) (_ net.Conn, _ credentials.AuthInfo, err error) {
   227  	if !vmOnGCP {
   228  		return nil, nil, ErrUntrustedPlatform
   229  	}
   230  	// Connecting to ALTS handshaker service.
   231  	hsConn, err := service.Dial(g.hsAddress)
   232  	if err != nil {
   233  		return nil, nil, err
   234  	}
   235  	// Do not close hsConn since it's shared with other handshakes.
   236  
   237  	ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
   238  	defer cancel()
   239  	opts := handshaker.DefaultServerHandshakerOptions()
   240  	opts.RPCVersions = &altspb.RpcProtocolVersions{
   241  		MaxRpcVersion: maxRPCVersion,
   242  		MinRpcVersion: minRPCVersion,
   243  	}
   244  	shs, err := handshaker.NewServerHandshaker(ctx, hsConn, rawConn, opts)
   245  	if err != nil {
   246  		return nil, nil, err
   247  	}
   248  	defer func() {
   249  		if err != nil {
   250  			shs.Close()
   251  		}
   252  	}()
   253  	secConn, authInfo, err := shs.ServerHandshake(ctx)
   254  	if err != nil {
   255  		return nil, nil, err
   256  	}
   257  	altsAuthInfo, ok := authInfo.(AuthInfo)
   258  	if !ok {
   259  		return nil, nil, errors.New("server-side auth info is not of type alts.AuthInfo")
   260  	}
   261  	match, _ := checkRPCVersions(opts.RPCVersions, altsAuthInfo.PeerRPCVersions())
   262  	if !match {
   263  		return nil, nil, fmt.Errorf("client-side RPC versions is not compatible with this server, local versions: %v, peer versions: %v", opts.RPCVersions, altsAuthInfo.PeerRPCVersions())
   264  	}
   265  	return secConn, authInfo, nil
   266  }
   267  
   268  func (g *altsTC) Info() credentials.ProtocolInfo {
   269  	return *g.info
   270  }
   271  
   272  func (g *altsTC) Clone() credentials.TransportCredentials {
   273  	info := *g.info
   274  	var accounts []string
   275  	if g.accounts != nil {
   276  		accounts = make([]string, len(g.accounts))
   277  		copy(accounts, g.accounts)
   278  	}
   279  	return &altsTC{
   280  		info:      &info,
   281  		side:      g.side,
   282  		hsAddress: g.hsAddress,
   283  		accounts:  accounts,
   284  	}
   285  }
   286  
   287  func (g *altsTC) OverrideServerName(serverNameOverride string) error {
   288  	g.info.ServerName = serverNameOverride
   289  	return nil
   290  }
   291  
   292  // compareRPCVersion returns 0 if v1 == v2, 1 if v1 > v2 and -1 if v1 < v2.
   293  func compareRPCVersions(v1, v2 *altspb.RpcProtocolVersions_Version) int {
   294  	switch {
   295  	case v1.GetMajor() > v2.GetMajor(),
   296  		v1.GetMajor() == v2.GetMajor() && v1.GetMinor() > v2.GetMinor():
   297  		return 1
   298  	case v1.GetMajor() < v2.GetMajor(),
   299  		v1.GetMajor() == v2.GetMajor() && v1.GetMinor() < v2.GetMinor():
   300  		return -1
   301  	}
   302  	return 0
   303  }
   304  
   305  // checkRPCVersions performs a version check between local and peer rpc protocol
   306  // versions. This function returns true if the check passes which means both
   307  // parties agreed on a common rpc protocol to use, and false otherwise. The
   308  // function also returns the highest common RPC protocol version both parties
   309  // agreed on.
   310  func checkRPCVersions(local, peer *altspb.RpcProtocolVersions) (bool, *altspb.RpcProtocolVersions_Version) {
   311  	if local == nil || peer == nil {
   312  		logger.Error("invalid checkRPCVersions argument, either local or peer is nil.")
   313  		return false, nil
   314  	}
   315  
   316  	// maxCommonVersion is MIN(local.max, peer.max).
   317  	maxCommonVersion := local.GetMaxRpcVersion()
   318  	if compareRPCVersions(local.GetMaxRpcVersion(), peer.GetMaxRpcVersion()) > 0 {
   319  		maxCommonVersion = peer.GetMaxRpcVersion()
   320  	}
   321  
   322  	// minCommonVersion is MAX(local.min, peer.min).
   323  	minCommonVersion := peer.GetMinRpcVersion()
   324  	if compareRPCVersions(local.GetMinRpcVersion(), peer.GetMinRpcVersion()) > 0 {
   325  		minCommonVersion = local.GetMinRpcVersion()
   326  	}
   327  
   328  	if compareRPCVersions(maxCommonVersion, minCommonVersion) < 0 {
   329  		return false, nil
   330  	}
   331  	return true, maxCommonVersion
   332  }