github.com/koomox/wireguard-go@v0.0.0-20230722134753-17a50b2f22a3/ipc/namedpipe/namedpipe_test.go (about) 1 // Copyright 2021 The Go Authors. All rights reserved. 2 // Copyright 2015 Microsoft 3 // Use of this source code is governed by a BSD-style 4 // license that can be found in the LICENSE file. 5 6 //go:build windows 7 8 package namedpipe_test 9 10 import ( 11 "bufio" 12 "bytes" 13 "context" 14 "errors" 15 "io" 16 "net" 17 "os" 18 "sync" 19 "syscall" 20 "testing" 21 "time" 22 23 "golang.org/x/sys/windows" 24 "github.com/koomox/wireguard-go/ipc/namedpipe" 25 ) 26 27 func randomPipePath() string { 28 guid, err := windows.GenerateGUID() 29 if err != nil { 30 panic(err) 31 } 32 return `\\.\PIPE\go-namedpipe-test-` + guid.String() 33 } 34 35 func TestPingPong(t *testing.T) { 36 const ( 37 ping = 42 38 pong = 24 39 ) 40 pipePath := randomPipePath() 41 listener, err := namedpipe.Listen(pipePath) 42 if err != nil { 43 t.Fatalf("unable to listen on pipe: %v", err) 44 } 45 defer listener.Close() 46 go func() { 47 incoming, err := listener.Accept() 48 if err != nil { 49 t.Fatalf("unable to accept pipe connection: %v", err) 50 } 51 defer incoming.Close() 52 var data [1]byte 53 _, err = incoming.Read(data[:]) 54 if err != nil { 55 t.Fatalf("unable to read ping from pipe: %v", err) 56 } 57 if data[0] != ping { 58 t.Fatalf("expected ping, got %d", data[0]) 59 } 60 data[0] = pong 61 _, err = incoming.Write(data[:]) 62 if err != nil { 63 t.Fatalf("unable to write pong to pipe: %v", err) 64 } 65 }() 66 client, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) 67 if err != nil { 68 t.Fatalf("unable to dial pipe: %v", err) 69 } 70 defer client.Close() 71 client.SetDeadline(time.Now().Add(time.Second * 5)) 72 var data [1]byte 73 data[0] = ping 74 _, err = client.Write(data[:]) 75 if err != nil { 76 t.Fatalf("unable to write ping to pipe: %v", err) 77 } 78 _, err = client.Read(data[:]) 79 if err != nil { 80 t.Fatalf("unable to read pong from pipe: %v", err) 81 } 82 if data[0] != pong { 83 t.Fatalf("expected pong, got %d", data[0]) 84 } 85 } 86 87 func TestDialUnknownFailsImmediately(t *testing.T) { 88 _, err := namedpipe.DialTimeout(randomPipePath(), time.Duration(0)) 89 if !errors.Is(err, syscall.ENOENT) { 90 t.Fatalf("expected ENOENT got %v", err) 91 } 92 } 93 94 func TestDialListenerTimesOut(t *testing.T) { 95 pipePath := randomPipePath() 96 l, err := namedpipe.Listen(pipePath) 97 if err != nil { 98 t.Fatal(err) 99 } 100 defer l.Close() 101 pipe, err := namedpipe.DialTimeout(pipePath, 10*time.Millisecond) 102 if err == nil { 103 pipe.Close() 104 } 105 if err != os.ErrDeadlineExceeded { 106 t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err) 107 } 108 } 109 110 func TestDialContextListenerTimesOut(t *testing.T) { 111 pipePath := randomPipePath() 112 l, err := namedpipe.Listen(pipePath) 113 if err != nil { 114 t.Fatal(err) 115 } 116 defer l.Close() 117 d := 10 * time.Millisecond 118 ctx, _ := context.WithTimeout(context.Background(), d) 119 pipe, err := namedpipe.DialContext(ctx, pipePath) 120 if err == nil { 121 pipe.Close() 122 } 123 if err != context.DeadlineExceeded { 124 t.Fatalf("expected context.DeadlineExceeded, got %v", err) 125 } 126 } 127 128 func TestDialListenerGetsCancelled(t *testing.T) { 129 pipePath := randomPipePath() 130 ctx, cancel := context.WithCancel(context.Background()) 131 l, err := namedpipe.Listen(pipePath) 132 if err != nil { 133 t.Fatal(err) 134 } 135 defer l.Close() 136 ch := make(chan error) 137 go func(ctx context.Context, ch chan error) { 138 _, err := namedpipe.DialContext(ctx, pipePath) 139 ch <- err 140 }(ctx, ch) 141 time.Sleep(time.Millisecond * 30) 142 cancel() 143 err = <-ch 144 if err != context.Canceled { 145 t.Fatalf("expected context.Canceled, got %v", err) 146 } 147 } 148 149 func TestDialAccessDeniedWithRestrictedSD(t *testing.T) { 150 if windows.NewLazySystemDLL("ntdll.dll").NewProc("wine_get_version").Find() == nil { 151 t.Skip("dacls on named pipes are broken on wine") 152 } 153 pipePath := randomPipePath() 154 sd, _ := windows.SecurityDescriptorFromString("D:") 155 l, err := (&namedpipe.ListenConfig{ 156 SecurityDescriptor: sd, 157 }).Listen(pipePath) 158 if err != nil { 159 t.Fatal(err) 160 } 161 defer l.Close() 162 pipe, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) 163 if err == nil { 164 pipe.Close() 165 } 166 if !errors.Is(err, windows.ERROR_ACCESS_DENIED) { 167 t.Fatalf("expected ERROR_ACCESS_DENIED, got %v", err) 168 } 169 } 170 171 func getConnection(cfg *namedpipe.ListenConfig) (client, server net.Conn, err error) { 172 pipePath := randomPipePath() 173 if cfg == nil { 174 cfg = &namedpipe.ListenConfig{} 175 } 176 l, err := cfg.Listen(pipePath) 177 if err != nil { 178 return 179 } 180 defer l.Close() 181 182 type response struct { 183 c net.Conn 184 err error 185 } 186 ch := make(chan response) 187 go func() { 188 c, err := l.Accept() 189 ch <- response{c, err} 190 }() 191 192 c, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) 193 if err != nil { 194 return 195 } 196 197 r := <-ch 198 if err = r.err; err != nil { 199 c.Close() 200 return 201 } 202 203 client = c 204 server = r.c 205 return 206 } 207 208 func TestReadTimeout(t *testing.T) { 209 c, s, err := getConnection(nil) 210 if err != nil { 211 t.Fatal(err) 212 } 213 defer c.Close() 214 defer s.Close() 215 216 c.SetReadDeadline(time.Now().Add(10 * time.Millisecond)) 217 218 buf := make([]byte, 10) 219 _, err = c.Read(buf) 220 if err != os.ErrDeadlineExceeded { 221 t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err) 222 } 223 } 224 225 func server(l net.Listener, ch chan int) { 226 c, err := l.Accept() 227 if err != nil { 228 panic(err) 229 } 230 rw := bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c)) 231 s, err := rw.ReadString('\n') 232 if err != nil { 233 panic(err) 234 } 235 _, err = rw.WriteString("got " + s) 236 if err != nil { 237 panic(err) 238 } 239 err = rw.Flush() 240 if err != nil { 241 panic(err) 242 } 243 c.Close() 244 ch <- 1 245 } 246 247 func TestFullListenDialReadWrite(t *testing.T) { 248 pipePath := randomPipePath() 249 l, err := namedpipe.Listen(pipePath) 250 if err != nil { 251 t.Fatal(err) 252 } 253 defer l.Close() 254 255 ch := make(chan int) 256 go server(l, ch) 257 258 c, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) 259 if err != nil { 260 t.Fatal(err) 261 } 262 defer c.Close() 263 264 rw := bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c)) 265 _, err = rw.WriteString("hello world\n") 266 if err != nil { 267 t.Fatal(err) 268 } 269 err = rw.Flush() 270 if err != nil { 271 t.Fatal(err) 272 } 273 274 s, err := rw.ReadString('\n') 275 if err != nil { 276 t.Fatal(err) 277 } 278 ms := "got hello world\n" 279 if s != ms { 280 t.Errorf("expected '%s', got '%s'", ms, s) 281 } 282 283 <-ch 284 } 285 286 func TestCloseAbortsListen(t *testing.T) { 287 pipePath := randomPipePath() 288 l, err := namedpipe.Listen(pipePath) 289 if err != nil { 290 t.Fatal(err) 291 } 292 293 ch := make(chan error) 294 go func() { 295 _, err := l.Accept() 296 ch <- err 297 }() 298 299 time.Sleep(30 * time.Millisecond) 300 l.Close() 301 302 err = <-ch 303 if err != net.ErrClosed { 304 t.Fatalf("expected net.ErrClosed, got %v", err) 305 } 306 } 307 308 func ensureEOFOnClose(t *testing.T, r io.Reader, w io.Closer) { 309 b := make([]byte, 10) 310 w.Close() 311 n, err := r.Read(b) 312 if n > 0 { 313 t.Errorf("unexpected byte count %d", n) 314 } 315 if err != io.EOF { 316 t.Errorf("expected EOF: %v", err) 317 } 318 } 319 320 func TestCloseClientEOFServer(t *testing.T) { 321 c, s, err := getConnection(nil) 322 if err != nil { 323 t.Fatal(err) 324 } 325 defer c.Close() 326 defer s.Close() 327 ensureEOFOnClose(t, c, s) 328 } 329 330 func TestCloseServerEOFClient(t *testing.T) { 331 c, s, err := getConnection(nil) 332 if err != nil { 333 t.Fatal(err) 334 } 335 defer c.Close() 336 defer s.Close() 337 ensureEOFOnClose(t, s, c) 338 } 339 340 func TestCloseWriteEOF(t *testing.T) { 341 cfg := &namedpipe.ListenConfig{ 342 MessageMode: true, 343 } 344 c, s, err := getConnection(cfg) 345 if err != nil { 346 t.Fatal(err) 347 } 348 defer c.Close() 349 defer s.Close() 350 351 type closeWriter interface { 352 CloseWrite() error 353 } 354 355 err = c.(closeWriter).CloseWrite() 356 if err != nil { 357 t.Fatal(err) 358 } 359 360 b := make([]byte, 10) 361 _, err = s.Read(b) 362 if err != io.EOF { 363 t.Fatal(err) 364 } 365 } 366 367 func TestAcceptAfterCloseFails(t *testing.T) { 368 pipePath := randomPipePath() 369 l, err := namedpipe.Listen(pipePath) 370 if err != nil { 371 t.Fatal(err) 372 } 373 l.Close() 374 _, err = l.Accept() 375 if err != net.ErrClosed { 376 t.Fatalf("expected net.ErrClosed, got %v", err) 377 } 378 } 379 380 func TestDialTimesOutByDefault(t *testing.T) { 381 pipePath := randomPipePath() 382 l, err := namedpipe.Listen(pipePath) 383 if err != nil { 384 t.Fatal(err) 385 } 386 defer l.Close() 387 pipe, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) // Should timeout after 2 seconds. 388 if err == nil { 389 pipe.Close() 390 } 391 if err != os.ErrDeadlineExceeded { 392 t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err) 393 } 394 } 395 396 func TestTimeoutPendingRead(t *testing.T) { 397 pipePath := randomPipePath() 398 l, err := namedpipe.Listen(pipePath) 399 if err != nil { 400 t.Fatal(err) 401 } 402 defer l.Close() 403 404 serverDone := make(chan struct{}) 405 406 go func() { 407 s, err := l.Accept() 408 if err != nil { 409 t.Fatal(err) 410 } 411 time.Sleep(1 * time.Second) 412 s.Close() 413 close(serverDone) 414 }() 415 416 client, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) 417 if err != nil { 418 t.Fatal(err) 419 } 420 defer client.Close() 421 422 clientErr := make(chan error) 423 go func() { 424 buf := make([]byte, 10) 425 _, err = client.Read(buf) 426 clientErr <- err 427 }() 428 429 time.Sleep(100 * time.Millisecond) // make *sure* the pipe is reading before we set the deadline 430 client.SetReadDeadline(time.Unix(1, 0)) 431 432 select { 433 case err = <-clientErr: 434 if err != os.ErrDeadlineExceeded { 435 t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err) 436 } 437 case <-time.After(100 * time.Millisecond): 438 t.Fatalf("timed out while waiting for read to cancel") 439 <-clientErr 440 } 441 <-serverDone 442 } 443 444 func TestTimeoutPendingWrite(t *testing.T) { 445 pipePath := randomPipePath() 446 l, err := namedpipe.Listen(pipePath) 447 if err != nil { 448 t.Fatal(err) 449 } 450 defer l.Close() 451 452 serverDone := make(chan struct{}) 453 454 go func() { 455 s, err := l.Accept() 456 if err != nil { 457 t.Fatal(err) 458 } 459 time.Sleep(1 * time.Second) 460 s.Close() 461 close(serverDone) 462 }() 463 464 client, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) 465 if err != nil { 466 t.Fatal(err) 467 } 468 defer client.Close() 469 470 clientErr := make(chan error) 471 go func() { 472 _, err = client.Write([]byte("this should timeout")) 473 clientErr <- err 474 }() 475 476 time.Sleep(100 * time.Millisecond) // make *sure* the pipe is writing before we set the deadline 477 client.SetWriteDeadline(time.Unix(1, 0)) 478 479 select { 480 case err = <-clientErr: 481 if err != os.ErrDeadlineExceeded { 482 t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err) 483 } 484 case <-time.After(100 * time.Millisecond): 485 t.Fatalf("timed out while waiting for write to cancel") 486 <-clientErr 487 } 488 <-serverDone 489 } 490 491 type CloseWriter interface { 492 CloseWrite() error 493 } 494 495 func TestEchoWithMessaging(t *testing.T) { 496 pipePath := randomPipePath() 497 l, err := (&namedpipe.ListenConfig{ 498 MessageMode: true, // Use message mode so that CloseWrite() is supported 499 InputBufferSize: 65536, // Use 64KB buffers to improve performance 500 OutputBufferSize: 65536, 501 }).Listen(pipePath) 502 if err != nil { 503 t.Fatal(err) 504 } 505 defer l.Close() 506 507 listenerDone := make(chan bool) 508 clientDone := make(chan bool) 509 go func() { 510 // server echo 511 conn, err := l.Accept() 512 if err != nil { 513 t.Fatal(err) 514 } 515 defer conn.Close() 516 517 time.Sleep(500 * time.Millisecond) // make *sure* we don't begin to read before eof signal is sent 518 _, err = io.Copy(conn, conn) 519 if err != nil { 520 t.Fatal(err) 521 } 522 conn.(CloseWriter).CloseWrite() 523 close(listenerDone) 524 }() 525 client, err := namedpipe.DialTimeout(pipePath, time.Second) 526 if err != nil { 527 t.Fatal(err) 528 } 529 defer client.Close() 530 531 go func() { 532 // client read back 533 bytes := make([]byte, 2) 534 n, e := client.Read(bytes) 535 if e != nil { 536 t.Fatal(e) 537 } 538 if n != 2 || bytes[0] != 0 || bytes[1] != 1 { 539 t.Fatalf("expected 2 bytes, got %v", n) 540 } 541 close(clientDone) 542 }() 543 544 payload := make([]byte, 2) 545 payload[0] = 0 546 payload[1] = 1 547 548 n, err := client.Write(payload) 549 if err != nil { 550 t.Fatal(err) 551 } 552 if n != 2 { 553 t.Fatalf("expected 2 bytes, got %v", n) 554 } 555 client.(CloseWriter).CloseWrite() 556 <-listenerDone 557 <-clientDone 558 } 559 560 func TestConnectRace(t *testing.T) { 561 pipePath := randomPipePath() 562 l, err := namedpipe.Listen(pipePath) 563 if err != nil { 564 t.Fatal(err) 565 } 566 defer l.Close() 567 go func() { 568 for { 569 s, err := l.Accept() 570 if err == net.ErrClosed { 571 return 572 } 573 574 if err != nil { 575 t.Fatal(err) 576 } 577 s.Close() 578 } 579 }() 580 581 for i := 0; i < 1000; i++ { 582 c, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) 583 if err != nil { 584 t.Fatal(err) 585 } 586 c.Close() 587 } 588 } 589 590 func TestMessageReadMode(t *testing.T) { 591 if maj, _, _ := windows.RtlGetNtVersionNumbers(); maj <= 8 { 592 t.Skipf("Skipping on Windows %d", maj) 593 } 594 var wg sync.WaitGroup 595 defer wg.Wait() 596 pipePath := randomPipePath() 597 l, err := (&namedpipe.ListenConfig{MessageMode: true}).Listen(pipePath) 598 if err != nil { 599 t.Fatal(err) 600 } 601 defer l.Close() 602 603 msg := ([]byte)("hello world") 604 605 wg.Add(1) 606 go func() { 607 defer wg.Done() 608 s, err := l.Accept() 609 if err != nil { 610 t.Fatal(err) 611 } 612 _, err = s.Write(msg) 613 if err != nil { 614 t.Fatal(err) 615 } 616 s.Close() 617 }() 618 619 c, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) 620 if err != nil { 621 t.Fatal(err) 622 } 623 defer c.Close() 624 625 mode := uint32(windows.PIPE_READMODE_MESSAGE) 626 err = windows.SetNamedPipeHandleState(c.(interface{ Handle() windows.Handle }).Handle(), &mode, nil, nil) 627 if err != nil { 628 t.Fatal(err) 629 } 630 631 ch := make([]byte, 1) 632 var vmsg []byte 633 for { 634 n, err := c.Read(ch) 635 if err == io.EOF { 636 break 637 } 638 if err != nil { 639 t.Fatal(err) 640 } 641 if n != 1 { 642 t.Fatalf("expected 1, got %d", n) 643 } 644 vmsg = append(vmsg, ch[0]) 645 } 646 if !bytes.Equal(msg, vmsg) { 647 t.Fatalf("expected %s, got %s", msg, vmsg) 648 } 649 } 650 651 func TestListenConnectRace(t *testing.T) { 652 if testing.Short() { 653 t.Skip("Skipping long race test") 654 } 655 pipePath := randomPipePath() 656 for i := 0; i < 50 && !t.Failed(); i++ { 657 var wg sync.WaitGroup 658 wg.Add(1) 659 go func() { 660 c, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) 661 if err == nil { 662 c.Close() 663 } 664 wg.Done() 665 }() 666 s, err := namedpipe.Listen(pipePath) 667 if err != nil { 668 t.Error(i, err) 669 } else { 670 s.Close() 671 } 672 wg.Wait() 673 } 674 }