github.com/hxx258456/ccgo@v0.0.5-0.20230213014102-48b35f46f66f/grpc/credentials/alts/alts_test.go (about)

     1  //go:build linux || windows
     2  // +build linux windows
     3  
     4  /*
     5   *
     6   * Copyright 2018 gRPC authors.
     7   *
     8   * Licensed under the Apache License, Version 2.0 (the "License");
     9   * you may not use this file except in compliance with the License.
    10   * You may obtain a copy of the License at
    11   *
    12   *     http://www.apache.org/licenses/LICENSE-2.0
    13   *
    14   * Unless required by applicable law or agreed to in writing, software
    15   * distributed under the License is distributed on an "AS IS" BASIS,
    16   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    17   * See the License for the specific language governing permissions and
    18   * limitations under the License.
    19   *
    20   */
    21  
    22  package alts
    23  
    24  import (
    25  	"reflect"
    26  	"testing"
    27  
    28  	"github.com/golang/protobuf/proto"
    29  	altspb "github.com/hxx258456/ccgo/grpc/credentials/alts/internal/proto/grpc_gcp"
    30  	"github.com/hxx258456/ccgo/grpc/internal/grpctest"
    31  )
    32  
    33  type s struct {
    34  	grpctest.Tester
    35  }
    36  
    37  func Test(t *testing.T) {
    38  	grpctest.RunSubTests(t, s{})
    39  }
    40  
    41  func (s) TestInfoServerName(t *testing.T) {
    42  	// This is not testing any handshaker functionality, so it's fine to only
    43  	// use NewServerCreds and not NewClientCreds.
    44  	alts := NewServerCreds(DefaultServerOptions())
    45  	if got, want := alts.Info().ServerName, ""; got != want {
    46  		t.Fatalf("%v.Info().ServerName = %v, want %v", alts, got, want)
    47  	}
    48  }
    49  
    50  func (s) TestOverrideServerName(t *testing.T) {
    51  	wantServerName := "server.name"
    52  	// This is not testing any handshaker functionality, so it's fine to only
    53  	// use NewServerCreds and not NewClientCreds.
    54  	c := NewServerCreds(DefaultServerOptions())
    55  	c.OverrideServerName(wantServerName)
    56  	if got, want := c.Info().ServerName, wantServerName; got != want {
    57  		t.Fatalf("c.Info().ServerName = %v, want %v", got, want)
    58  	}
    59  }
    60  
    61  func (s) TestCloneClient(t *testing.T) {
    62  	wantServerName := "server.name"
    63  	opt := DefaultClientOptions()
    64  	opt.TargetServiceAccounts = []string{"not", "empty"}
    65  	c := NewClientCreds(opt)
    66  	c.OverrideServerName(wantServerName)
    67  	cc := c.Clone()
    68  	if got, want := cc.Info().ServerName, wantServerName; got != want {
    69  		t.Fatalf("cc.Info().ServerName = %v, want %v", got, want)
    70  	}
    71  	cc.OverrideServerName("")
    72  	if got, want := c.Info().ServerName, wantServerName; got != want {
    73  		t.Fatalf("Change in clone should not affect the original, c.Info().ServerName = %v, want %v", got, want)
    74  	}
    75  	if got, want := cc.Info().ServerName, ""; got != want {
    76  		t.Fatalf("cc.Info().ServerName = %v, want %v", got, want)
    77  	}
    78  
    79  	ct := c.(*altsTC)
    80  	cct := cc.(*altsTC)
    81  
    82  	if ct.side != cct.side {
    83  		t.Errorf("cc.side = %q, want %q", cct.side, ct.side)
    84  	}
    85  	if ct.hsAddress != cct.hsAddress {
    86  		t.Errorf("cc.hsAddress = %q, want %q", cct.hsAddress, ct.hsAddress)
    87  	}
    88  	if !reflect.DeepEqual(ct.accounts, cct.accounts) {
    89  		t.Errorf("cc.accounts = %q, want %q", cct.accounts, ct.accounts)
    90  	}
    91  }
    92  
    93  func (s) TestCloneServer(t *testing.T) {
    94  	wantServerName := "server.name"
    95  	c := NewServerCreds(DefaultServerOptions())
    96  	c.OverrideServerName(wantServerName)
    97  	cc := c.Clone()
    98  	if got, want := cc.Info().ServerName, wantServerName; got != want {
    99  		t.Fatalf("cc.Info().ServerName = %v, want %v", got, want)
   100  	}
   101  	cc.OverrideServerName("")
   102  	if got, want := c.Info().ServerName, wantServerName; got != want {
   103  		t.Fatalf("Change in clone should not affect the original, c.Info().ServerName = %v, want %v", got, want)
   104  	}
   105  	if got, want := cc.Info().ServerName, ""; got != want {
   106  		t.Fatalf("cc.Info().ServerName = %v, want %v", got, want)
   107  	}
   108  
   109  	ct := c.(*altsTC)
   110  	cct := cc.(*altsTC)
   111  
   112  	if ct.side != cct.side {
   113  		t.Errorf("cc.side = %q, want %q", cct.side, ct.side)
   114  	}
   115  	if ct.hsAddress != cct.hsAddress {
   116  		t.Errorf("cc.hsAddress = %q, want %q", cct.hsAddress, ct.hsAddress)
   117  	}
   118  	if !reflect.DeepEqual(ct.accounts, cct.accounts) {
   119  		t.Errorf("cc.accounts = %q, want %q", cct.accounts, ct.accounts)
   120  	}
   121  }
   122  
   123  func (s) TestInfo(t *testing.T) {
   124  	// This is not testing any handshaker functionality, so it's fine to only
   125  	// use NewServerCreds and not NewClientCreds.
   126  	c := NewServerCreds(DefaultServerOptions())
   127  	info := c.Info()
   128  	if got, want := info.ProtocolVersion, ""; got != want {
   129  		t.Errorf("info.ProtocolVersion=%v, want %v", got, want)
   130  	}
   131  	if got, want := info.SecurityProtocol, "alts"; got != want {
   132  		t.Errorf("info.SecurityProtocol=%v, want %v", got, want)
   133  	}
   134  	if got, want := info.SecurityVersion, "1.0"; got != want {
   135  		t.Errorf("info.SecurityVersion=%v, want %v", got, want)
   136  	}
   137  	if got, want := info.ServerName, ""; got != want {
   138  		t.Errorf("info.ServerName=%v, want %v", got, want)
   139  	}
   140  }
   141  
   142  func (s) TestCompareRPCVersions(t *testing.T) {
   143  	for _, tc := range []struct {
   144  		v1     *altspb.RpcProtocolVersions_Version
   145  		v2     *altspb.RpcProtocolVersions_Version
   146  		output int
   147  	}{
   148  		{
   149  			version(3, 2),
   150  			version(2, 1),
   151  			1,
   152  		},
   153  		{
   154  			version(3, 2),
   155  			version(3, 1),
   156  			1,
   157  		},
   158  		{
   159  			version(2, 1),
   160  			version(3, 2),
   161  			-1,
   162  		},
   163  		{
   164  			version(3, 1),
   165  			version(3, 2),
   166  			-1,
   167  		},
   168  		{
   169  			version(3, 2),
   170  			version(3, 2),
   171  			0,
   172  		},
   173  	} {
   174  		if got, want := compareRPCVersions(tc.v1, tc.v2), tc.output; got != want {
   175  			t.Errorf("compareRPCVersions(%v, %v)=%v, want %v", tc.v1, tc.v2, got, want)
   176  		}
   177  	}
   178  }
   179  
   180  func (s) TestCheckRPCVersions(t *testing.T) {
   181  	for _, tc := range []struct {
   182  		desc             string
   183  		local            *altspb.RpcProtocolVersions
   184  		peer             *altspb.RpcProtocolVersions
   185  		output           bool
   186  		maxCommonVersion *altspb.RpcProtocolVersions_Version
   187  	}{
   188  		{
   189  			"local.max > peer.max and local.min > peer.min",
   190  			versions(2, 1, 3, 2),
   191  			versions(1, 2, 2, 1),
   192  			true,
   193  			version(2, 1),
   194  		},
   195  		{
   196  			"local.max > peer.max and local.min < peer.min",
   197  			versions(1, 2, 3, 2),
   198  			versions(2, 1, 2, 1),
   199  			true,
   200  			version(2, 1),
   201  		},
   202  		{
   203  			"local.max > peer.max and local.min = peer.min",
   204  			versions(2, 1, 3, 2),
   205  			versions(2, 1, 2, 1),
   206  			true,
   207  			version(2, 1),
   208  		},
   209  		{
   210  			"local.max < peer.max and local.min > peer.min",
   211  			versions(2, 1, 2, 1),
   212  			versions(1, 2, 3, 2),
   213  			true,
   214  			version(2, 1),
   215  		},
   216  		{
   217  			"local.max = peer.max and local.min > peer.min",
   218  			versions(2, 1, 2, 1),
   219  			versions(1, 2, 2, 1),
   220  			true,
   221  			version(2, 1),
   222  		},
   223  		{
   224  			"local.max < peer.max and local.min < peer.min",
   225  			versions(1, 2, 2, 1),
   226  			versions(2, 1, 3, 2),
   227  			true,
   228  			version(2, 1),
   229  		},
   230  		{
   231  			"local.max < peer.max and local.min = peer.min",
   232  			versions(1, 2, 2, 1),
   233  			versions(1, 2, 3, 2),
   234  			true,
   235  			version(2, 1),
   236  		},
   237  		{
   238  			"local.max = peer.max and local.min < peer.min",
   239  			versions(1, 2, 2, 1),
   240  			versions(2, 1, 2, 1),
   241  			true,
   242  			version(2, 1),
   243  		},
   244  		{
   245  			"all equal",
   246  			versions(2, 1, 2, 1),
   247  			versions(2, 1, 2, 1),
   248  			true,
   249  			version(2, 1),
   250  		},
   251  		{
   252  			"max is smaller than min",
   253  			versions(2, 1, 1, 2),
   254  			versions(2, 1, 1, 2),
   255  			false,
   256  			nil,
   257  		},
   258  		{
   259  			"no overlap, local > peer",
   260  			versions(4, 3, 6, 5),
   261  			versions(1, 0, 2, 1),
   262  			false,
   263  			nil,
   264  		},
   265  		{
   266  			"no overlap, local < peer",
   267  			versions(1, 0, 2, 1),
   268  			versions(4, 3, 6, 5),
   269  			false,
   270  			nil,
   271  		},
   272  		{
   273  			"no overlap, max < min",
   274  			versions(6, 5, 4, 3),
   275  			versions(2, 1, 1, 0),
   276  			false,
   277  			nil,
   278  		},
   279  	} {
   280  		output, maxCommonVersion := checkRPCVersions(tc.local, tc.peer)
   281  		if got, want := output, tc.output; got != want {
   282  			t.Errorf("%v: checkRPCVersions(%v, %v)=(%v, _), want (%v, _)", tc.desc, tc.local, tc.peer, got, want)
   283  		}
   284  		if got, want := maxCommonVersion, tc.maxCommonVersion; !proto.Equal(got, want) {
   285  			t.Errorf("%v: checkRPCVersions(%v, %v)=(_, %v), want (_, %v)", tc.desc, tc.local, tc.peer, got, want)
   286  		}
   287  	}
   288  }
   289  
   290  func version(major, minor uint32) *altspb.RpcProtocolVersions_Version {
   291  	return &altspb.RpcProtocolVersions_Version{
   292  		Major: major,
   293  		Minor: minor,
   294  	}
   295  }
   296  
   297  func versions(minMajor, minMinor, maxMajor, maxMinor uint32) *altspb.RpcProtocolVersions {
   298  	return &altspb.RpcProtocolVersions{
   299  		MinRpcVersion: version(minMajor, minMinor),
   300  		MaxRpcVersion: version(maxMajor, maxMinor),
   301  	}
   302  }