github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/testutils/net.go (about) 1 // Copyright 2017 The Cockroach Authors. 2 // 3 // Use of this software is governed by the Business Source License 4 // included in the file licenses/BSL.txt. 5 // 6 // As of the Change Date specified in that file, in accordance with 7 // the Business Source License, use of this software will be governed 8 // by the Apache License, Version 2.0, included in the file 9 // licenses/APL.txt. 10 11 package testutils 12 13 import ( 14 "context" 15 "io" 16 "net" 17 "sync" 18 19 "github.com/cockroachdb/cockroach/pkg/util/log" 20 "github.com/cockroachdb/cockroach/pkg/util/syncutil" 21 "github.com/cockroachdb/errors" 22 ) 23 24 // bufferSize is the size of the buffer used by PartitionableConn. Writes to a 25 // partitioned connection will block after the buffer gets filled. 26 const bufferSize = 16 << 10 // 16 KB 27 28 // PartitionableConn is an implementation of net.Conn that allows the 29 // client->server and/or the server->client directions to be temporarily 30 // partitioned. 31 // 32 // A PartitionableConn wraps a provided net.Conn (the serverConn member) and 33 // forwards every read and write to it. It interposes an arbiter in front of it 34 // that's used to block reads/writes while the PartitionableConn is in the 35 // partitioned mode. 36 // 37 // While a direction is partitioned, data sent in that direction doesn't flow. A 38 // write while partitioned will block after an internal buffer gets filled. Data 39 // written to the conn after the partition has been established is not delivered 40 // to the remote party until the partition is lifted. At that time, all the 41 // buffered data is delivered. Since data is delivered async, data written 42 // before the partition is established may or may not be blocked by the 43 // partition; use application-level ACKs if that's important. 44 type PartitionableConn struct { 45 // We embed a net.Conn so that we inherit the interface. Note that we override 46 // Read() and Write(). 47 // 48 // This embedded Conn is half of a net.Pipe(). The other half is clientConn. 49 net.Conn 50 51 clientConn net.Conn 52 serverConn net.Conn 53 54 mu struct { 55 syncutil.Mutex 56 57 // err, if set, is returned by any subsequent call to Read or Write. 58 err error 59 60 // Are any of the two direction (client-to-server, server-to-client) 61 // currently partitioned? 62 c2sPartitioned bool 63 s2cPartitioned bool 64 65 c2sBuffer buf 66 s2cBuffer buf 67 68 // Conds to be signaled when the corresponding partition is lifted. 69 c2sWaiter *sync.Cond 70 s2cWaiter *sync.Cond 71 } 72 } 73 74 type buf struct { 75 // A mutex used to synchronize access to all the fields. It will be set to the 76 // parent PartitionableConn's mutex. 77 *syncutil.Mutex 78 79 data []byte 80 capacity int 81 closed bool 82 // The error that was passed to Close(err). See Close() for more info. 83 closedErr error 84 name string // A human-readable name, useful for debugging. 85 86 // readerWait is signaled when the reader should wake up and check the 87 // buffer's state: when new data is put in the buffer, when the buffer is 88 // closed, and whenever the PartitionableConn wants to unblock all reads (i.e. 89 // on partition). 90 readerWait *sync.Cond 91 92 // capacityWait is signaled when a blocked writer should wake up because data 93 // is taken out of the buffer and there's now some capacity. It's also 94 // signaled when the buffer is closed. 95 capacityWait *sync.Cond 96 } 97 98 func makeBuf(name string, capacity int, mu *syncutil.Mutex) buf { 99 return buf{ 100 Mutex: mu, 101 name: name, 102 capacity: capacity, 103 readerWait: sync.NewCond(mu), 104 capacityWait: sync.NewCond(mu), 105 } 106 } 107 108 // Write adds data to the buffer. If there's zero free capacity, it will block 109 // until there's some capacity available or the buffer is closed. If there's 110 // non-zero insufficient capacity, it will perform a partial write. 111 // 112 // The number of bytes written is returned. 113 func (b *buf) Write(data []byte) (int, error) { 114 b.Lock() 115 defer b.Unlock() 116 for b.capacity == len(b.data) && !b.closed { 117 // Block for capacity. 118 b.capacityWait.Wait() 119 } 120 if b.closed { 121 return 0, b.closedErr 122 } 123 available := b.capacity - len(b.data) 124 toCopy := available 125 if len(data) < available { 126 toCopy = len(data) 127 } 128 b.data = append(b.data, data[:toCopy]...) 129 b.wakeReaderLocked() 130 return toCopy, nil 131 } 132 133 // errEAgain is returned by buf.readLocked() when the read was blocked at the 134 // time when buf.readerWait was signaled (in particular, after the 135 // PartitionableConn interrupted the read because of a partition). The caller is 136 // expected to try the read again after the partition is gone. 137 var errEAgain = errors.New("try read again") 138 139 // readLocked returns data from buf, up to "size" bytes. If there's no data in 140 // the buffer, it blocks until either some data becomes available or the buffer 141 // is closed. 142 func (b *buf) readLocked(size int) ([]byte, error) { 143 if len(b.data) == 0 && !b.closed { 144 b.readerWait.Wait() 145 // We were unblocked either by data arrving, or by a partition, or by 146 // another uninteresting reason. Return to the caller, in case it's because 147 // of a partition. 148 return nil, errEAgain 149 } 150 if b.closed && len(b.data) == 0 { 151 return nil, b.closedErr 152 } 153 var ret []byte 154 if len(b.data) < size { 155 ret = b.data 156 b.data = nil 157 } else { 158 ret = b.data[:size] 159 b.data = b.data[size:] 160 } 161 b.capacityWait.Signal() 162 return ret, nil 163 } 164 165 // Close closes the buffer. All reads and writes that are currently blocked will 166 // be woken and they'll all return err. 167 func (b *buf) Close(err error) { 168 b.Lock() 169 defer b.Unlock() 170 b.closed = true 171 b.closedErr = err 172 b.readerWait.Signal() 173 b.capacityWait.Signal() 174 } 175 176 // wakeReaderLocked wakes the reader in case it's blocked. 177 // See comments on readerWait. 178 // 179 // This needs to be called while holding the buffer's mutex. 180 func (b *buf) wakeReaderLocked() { 181 b.readerWait.Signal() 182 } 183 184 // NewPartitionableConn wraps serverConn in a PartitionableConn. 185 func NewPartitionableConn(serverConn net.Conn) *PartitionableConn { 186 clientEnd, clientConn := net.Pipe() 187 c := &PartitionableConn{ 188 Conn: clientEnd, 189 clientConn: clientConn, 190 serverConn: serverConn, 191 } 192 c.mu.c2sWaiter = sync.NewCond(&c.mu.Mutex) 193 c.mu.s2cWaiter = sync.NewCond(&c.mu.Mutex) 194 c.mu.c2sBuffer = makeBuf("c2sBuf", bufferSize, &c.mu.Mutex) 195 c.mu.s2cBuffer = makeBuf("s2cBuf", bufferSize, &c.mu.Mutex) 196 197 // Start copying from client to server. 198 go func() { 199 err := c.copy( 200 c.clientConn, // src 201 c.serverConn, // dst 202 &c.mu.c2sBuffer, 203 func() { // waitForNoPartitionLocked 204 for c.mu.c2sPartitioned { 205 c.mu.c2sWaiter.Wait() 206 } 207 }) 208 c.mu.Lock() 209 c.mu.err = err 210 c.mu.Unlock() 211 if err := c.clientConn.Close(); err != nil { 212 log.Errorf(context.TODO(), "unexpected error closing internal pipe: %s", err) 213 } 214 if err := c.serverConn.Close(); err != nil { 215 log.Errorf(context.TODO(), "error closing server conn: %s", err) 216 } 217 }() 218 219 // Start copying from server to client. 220 go func() { 221 err := c.copy( 222 c.serverConn, // src 223 c.clientConn, // dst 224 &c.mu.s2cBuffer, 225 func() { // waitForNoPartitionLocked 226 for c.mu.s2cPartitioned { 227 c.mu.s2cWaiter.Wait() 228 } 229 }) 230 c.mu.Lock() 231 c.mu.err = err 232 c.mu.Unlock() 233 if err := c.clientConn.Close(); err != nil { 234 log.Fatalf(context.TODO(), "unexpected error closing internal pipe: %s", err) 235 } 236 if err := c.serverConn.Close(); err != nil { 237 log.Errorf(context.TODO(), "error closing server conn: %s", err) 238 } 239 }() 240 241 return c 242 } 243 244 // Finish removes any partitions that may exist so that blocked goroutines can 245 // finish. 246 // Finish() must be called if a connection may have been left in a partitioned 247 // state. 248 func (c *PartitionableConn) Finish() { 249 c.mu.Lock() 250 defer c.mu.Unlock() 251 c.mu.c2sPartitioned = false 252 c.mu.c2sWaiter.Signal() 253 c.mu.s2cPartitioned = false 254 c.mu.s2cWaiter.Signal() 255 } 256 257 // PartitionC2S partitions the client-to-server direction. 258 // If UnpartitionC2S() is not called, Finish() must be called. 259 func (c *PartitionableConn) PartitionC2S() { 260 c.mu.Lock() 261 defer c.mu.Unlock() 262 if c.mu.c2sPartitioned { 263 panic("already partitioned") 264 } 265 c.mu.c2sPartitioned = true 266 c.mu.c2sBuffer.wakeReaderLocked() 267 } 268 269 // UnpartitionC2S lifts an existing client-to-server partition. 270 func (c *PartitionableConn) UnpartitionC2S() { 271 c.mu.Lock() 272 defer c.mu.Unlock() 273 if !c.mu.c2sPartitioned { 274 panic("not partitioned") 275 } 276 c.mu.c2sPartitioned = false 277 c.mu.c2sWaiter.Signal() 278 } 279 280 // PartitionS2C partitions the server-to-client direction. 281 // If UnpartitionS2C() is not called, Finish() must be called. 282 func (c *PartitionableConn) PartitionS2C() { 283 c.mu.Lock() 284 defer c.mu.Unlock() 285 if c.mu.s2cPartitioned { 286 panic("already partitioned") 287 } 288 c.mu.s2cPartitioned = true 289 c.mu.s2cBuffer.wakeReaderLocked() 290 } 291 292 // UnpartitionS2C lifts an existing server-to-client partition. 293 func (c *PartitionableConn) UnpartitionS2C() { 294 c.mu.Lock() 295 defer c.mu.Unlock() 296 if !c.mu.s2cPartitioned { 297 panic("not partitioned") 298 } 299 c.mu.s2cPartitioned = false 300 c.mu.s2cWaiter.Signal() 301 } 302 303 // Read is part of the net.Conn interface. 304 func (c *PartitionableConn) Read(b []byte) (n int, err error) { 305 c.mu.Lock() 306 err = c.mu.err 307 c.mu.Unlock() 308 if err != nil { 309 return 0, err 310 } 311 312 // Forward to the embedded connection. 313 return c.Conn.Read(b) 314 } 315 316 // Write is part of the net.Conn interface. 317 func (c *PartitionableConn) Write(b []byte) (n int, err error) { 318 c.mu.Lock() 319 err = c.mu.err 320 c.mu.Unlock() 321 if err != nil { 322 return 0, err 323 } 324 325 // Forward to the embedded connection. 326 return c.Conn.Write(b) 327 } 328 329 // readFrom copies data from src into the buffer until src.Read() returns an 330 // error (e.g. io.EOF). That error is returned. 331 // 332 // readFrom is written in the spirit of interface io.ReaderFrom, except it 333 // returns the io.EOF error, and also doesn't guarantee that every byte that has 334 // been read from src is put into the buffer (as the buffer allows concurrent 335 // access and buf.Write can return an error). 336 func (b *buf) readFrom(src io.Reader) error { 337 data := make([]byte, 1024) 338 for { 339 nr, err := src.Read(data) 340 if err != nil { 341 return err 342 } 343 toSend := data[:nr] 344 for { 345 nw, ew := b.Write(toSend) 346 if ew != nil { 347 return ew 348 } 349 if nw == len(toSend) { 350 break 351 } 352 toSend = toSend[nw:] 353 } 354 } 355 } 356 357 // copyFromBuffer copies data from src to dst until src.Read() returns EOF. 358 // The EOF is returned (i.e. the return value is always != nil). This is because 359 // the PartitionableConn wants to hold on to any error, including EOF. 360 // 361 // waitForNoPartitionLocked is a function to be called before consuming data 362 // from src, in order to make sure that we only consume data when we're not 363 // partitioned. It needs to be called under src.Mutex, as the check needs to be 364 // done atomically with consuming the buffer's data. 365 func (c *PartitionableConn) copyFromBuffer( 366 src *buf, dst net.Conn, waitForNoPartitionLocked func(), 367 ) error { 368 for { 369 // Don't read from the buffer while we're partitioned. 370 src.Mutex.Lock() 371 waitForNoPartitionLocked() 372 data, err := src.readLocked(1024 * 1024) 373 src.Mutex.Unlock() 374 375 if len(data) > 0 { 376 nw, ew := dst.Write(data) 377 if ew != nil { 378 err = ew 379 } 380 if len(data) != nw { 381 err = io.ErrShortWrite 382 } 383 } else if err == nil { 384 err = io.EOF 385 } else if errors.Is(err, errEAgain) { 386 continue 387 } 388 if err != nil { 389 return err 390 } 391 } 392 } 393 394 // copy copies data from src to dst while we're not partitioned and stops doing 395 // so while partitioned. 396 // 397 // It runs two goroutines internally: one copying from src to an internal buffer 398 // and one copying from the buffer to dst. The 2nd one deals with partitions. 399 func (c *PartitionableConn) copy( 400 src net.Conn, dst net.Conn, buf *buf, waitForNoPartitionLocked func(), 401 ) error { 402 tasks := make(chan error) 403 go func() { 404 err := buf.readFrom(src) 405 buf.Close(err) 406 tasks <- err 407 }() 408 go func() { 409 err := c.copyFromBuffer(buf, dst, waitForNoPartitionLocked) 410 buf.Close(err) 411 tasks <- err 412 }() 413 err := <-tasks 414 err2 := <-tasks 415 if err == nil { 416 err = err2 417 } 418 return err 419 }