github.com/Hyperledger-TWGC/tjfoc-gm@v1.4.0/gmtls/gmcredentials/credentials.go (about) 1 /* 2 Copyright Suzhou Tongji Fintech Research Institute 2017 All Rights Reserved. 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 */ 15 16 package gmcredentials 17 18 import ( 19 "errors" 20 "fmt" 21 "io/ioutil" 22 "net" 23 "strings" 24 25 "github.com/Hyperledger-TWGC/tjfoc-gm/gmtls" 26 "github.com/Hyperledger-TWGC/tjfoc-gm/x509" 27 "golang.org/x/net/context" 28 "google.golang.org/grpc/credentials" 29 ) 30 31 var ( 32 // alpnProtoStr are the specified application level protocols for gRPC. 33 alpnProtoStr = []string{"h2"} 34 ) 35 36 // PerRPCCredentials defines the common interface for the credentials which need to 37 // attach security information to every RPC (e.g., oauth2). 38 type PerRPCCredentials interface { 39 // GetRequestMetadata gets the current request metadata, refreshing 40 // tokens if required. This should be called by the transport layer on 41 // each request, and the data should be populated in headers or other 42 // context. uri is the URI of the entry point for the request. When 43 // supported by the underlying implementation, ctx can be used for 44 // timeout and cancellation. 45 // TODO(zhaoq): Define the set of the qualified keys instead of leaving 46 // it as an arbitrary string. 47 GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) 48 // RequireTransportSecurity indicates whether the credentials requires 49 // transport security. 50 RequireTransportSecurity() bool 51 } 52 53 // ProtocolInfo provides information regarding the gRPC wire protocol version, 54 // security protocol, security protocol version in use, server name, etc. 55 type ProtocolInfo struct { 56 // ProtocolVersion is the gRPC wire protocol version. 57 ProtocolVersion string 58 // SecurityProtocol is the security protocol in use. 59 SecurityProtocol string 60 // SecurityVersion is the security protocol version. 61 SecurityVersion string 62 // ServerName is the user-configured server name. 63 ServerName string 64 } 65 66 // AuthInfo defines the common interface for the auth information the users are interested in. 67 type AuthInfo interface { 68 AuthType() string 69 } 70 71 var ( 72 // ErrConnDispatched indicates that rawConn has been dispatched out of gRPC 73 // and the caller should not close rawConn. 74 ErrConnDispatched = errors.New("credentials: rawConn is dispatched out of gRPC") 75 ) 76 77 // TLSInfo contains the auth information for a TLS authenticated connection. 78 // It implements the AuthInfo interface. 79 type TLSInfo struct { 80 State gmtls.ConnectionState 81 } 82 83 // AuthType returns the type of TLSInfo as a string. 84 func (t TLSInfo) AuthType() string { 85 return "tls" 86 } 87 88 // tlsCreds is the credentials required for authenticating a connection using TLS. 89 type tlsCreds struct { 90 // TLS configuration 91 config *gmtls.Config 92 } 93 94 func (c tlsCreds) Info() credentials.ProtocolInfo { 95 return credentials.ProtocolInfo{ 96 SecurityProtocol: "tls", 97 SecurityVersion: "1.2", 98 ServerName: c.config.ServerName, 99 } 100 } 101 102 func (c *tlsCreds) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (_ net.Conn, _ credentials.AuthInfo, err error) { 103 // use local cfg to avoid clobbering ServerName if using multiple endpoints 104 cfg := cloneTLSConfig(c.config) 105 if cfg.ServerName == "" { 106 colonPos := strings.LastIndex(addr, ":") 107 if colonPos == -1 { 108 colonPos = len(addr) 109 } 110 cfg.ServerName = addr[:colonPos] 111 } 112 conn := gmtls.Client(rawConn, cfg) 113 errChannel := make(chan error, 1) 114 go func() { 115 errChannel <- conn.Handshake() 116 }() 117 select { 118 case err := <-errChannel: 119 if err != nil { 120 return nil, nil, err 121 } 122 case <-ctx.Done(): 123 return nil, nil, ctx.Err() 124 } 125 return conn, TLSInfo{conn.ConnectionState()}, nil 126 } 127 128 func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { 129 conn := gmtls.Server(rawConn, c.config) 130 if err := conn.Handshake(); err != nil { 131 return nil, nil, err 132 } 133 return conn, TLSInfo{conn.ConnectionState()}, nil 134 } 135 136 func (c *tlsCreds) Clone() credentials.TransportCredentials { 137 return NewTLS(c.config) 138 } 139 140 func (c *tlsCreds) OverrideServerName(serverNameOverride string) error { 141 c.config.ServerName = serverNameOverride 142 return nil 143 } 144 145 // NewTLS uses c to construct a TransportCredentials based on TLS. 146 func NewTLS(c *gmtls.Config) credentials.TransportCredentials { 147 tc := &tlsCreds{cloneTLSConfig(c)} 148 tc.config.NextProtos = alpnProtoStr 149 return tc 150 } 151 152 // NewClientTLSFromCert constructs TLS credentials from the input certificate for client. 153 // serverNameOverride is for testing only. If set to a non empty string, 154 // it will override the virtual host name of authority (e.g. :authority header field) in requests. 155 func NewClientTLSFromCert(cp *x509.CertPool, serverNameOverride string) credentials.TransportCredentials { 156 return NewTLS(&gmtls.Config{GMSupport: &gmtls.GMSupport{}, ServerName: serverNameOverride, RootCAs: cp}) 157 } 158 159 // NewClientTLSFromFile constructs TLS credentials from the input certificate file for client. 160 // serverNameOverride is for testing only. If set to a non empty string, 161 // it will override the virtual host name of authority (e.g. :authority header field) in requests. 162 func NewClientTLSFromFile(certFile, serverNameOverride string) (credentials.TransportCredentials, error) { 163 b, err := ioutil.ReadFile(certFile) 164 if err != nil { 165 return nil, err 166 } 167 cp := x509.NewCertPool() 168 if !cp.AppendCertsFromPEM(b) { 169 return nil, fmt.Errorf("credentials: failed to append certificates") 170 } 171 return NewTLS(&gmtls.Config{ServerName: serverNameOverride, RootCAs: cp}), nil 172 } 173 174 // NewServerTLSFromCert constructs TLS credentials from the input certificate for server. 175 func NewServerTLSFromCert(cert *gmtls.Certificate) credentials.TransportCredentials { 176 return NewTLS(&gmtls.Config{Certificates: []gmtls.Certificate{*cert}}) 177 } 178 179 // NewServerTLSFromFile constructs TLS credentials from the input certificate file and key 180 // file for server. 181 func NewServerTLSFromFile(certFile, keyFile string) (credentials.TransportCredentials, error) { 182 cert, err := gmtls.LoadX509KeyPair(certFile, keyFile) 183 if err != nil { 184 return nil, err 185 } 186 return NewTLS(&gmtls.Config{Certificates: []gmtls.Certificate{cert}}), nil 187 }