github.com/hxx258456/ccgo@v0.0.5-0.20230213014102-48b35f46f66f/grpc/credentials/google/google_test.go (about)

     1  /*
     2   *
     3   * Copyright 2021 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  package google
    20  
    21  import (
    22  	"context"
    23  	"net"
    24  	"testing"
    25  
    26  	"github.com/hxx258456/ccgo/grpc/credentials"
    27  	"github.com/hxx258456/ccgo/grpc/internal"
    28  	icredentials "github.com/hxx258456/ccgo/grpc/internal/credentials"
    29  	"github.com/hxx258456/ccgo/grpc/internal/grpctest"
    30  	"github.com/hxx258456/ccgo/grpc/resolver"
    31  )
    32  
    33  type s struct {
    34  	grpctest.Tester
    35  }
    36  
    37  func Test(t *testing.T) {
    38  	grpctest.RunSubTests(t, s{})
    39  }
    40  
    41  type testCreds struct {
    42  	credentials.TransportCredentials
    43  	typ string
    44  }
    45  
    46  func (c *testCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
    47  	return nil, &testAuthInfo{typ: c.typ}, nil
    48  }
    49  
    50  func (c *testCreds) ServerHandshake(conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
    51  	return nil, &testAuthInfo{typ: c.typ}, nil
    52  }
    53  
    54  type testAuthInfo struct {
    55  	typ string
    56  }
    57  
    58  func (t *testAuthInfo) AuthType() string {
    59  	return t.typ
    60  }
    61  
    62  var (
    63  	testTLS  = &testCreds{typ: "tls"}
    64  	testALTS = &testCreds{typ: "alts"}
    65  )
    66  
    67  func overrideNewCredsFuncs() func() {
    68  	oldNewTLS := newTLS
    69  	newTLS = func() credentials.TransportCredentials {
    70  		return testTLS
    71  	}
    72  	oldNewALTS := newALTS
    73  	newALTS = func() credentials.TransportCredentials {
    74  		return testALTS
    75  	}
    76  	return func() {
    77  		newTLS = oldNewTLS
    78  		newALTS = oldNewALTS
    79  	}
    80  }
    81  
    82  // TestClientHandshakeBasedOnClusterName that by default (without switching
    83  // modes), ClientHandshake does either tls or alts base on the cluster name in
    84  // attributes.
    85  func (s) TestClientHandshakeBasedOnClusterName(t *testing.T) {
    86  	defer overrideNewCredsFuncs()()
    87  	for bundleTyp, tc := range map[string]credentials.Bundle{
    88  		"defaultCredsWithOptions": NewDefaultCredentialsWithOptions(DefaultCredentialsOptions{}),
    89  		"defaultCreds":            NewDefaultCredentials(),
    90  		"computeCreds":            NewComputeEngineCredentials(),
    91  	} {
    92  		tests := []struct {
    93  			name    string
    94  			ctx     context.Context
    95  			wantTyp string
    96  		}{
    97  			{
    98  				name:    "no cluster name",
    99  				ctx:     context.Background(),
   100  				wantTyp: "tls",
   101  			},
   102  			{
   103  				name: "with non-CFE cluster name",
   104  				ctx: icredentials.NewClientHandshakeInfoContext(context.Background(), credentials.ClientHandshakeInfo{
   105  					Attributes: internal.SetXDSHandshakeClusterName(resolver.Address{}, "lalala").Attributes,
   106  				}),
   107  				// non-CFE backends should use alts.
   108  				wantTyp: "alts",
   109  			},
   110  			{
   111  				name: "with CFE cluster name",
   112  				ctx: icredentials.NewClientHandshakeInfoContext(context.Background(), credentials.ClientHandshakeInfo{
   113  					Attributes: internal.SetXDSHandshakeClusterName(resolver.Address{}, "google_cfe_bigtable.googleapis.com").Attributes,
   114  				}),
   115  				// CFE should use tls.
   116  				wantTyp: "tls",
   117  			},
   118  		}
   119  		for _, tt := range tests {
   120  			t.Run(bundleTyp+" "+tt.name, func(t *testing.T) {
   121  				_, info, err := tc.TransportCredentials().ClientHandshake(tt.ctx, "", nil)
   122  				if err != nil {
   123  					t.Fatalf("ClientHandshake failed: %v", err)
   124  				}
   125  				if gotType := info.AuthType(); gotType != tt.wantTyp {
   126  					t.Fatalf("unexpected authtype: %v, want: %v", gotType, tt.wantTyp)
   127  				}
   128  
   129  				_, infoServer, err := tc.TransportCredentials().ServerHandshake(nil)
   130  				if err != nil {
   131  					t.Fatalf("ClientHandshake failed: %v", err)
   132  				}
   133  				// ServerHandshake should always do TLS.
   134  				if gotType := infoServer.AuthType(); gotType != "tls" {
   135  					t.Fatalf("unexpected server authtype: %v, want: %v", gotType, "tls")
   136  				}
   137  			})
   138  		}
   139  	}
   140  }