github.com/hxx258456/ccgo@v0.0.5-0.20230213014102-48b35f46f66f/grpc/credentials/credentials_test.go (about) 1 /* 2 * 3 * Copyright 2016 gRPC authors. 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 * 17 */ 18 19 package credentials 20 21 import ( 22 "context" 23 "net" 24 "strings" 25 "testing" 26 "time" 27 28 tls "github.com/hxx258456/ccgo/gmtls" 29 30 "github.com/hxx258456/ccgo/grpc/internal/grpctest" 31 "github.com/hxx258456/ccgo/grpc/testdata" 32 ) 33 34 const defaultTestTimeout = 10 * time.Second 35 36 type s struct { 37 grpctest.Tester 38 } 39 40 func Test(t *testing.T) { 41 grpctest.RunSubTests(t, s{}) 42 } 43 44 // A struct that implements AuthInfo interface but does not implement GetCommonAuthInfo() method. 45 type testAuthInfoNoGetCommonAuthInfoMethod struct{} 46 47 func (ta testAuthInfoNoGetCommonAuthInfoMethod) AuthType() string { 48 return "testAuthInfoNoGetCommonAuthInfoMethod" 49 } 50 51 // A struct that implements AuthInfo interface and implements CommonAuthInfo() method. 52 type testAuthInfo struct { 53 CommonAuthInfo 54 } 55 56 func (ta testAuthInfo) AuthType() string { 57 return "testAuthInfo" 58 } 59 60 func (s) TestCheckSecurityLevel(t *testing.T) { 61 testCases := []struct { 62 authLevel SecurityLevel 63 testLevel SecurityLevel 64 want bool 65 }{ 66 { 67 authLevel: PrivacyAndIntegrity, 68 testLevel: PrivacyAndIntegrity, 69 want: true, 70 }, 71 { 72 authLevel: IntegrityOnly, 73 testLevel: PrivacyAndIntegrity, 74 want: false, 75 }, 76 { 77 authLevel: IntegrityOnly, 78 testLevel: NoSecurity, 79 want: true, 80 }, 81 { 82 authLevel: InvalidSecurityLevel, 83 testLevel: IntegrityOnly, 84 want: true, 85 }, 86 { 87 authLevel: InvalidSecurityLevel, 88 testLevel: PrivacyAndIntegrity, 89 want: true, 90 }, 91 } 92 for _, tc := range testCases { 93 err := CheckSecurityLevel(testAuthInfo{CommonAuthInfo: CommonAuthInfo{SecurityLevel: tc.authLevel}}, tc.testLevel) 94 if tc.want && (err != nil) { 95 t.Fatalf("CheckSeurityLevel(%s, %s) returned failure but want success", tc.authLevel.String(), tc.testLevel.String()) 96 } else if !tc.want && (err == nil) { 97 t.Fatalf("CheckSeurityLevel(%s, %s) returned success but want failure", tc.authLevel.String(), tc.testLevel.String()) 98 99 } 100 } 101 } 102 103 func (s) TestCheckSecurityLevelNoGetCommonAuthInfoMethod(t *testing.T) { 104 if err := CheckSecurityLevel(testAuthInfoNoGetCommonAuthInfoMethod{}, PrivacyAndIntegrity); err != nil { 105 t.Fatalf("CheckSeurityLevel() returned failure but want success") 106 } 107 } 108 109 func (s) TestTLSOverrideServerName(t *testing.T) { 110 expectedServerName := "server.name" 111 c := NewTLS(nil) 112 c.OverrideServerName(expectedServerName) 113 if c.Info().ServerName != expectedServerName { 114 t.Fatalf("c.Info().ServerName = %v, want %v", c.Info().ServerName, expectedServerName) 115 } 116 } 117 118 func (s) TestTLSClone(t *testing.T) { 119 expectedServerName := "server.name" 120 c := NewTLS(nil) 121 c.OverrideServerName(expectedServerName) 122 cc := c.Clone() 123 if cc.Info().ServerName != expectedServerName { 124 t.Fatalf("cc.Info().ServerName = %v, want %v", cc.Info().ServerName, expectedServerName) 125 } 126 cc.OverrideServerName("") 127 if c.Info().ServerName != expectedServerName { 128 t.Fatalf("Change in clone should not affect the original, c.Info().ServerName = %v, want %v", c.Info().ServerName, expectedServerName) 129 } 130 131 } 132 133 type serverHandshake func(net.Conn) (AuthInfo, error) 134 135 func (s) TestClientHandshakeReturnsAuthInfo(t *testing.T) { 136 tcs := []struct { 137 name string 138 address string 139 }{ 140 { 141 name: "localhost", 142 address: "localhost:0", 143 }, 144 { 145 name: "ipv4", 146 address: "127.0.0.1:0", 147 }, 148 { 149 name: "ipv6", 150 address: "[::1]:0", 151 }, 152 } 153 154 for _, tc := range tcs { 155 t.Run(tc.name, func(t *testing.T) { 156 done := make(chan AuthInfo, 1) 157 lis := launchServerOnListenAddress(t, tlsServerHandshake, done, tc.address) 158 defer lis.Close() 159 lisAddr := lis.Addr().String() 160 clientAuthInfo := clientHandle(t, gRPCClientHandshake, lisAddr) 161 // wait until server sends serverAuthInfo or fails. 162 serverAuthInfo, ok := <-done 163 if !ok { 164 t.Fatalf("Error at server-side") 165 } 166 if !compare(clientAuthInfo, serverAuthInfo) { 167 t.Fatalf("c.ClientHandshake(_, %v, _) = %v, want %v.", lisAddr, clientAuthInfo, serverAuthInfo) 168 } 169 }) 170 } 171 } 172 173 func (s) TestServerHandshakeReturnsAuthInfo(t *testing.T) { 174 done := make(chan AuthInfo, 1) 175 lis := launchServer(t, gRPCServerHandshake, done) 176 defer lis.Close() 177 clientAuthInfo := clientHandle(t, tlsClientHandshake, lis.Addr().String()) 178 // wait until server sends serverAuthInfo or fails. 179 serverAuthInfo, ok := <-done 180 if !ok { 181 t.Fatalf("Error at server-side") 182 } 183 if !compare(clientAuthInfo, serverAuthInfo) { 184 t.Fatalf("ServerHandshake(_) = %v, want %v.", serverAuthInfo, clientAuthInfo) 185 } 186 } 187 188 func (s) TestServerAndClientHandshake(t *testing.T) { 189 done := make(chan AuthInfo, 1) 190 lis := launchServer(t, gRPCServerHandshake, done) 191 defer lis.Close() 192 clientAuthInfo := clientHandle(t, gRPCClientHandshake, lis.Addr().String()) 193 // wait until server sends serverAuthInfo or fails. 194 serverAuthInfo, ok := <-done 195 if !ok { 196 t.Fatalf("Error at server-side") 197 } 198 if !compare(clientAuthInfo, serverAuthInfo) { 199 t.Fatalf("AuthInfo returned by server: %v and client: %v aren't same", serverAuthInfo, clientAuthInfo) 200 } 201 } 202 203 func compare(a1, a2 AuthInfo) bool { 204 if a1.AuthType() != a2.AuthType() { 205 return false 206 } 207 switch a1.AuthType() { 208 case "tls": 209 state1 := a1.(TLSInfo).State 210 state2 := a2.(TLSInfo).State 211 if state1.Version == state2.Version && 212 state1.HandshakeComplete == state2.HandshakeComplete && 213 state1.CipherSuite == state2.CipherSuite && 214 state1.NegotiatedProtocol == state2.NegotiatedProtocol { 215 return true 216 } 217 return false 218 default: 219 return false 220 } 221 } 222 223 func launchServer(t *testing.T, hs serverHandshake, done chan AuthInfo) net.Listener { 224 return launchServerOnListenAddress(t, hs, done, "localhost:0") 225 } 226 227 func launchServerOnListenAddress(t *testing.T, hs serverHandshake, done chan AuthInfo, address string) net.Listener { 228 lis, err := net.Listen("tcp", address) 229 if err != nil { 230 if strings.Contains(err.Error(), "bind: cannot assign requested address") || 231 strings.Contains(err.Error(), "socket: address family not supported by protocol") { 232 t.Skipf("no support for address %v", address) 233 } 234 t.Fatalf("Failed to listen: %v", err) 235 } 236 go serverHandle(t, hs, done, lis) 237 return lis 238 } 239 240 // Is run in a separate goroutine. 241 func serverHandle(t *testing.T, hs serverHandshake, done chan AuthInfo, lis net.Listener) { 242 serverRawConn, err := lis.Accept() 243 if err != nil { 244 t.Errorf("Server failed to accept connection: %v", err) 245 close(done) 246 return 247 } 248 serverAuthInfo, err := hs(serverRawConn) 249 if err != nil { 250 t.Errorf("Server failed while handshake. Error: %v", err) 251 serverRawConn.Close() 252 close(done) 253 return 254 } 255 done <- serverAuthInfo 256 } 257 258 func clientHandle(t *testing.T, hs func(net.Conn, string) (AuthInfo, error), lisAddr string) AuthInfo { 259 conn, err := net.Dial("tcp", lisAddr) 260 if err != nil { 261 t.Fatalf("Client failed to connect to %s. Error: %v", lisAddr, err) 262 } 263 defer conn.Close() 264 clientAuthInfo, err := hs(conn, lisAddr) 265 if err != nil { 266 t.Fatalf("Error on client while handshake. Error: %v", err) 267 } 268 return clientAuthInfo 269 } 270 271 // Server handshake implementation in gRPC. 272 func gRPCServerHandshake(conn net.Conn) (AuthInfo, error) { 273 serverTLS, err := NewServerTLSFromFile(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem")) 274 if err != nil { 275 return nil, err 276 } 277 _, serverAuthInfo, err := serverTLS.ServerHandshake(conn) 278 if err != nil { 279 return nil, err 280 } 281 return serverAuthInfo, nil 282 } 283 284 // Client handshake implementation in gRPC. 285 func gRPCClientHandshake(conn net.Conn, lisAddr string) (AuthInfo, error) { 286 clientTLS := NewTLS(&tls.Config{InsecureSkipVerify: true}) 287 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 288 defer cancel() 289 _, authInfo, err := clientTLS.ClientHandshake(ctx, lisAddr, conn) 290 if err != nil { 291 return nil, err 292 } 293 return authInfo, nil 294 } 295 296 func tlsServerHandshake(conn net.Conn) (AuthInfo, error) { 297 cert, err := tls.LoadX509KeyPair(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem")) 298 if err != nil { 299 return nil, err 300 } 301 serverTLSConfig := &tls.Config{Certificates: []tls.Certificate{cert}} 302 serverConn := tls.Server(conn, serverTLSConfig) 303 err = serverConn.Handshake() 304 if err != nil { 305 return nil, err 306 } 307 return TLSInfo{State: serverConn.ConnectionState(), CommonAuthInfo: CommonAuthInfo{SecurityLevel: PrivacyAndIntegrity}}, nil 308 } 309 310 func tlsClientHandshake(conn net.Conn, _ string) (AuthInfo, error) { 311 clientTLSConfig := &tls.Config{InsecureSkipVerify: true} 312 clientConn := tls.Client(conn, clientTLSConfig) 313 if err := clientConn.Handshake(); err != nil { 314 return nil, err 315 } 316 return TLSInfo{State: clientConn.ConnectionState(), CommonAuthInfo: CommonAuthInfo{SecurityLevel: PrivacyAndIntegrity}}, nil 317 }