google.golang.org/grpc@v1.72.2/credentials/google/google_test.go (about)

     1  /*
     2   *
     3   * Copyright 2021 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  package google
    20  
    21  import (
    22  	"context"
    23  	"net"
    24  	"testing"
    25  	"time"
    26  
    27  	"github.com/google/go-cmp/cmp"
    28  	"google.golang.org/grpc/credentials"
    29  	icredentials "google.golang.org/grpc/internal/credentials"
    30  	"google.golang.org/grpc/internal/grpctest"
    31  	"google.golang.org/grpc/internal/xds"
    32  	"google.golang.org/grpc/resolver"
    33  )
    34  
    35  var defaultTestTimeout = 10 * time.Second
    36  
    37  type s struct {
    38  	grpctest.Tester
    39  }
    40  
    41  func Test(t *testing.T) {
    42  	grpctest.RunSubTests(t, s{})
    43  }
    44  
    45  type testCreds struct {
    46  	credentials.TransportCredentials
    47  	typ string
    48  }
    49  
    50  func (c *testCreds) ClientHandshake(context.Context, string, net.Conn) (net.Conn, credentials.AuthInfo, error) {
    51  	return nil, &testAuthInfo{typ: c.typ}, nil
    52  }
    53  
    54  func (c *testCreds) ServerHandshake(net.Conn) (net.Conn, credentials.AuthInfo, error) {
    55  	return nil, &testAuthInfo{typ: c.typ}, nil
    56  }
    57  
    58  type testAuthInfo struct {
    59  	typ string
    60  }
    61  
    62  func (t *testAuthInfo) AuthType() string {
    63  	return t.typ
    64  }
    65  
    66  type testPerRPCCreds struct {
    67  	md map[string]string
    68  }
    69  
    70  func (c *testPerRPCCreds) RequireTransportSecurity() bool {
    71  	return true
    72  }
    73  
    74  func (c *testPerRPCCreds) GetRequestMetadata(context.Context, ...string) (map[string]string, error) {
    75  	return c.md, nil
    76  }
    77  
    78  var (
    79  	testTLS  = &testCreds{typ: "tls"}
    80  	testALTS = &testCreds{typ: "alts"}
    81  )
    82  
    83  func overrideNewCredsFuncs() func() {
    84  	origNewTLS := newTLS
    85  	newTLS = func() credentials.TransportCredentials {
    86  		return testTLS
    87  	}
    88  	origNewALTS := newALTS
    89  	newALTS = func() credentials.TransportCredentials {
    90  		return testALTS
    91  	}
    92  	origNewADC := newADC
    93  	newADC = func(context.Context) (credentials.PerRPCCredentials, error) {
    94  		// We do not use perRPC creds in this test. It is safe to return nil here.
    95  		return nil, nil
    96  	}
    97  
    98  	return func() {
    99  		newTLS = origNewTLS
   100  		newALTS = origNewALTS
   101  		newADC = origNewADC
   102  	}
   103  }
   104  
   105  // TestClientHandshakeBasedOnClusterName that by default (without switching
   106  // modes), ClientHandshake does either tls or alts base on the cluster name in
   107  // attributes.
   108  func (s) TestClientHandshakeBasedOnClusterName(t *testing.T) {
   109  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   110  	defer cancel()
   111  	defer overrideNewCredsFuncs()()
   112  	for bundleTyp, tc := range map[string]credentials.Bundle{
   113  		"defaultCredsWithOptions": NewDefaultCredentialsWithOptions(DefaultCredentialsOptions{}),
   114  		"defaultCreds":            NewDefaultCredentials(),
   115  		"computeCreds":            NewComputeEngineCredentials(),
   116  	} {
   117  		tests := []struct {
   118  			name    string
   119  			ctx     context.Context
   120  			wantTyp string
   121  		}{
   122  			{
   123  				name:    "no cluster name",
   124  				ctx:     ctx,
   125  				wantTyp: "tls",
   126  			},
   127  			{
   128  				name: "with non-CFE cluster name",
   129  				ctx: icredentials.NewClientHandshakeInfoContext(ctx, credentials.ClientHandshakeInfo{
   130  					Attributes: xds.SetXDSHandshakeClusterName(resolver.Address{}, "lalala").Attributes,
   131  				}),
   132  				// non-CFE backends should use alts.
   133  				wantTyp: "alts",
   134  			},
   135  			{
   136  				name: "with CFE cluster name",
   137  				ctx: icredentials.NewClientHandshakeInfoContext(ctx, credentials.ClientHandshakeInfo{
   138  					Attributes: xds.SetXDSHandshakeClusterName(resolver.Address{}, "google_cfe_bigtable.googleapis.com").Attributes,
   139  				}),
   140  				// CFE should use tls.
   141  				wantTyp: "tls",
   142  			},
   143  			{
   144  				name: "with xdstp CFE cluster name",
   145  				ctx: icredentials.NewClientHandshakeInfoContext(ctx, credentials.ClientHandshakeInfo{
   146  					Attributes: xds.SetXDSHandshakeClusterName(resolver.Address{}, "xdstp://traffic-director-c2p.xds.googleapis.com/envoy.config.cluster.v3.Cluster/google_cfe_bigtable.googleapis.com").Attributes,
   147  				}),
   148  				// CFE should use tls.
   149  				wantTyp: "tls",
   150  			},
   151  			{
   152  				name: "with xdstp non-CFE cluster name",
   153  				ctx: icredentials.NewClientHandshakeInfoContext(ctx, credentials.ClientHandshakeInfo{
   154  					Attributes: xds.SetXDSHandshakeClusterName(resolver.Address{}, "xdstp://other.com/envoy.config.cluster.v3.Cluster/google_cfe_bigtable.googleapis.com").Attributes,
   155  				}),
   156  				// non-CFE should use atls.
   157  				wantTyp: "alts",
   158  			},
   159  		}
   160  		for _, tt := range tests {
   161  			t.Run(bundleTyp+" "+tt.name, func(t *testing.T) {
   162  				_, info, err := tc.TransportCredentials().ClientHandshake(tt.ctx, "", nil)
   163  				if err != nil {
   164  					t.Fatalf("ClientHandshake failed: %v", err)
   165  				}
   166  				if gotType := info.AuthType(); gotType != tt.wantTyp {
   167  					t.Fatalf("unexpected authtype: %v, want: %v", gotType, tt.wantTyp)
   168  				}
   169  
   170  				_, infoServer, err := tc.TransportCredentials().ServerHandshake(nil)
   171  				if err != nil {
   172  					t.Fatalf("ClientHandshake failed: %v", err)
   173  				}
   174  				// ServerHandshake should always do TLS.
   175  				if gotType := infoServer.AuthType(); gotType != "tls" {
   176  					t.Fatalf("unexpected server authtype: %v, want: %v", gotType, "tls")
   177  				}
   178  			})
   179  		}
   180  	}
   181  }
   182  
   183  func TestDefaultCredentialsWithOptions(t *testing.T) {
   184  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   185  	defer cancel()
   186  	md1 := map[string]string{"foo": "tls"}
   187  	md2 := map[string]string{"foo": "alts"}
   188  	tests := []struct {
   189  		desc             string
   190  		defaultCredsOpts DefaultCredentialsOptions
   191  		authInfo         credentials.AuthInfo
   192  		wantedMetadata   map[string]string
   193  	}{
   194  		{
   195  			desc: "no ALTSPerRPCCreds with tls channel",
   196  			defaultCredsOpts: DefaultCredentialsOptions{
   197  				PerRPCCreds: &testPerRPCCreds{
   198  					md: md1,
   199  				},
   200  			},
   201  			authInfo:       &testAuthInfo{typ: "tls"},
   202  			wantedMetadata: md1,
   203  		},
   204  		{
   205  			desc: "no ALTSPerRPCCreds with alts channel",
   206  			defaultCredsOpts: DefaultCredentialsOptions{
   207  				PerRPCCreds: &testPerRPCCreds{
   208  					md: md1,
   209  				},
   210  			},
   211  			authInfo:       &testAuthInfo{typ: "alts"},
   212  			wantedMetadata: md1,
   213  		},
   214  		{
   215  			desc: "ALTSPerRPCCreds specified with tls channel",
   216  			defaultCredsOpts: DefaultCredentialsOptions{
   217  				PerRPCCreds: &testPerRPCCreds{
   218  					md: md1,
   219  				},
   220  				ALTSPerRPCCreds: &testPerRPCCreds{
   221  					md: md2,
   222  				},
   223  			},
   224  			authInfo:       &testAuthInfo{typ: "tls"},
   225  			wantedMetadata: md1,
   226  		},
   227  		{
   228  			desc: "ALTSPerRPCCreds specified with alts channel",
   229  			defaultCredsOpts: DefaultCredentialsOptions{
   230  				PerRPCCreds: &testPerRPCCreds{
   231  					md: md1,
   232  				},
   233  				ALTSPerRPCCreds: &testPerRPCCreds{
   234  					md: md2,
   235  				},
   236  			},
   237  			authInfo:       &testAuthInfo{typ: "alts"},
   238  			wantedMetadata: md2,
   239  		},
   240  		{
   241  			desc: "ALTSPerRPCCreds specified with unknown channel",
   242  			defaultCredsOpts: DefaultCredentialsOptions{
   243  				PerRPCCreds: &testPerRPCCreds{
   244  					md: md1,
   245  				},
   246  				ALTSPerRPCCreds: &testPerRPCCreds{
   247  					md: md2,
   248  				},
   249  			},
   250  			authInfo:       &testAuthInfo{typ: "foo"},
   251  			wantedMetadata: md1,
   252  		},
   253  	}
   254  	for _, tc := range tests {
   255  		t.Run(tc.desc, func(t *testing.T) {
   256  			bundle := NewDefaultCredentialsWithOptions(tc.defaultCredsOpts)
   257  			ri := credentials.RequestInfo{AuthInfo: tc.authInfo}
   258  			ctx := icredentials.NewRequestInfoContext(ctx, ri)
   259  			got, err := bundle.PerRPCCredentials().GetRequestMetadata(ctx, "uri")
   260  			if err != nil {
   261  				t.Fatalf("Bundle's PerRPCCredentials().GetRequestMetadata() unexpected error = %v", err)
   262  			}
   263  			if diff := cmp.Diff(got, tc.wantedMetadata); diff != "" {
   264  				t.Errorf("Unexpected request metadata from bundle's PerRPCCredentials. Diff (-got +want):\n%v", diff)
   265  			}
   266  		})
   267  	}
   268  }