github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/testutils/net_test.go (about) 1 // Copyright 2017 The Cockroach Authors. 2 // 3 // Use of this software is governed by the Business Source License 4 // included in the file licenses/BSL.txt. 5 // 6 // As of the Change Date specified in that file, in accordance with 7 // the Business Source License, use of this software will be governed 8 // by the Apache License, Version 2.0, included in the file 9 // licenses/APL.txt. 10 11 package testutils 12 13 import ( 14 "bufio" 15 "fmt" 16 "io" 17 "net" 18 "testing" 19 "time" 20 21 "github.com/cockroachdb/cockroach/pkg/util" 22 "github.com/cockroachdb/cockroach/pkg/util/grpcutil" 23 "github.com/cockroachdb/cockroach/pkg/util/leaktest" 24 "github.com/cockroachdb/cockroach/pkg/util/netutil" 25 "github.com/cockroachdb/errors" 26 ) 27 28 // RunEchoServer runs a network server that accepts one connection from ln and 29 // echos the data sent on it. 30 // 31 // If serverSideCh != nil, every slice of data received by the server is also 32 // sent on this channel before being echoed back on the connection it came on. 33 // Useful to observe what the server has received when this server is used with 34 // partitioned connections. 35 func RunEchoServer(ln net.Listener, serverSideCh chan<- []byte) error { 36 conn, err := ln.Accept() 37 if err != nil { 38 if grpcutil.IsClosedConnection(err) { 39 return nil 40 } 41 return err 42 } 43 if _, err := copyWithSideChan(conn, conn, serverSideCh); err != nil { 44 return err 45 } 46 return nil 47 } 48 49 // copyWithSideChan is like io.Copy(), but also takes a channel on which data 50 // read from src is sent before being written to dst. 51 func copyWithSideChan(dst io.Writer, src io.Reader, ch chan<- []byte) (written int64, err error) { 52 buf := make([]byte, 32*1024) 53 for { 54 nr, er := src.Read(buf) 55 if nr > 0 { 56 if ch != nil { 57 ch <- buf[:nr] 58 } 59 60 nw, ew := dst.Write(buf[0:nr]) 61 if nw > 0 { 62 written += int64(nw) 63 } 64 if ew != nil { 65 err = ew 66 break 67 } 68 if nr != nw { 69 err = io.ErrShortWrite 70 break 71 } 72 } 73 if er != nil { 74 if er != io.EOF { 75 err = er 76 } 77 break 78 } 79 } 80 return written, err 81 } 82 83 func TestPartitionableConnBasic(t *testing.T) { 84 defer leaktest.AfterTest(t)() 85 addr := util.TestAddr 86 ln, err := net.Listen(addr.Network(), addr.String()) 87 if err != nil { 88 t.Fatal(err) 89 } 90 go func() { 91 if err := RunEchoServer(ln, nil); err != nil { 92 t.Error(err) 93 } 94 }() 95 defer func() { 96 netutil.FatalIfUnexpected(ln.Close()) 97 }() 98 99 serverConn, err := net.Dial("tcp", ln.Addr().String()) 100 if err != nil { 101 t.Fatal(err) 102 } 103 104 pConn := NewPartitionableConn(serverConn) 105 defer pConn.Close() 106 107 exp := "let's see if this value comes back\n" 108 fmt.Fprint(pConn, exp) 109 got, err := bufio.NewReader(pConn).ReadString('\n') 110 if err != nil { 111 t.Fatal(err) 112 } 113 if got != exp { 114 t.Fatalf("expecting: %q , got %q", exp, got) 115 } 116 } 117 118 func TestPartitionableConnPartitionC2S(t *testing.T) { 119 defer leaktest.AfterTest(t)() 120 121 addr := util.TestAddr 122 ln, err := net.Listen(addr.Network(), addr.String()) 123 if err != nil { 124 t.Fatal(err) 125 } 126 serverSideCh := make(chan []byte) 127 go func() { 128 if err := RunEchoServer(ln, serverSideCh); err != nil { 129 t.Error(err) 130 } 131 }() 132 defer func() { 133 netutil.FatalIfUnexpected(ln.Close()) 134 }() 135 136 serverConn, err := net.Dial("tcp", ln.Addr().String()) 137 if err != nil { 138 t.Fatal(err) 139 } 140 141 pConn := NewPartitionableConn(serverConn) 142 defer pConn.Close() 143 144 // Partition the client->server connection. Afterwards, we're going to send 145 // something and assert that the server doesn't get it (within a timeout) by 146 // snooping on the server's side channel. Then we'll resolve the partition and 147 // expect that the server gets the message that was pending and echoes it 148 // back. 149 150 pConn.PartitionC2S() 151 152 // Client sends data. 153 exp := "let's see when this value comes back\n" 154 fmt.Fprint(pConn, exp) 155 156 // In the background, the client waits on a read. 157 clientDoneCh := make(chan error) 158 go func() { 159 clientDoneCh <- func() error { 160 got, err := bufio.NewReader(pConn).ReadString('\n') 161 if err != nil { 162 return err 163 } 164 if got != exp { 165 return errors.Errorf("expecting: %q , got %q", exp, got) 166 } 167 return nil 168 }() 169 }() 170 171 timerDoneCh := make(chan error) 172 time.AfterFunc(3*time.Millisecond, func() { 173 var err error 174 select { 175 case err = <-clientDoneCh: 176 err = errors.Errorf("unexpected reply while partitioned: %v", err) 177 case buf := <-serverSideCh: 178 err = errors.Errorf("server was not supposed to have received data while partitioned: %q", buf) 179 default: 180 } 181 timerDoneCh <- err 182 }) 183 184 if err := <-timerDoneCh; err != nil { 185 t.Fatal(err) 186 } 187 188 // Now unpartition and expect the pending data to be sent and a reply to be 189 // received. 190 191 pConn.UnpartitionC2S() 192 193 // Expect the server to receive the data. 194 <-serverSideCh 195 196 if err := <-clientDoneCh; err != nil { 197 t.Fatal(err) 198 } 199 } 200 201 func TestPartitionableConnPartitionS2C(t *testing.T) { 202 defer leaktest.AfterTest(t)() 203 204 addr := util.TestAddr 205 ln, err := net.Listen(addr.Network(), addr.String()) 206 if err != nil { 207 t.Fatal(err) 208 } 209 serverSideCh := make(chan []byte) 210 go func() { 211 if err := RunEchoServer(ln, serverSideCh); err != nil { 212 t.Error(err) 213 } 214 }() 215 defer func() { 216 netutil.FatalIfUnexpected(ln.Close()) 217 }() 218 219 serverConn, err := net.Dial("tcp", ln.Addr().String()) 220 if err != nil { 221 t.Fatal(err) 222 } 223 224 pConn := NewPartitionableConn(serverConn) 225 defer pConn.Close() 226 227 // We're going to partition the server->client connection. Then we'll send 228 // some data and assert that the server gets it (by snooping on the server's 229 // side-channel). Then we'll assert that the client doesn't get the reply 230 // (with a timeout). Then we resolve the partition and assert that the client 231 // gets the reply. 232 233 pConn.PartitionS2C() 234 235 // Client sends data. 236 exp := "let's see when this value comes back\n" 237 fmt.Fprint(pConn, exp) 238 239 if s := <-serverSideCh; string(s) != exp { 240 t.Fatalf("expected server to receive %q, got %q", exp, s) 241 } 242 243 // In the background, the client waits on a read. 244 clientDoneCh := make(chan error) 245 go func() { 246 clientDoneCh <- func() error { 247 got, err := bufio.NewReader(pConn).ReadString('\n') 248 if err != nil { 249 return err 250 } 251 if got != exp { 252 return errors.Errorf("expecting: %q , got %q", exp, got) 253 } 254 return nil 255 }() 256 }() 257 258 // Check that the client does not get the server's response. 259 time.AfterFunc(3*time.Millisecond, func() { 260 select { 261 case err := <-clientDoneCh: 262 t.Errorf("unexpected reply while partitioned: %v", err) 263 default: 264 } 265 }) 266 267 // Now unpartition and expect the pending data to be sent and a reply to be 268 // received. 269 270 pConn.UnpartitionS2C() 271 272 if err := <-clientDoneCh; err != nil { 273 t.Fatal(err) 274 } 275 } 276 277 // Test that, while partitioned, a sender doesn't block while the internal 278 // buffer is not full. 279 func TestPartitionableConnBuffering(t *testing.T) { 280 defer leaktest.AfterTest(t)() 281 282 addr := util.TestAddr 283 ln, err := net.Listen(addr.Network(), addr.String()) 284 if err != nil { 285 t.Fatal(err) 286 } 287 288 // In the background, the server reads everything. 289 exp := 5 * (bufferSize / 10) 290 serverDoneCh := make(chan error) 291 go func() { 292 serverDoneCh <- func() error { 293 conn, err := ln.Accept() 294 if err != nil { 295 return err 296 } 297 received := 0 298 for { 299 data := make([]byte, 1024*1024) 300 nr, err := conn.Read(data) 301 if err != nil { 302 if err == io.EOF { 303 break 304 } 305 return err 306 } 307 received += nr 308 } 309 if received != exp { 310 return errors.Errorf("server expecting: %d , got %d", exp, received) 311 } 312 return nil 313 }() 314 }() 315 316 serverConn, err := net.Dial("tcp", ln.Addr().String()) 317 if err != nil { 318 t.Fatal(err) 319 } 320 321 pConn := NewPartitionableConn(serverConn) 322 defer pConn.Close() 323 324 pConn.PartitionC2S() 325 defer pConn.Finish() 326 327 // Send chunks such that they don't add up to the buffer size exactly. 328 data := make([]byte, bufferSize/10) 329 for i := 0; i < 5; i++ { 330 nw, err := pConn.Write(data) 331 if err != nil { 332 t.Fatal(err) 333 } 334 if nw != len(data) { 335 t.Fatal("unexpected partial write; PartitionableConn always writes fully") 336 } 337 } 338 pConn.UnpartitionC2S() 339 pConn.Close() 340 341 if err := <-serverDoneCh; err != nil { 342 t.Fatal(err) 343 } 344 } 345 346 // Test that, while partitioned, a party can close the connection and the other 347 // party will not observe this until after the partition is lifted. 348 func TestPartitionableConnCloseDeliveredAfterPartition(t *testing.T) { 349 defer leaktest.AfterTest(t)() 350 351 addr := util.TestAddr 352 ln, err := net.Listen(addr.Network(), addr.String()) 353 if err != nil { 354 t.Fatal(err) 355 } 356 357 // In the background, the server reads everything. 358 serverDoneCh := make(chan error) 359 go func() { 360 serverDoneCh <- func() error { 361 conn, err := ln.Accept() 362 if err != nil { 363 return err 364 } 365 received := 0 366 for { 367 data := make([]byte, 1<<20 /* 1 MiB */) 368 nr, err := conn.Read(data) 369 if err != nil { 370 if err == io.EOF { 371 return nil 372 } 373 return err 374 } 375 received += nr 376 } 377 }() 378 }() 379 380 serverConn, err := net.Dial("tcp", ln.Addr().String()) 381 if err != nil { 382 t.Fatal(err) 383 } 384 385 pConn := NewPartitionableConn(serverConn) 386 defer pConn.Close() 387 388 pConn.PartitionC2S() 389 defer pConn.Finish() 390 391 pConn.Close() 392 393 timerDoneCh := make(chan error) 394 time.AfterFunc(3*time.Millisecond, func() { 395 var err error 396 select { 397 case err = <-serverDoneCh: 398 err = errors.Wrap(err, "server was not supposed to see the closing while partitioned") 399 default: 400 } 401 timerDoneCh <- err 402 }) 403 404 if err := <-timerDoneCh; err != nil { 405 t.Fatal(err) 406 } 407 408 pConn.UnpartitionC2S() 409 410 if err := <-serverDoneCh; err != nil { 411 t.Fatal(err) 412 } 413 }