github.com/tailscale/wireguard-go@v0.0.20201119-0.20210522003738-46b531feb08a/ipc/winpipe/winpipe_test.go (about)

     1  // +build windows
     2  
     3  /* SPDX-License-Identifier: MIT
     4   *
     5   * Copyright (C) 2005 Microsoft
     6   * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
     7   */
     8  
     9  package winpipe_test
    10  
    11  import (
    12  	"bufio"
    13  	"bytes"
    14  	"context"
    15  	"errors"
    16  	"io"
    17  	"net"
    18  	"os"
    19  	"sync"
    20  	"syscall"
    21  	"testing"
    22  	"time"
    23  
    24  	"github.com/tailscale/wireguard-go/ipc/winpipe"
    25  	"golang.org/x/sys/windows"
    26  )
    27  
    28  func randomPipePath() string {
    29  	guid, err := windows.GenerateGUID()
    30  	if err != nil {
    31  		panic(err)
    32  	}
    33  	return `\\.\PIPE\go-winpipe-test-` + guid.String()
    34  }
    35  
    36  func TestPingPong(t *testing.T) {
    37  	const (
    38  		ping = 42
    39  		pong = 24
    40  	)
    41  	pipePath := randomPipePath()
    42  	listener, err := winpipe.Listen(pipePath, nil)
    43  	if err != nil {
    44  		t.Fatalf("unable to listen on pipe: %v", err)
    45  	}
    46  	defer listener.Close()
    47  	go func() {
    48  		incoming, err := listener.Accept()
    49  		if err != nil {
    50  			t.Fatalf("unable to accept pipe connection: %v", err)
    51  		}
    52  		defer incoming.Close()
    53  		var data [1]byte
    54  		_, err = incoming.Read(data[:])
    55  		if err != nil {
    56  			t.Fatalf("unable to read ping from pipe: %v", err)
    57  		}
    58  		if data[0] != ping {
    59  			t.Fatalf("expected ping, got %d", data[0])
    60  		}
    61  		data[0] = pong
    62  		_, err = incoming.Write(data[:])
    63  		if err != nil {
    64  			t.Fatalf("unable to write pong to pipe: %v", err)
    65  		}
    66  	}()
    67  	client, err := winpipe.Dial(pipePath, nil, nil)
    68  	if err != nil {
    69  		t.Fatalf("unable to dial pipe: %v", err)
    70  	}
    71  	defer client.Close()
    72  	var data [1]byte
    73  	data[0] = ping
    74  	_, err = client.Write(data[:])
    75  	if err != nil {
    76  		t.Fatalf("unable to write ping to pipe: %v", err)
    77  	}
    78  	_, err = client.Read(data[:])
    79  	if err != nil {
    80  		t.Fatalf("unable to read pong from pipe: %v", err)
    81  	}
    82  	if data[0] != pong {
    83  		t.Fatalf("expected pong, got %d", data[0])
    84  	}
    85  }
    86  
    87  func TestDialUnknownFailsImmediately(t *testing.T) {
    88  	_, err := winpipe.Dial(randomPipePath(), nil, nil)
    89  	if !errors.Is(err, syscall.ENOENT) {
    90  		t.Fatalf("expected ENOENT got %v", err)
    91  	}
    92  }
    93  
    94  func TestDialListenerTimesOut(t *testing.T) {
    95  	pipePath := randomPipePath()
    96  	l, err := winpipe.Listen(pipePath, nil)
    97  	if err != nil {
    98  		t.Fatal(err)
    99  	}
   100  	defer l.Close()
   101  	d := 10 * time.Millisecond
   102  	_, err = winpipe.Dial(pipePath, &d, nil)
   103  	if err != os.ErrDeadlineExceeded {
   104  		t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err)
   105  	}
   106  }
   107  
   108  func TestDialContextListenerTimesOut(t *testing.T) {
   109  	pipePath := randomPipePath()
   110  	l, err := winpipe.Listen(pipePath, nil)
   111  	if err != nil {
   112  		t.Fatal(err)
   113  	}
   114  	defer l.Close()
   115  	d := 10 * time.Millisecond
   116  	ctx, _ := context.WithTimeout(context.Background(), d)
   117  	_, err = winpipe.DialContext(ctx, pipePath, nil)
   118  	if err != context.DeadlineExceeded {
   119  		t.Fatalf("expected context.DeadlineExceeded, got %v", err)
   120  	}
   121  }
   122  
   123  func TestDialListenerGetsCancelled(t *testing.T) {
   124  	pipePath := randomPipePath()
   125  	ctx, cancel := context.WithCancel(context.Background())
   126  	l, err := winpipe.Listen(pipePath, nil)
   127  	if err != nil {
   128  		t.Fatal(err)
   129  	}
   130  	ch := make(chan error)
   131  	defer l.Close()
   132  	go func(ctx context.Context, ch chan error) {
   133  		_, err := winpipe.DialContext(ctx, pipePath, nil)
   134  		ch <- err
   135  	}(ctx, ch)
   136  	time.Sleep(time.Millisecond * 30)
   137  	cancel()
   138  	err = <-ch
   139  	if err != context.Canceled {
   140  		t.Fatalf("expected context.Canceled, got %v", err)
   141  	}
   142  }
   143  
   144  func TestDialAccessDeniedWithRestrictedSD(t *testing.T) {
   145  	if windows.NewLazySystemDLL("ntdll.dll").NewProc("wine_get_version").Find() == nil {
   146  		t.Skip("dacls on named pipes are broken on wine")
   147  	}
   148  	pipePath := randomPipePath()
   149  	sd, _ := windows.SecurityDescriptorFromString("D:")
   150  	c := winpipe.ListenConfig{
   151  		SecurityDescriptor: sd,
   152  	}
   153  	l, err := winpipe.Listen(pipePath, &c)
   154  	if err != nil {
   155  		t.Fatal(err)
   156  	}
   157  	defer l.Close()
   158  	_, err = winpipe.Dial(pipePath, nil, nil)
   159  	if !errors.Is(err, windows.ERROR_ACCESS_DENIED) {
   160  		t.Fatalf("expected ERROR_ACCESS_DENIED, got %v", err)
   161  	}
   162  }
   163  
   164  func getConnection(cfg *winpipe.ListenConfig) (client net.Conn, server net.Conn, err error) {
   165  	pipePath := randomPipePath()
   166  	l, err := winpipe.Listen(pipePath, cfg)
   167  	if err != nil {
   168  		return
   169  	}
   170  	defer l.Close()
   171  
   172  	type response struct {
   173  		c   net.Conn
   174  		err error
   175  	}
   176  	ch := make(chan response)
   177  	go func() {
   178  		c, err := l.Accept()
   179  		ch <- response{c, err}
   180  	}()
   181  
   182  	c, err := winpipe.Dial(pipePath, nil, nil)
   183  	if err != nil {
   184  		return
   185  	}
   186  
   187  	r := <-ch
   188  	if err = r.err; err != nil {
   189  		c.Close()
   190  		return
   191  	}
   192  
   193  	client = c
   194  	server = r.c
   195  	return
   196  }
   197  
   198  func TestReadTimeout(t *testing.T) {
   199  	c, s, err := getConnection(nil)
   200  	if err != nil {
   201  		t.Fatal(err)
   202  	}
   203  	defer c.Close()
   204  	defer s.Close()
   205  
   206  	c.SetReadDeadline(time.Now().Add(10 * time.Millisecond))
   207  
   208  	buf := make([]byte, 10)
   209  	_, err = c.Read(buf)
   210  	if err != os.ErrDeadlineExceeded {
   211  		t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err)
   212  	}
   213  }
   214  
   215  func server(l net.Listener, ch chan int) {
   216  	c, err := l.Accept()
   217  	if err != nil {
   218  		panic(err)
   219  	}
   220  	rw := bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c))
   221  	s, err := rw.ReadString('\n')
   222  	if err != nil {
   223  		panic(err)
   224  	}
   225  	_, err = rw.WriteString("got " + s)
   226  	if err != nil {
   227  		panic(err)
   228  	}
   229  	err = rw.Flush()
   230  	if err != nil {
   231  		panic(err)
   232  	}
   233  	c.Close()
   234  	ch <- 1
   235  }
   236  
   237  func TestFullListenDialReadWrite(t *testing.T) {
   238  	pipePath := randomPipePath()
   239  	l, err := winpipe.Listen(pipePath, nil)
   240  	if err != nil {
   241  		t.Fatal(err)
   242  	}
   243  	defer l.Close()
   244  
   245  	ch := make(chan int)
   246  	go server(l, ch)
   247  
   248  	c, err := winpipe.Dial(pipePath, nil, nil)
   249  	if err != nil {
   250  		t.Fatal(err)
   251  	}
   252  	defer c.Close()
   253  
   254  	rw := bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c))
   255  	_, err = rw.WriteString("hello world\n")
   256  	if err != nil {
   257  		t.Fatal(err)
   258  	}
   259  	err = rw.Flush()
   260  	if err != nil {
   261  		t.Fatal(err)
   262  	}
   263  
   264  	s, err := rw.ReadString('\n')
   265  	if err != nil {
   266  		t.Fatal(err)
   267  	}
   268  	ms := "got hello world\n"
   269  	if s != ms {
   270  		t.Errorf("expected '%s', got '%s'", ms, s)
   271  	}
   272  
   273  	<-ch
   274  }
   275  
   276  func TestCloseAbortsListen(t *testing.T) {
   277  	pipePath := randomPipePath()
   278  	l, err := winpipe.Listen(pipePath, nil)
   279  	if err != nil {
   280  		t.Fatal(err)
   281  	}
   282  
   283  	ch := make(chan error)
   284  	go func() {
   285  		_, err := l.Accept()
   286  		ch <- err
   287  	}()
   288  
   289  	time.Sleep(30 * time.Millisecond)
   290  	l.Close()
   291  
   292  	err = <-ch
   293  	if err != net.ErrClosed {
   294  		t.Fatalf("expected net.ErrClosed, got %v", err)
   295  	}
   296  }
   297  
   298  func ensureEOFOnClose(t *testing.T, r io.Reader, w io.Closer) {
   299  	b := make([]byte, 10)
   300  	w.Close()
   301  	n, err := r.Read(b)
   302  	if n > 0 {
   303  		t.Errorf("unexpected byte count %d", n)
   304  	}
   305  	if err != io.EOF {
   306  		t.Errorf("expected EOF: %v", err)
   307  	}
   308  }
   309  
   310  func TestCloseClientEOFServer(t *testing.T) {
   311  	c, s, err := getConnection(nil)
   312  	if err != nil {
   313  		t.Fatal(err)
   314  	}
   315  	defer c.Close()
   316  	defer s.Close()
   317  	ensureEOFOnClose(t, c, s)
   318  }
   319  
   320  func TestCloseServerEOFClient(t *testing.T) {
   321  	c, s, err := getConnection(nil)
   322  	if err != nil {
   323  		t.Fatal(err)
   324  	}
   325  	defer c.Close()
   326  	defer s.Close()
   327  	ensureEOFOnClose(t, s, c)
   328  }
   329  
   330  func TestCloseWriteEOF(t *testing.T) {
   331  	cfg := &winpipe.ListenConfig{
   332  		MessageMode: true,
   333  	}
   334  	c, s, err := getConnection(cfg)
   335  	if err != nil {
   336  		t.Fatal(err)
   337  	}
   338  	defer c.Close()
   339  	defer s.Close()
   340  
   341  	type closeWriter interface {
   342  		CloseWrite() error
   343  	}
   344  
   345  	err = c.(closeWriter).CloseWrite()
   346  	if err != nil {
   347  		t.Fatal(err)
   348  	}
   349  
   350  	b := make([]byte, 10)
   351  	_, err = s.Read(b)
   352  	if err != io.EOF {
   353  		t.Fatal(err)
   354  	}
   355  }
   356  
   357  func TestAcceptAfterCloseFails(t *testing.T) {
   358  	pipePath := randomPipePath()
   359  	l, err := winpipe.Listen(pipePath, nil)
   360  	if err != nil {
   361  		t.Fatal(err)
   362  	}
   363  	l.Close()
   364  	_, err = l.Accept()
   365  	if err != net.ErrClosed {
   366  		t.Fatalf("expected net.ErrClosed, got %v", err)
   367  	}
   368  }
   369  
   370  func TestDialTimesOutByDefault(t *testing.T) {
   371  	pipePath := randomPipePath()
   372  	l, err := winpipe.Listen(pipePath, nil)
   373  	if err != nil {
   374  		t.Fatal(err)
   375  	}
   376  	defer l.Close()
   377  	_, err = winpipe.Dial(pipePath, nil, nil)
   378  	if err != os.ErrDeadlineExceeded {
   379  		t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err)
   380  	}
   381  }
   382  
   383  func TestTimeoutPendingRead(t *testing.T) {
   384  	pipePath := randomPipePath()
   385  	l, err := winpipe.Listen(pipePath, nil)
   386  	if err != nil {
   387  		t.Fatal(err)
   388  	}
   389  	defer l.Close()
   390  
   391  	serverDone := make(chan struct{})
   392  
   393  	go func() {
   394  		s, err := l.Accept()
   395  		if err != nil {
   396  			t.Fatal(err)
   397  		}
   398  		time.Sleep(1 * time.Second)
   399  		s.Close()
   400  		close(serverDone)
   401  	}()
   402  
   403  	client, err := winpipe.Dial(pipePath, nil, nil)
   404  	if err != nil {
   405  		t.Fatal(err)
   406  	}
   407  	defer client.Close()
   408  
   409  	clientErr := make(chan error)
   410  	go func() {
   411  		buf := make([]byte, 10)
   412  		_, err = client.Read(buf)
   413  		clientErr <- err
   414  	}()
   415  
   416  	time.Sleep(100 * time.Millisecond) // make *sure* the pipe is reading before we set the deadline
   417  	client.SetReadDeadline(time.Unix(1, 0))
   418  
   419  	select {
   420  	case err = <-clientErr:
   421  		if err != os.ErrDeadlineExceeded {
   422  			t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err)
   423  		}
   424  	case <-time.After(100 * time.Millisecond):
   425  		t.Fatalf("timed out while waiting for read to cancel")
   426  		<-clientErr
   427  	}
   428  	<-serverDone
   429  }
   430  
   431  func TestTimeoutPendingWrite(t *testing.T) {
   432  	pipePath := randomPipePath()
   433  	l, err := winpipe.Listen(pipePath, nil)
   434  	if err != nil {
   435  		t.Fatal(err)
   436  	}
   437  	defer l.Close()
   438  
   439  	serverDone := make(chan struct{})
   440  
   441  	go func() {
   442  		s, err := l.Accept()
   443  		if err != nil {
   444  			t.Fatal(err)
   445  		}
   446  		time.Sleep(1 * time.Second)
   447  		s.Close()
   448  		close(serverDone)
   449  	}()
   450  
   451  	client, err := winpipe.Dial(pipePath, nil, nil)
   452  	if err != nil {
   453  		t.Fatal(err)
   454  	}
   455  	defer client.Close()
   456  
   457  	clientErr := make(chan error)
   458  	go func() {
   459  		_, err = client.Write([]byte("this should timeout"))
   460  		clientErr <- err
   461  	}()
   462  
   463  	time.Sleep(100 * time.Millisecond) // make *sure* the pipe is writing before we set the deadline
   464  	client.SetWriteDeadline(time.Unix(1, 0))
   465  
   466  	select {
   467  	case err = <-clientErr:
   468  		if err != os.ErrDeadlineExceeded {
   469  			t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err)
   470  		}
   471  	case <-time.After(100 * time.Millisecond):
   472  		t.Fatalf("timed out while waiting for write to cancel")
   473  		<-clientErr
   474  	}
   475  	<-serverDone
   476  }
   477  
   478  type CloseWriter interface {
   479  	CloseWrite() error
   480  }
   481  
   482  func TestEchoWithMessaging(t *testing.T) {
   483  	c := winpipe.ListenConfig{
   484  		MessageMode:      true,  // Use message mode so that CloseWrite() is supported
   485  		InputBufferSize:  65536, // Use 64KB buffers to improve performance
   486  		OutputBufferSize: 65536,
   487  	}
   488  	pipePath := randomPipePath()
   489  	l, err := winpipe.Listen(pipePath, &c)
   490  	if err != nil {
   491  		t.Fatal(err)
   492  	}
   493  	defer l.Close()
   494  
   495  	listenerDone := make(chan bool)
   496  	clientDone := make(chan bool)
   497  	go func() {
   498  		// server echo
   499  		conn, e := l.Accept()
   500  		if e != nil {
   501  			t.Fatal(e)
   502  		}
   503  		defer conn.Close()
   504  
   505  		time.Sleep(500 * time.Millisecond) // make *sure* we don't begin to read before eof signal is sent
   506  		io.Copy(conn, conn)
   507  		conn.(CloseWriter).CloseWrite()
   508  		close(listenerDone)
   509  	}()
   510  	timeout := 1 * time.Second
   511  	client, err := winpipe.Dial(pipePath, &timeout, nil)
   512  	if err != nil {
   513  		t.Fatal(err)
   514  	}
   515  	defer client.Close()
   516  
   517  	go func() {
   518  		// client read back
   519  		bytes := make([]byte, 2)
   520  		n, e := client.Read(bytes)
   521  		if e != nil {
   522  			t.Fatal(e)
   523  		}
   524  		if n != 2 {
   525  			t.Fatalf("expected 2 bytes, got %v", n)
   526  		}
   527  		close(clientDone)
   528  	}()
   529  
   530  	payload := make([]byte, 2)
   531  	payload[0] = 0
   532  	payload[1] = 1
   533  
   534  	n, err := client.Write(payload)
   535  	if err != nil {
   536  		t.Fatal(err)
   537  	}
   538  	if n != 2 {
   539  		t.Fatalf("expected 2 bytes, got %v", n)
   540  	}
   541  	client.(CloseWriter).CloseWrite()
   542  	<-listenerDone
   543  	<-clientDone
   544  }
   545  
   546  func TestConnectRace(t *testing.T) {
   547  	pipePath := randomPipePath()
   548  	l, err := winpipe.Listen(pipePath, nil)
   549  	if err != nil {
   550  		t.Fatal(err)
   551  	}
   552  	defer l.Close()
   553  	go func() {
   554  		for {
   555  			s, err := l.Accept()
   556  			if err == net.ErrClosed {
   557  				return
   558  			}
   559  
   560  			if err != nil {
   561  				t.Fatal(err)
   562  			}
   563  			s.Close()
   564  		}
   565  	}()
   566  
   567  	for i := 0; i < 1000; i++ {
   568  		c, err := winpipe.Dial(pipePath, nil, nil)
   569  		if err != nil {
   570  			t.Fatal(err)
   571  		}
   572  		c.Close()
   573  	}
   574  }
   575  
   576  func TestMessageReadMode(t *testing.T) {
   577  	if maj, _, _ := windows.RtlGetNtVersionNumbers(); maj <= 8 {
   578  		t.Skipf("Skipping on Windows %d", maj)
   579  	}
   580  	var wg sync.WaitGroup
   581  	defer wg.Wait()
   582  	pipePath := randomPipePath()
   583  	l, err := winpipe.Listen(pipePath, &winpipe.ListenConfig{MessageMode: true})
   584  	if err != nil {
   585  		t.Fatal(err)
   586  	}
   587  	defer l.Close()
   588  
   589  	msg := ([]byte)("hello world")
   590  
   591  	wg.Add(1)
   592  	go func() {
   593  		defer wg.Done()
   594  		s, err := l.Accept()
   595  		if err != nil {
   596  			t.Fatal(err)
   597  		}
   598  		_, err = s.Write(msg)
   599  		if err != nil {
   600  			t.Fatal(err)
   601  		}
   602  		s.Close()
   603  	}()
   604  
   605  	c, err := winpipe.Dial(pipePath, nil, nil)
   606  	if err != nil {
   607  		t.Fatal(err)
   608  	}
   609  	defer c.Close()
   610  
   611  	mode := uint32(windows.PIPE_READMODE_MESSAGE)
   612  	err = windows.SetNamedPipeHandleState(c.(interface{ Handle() windows.Handle }).Handle(), &mode, nil, nil)
   613  	if err != nil {
   614  		t.Fatal(err)
   615  	}
   616  
   617  	ch := make([]byte, 1)
   618  	var vmsg []byte
   619  	for {
   620  		n, err := c.Read(ch)
   621  		if err == io.EOF {
   622  			break
   623  		}
   624  		if err != nil {
   625  			t.Fatal(err)
   626  		}
   627  		if n != 1 {
   628  			t.Fatalf("expected 1, got %d", n)
   629  		}
   630  		vmsg = append(vmsg, ch[0])
   631  	}
   632  	if !bytes.Equal(msg, vmsg) {
   633  		t.Fatalf("expected %s, got %s", msg, vmsg)
   634  	}
   635  }
   636  
   637  func TestListenConnectRace(t *testing.T) {
   638  	if testing.Short() {
   639  		t.Skip("Skipping long race test")
   640  	}
   641  	pipePath := randomPipePath()
   642  	for i := 0; i < 50 && !t.Failed(); i++ {
   643  		var wg sync.WaitGroup
   644  		wg.Add(1)
   645  		go func() {
   646  			c, err := winpipe.Dial(pipePath, nil, nil)
   647  			if err == nil {
   648  				c.Close()
   649  			}
   650  			wg.Done()
   651  		}()
   652  		s, err := winpipe.Listen(pipePath, nil)
   653  		if err != nil {
   654  			t.Error(i, err)
   655  		} else {
   656  			s.Close()
   657  		}
   658  		wg.Wait()
   659  	}
   660  }