gitee.com/ks-custle/core-gm@v0.0.0-20230922171213-b83bdd97b62c/grpc/test/healthcheck_test.go (about)

     1  /*
     2   *
     3   * Copyright 2018 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  package test
    20  
    21  import (
    22  	"context"
    23  	"errors"
    24  	"fmt"
    25  	"net"
    26  	"sync"
    27  	"testing"
    28  	"time"
    29  
    30  	grpc "gitee.com/ks-custle/core-gm/grpc"
    31  	"gitee.com/ks-custle/core-gm/grpc/codes"
    32  	"gitee.com/ks-custle/core-gm/grpc/connectivity"
    33  	_ "gitee.com/ks-custle/core-gm/grpc/health"
    34  	healthgrpc "gitee.com/ks-custle/core-gm/grpc/health/grpc_health_v1"
    35  	healthpb "gitee.com/ks-custle/core-gm/grpc/health/grpc_health_v1"
    36  	"gitee.com/ks-custle/core-gm/grpc/internal"
    37  	"gitee.com/ks-custle/core-gm/grpc/internal/channelz"
    38  	"gitee.com/ks-custle/core-gm/grpc/internal/grpctest"
    39  	"gitee.com/ks-custle/core-gm/grpc/resolver"
    40  	"gitee.com/ks-custle/core-gm/grpc/resolver/manual"
    41  	"gitee.com/ks-custle/core-gm/grpc/status"
    42  	testpb "gitee.com/ks-custle/core-gm/grpc/test/grpc_testing"
    43  )
    44  
    45  var testHealthCheckFunc = internal.HealthCheckFunc
    46  
    47  func newTestHealthServer() *testHealthServer {
    48  	return newTestHealthServerWithWatchFunc(defaultWatchFunc)
    49  }
    50  
    51  func newTestHealthServerWithWatchFunc(f func(s *testHealthServer, in *healthpb.HealthCheckRequest, stream healthgrpc.Health_WatchServer) error) *testHealthServer {
    52  	return &testHealthServer{
    53  		watchFunc: f,
    54  		update:    make(chan struct{}, 1),
    55  		status:    make(map[string]healthpb.HealthCheckResponse_ServingStatus),
    56  	}
    57  }
    58  
    59  // defaultWatchFunc will send a HealthCheckResponse to the client whenever SetServingStatus is called.
    60  func defaultWatchFunc(s *testHealthServer, in *healthpb.HealthCheckRequest, stream healthgrpc.Health_WatchServer) error {
    61  	if in.Service != "foo" {
    62  		return status.Error(codes.FailedPrecondition,
    63  			"the defaultWatchFunc only handles request with service name to be \"foo\"")
    64  	}
    65  	var done bool
    66  	for {
    67  		select {
    68  		case <-stream.Context().Done():
    69  			done = true
    70  		case <-s.update:
    71  		}
    72  		if done {
    73  			break
    74  		}
    75  		s.mu.Lock()
    76  		resp := &healthpb.HealthCheckResponse{
    77  			Status: s.status[in.Service],
    78  		}
    79  		s.mu.Unlock()
    80  		stream.SendMsg(resp)
    81  	}
    82  	return nil
    83  }
    84  
    85  type testHealthServer struct {
    86  	healthpb.UnimplementedHealthServer
    87  	watchFunc func(s *testHealthServer, in *healthpb.HealthCheckRequest, stream healthgrpc.Health_WatchServer) error
    88  	mu        sync.Mutex
    89  	status    map[string]healthpb.HealthCheckResponse_ServingStatus
    90  	update    chan struct{}
    91  }
    92  
    93  func (s *testHealthServer) Check(ctx context.Context, in *healthpb.HealthCheckRequest) (*healthpb.HealthCheckResponse, error) {
    94  	return &healthpb.HealthCheckResponse{
    95  		Status: healthpb.HealthCheckResponse_SERVING,
    96  	}, nil
    97  }
    98  
    99  func (s *testHealthServer) Watch(in *healthpb.HealthCheckRequest, stream healthgrpc.Health_WatchServer) error {
   100  	return s.watchFunc(s, in, stream)
   101  }
   102  
   103  // SetServingStatus is called when need to reset the serving status of a service
   104  // or insert a new service entry into the statusMap.
   105  func (s *testHealthServer) SetServingStatus(service string, status healthpb.HealthCheckResponse_ServingStatus) {
   106  	s.mu.Lock()
   107  	s.status[service] = status
   108  	select {
   109  	case <-s.update:
   110  	default:
   111  	}
   112  	s.update <- struct{}{}
   113  	s.mu.Unlock()
   114  }
   115  
   116  func setupHealthCheckWrapper() (hcEnterChan chan struct{}, hcExitChan chan struct{}, wrapper internal.HealthChecker) {
   117  	hcEnterChan = make(chan struct{})
   118  	hcExitChan = make(chan struct{})
   119  	wrapper = func(ctx context.Context, newStream func(string) (interface{}, error), update func(connectivity.State, error), service string) error {
   120  		close(hcEnterChan)
   121  		defer close(hcExitChan)
   122  		return testHealthCheckFunc(ctx, newStream, update, service)
   123  	}
   124  	return
   125  }
   126  
   127  type svrConfig struct {
   128  	specialWatchFunc func(s *testHealthServer, in *healthpb.HealthCheckRequest, stream healthgrpc.Health_WatchServer) error
   129  }
   130  
   131  func setupServer(sc *svrConfig) (s *grpc.Server, lis net.Listener, ts *testHealthServer, deferFunc func(), err error) {
   132  	s = grpc.NewServer()
   133  	lis, err = net.Listen("tcp", "localhost:0")
   134  	if err != nil {
   135  		return nil, nil, nil, func() {}, fmt.Errorf("failed to listen due to err %v", err)
   136  	}
   137  	if sc.specialWatchFunc != nil {
   138  		ts = newTestHealthServerWithWatchFunc(sc.specialWatchFunc)
   139  	} else {
   140  		ts = newTestHealthServer()
   141  	}
   142  	healthgrpc.RegisterHealthServer(s, ts)
   143  	testpb.RegisterTestServiceServer(s, &testServer{})
   144  	go s.Serve(lis)
   145  	return s, lis, ts, s.Stop, nil
   146  }
   147  
   148  type clientConfig struct {
   149  	balancerName               string
   150  	testHealthCheckFuncWrapper internal.HealthChecker
   151  	extraDialOption            []grpc.DialOption
   152  }
   153  
   154  func setupClient(c *clientConfig) (cc *grpc.ClientConn, r *manual.Resolver, deferFunc func(), err error) {
   155  	r = manual.NewBuilderWithScheme("whatever")
   156  	var opts []grpc.DialOption
   157  	opts = append(opts, grpc.WithInsecure(), grpc.WithResolvers(r), grpc.WithBalancerName(c.balancerName))
   158  	if c.testHealthCheckFuncWrapper != nil {
   159  		opts = append(opts, internal.WithHealthCheckFunc.(func(internal.HealthChecker) grpc.DialOption)(c.testHealthCheckFuncWrapper))
   160  	}
   161  	opts = append(opts, c.extraDialOption...)
   162  	cc, err = grpc.Dial(r.Scheme()+":///test.server", opts...)
   163  	if err != nil {
   164  
   165  		return nil, nil, nil, fmt.Errorf("dial failed due to err: %v", err)
   166  	}
   167  	return cc, r, func() { cc.Close() }, nil
   168  }
   169  
   170  func (s) TestHealthCheckWatchStateChange(t *testing.T) {
   171  	_, lis, ts, deferFunc, err := setupServer(&svrConfig{})
   172  	defer deferFunc()
   173  	if err != nil {
   174  		t.Fatal(err)
   175  	}
   176  
   177  	// The table below shows the expected series of addrConn connectivity transitions when server
   178  	// updates its health status. As there's only one addrConn corresponds with the ClientConn in this
   179  	// test, we use ClientConn's connectivity state as the addrConn connectivity state.
   180  	//+------------------------------+-------------------------------------------+
   181  	//| Health Check Returned Status | Expected addrConn Connectivity Transition |
   182  	//+------------------------------+-------------------------------------------+
   183  	//| NOT_SERVING                  | ->TRANSIENT FAILURE                       |
   184  	//| SERVING                      | ->READY                                   |
   185  	//| SERVICE_UNKNOWN              | ->TRANSIENT FAILURE                       |
   186  	//| SERVING                      | ->READY                                   |
   187  	//| UNKNOWN                      | ->TRANSIENT FAILURE                       |
   188  	//+------------------------------+-------------------------------------------+
   189  	ts.SetServingStatus("foo", healthpb.HealthCheckResponse_NOT_SERVING)
   190  
   191  	cc, r, deferFunc, err := setupClient(&clientConfig{balancerName: "round_robin"})
   192  	if err != nil {
   193  		t.Fatal(err)
   194  	}
   195  	defer deferFunc()
   196  
   197  	r.UpdateState(resolver.State{
   198  		Addresses: []resolver.Address{{Addr: lis.Addr().String()}},
   199  		ServiceConfig: parseCfg(r, `{
   200  	"healthCheckConfig": {
   201  		"serviceName": "foo"
   202  	}
   203  }`)})
   204  	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
   205  	defer cancel()
   206  	if ok := cc.WaitForStateChange(ctx, connectivity.Idle); !ok {
   207  		t.Fatal("ClientConn is still in IDLE state when the context times out.")
   208  	}
   209  	if ok := cc.WaitForStateChange(ctx, connectivity.Connecting); !ok {
   210  		t.Fatal("ClientConn is still in CONNECTING state when the context times out.")
   211  	}
   212  	if s := cc.GetState(); s != connectivity.TransientFailure {
   213  		t.Fatalf("ClientConn is in %v state, want TRANSIENT FAILURE", s)
   214  	}
   215  
   216  	ts.SetServingStatus("foo", healthpb.HealthCheckResponse_SERVING)
   217  	if ok := cc.WaitForStateChange(ctx, connectivity.TransientFailure); !ok {
   218  		t.Fatal("ClientConn is still in TRANSIENT FAILURE state when the context times out.")
   219  	}
   220  	if s := cc.GetState(); s != connectivity.Ready {
   221  		t.Fatalf("ClientConn is in %v state, want READY", s)
   222  	}
   223  
   224  	ts.SetServingStatus("foo", healthpb.HealthCheckResponse_SERVICE_UNKNOWN)
   225  	if ok := cc.WaitForStateChange(ctx, connectivity.Ready); !ok {
   226  		t.Fatal("ClientConn is still in READY state when the context times out.")
   227  	}
   228  	if s := cc.GetState(); s != connectivity.TransientFailure {
   229  		t.Fatalf("ClientConn is in %v state, want TRANSIENT FAILURE", s)
   230  	}
   231  
   232  	ts.SetServingStatus("foo", healthpb.HealthCheckResponse_SERVING)
   233  	if ok := cc.WaitForStateChange(ctx, connectivity.TransientFailure); !ok {
   234  		t.Fatal("ClientConn is still in TRANSIENT FAILURE state when the context times out.")
   235  	}
   236  	if s := cc.GetState(); s != connectivity.Ready {
   237  		t.Fatalf("ClientConn is in %v state, want READY", s)
   238  	}
   239  
   240  	ts.SetServingStatus("foo", healthpb.HealthCheckResponse_UNKNOWN)
   241  	if ok := cc.WaitForStateChange(ctx, connectivity.Ready); !ok {
   242  		t.Fatal("ClientConn is still in READY state when the context times out.")
   243  	}
   244  	if s := cc.GetState(); s != connectivity.TransientFailure {
   245  		t.Fatalf("ClientConn is in %v state, want TRANSIENT FAILURE", s)
   246  	}
   247  }
   248  
   249  // If Watch returns Unimplemented, then the ClientConn should go into READY state.
   250  func (s) TestHealthCheckHealthServerNotRegistered(t *testing.T) {
   251  	grpctest.TLogger.ExpectError("Subchannel health check is unimplemented at server side, thus health check is disabled")
   252  	s := grpc.NewServer()
   253  	lis, err := net.Listen("tcp", "localhost:0")
   254  	if err != nil {
   255  		t.Fatalf("failed to listen due to err: %v", err)
   256  	}
   257  	go s.Serve(lis)
   258  	defer s.Stop()
   259  
   260  	cc, r, deferFunc, err := setupClient(&clientConfig{balancerName: "round_robin"})
   261  	if err != nil {
   262  		t.Fatal(err)
   263  	}
   264  	defer deferFunc()
   265  
   266  	r.UpdateState(resolver.State{
   267  		Addresses: []resolver.Address{{Addr: lis.Addr().String()}},
   268  		ServiceConfig: parseCfg(r, `{
   269  	"healthCheckConfig": {
   270  		"serviceName": "foo"
   271  	}
   272  }`)})
   273  	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
   274  	defer cancel()
   275  
   276  	if ok := cc.WaitForStateChange(ctx, connectivity.Idle); !ok {
   277  		t.Fatal("ClientConn is still in IDLE state when the context times out.")
   278  	}
   279  	if ok := cc.WaitForStateChange(ctx, connectivity.Connecting); !ok {
   280  		t.Fatal("ClientConn is still in CONNECTING state when the context times out.")
   281  	}
   282  	if s := cc.GetState(); s != connectivity.Ready {
   283  		t.Fatalf("ClientConn is in %v state, want READY", s)
   284  	}
   285  }
   286  
   287  // In the case of a goaway received, the health check stream should be terminated and health check
   288  // function should exit.
   289  func (s) TestHealthCheckWithGoAway(t *testing.T) {
   290  	hcEnterChan, hcExitChan, testHealthCheckFuncWrapper := setupHealthCheckWrapper()
   291  
   292  	s, lis, ts, deferFunc, err := setupServer(&svrConfig{})
   293  	defer deferFunc()
   294  	if err != nil {
   295  		t.Fatal(err)
   296  	}
   297  
   298  	ts.SetServingStatus("foo", healthpb.HealthCheckResponse_SERVING)
   299  
   300  	cc, r, deferFunc, err := setupClient(&clientConfig{
   301  		balancerName:               "round_robin",
   302  		testHealthCheckFuncWrapper: testHealthCheckFuncWrapper,
   303  	})
   304  	if err != nil {
   305  		t.Fatal(err)
   306  	}
   307  	defer deferFunc()
   308  
   309  	tc := testpb.NewTestServiceClient(cc)
   310  	r.UpdateState(resolver.State{
   311  		Addresses: []resolver.Address{{Addr: lis.Addr().String()}},
   312  		ServiceConfig: parseCfg(r, `{
   313  	"healthCheckConfig": {
   314  		"serviceName": "foo"
   315  	}
   316  }`)})
   317  
   318  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   319  	defer cancel()
   320  
   321  	// make some rpcs to make sure connection is working.
   322  	if err := verifyResultWithDelay(func() (bool, error) {
   323  		if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err != nil {
   324  			return false, fmt.Errorf("TestService/EmptyCall(_, _) = _, %v, want _, <nil>", err)
   325  		}
   326  		return true, nil
   327  	}); err != nil {
   328  		t.Fatal(err)
   329  	}
   330  
   331  	// the stream rpc will persist through goaway event.
   332  	stream, err := tc.FullDuplexCall(ctx, grpc.WaitForReady(true))
   333  	if err != nil {
   334  		t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
   335  	}
   336  	respParam := []*testpb.ResponseParameters{{Size: 1}}
   337  	payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, int32(1))
   338  	if err != nil {
   339  		t.Fatal(err)
   340  	}
   341  	req := &testpb.StreamingOutputCallRequest{
   342  		ResponseParameters: respParam,
   343  		Payload:            payload,
   344  	}
   345  	if err := stream.Send(req); err != nil {
   346  		t.Fatalf("%v.Send(_) = %v, want <nil>", stream, err)
   347  	}
   348  	if _, err := stream.Recv(); err != nil {
   349  		t.Fatalf("%v.Recv() = _, %v, want _, <nil>", stream, err)
   350  	}
   351  
   352  	select {
   353  	case <-hcExitChan:
   354  		t.Fatal("Health check function has exited, which is not expected.")
   355  	default:
   356  	}
   357  
   358  	// server sends GoAway
   359  	go s.GracefulStop()
   360  
   361  	select {
   362  	case <-hcExitChan:
   363  	case <-time.After(5 * time.Second):
   364  		select {
   365  		case <-hcEnterChan:
   366  		default:
   367  			t.Fatal("Health check function has not entered after 5s.")
   368  		}
   369  		t.Fatal("Health check function has not exited after 5s.")
   370  	}
   371  
   372  	// The existing RPC should be still good to proceed.
   373  	if err := stream.Send(req); err != nil {
   374  		t.Fatalf("%v.Send(_) = %v, want <nil>", stream, err)
   375  	}
   376  	if _, err := stream.Recv(); err != nil {
   377  		t.Fatalf("%v.Recv() = _, %v, want _, <nil>", stream, err)
   378  	}
   379  }
   380  
   381  func (s) TestHealthCheckWithConnClose(t *testing.T) {
   382  	hcEnterChan, hcExitChan, testHealthCheckFuncWrapper := setupHealthCheckWrapper()
   383  
   384  	s, lis, ts, deferFunc, err := setupServer(&svrConfig{})
   385  	defer deferFunc()
   386  	if err != nil {
   387  		t.Fatal(err)
   388  	}
   389  
   390  	ts.SetServingStatus("foo", healthpb.HealthCheckResponse_SERVING)
   391  
   392  	cc, r, deferFunc, err := setupClient(&clientConfig{
   393  		balancerName:               "round_robin",
   394  		testHealthCheckFuncWrapper: testHealthCheckFuncWrapper,
   395  	})
   396  	if err != nil {
   397  		t.Fatal(err)
   398  	}
   399  	defer deferFunc()
   400  
   401  	tc := testpb.NewTestServiceClient(cc)
   402  
   403  	r.UpdateState(resolver.State{
   404  		Addresses: []resolver.Address{{Addr: lis.Addr().String()}},
   405  		ServiceConfig: parseCfg(r, `{
   406  	"healthCheckConfig": {
   407  		"serviceName": "foo"
   408  	}
   409  }`)})
   410  
   411  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   412  	defer cancel()
   413  	// make some rpcs to make sure connection is working.
   414  	if err := verifyResultWithDelay(func() (bool, error) {
   415  		if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err != nil {
   416  			return false, fmt.Errorf("TestService/EmptyCall(_, _) = _, %v, want _, <nil>", err)
   417  		}
   418  		return true, nil
   419  	}); err != nil {
   420  		t.Fatal(err)
   421  	}
   422  
   423  	select {
   424  	case <-hcExitChan:
   425  		t.Fatal("Health check function has exited, which is not expected.")
   426  	default:
   427  	}
   428  	// server closes the connection
   429  	s.Stop()
   430  
   431  	select {
   432  	case <-hcExitChan:
   433  	case <-time.After(5 * time.Second):
   434  		select {
   435  		case <-hcEnterChan:
   436  		default:
   437  			t.Fatal("Health check function has not entered after 5s.")
   438  		}
   439  		t.Fatal("Health check function has not exited after 5s.")
   440  	}
   441  }
   442  
   443  // addrConn drain happens when addrConn gets torn down due to its address being no longer in the
   444  // address list returned by the resolver.
   445  func (s) TestHealthCheckWithAddrConnDrain(t *testing.T) {
   446  	hcEnterChan, hcExitChan, testHealthCheckFuncWrapper := setupHealthCheckWrapper()
   447  
   448  	_, lis, ts, deferFunc, err := setupServer(&svrConfig{})
   449  	defer deferFunc()
   450  	if err != nil {
   451  		t.Fatal(err)
   452  	}
   453  
   454  	ts.SetServingStatus("foo", healthpb.HealthCheckResponse_SERVING)
   455  
   456  	cc, r, deferFunc, err := setupClient(&clientConfig{
   457  		balancerName:               "round_robin",
   458  		testHealthCheckFuncWrapper: testHealthCheckFuncWrapper,
   459  	})
   460  	if err != nil {
   461  		t.Fatal(err)
   462  	}
   463  	defer deferFunc()
   464  
   465  	tc := testpb.NewTestServiceClient(cc)
   466  	sc := parseCfg(r, `{
   467  	"healthCheckConfig": {
   468  		"serviceName": "foo"
   469  	}
   470  }`)
   471  	r.UpdateState(resolver.State{
   472  		Addresses:     []resolver.Address{{Addr: lis.Addr().String()}},
   473  		ServiceConfig: sc,
   474  	})
   475  
   476  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   477  	defer cancel()
   478  	// make some rpcs to make sure connection is working.
   479  	if err := verifyResultWithDelay(func() (bool, error) {
   480  		if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err != nil {
   481  			return false, fmt.Errorf("TestService/EmptyCall(_, _) = _, %v, want _, <nil>", err)
   482  		}
   483  		return true, nil
   484  	}); err != nil {
   485  		t.Fatal(err)
   486  	}
   487  
   488  	// the stream rpc will persist through goaway event.
   489  	stream, err := tc.FullDuplexCall(ctx, grpc.WaitForReady(true))
   490  	if err != nil {
   491  		t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
   492  	}
   493  	respParam := []*testpb.ResponseParameters{{Size: 1}}
   494  	payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, int32(1))
   495  	if err != nil {
   496  		t.Fatal(err)
   497  	}
   498  	req := &testpb.StreamingOutputCallRequest{
   499  		ResponseParameters: respParam,
   500  		Payload:            payload,
   501  	}
   502  	if err := stream.Send(req); err != nil {
   503  		t.Fatalf("%v.Send(_) = %v, want <nil>", stream, err)
   504  	}
   505  	if _, err := stream.Recv(); err != nil {
   506  		t.Fatalf("%v.Recv() = _, %v, want _, <nil>", stream, err)
   507  	}
   508  
   509  	select {
   510  	case <-hcExitChan:
   511  		t.Fatal("Health check function has exited, which is not expected.")
   512  	default:
   513  	}
   514  	// trigger teardown of the ac
   515  	r.UpdateState(resolver.State{Addresses: []resolver.Address{{Addr: "fake address"}}, ServiceConfig: sc})
   516  
   517  	select {
   518  	case <-hcExitChan:
   519  	case <-time.After(5 * time.Second):
   520  		select {
   521  		case <-hcEnterChan:
   522  		default:
   523  			t.Fatal("Health check function has not entered after 5s.")
   524  		}
   525  		t.Fatal("Health check function has not exited after 5s.")
   526  	}
   527  
   528  	// The existing RPC should be still good to proceed.
   529  	if err := stream.Send(req); err != nil {
   530  		t.Fatalf("%v.Send(_) = %v, want <nil>", stream, err)
   531  	}
   532  	if _, err := stream.Recv(); err != nil {
   533  		t.Fatalf("%v.Recv() = _, %v, want _, <nil>", stream, err)
   534  	}
   535  }
   536  
   537  // ClientConn close will lead to its addrConns being torn down.
   538  func (s) TestHealthCheckWithClientConnClose(t *testing.T) {
   539  	hcEnterChan, hcExitChan, testHealthCheckFuncWrapper := setupHealthCheckWrapper()
   540  
   541  	_, lis, ts, deferFunc, err := setupServer(&svrConfig{})
   542  	defer deferFunc()
   543  	if err != nil {
   544  		t.Fatal(err)
   545  	}
   546  
   547  	ts.SetServingStatus("foo", healthpb.HealthCheckResponse_SERVING)
   548  
   549  	cc, r, deferFunc, err := setupClient(&clientConfig{
   550  		balancerName:               "round_robin",
   551  		testHealthCheckFuncWrapper: testHealthCheckFuncWrapper,
   552  	})
   553  	if err != nil {
   554  		t.Fatal(err)
   555  	}
   556  	defer deferFunc()
   557  
   558  	tc := testpb.NewTestServiceClient(cc)
   559  	r.UpdateState(resolver.State{
   560  		Addresses: []resolver.Address{{Addr: lis.Addr().String()}},
   561  		ServiceConfig: parseCfg(r, `{
   562  	"healthCheckConfig": {
   563  		"serviceName": "foo"
   564  	}
   565  }`)})
   566  
   567  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   568  	defer cancel()
   569  	// make some rpcs to make sure connection is working.
   570  	if err := verifyResultWithDelay(func() (bool, error) {
   571  		if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err != nil {
   572  			return false, fmt.Errorf("TestService/EmptyCall(_, _) = _, %v, want _, <nil>", err)
   573  		}
   574  		return true, nil
   575  	}); err != nil {
   576  		t.Fatal(err)
   577  	}
   578  
   579  	select {
   580  	case <-hcExitChan:
   581  		t.Fatal("Health check function has exited, which is not expected.")
   582  	default:
   583  	}
   584  
   585  	// trigger addrConn teardown
   586  	cc.Close()
   587  
   588  	select {
   589  	case <-hcExitChan:
   590  	case <-time.After(5 * time.Second):
   591  		select {
   592  		case <-hcEnterChan:
   593  		default:
   594  			t.Fatal("Health check function has not entered after 5s.")
   595  		}
   596  		t.Fatal("Health check function has not exited after 5s.")
   597  	}
   598  }
   599  
   600  // This test is to test the logic in the createTransport after the health check function returns which
   601  // closes the skipReset channel(since it has not been closed inside health check func) to unblock
   602  // onGoAway/onClose goroutine.
   603  func (s) TestHealthCheckWithoutSetConnectivityStateCalledAddrConnShutDown(t *testing.T) {
   604  	hcEnterChan, hcExitChan, testHealthCheckFuncWrapper := setupHealthCheckWrapper()
   605  
   606  	_, lis, ts, deferFunc, err := setupServer(&svrConfig{
   607  		specialWatchFunc: func(s *testHealthServer, in *healthpb.HealthCheckRequest, stream healthgrpc.Health_WatchServer) error {
   608  			if in.Service != "delay" {
   609  				return status.Error(codes.FailedPrecondition,
   610  					"this special Watch function only handles request with service name to be \"delay\"")
   611  			}
   612  			// Do nothing to mock a delay of health check response from server side.
   613  			// This case is to help with the test that covers the condition that setConnectivityState is not
   614  			// called inside HealthCheckFunc before the func returns.
   615  			select {
   616  			case <-stream.Context().Done():
   617  			case <-time.After(5 * time.Second):
   618  			}
   619  			return nil
   620  		},
   621  	})
   622  	defer deferFunc()
   623  	if err != nil {
   624  		t.Fatal(err)
   625  	}
   626  
   627  	ts.SetServingStatus("delay", healthpb.HealthCheckResponse_SERVING)
   628  
   629  	_, r, deferFunc, err := setupClient(&clientConfig{
   630  		balancerName:               "round_robin",
   631  		testHealthCheckFuncWrapper: testHealthCheckFuncWrapper,
   632  	})
   633  	if err != nil {
   634  		t.Fatal(err)
   635  	}
   636  	defer deferFunc()
   637  
   638  	// The serviceName "delay" is specially handled at server side, where response will not be sent
   639  	// back to client immediately upon receiving the request (client should receive no response until
   640  	// test ends).
   641  	sc := parseCfg(r, `{
   642  	"healthCheckConfig": {
   643  		"serviceName": "delay"
   644  	}
   645  }`)
   646  	r.UpdateState(resolver.State{
   647  		Addresses:     []resolver.Address{{Addr: lis.Addr().String()}},
   648  		ServiceConfig: sc,
   649  	})
   650  
   651  	select {
   652  	case <-hcExitChan:
   653  		t.Fatal("Health check function has exited, which is not expected.")
   654  	default:
   655  	}
   656  
   657  	select {
   658  	case <-hcEnterChan:
   659  	case <-time.After(5 * time.Second):
   660  		t.Fatal("Health check function has not been invoked after 5s.")
   661  	}
   662  	// trigger teardown of the ac, ac in SHUTDOWN state
   663  	r.UpdateState(resolver.State{Addresses: []resolver.Address{{Addr: "fake address"}}, ServiceConfig: sc})
   664  
   665  	// The health check func should exit without calling the setConnectivityState func, as server hasn't sent
   666  	// any response.
   667  	select {
   668  	case <-hcExitChan:
   669  	case <-time.After(5 * time.Second):
   670  		t.Fatal("Health check function has not exited after 5s.")
   671  	}
   672  	// The deferred leakcheck will check whether there's leaked goroutine, which is an indication
   673  	// whether we closes the skipReset channel to unblock onGoAway/onClose goroutine.
   674  }
   675  
   676  // This test is to test the logic in the createTransport after the health check function returns which
   677  // closes the allowedToReset channel(since it has not been closed inside health check func) to unblock
   678  // onGoAway/onClose goroutine.
   679  func (s) TestHealthCheckWithoutSetConnectivityStateCalled(t *testing.T) {
   680  	hcEnterChan, hcExitChan, testHealthCheckFuncWrapper := setupHealthCheckWrapper()
   681  
   682  	s, lis, ts, deferFunc, err := setupServer(&svrConfig{
   683  		specialWatchFunc: func(s *testHealthServer, in *healthpb.HealthCheckRequest, stream healthgrpc.Health_WatchServer) error {
   684  			if in.Service != "delay" {
   685  				return status.Error(codes.FailedPrecondition,
   686  					"this special Watch function only handles request with service name to be \"delay\"")
   687  			}
   688  			// Do nothing to mock a delay of health check response from server side.
   689  			// This case is to help with the test that covers the condition that setConnectivityState is not
   690  			// called inside HealthCheckFunc before the func returns.
   691  			select {
   692  			case <-stream.Context().Done():
   693  			case <-time.After(5 * time.Second):
   694  			}
   695  			return nil
   696  		},
   697  	})
   698  	defer deferFunc()
   699  	if err != nil {
   700  		t.Fatal(err)
   701  	}
   702  
   703  	ts.SetServingStatus("delay", healthpb.HealthCheckResponse_SERVING)
   704  
   705  	_, r, deferFunc, err := setupClient(&clientConfig{
   706  		balancerName:               "round_robin",
   707  		testHealthCheckFuncWrapper: testHealthCheckFuncWrapper,
   708  	})
   709  	if err != nil {
   710  		t.Fatal(err)
   711  	}
   712  	defer deferFunc()
   713  
   714  	// The serviceName "delay" is specially handled at server side, where response will not be sent
   715  	// back to client immediately upon receiving the request (client should receive no response until
   716  	// test ends).
   717  	r.UpdateState(resolver.State{
   718  		Addresses: []resolver.Address{{Addr: lis.Addr().String()}},
   719  		ServiceConfig: parseCfg(r, `{
   720  	"healthCheckConfig": {
   721  		"serviceName": "delay"
   722  	}
   723  }`)})
   724  
   725  	select {
   726  	case <-hcExitChan:
   727  		t.Fatal("Health check function has exited, which is not expected.")
   728  	default:
   729  	}
   730  
   731  	select {
   732  	case <-hcEnterChan:
   733  	case <-time.After(5 * time.Second):
   734  		t.Fatal("Health check function has not been invoked after 5s.")
   735  	}
   736  	// trigger transport being closed
   737  	s.Stop()
   738  
   739  	// The health check func should exit without calling the setConnectivityState func, as server hasn't sent
   740  	// any response.
   741  	select {
   742  	case <-hcExitChan:
   743  	case <-time.After(5 * time.Second):
   744  		t.Fatal("Health check function has not exited after 5s.")
   745  	}
   746  	// The deferred leakcheck will check whether there's leaked goroutine, which is an indication
   747  	// whether we closes the allowedToReset channel to unblock onGoAway/onClose goroutine.
   748  }
   749  
   750  func testHealthCheckDisableWithDialOption(t *testing.T, addr string) {
   751  	hcEnterChan, _, testHealthCheckFuncWrapper := setupHealthCheckWrapper()
   752  
   753  	cc, r, deferFunc, err := setupClient(&clientConfig{
   754  		balancerName:               "round_robin",
   755  		testHealthCheckFuncWrapper: testHealthCheckFuncWrapper,
   756  		extraDialOption:            []grpc.DialOption{grpc.WithDisableHealthCheck()},
   757  	})
   758  	if err != nil {
   759  		t.Fatal(err)
   760  	}
   761  	defer deferFunc()
   762  
   763  	tc := testpb.NewTestServiceClient(cc)
   764  
   765  	r.UpdateState(resolver.State{
   766  		Addresses: []resolver.Address{{Addr: addr}},
   767  		ServiceConfig: parseCfg(r, `{
   768  	"healthCheckConfig": {
   769  		"serviceName": "foo"
   770  	}
   771  }`)})
   772  
   773  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   774  	defer cancel()
   775  	// send some rpcs to make sure transport has been created and is ready for use.
   776  	if err := verifyResultWithDelay(func() (bool, error) {
   777  		if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err != nil {
   778  			return false, fmt.Errorf("TestService/EmptyCall(_, _) = _, %v, want _, <nil>", err)
   779  		}
   780  		return true, nil
   781  	}); err != nil {
   782  		t.Fatal(err)
   783  	}
   784  
   785  	select {
   786  	case <-hcEnterChan:
   787  		t.Fatal("Health check function has exited, which is not expected.")
   788  	default:
   789  	}
   790  }
   791  
   792  func testHealthCheckDisableWithBalancer(t *testing.T, addr string) {
   793  	hcEnterChan, _, testHealthCheckFuncWrapper := setupHealthCheckWrapper()
   794  
   795  	cc, r, deferFunc, err := setupClient(&clientConfig{
   796  		balancerName:               "pick_first",
   797  		testHealthCheckFuncWrapper: testHealthCheckFuncWrapper,
   798  	})
   799  	if err != nil {
   800  		t.Fatal(err)
   801  	}
   802  	defer deferFunc()
   803  
   804  	tc := testpb.NewTestServiceClient(cc)
   805  
   806  	r.UpdateState(resolver.State{
   807  		Addresses: []resolver.Address{{Addr: addr}},
   808  		ServiceConfig: parseCfg(r, `{
   809  	"healthCheckConfig": {
   810  		"serviceName": "foo"
   811  	}
   812  }`)})
   813  
   814  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   815  	defer cancel()
   816  	// send some rpcs to make sure transport has been created and is ready for use.
   817  	if err := verifyResultWithDelay(func() (bool, error) {
   818  		if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err != nil {
   819  			return false, fmt.Errorf("TestService/EmptyCall(_, _) = _, %v, want _, <nil>", err)
   820  		}
   821  		return true, nil
   822  	}); err != nil {
   823  		t.Fatal(err)
   824  	}
   825  
   826  	select {
   827  	case <-hcEnterChan:
   828  		t.Fatal("Health check function has started, which is not expected.")
   829  	default:
   830  	}
   831  }
   832  
   833  func testHealthCheckDisableWithServiceConfig(t *testing.T, addr string) {
   834  	hcEnterChan, _, testHealthCheckFuncWrapper := setupHealthCheckWrapper()
   835  
   836  	cc, r, deferFunc, err := setupClient(&clientConfig{
   837  		balancerName:               "round_robin",
   838  		testHealthCheckFuncWrapper: testHealthCheckFuncWrapper,
   839  	})
   840  	if err != nil {
   841  		t.Fatal(err)
   842  	}
   843  	defer deferFunc()
   844  
   845  	tc := testpb.NewTestServiceClient(cc)
   846  
   847  	r.UpdateState(resolver.State{Addresses: []resolver.Address{{Addr: addr}}})
   848  
   849  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   850  	defer cancel()
   851  	// send some rpcs to make sure transport has been created and is ready for use.
   852  	if err := verifyResultWithDelay(func() (bool, error) {
   853  		if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err != nil {
   854  			return false, fmt.Errorf("TestService/EmptyCall(_, _) = _, %v, want _, <nil>", err)
   855  		}
   856  		return true, nil
   857  	}); err != nil {
   858  		t.Fatal(err)
   859  	}
   860  
   861  	select {
   862  	case <-hcEnterChan:
   863  		t.Fatal("Health check function has started, which is not expected.")
   864  	default:
   865  	}
   866  }
   867  
   868  func (s) TestHealthCheckDisable(t *testing.T) {
   869  	_, lis, ts, deferFunc, err := setupServer(&svrConfig{})
   870  	defer deferFunc()
   871  	if err != nil {
   872  		t.Fatal(err)
   873  	}
   874  	ts.SetServingStatus("foo", healthpb.HealthCheckResponse_SERVING)
   875  
   876  	// test client side disabling configuration.
   877  	testHealthCheckDisableWithDialOption(t, lis.Addr().String())
   878  	testHealthCheckDisableWithBalancer(t, lis.Addr().String())
   879  	testHealthCheckDisableWithServiceConfig(t, lis.Addr().String())
   880  }
   881  
   882  func (s) TestHealthCheckChannelzCountingCallSuccess(t *testing.T) {
   883  	_, lis, _, deferFunc, err := setupServer(&svrConfig{
   884  		specialWatchFunc: func(s *testHealthServer, in *healthpb.HealthCheckRequest, stream healthgrpc.Health_WatchServer) error {
   885  			if in.Service != "channelzSuccess" {
   886  				return status.Error(codes.FailedPrecondition,
   887  					"this special Watch function only handles request with service name to be \"channelzSuccess\"")
   888  			}
   889  			return status.Error(codes.OK, "fake success")
   890  		},
   891  	})
   892  	defer deferFunc()
   893  	if err != nil {
   894  		t.Fatal(err)
   895  	}
   896  
   897  	_, r, deferFunc, err := setupClient(&clientConfig{balancerName: "round_robin"})
   898  	if err != nil {
   899  		t.Fatal(err)
   900  	}
   901  	defer deferFunc()
   902  
   903  	r.UpdateState(resolver.State{
   904  		Addresses: []resolver.Address{{Addr: lis.Addr().String()}},
   905  		ServiceConfig: parseCfg(r, `{
   906  	"healthCheckConfig": {
   907  		"serviceName": "channelzSuccess"
   908  	}
   909  }`)})
   910  
   911  	if err := verifyResultWithDelay(func() (bool, error) {
   912  		cm, _ := channelz.GetTopChannels(0, 0)
   913  		if len(cm) == 0 {
   914  			return false, errors.New("channelz.GetTopChannels return 0 top channel")
   915  		}
   916  		if len(cm[0].SubChans) == 0 {
   917  			return false, errors.New("there is 0 subchannel")
   918  		}
   919  		var id int64
   920  		for k := range cm[0].SubChans {
   921  			id = k
   922  			break
   923  		}
   924  		scm := channelz.GetSubChannel(id)
   925  		if scm == nil || scm.ChannelData == nil {
   926  			return false, errors.New("nil subchannel metric or nil subchannel metric ChannelData returned")
   927  		}
   928  		// exponential backoff retry may result in more than one health check call.
   929  		if scm.ChannelData.CallsStarted > 0 && scm.ChannelData.CallsSucceeded > 0 && scm.ChannelData.CallsFailed == 0 {
   930  			return true, nil
   931  		}
   932  		return false, fmt.Errorf("got %d CallsStarted, %d CallsSucceeded, want >0 >0", scm.ChannelData.CallsStarted, scm.ChannelData.CallsSucceeded)
   933  	}); err != nil {
   934  		t.Fatal(err)
   935  	}
   936  }
   937  
   938  func (s) TestHealthCheckChannelzCountingCallFailure(t *testing.T) {
   939  	_, lis, _, deferFunc, err := setupServer(&svrConfig{
   940  		specialWatchFunc: func(s *testHealthServer, in *healthpb.HealthCheckRequest, stream healthgrpc.Health_WatchServer) error {
   941  			if in.Service != "channelzFailure" {
   942  				return status.Error(codes.FailedPrecondition,
   943  					"this special Watch function only handles request with service name to be \"channelzFailure\"")
   944  			}
   945  			return status.Error(codes.Internal, "fake failure")
   946  		},
   947  	})
   948  	if err != nil {
   949  		t.Fatal(err)
   950  	}
   951  	defer deferFunc()
   952  
   953  	_, r, deferFunc, err := setupClient(&clientConfig{balancerName: "round_robin"})
   954  	if err != nil {
   955  		t.Fatal(err)
   956  	}
   957  	defer deferFunc()
   958  
   959  	r.UpdateState(resolver.State{
   960  		Addresses: []resolver.Address{{Addr: lis.Addr().String()}},
   961  		ServiceConfig: parseCfg(r, `{
   962  	"healthCheckConfig": {
   963  		"serviceName": "channelzFailure"
   964  	}
   965  }`)})
   966  
   967  	if err := verifyResultWithDelay(func() (bool, error) {
   968  		cm, _ := channelz.GetTopChannels(0, 0)
   969  		if len(cm) == 0 {
   970  			return false, errors.New("channelz.GetTopChannels return 0 top channel")
   971  		}
   972  		if len(cm[0].SubChans) == 0 {
   973  			return false, errors.New("there is 0 subchannel")
   974  		}
   975  		var id int64
   976  		for k := range cm[0].SubChans {
   977  			id = k
   978  			break
   979  		}
   980  		scm := channelz.GetSubChannel(id)
   981  		if scm == nil || scm.ChannelData == nil {
   982  			return false, errors.New("nil subchannel metric or nil subchannel metric ChannelData returned")
   983  		}
   984  		// exponential backoff retry may result in more than one health check call.
   985  		if scm.ChannelData.CallsStarted > 0 && scm.ChannelData.CallsFailed > 0 && scm.ChannelData.CallsSucceeded == 0 {
   986  			return true, nil
   987  		}
   988  		return false, fmt.Errorf("got %d CallsStarted, %d CallsFailed, want >0, >0", scm.ChannelData.CallsStarted, scm.ChannelData.CallsFailed)
   989  	}); err != nil {
   990  		t.Fatal(err)
   991  	}
   992  }