google.golang.org/grpc@v1.62.1/credentials/credentials_test.go (about) 1 /* 2 * 3 * Copyright 2016 gRPC authors. 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 * 17 */ 18 19 package credentials 20 21 import ( 22 "context" 23 "crypto/tls" 24 "net" 25 "strings" 26 "testing" 27 "time" 28 29 "google.golang.org/grpc/internal/grpctest" 30 "google.golang.org/grpc/testdata" 31 ) 32 33 const defaultTestTimeout = 10 * time.Second 34 35 type s struct { 36 grpctest.Tester 37 } 38 39 func Test(t *testing.T) { 40 grpctest.RunSubTests(t, s{}) 41 } 42 43 // A struct that implements AuthInfo interface but does not implement GetCommonAuthInfo() method. 44 type testAuthInfoNoGetCommonAuthInfoMethod struct{} 45 46 func (ta testAuthInfoNoGetCommonAuthInfoMethod) AuthType() string { 47 return "testAuthInfoNoGetCommonAuthInfoMethod" 48 } 49 50 // A struct that implements AuthInfo interface and implements CommonAuthInfo() method. 51 type testAuthInfo struct { 52 CommonAuthInfo 53 } 54 55 func (ta testAuthInfo) AuthType() string { 56 return "testAuthInfo" 57 } 58 59 func (s) TestCheckSecurityLevel(t *testing.T) { 60 testCases := []struct { 61 authLevel SecurityLevel 62 testLevel SecurityLevel 63 want bool 64 }{ 65 { 66 authLevel: PrivacyAndIntegrity, 67 testLevel: PrivacyAndIntegrity, 68 want: true, 69 }, 70 { 71 authLevel: IntegrityOnly, 72 testLevel: PrivacyAndIntegrity, 73 want: false, 74 }, 75 { 76 authLevel: IntegrityOnly, 77 testLevel: NoSecurity, 78 want: true, 79 }, 80 { 81 authLevel: InvalidSecurityLevel, 82 testLevel: IntegrityOnly, 83 want: true, 84 }, 85 { 86 authLevel: InvalidSecurityLevel, 87 testLevel: PrivacyAndIntegrity, 88 want: true, 89 }, 90 } 91 for _, tc := range testCases { 92 err := CheckSecurityLevel(testAuthInfo{CommonAuthInfo: CommonAuthInfo{SecurityLevel: tc.authLevel}}, tc.testLevel) 93 if tc.want && (err != nil) { 94 t.Fatalf("CheckSeurityLevel(%s, %s) returned failure but want success", tc.authLevel.String(), tc.testLevel.String()) 95 } else if !tc.want && (err == nil) { 96 t.Fatalf("CheckSeurityLevel(%s, %s) returned success but want failure", tc.authLevel.String(), tc.testLevel.String()) 97 98 } 99 } 100 } 101 102 func (s) TestCheckSecurityLevelNoGetCommonAuthInfoMethod(t *testing.T) { 103 if err := CheckSecurityLevel(testAuthInfoNoGetCommonAuthInfoMethod{}, PrivacyAndIntegrity); err != nil { 104 t.Fatalf("CheckSeurityLevel() returned failure but want success") 105 } 106 } 107 108 func (s) TestTLSOverrideServerName(t *testing.T) { 109 expectedServerName := "server.name" 110 c := NewTLS(nil) 111 c.OverrideServerName(expectedServerName) 112 if c.Info().ServerName != expectedServerName { 113 t.Fatalf("c.Info().ServerName = %v, want %v", c.Info().ServerName, expectedServerName) 114 } 115 } 116 117 func (s) TestTLSClone(t *testing.T) { 118 expectedServerName := "server.name" 119 c := NewTLS(nil) 120 c.OverrideServerName(expectedServerName) 121 cc := c.Clone() 122 if cc.Info().ServerName != expectedServerName { 123 t.Fatalf("cc.Info().ServerName = %v, want %v", cc.Info().ServerName, expectedServerName) 124 } 125 cc.OverrideServerName("") 126 if c.Info().ServerName != expectedServerName { 127 t.Fatalf("Change in clone should not affect the original, c.Info().ServerName = %v, want %v", c.Info().ServerName, expectedServerName) 128 } 129 130 } 131 132 type serverHandshake func(net.Conn) (AuthInfo, error) 133 134 func (s) TestClientHandshakeReturnsAuthInfo(t *testing.T) { 135 tcs := []struct { 136 name string 137 address string 138 }{ 139 { 140 name: "localhost", 141 address: "localhost:0", 142 }, 143 { 144 name: "ipv4", 145 address: "127.0.0.1:0", 146 }, 147 { 148 name: "ipv6", 149 address: "[::1]:0", 150 }, 151 } 152 153 for _, tc := range tcs { 154 t.Run(tc.name, func(t *testing.T) { 155 done := make(chan AuthInfo, 1) 156 lis := launchServerOnListenAddress(t, tlsServerHandshake, done, tc.address) 157 defer lis.Close() 158 lisAddr := lis.Addr().String() 159 clientAuthInfo := clientHandle(t, gRPCClientHandshake, lisAddr) 160 // wait until server sends serverAuthInfo or fails. 161 serverAuthInfo, ok := <-done 162 if !ok { 163 t.Fatalf("Error at server-side") 164 } 165 if !compare(clientAuthInfo, serverAuthInfo) { 166 t.Fatalf("c.ClientHandshake(_, %v, _) = %v, want %v.", lisAddr, clientAuthInfo, serverAuthInfo) 167 } 168 }) 169 } 170 } 171 172 func (s) TestServerHandshakeReturnsAuthInfo(t *testing.T) { 173 done := make(chan AuthInfo, 1) 174 lis := launchServer(t, gRPCServerHandshake, done) 175 defer lis.Close() 176 clientAuthInfo := clientHandle(t, tlsClientHandshake, lis.Addr().String()) 177 // wait until server sends serverAuthInfo or fails. 178 serverAuthInfo, ok := <-done 179 if !ok { 180 t.Fatalf("Error at server-side") 181 } 182 if !compare(clientAuthInfo, serverAuthInfo) { 183 t.Fatalf("ServerHandshake(_) = %v, want %v.", serverAuthInfo, clientAuthInfo) 184 } 185 } 186 187 func (s) TestServerAndClientHandshake(t *testing.T) { 188 done := make(chan AuthInfo, 1) 189 lis := launchServer(t, gRPCServerHandshake, done) 190 defer lis.Close() 191 clientAuthInfo := clientHandle(t, gRPCClientHandshake, lis.Addr().String()) 192 // wait until server sends serverAuthInfo or fails. 193 serverAuthInfo, ok := <-done 194 if !ok { 195 t.Fatalf("Error at server-side") 196 } 197 if !compare(clientAuthInfo, serverAuthInfo) { 198 t.Fatalf("AuthInfo returned by server: %v and client: %v aren't same", serverAuthInfo, clientAuthInfo) 199 } 200 } 201 202 func compare(a1, a2 AuthInfo) bool { 203 if a1.AuthType() != a2.AuthType() { 204 return false 205 } 206 switch a1.AuthType() { 207 case "tls": 208 state1 := a1.(TLSInfo).State 209 state2 := a2.(TLSInfo).State 210 if state1.Version == state2.Version && 211 state1.HandshakeComplete == state2.HandshakeComplete && 212 state1.CipherSuite == state2.CipherSuite && 213 state1.NegotiatedProtocol == state2.NegotiatedProtocol { 214 return true 215 } 216 return false 217 default: 218 return false 219 } 220 } 221 222 func launchServer(t *testing.T, hs serverHandshake, done chan AuthInfo) net.Listener { 223 return launchServerOnListenAddress(t, hs, done, "localhost:0") 224 } 225 226 func launchServerOnListenAddress(t *testing.T, hs serverHandshake, done chan AuthInfo, address string) net.Listener { 227 lis, err := net.Listen("tcp", address) 228 if err != nil { 229 if strings.Contains(err.Error(), "bind: cannot assign requested address") || 230 strings.Contains(err.Error(), "socket: address family not supported by protocol") { 231 t.Skipf("no support for address %v", address) 232 } 233 t.Fatalf("Failed to listen: %v", err) 234 } 235 go serverHandle(t, hs, done, lis) 236 return lis 237 } 238 239 // Is run in a separate goroutine. 240 func serverHandle(t *testing.T, hs serverHandshake, done chan AuthInfo, lis net.Listener) { 241 serverRawConn, err := lis.Accept() 242 if err != nil { 243 t.Errorf("Server failed to accept connection: %v", err) 244 close(done) 245 return 246 } 247 serverAuthInfo, err := hs(serverRawConn) 248 if err != nil { 249 t.Errorf("Server failed while handshake. Error: %v", err) 250 serverRawConn.Close() 251 close(done) 252 return 253 } 254 done <- serverAuthInfo 255 } 256 257 func clientHandle(t *testing.T, hs func(net.Conn, string) (AuthInfo, error), lisAddr string) AuthInfo { 258 conn, err := net.Dial("tcp", lisAddr) 259 if err != nil { 260 t.Fatalf("Client failed to connect to %s. Error: %v", lisAddr, err) 261 } 262 defer conn.Close() 263 clientAuthInfo, err := hs(conn, lisAddr) 264 if err != nil { 265 t.Fatalf("Error on client while handshake. Error: %v", err) 266 } 267 return clientAuthInfo 268 } 269 270 // Server handshake implementation in gRPC. 271 func gRPCServerHandshake(conn net.Conn) (AuthInfo, error) { 272 serverTLS, err := NewServerTLSFromFile(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem")) 273 if err != nil { 274 return nil, err 275 } 276 _, serverAuthInfo, err := serverTLS.ServerHandshake(conn) 277 if err != nil { 278 return nil, err 279 } 280 return serverAuthInfo, nil 281 } 282 283 // Client handshake implementation in gRPC. 284 func gRPCClientHandshake(conn net.Conn, lisAddr string) (AuthInfo, error) { 285 clientTLS := NewTLS(&tls.Config{InsecureSkipVerify: true}) 286 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 287 defer cancel() 288 _, authInfo, err := clientTLS.ClientHandshake(ctx, lisAddr, conn) 289 if err != nil { 290 return nil, err 291 } 292 return authInfo, nil 293 } 294 295 func tlsServerHandshake(conn net.Conn) (AuthInfo, error) { 296 cert, err := tls.LoadX509KeyPair(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem")) 297 if err != nil { 298 return nil, err 299 } 300 serverTLSConfig := &tls.Config{Certificates: []tls.Certificate{cert}} 301 serverConn := tls.Server(conn, serverTLSConfig) 302 err = serverConn.Handshake() 303 if err != nil { 304 return nil, err 305 } 306 return TLSInfo{State: serverConn.ConnectionState(), CommonAuthInfo: CommonAuthInfo{SecurityLevel: PrivacyAndIntegrity}}, nil 307 } 308 309 func tlsClientHandshake(conn net.Conn, _ string) (AuthInfo, error) { 310 clientTLSConfig := &tls.Config{InsecureSkipVerify: true} 311 clientConn := tls.Client(conn, clientTLSConfig) 312 if err := clientConn.Handshake(); err != nil { 313 return nil, err 314 } 315 return TLSInfo{State: clientConn.ConnectionState(), CommonAuthInfo: CommonAuthInfo{SecurityLevel: PrivacyAndIntegrity}}, nil 316 }