gitee.com/lh-her-team/common@v1.5.1/crypto/tls/credentials/credentials.go (about)

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package credentials
     6  
     7  import (
     8  	"errors"
     9  	"fmt"
    10  	"io/ioutil"
    11  	"net"
    12  	"strings"
    13  
    14  	cmtls "gitee.com/lh-her-team/common/crypto/tls"
    15  	cmx509 "gitee.com/lh-her-team/common/crypto/x509"
    16  	"golang.org/x/net/context"
    17  	"google.golang.org/grpc/credentials"
    18  )
    19  
    20  var (
    21  	// alpnProtoStr are the specified application level protocols for gRPC.
    22  	alpnProtoStr = []string{"h2"}
    23  )
    24  
    25  // PerRPCCredentials defines the common interface for the credentials which need to
    26  // attach security information to every RPC (e.g., oauth2).
    27  type PerRPCCredentials interface {
    28  	// GetRequestMetadata gets the current request metadata, refreshing
    29  	// tokens if required. This should be called by the transport layer on
    30  	// each request, and the data should be populated in headers or other
    31  	// context. uri is the URI of the entry point for the request. When
    32  	// supported by the underlying implementation, ctx can be used for
    33  	// timeout and cancellation.
    34  	// TODO(zhaoq): Define the set of the qualified keys instead of leaving
    35  	// it as an arbitrary string.
    36  	GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error)
    37  	// RequireTransportSecurity indicates whether the credentials requires
    38  	// transport security.
    39  	RequireTransportSecurity() bool
    40  }
    41  
    42  // ProtocolInfo provides information regarding the gRPC wire protocol version,
    43  // security protocol, security protocol version in use, server name, etc.
    44  type ProtocolInfo struct {
    45  	// ProtocolVersion is the gRPC wire protocol version.
    46  	ProtocolVersion string
    47  	// SecurityProtocol is the security protocol in use.
    48  	SecurityProtocol string
    49  	// SecurityVersion is the security protocol version.
    50  	SecurityVersion string
    51  	// ServerName is the user-configured server name.
    52  	ServerName string
    53  }
    54  
    55  // AuthInfo defines the common interface for the auth information the users are interested in.
    56  type AuthInfo interface {
    57  	AuthType() string
    58  }
    59  
    60  var (
    61  	// ErrConnDispatched indicates that rawConn has been dispatched out of gRPC
    62  	// and the caller should not close rawConn.
    63  	ErrConnDispatched = errors.New("credentials: rawConn is dispatched out of gRPC")
    64  )
    65  
    66  // TLSInfo contains the auth information for a TLS authenticated connection.
    67  // It implements the AuthInfo interface.
    68  type TLSInfo struct {
    69  	State cmtls.ConnectionState
    70  }
    71  
    72  // AuthType returns the type of TLSInfo as a string.
    73  func (t TLSInfo) AuthType() string {
    74  	return "tls"
    75  }
    76  
    77  // tlsCreds is the credentials required for authenticating a connection using TLS.
    78  type tlsCreds struct {
    79  	// TLS configuration
    80  	config *cmtls.Config
    81  }
    82  
    83  func (c tlsCreds) Info() credentials.ProtocolInfo {
    84  	return credentials.ProtocolInfo{
    85  		SecurityProtocol: "tls",
    86  		SecurityVersion:  "1.2",
    87  		ServerName:       c.config.ServerName,
    88  	}
    89  }
    90  
    91  func (c *tlsCreds) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (_ net.Conn, _ credentials.AuthInfo, err error) {
    92  	// use local cfg to avoid clobbering ServerName if using multiple endpoints
    93  	cfg := cloneTLSConfig(c.config)
    94  	if cfg.ServerName == "" {
    95  		colonPos := strings.LastIndex(addr, ":")
    96  		if colonPos == -1 {
    97  			colonPos = len(addr)
    98  		}
    99  		cfg.ServerName = addr[:colonPos]
   100  	}
   101  	conn := cmtls.Client(rawConn, cfg)
   102  	errChannel := make(chan error, 1)
   103  	go func() {
   104  		errChannel <- conn.Handshake()
   105  	}()
   106  	select {
   107  	case err := <-errChannel:
   108  		if err != nil {
   109  			return nil, nil, err
   110  		}
   111  	case <-ctx.Done():
   112  		return nil, nil, ctx.Err()
   113  	}
   114  	return conn, TLSInfo{conn.ConnectionState()}, nil
   115  }
   116  
   117  func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
   118  	conn := cmtls.Server(rawConn, c.config)
   119  	if err := conn.Handshake(); err != nil {
   120  		return nil, nil, err
   121  	}
   122  	return conn, TLSInfo{conn.ConnectionState()}, nil
   123  }
   124  
   125  func (c *tlsCreds) Clone() credentials.TransportCredentials {
   126  	return NewTLS(c.config)
   127  }
   128  
   129  func (c *tlsCreds) OverrideServerName(serverNameOverride string) error {
   130  	c.config.ServerName = serverNameOverride
   131  	return nil
   132  }
   133  
   134  // NewTLS uses c to construct a TransportCredentials based on TLS.
   135  func NewTLS(c *cmtls.Config) credentials.TransportCredentials {
   136  	tc := &tlsCreds{cloneTLSConfig(c)}
   137  	tc.config.NextProtos = alpnProtoStr
   138  	return tc
   139  }
   140  
   141  // NewClientTLSFromCert constructs TLS credentials from the input certificate for client.
   142  // serverNameOverride is for testing only. If set to a non empty string,
   143  // it will override the virtual host name of authority (e.g. :authority header field) in requests.
   144  func NewClientTLSFromCert(cp *cmx509.CertPool, serverNameOverride string) credentials.TransportCredentials {
   145  	return NewTLS(&cmtls.Config{ServerName: serverNameOverride, RootCAs: cp})
   146  }
   147  
   148  // NewClientTLSFromFile constructs TLS credentials from the input certificate file for client.
   149  // serverNameOverride is for testing only. If set to a non empty string,
   150  // it will override the virtual host name of authority (e.g. :authority header field) in requests.
   151  func NewClientTLSFromFile(certFile, serverNameOverride string) (credentials.TransportCredentials, error) {
   152  	b, err := ioutil.ReadFile(certFile)
   153  	if err != nil {
   154  		return nil, err
   155  	}
   156  	cp := cmx509.NewCertPool()
   157  	if !cp.AppendCertsFromPEM(b) {
   158  		return nil, fmt.Errorf("credentials: failed to append certificates")
   159  	}
   160  	return NewTLS(&cmtls.Config{ServerName: serverNameOverride, RootCAs: cp}), nil
   161  }
   162  
   163  // NewServerTLSFromCert constructs TLS credentials from the input certificate for server.
   164  func NewServerTLSFromCert(cert *cmtls.Certificate) credentials.TransportCredentials {
   165  	return NewTLS(&cmtls.Config{Certificates: []cmtls.Certificate{*cert}})
   166  }
   167  
   168  // NewServerTLSFromFile constructs TLS credentials from the input certificate file and key
   169  // file for server.
   170  func NewServerTLSFromFile(certFile, keyFile string) (credentials.TransportCredentials, error) {
   171  	cert, err := cmtls.LoadX509KeyPair(certFile, keyFile)
   172  	if err != nil {
   173  		return nil, err
   174  	}
   175  	return NewTLS(&cmtls.Config{Certificates: []cmtls.Certificate{cert}}), nil
   176  }