github.com/liloew/wireguard-go@v0.0.0-20220224014633-9cd745e6f114/ipc/namedpipe/namedpipe_test.go (about)

     1  // Copyright 2021 The Go Authors. All rights reserved.
     2  // Copyright 2015 Microsoft
     3  // Use of this source code is governed by a BSD-style
     4  // license that can be found in the LICENSE file.
     5  
     6  //go:build windows
     7  // +build windows
     8  
     9  package namedpipe_test
    10  
    11  import (
    12  	"bufio"
    13  	"bytes"
    14  	"context"
    15  	"errors"
    16  	"io"
    17  	"net"
    18  	"os"
    19  	"sync"
    20  	"syscall"
    21  	"testing"
    22  	"time"
    23  
    24  	"golang.org/x/sys/windows"
    25  	"github.com/liloew/wireguard-go/ipc/namedpipe"
    26  )
    27  
    28  func randomPipePath() string {
    29  	guid, err := windows.GenerateGUID()
    30  	if err != nil {
    31  		panic(err)
    32  	}
    33  	return `\\.\PIPE\go-namedpipe-test-` + guid.String()
    34  }
    35  
    36  func TestPingPong(t *testing.T) {
    37  	const (
    38  		ping = 42
    39  		pong = 24
    40  	)
    41  	pipePath := randomPipePath()
    42  	listener, err := namedpipe.Listen(pipePath)
    43  	if err != nil {
    44  		t.Fatalf("unable to listen on pipe: %v", err)
    45  	}
    46  	defer listener.Close()
    47  	go func() {
    48  		incoming, err := listener.Accept()
    49  		if err != nil {
    50  			t.Fatalf("unable to accept pipe connection: %v", err)
    51  		}
    52  		defer incoming.Close()
    53  		var data [1]byte
    54  		_, err = incoming.Read(data[:])
    55  		if err != nil {
    56  			t.Fatalf("unable to read ping from pipe: %v", err)
    57  		}
    58  		if data[0] != ping {
    59  			t.Fatalf("expected ping, got %d", data[0])
    60  		}
    61  		data[0] = pong
    62  		_, err = incoming.Write(data[:])
    63  		if err != nil {
    64  			t.Fatalf("unable to write pong to pipe: %v", err)
    65  		}
    66  	}()
    67  	client, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
    68  	if err != nil {
    69  		t.Fatalf("unable to dial pipe: %v", err)
    70  	}
    71  	defer client.Close()
    72  	client.SetDeadline(time.Now().Add(time.Second * 5))
    73  	var data [1]byte
    74  	data[0] = ping
    75  	_, err = client.Write(data[:])
    76  	if err != nil {
    77  		t.Fatalf("unable to write ping to pipe: %v", err)
    78  	}
    79  	_, err = client.Read(data[:])
    80  	if err != nil {
    81  		t.Fatalf("unable to read pong from pipe: %v", err)
    82  	}
    83  	if data[0] != pong {
    84  		t.Fatalf("expected pong, got %d", data[0])
    85  	}
    86  }
    87  
    88  func TestDialUnknownFailsImmediately(t *testing.T) {
    89  	_, err := namedpipe.DialTimeout(randomPipePath(), time.Duration(0))
    90  	if !errors.Is(err, syscall.ENOENT) {
    91  		t.Fatalf("expected ENOENT got %v", err)
    92  	}
    93  }
    94  
    95  func TestDialListenerTimesOut(t *testing.T) {
    96  	pipePath := randomPipePath()
    97  	l, err := namedpipe.Listen(pipePath)
    98  	if err != nil {
    99  		t.Fatal(err)
   100  	}
   101  	defer l.Close()
   102  	pipe, err := namedpipe.DialTimeout(pipePath, 10*time.Millisecond)
   103  	if err == nil {
   104  		pipe.Close()
   105  	}
   106  	if err != os.ErrDeadlineExceeded {
   107  		t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err)
   108  	}
   109  }
   110  
   111  func TestDialContextListenerTimesOut(t *testing.T) {
   112  	pipePath := randomPipePath()
   113  	l, err := namedpipe.Listen(pipePath)
   114  	if err != nil {
   115  		t.Fatal(err)
   116  	}
   117  	defer l.Close()
   118  	d := 10 * time.Millisecond
   119  	ctx, _ := context.WithTimeout(context.Background(), d)
   120  	pipe, err := namedpipe.DialContext(ctx, pipePath)
   121  	if err == nil {
   122  		pipe.Close()
   123  	}
   124  	if err != context.DeadlineExceeded {
   125  		t.Fatalf("expected context.DeadlineExceeded, got %v", err)
   126  	}
   127  }
   128  
   129  func TestDialListenerGetsCancelled(t *testing.T) {
   130  	pipePath := randomPipePath()
   131  	ctx, cancel := context.WithCancel(context.Background())
   132  	l, err := namedpipe.Listen(pipePath)
   133  	if err != nil {
   134  		t.Fatal(err)
   135  	}
   136  	defer l.Close()
   137  	ch := make(chan error)
   138  	go func(ctx context.Context, ch chan error) {
   139  		_, err := namedpipe.DialContext(ctx, pipePath)
   140  		ch <- err
   141  	}(ctx, ch)
   142  	time.Sleep(time.Millisecond * 30)
   143  	cancel()
   144  	err = <-ch
   145  	if err != context.Canceled {
   146  		t.Fatalf("expected context.Canceled, got %v", err)
   147  	}
   148  }
   149  
   150  func TestDialAccessDeniedWithRestrictedSD(t *testing.T) {
   151  	if windows.NewLazySystemDLL("ntdll.dll").NewProc("wine_get_version").Find() == nil {
   152  		t.Skip("dacls on named pipes are broken on wine")
   153  	}
   154  	pipePath := randomPipePath()
   155  	sd, _ := windows.SecurityDescriptorFromString("D:")
   156  	l, err := (&namedpipe.ListenConfig{
   157  		SecurityDescriptor: sd,
   158  	}).Listen(pipePath)
   159  	if err != nil {
   160  		t.Fatal(err)
   161  	}
   162  	defer l.Close()
   163  	pipe, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
   164  	if err == nil {
   165  		pipe.Close()
   166  	}
   167  	if !errors.Is(err, windows.ERROR_ACCESS_DENIED) {
   168  		t.Fatalf("expected ERROR_ACCESS_DENIED, got %v", err)
   169  	}
   170  }
   171  
   172  func getConnection(cfg *namedpipe.ListenConfig) (client, server net.Conn, err error) {
   173  	pipePath := randomPipePath()
   174  	if cfg == nil {
   175  		cfg = &namedpipe.ListenConfig{}
   176  	}
   177  	l, err := cfg.Listen(pipePath)
   178  	if err != nil {
   179  		return
   180  	}
   181  	defer l.Close()
   182  
   183  	type response struct {
   184  		c   net.Conn
   185  		err error
   186  	}
   187  	ch := make(chan response)
   188  	go func() {
   189  		c, err := l.Accept()
   190  		ch <- response{c, err}
   191  	}()
   192  
   193  	c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
   194  	if err != nil {
   195  		return
   196  	}
   197  
   198  	r := <-ch
   199  	if err = r.err; err != nil {
   200  		c.Close()
   201  		return
   202  	}
   203  
   204  	client = c
   205  	server = r.c
   206  	return
   207  }
   208  
   209  func TestReadTimeout(t *testing.T) {
   210  	c, s, err := getConnection(nil)
   211  	if err != nil {
   212  		t.Fatal(err)
   213  	}
   214  	defer c.Close()
   215  	defer s.Close()
   216  
   217  	c.SetReadDeadline(time.Now().Add(10 * time.Millisecond))
   218  
   219  	buf := make([]byte, 10)
   220  	_, err = c.Read(buf)
   221  	if err != os.ErrDeadlineExceeded {
   222  		t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err)
   223  	}
   224  }
   225  
   226  func server(l net.Listener, ch chan int) {
   227  	c, err := l.Accept()
   228  	if err != nil {
   229  		panic(err)
   230  	}
   231  	rw := bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c))
   232  	s, err := rw.ReadString('\n')
   233  	if err != nil {
   234  		panic(err)
   235  	}
   236  	_, err = rw.WriteString("got " + s)
   237  	if err != nil {
   238  		panic(err)
   239  	}
   240  	err = rw.Flush()
   241  	if err != nil {
   242  		panic(err)
   243  	}
   244  	c.Close()
   245  	ch <- 1
   246  }
   247  
   248  func TestFullListenDialReadWrite(t *testing.T) {
   249  	pipePath := randomPipePath()
   250  	l, err := namedpipe.Listen(pipePath)
   251  	if err != nil {
   252  		t.Fatal(err)
   253  	}
   254  	defer l.Close()
   255  
   256  	ch := make(chan int)
   257  	go server(l, ch)
   258  
   259  	c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
   260  	if err != nil {
   261  		t.Fatal(err)
   262  	}
   263  	defer c.Close()
   264  
   265  	rw := bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c))
   266  	_, err = rw.WriteString("hello world\n")
   267  	if err != nil {
   268  		t.Fatal(err)
   269  	}
   270  	err = rw.Flush()
   271  	if err != nil {
   272  		t.Fatal(err)
   273  	}
   274  
   275  	s, err := rw.ReadString('\n')
   276  	if err != nil {
   277  		t.Fatal(err)
   278  	}
   279  	ms := "got hello world\n"
   280  	if s != ms {
   281  		t.Errorf("expected '%s', got '%s'", ms, s)
   282  	}
   283  
   284  	<-ch
   285  }
   286  
   287  func TestCloseAbortsListen(t *testing.T) {
   288  	pipePath := randomPipePath()
   289  	l, err := namedpipe.Listen(pipePath)
   290  	if err != nil {
   291  		t.Fatal(err)
   292  	}
   293  
   294  	ch := make(chan error)
   295  	go func() {
   296  		_, err := l.Accept()
   297  		ch <- err
   298  	}()
   299  
   300  	time.Sleep(30 * time.Millisecond)
   301  	l.Close()
   302  
   303  	err = <-ch
   304  	if err != net.ErrClosed {
   305  		t.Fatalf("expected net.ErrClosed, got %v", err)
   306  	}
   307  }
   308  
   309  func ensureEOFOnClose(t *testing.T, r io.Reader, w io.Closer) {
   310  	b := make([]byte, 10)
   311  	w.Close()
   312  	n, err := r.Read(b)
   313  	if n > 0 {
   314  		t.Errorf("unexpected byte count %d", n)
   315  	}
   316  	if err != io.EOF {
   317  		t.Errorf("expected EOF: %v", err)
   318  	}
   319  }
   320  
   321  func TestCloseClientEOFServer(t *testing.T) {
   322  	c, s, err := getConnection(nil)
   323  	if err != nil {
   324  		t.Fatal(err)
   325  	}
   326  	defer c.Close()
   327  	defer s.Close()
   328  	ensureEOFOnClose(t, c, s)
   329  }
   330  
   331  func TestCloseServerEOFClient(t *testing.T) {
   332  	c, s, err := getConnection(nil)
   333  	if err != nil {
   334  		t.Fatal(err)
   335  	}
   336  	defer c.Close()
   337  	defer s.Close()
   338  	ensureEOFOnClose(t, s, c)
   339  }
   340  
   341  func TestCloseWriteEOF(t *testing.T) {
   342  	cfg := &namedpipe.ListenConfig{
   343  		MessageMode: true,
   344  	}
   345  	c, s, err := getConnection(cfg)
   346  	if err != nil {
   347  		t.Fatal(err)
   348  	}
   349  	defer c.Close()
   350  	defer s.Close()
   351  
   352  	type closeWriter interface {
   353  		CloseWrite() error
   354  	}
   355  
   356  	err = c.(closeWriter).CloseWrite()
   357  	if err != nil {
   358  		t.Fatal(err)
   359  	}
   360  
   361  	b := make([]byte, 10)
   362  	_, err = s.Read(b)
   363  	if err != io.EOF {
   364  		t.Fatal(err)
   365  	}
   366  }
   367  
   368  func TestAcceptAfterCloseFails(t *testing.T) {
   369  	pipePath := randomPipePath()
   370  	l, err := namedpipe.Listen(pipePath)
   371  	if err != nil {
   372  		t.Fatal(err)
   373  	}
   374  	l.Close()
   375  	_, err = l.Accept()
   376  	if err != net.ErrClosed {
   377  		t.Fatalf("expected net.ErrClosed, got %v", err)
   378  	}
   379  }
   380  
   381  func TestDialTimesOutByDefault(t *testing.T) {
   382  	pipePath := randomPipePath()
   383  	l, err := namedpipe.Listen(pipePath)
   384  	if err != nil {
   385  		t.Fatal(err)
   386  	}
   387  	defer l.Close()
   388  	pipe, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) // Should timeout after 2 seconds.
   389  	if err == nil {
   390  		pipe.Close()
   391  	}
   392  	if err != os.ErrDeadlineExceeded {
   393  		t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err)
   394  	}
   395  }
   396  
   397  func TestTimeoutPendingRead(t *testing.T) {
   398  	pipePath := randomPipePath()
   399  	l, err := namedpipe.Listen(pipePath)
   400  	if err != nil {
   401  		t.Fatal(err)
   402  	}
   403  	defer l.Close()
   404  
   405  	serverDone := make(chan struct{})
   406  
   407  	go func() {
   408  		s, err := l.Accept()
   409  		if err != nil {
   410  			t.Fatal(err)
   411  		}
   412  		time.Sleep(1 * time.Second)
   413  		s.Close()
   414  		close(serverDone)
   415  	}()
   416  
   417  	client, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
   418  	if err != nil {
   419  		t.Fatal(err)
   420  	}
   421  	defer client.Close()
   422  
   423  	clientErr := make(chan error)
   424  	go func() {
   425  		buf := make([]byte, 10)
   426  		_, err = client.Read(buf)
   427  		clientErr <- err
   428  	}()
   429  
   430  	time.Sleep(100 * time.Millisecond) // make *sure* the pipe is reading before we set the deadline
   431  	client.SetReadDeadline(time.Unix(1, 0))
   432  
   433  	select {
   434  	case err = <-clientErr:
   435  		if err != os.ErrDeadlineExceeded {
   436  			t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err)
   437  		}
   438  	case <-time.After(100 * time.Millisecond):
   439  		t.Fatalf("timed out while waiting for read to cancel")
   440  		<-clientErr
   441  	}
   442  	<-serverDone
   443  }
   444  
   445  func TestTimeoutPendingWrite(t *testing.T) {
   446  	pipePath := randomPipePath()
   447  	l, err := namedpipe.Listen(pipePath)
   448  	if err != nil {
   449  		t.Fatal(err)
   450  	}
   451  	defer l.Close()
   452  
   453  	serverDone := make(chan struct{})
   454  
   455  	go func() {
   456  		s, err := l.Accept()
   457  		if err != nil {
   458  			t.Fatal(err)
   459  		}
   460  		time.Sleep(1 * time.Second)
   461  		s.Close()
   462  		close(serverDone)
   463  	}()
   464  
   465  	client, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
   466  	if err != nil {
   467  		t.Fatal(err)
   468  	}
   469  	defer client.Close()
   470  
   471  	clientErr := make(chan error)
   472  	go func() {
   473  		_, err = client.Write([]byte("this should timeout"))
   474  		clientErr <- err
   475  	}()
   476  
   477  	time.Sleep(100 * time.Millisecond) // make *sure* the pipe is writing before we set the deadline
   478  	client.SetWriteDeadline(time.Unix(1, 0))
   479  
   480  	select {
   481  	case err = <-clientErr:
   482  		if err != os.ErrDeadlineExceeded {
   483  			t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err)
   484  		}
   485  	case <-time.After(100 * time.Millisecond):
   486  		t.Fatalf("timed out while waiting for write to cancel")
   487  		<-clientErr
   488  	}
   489  	<-serverDone
   490  }
   491  
   492  type CloseWriter interface {
   493  	CloseWrite() error
   494  }
   495  
   496  func TestEchoWithMessaging(t *testing.T) {
   497  	pipePath := randomPipePath()
   498  	l, err := (&namedpipe.ListenConfig{
   499  		MessageMode:      true,  // Use message mode so that CloseWrite() is supported
   500  		InputBufferSize:  65536, // Use 64KB buffers to improve performance
   501  		OutputBufferSize: 65536,
   502  	}).Listen(pipePath)
   503  	if err != nil {
   504  		t.Fatal(err)
   505  	}
   506  	defer l.Close()
   507  
   508  	listenerDone := make(chan bool)
   509  	clientDone := make(chan bool)
   510  	go func() {
   511  		// server echo
   512  		conn, err := l.Accept()
   513  		if err != nil {
   514  			t.Fatal(err)
   515  		}
   516  		defer conn.Close()
   517  
   518  		time.Sleep(500 * time.Millisecond) // make *sure* we don't begin to read before eof signal is sent
   519  		_, err = io.Copy(conn, conn)
   520  		if err != nil {
   521  			t.Fatal(err)
   522  		}
   523  		conn.(CloseWriter).CloseWrite()
   524  		close(listenerDone)
   525  	}()
   526  	client, err := namedpipe.DialTimeout(pipePath, time.Second)
   527  	if err != nil {
   528  		t.Fatal(err)
   529  	}
   530  	defer client.Close()
   531  
   532  	go func() {
   533  		// client read back
   534  		bytes := make([]byte, 2)
   535  		n, e := client.Read(bytes)
   536  		if e != nil {
   537  			t.Fatal(e)
   538  		}
   539  		if n != 2 || bytes[0] != 0 || bytes[1] != 1 {
   540  			t.Fatalf("expected 2 bytes, got %v", n)
   541  		}
   542  		close(clientDone)
   543  	}()
   544  
   545  	payload := make([]byte, 2)
   546  	payload[0] = 0
   547  	payload[1] = 1
   548  
   549  	n, err := client.Write(payload)
   550  	if err != nil {
   551  		t.Fatal(err)
   552  	}
   553  	if n != 2 {
   554  		t.Fatalf("expected 2 bytes, got %v", n)
   555  	}
   556  	client.(CloseWriter).CloseWrite()
   557  	<-listenerDone
   558  	<-clientDone
   559  }
   560  
   561  func TestConnectRace(t *testing.T) {
   562  	pipePath := randomPipePath()
   563  	l, err := namedpipe.Listen(pipePath)
   564  	if err != nil {
   565  		t.Fatal(err)
   566  	}
   567  	defer l.Close()
   568  	go func() {
   569  		for {
   570  			s, err := l.Accept()
   571  			if err == net.ErrClosed {
   572  				return
   573  			}
   574  
   575  			if err != nil {
   576  				t.Fatal(err)
   577  			}
   578  			s.Close()
   579  		}
   580  	}()
   581  
   582  	for i := 0; i < 1000; i++ {
   583  		c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
   584  		if err != nil {
   585  			t.Fatal(err)
   586  		}
   587  		c.Close()
   588  	}
   589  }
   590  
   591  func TestMessageReadMode(t *testing.T) {
   592  	if maj, _, _ := windows.RtlGetNtVersionNumbers(); maj <= 8 {
   593  		t.Skipf("Skipping on Windows %d", maj)
   594  	}
   595  	var wg sync.WaitGroup
   596  	defer wg.Wait()
   597  	pipePath := randomPipePath()
   598  	l, err := (&namedpipe.ListenConfig{MessageMode: true}).Listen(pipePath)
   599  	if err != nil {
   600  		t.Fatal(err)
   601  	}
   602  	defer l.Close()
   603  
   604  	msg := ([]byte)("hello world")
   605  
   606  	wg.Add(1)
   607  	go func() {
   608  		defer wg.Done()
   609  		s, err := l.Accept()
   610  		if err != nil {
   611  			t.Fatal(err)
   612  		}
   613  		_, err = s.Write(msg)
   614  		if err != nil {
   615  			t.Fatal(err)
   616  		}
   617  		s.Close()
   618  	}()
   619  
   620  	c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
   621  	if err != nil {
   622  		t.Fatal(err)
   623  	}
   624  	defer c.Close()
   625  
   626  	mode := uint32(windows.PIPE_READMODE_MESSAGE)
   627  	err = windows.SetNamedPipeHandleState(c.(interface{ Handle() windows.Handle }).Handle(), &mode, nil, nil)
   628  	if err != nil {
   629  		t.Fatal(err)
   630  	}
   631  
   632  	ch := make([]byte, 1)
   633  	var vmsg []byte
   634  	for {
   635  		n, err := c.Read(ch)
   636  		if err == io.EOF {
   637  			break
   638  		}
   639  		if err != nil {
   640  			t.Fatal(err)
   641  		}
   642  		if n != 1 {
   643  			t.Fatalf("expected 1, got %d", n)
   644  		}
   645  		vmsg = append(vmsg, ch[0])
   646  	}
   647  	if !bytes.Equal(msg, vmsg) {
   648  		t.Fatalf("expected %s, got %s", msg, vmsg)
   649  	}
   650  }
   651  
   652  func TestListenConnectRace(t *testing.T) {
   653  	if testing.Short() {
   654  		t.Skip("Skipping long race test")
   655  	}
   656  	pipePath := randomPipePath()
   657  	for i := 0; i < 50 && !t.Failed(); i++ {
   658  		var wg sync.WaitGroup
   659  		wg.Add(1)
   660  		go func() {
   661  			c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
   662  			if err == nil {
   663  				c.Close()
   664  			}
   665  			wg.Done()
   666  		}()
   667  		s, err := namedpipe.Listen(pipePath)
   668  		if err != nil {
   669  			t.Error(i, err)
   670  		} else {
   671  			s.Close()
   672  		}
   673  		wg.Wait()
   674  	}
   675  }