github.com/Hyperledger-TWGC/tjfoc-gm@v1.4.0/gmtls/gmcredentials/credentials.go (about)

     1  /*
     2  Copyright Suzhou Tongji Fintech Research Institute 2017 All Rights Reserved.
     3  Licensed under the Apache License, Version 2.0 (the "License");
     4  you may not use this file except in compliance with the License.
     5  You may obtain a copy of the License at
     6  
     7  	http://www.apache.org/licenses/LICENSE-2.0
     8  
     9  Unless required by applicable law or agreed to in writing, software
    10  distributed under the License is distributed on an "AS IS" BASIS,
    11  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  See the License for the specific language governing permissions and
    13  limitations under the License.
    14  */
    15  
    16  package gmcredentials
    17  
    18  import (
    19  	"errors"
    20  	"fmt"
    21  	"io/ioutil"
    22  	"net"
    23  	"strings"
    24  
    25  	"github.com/Hyperledger-TWGC/tjfoc-gm/gmtls"
    26  	"github.com/Hyperledger-TWGC/tjfoc-gm/x509"
    27  	"golang.org/x/net/context"
    28  	"google.golang.org/grpc/credentials"
    29  )
    30  
    31  var (
    32  	// alpnProtoStr are the specified application level protocols for gRPC.
    33  	alpnProtoStr = []string{"h2"}
    34  )
    35  
    36  // PerRPCCredentials defines the common interface for the credentials which need to
    37  // attach security information to every RPC (e.g., oauth2).
    38  type PerRPCCredentials interface {
    39  	// GetRequestMetadata gets the current request metadata, refreshing
    40  	// tokens if required. This should be called by the transport layer on
    41  	// each request, and the data should be populated in headers or other
    42  	// context. uri is the URI of the entry point for the request. When
    43  	// supported by the underlying implementation, ctx can be used for
    44  	// timeout and cancellation.
    45  	// TODO(zhaoq): Define the set of the qualified keys instead of leaving
    46  	// it as an arbitrary string.
    47  	GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error)
    48  	// RequireTransportSecurity indicates whether the credentials requires
    49  	// transport security.
    50  	RequireTransportSecurity() bool
    51  }
    52  
    53  // ProtocolInfo provides information regarding the gRPC wire protocol version,
    54  // security protocol, security protocol version in use, server name, etc.
    55  type ProtocolInfo struct {
    56  	// ProtocolVersion is the gRPC wire protocol version.
    57  	ProtocolVersion string
    58  	// SecurityProtocol is the security protocol in use.
    59  	SecurityProtocol string
    60  	// SecurityVersion is the security protocol version.
    61  	SecurityVersion string
    62  	// ServerName is the user-configured server name.
    63  	ServerName string
    64  }
    65  
    66  // AuthInfo defines the common interface for the auth information the users are interested in.
    67  type AuthInfo interface {
    68  	AuthType() string
    69  }
    70  
    71  var (
    72  	// ErrConnDispatched indicates that rawConn has been dispatched out of gRPC
    73  	// and the caller should not close rawConn.
    74  	ErrConnDispatched = errors.New("credentials: rawConn is dispatched out of gRPC")
    75  )
    76  
    77  // TLSInfo contains the auth information for a TLS authenticated connection.
    78  // It implements the AuthInfo interface.
    79  type TLSInfo struct {
    80  	State gmtls.ConnectionState
    81  }
    82  
    83  // AuthType returns the type of TLSInfo as a string.
    84  func (t TLSInfo) AuthType() string {
    85  	return "tls"
    86  }
    87  
    88  // tlsCreds is the credentials required for authenticating a connection using TLS.
    89  type tlsCreds struct {
    90  	// TLS configuration
    91  	config *gmtls.Config
    92  }
    93  
    94  func (c tlsCreds) Info() credentials.ProtocolInfo {
    95  	return credentials.ProtocolInfo{
    96  		SecurityProtocol: "tls",
    97  		SecurityVersion:  "1.2",
    98  		ServerName:       c.config.ServerName,
    99  	}
   100  }
   101  
   102  func (c *tlsCreds) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (_ net.Conn, _ credentials.AuthInfo, err error) {
   103  	// use local cfg to avoid clobbering ServerName if using multiple endpoints
   104  	cfg := cloneTLSConfig(c.config)
   105  	if cfg.ServerName == "" {
   106  		colonPos := strings.LastIndex(addr, ":")
   107  		if colonPos == -1 {
   108  			colonPos = len(addr)
   109  		}
   110  		cfg.ServerName = addr[:colonPos]
   111  	}
   112  	conn := gmtls.Client(rawConn, cfg)
   113  	errChannel := make(chan error, 1)
   114  	go func() {
   115  		errChannel <- conn.Handshake()
   116  	}()
   117  	select {
   118  	case err := <-errChannel:
   119  		if err != nil {
   120  			return nil, nil, err
   121  		}
   122  	case <-ctx.Done():
   123  		return nil, nil, ctx.Err()
   124  	}
   125  	return conn, TLSInfo{conn.ConnectionState()}, nil
   126  }
   127  
   128  func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
   129  	conn := gmtls.Server(rawConn, c.config)
   130  	if err := conn.Handshake(); err != nil {
   131  		return nil, nil, err
   132  	}
   133  	return conn, TLSInfo{conn.ConnectionState()}, nil
   134  }
   135  
   136  func (c *tlsCreds) Clone() credentials.TransportCredentials {
   137  	return NewTLS(c.config)
   138  }
   139  
   140  func (c *tlsCreds) OverrideServerName(serverNameOverride string) error {
   141  	c.config.ServerName = serverNameOverride
   142  	return nil
   143  }
   144  
   145  // NewTLS uses c to construct a TransportCredentials based on TLS.
   146  func NewTLS(c *gmtls.Config) credentials.TransportCredentials {
   147  	tc := &tlsCreds{cloneTLSConfig(c)}
   148  	tc.config.NextProtos = alpnProtoStr
   149  	return tc
   150  }
   151  
   152  // NewClientTLSFromCert constructs TLS credentials from the input certificate for client.
   153  // serverNameOverride is for testing only. If set to a non empty string,
   154  // it will override the virtual host name of authority (e.g. :authority header field) in requests.
   155  func NewClientTLSFromCert(cp *x509.CertPool, serverNameOverride string) credentials.TransportCredentials {
   156  	return NewTLS(&gmtls.Config{GMSupport: &gmtls.GMSupport{}, ServerName: serverNameOverride, RootCAs: cp})
   157  }
   158  
   159  // NewClientTLSFromFile constructs TLS credentials from the input certificate file for client.
   160  // serverNameOverride is for testing only. If set to a non empty string,
   161  // it will override the virtual host name of authority (e.g. :authority header field) in requests.
   162  func NewClientTLSFromFile(certFile, serverNameOverride string) (credentials.TransportCredentials, error) {
   163  	b, err := ioutil.ReadFile(certFile)
   164  	if err != nil {
   165  		return nil, err
   166  	}
   167  	cp := x509.NewCertPool()
   168  	if !cp.AppendCertsFromPEM(b) {
   169  		return nil, fmt.Errorf("credentials: failed to append certificates")
   170  	}
   171  	return NewTLS(&gmtls.Config{ServerName: serverNameOverride, RootCAs: cp}), nil
   172  }
   173  
   174  // NewServerTLSFromCert constructs TLS credentials from the input certificate for server.
   175  func NewServerTLSFromCert(cert *gmtls.Certificate) credentials.TransportCredentials {
   176  	return NewTLS(&gmtls.Config{Certificates: []gmtls.Certificate{*cert}})
   177  }
   178  
   179  // NewServerTLSFromFile constructs TLS credentials from the input certificate file and key
   180  // file for server.
   181  func NewServerTLSFromFile(certFile, keyFile string) (credentials.TransportCredentials, error) {
   182  	cert, err := gmtls.LoadX509KeyPair(certFile, keyFile)
   183  	if err != nil {
   184  		return nil, err
   185  	}
   186  	return NewTLS(&gmtls.Config{Certificates: []gmtls.Certificate{cert}}), nil
   187  }