github.com/Serizao/go-winio@v0.0.0-20230906082528-f02f7f4ad6e8/pipe_test.go (about) 1 //go:build windows 2 // +build windows 3 4 package winio 5 6 import ( 7 "bufio" 8 "bytes" 9 "context" 10 "errors" 11 "io" 12 "net" 13 "sync" 14 "testing" 15 "time" 16 "unsafe" 17 18 "golang.org/x/sys/windows" 19 ) 20 21 var testPipeName = `\\.\pipe\winiotestpipe` 22 23 var aLongTimeAgo = time.Unix(1, 0) 24 25 func TestDialUnknownFailsImmediately(t *testing.T) { 26 _, err := DialPipe(testPipeName, nil) 27 if !errors.Is(err, windows.ERROR_FILE_NOT_FOUND) { 28 t.Fatalf("expected ERROR_FILE_NOT_FOUND got %v", err) 29 } 30 } 31 32 func TestDialListenerTimesOut(t *testing.T) { 33 l, err := ListenPipe(testPipeName, nil) 34 if err != nil { 35 t.Fatal(err) 36 } 37 defer l.Close() 38 var d = 10 * time.Millisecond 39 _, err = DialPipe(testPipeName, &d) 40 if !errors.Is(err, ErrTimeout) { 41 t.Fatalf("expected ErrTimeout, got %v", err) 42 } 43 } 44 45 func TestDialContextListenerTimesOut(t *testing.T) { 46 l, err := ListenPipe(testPipeName, nil) 47 if err != nil { 48 t.Fatal(err) 49 } 50 defer l.Close() 51 var d = 10 * time.Millisecond 52 ctx, cancel := context.WithTimeout(context.Background(), d) 53 defer cancel() 54 _, err = DialPipeContext(ctx, testPipeName) 55 if !errors.Is(err, context.DeadlineExceeded) { 56 t.Fatalf("expected context.DeadlineExceeded, got %v", err) 57 } 58 } 59 60 func TestDialListenerGetsCancelled(t *testing.T) { 61 ctx, cancel := context.WithCancel(context.Background()) 62 l, err := ListenPipe(testPipeName, nil) 63 if err != nil { 64 t.Fatal(err) 65 } 66 ch := make(chan error) 67 defer l.Close() 68 go func(ctx context.Context, ch chan error) { 69 _, err := DialPipeContext(ctx, testPipeName) 70 ch <- err 71 }(ctx, ch) 72 time.Sleep(time.Millisecond * 30) 73 cancel() 74 err = <-ch 75 if !errors.Is(err, context.Canceled) { 76 t.Fatalf("expected context.Canceled, got %v", err) 77 } 78 } 79 80 func TestDialAccessDeniedWithRestrictedSD(t *testing.T) { 81 c := PipeConfig{ 82 SecurityDescriptor: "D:P(A;;0x1200FF;;;WD)", 83 } 84 l, err := ListenPipe(testPipeName, &c) 85 if err != nil { 86 t.Fatal(err) 87 } 88 defer l.Close() 89 _, err = DialPipe(testPipeName, nil) 90 if !errors.Is(err, windows.ERROR_ACCESS_DENIED) { 91 t.Fatalf("expected ERROR_ACCESS_DENIED, got %v", err) 92 } 93 } 94 95 func getConnection(cfg *PipeConfig) (client net.Conn, server net.Conn, err error) { 96 l, err := ListenPipe(testPipeName, cfg) 97 if err != nil { 98 return nil, nil, err 99 } 100 defer l.Close() 101 102 type response struct { 103 c net.Conn 104 err error 105 } 106 ch := make(chan response) 107 go func() { 108 c, err := l.Accept() 109 ch <- response{c, err} 110 }() 111 112 c, err := DialPipe(testPipeName, nil) 113 if err != nil { 114 return client, server, err 115 } 116 117 r := <-ch 118 if err = r.err; err != nil { 119 c.Close() 120 return nil, nil, err 121 } 122 123 return c, r.c, nil 124 } 125 126 func TestReadTimeout(t *testing.T) { 127 c, s, err := getConnection(nil) 128 if err != nil { 129 t.Fatal(err) 130 } 131 defer c.Close() 132 defer s.Close() 133 134 _ = c.SetReadDeadline(time.Now().Add(10 * time.Millisecond)) 135 136 buf := make([]byte, 10) 137 _, err = c.Read(buf) 138 if !errors.Is(err, ErrTimeout) { 139 t.Fatalf("expected ErrTimeout, got %v", err) 140 } 141 } 142 143 func server(l net.Listener, ch chan int) { 144 c, err := l.Accept() 145 if err != nil { 146 panic(err) 147 } 148 rw := bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c)) 149 s, err := rw.ReadString('\n') 150 if err != nil { 151 panic(err) 152 } 153 _, err = rw.WriteString("got " + s) 154 if err != nil { 155 panic(err) 156 } 157 err = rw.Flush() 158 if err != nil { 159 panic(err) 160 } 161 c.Close() 162 ch <- 1 163 } 164 165 func TestFullListenDialReadWrite(t *testing.T) { 166 l, err := ListenPipe(testPipeName, nil) 167 if err != nil { 168 t.Fatal(err) 169 } 170 defer l.Close() 171 172 ch := make(chan int) 173 go server(l, ch) 174 175 c, err := DialPipe(testPipeName, nil) 176 if err != nil { 177 t.Fatal(err) 178 } 179 defer c.Close() 180 181 rw := bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c)) 182 _, err = rw.WriteString("hello world\n") 183 if err != nil { 184 t.Fatal(err) 185 } 186 err = rw.Flush() 187 if err != nil { 188 t.Fatal(err) 189 } 190 191 s, err := rw.ReadString('\n') 192 if err != nil { 193 t.Fatal(err) 194 } 195 ms := "got hello world\n" 196 if s != ms { 197 t.Errorf("expected '%s', got '%s'", ms, s) 198 } 199 200 <-ch 201 } 202 203 func TestCloseAbortsListen(t *testing.T) { 204 l, err := ListenPipe(testPipeName, nil) 205 if err != nil { 206 t.Fatal(err) 207 } 208 209 ch := make(chan error) 210 go func() { 211 _, err := l.Accept() 212 ch <- err 213 }() 214 215 time.Sleep(30 * time.Millisecond) 216 l.Close() 217 218 err = <-ch 219 if !errors.Is(err, ErrPipeListenerClosed) { 220 t.Fatalf("expected ErrPipeListenerClosed, got %v", err) 221 } 222 } 223 224 func ensureEOFOnClose(t *testing.T, r io.Reader, w io.Closer) { 225 t.Helper() 226 227 b := make([]byte, 10) 228 w.Close() 229 n, err := r.Read(b) 230 if n > 0 { 231 t.Errorf("unexpected byte count %d", n) 232 } 233 if err != io.EOF { 234 t.Errorf("expected EOF: %v", err) 235 } 236 } 237 238 func TestCloseClientEOFServer(t *testing.T) { 239 c, s, err := getConnection(nil) 240 if err != nil { 241 t.Fatal(err) 242 } 243 defer c.Close() 244 defer s.Close() 245 ensureEOFOnClose(t, c, s) 246 } 247 248 func TestCloseServerEOFClient(t *testing.T) { 249 c, s, err := getConnection(nil) 250 if err != nil { 251 t.Fatal(err) 252 } 253 defer c.Close() 254 defer s.Close() 255 ensureEOFOnClose(t, s, c) 256 } 257 258 func TestCloseWriteEOF(t *testing.T) { 259 cfg := &PipeConfig{ 260 MessageMode: true, 261 } 262 c, s, err := getConnection(cfg) 263 if err != nil { 264 t.Fatal(err) 265 } 266 defer c.Close() 267 defer s.Close() 268 269 type closeWriter interface { 270 CloseWrite() error 271 } 272 273 err = c.(closeWriter).CloseWrite() 274 if err != nil { 275 t.Fatal(err) 276 } 277 278 b := make([]byte, 10) 279 _, err = s.Read(b) 280 if !errors.Is(err, io.EOF) { 281 t.Fatal(err) 282 } 283 } 284 285 func TestAcceptAfterCloseFails(t *testing.T) { 286 l, err := ListenPipe(testPipeName, nil) 287 if err != nil { 288 t.Fatal(err) 289 } 290 l.Close() 291 _, err = l.Accept() 292 if !errors.Is(err, ErrPipeListenerClosed) { 293 t.Fatalf("expected ErrPipeListenerClosed, got %v", err) 294 } 295 } 296 297 func TestDialTimesOutByDefault(t *testing.T) { 298 l, err := ListenPipe(testPipeName, nil) 299 if err != nil { 300 t.Fatal(err) 301 } 302 defer l.Close() 303 _, err = DialPipe(testPipeName, nil) 304 if !errors.Is(err, ErrTimeout) { 305 t.Fatalf("expected ErrTimeout, got %v", err) 306 } 307 } 308 309 func TestTimeoutPendingRead(t *testing.T) { 310 l, err := ListenPipe(testPipeName, nil) 311 if err != nil { 312 t.Fatal(err) 313 } 314 defer l.Close() 315 316 serverDone := make(chan struct{}) 317 318 go func() { 319 s, err := l.Accept() 320 if err != nil { 321 t.Error(err) 322 return 323 } 324 time.Sleep(1 * time.Second) 325 s.Close() 326 close(serverDone) 327 }() 328 329 client, err := DialPipe(testPipeName, nil) 330 if err != nil { 331 t.Fatal(err) 332 } 333 defer client.Close() 334 335 clientErr := make(chan error) 336 go func() { 337 buf := make([]byte, 10) 338 _, err = client.Read(buf) 339 clientErr <- err 340 }() 341 342 time.Sleep(100 * time.Millisecond) // make *sure* the pipe is reading before we set the deadline 343 _ = client.SetReadDeadline(aLongTimeAgo) 344 345 select { 346 case err = <-clientErr: 347 if !errors.Is(err, ErrTimeout) { 348 t.Fatalf("expected ErrTimeout, got %v", err) 349 } 350 case <-time.After(100 * time.Millisecond): 351 t.Fatalf("timed out while waiting for read to cancel") 352 <-clientErr 353 } 354 <-serverDone 355 } 356 357 func TestTimeoutPendingWrite(t *testing.T) { 358 l, err := ListenPipe(testPipeName, nil) 359 if err != nil { 360 t.Fatal(err) 361 } 362 defer l.Close() 363 364 serverDone := make(chan struct{}) 365 366 go func() { 367 s, err := l.Accept() 368 if err != nil { 369 t.Error(err) 370 return 371 } 372 time.Sleep(1 * time.Second) 373 s.Close() 374 close(serverDone) 375 }() 376 377 client, err := DialPipe(testPipeName, nil) 378 if err != nil { 379 t.Fatal(err) 380 } 381 defer client.Close() 382 383 clientErr := make(chan error) 384 go func() { 385 _, err = client.Write([]byte("this should timeout")) 386 clientErr <- err 387 }() 388 389 time.Sleep(100 * time.Millisecond) // make *sure* the pipe is writing before we set the deadline 390 _ = client.SetWriteDeadline(aLongTimeAgo) 391 392 select { 393 case err = <-clientErr: 394 if !errors.Is(err, ErrTimeout) { 395 t.Fatalf("expected ErrTimeout, got %v", err) 396 } 397 case <-time.After(100 * time.Millisecond): 398 t.Fatalf("timed out while waiting for write to cancel") 399 <-clientErr 400 } 401 <-serverDone 402 } 403 404 func TestDisconnectPipe(t *testing.T) { 405 l, err := ListenPipe(testPipeName, nil) 406 if err != nil { 407 t.Fatal(err) 408 } 409 defer l.Close() 410 411 const testData = "foo" 412 serverDone := make(chan struct{}) 413 414 go func() { 415 s, err := l.Accept() 416 if err != nil { 417 t.Error(err) 418 return 419 } 420 defer func() { 421 s.Close() 422 close(serverDone) 423 }() 424 425 if _, err := s.Write([]byte(testData)); err != nil { 426 t.Error(err) 427 return 428 } 429 430 if err := s.(PipeConn).Flush(); err != nil { 431 t.Error(err) 432 return 433 } 434 435 if err := s.(PipeConn).Disconnect(); err != nil { 436 t.Error(err) 437 return 438 } 439 }() 440 441 client, err := DialPipe(testPipeName, nil) 442 if err != nil { 443 t.Fatal(err) 444 } 445 defer client.Close() 446 447 buf := make([]byte, len(testData)) 448 if _, err = client.Read(buf); err != nil { 449 t.Fatal(err) 450 } 451 452 dataRead := string(buf) 453 if dataRead != testData { 454 t.Fatalf("incorrect data read %q", dataRead) 455 } 456 457 if _, err = client.Read(buf); err == nil { 458 t.Fatal("read should fail") 459 } 460 461 <-serverDone 462 } 463 464 type CloseWriter interface { 465 CloseWrite() error 466 } 467 468 func TestEchoWithMessaging(t *testing.T) { 469 c := PipeConfig{ 470 MessageMode: true, // Use message mode so that CloseWrite() is supported 471 InputBufferSize: 65536, // Use 64KB buffers to improve performance 472 OutputBufferSize: 65536, 473 } 474 l, err := ListenPipe(testPipeName, &c) 475 if err != nil { 476 t.Fatal(err) 477 } 478 defer l.Close() 479 480 listenerDone := make(chan bool) 481 clientDone := make(chan bool) 482 go func() { 483 // server echo 484 conn, e := l.Accept() 485 if e != nil { 486 t.Error(err) 487 return 488 } 489 defer conn.Close() 490 491 time.Sleep(500 * time.Millisecond) // make *sure* we don't begin to read before eof signal is sent 492 _, _ = io.Copy(conn, conn) 493 _ = conn.(CloseWriter).CloseWrite() 494 close(listenerDone) 495 }() 496 timeout := 1 * time.Second 497 client, err := DialPipe(testPipeName, &timeout) 498 if err != nil { 499 t.Fatal(err) 500 } 501 defer client.Close() 502 503 go func() { 504 // client read back 505 bytes := make([]byte, 2) 506 n, e := client.Read(bytes) 507 if e != nil { 508 t.Error(err) 509 return 510 } 511 if n != 2 { 512 t.Errorf("expected 2 bytes, got %v", n) 513 return 514 } 515 close(clientDone) 516 }() 517 518 payload := make([]byte, 2) 519 payload[0] = 0 520 payload[1] = 1 521 522 n, err := client.Write(payload) 523 if err != nil { 524 t.Fatal(err) 525 } 526 if n != 2 { 527 t.Fatalf("expected 2 bytes, got %v", n) 528 } 529 _ = client.(CloseWriter).CloseWrite() 530 <-listenerDone 531 <-clientDone 532 } 533 534 func TestConnectRace(t *testing.T) { 535 l, err := ListenPipe(testPipeName, nil) 536 if err != nil { 537 t.Fatal(err) 538 } 539 defer l.Close() 540 go func() { 541 for { 542 s, err := l.Accept() 543 if errors.Is(err, ErrPipeListenerClosed) { 544 return 545 } 546 547 if err != nil { 548 t.Error(err) 549 return 550 } 551 s.Close() 552 } 553 }() 554 555 for i := 0; i < 1000; i++ { 556 c, err := DialPipe(testPipeName, nil) 557 if err != nil { 558 t.Fatal(err) 559 } 560 c.Close() 561 } 562 } 563 564 func TestMessageReadMode(t *testing.T) { 565 var wg sync.WaitGroup 566 defer wg.Wait() 567 568 l, err := ListenPipe(testPipeName, &PipeConfig{MessageMode: true}) 569 if err != nil { 570 t.Fatal(err) 571 } 572 defer l.Close() 573 574 msg := ([]byte)("hello world") 575 576 wg.Add(1) 577 go func() { 578 defer wg.Done() 579 s, err := l.Accept() 580 if err != nil { 581 t.Error(err) 582 return 583 } 584 _, err = s.Write(msg) 585 if err != nil { 586 t.Error(err) 587 return 588 } 589 s.Close() 590 }() 591 592 c, err := DialPipe(testPipeName, nil) 593 if err != nil { 594 t.Fatal(err) 595 } 596 defer c.Close() 597 598 setNamedPipeHandleState := windows.NewLazyDLL("kernel32.dll").NewProc("SetNamedPipeHandleState") 599 600 p := c.(*win32MessageBytePipe) 601 mode := uint32(windows.PIPE_READMODE_MESSAGE) 602 if s, _, err := setNamedPipeHandleState.Call(uintptr(p.handle), uintptr(unsafe.Pointer(&mode)), 0, 0); s == 0 { 603 t.Fatal(err) 604 } 605 606 ch := make([]byte, 1) 607 var vmsg []byte 608 for { 609 n, err := c.Read(ch) 610 if err == io.EOF { //nolint:errorlint 611 break 612 } 613 if err != nil { 614 t.Fatal(err) 615 } 616 if n != 1 { 617 t.Fatal("expected 1: ", n) 618 } 619 vmsg = append(vmsg, ch[0]) 620 } 621 if !bytes.Equal(msg, vmsg) { 622 t.Fatalf("expected %s: %s", msg, vmsg) 623 } 624 } 625 626 func TestListenConnectRace(t *testing.T) { 627 for i := 0; i < 50 && !t.Failed(); i++ { 628 var wg sync.WaitGroup 629 wg.Add(1) 630 go func() { 631 c, err := DialPipe(testPipeName, nil) 632 if err == nil { 633 c.Close() 634 } 635 wg.Done() 636 }() 637 s, err := ListenPipe(testPipeName, nil) 638 if err != nil { 639 t.Error(i, err) 640 } else { 641 s.Close() 642 } 643 wg.Wait() 644 } 645 }