github.com/GFW-knocker/wireguard@v1.0.1/ipc/namedpipe/namedpipe_test.go (about)

     1  // Copyright 2021 The Go Authors. All rights reserved.
     2  // Copyright 2015 Microsoft
     3  // Use of this source code is governed by a BSD-style
     4  // license that can be found in the LICENSE file.
     5  
     6  //go:build windows
     7  
     8  package namedpipe_test
     9  
    10  import (
    11  	"bufio"
    12  	"bytes"
    13  	"context"
    14  	"errors"
    15  	"io"
    16  	"net"
    17  	"os"
    18  	"sync"
    19  	"syscall"
    20  	"testing"
    21  	"time"
    22  
    23  	"github.com/GFW-knocker/wireguard/ipc/namedpipe"
    24  	"golang.org/x/sys/windows"
    25  )
    26  
    27  func randomPipePath() string {
    28  	guid, err := windows.GenerateGUID()
    29  	if err != nil {
    30  		panic(err)
    31  	}
    32  	return `\\.\PIPE\go-namedpipe-test-` + guid.String()
    33  }
    34  
    35  func TestPingPong(t *testing.T) {
    36  	const (
    37  		ping = 42
    38  		pong = 24
    39  	)
    40  	pipePath := randomPipePath()
    41  	listener, err := namedpipe.Listen(pipePath)
    42  	if err != nil {
    43  		t.Fatalf("unable to listen on pipe: %v", err)
    44  	}
    45  	defer listener.Close()
    46  	go func() {
    47  		incoming, err := listener.Accept()
    48  		if err != nil {
    49  			t.Fatalf("unable to accept pipe connection: %v", err)
    50  		}
    51  		defer incoming.Close()
    52  		var data [1]byte
    53  		_, err = incoming.Read(data[:])
    54  		if err != nil {
    55  			t.Fatalf("unable to read ping from pipe: %v", err)
    56  		}
    57  		if data[0] != ping {
    58  			t.Fatalf("expected ping, got %d", data[0])
    59  		}
    60  		data[0] = pong
    61  		_, err = incoming.Write(data[:])
    62  		if err != nil {
    63  			t.Fatalf("unable to write pong to pipe: %v", err)
    64  		}
    65  	}()
    66  	client, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
    67  	if err != nil {
    68  		t.Fatalf("unable to dial pipe: %v", err)
    69  	}
    70  	defer client.Close()
    71  	client.SetDeadline(time.Now().Add(time.Second * 5))
    72  	var data [1]byte
    73  	data[0] = ping
    74  	_, err = client.Write(data[:])
    75  	if err != nil {
    76  		t.Fatalf("unable to write ping to pipe: %v", err)
    77  	}
    78  	_, err = client.Read(data[:])
    79  	if err != nil {
    80  		t.Fatalf("unable to read pong from pipe: %v", err)
    81  	}
    82  	if data[0] != pong {
    83  		t.Fatalf("expected pong, got %d", data[0])
    84  	}
    85  }
    86  
    87  func TestDialUnknownFailsImmediately(t *testing.T) {
    88  	_, err := namedpipe.DialTimeout(randomPipePath(), time.Duration(0))
    89  	if !errors.Is(err, syscall.ENOENT) {
    90  		t.Fatalf("expected ENOENT got %v", err)
    91  	}
    92  }
    93  
    94  func TestDialListenerTimesOut(t *testing.T) {
    95  	pipePath := randomPipePath()
    96  	l, err := namedpipe.Listen(pipePath)
    97  	if err != nil {
    98  		t.Fatal(err)
    99  	}
   100  	defer l.Close()
   101  	pipe, err := namedpipe.DialTimeout(pipePath, 10*time.Millisecond)
   102  	if err == nil {
   103  		pipe.Close()
   104  	}
   105  	if err != os.ErrDeadlineExceeded {
   106  		t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err)
   107  	}
   108  }
   109  
   110  func TestDialContextListenerTimesOut(t *testing.T) {
   111  	pipePath := randomPipePath()
   112  	l, err := namedpipe.Listen(pipePath)
   113  	if err != nil {
   114  		t.Fatal(err)
   115  	}
   116  	defer l.Close()
   117  	d := 10 * time.Millisecond
   118  	ctx, _ := context.WithTimeout(context.Background(), d)
   119  	pipe, err := namedpipe.DialContext(ctx, pipePath)
   120  	if err == nil {
   121  		pipe.Close()
   122  	}
   123  	if err != context.DeadlineExceeded {
   124  		t.Fatalf("expected context.DeadlineExceeded, got %v", err)
   125  	}
   126  }
   127  
   128  func TestDialListenerGetsCancelled(t *testing.T) {
   129  	pipePath := randomPipePath()
   130  	ctx, cancel := context.WithCancel(context.Background())
   131  	l, err := namedpipe.Listen(pipePath)
   132  	if err != nil {
   133  		t.Fatal(err)
   134  	}
   135  	defer l.Close()
   136  	ch := make(chan error)
   137  	go func(ctx context.Context, ch chan error) {
   138  		_, err := namedpipe.DialContext(ctx, pipePath)
   139  		ch <- err
   140  	}(ctx, ch)
   141  	time.Sleep(time.Millisecond * 30)
   142  	cancel()
   143  	err = <-ch
   144  	if err != context.Canceled {
   145  		t.Fatalf("expected context.Canceled, got %v", err)
   146  	}
   147  }
   148  
   149  func TestDialAccessDeniedWithRestrictedSD(t *testing.T) {
   150  	if windows.NewLazySystemDLL("ntdll.dll").NewProc("wine_get_version").Find() == nil {
   151  		t.Skip("dacls on named pipes are broken on wine")
   152  	}
   153  	pipePath := randomPipePath()
   154  	sd, _ := windows.SecurityDescriptorFromString("D:")
   155  	l, err := (&namedpipe.ListenConfig{
   156  		SecurityDescriptor: sd,
   157  	}).Listen(pipePath)
   158  	if err != nil {
   159  		t.Fatal(err)
   160  	}
   161  	defer l.Close()
   162  	pipe, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
   163  	if err == nil {
   164  		pipe.Close()
   165  	}
   166  	if !errors.Is(err, windows.ERROR_ACCESS_DENIED) {
   167  		t.Fatalf("expected ERROR_ACCESS_DENIED, got %v", err)
   168  	}
   169  }
   170  
   171  func getConnection(cfg *namedpipe.ListenConfig) (client, server net.Conn, err error) {
   172  	pipePath := randomPipePath()
   173  	if cfg == nil {
   174  		cfg = &namedpipe.ListenConfig{}
   175  	}
   176  	l, err := cfg.Listen(pipePath)
   177  	if err != nil {
   178  		return
   179  	}
   180  	defer l.Close()
   181  
   182  	type response struct {
   183  		c   net.Conn
   184  		err error
   185  	}
   186  	ch := make(chan response)
   187  	go func() {
   188  		c, err := l.Accept()
   189  		ch <- response{c, err}
   190  	}()
   191  
   192  	c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
   193  	if err != nil {
   194  		return
   195  	}
   196  
   197  	r := <-ch
   198  	if err = r.err; err != nil {
   199  		c.Close()
   200  		return
   201  	}
   202  
   203  	client = c
   204  	server = r.c
   205  	return
   206  }
   207  
   208  func TestReadTimeout(t *testing.T) {
   209  	c, s, err := getConnection(nil)
   210  	if err != nil {
   211  		t.Fatal(err)
   212  	}
   213  	defer c.Close()
   214  	defer s.Close()
   215  
   216  	c.SetReadDeadline(time.Now().Add(10 * time.Millisecond))
   217  
   218  	buf := make([]byte, 10)
   219  	_, err = c.Read(buf)
   220  	if err != os.ErrDeadlineExceeded {
   221  		t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err)
   222  	}
   223  }
   224  
   225  func server(l net.Listener, ch chan int) {
   226  	c, err := l.Accept()
   227  	if err != nil {
   228  		panic(err)
   229  	}
   230  	rw := bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c))
   231  	s, err := rw.ReadString('\n')
   232  	if err != nil {
   233  		panic(err)
   234  	}
   235  	_, err = rw.WriteString("got " + s)
   236  	if err != nil {
   237  		panic(err)
   238  	}
   239  	err = rw.Flush()
   240  	if err != nil {
   241  		panic(err)
   242  	}
   243  	c.Close()
   244  	ch <- 1
   245  }
   246  
   247  func TestFullListenDialReadWrite(t *testing.T) {
   248  	pipePath := randomPipePath()
   249  	l, err := namedpipe.Listen(pipePath)
   250  	if err != nil {
   251  		t.Fatal(err)
   252  	}
   253  	defer l.Close()
   254  
   255  	ch := make(chan int)
   256  	go server(l, ch)
   257  
   258  	c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
   259  	if err != nil {
   260  		t.Fatal(err)
   261  	}
   262  	defer c.Close()
   263  
   264  	rw := bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c))
   265  	_, err = rw.WriteString("hello world\n")
   266  	if err != nil {
   267  		t.Fatal(err)
   268  	}
   269  	err = rw.Flush()
   270  	if err != nil {
   271  		t.Fatal(err)
   272  	}
   273  
   274  	s, err := rw.ReadString('\n')
   275  	if err != nil {
   276  		t.Fatal(err)
   277  	}
   278  	ms := "got hello world\n"
   279  	if s != ms {
   280  		t.Errorf("expected '%s', got '%s'", ms, s)
   281  	}
   282  
   283  	<-ch
   284  }
   285  
   286  func TestCloseAbortsListen(t *testing.T) {
   287  	pipePath := randomPipePath()
   288  	l, err := namedpipe.Listen(pipePath)
   289  	if err != nil {
   290  		t.Fatal(err)
   291  	}
   292  
   293  	ch := make(chan error)
   294  	go func() {
   295  		_, err := l.Accept()
   296  		ch <- err
   297  	}()
   298  
   299  	time.Sleep(30 * time.Millisecond)
   300  	l.Close()
   301  
   302  	err = <-ch
   303  	if err != net.ErrClosed {
   304  		t.Fatalf("expected net.ErrClosed, got %v", err)
   305  	}
   306  }
   307  
   308  func ensureEOFOnClose(t *testing.T, r io.Reader, w io.Closer) {
   309  	b := make([]byte, 10)
   310  	w.Close()
   311  	n, err := r.Read(b)
   312  	if n > 0 {
   313  		t.Errorf("unexpected byte count %d", n)
   314  	}
   315  	if err != io.EOF {
   316  		t.Errorf("expected EOF: %v", err)
   317  	}
   318  }
   319  
   320  func TestCloseClientEOFServer(t *testing.T) {
   321  	c, s, err := getConnection(nil)
   322  	if err != nil {
   323  		t.Fatal(err)
   324  	}
   325  	defer c.Close()
   326  	defer s.Close()
   327  	ensureEOFOnClose(t, c, s)
   328  }
   329  
   330  func TestCloseServerEOFClient(t *testing.T) {
   331  	c, s, err := getConnection(nil)
   332  	if err != nil {
   333  		t.Fatal(err)
   334  	}
   335  	defer c.Close()
   336  	defer s.Close()
   337  	ensureEOFOnClose(t, s, c)
   338  }
   339  
   340  func TestCloseWriteEOF(t *testing.T) {
   341  	cfg := &namedpipe.ListenConfig{
   342  		MessageMode: true,
   343  	}
   344  	c, s, err := getConnection(cfg)
   345  	if err != nil {
   346  		t.Fatal(err)
   347  	}
   348  	defer c.Close()
   349  	defer s.Close()
   350  
   351  	type closeWriter interface {
   352  		CloseWrite() error
   353  	}
   354  
   355  	err = c.(closeWriter).CloseWrite()
   356  	if err != nil {
   357  		t.Fatal(err)
   358  	}
   359  
   360  	b := make([]byte, 10)
   361  	_, err = s.Read(b)
   362  	if err != io.EOF {
   363  		t.Fatal(err)
   364  	}
   365  }
   366  
   367  func TestAcceptAfterCloseFails(t *testing.T) {
   368  	pipePath := randomPipePath()
   369  	l, err := namedpipe.Listen(pipePath)
   370  	if err != nil {
   371  		t.Fatal(err)
   372  	}
   373  	l.Close()
   374  	_, err = l.Accept()
   375  	if err != net.ErrClosed {
   376  		t.Fatalf("expected net.ErrClosed, got %v", err)
   377  	}
   378  }
   379  
   380  func TestDialTimesOutByDefault(t *testing.T) {
   381  	pipePath := randomPipePath()
   382  	l, err := namedpipe.Listen(pipePath)
   383  	if err != nil {
   384  		t.Fatal(err)
   385  	}
   386  	defer l.Close()
   387  	pipe, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) // Should timeout after 2 seconds.
   388  	if err == nil {
   389  		pipe.Close()
   390  	}
   391  	if err != os.ErrDeadlineExceeded {
   392  		t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err)
   393  	}
   394  }
   395  
   396  func TestTimeoutPendingRead(t *testing.T) {
   397  	pipePath := randomPipePath()
   398  	l, err := namedpipe.Listen(pipePath)
   399  	if err != nil {
   400  		t.Fatal(err)
   401  	}
   402  	defer l.Close()
   403  
   404  	serverDone := make(chan struct{})
   405  
   406  	go func() {
   407  		s, err := l.Accept()
   408  		if err != nil {
   409  			t.Fatal(err)
   410  		}
   411  		time.Sleep(1 * time.Second)
   412  		s.Close()
   413  		close(serverDone)
   414  	}()
   415  
   416  	client, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
   417  	if err != nil {
   418  		t.Fatal(err)
   419  	}
   420  	defer client.Close()
   421  
   422  	clientErr := make(chan error)
   423  	go func() {
   424  		buf := make([]byte, 10)
   425  		_, err = client.Read(buf)
   426  		clientErr <- err
   427  	}()
   428  
   429  	time.Sleep(100 * time.Millisecond) // make *sure* the pipe is reading before we set the deadline
   430  	client.SetReadDeadline(time.Unix(1, 0))
   431  
   432  	select {
   433  	case err = <-clientErr:
   434  		if err != os.ErrDeadlineExceeded {
   435  			t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err)
   436  		}
   437  	case <-time.After(100 * time.Millisecond):
   438  		t.Fatalf("timed out while waiting for read to cancel")
   439  		<-clientErr
   440  	}
   441  	<-serverDone
   442  }
   443  
   444  func TestTimeoutPendingWrite(t *testing.T) {
   445  	pipePath := randomPipePath()
   446  	l, err := namedpipe.Listen(pipePath)
   447  	if err != nil {
   448  		t.Fatal(err)
   449  	}
   450  	defer l.Close()
   451  
   452  	serverDone := make(chan struct{})
   453  
   454  	go func() {
   455  		s, err := l.Accept()
   456  		if err != nil {
   457  			t.Fatal(err)
   458  		}
   459  		time.Sleep(1 * time.Second)
   460  		s.Close()
   461  		close(serverDone)
   462  	}()
   463  
   464  	client, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
   465  	if err != nil {
   466  		t.Fatal(err)
   467  	}
   468  	defer client.Close()
   469  
   470  	clientErr := make(chan error)
   471  	go func() {
   472  		_, err = client.Write([]byte("this should timeout"))
   473  		clientErr <- err
   474  	}()
   475  
   476  	time.Sleep(100 * time.Millisecond) // make *sure* the pipe is writing before we set the deadline
   477  	client.SetWriteDeadline(time.Unix(1, 0))
   478  
   479  	select {
   480  	case err = <-clientErr:
   481  		if err != os.ErrDeadlineExceeded {
   482  			t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err)
   483  		}
   484  	case <-time.After(100 * time.Millisecond):
   485  		t.Fatalf("timed out while waiting for write to cancel")
   486  		<-clientErr
   487  	}
   488  	<-serverDone
   489  }
   490  
   491  type CloseWriter interface {
   492  	CloseWrite() error
   493  }
   494  
   495  func TestEchoWithMessaging(t *testing.T) {
   496  	pipePath := randomPipePath()
   497  	l, err := (&namedpipe.ListenConfig{
   498  		MessageMode:      true,  // Use message mode so that CloseWrite() is supported
   499  		InputBufferSize:  65536, // Use 64KB buffers to improve performance
   500  		OutputBufferSize: 65536,
   501  	}).Listen(pipePath)
   502  	if err != nil {
   503  		t.Fatal(err)
   504  	}
   505  	defer l.Close()
   506  
   507  	listenerDone := make(chan bool)
   508  	clientDone := make(chan bool)
   509  	go func() {
   510  		// server echo
   511  		conn, err := l.Accept()
   512  		if err != nil {
   513  			t.Fatal(err)
   514  		}
   515  		defer conn.Close()
   516  
   517  		time.Sleep(500 * time.Millisecond) // make *sure* we don't begin to read before eof signal is sent
   518  		_, err = io.Copy(conn, conn)
   519  		if err != nil {
   520  			t.Fatal(err)
   521  		}
   522  		conn.(CloseWriter).CloseWrite()
   523  		close(listenerDone)
   524  	}()
   525  	client, err := namedpipe.DialTimeout(pipePath, time.Second)
   526  	if err != nil {
   527  		t.Fatal(err)
   528  	}
   529  	defer client.Close()
   530  
   531  	go func() {
   532  		// client read back
   533  		bytes := make([]byte, 2)
   534  		n, e := client.Read(bytes)
   535  		if e != nil {
   536  			t.Fatal(e)
   537  		}
   538  		if n != 2 || bytes[0] != 0 || bytes[1] != 1 {
   539  			t.Fatalf("expected 2 bytes, got %v", n)
   540  		}
   541  		close(clientDone)
   542  	}()
   543  
   544  	payload := make([]byte, 2)
   545  	payload[0] = 0
   546  	payload[1] = 1
   547  
   548  	n, err := client.Write(payload)
   549  	if err != nil {
   550  		t.Fatal(err)
   551  	}
   552  	if n != 2 {
   553  		t.Fatalf("expected 2 bytes, got %v", n)
   554  	}
   555  	client.(CloseWriter).CloseWrite()
   556  	<-listenerDone
   557  	<-clientDone
   558  }
   559  
   560  func TestConnectRace(t *testing.T) {
   561  	pipePath := randomPipePath()
   562  	l, err := namedpipe.Listen(pipePath)
   563  	if err != nil {
   564  		t.Fatal(err)
   565  	}
   566  	defer l.Close()
   567  	go func() {
   568  		for {
   569  			s, err := l.Accept()
   570  			if err == net.ErrClosed {
   571  				return
   572  			}
   573  
   574  			if err != nil {
   575  				t.Fatal(err)
   576  			}
   577  			s.Close()
   578  		}
   579  	}()
   580  
   581  	for i := 0; i < 1000; i++ {
   582  		c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
   583  		if err != nil {
   584  			t.Fatal(err)
   585  		}
   586  		c.Close()
   587  	}
   588  }
   589  
   590  func TestMessageReadMode(t *testing.T) {
   591  	if maj, _, _ := windows.RtlGetNtVersionNumbers(); maj <= 8 {
   592  		t.Skipf("Skipping on Windows %d", maj)
   593  	}
   594  	var wg sync.WaitGroup
   595  	defer wg.Wait()
   596  	pipePath := randomPipePath()
   597  	l, err := (&namedpipe.ListenConfig{MessageMode: true}).Listen(pipePath)
   598  	if err != nil {
   599  		t.Fatal(err)
   600  	}
   601  	defer l.Close()
   602  
   603  	msg := ([]byte)("hello world")
   604  
   605  	wg.Add(1)
   606  	go func() {
   607  		defer wg.Done()
   608  		s, err := l.Accept()
   609  		if err != nil {
   610  			t.Fatal(err)
   611  		}
   612  		_, err = s.Write(msg)
   613  		if err != nil {
   614  			t.Fatal(err)
   615  		}
   616  		s.Close()
   617  	}()
   618  
   619  	c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
   620  	if err != nil {
   621  		t.Fatal(err)
   622  	}
   623  	defer c.Close()
   624  
   625  	mode := uint32(windows.PIPE_READMODE_MESSAGE)
   626  	err = windows.SetNamedPipeHandleState(c.(interface{ Handle() windows.Handle }).Handle(), &mode, nil, nil)
   627  	if err != nil {
   628  		t.Fatal(err)
   629  	}
   630  
   631  	ch := make([]byte, 1)
   632  	var vmsg []byte
   633  	for {
   634  		n, err := c.Read(ch)
   635  		if err == io.EOF {
   636  			break
   637  		}
   638  		if err != nil {
   639  			t.Fatal(err)
   640  		}
   641  		if n != 1 {
   642  			t.Fatalf("expected 1, got %d", n)
   643  		}
   644  		vmsg = append(vmsg, ch[0])
   645  	}
   646  	if !bytes.Equal(msg, vmsg) {
   647  		t.Fatalf("expected %s, got %s", msg, vmsg)
   648  	}
   649  }
   650  
   651  func TestListenConnectRace(t *testing.T) {
   652  	if testing.Short() {
   653  		t.Skip("Skipping long race test")
   654  	}
   655  	pipePath := randomPipePath()
   656  	for i := 0; i < 50 && !t.Failed(); i++ {
   657  		var wg sync.WaitGroup
   658  		wg.Add(1)
   659  		go func() {
   660  			c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
   661  			if err == nil {
   662  				c.Close()
   663  			}
   664  			wg.Done()
   665  		}()
   666  		s, err := namedpipe.Listen(pipePath)
   667  		if err != nil {
   668  			t.Error(i, err)
   669  		} else {
   670  			s.Close()
   671  		}
   672  		wg.Wait()
   673  	}
   674  }