google.golang.org/grpc@v1.72.2/internal/testutils/blocking_context_dialer_test.go (about)

     1  /*
     2   *
     3   * Copyright 2024 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  package testutils
    20  
    21  import (
    22  	"context"
    23  	"errors"
    24  	"testing"
    25  	"time"
    26  )
    27  
    28  const (
    29  	testTimeout      = 5 * time.Second
    30  	testShortTimeout = 10 * time.Millisecond
    31  )
    32  
    33  func (s) TestBlockingDialer_NoHold(t *testing.T) {
    34  	lis, err := LocalTCPListener()
    35  	if err != nil {
    36  		t.Fatalf("Failed to listen: %v", err)
    37  	}
    38  	defer lis.Close()
    39  
    40  	d := NewBlockingDialer()
    41  
    42  	// This should not block.
    43  	ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
    44  	defer cancel()
    45  	conn, err := d.DialContext(ctx, lis.Addr().String())
    46  	if err != nil {
    47  		t.Fatalf("Failed to dial: %v", err)
    48  	}
    49  	conn.Close()
    50  }
    51  
    52  func (s) TestBlockingDialer_HoldWaitResume(t *testing.T) {
    53  	lis, err := LocalTCPListener()
    54  	if err != nil {
    55  		t.Fatalf("Failed to listen: %v", err)
    56  	}
    57  	defer lis.Close()
    58  
    59  	d := NewBlockingDialer()
    60  	h := d.Hold(lis.Addr().String())
    61  
    62  	if h.IsStarted() {
    63  		t.Fatalf("hold.IsStarted() = true, want false")
    64  	}
    65  
    66  	done := make(chan struct{})
    67  	ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
    68  	defer cancel()
    69  	go func() {
    70  		defer close(done)
    71  		conn, err := d.DialContext(ctx, lis.Addr().String())
    72  		if err != nil {
    73  			t.Errorf("BlockingDialer.DialContext() got error: %v, want success", err)
    74  			return
    75  		}
    76  
    77  		if !h.IsStarted() {
    78  			t.Errorf("hold.IsStarted() = false, want true")
    79  		}
    80  		conn.Close()
    81  	}()
    82  
    83  	// This should block until the goroutine above is scheduled.
    84  	if !h.Wait(ctx) {
    85  		t.Fatalf("Timeout while waiting for a connection attempt to %q", h.addr)
    86  	}
    87  
    88  	if !h.IsStarted() {
    89  		t.Errorf("hold.IsStarted() = false, want true")
    90  	}
    91  
    92  	select {
    93  	case <-done:
    94  		t.Fatalf("Expected dialer to be blocked.")
    95  	case <-time.After(testShortTimeout):
    96  	}
    97  
    98  	h.Resume() // Unblock the above goroutine.
    99  
   100  	select {
   101  	case <-done:
   102  	case <-ctx.Done():
   103  		t.Errorf("Timeout waiting for connection attempt to resume.")
   104  	}
   105  }
   106  
   107  func (s) TestBlockingDialer_HoldWaitFail(t *testing.T) {
   108  	lis, err := LocalTCPListener()
   109  	if err != nil {
   110  		t.Fatalf("Failed to listen: %v", err)
   111  	}
   112  	defer lis.Close()
   113  
   114  	d := NewBlockingDialer()
   115  	h := d.Hold(lis.Addr().String())
   116  
   117  	wantErr := errors.New("test error")
   118  
   119  	dialError := make(chan error)
   120  	ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
   121  	defer cancel()
   122  	go func() {
   123  		_, err := d.DialContext(ctx, lis.Addr().String())
   124  		dialError <- err
   125  	}()
   126  
   127  	if !h.Wait(ctx) {
   128  		t.Fatal("Timeout while waiting for a connection attempt to " + h.addr)
   129  	}
   130  	select {
   131  	case err = <-dialError:
   132  		t.Errorf("DialContext got unblocked with err %v. Want DialContext to still be blocked after Wait()", err)
   133  	case <-time.After(testShortTimeout):
   134  	}
   135  
   136  	h.Fail(wantErr)
   137  
   138  	select {
   139  	case err = <-dialError:
   140  		if !errors.Is(err, wantErr) {
   141  			t.Errorf("BlockingDialer.DialContext() after Fail(): got error %v, want %v", err, wantErr)
   142  		}
   143  	case <-ctx.Done():
   144  		t.Errorf("Timeout waiting for connection attempt to fail.")
   145  	}
   146  }
   147  
   148  func (s) TestBlockingDialer_ContextCanceled(t *testing.T) {
   149  	lis, err := LocalTCPListener()
   150  	if err != nil {
   151  		t.Fatalf("Failed to listen: %v", err)
   152  	}
   153  	defer lis.Close()
   154  
   155  	d := NewBlockingDialer()
   156  	h := d.Hold(lis.Addr().String())
   157  
   158  	dialErr := make(chan error)
   159  	testCtx, cancel := context.WithTimeout(context.Background(), testTimeout)
   160  	defer cancel()
   161  
   162  	ctx, cancel := context.WithCancel(testCtx)
   163  	defer cancel()
   164  	go func() {
   165  		_, err := d.DialContext(ctx, lis.Addr().String())
   166  		dialErr <- err
   167  	}()
   168  	if !h.Wait(testCtx) {
   169  		t.Errorf("Timeout while waiting for a connection attempt to %q", h.addr)
   170  	}
   171  
   172  	cancel()
   173  
   174  	select {
   175  	case err = <-dialErr:
   176  		if !errors.Is(err, context.Canceled) {
   177  			t.Errorf("BlockingDialer.DialContext() after context cancel: got error %v, want %v", err, context.Canceled)
   178  		}
   179  	case <-testCtx.Done():
   180  		t.Errorf("Timeout while waiting for Wait to return.")
   181  	}
   182  
   183  	h.Resume() // noop, just make sure nothing bad happen.
   184  }
   185  
   186  func (s) TestBlockingDialer_CancelWait(t *testing.T) {
   187  	lis, err := LocalTCPListener()
   188  	if err != nil {
   189  		t.Fatalf("Failed to listen: %v", err)
   190  	}
   191  	defer lis.Close()
   192  
   193  	d := NewBlockingDialer()
   194  	h := d.Hold(lis.Addr().String())
   195  
   196  	testCtx, cancel := context.WithTimeout(context.Background(), testTimeout)
   197  	defer cancel()
   198  
   199  	ctx, cancel := context.WithCancel(testCtx)
   200  	cancel()
   201  	done := make(chan struct{})
   202  	go func() {
   203  		if h.Wait(ctx) {
   204  			t.Errorf("Expected cancel to return false when context expires")
   205  		}
   206  		done <- struct{}{}
   207  	}()
   208  
   209  	select {
   210  	case <-done:
   211  	case <-testCtx.Done():
   212  		t.Errorf("Timeout while waiting for Wait to return.")
   213  	}
   214  }