github.com/liloew/wireguard-go@v0.0.0-20220224014633-9cd745e6f114/ipc/namedpipe/namedpipe_test.go (about) 1 // Copyright 2021 The Go Authors. All rights reserved. 2 // Copyright 2015 Microsoft 3 // Use of this source code is governed by a BSD-style 4 // license that can be found in the LICENSE file. 5 6 //go:build windows 7 // +build windows 8 9 package namedpipe_test 10 11 import ( 12 "bufio" 13 "bytes" 14 "context" 15 "errors" 16 "io" 17 "net" 18 "os" 19 "sync" 20 "syscall" 21 "testing" 22 "time" 23 24 "golang.org/x/sys/windows" 25 "github.com/liloew/wireguard-go/ipc/namedpipe" 26 ) 27 28 func randomPipePath() string { 29 guid, err := windows.GenerateGUID() 30 if err != nil { 31 panic(err) 32 } 33 return `\\.\PIPE\go-namedpipe-test-` + guid.String() 34 } 35 36 func TestPingPong(t *testing.T) { 37 const ( 38 ping = 42 39 pong = 24 40 ) 41 pipePath := randomPipePath() 42 listener, err := namedpipe.Listen(pipePath) 43 if err != nil { 44 t.Fatalf("unable to listen on pipe: %v", err) 45 } 46 defer listener.Close() 47 go func() { 48 incoming, err := listener.Accept() 49 if err != nil { 50 t.Fatalf("unable to accept pipe connection: %v", err) 51 } 52 defer incoming.Close() 53 var data [1]byte 54 _, err = incoming.Read(data[:]) 55 if err != nil { 56 t.Fatalf("unable to read ping from pipe: %v", err) 57 } 58 if data[0] != ping { 59 t.Fatalf("expected ping, got %d", data[0]) 60 } 61 data[0] = pong 62 _, err = incoming.Write(data[:]) 63 if err != nil { 64 t.Fatalf("unable to write pong to pipe: %v", err) 65 } 66 }() 67 client, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) 68 if err != nil { 69 t.Fatalf("unable to dial pipe: %v", err) 70 } 71 defer client.Close() 72 client.SetDeadline(time.Now().Add(time.Second * 5)) 73 var data [1]byte 74 data[0] = ping 75 _, err = client.Write(data[:]) 76 if err != nil { 77 t.Fatalf("unable to write ping to pipe: %v", err) 78 } 79 _, err = client.Read(data[:]) 80 if err != nil { 81 t.Fatalf("unable to read pong from pipe: %v", err) 82 } 83 if data[0] != pong { 84 t.Fatalf("expected pong, got %d", data[0]) 85 } 86 } 87 88 func TestDialUnknownFailsImmediately(t *testing.T) { 89 _, err := namedpipe.DialTimeout(randomPipePath(), time.Duration(0)) 90 if !errors.Is(err, syscall.ENOENT) { 91 t.Fatalf("expected ENOENT got %v", err) 92 } 93 } 94 95 func TestDialListenerTimesOut(t *testing.T) { 96 pipePath := randomPipePath() 97 l, err := namedpipe.Listen(pipePath) 98 if err != nil { 99 t.Fatal(err) 100 } 101 defer l.Close() 102 pipe, err := namedpipe.DialTimeout(pipePath, 10*time.Millisecond) 103 if err == nil { 104 pipe.Close() 105 } 106 if err != os.ErrDeadlineExceeded { 107 t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err) 108 } 109 } 110 111 func TestDialContextListenerTimesOut(t *testing.T) { 112 pipePath := randomPipePath() 113 l, err := namedpipe.Listen(pipePath) 114 if err != nil { 115 t.Fatal(err) 116 } 117 defer l.Close() 118 d := 10 * time.Millisecond 119 ctx, _ := context.WithTimeout(context.Background(), d) 120 pipe, err := namedpipe.DialContext(ctx, pipePath) 121 if err == nil { 122 pipe.Close() 123 } 124 if err != context.DeadlineExceeded { 125 t.Fatalf("expected context.DeadlineExceeded, got %v", err) 126 } 127 } 128 129 func TestDialListenerGetsCancelled(t *testing.T) { 130 pipePath := randomPipePath() 131 ctx, cancel := context.WithCancel(context.Background()) 132 l, err := namedpipe.Listen(pipePath) 133 if err != nil { 134 t.Fatal(err) 135 } 136 defer l.Close() 137 ch := make(chan error) 138 go func(ctx context.Context, ch chan error) { 139 _, err := namedpipe.DialContext(ctx, pipePath) 140 ch <- err 141 }(ctx, ch) 142 time.Sleep(time.Millisecond * 30) 143 cancel() 144 err = <-ch 145 if err != context.Canceled { 146 t.Fatalf("expected context.Canceled, got %v", err) 147 } 148 } 149 150 func TestDialAccessDeniedWithRestrictedSD(t *testing.T) { 151 if windows.NewLazySystemDLL("ntdll.dll").NewProc("wine_get_version").Find() == nil { 152 t.Skip("dacls on named pipes are broken on wine") 153 } 154 pipePath := randomPipePath() 155 sd, _ := windows.SecurityDescriptorFromString("D:") 156 l, err := (&namedpipe.ListenConfig{ 157 SecurityDescriptor: sd, 158 }).Listen(pipePath) 159 if err != nil { 160 t.Fatal(err) 161 } 162 defer l.Close() 163 pipe, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) 164 if err == nil { 165 pipe.Close() 166 } 167 if !errors.Is(err, windows.ERROR_ACCESS_DENIED) { 168 t.Fatalf("expected ERROR_ACCESS_DENIED, got %v", err) 169 } 170 } 171 172 func getConnection(cfg *namedpipe.ListenConfig) (client, server net.Conn, err error) { 173 pipePath := randomPipePath() 174 if cfg == nil { 175 cfg = &namedpipe.ListenConfig{} 176 } 177 l, err := cfg.Listen(pipePath) 178 if err != nil { 179 return 180 } 181 defer l.Close() 182 183 type response struct { 184 c net.Conn 185 err error 186 } 187 ch := make(chan response) 188 go func() { 189 c, err := l.Accept() 190 ch <- response{c, err} 191 }() 192 193 c, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) 194 if err != nil { 195 return 196 } 197 198 r := <-ch 199 if err = r.err; err != nil { 200 c.Close() 201 return 202 } 203 204 client = c 205 server = r.c 206 return 207 } 208 209 func TestReadTimeout(t *testing.T) { 210 c, s, err := getConnection(nil) 211 if err != nil { 212 t.Fatal(err) 213 } 214 defer c.Close() 215 defer s.Close() 216 217 c.SetReadDeadline(time.Now().Add(10 * time.Millisecond)) 218 219 buf := make([]byte, 10) 220 _, err = c.Read(buf) 221 if err != os.ErrDeadlineExceeded { 222 t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err) 223 } 224 } 225 226 func server(l net.Listener, ch chan int) { 227 c, err := l.Accept() 228 if err != nil { 229 panic(err) 230 } 231 rw := bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c)) 232 s, err := rw.ReadString('\n') 233 if err != nil { 234 panic(err) 235 } 236 _, err = rw.WriteString("got " + s) 237 if err != nil { 238 panic(err) 239 } 240 err = rw.Flush() 241 if err != nil { 242 panic(err) 243 } 244 c.Close() 245 ch <- 1 246 } 247 248 func TestFullListenDialReadWrite(t *testing.T) { 249 pipePath := randomPipePath() 250 l, err := namedpipe.Listen(pipePath) 251 if err != nil { 252 t.Fatal(err) 253 } 254 defer l.Close() 255 256 ch := make(chan int) 257 go server(l, ch) 258 259 c, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) 260 if err != nil { 261 t.Fatal(err) 262 } 263 defer c.Close() 264 265 rw := bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c)) 266 _, err = rw.WriteString("hello world\n") 267 if err != nil { 268 t.Fatal(err) 269 } 270 err = rw.Flush() 271 if err != nil { 272 t.Fatal(err) 273 } 274 275 s, err := rw.ReadString('\n') 276 if err != nil { 277 t.Fatal(err) 278 } 279 ms := "got hello world\n" 280 if s != ms { 281 t.Errorf("expected '%s', got '%s'", ms, s) 282 } 283 284 <-ch 285 } 286 287 func TestCloseAbortsListen(t *testing.T) { 288 pipePath := randomPipePath() 289 l, err := namedpipe.Listen(pipePath) 290 if err != nil { 291 t.Fatal(err) 292 } 293 294 ch := make(chan error) 295 go func() { 296 _, err := l.Accept() 297 ch <- err 298 }() 299 300 time.Sleep(30 * time.Millisecond) 301 l.Close() 302 303 err = <-ch 304 if err != net.ErrClosed { 305 t.Fatalf("expected net.ErrClosed, got %v", err) 306 } 307 } 308 309 func ensureEOFOnClose(t *testing.T, r io.Reader, w io.Closer) { 310 b := make([]byte, 10) 311 w.Close() 312 n, err := r.Read(b) 313 if n > 0 { 314 t.Errorf("unexpected byte count %d", n) 315 } 316 if err != io.EOF { 317 t.Errorf("expected EOF: %v", err) 318 } 319 } 320 321 func TestCloseClientEOFServer(t *testing.T) { 322 c, s, err := getConnection(nil) 323 if err != nil { 324 t.Fatal(err) 325 } 326 defer c.Close() 327 defer s.Close() 328 ensureEOFOnClose(t, c, s) 329 } 330 331 func TestCloseServerEOFClient(t *testing.T) { 332 c, s, err := getConnection(nil) 333 if err != nil { 334 t.Fatal(err) 335 } 336 defer c.Close() 337 defer s.Close() 338 ensureEOFOnClose(t, s, c) 339 } 340 341 func TestCloseWriteEOF(t *testing.T) { 342 cfg := &namedpipe.ListenConfig{ 343 MessageMode: true, 344 } 345 c, s, err := getConnection(cfg) 346 if err != nil { 347 t.Fatal(err) 348 } 349 defer c.Close() 350 defer s.Close() 351 352 type closeWriter interface { 353 CloseWrite() error 354 } 355 356 err = c.(closeWriter).CloseWrite() 357 if err != nil { 358 t.Fatal(err) 359 } 360 361 b := make([]byte, 10) 362 _, err = s.Read(b) 363 if err != io.EOF { 364 t.Fatal(err) 365 } 366 } 367 368 func TestAcceptAfterCloseFails(t *testing.T) { 369 pipePath := randomPipePath() 370 l, err := namedpipe.Listen(pipePath) 371 if err != nil { 372 t.Fatal(err) 373 } 374 l.Close() 375 _, err = l.Accept() 376 if err != net.ErrClosed { 377 t.Fatalf("expected net.ErrClosed, got %v", err) 378 } 379 } 380 381 func TestDialTimesOutByDefault(t *testing.T) { 382 pipePath := randomPipePath() 383 l, err := namedpipe.Listen(pipePath) 384 if err != nil { 385 t.Fatal(err) 386 } 387 defer l.Close() 388 pipe, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) // Should timeout after 2 seconds. 389 if err == nil { 390 pipe.Close() 391 } 392 if err != os.ErrDeadlineExceeded { 393 t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err) 394 } 395 } 396 397 func TestTimeoutPendingRead(t *testing.T) { 398 pipePath := randomPipePath() 399 l, err := namedpipe.Listen(pipePath) 400 if err != nil { 401 t.Fatal(err) 402 } 403 defer l.Close() 404 405 serverDone := make(chan struct{}) 406 407 go func() { 408 s, err := l.Accept() 409 if err != nil { 410 t.Fatal(err) 411 } 412 time.Sleep(1 * time.Second) 413 s.Close() 414 close(serverDone) 415 }() 416 417 client, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) 418 if err != nil { 419 t.Fatal(err) 420 } 421 defer client.Close() 422 423 clientErr := make(chan error) 424 go func() { 425 buf := make([]byte, 10) 426 _, err = client.Read(buf) 427 clientErr <- err 428 }() 429 430 time.Sleep(100 * time.Millisecond) // make *sure* the pipe is reading before we set the deadline 431 client.SetReadDeadline(time.Unix(1, 0)) 432 433 select { 434 case err = <-clientErr: 435 if err != os.ErrDeadlineExceeded { 436 t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err) 437 } 438 case <-time.After(100 * time.Millisecond): 439 t.Fatalf("timed out while waiting for read to cancel") 440 <-clientErr 441 } 442 <-serverDone 443 } 444 445 func TestTimeoutPendingWrite(t *testing.T) { 446 pipePath := randomPipePath() 447 l, err := namedpipe.Listen(pipePath) 448 if err != nil { 449 t.Fatal(err) 450 } 451 defer l.Close() 452 453 serverDone := make(chan struct{}) 454 455 go func() { 456 s, err := l.Accept() 457 if err != nil { 458 t.Fatal(err) 459 } 460 time.Sleep(1 * time.Second) 461 s.Close() 462 close(serverDone) 463 }() 464 465 client, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) 466 if err != nil { 467 t.Fatal(err) 468 } 469 defer client.Close() 470 471 clientErr := make(chan error) 472 go func() { 473 _, err = client.Write([]byte("this should timeout")) 474 clientErr <- err 475 }() 476 477 time.Sleep(100 * time.Millisecond) // make *sure* the pipe is writing before we set the deadline 478 client.SetWriteDeadline(time.Unix(1, 0)) 479 480 select { 481 case err = <-clientErr: 482 if err != os.ErrDeadlineExceeded { 483 t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err) 484 } 485 case <-time.After(100 * time.Millisecond): 486 t.Fatalf("timed out while waiting for write to cancel") 487 <-clientErr 488 } 489 <-serverDone 490 } 491 492 type CloseWriter interface { 493 CloseWrite() error 494 } 495 496 func TestEchoWithMessaging(t *testing.T) { 497 pipePath := randomPipePath() 498 l, err := (&namedpipe.ListenConfig{ 499 MessageMode: true, // Use message mode so that CloseWrite() is supported 500 InputBufferSize: 65536, // Use 64KB buffers to improve performance 501 OutputBufferSize: 65536, 502 }).Listen(pipePath) 503 if err != nil { 504 t.Fatal(err) 505 } 506 defer l.Close() 507 508 listenerDone := make(chan bool) 509 clientDone := make(chan bool) 510 go func() { 511 // server echo 512 conn, err := l.Accept() 513 if err != nil { 514 t.Fatal(err) 515 } 516 defer conn.Close() 517 518 time.Sleep(500 * time.Millisecond) // make *sure* we don't begin to read before eof signal is sent 519 _, err = io.Copy(conn, conn) 520 if err != nil { 521 t.Fatal(err) 522 } 523 conn.(CloseWriter).CloseWrite() 524 close(listenerDone) 525 }() 526 client, err := namedpipe.DialTimeout(pipePath, time.Second) 527 if err != nil { 528 t.Fatal(err) 529 } 530 defer client.Close() 531 532 go func() { 533 // client read back 534 bytes := make([]byte, 2) 535 n, e := client.Read(bytes) 536 if e != nil { 537 t.Fatal(e) 538 } 539 if n != 2 || bytes[0] != 0 || bytes[1] != 1 { 540 t.Fatalf("expected 2 bytes, got %v", n) 541 } 542 close(clientDone) 543 }() 544 545 payload := make([]byte, 2) 546 payload[0] = 0 547 payload[1] = 1 548 549 n, err := client.Write(payload) 550 if err != nil { 551 t.Fatal(err) 552 } 553 if n != 2 { 554 t.Fatalf("expected 2 bytes, got %v", n) 555 } 556 client.(CloseWriter).CloseWrite() 557 <-listenerDone 558 <-clientDone 559 } 560 561 func TestConnectRace(t *testing.T) { 562 pipePath := randomPipePath() 563 l, err := namedpipe.Listen(pipePath) 564 if err != nil { 565 t.Fatal(err) 566 } 567 defer l.Close() 568 go func() { 569 for { 570 s, err := l.Accept() 571 if err == net.ErrClosed { 572 return 573 } 574 575 if err != nil { 576 t.Fatal(err) 577 } 578 s.Close() 579 } 580 }() 581 582 for i := 0; i < 1000; i++ { 583 c, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) 584 if err != nil { 585 t.Fatal(err) 586 } 587 c.Close() 588 } 589 } 590 591 func TestMessageReadMode(t *testing.T) { 592 if maj, _, _ := windows.RtlGetNtVersionNumbers(); maj <= 8 { 593 t.Skipf("Skipping on Windows %d", maj) 594 } 595 var wg sync.WaitGroup 596 defer wg.Wait() 597 pipePath := randomPipePath() 598 l, err := (&namedpipe.ListenConfig{MessageMode: true}).Listen(pipePath) 599 if err != nil { 600 t.Fatal(err) 601 } 602 defer l.Close() 603 604 msg := ([]byte)("hello world") 605 606 wg.Add(1) 607 go func() { 608 defer wg.Done() 609 s, err := l.Accept() 610 if err != nil { 611 t.Fatal(err) 612 } 613 _, err = s.Write(msg) 614 if err != nil { 615 t.Fatal(err) 616 } 617 s.Close() 618 }() 619 620 c, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) 621 if err != nil { 622 t.Fatal(err) 623 } 624 defer c.Close() 625 626 mode := uint32(windows.PIPE_READMODE_MESSAGE) 627 err = windows.SetNamedPipeHandleState(c.(interface{ Handle() windows.Handle }).Handle(), &mode, nil, nil) 628 if err != nil { 629 t.Fatal(err) 630 } 631 632 ch := make([]byte, 1) 633 var vmsg []byte 634 for { 635 n, err := c.Read(ch) 636 if err == io.EOF { 637 break 638 } 639 if err != nil { 640 t.Fatal(err) 641 } 642 if n != 1 { 643 t.Fatalf("expected 1, got %d", n) 644 } 645 vmsg = append(vmsg, ch[0]) 646 } 647 if !bytes.Equal(msg, vmsg) { 648 t.Fatalf("expected %s, got %s", msg, vmsg) 649 } 650 } 651 652 func TestListenConnectRace(t *testing.T) { 653 if testing.Short() { 654 t.Skip("Skipping long race test") 655 } 656 pipePath := randomPipePath() 657 for i := 0; i < 50 && !t.Failed(); i++ { 658 var wg sync.WaitGroup 659 wg.Add(1) 660 go func() { 661 c, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) 662 if err == nil { 663 c.Close() 664 } 665 wg.Done() 666 }() 667 s, err := namedpipe.Listen(pipePath) 668 if err != nil { 669 t.Error(i, err) 670 } else { 671 s.Close() 672 } 673 wg.Wait() 674 } 675 }