google.golang.org/grpc@v1.72.2/internal/testutils/blocking_context_dialer.go (about)

     1  /*
     2   *
     3   * Copyright 2024 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  package testutils
    20  
    21  import (
    22  	"context"
    23  	"net"
    24  	"sync"
    25  
    26  	"google.golang.org/grpc/grpclog"
    27  )
    28  
    29  var logger = grpclog.Component("testutils")
    30  
    31  // BlockingDialer is a dialer that waits for Resume() to be called before
    32  // dialing.
    33  type BlockingDialer struct {
    34  	// mu protects holds.
    35  	mu sync.Mutex
    36  	// holds maps network addresses to a list of holds for that address.
    37  	holds map[string][]*Hold
    38  }
    39  
    40  // NewBlockingDialer returns a dialer that waits for Resume() to be called
    41  // before dialing.
    42  func NewBlockingDialer() *BlockingDialer {
    43  	return &BlockingDialer{
    44  		holds: make(map[string][]*Hold),
    45  	}
    46  }
    47  
    48  // DialContext implements a context dialer for use with grpc.WithContextDialer
    49  // dial option for a BlockingDialer.
    50  func (d *BlockingDialer) DialContext(ctx context.Context, addr string) (net.Conn, error) {
    51  	d.mu.Lock()
    52  	holds := d.holds[addr]
    53  	if len(holds) == 0 {
    54  		// No hold for this addr.
    55  		d.mu.Unlock()
    56  		return (&net.Dialer{}).DialContext(ctx, "tcp", addr)
    57  	}
    58  	hold := holds[0]
    59  	d.holds[addr] = holds[1:]
    60  	d.mu.Unlock()
    61  
    62  	logger.Infof("Hold %p: Intercepted connection attempt to addr %q", hold, addr)
    63  	close(hold.waitCh)
    64  	select {
    65  	case err := <-hold.blockCh:
    66  		if err != nil {
    67  			return nil, err
    68  		}
    69  		return (&net.Dialer{}).DialContext(ctx, "tcp", addr)
    70  	case <-ctx.Done():
    71  		logger.Infof("Hold %p: Connection attempt to addr %q timed out", hold, addr)
    72  		return nil, ctx.Err()
    73  	}
    74  }
    75  
    76  // Hold is a handle to a single connection attempt. It can be used to block,
    77  // fail and succeed connection attempts.
    78  type Hold struct {
    79  	// dialer is the dialer that created this hold.
    80  	dialer *BlockingDialer
    81  	// waitCh is closed when a connection attempt is received.
    82  	waitCh chan struct{}
    83  	// blockCh receives the value to return from DialContext for this connection
    84  	// attempt (nil on resume, an error on fail). It receives at most 1 value.
    85  	blockCh chan error
    86  	// addr is the address that this hold is for.
    87  	addr string
    88  }
    89  
    90  // Hold blocks the dialer when a connection attempt is made to the given addr.
    91  // A hold is valid for exactly one connection attempt. Multiple holds for an
    92  // addr can be added, and they will apply in the order that the connections are
    93  // attempted.
    94  func (d *BlockingDialer) Hold(addr string) *Hold {
    95  	d.mu.Lock()
    96  	defer d.mu.Unlock()
    97  
    98  	h := Hold{dialer: d, blockCh: make(chan error, 1), waitCh: make(chan struct{}), addr: addr}
    99  	d.holds[addr] = append(d.holds[addr], &h)
   100  	return &h
   101  }
   102  
   103  // Wait blocks until there is a connection attempt on this Hold, or the context
   104  // expires. Return false if the context has expired, true otherwise.
   105  func (h *Hold) Wait(ctx context.Context) bool {
   106  	logger.Infof("Hold %p: Waiting for a connection attempt to addr %q", h, h.addr)
   107  	select {
   108  	case <-ctx.Done():
   109  		return false
   110  	case <-h.waitCh:
   111  		return true
   112  	}
   113  }
   114  
   115  // Resume unblocks the dialer for the given addr. Either Resume or Fail must be
   116  // called at most once on a hold. Otherwise, Resume panics.
   117  func (h *Hold) Resume() {
   118  	logger.Infof("Hold %p: Resuming connection attempt to addr %q", h, h.addr)
   119  	h.blockCh <- nil
   120  	close(h.blockCh)
   121  }
   122  
   123  // Fail fails the connection attempt. Either Resume or Fail must be
   124  // called at most once on a hold. Otherwise, Resume panics.
   125  func (h *Hold) Fail(err error) {
   126  	logger.Infof("Hold %p: Failing connection attempt to addr %q", h, h.addr)
   127  	h.blockCh <- err
   128  	close(h.blockCh)
   129  }
   130  
   131  // IsStarted returns true if this hold has received a connection attempt.
   132  func (h *Hold) IsStarted() bool {
   133  	select {
   134  	case <-h.waitCh:
   135  		return true
   136  	default:
   137  		return false
   138  	}
   139  }