github.com/hellobchain/newcryptosm@v0.0.0-20221019060107-edb949a317e9/tls/credentials/credentials.go (about)

     1  /*
     2  Copyright Suzhou Tongji Fintech Research Institute 2017 All Rights Reserved.
     3  Licensed under the Apache License, Version 2.0 (the "License");
     4  you may not use this file except in compliance with the License.
     5  You may obtain a copy of the License at
     6  
     7  	http://www.apache.org/licenses/LICENSE-2.0
     8  
     9  Unless required by applicable law or agreed to in writing, software
    10  distributed under the License is distributed on an "AS IS" BASIS,
    11  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  See the License for the specific language governing permissions and
    13  limitations under the License.
    14  */
    15  
    16  package credentials
    17  
    18  import (
    19  	"errors"
    20  	"fmt"
    21  	"github.com/hellobchain/newcryptosm/tls"
    22  	"github.com/hellobchain/newcryptosm/x509"
    23  	"io/ioutil"
    24  	"net"
    25  	"strings"
    26  
    27  	"google.golang.org/grpc/credentials"
    28  
    29  	"golang.org/x/net/context"
    30  )
    31  
    32  var (
    33  	// alpnProtoStr are the specified application level protocols for gRPC.
    34  	alpnProtoStr = []string{"h2"}
    35  )
    36  
    37  // PerRPCCredentials defines the common interface for the credentials which need to
    38  // attach security information to every RPC (e.g., oauth2).
    39  type PerRPCCredentials interface {
    40  	// GetRequestMetadata gets the current request metadata, refreshing
    41  	// tokens if required. This should be called by the transport layer on
    42  	// each request, and the data should be populated in headers or other
    43  	// context. uri is the URI of the entry point for the request. When
    44  	// supported by the underlying implementation, ctx can be used for
    45  	// timeout and cancellation.
    46  	// TODO(zhaoq): Define the set of the qualified keys instead of leaving
    47  	// it as an arbitrary string.
    48  	GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error)
    49  	// RequireTransportSecurity indicates whether the credentials requires
    50  	// transport security.
    51  	RequireTransportSecurity() bool
    52  }
    53  
    54  // ProtocolInfo provides information regarding the gRPC wire protocol version,
    55  // security protocol, security protocol version in use, server name, etc.
    56  type ProtocolInfo struct {
    57  	// ProtocolVersion is the gRPC wire protocol version.
    58  	ProtocolVersion string
    59  	// SecurityProtocol is the security protocol in use.
    60  	SecurityProtocol string
    61  	// SecurityVersion is the security protocol version.
    62  	SecurityVersion string
    63  	// ServerName is the user-configured server name.
    64  	ServerName string
    65  }
    66  
    67  // AuthInfo defines the common interface for the auth information the users are interested in.
    68  type AuthInfo interface {
    69  	AuthType() string
    70  }
    71  
    72  var (
    73  	// ErrConnDispatched indicates that rawConn has been dispatched out of gRPC
    74  	// and the caller should not close rawConn.
    75  	ErrConnDispatched = errors.New("credentials: rawConn is dispatched out of gRPC")
    76  )
    77  
    78  // TLSInfo contains the auth information for a TLS authenticated connection.
    79  // It implements the AuthInfo interface.
    80  type TLSInfo struct {
    81  	State tls.ConnectionState
    82  }
    83  
    84  // AuthType returns the type of TLSInfo as a string.
    85  func (t TLSInfo) AuthType() string {
    86  	return "tls"
    87  }
    88  
    89  // tlsCreds is the credentials required for authenticating a connection using TLS.
    90  type tlsCreds struct {
    91  	// TLS configuration
    92  	config *tls.Config
    93  }
    94  
    95  func (c tlsCreds) Info() credentials.ProtocolInfo {
    96  	return credentials.ProtocolInfo{
    97  		SecurityProtocol: "tls",
    98  		SecurityVersion:  "1.2",
    99  		ServerName:       c.config.ServerName,
   100  	}
   101  }
   102  
   103  func (c *tlsCreds) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (_ net.Conn, _ credentials.AuthInfo, err error) {
   104  	// use local cfg to avoid clobbering ServerName if using multiple endpoints
   105  	cfg := cloneTLSConfig(c.config)
   106  	if cfg.ServerName == "" {
   107  		colonPos := strings.LastIndex(addr, ":")
   108  		if colonPos == -1 {
   109  			colonPos = len(addr)
   110  		}
   111  		cfg.ServerName = addr[:colonPos]
   112  	}
   113  	conn := tls.Client(rawConn, cfg)
   114  	errChannel := make(chan error, 1)
   115  	go func() {
   116  		errChannel <- conn.Handshake()
   117  	}()
   118  	select {
   119  	case err := <-errChannel:
   120  		if err != nil {
   121  			return nil, nil, err
   122  		}
   123  	case <-ctx.Done():
   124  		return nil, nil, ctx.Err()
   125  	}
   126  	return conn, TLSInfo{conn.ConnectionState()}, nil
   127  }
   128  
   129  func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
   130  	conn := tls.Server(rawConn, c.config)
   131  	if err := conn.Handshake(); err != nil {
   132  		return nil, nil, err
   133  	}
   134  	return conn, TLSInfo{conn.ConnectionState()}, nil
   135  }
   136  
   137  func (c *tlsCreds) Clone() credentials.TransportCredentials {
   138  
   139  	return NewTLS(c.config)
   140  }
   141  
   142  func (c *tlsCreds) OverrideServerName(serverNameOverride string) error {
   143  	c.config.ServerName = serverNameOverride
   144  	return nil
   145  }
   146  
   147  // NewTLS uses c to construct a TransportCredentials based on TLS.
   148  func NewTLS(c *tls.Config) credentials.TransportCredentials {
   149  	tc := &tlsCreds{cloneTLSConfig(c)}
   150  	tc.config.NextProtos = alpnProtoStr
   151  	return tc
   152  }
   153  
   154  // NewClientTLSFromCert constructs TLS credentials from the input certificate for client.
   155  // serverNameOverride is for testing only. If set to a non empty string,
   156  // it will override the virtual host name of authority (e.g. :authority header field) in requests.
   157  func NewClientTLSFromCert(cp *x509.CertPool, serverNameOverride string) credentials.TransportCredentials {
   158  	return NewTLS(&tls.Config{ServerName: serverNameOverride, RootCAs: cp})
   159  }
   160  
   161  // NewClientTLSFromFile constructs TLS credentials from the input certificate file for client.
   162  // serverNameOverride is for testing only. If set to a non empty string,
   163  // it will override the virtual host name of authority (e.g. :authority header field) in requests.
   164  func NewClientTLSFromFile(certFile, serverNameOverride string) (credentials.TransportCredentials, error) {
   165  	b, err := ioutil.ReadFile(certFile)
   166  	if err != nil {
   167  		return nil, err
   168  	}
   169  	cp := x509.NewCertPool()
   170  	if !cp.AppendCertsFromPEM(b) {
   171  		return nil, fmt.Errorf("credentials: failed to append certificates")
   172  	}
   173  	return NewTLS(&tls.Config{ServerName: serverNameOverride, RootCAs: cp}), nil
   174  }
   175  
   176  // NewServerTLSFromCert constructs TLS credentials from the input certificate for server.
   177  func NewServerTLSFromCert(cert *tls.Certificate) credentials.TransportCredentials {
   178  	return NewTLS(&tls.Config{Certificates: []tls.Certificate{*cert}})
   179  }
   180  
   181  // NewServerTLSFromFile constructs TLS credentials from the input certificate file and key
   182  // file for server.
   183  func NewServerTLSFromFile(certFile, keyFile string) (credentials.TransportCredentials, error) {
   184  	cert, err := tls.LoadX509KeyPair(certFile, keyFile)
   185  	if err != nil {
   186  		return nil, err
   187  	}
   188  	return NewTLS(&tls.Config{Certificates: []tls.Certificate{cert}}), nil
   189  }