google.golang.org/grpc@v1.74.2/credentials/alts/alts.go (about)

     1  /*
     2   *
     3   * Copyright 2018 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  // Package alts implements the ALTS credential support by gRPC library, which
    20  // encapsulates all the state needed by a client to authenticate with a server
    21  // using ALTS and make various assertions, e.g., about the client's identity,
    22  // role, or whether it is authorized to make a particular call.
    23  // This package is experimental.
    24  package alts
    25  
    26  import (
    27  	"context"
    28  	"errors"
    29  	"fmt"
    30  	"net"
    31  	"sync"
    32  	"time"
    33  
    34  	"google.golang.org/grpc/credentials"
    35  	core "google.golang.org/grpc/credentials/alts/internal"
    36  	"google.golang.org/grpc/credentials/alts/internal/handshaker"
    37  	"google.golang.org/grpc/credentials/alts/internal/handshaker/service"
    38  	altspb "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
    39  	"google.golang.org/grpc/grpclog"
    40  	"google.golang.org/grpc/internal/googlecloud"
    41  )
    42  
    43  const (
    44  	// hypervisorHandshakerServiceAddress represents the default ALTS gRPC
    45  	// handshaker service address in the hypervisor.
    46  	hypervisorHandshakerServiceAddress = "dns:///metadata.google.internal.:8080"
    47  	// defaultTimeout specifies the server handshake timeout.
    48  	defaultTimeout = 30.0 * time.Second
    49  	// The following constants specify the minimum and maximum acceptable
    50  	// protocol versions.
    51  	protocolVersionMaxMajor = 2
    52  	protocolVersionMaxMinor = 1
    53  	protocolVersionMinMajor = 2
    54  	protocolVersionMinMinor = 1
    55  )
    56  
    57  var (
    58  	vmOnGCP       bool
    59  	once          sync.Once
    60  	maxRPCVersion = &altspb.RpcProtocolVersions_Version{
    61  		Major: protocolVersionMaxMajor,
    62  		Minor: protocolVersionMaxMinor,
    63  	}
    64  	minRPCVersion = &altspb.RpcProtocolVersions_Version{
    65  		Major: protocolVersionMinMajor,
    66  		Minor: protocolVersionMinMinor,
    67  	}
    68  	// ErrUntrustedPlatform is returned from ClientHandshake and
    69  	// ServerHandshake is running on a platform where the trustworthiness of
    70  	// the handshaker service is not guaranteed.
    71  	ErrUntrustedPlatform = errors.New("ALTS: untrusted platform. ALTS is only supported on GCP")
    72  	logger               = grpclog.Component("alts")
    73  )
    74  
    75  // AuthInfo exposes security information from the ALTS handshake to the
    76  // application. This interface is to be implemented by ALTS. Users should not
    77  // need a brand new implementation of this interface. For situations like
    78  // testing, any new implementation should embed this interface. This allows
    79  // ALTS to add new methods to this interface.
    80  type AuthInfo interface {
    81  	// ApplicationProtocol returns application protocol negotiated for the
    82  	// ALTS connection.
    83  	ApplicationProtocol() string
    84  	// RecordProtocol returns the record protocol negotiated for the ALTS
    85  	// connection.
    86  	RecordProtocol() string
    87  	// SecurityLevel returns the security level of the created ALTS secure
    88  	// channel.
    89  	SecurityLevel() altspb.SecurityLevel
    90  	// PeerServiceAccount returns the peer service account.
    91  	PeerServiceAccount() string
    92  	// LocalServiceAccount returns the local service account.
    93  	LocalServiceAccount() string
    94  	// PeerRPCVersions returns the RPC version supported by the peer.
    95  	PeerRPCVersions() *altspb.RpcProtocolVersions
    96  }
    97  
    98  // ClientOptions contains the client-side options of an ALTS channel. These
    99  // options will be passed to the underlying ALTS handshaker.
   100  type ClientOptions struct {
   101  	// TargetServiceAccounts contains a list of expected target service
   102  	// accounts.
   103  	TargetServiceAccounts []string
   104  	// HandshakerServiceAddress represents the ALTS handshaker gRPC service
   105  	// address to connect to.
   106  	HandshakerServiceAddress string
   107  }
   108  
   109  // DefaultClientOptions creates a new ClientOptions object with the default
   110  // values.
   111  func DefaultClientOptions() *ClientOptions {
   112  	return &ClientOptions{
   113  		HandshakerServiceAddress: hypervisorHandshakerServiceAddress,
   114  	}
   115  }
   116  
   117  // ServerOptions contains the server-side options of an ALTS channel. These
   118  // options will be passed to the underlying ALTS handshaker.
   119  type ServerOptions struct {
   120  	// HandshakerServiceAddress represents the ALTS handshaker gRPC service
   121  	// address to connect to.
   122  	HandshakerServiceAddress string
   123  }
   124  
   125  // DefaultServerOptions creates a new ServerOptions object with the default
   126  // values.
   127  func DefaultServerOptions() *ServerOptions {
   128  	return &ServerOptions{
   129  		HandshakerServiceAddress: hypervisorHandshakerServiceAddress,
   130  	}
   131  }
   132  
   133  // altsTC is the credentials required for authenticating a connection using ALTS.
   134  // It implements credentials.TransportCredentials interface.
   135  type altsTC struct {
   136  	info             *credentials.ProtocolInfo
   137  	side             core.Side
   138  	accounts         []string
   139  	hsAddress        string
   140  	boundAccessToken string
   141  }
   142  
   143  // NewClientCreds constructs a client-side ALTS TransportCredentials object.
   144  func NewClientCreds(opts *ClientOptions) credentials.TransportCredentials {
   145  	return newALTS(core.ClientSide, opts.TargetServiceAccounts, opts.HandshakerServiceAddress)
   146  }
   147  
   148  // NewServerCreds constructs a server-side ALTS TransportCredentials object.
   149  func NewServerCreds(opts *ServerOptions) credentials.TransportCredentials {
   150  	return newALTS(core.ServerSide, nil, opts.HandshakerServiceAddress)
   151  }
   152  
   153  func newALTS(side core.Side, accounts []string, hsAddress string) credentials.TransportCredentials {
   154  	once.Do(func() {
   155  		vmOnGCP = googlecloud.OnGCE()
   156  	})
   157  	if hsAddress == "" {
   158  		hsAddress = hypervisorHandshakerServiceAddress
   159  	}
   160  	return &altsTC{
   161  		info: &credentials.ProtocolInfo{
   162  			SecurityProtocol: "alts",
   163  			SecurityVersion:  "1.0",
   164  		},
   165  		side:      side,
   166  		accounts:  accounts,
   167  		hsAddress: hsAddress,
   168  	}
   169  }
   170  
   171  // ClientHandshake implements the client side handshake protocol.
   172  func (g *altsTC) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (_ net.Conn, _ credentials.AuthInfo, err error) {
   173  	if !vmOnGCP {
   174  		return nil, nil, ErrUntrustedPlatform
   175  	}
   176  
   177  	// Connecting to ALTS handshaker service.
   178  	hsConn, err := service.Dial(g.hsAddress)
   179  	if err != nil {
   180  		return nil, nil, err
   181  	}
   182  	// Do not close hsConn since it is shared with other handshakes.
   183  
   184  	// Possible context leak:
   185  	// The cancel function for the child context we create will only be
   186  	// called a non-nil error is returned.
   187  	var cancel context.CancelFunc
   188  	ctx, cancel = context.WithCancel(ctx)
   189  	defer func() {
   190  		if err != nil {
   191  			cancel()
   192  		}
   193  	}()
   194  
   195  	opts := handshaker.DefaultClientHandshakerOptions()
   196  	opts.TargetName = addr
   197  	opts.TargetServiceAccounts = g.accounts
   198  	opts.RPCVersions = &altspb.RpcProtocolVersions{
   199  		MaxRpcVersion: maxRPCVersion,
   200  		MinRpcVersion: minRPCVersion,
   201  	}
   202  	opts.BoundAccessToken = g.boundAccessToken
   203  	chs, err := handshaker.NewClientHandshaker(ctx, hsConn, rawConn, opts)
   204  	if err != nil {
   205  		return nil, nil, err
   206  	}
   207  	defer func() {
   208  		if err != nil {
   209  			chs.Close()
   210  		}
   211  	}()
   212  	secConn, authInfo, err := chs.ClientHandshake(ctx)
   213  	if err != nil {
   214  		return nil, nil, err
   215  	}
   216  	altsAuthInfo, ok := authInfo.(AuthInfo)
   217  	if !ok {
   218  		return nil, nil, errors.New("client-side auth info is not of type alts.AuthInfo")
   219  	}
   220  	match, _ := checkRPCVersions(opts.RPCVersions, altsAuthInfo.PeerRPCVersions())
   221  	if !match {
   222  		return nil, nil, fmt.Errorf("server-side RPC versions are not compatible with this client, local versions: %v, peer versions: %v", opts.RPCVersions, altsAuthInfo.PeerRPCVersions())
   223  	}
   224  	return secConn, authInfo, nil
   225  }
   226  
   227  // ServerHandshake implements the server side ALTS handshaker.
   228  func (g *altsTC) ServerHandshake(rawConn net.Conn) (_ net.Conn, _ credentials.AuthInfo, err error) {
   229  	if !vmOnGCP {
   230  		return nil, nil, ErrUntrustedPlatform
   231  	}
   232  	// Connecting to ALTS handshaker service.
   233  	hsConn, err := service.Dial(g.hsAddress)
   234  	if err != nil {
   235  		return nil, nil, err
   236  	}
   237  	// Do not close hsConn since it's shared with other handshakes.
   238  
   239  	ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
   240  	defer cancel()
   241  	opts := handshaker.DefaultServerHandshakerOptions()
   242  	opts.RPCVersions = &altspb.RpcProtocolVersions{
   243  		MaxRpcVersion: maxRPCVersion,
   244  		MinRpcVersion: minRPCVersion,
   245  	}
   246  	shs, err := handshaker.NewServerHandshaker(ctx, hsConn, rawConn, opts)
   247  	if err != nil {
   248  		return nil, nil, err
   249  	}
   250  	defer func() {
   251  		if err != nil {
   252  			shs.Close()
   253  		}
   254  	}()
   255  	secConn, authInfo, err := shs.ServerHandshake(ctx)
   256  	if err != nil {
   257  		return nil, nil, err
   258  	}
   259  	altsAuthInfo, ok := authInfo.(AuthInfo)
   260  	if !ok {
   261  		return nil, nil, errors.New("server-side auth info is not of type alts.AuthInfo")
   262  	}
   263  	match, _ := checkRPCVersions(opts.RPCVersions, altsAuthInfo.PeerRPCVersions())
   264  	if !match {
   265  		return nil, nil, fmt.Errorf("client-side RPC versions is not compatible with this server, local versions: %v, peer versions: %v", opts.RPCVersions, altsAuthInfo.PeerRPCVersions())
   266  	}
   267  	return secConn, authInfo, nil
   268  }
   269  
   270  func (g *altsTC) Info() credentials.ProtocolInfo {
   271  	return *g.info
   272  }
   273  
   274  func (g *altsTC) Clone() credentials.TransportCredentials {
   275  	info := *g.info
   276  	var accounts []string
   277  	if g.accounts != nil {
   278  		accounts = make([]string, len(g.accounts))
   279  		copy(accounts, g.accounts)
   280  	}
   281  	return &altsTC{
   282  		info:      &info,
   283  		side:      g.side,
   284  		hsAddress: g.hsAddress,
   285  		accounts:  accounts,
   286  	}
   287  }
   288  
   289  func (g *altsTC) OverrideServerName(serverNameOverride string) error {
   290  	g.info.ServerName = serverNameOverride
   291  	return nil
   292  }
   293  
   294  // compareRPCVersion returns 0 if v1 == v2, 1 if v1 > v2 and -1 if v1 < v2.
   295  func compareRPCVersions(v1, v2 *altspb.RpcProtocolVersions_Version) int {
   296  	switch {
   297  	case v1.GetMajor() > v2.GetMajor(),
   298  		v1.GetMajor() == v2.GetMajor() && v1.GetMinor() > v2.GetMinor():
   299  		return 1
   300  	case v1.GetMajor() < v2.GetMajor(),
   301  		v1.GetMajor() == v2.GetMajor() && v1.GetMinor() < v2.GetMinor():
   302  		return -1
   303  	}
   304  	return 0
   305  }
   306  
   307  // checkRPCVersions performs a version check between local and peer rpc protocol
   308  // versions. This function returns true if the check passes which means both
   309  // parties agreed on a common rpc protocol to use, and false otherwise. The
   310  // function also returns the highest common RPC protocol version both parties
   311  // agreed on.
   312  func checkRPCVersions(local, peer *altspb.RpcProtocolVersions) (bool, *altspb.RpcProtocolVersions_Version) {
   313  	if local == nil || peer == nil {
   314  		logger.Error("invalid checkRPCVersions argument, either local or peer is nil.")
   315  		return false, nil
   316  	}
   317  
   318  	// maxCommonVersion is MIN(local.max, peer.max).
   319  	maxCommonVersion := local.GetMaxRpcVersion()
   320  	if compareRPCVersions(local.GetMaxRpcVersion(), peer.GetMaxRpcVersion()) > 0 {
   321  		maxCommonVersion = peer.GetMaxRpcVersion()
   322  	}
   323  
   324  	// minCommonVersion is MAX(local.min, peer.min).
   325  	minCommonVersion := peer.GetMinRpcVersion()
   326  	if compareRPCVersions(local.GetMinRpcVersion(), peer.GetMinRpcVersion()) > 0 {
   327  		minCommonVersion = local.GetMinRpcVersion()
   328  	}
   329  
   330  	if compareRPCVersions(maxCommonVersion, minCommonVersion) < 0 {
   331  		return false, nil
   332  	}
   333  	return true, maxCommonVersion
   334  }