google.golang.org/grpc@v1.72.2/experimental/credentials/credentials_test.go (about) 1 /* 2 * 3 * Copyright 2025 gRPC authors. 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 * 17 */ 18 19 package credentials 20 21 import ( 22 "context" 23 "crypto/tls" 24 "net" 25 "strings" 26 "testing" 27 "time" 28 29 "google.golang.org/grpc/credentials" 30 "google.golang.org/grpc/internal/grpctest" 31 "google.golang.org/grpc/testdata" 32 ) 33 34 const defaultTestTimeout = 10 * time.Second 35 36 type s struct { 37 grpctest.Tester 38 } 39 40 func Test(t *testing.T) { 41 grpctest.RunSubTests(t, s{}) 42 } 43 44 func (s) TestTLSOverrideServerName(t *testing.T) { 45 expectedServerName := "server.name" 46 c := NewTLSWithALPNDisabled(nil) 47 c.OverrideServerName(expectedServerName) 48 if c.Info().ServerName != expectedServerName { 49 t.Fatalf("c.Info().ServerName = %v, want %v", c.Info().ServerName, expectedServerName) 50 } 51 } 52 53 func (s) TestTLSClone(t *testing.T) { 54 expectedServerName := "server.name" 55 c := NewTLSWithALPNDisabled(nil) 56 c.OverrideServerName(expectedServerName) 57 cc := c.Clone() 58 if cc.Info().ServerName != expectedServerName { 59 t.Fatalf("cc.Info().ServerName = %v, want %v", cc.Info().ServerName, expectedServerName) 60 } 61 cc.OverrideServerName("") 62 if c.Info().ServerName != expectedServerName { 63 t.Fatalf("Change in clone should not affect the original, c.Info().ServerName = %v, want %v", c.Info().ServerName, expectedServerName) 64 } 65 66 } 67 68 type serverHandshake func(net.Conn) (credentials.AuthInfo, error) 69 70 func (s) TestClientHandshakeReturnsAuthInfo(t *testing.T) { 71 tcs := []struct { 72 name string 73 address string 74 }{ 75 { 76 name: "localhost", 77 address: "localhost:0", 78 }, 79 { 80 name: "ipv4", 81 address: "127.0.0.1:0", 82 }, 83 { 84 name: "ipv6", 85 address: "[::1]:0", 86 }, 87 } 88 89 for _, tc := range tcs { 90 t.Run(tc.name, func(t *testing.T) { 91 done := make(chan credentials.AuthInfo, 1) 92 lis := launchServerOnListenAddress(t, tlsServerHandshake, done, tc.address) 93 defer lis.Close() 94 lisAddr := lis.Addr().String() 95 clientAuthInfo := clientHandle(t, gRPCClientHandshake, lisAddr) 96 // wait until server sends serverAuthInfo or fails. 97 serverAuthInfo, ok := <-done 98 if !ok { 99 t.Fatalf("Error at server-side") 100 } 101 if !compare(clientAuthInfo, serverAuthInfo) { 102 t.Fatalf("c.ClientHandshake(_, %v, _) = %v, want %v.", lisAddr, clientAuthInfo, serverAuthInfo) 103 } 104 }) 105 } 106 } 107 108 func (s) TestServerHandshakeReturnsAuthInfo(t *testing.T) { 109 done := make(chan credentials.AuthInfo, 1) 110 lis := launchServer(t, gRPCServerHandshake, done) 111 defer lis.Close() 112 clientAuthInfo := clientHandle(t, tlsClientHandshake, lis.Addr().String()) 113 // wait until server sends serverAuthInfo or fails. 114 serverAuthInfo, ok := <-done 115 if !ok { 116 t.Fatalf("Error at server-side") 117 } 118 if !compare(clientAuthInfo, serverAuthInfo) { 119 t.Fatalf("ServerHandshake(_) = %v, want %v.", serverAuthInfo, clientAuthInfo) 120 } 121 } 122 123 func (s) TestServerAndClientHandshake(t *testing.T) { 124 done := make(chan credentials.AuthInfo, 1) 125 lis := launchServer(t, gRPCServerHandshake, done) 126 defer lis.Close() 127 clientAuthInfo := clientHandle(t, gRPCClientHandshake, lis.Addr().String()) 128 // wait until server sends serverAuthInfo or fails. 129 serverAuthInfo, ok := <-done 130 if !ok { 131 t.Fatalf("Error at server-side") 132 } 133 if !compare(clientAuthInfo, serverAuthInfo) { 134 t.Fatalf("AuthInfo returned by server: %v and client: %v aren't same", serverAuthInfo, clientAuthInfo) 135 } 136 } 137 138 func compare(a1, a2 credentials.AuthInfo) bool { 139 if a1.AuthType() != a2.AuthType() { 140 return false 141 } 142 switch a1.AuthType() { 143 case "tls": 144 state1 := a1.(credentials.TLSInfo).State 145 state2 := a2.(credentials.TLSInfo).State 146 if state1.Version == state2.Version && 147 state1.HandshakeComplete == state2.HandshakeComplete && 148 state1.CipherSuite == state2.CipherSuite && 149 state1.NegotiatedProtocol == state2.NegotiatedProtocol { 150 return true 151 } 152 return false 153 default: 154 return false 155 } 156 } 157 158 func launchServer(t *testing.T, hs serverHandshake, done chan credentials.AuthInfo) net.Listener { 159 return launchServerOnListenAddress(t, hs, done, "localhost:0") 160 } 161 162 func launchServerOnListenAddress(t *testing.T, hs serverHandshake, done chan credentials.AuthInfo, address string) net.Listener { 163 lis, err := net.Listen("tcp", address) 164 if err != nil { 165 if strings.Contains(err.Error(), "bind: cannot assign requested address") || 166 strings.Contains(err.Error(), "socket: address family not supported by protocol") { 167 t.Skipf("no support for address %v", address) 168 } 169 t.Fatalf("Failed to listen: %v", err) 170 } 171 go serverHandle(t, hs, done, lis) 172 return lis 173 } 174 175 // Is run in a separate goroutine. 176 func serverHandle(t *testing.T, hs serverHandshake, done chan credentials.AuthInfo, lis net.Listener) { 177 serverRawConn, err := lis.Accept() 178 if err != nil { 179 t.Errorf("Server failed to accept connection: %v", err) 180 close(done) 181 return 182 } 183 serverAuthInfo, err := hs(serverRawConn) 184 if err != nil { 185 t.Errorf("Server failed while handshake. Error: %v", err) 186 serverRawConn.Close() 187 close(done) 188 return 189 } 190 done <- serverAuthInfo 191 } 192 193 func clientHandle(t *testing.T, hs func(net.Conn, string) (credentials.AuthInfo, error), lisAddr string) credentials.AuthInfo { 194 conn, err := net.Dial("tcp", lisAddr) 195 if err != nil { 196 t.Fatalf("Client failed to connect to %s. Error: %v", lisAddr, err) 197 } 198 defer conn.Close() 199 clientAuthInfo, err := hs(conn, lisAddr) 200 if err != nil { 201 t.Fatalf("Error on client while handshake. Error: %v", err) 202 } 203 return clientAuthInfo 204 } 205 206 // Server handshake implementation in gRPC. 207 func gRPCServerHandshake(conn net.Conn) (credentials.AuthInfo, error) { 208 serverTLS, err := NewServerTLSFromFileWithALPNDisabled(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem")) 209 if err != nil { 210 return nil, err 211 } 212 _, serverAuthInfo, err := serverTLS.ServerHandshake(conn) 213 if err != nil { 214 return nil, err 215 } 216 return serverAuthInfo, nil 217 } 218 219 // Client handshake implementation in gRPC. 220 func gRPCClientHandshake(conn net.Conn, lisAddr string) (credentials.AuthInfo, error) { 221 clientTLS := NewTLSWithALPNDisabled(&tls.Config{InsecureSkipVerify: true}) 222 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 223 defer cancel() 224 _, authInfo, err := clientTLS.ClientHandshake(ctx, lisAddr, conn) 225 if err != nil { 226 return nil, err 227 } 228 return authInfo, nil 229 } 230 231 func tlsServerHandshake(conn net.Conn) (credentials.AuthInfo, error) { 232 cert, err := tls.LoadX509KeyPair(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem")) 233 if err != nil { 234 return nil, err 235 } 236 serverTLSConfig := &tls.Config{Certificates: []tls.Certificate{cert}} 237 serverConn := tls.Server(conn, serverTLSConfig) 238 err = serverConn.Handshake() 239 if err != nil { 240 return nil, err 241 } 242 return credentials.TLSInfo{State: serverConn.ConnectionState(), CommonAuthInfo: credentials.CommonAuthInfo{SecurityLevel: credentials.PrivacyAndIntegrity}}, nil 243 } 244 245 func tlsClientHandshake(conn net.Conn, _ string) (credentials.AuthInfo, error) { 246 clientTLSConfig := &tls.Config{InsecureSkipVerify: true} 247 clientConn := tls.Client(conn, clientTLSConfig) 248 if err := clientConn.Handshake(); err != nil { 249 return nil, err 250 } 251 return credentials.TLSInfo{State: clientConn.ConnectionState(), CommonAuthInfo: credentials.CommonAuthInfo{SecurityLevel: credentials.PrivacyAndIntegrity}}, nil 252 }