google.golang.org/grpc@v1.72.2/internal/testutils/blocking_context_dialer.go (about) 1 /* 2 * 3 * Copyright 2024 gRPC authors. 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 * 17 */ 18 19 package testutils 20 21 import ( 22 "context" 23 "net" 24 "sync" 25 26 "google.golang.org/grpc/grpclog" 27 ) 28 29 var logger = grpclog.Component("testutils") 30 31 // BlockingDialer is a dialer that waits for Resume() to be called before 32 // dialing. 33 type BlockingDialer struct { 34 // mu protects holds. 35 mu sync.Mutex 36 // holds maps network addresses to a list of holds for that address. 37 holds map[string][]*Hold 38 } 39 40 // NewBlockingDialer returns a dialer that waits for Resume() to be called 41 // before dialing. 42 func NewBlockingDialer() *BlockingDialer { 43 return &BlockingDialer{ 44 holds: make(map[string][]*Hold), 45 } 46 } 47 48 // DialContext implements a context dialer for use with grpc.WithContextDialer 49 // dial option for a BlockingDialer. 50 func (d *BlockingDialer) DialContext(ctx context.Context, addr string) (net.Conn, error) { 51 d.mu.Lock() 52 holds := d.holds[addr] 53 if len(holds) == 0 { 54 // No hold for this addr. 55 d.mu.Unlock() 56 return (&net.Dialer{}).DialContext(ctx, "tcp", addr) 57 } 58 hold := holds[0] 59 d.holds[addr] = holds[1:] 60 d.mu.Unlock() 61 62 logger.Infof("Hold %p: Intercepted connection attempt to addr %q", hold, addr) 63 close(hold.waitCh) 64 select { 65 case err := <-hold.blockCh: 66 if err != nil { 67 return nil, err 68 } 69 return (&net.Dialer{}).DialContext(ctx, "tcp", addr) 70 case <-ctx.Done(): 71 logger.Infof("Hold %p: Connection attempt to addr %q timed out", hold, addr) 72 return nil, ctx.Err() 73 } 74 } 75 76 // Hold is a handle to a single connection attempt. It can be used to block, 77 // fail and succeed connection attempts. 78 type Hold struct { 79 // dialer is the dialer that created this hold. 80 dialer *BlockingDialer 81 // waitCh is closed when a connection attempt is received. 82 waitCh chan struct{} 83 // blockCh receives the value to return from DialContext for this connection 84 // attempt (nil on resume, an error on fail). It receives at most 1 value. 85 blockCh chan error 86 // addr is the address that this hold is for. 87 addr string 88 } 89 90 // Hold blocks the dialer when a connection attempt is made to the given addr. 91 // A hold is valid for exactly one connection attempt. Multiple holds for an 92 // addr can be added, and they will apply in the order that the connections are 93 // attempted. 94 func (d *BlockingDialer) Hold(addr string) *Hold { 95 d.mu.Lock() 96 defer d.mu.Unlock() 97 98 h := Hold{dialer: d, blockCh: make(chan error, 1), waitCh: make(chan struct{}), addr: addr} 99 d.holds[addr] = append(d.holds[addr], &h) 100 return &h 101 } 102 103 // Wait blocks until there is a connection attempt on this Hold, or the context 104 // expires. Return false if the context has expired, true otherwise. 105 func (h *Hold) Wait(ctx context.Context) bool { 106 logger.Infof("Hold %p: Waiting for a connection attempt to addr %q", h, h.addr) 107 select { 108 case <-ctx.Done(): 109 return false 110 case <-h.waitCh: 111 return true 112 } 113 } 114 115 // Resume unblocks the dialer for the given addr. Either Resume or Fail must be 116 // called at most once on a hold. Otherwise, Resume panics. 117 func (h *Hold) Resume() { 118 logger.Infof("Hold %p: Resuming connection attempt to addr %q", h, h.addr) 119 h.blockCh <- nil 120 close(h.blockCh) 121 } 122 123 // Fail fails the connection attempt. Either Resume or Fail must be 124 // called at most once on a hold. Otherwise, Resume panics. 125 func (h *Hold) Fail(err error) { 126 logger.Infof("Hold %p: Failing connection attempt to addr %q", h, h.addr) 127 h.blockCh <- err 128 close(h.blockCh) 129 } 130 131 // IsStarted returns true if this hold has received a connection attempt. 132 func (h *Hold) IsStarted() bool { 133 select { 134 case <-h.waitCh: 135 return true 136 default: 137 return false 138 } 139 }