gitee.com/ks-custle/core-gm@v0.0.0-20230922171213-b83bdd97b62c/grpc/credentials/local/local_test.go (about) 1 /* 2 * 3 * Copyright 2020 gRPC authors. 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 * 17 */ 18 19 package local 20 21 import ( 22 "context" 23 "fmt" 24 "net" 25 "runtime" 26 "strings" 27 "testing" 28 "time" 29 30 "gitee.com/ks-custle/core-gm/grpc/credentials" 31 "gitee.com/ks-custle/core-gm/grpc/internal/grpctest" 32 ) 33 34 const defaultTestTimeout = 10 * time.Second 35 36 type s struct { 37 grpctest.Tester 38 } 39 40 func Test(t *testing.T) { 41 grpctest.RunSubTests(t, s{}) 42 } 43 44 func (s) TestGetSecurityLevel(t *testing.T) { 45 testCases := []struct { 46 testNetwork string 47 testAddr string 48 want credentials.SecurityLevel 49 }{ 50 { 51 testNetwork: "tcp", 52 testAddr: "127.0.0.1:10000", 53 want: credentials.NoSecurity, 54 }, 55 { 56 testNetwork: "tcp", 57 testAddr: "[::1]:10000", 58 want: credentials.NoSecurity, 59 }, 60 { 61 testNetwork: "unix", 62 testAddr: "/tmp/grpc_fullstack_test", 63 want: credentials.PrivacyAndIntegrity, 64 }, 65 { 66 testNetwork: "tcp", 67 testAddr: "192.168.0.1:10000", 68 want: credentials.InvalidSecurityLevel, 69 }, 70 } 71 for _, tc := range testCases { 72 got, _ := getSecurityLevel(tc.testNetwork, tc.testAddr) 73 if got != tc.want { 74 t.Fatalf("GetSeurityLevel(%s, %s) returned %s but want %s", tc.testNetwork, tc.testAddr, got.String(), tc.want.String()) 75 } 76 } 77 } 78 79 type serverHandshake func(net.Conn) (credentials.AuthInfo, error) 80 81 func getSecurityLevelFromAuthInfo(ai credentials.AuthInfo) credentials.SecurityLevel { 82 if c, ok := ai.(interface { 83 GetCommonAuthInfo() credentials.CommonAuthInfo 84 }); ok { 85 return c.GetCommonAuthInfo().SecurityLevel 86 } 87 return credentials.InvalidSecurityLevel 88 } 89 90 // Server local handshake implementation. 91 func serverLocalHandshake(conn net.Conn) (credentials.AuthInfo, error) { 92 cred := NewCredentials() 93 _, authInfo, err := cred.ServerHandshake(conn) 94 if err != nil { 95 return nil, err 96 } 97 return authInfo, nil 98 } 99 100 // Client local handshake implementation. 101 func clientLocalHandshake(conn net.Conn, lisAddr string) (credentials.AuthInfo, error) { 102 cred := NewCredentials() 103 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 104 defer cancel() 105 106 _, authInfo, err := cred.ClientHandshake(ctx, lisAddr, conn) 107 if err != nil { 108 return nil, err 109 } 110 return authInfo, nil 111 } 112 113 // Client connects to a server with local credentials. 114 func clientHandle(hs func(net.Conn, string) (credentials.AuthInfo, error), network, lisAddr string) (credentials.AuthInfo, error) { 115 conn, _ := net.Dial(network, lisAddr) 116 defer conn.Close() 117 clientAuthInfo, err := hs(conn, lisAddr) 118 if err != nil { 119 return nil, fmt.Errorf("Error on client while handshake") 120 } 121 return clientAuthInfo, nil 122 } 123 124 type testServerHandleResult struct { 125 authInfo credentials.AuthInfo 126 err error 127 } 128 129 // Server accepts a client's connection with local credentials. 130 func serverHandle(hs serverHandshake, done chan testServerHandleResult, lis net.Listener) { 131 serverRawConn, err := lis.Accept() 132 if err != nil { 133 done <- testServerHandleResult{authInfo: nil, err: fmt.Errorf("Server failed to accept connection. Error: %v", err)} 134 return 135 } 136 serverAuthInfo, err := hs(serverRawConn) 137 if err != nil { 138 serverRawConn.Close() 139 done <- testServerHandleResult{authInfo: nil, err: fmt.Errorf("Server failed while handshake. Error: %v", err)} 140 return 141 } 142 done <- testServerHandleResult{authInfo: serverAuthInfo, err: nil} 143 } 144 145 func serverAndClientHandshake(lis net.Listener) (credentials.SecurityLevel, error) { 146 done := make(chan testServerHandleResult, 1) 147 const timeout = 5 * time.Second 148 timer := time.NewTimer(timeout) 149 defer timer.Stop() 150 go serverHandle(serverLocalHandshake, done, lis) 151 defer lis.Close() 152 clientAuthInfo, err := clientHandle(clientLocalHandshake, lis.Addr().Network(), lis.Addr().String()) 153 if err != nil { 154 return credentials.InvalidSecurityLevel, fmt.Errorf("Error at client-side: %v", err) 155 } 156 select { 157 case <-timer.C: 158 return credentials.InvalidSecurityLevel, fmt.Errorf("Test didn't finish in time") 159 case serverHandleResult := <-done: 160 if serverHandleResult.err != nil { 161 return credentials.InvalidSecurityLevel, fmt.Errorf("Error at server-side: %v", serverHandleResult.err) 162 } 163 clientSecLevel := getSecurityLevelFromAuthInfo(clientAuthInfo) 164 serverSecLevel := getSecurityLevelFromAuthInfo(serverHandleResult.authInfo) 165 166 if clientSecLevel == credentials.InvalidSecurityLevel { 167 return credentials.InvalidSecurityLevel, fmt.Errorf("Error at client-side: client's AuthInfo does not implement GetCommonAuthInfo()") 168 } 169 if serverSecLevel == credentials.InvalidSecurityLevel { 170 return credentials.InvalidSecurityLevel, fmt.Errorf("Error at server-side: server's AuthInfo does not implement GetCommonAuthInfo()") 171 } 172 if clientSecLevel != serverSecLevel { 173 return credentials.InvalidSecurityLevel, fmt.Errorf("client's AuthInfo contains %s but server's AuthInfo contains %s", clientSecLevel.String(), serverSecLevel.String()) 174 } 175 return clientSecLevel, nil 176 } 177 } 178 179 func (s) TestServerAndClientHandshake(t *testing.T) { 180 testCases := []struct { 181 testNetwork string 182 testAddr string 183 want credentials.SecurityLevel 184 }{ 185 { 186 testNetwork: "tcp", 187 testAddr: "127.0.0.1:0", 188 want: credentials.NoSecurity, 189 }, 190 { 191 testNetwork: "tcp", 192 testAddr: "[::1]:0", 193 want: credentials.NoSecurity, 194 }, 195 { 196 testNetwork: "tcp", 197 testAddr: "localhost:0", 198 want: credentials.NoSecurity, 199 }, 200 { 201 testNetwork: "unix", 202 testAddr: fmt.Sprintf("/tmp/grpc_fullstck_test%d", time.Now().UnixNano()), 203 want: credentials.PrivacyAndIntegrity, 204 }, 205 } 206 for _, tc := range testCases { 207 if runtime.GOOS == "windows" && tc.testNetwork == "unix" { 208 t.Skip("skipping tests for unix connections on Windows") 209 } 210 t.Run("serverAndClientHandshakeResult", func(t *testing.T) { 211 lis, err := net.Listen(tc.testNetwork, tc.testAddr) 212 if err != nil { 213 if strings.Contains(err.Error(), "bind: cannot assign requested address") || 214 strings.Contains(err.Error(), "socket: address family not supported by protocol") { 215 t.Skipf("no support for address %v", tc.testAddr) 216 } 217 t.Fatalf("Failed to listen: %v", err) 218 } 219 got, err := serverAndClientHandshake(lis) 220 if got != tc.want { 221 t.Fatalf("serverAndClientHandshake(%s, %s) = %v, %v; want %v, nil", tc.testNetwork, tc.testAddr, got, err, tc.want) 222 } 223 }) 224 } 225 }