github.com/hxx258456/ccgo@v0.0.5-0.20230213014102-48b35f46f66f/grpc/credentials/google/google_test.go (about) 1 /* 2 * 3 * Copyright 2021 gRPC authors. 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 * 17 */ 18 19 package google 20 21 import ( 22 "context" 23 "net" 24 "testing" 25 26 "github.com/hxx258456/ccgo/grpc/credentials" 27 "github.com/hxx258456/ccgo/grpc/internal" 28 icredentials "github.com/hxx258456/ccgo/grpc/internal/credentials" 29 "github.com/hxx258456/ccgo/grpc/internal/grpctest" 30 "github.com/hxx258456/ccgo/grpc/resolver" 31 ) 32 33 type s struct { 34 grpctest.Tester 35 } 36 37 func Test(t *testing.T) { 38 grpctest.RunSubTests(t, s{}) 39 } 40 41 type testCreds struct { 42 credentials.TransportCredentials 43 typ string 44 } 45 46 func (c *testCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { 47 return nil, &testAuthInfo{typ: c.typ}, nil 48 } 49 50 func (c *testCreds) ServerHandshake(conn net.Conn) (net.Conn, credentials.AuthInfo, error) { 51 return nil, &testAuthInfo{typ: c.typ}, nil 52 } 53 54 type testAuthInfo struct { 55 typ string 56 } 57 58 func (t *testAuthInfo) AuthType() string { 59 return t.typ 60 } 61 62 var ( 63 testTLS = &testCreds{typ: "tls"} 64 testALTS = &testCreds{typ: "alts"} 65 ) 66 67 func overrideNewCredsFuncs() func() { 68 oldNewTLS := newTLS 69 newTLS = func() credentials.TransportCredentials { 70 return testTLS 71 } 72 oldNewALTS := newALTS 73 newALTS = func() credentials.TransportCredentials { 74 return testALTS 75 } 76 return func() { 77 newTLS = oldNewTLS 78 newALTS = oldNewALTS 79 } 80 } 81 82 // TestClientHandshakeBasedOnClusterName that by default (without switching 83 // modes), ClientHandshake does either tls or alts base on the cluster name in 84 // attributes. 85 func (s) TestClientHandshakeBasedOnClusterName(t *testing.T) { 86 defer overrideNewCredsFuncs()() 87 for bundleTyp, tc := range map[string]credentials.Bundle{ 88 "defaultCredsWithOptions": NewDefaultCredentialsWithOptions(DefaultCredentialsOptions{}), 89 "defaultCreds": NewDefaultCredentials(), 90 "computeCreds": NewComputeEngineCredentials(), 91 } { 92 tests := []struct { 93 name string 94 ctx context.Context 95 wantTyp string 96 }{ 97 { 98 name: "no cluster name", 99 ctx: context.Background(), 100 wantTyp: "tls", 101 }, 102 { 103 name: "with non-CFE cluster name", 104 ctx: icredentials.NewClientHandshakeInfoContext(context.Background(), credentials.ClientHandshakeInfo{ 105 Attributes: internal.SetXDSHandshakeClusterName(resolver.Address{}, "lalala").Attributes, 106 }), 107 // non-CFE backends should use alts. 108 wantTyp: "alts", 109 }, 110 { 111 name: "with CFE cluster name", 112 ctx: icredentials.NewClientHandshakeInfoContext(context.Background(), credentials.ClientHandshakeInfo{ 113 Attributes: internal.SetXDSHandshakeClusterName(resolver.Address{}, "google_cfe_bigtable.googleapis.com").Attributes, 114 }), 115 // CFE should use tls. 116 wantTyp: "tls", 117 }, 118 } 119 for _, tt := range tests { 120 t.Run(bundleTyp+" "+tt.name, func(t *testing.T) { 121 _, info, err := tc.TransportCredentials().ClientHandshake(tt.ctx, "", nil) 122 if err != nil { 123 t.Fatalf("ClientHandshake failed: %v", err) 124 } 125 if gotType := info.AuthType(); gotType != tt.wantTyp { 126 t.Fatalf("unexpected authtype: %v, want: %v", gotType, tt.wantTyp) 127 } 128 129 _, infoServer, err := tc.TransportCredentials().ServerHandshake(nil) 130 if err != nil { 131 t.Fatalf("ClientHandshake failed: %v", err) 132 } 133 // ServerHandshake should always do TLS. 134 if gotType := infoServer.AuthType(); gotType != "tls" { 135 t.Fatalf("unexpected server authtype: %v, want: %v", gotType, "tls") 136 } 137 }) 138 } 139 } 140 }