github.com/Serizao/go-winio@v0.0.0-20230906082528-f02f7f4ad6e8/pipe_test.go (about)

     1  //go:build windows
     2  // +build windows
     3  
     4  package winio
     5  
     6  import (
     7  	"bufio"
     8  	"bytes"
     9  	"context"
    10  	"errors"
    11  	"io"
    12  	"net"
    13  	"sync"
    14  	"testing"
    15  	"time"
    16  	"unsafe"
    17  
    18  	"golang.org/x/sys/windows"
    19  )
    20  
    21  var testPipeName = `\\.\pipe\winiotestpipe`
    22  
    23  var aLongTimeAgo = time.Unix(1, 0)
    24  
    25  func TestDialUnknownFailsImmediately(t *testing.T) {
    26  	_, err := DialPipe(testPipeName, nil)
    27  	if !errors.Is(err, windows.ERROR_FILE_NOT_FOUND) {
    28  		t.Fatalf("expected ERROR_FILE_NOT_FOUND got %v", err)
    29  	}
    30  }
    31  
    32  func TestDialListenerTimesOut(t *testing.T) {
    33  	l, err := ListenPipe(testPipeName, nil)
    34  	if err != nil {
    35  		t.Fatal(err)
    36  	}
    37  	defer l.Close()
    38  	var d = 10 * time.Millisecond
    39  	_, err = DialPipe(testPipeName, &d)
    40  	if !errors.Is(err, ErrTimeout) {
    41  		t.Fatalf("expected ErrTimeout, got %v", err)
    42  	}
    43  }
    44  
    45  func TestDialContextListenerTimesOut(t *testing.T) {
    46  	l, err := ListenPipe(testPipeName, nil)
    47  	if err != nil {
    48  		t.Fatal(err)
    49  	}
    50  	defer l.Close()
    51  	var d = 10 * time.Millisecond
    52  	ctx, cancel := context.WithTimeout(context.Background(), d)
    53  	defer cancel()
    54  	_, err = DialPipeContext(ctx, testPipeName)
    55  	if !errors.Is(err, context.DeadlineExceeded) {
    56  		t.Fatalf("expected context.DeadlineExceeded, got %v", err)
    57  	}
    58  }
    59  
    60  func TestDialListenerGetsCancelled(t *testing.T) {
    61  	ctx, cancel := context.WithCancel(context.Background())
    62  	l, err := ListenPipe(testPipeName, nil)
    63  	if err != nil {
    64  		t.Fatal(err)
    65  	}
    66  	ch := make(chan error)
    67  	defer l.Close()
    68  	go func(ctx context.Context, ch chan error) {
    69  		_, err := DialPipeContext(ctx, testPipeName)
    70  		ch <- err
    71  	}(ctx, ch)
    72  	time.Sleep(time.Millisecond * 30)
    73  	cancel()
    74  	err = <-ch
    75  	if !errors.Is(err, context.Canceled) {
    76  		t.Fatalf("expected context.Canceled, got %v", err)
    77  	}
    78  }
    79  
    80  func TestDialAccessDeniedWithRestrictedSD(t *testing.T) {
    81  	c := PipeConfig{
    82  		SecurityDescriptor: "D:P(A;;0x1200FF;;;WD)",
    83  	}
    84  	l, err := ListenPipe(testPipeName, &c)
    85  	if err != nil {
    86  		t.Fatal(err)
    87  	}
    88  	defer l.Close()
    89  	_, err = DialPipe(testPipeName, nil)
    90  	if !errors.Is(err, windows.ERROR_ACCESS_DENIED) {
    91  		t.Fatalf("expected ERROR_ACCESS_DENIED, got %v", err)
    92  	}
    93  }
    94  
    95  func getConnection(cfg *PipeConfig) (client net.Conn, server net.Conn, err error) {
    96  	l, err := ListenPipe(testPipeName, cfg)
    97  	if err != nil {
    98  		return nil, nil, err
    99  	}
   100  	defer l.Close()
   101  
   102  	type response struct {
   103  		c   net.Conn
   104  		err error
   105  	}
   106  	ch := make(chan response)
   107  	go func() {
   108  		c, err := l.Accept()
   109  		ch <- response{c, err}
   110  	}()
   111  
   112  	c, err := DialPipe(testPipeName, nil)
   113  	if err != nil {
   114  		return client, server, err
   115  	}
   116  
   117  	r := <-ch
   118  	if err = r.err; err != nil {
   119  		c.Close()
   120  		return nil, nil, err
   121  	}
   122  
   123  	return c, r.c, nil
   124  }
   125  
   126  func TestReadTimeout(t *testing.T) {
   127  	c, s, err := getConnection(nil)
   128  	if err != nil {
   129  		t.Fatal(err)
   130  	}
   131  	defer c.Close()
   132  	defer s.Close()
   133  
   134  	_ = c.SetReadDeadline(time.Now().Add(10 * time.Millisecond))
   135  
   136  	buf := make([]byte, 10)
   137  	_, err = c.Read(buf)
   138  	if !errors.Is(err, ErrTimeout) {
   139  		t.Fatalf("expected ErrTimeout, got %v", err)
   140  	}
   141  }
   142  
   143  func server(l net.Listener, ch chan int) {
   144  	c, err := l.Accept()
   145  	if err != nil {
   146  		panic(err)
   147  	}
   148  	rw := bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c))
   149  	s, err := rw.ReadString('\n')
   150  	if err != nil {
   151  		panic(err)
   152  	}
   153  	_, err = rw.WriteString("got " + s)
   154  	if err != nil {
   155  		panic(err)
   156  	}
   157  	err = rw.Flush()
   158  	if err != nil {
   159  		panic(err)
   160  	}
   161  	c.Close()
   162  	ch <- 1
   163  }
   164  
   165  func TestFullListenDialReadWrite(t *testing.T) {
   166  	l, err := ListenPipe(testPipeName, nil)
   167  	if err != nil {
   168  		t.Fatal(err)
   169  	}
   170  	defer l.Close()
   171  
   172  	ch := make(chan int)
   173  	go server(l, ch)
   174  
   175  	c, err := DialPipe(testPipeName, nil)
   176  	if err != nil {
   177  		t.Fatal(err)
   178  	}
   179  	defer c.Close()
   180  
   181  	rw := bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c))
   182  	_, err = rw.WriteString("hello world\n")
   183  	if err != nil {
   184  		t.Fatal(err)
   185  	}
   186  	err = rw.Flush()
   187  	if err != nil {
   188  		t.Fatal(err)
   189  	}
   190  
   191  	s, err := rw.ReadString('\n')
   192  	if err != nil {
   193  		t.Fatal(err)
   194  	}
   195  	ms := "got hello world\n"
   196  	if s != ms {
   197  		t.Errorf("expected '%s', got '%s'", ms, s)
   198  	}
   199  
   200  	<-ch
   201  }
   202  
   203  func TestCloseAbortsListen(t *testing.T) {
   204  	l, err := ListenPipe(testPipeName, nil)
   205  	if err != nil {
   206  		t.Fatal(err)
   207  	}
   208  
   209  	ch := make(chan error)
   210  	go func() {
   211  		_, err := l.Accept()
   212  		ch <- err
   213  	}()
   214  
   215  	time.Sleep(30 * time.Millisecond)
   216  	l.Close()
   217  
   218  	err = <-ch
   219  	if !errors.Is(err, ErrPipeListenerClosed) {
   220  		t.Fatalf("expected ErrPipeListenerClosed, got %v", err)
   221  	}
   222  }
   223  
   224  func ensureEOFOnClose(t *testing.T, r io.Reader, w io.Closer) {
   225  	t.Helper()
   226  
   227  	b := make([]byte, 10)
   228  	w.Close()
   229  	n, err := r.Read(b)
   230  	if n > 0 {
   231  		t.Errorf("unexpected byte count %d", n)
   232  	}
   233  	if err != io.EOF {
   234  		t.Errorf("expected EOF: %v", err)
   235  	}
   236  }
   237  
   238  func TestCloseClientEOFServer(t *testing.T) {
   239  	c, s, err := getConnection(nil)
   240  	if err != nil {
   241  		t.Fatal(err)
   242  	}
   243  	defer c.Close()
   244  	defer s.Close()
   245  	ensureEOFOnClose(t, c, s)
   246  }
   247  
   248  func TestCloseServerEOFClient(t *testing.T) {
   249  	c, s, err := getConnection(nil)
   250  	if err != nil {
   251  		t.Fatal(err)
   252  	}
   253  	defer c.Close()
   254  	defer s.Close()
   255  	ensureEOFOnClose(t, s, c)
   256  }
   257  
   258  func TestCloseWriteEOF(t *testing.T) {
   259  	cfg := &PipeConfig{
   260  		MessageMode: true,
   261  	}
   262  	c, s, err := getConnection(cfg)
   263  	if err != nil {
   264  		t.Fatal(err)
   265  	}
   266  	defer c.Close()
   267  	defer s.Close()
   268  
   269  	type closeWriter interface {
   270  		CloseWrite() error
   271  	}
   272  
   273  	err = c.(closeWriter).CloseWrite()
   274  	if err != nil {
   275  		t.Fatal(err)
   276  	}
   277  
   278  	b := make([]byte, 10)
   279  	_, err = s.Read(b)
   280  	if !errors.Is(err, io.EOF) {
   281  		t.Fatal(err)
   282  	}
   283  }
   284  
   285  func TestAcceptAfterCloseFails(t *testing.T) {
   286  	l, err := ListenPipe(testPipeName, nil)
   287  	if err != nil {
   288  		t.Fatal(err)
   289  	}
   290  	l.Close()
   291  	_, err = l.Accept()
   292  	if !errors.Is(err, ErrPipeListenerClosed) {
   293  		t.Fatalf("expected ErrPipeListenerClosed, got %v", err)
   294  	}
   295  }
   296  
   297  func TestDialTimesOutByDefault(t *testing.T) {
   298  	l, err := ListenPipe(testPipeName, nil)
   299  	if err != nil {
   300  		t.Fatal(err)
   301  	}
   302  	defer l.Close()
   303  	_, err = DialPipe(testPipeName, nil)
   304  	if !errors.Is(err, ErrTimeout) {
   305  		t.Fatalf("expected ErrTimeout, got %v", err)
   306  	}
   307  }
   308  
   309  func TestTimeoutPendingRead(t *testing.T) {
   310  	l, err := ListenPipe(testPipeName, nil)
   311  	if err != nil {
   312  		t.Fatal(err)
   313  	}
   314  	defer l.Close()
   315  
   316  	serverDone := make(chan struct{})
   317  
   318  	go func() {
   319  		s, err := l.Accept()
   320  		if err != nil {
   321  			t.Error(err)
   322  			return
   323  		}
   324  		time.Sleep(1 * time.Second)
   325  		s.Close()
   326  		close(serverDone)
   327  	}()
   328  
   329  	client, err := DialPipe(testPipeName, nil)
   330  	if err != nil {
   331  		t.Fatal(err)
   332  	}
   333  	defer client.Close()
   334  
   335  	clientErr := make(chan error)
   336  	go func() {
   337  		buf := make([]byte, 10)
   338  		_, err = client.Read(buf)
   339  		clientErr <- err
   340  	}()
   341  
   342  	time.Sleep(100 * time.Millisecond) // make *sure* the pipe is reading before we set the deadline
   343  	_ = client.SetReadDeadline(aLongTimeAgo)
   344  
   345  	select {
   346  	case err = <-clientErr:
   347  		if !errors.Is(err, ErrTimeout) {
   348  			t.Fatalf("expected ErrTimeout, got %v", err)
   349  		}
   350  	case <-time.After(100 * time.Millisecond):
   351  		t.Fatalf("timed out while waiting for read to cancel")
   352  		<-clientErr
   353  	}
   354  	<-serverDone
   355  }
   356  
   357  func TestTimeoutPendingWrite(t *testing.T) {
   358  	l, err := ListenPipe(testPipeName, nil)
   359  	if err != nil {
   360  		t.Fatal(err)
   361  	}
   362  	defer l.Close()
   363  
   364  	serverDone := make(chan struct{})
   365  
   366  	go func() {
   367  		s, err := l.Accept()
   368  		if err != nil {
   369  			t.Error(err)
   370  			return
   371  		}
   372  		time.Sleep(1 * time.Second)
   373  		s.Close()
   374  		close(serverDone)
   375  	}()
   376  
   377  	client, err := DialPipe(testPipeName, nil)
   378  	if err != nil {
   379  		t.Fatal(err)
   380  	}
   381  	defer client.Close()
   382  
   383  	clientErr := make(chan error)
   384  	go func() {
   385  		_, err = client.Write([]byte("this should timeout"))
   386  		clientErr <- err
   387  	}()
   388  
   389  	time.Sleep(100 * time.Millisecond) // make *sure* the pipe is writing before we set the deadline
   390  	_ = client.SetWriteDeadline(aLongTimeAgo)
   391  
   392  	select {
   393  	case err = <-clientErr:
   394  		if !errors.Is(err, ErrTimeout) {
   395  			t.Fatalf("expected ErrTimeout, got %v", err)
   396  		}
   397  	case <-time.After(100 * time.Millisecond):
   398  		t.Fatalf("timed out while waiting for write to cancel")
   399  		<-clientErr
   400  	}
   401  	<-serverDone
   402  }
   403  
   404  func TestDisconnectPipe(t *testing.T) {
   405  	l, err := ListenPipe(testPipeName, nil)
   406  	if err != nil {
   407  		t.Fatal(err)
   408  	}
   409  	defer l.Close()
   410  
   411  	const testData = "foo"
   412  	serverDone := make(chan struct{})
   413  
   414  	go func() {
   415  		s, err := l.Accept()
   416  		if err != nil {
   417  			t.Error(err)
   418  			return
   419  		}
   420  		defer func() {
   421  			s.Close()
   422  			close(serverDone)
   423  		}()
   424  
   425  		if _, err := s.Write([]byte(testData)); err != nil {
   426  			t.Error(err)
   427  			return
   428  		}
   429  
   430  		if err := s.(PipeConn).Flush(); err != nil {
   431  			t.Error(err)
   432  			return
   433  		}
   434  
   435  		if err := s.(PipeConn).Disconnect(); err != nil {
   436  			t.Error(err)
   437  			return
   438  		}
   439  	}()
   440  
   441  	client, err := DialPipe(testPipeName, nil)
   442  	if err != nil {
   443  		t.Fatal(err)
   444  	}
   445  	defer client.Close()
   446  
   447  	buf := make([]byte, len(testData))
   448  	if _, err = client.Read(buf); err != nil {
   449  		t.Fatal(err)
   450  	}
   451  
   452  	dataRead := string(buf)
   453  	if dataRead != testData {
   454  		t.Fatalf("incorrect data read %q", dataRead)
   455  	}
   456  
   457  	if _, err = client.Read(buf); err == nil {
   458  		t.Fatal("read should fail")
   459  	}
   460  
   461  	<-serverDone
   462  }
   463  
   464  type CloseWriter interface {
   465  	CloseWrite() error
   466  }
   467  
   468  func TestEchoWithMessaging(t *testing.T) {
   469  	c := PipeConfig{
   470  		MessageMode:      true,  // Use message mode so that CloseWrite() is supported
   471  		InputBufferSize:  65536, // Use 64KB buffers to improve performance
   472  		OutputBufferSize: 65536,
   473  	}
   474  	l, err := ListenPipe(testPipeName, &c)
   475  	if err != nil {
   476  		t.Fatal(err)
   477  	}
   478  	defer l.Close()
   479  
   480  	listenerDone := make(chan bool)
   481  	clientDone := make(chan bool)
   482  	go func() {
   483  		// server echo
   484  		conn, e := l.Accept()
   485  		if e != nil {
   486  			t.Error(err)
   487  			return
   488  		}
   489  		defer conn.Close()
   490  
   491  		time.Sleep(500 * time.Millisecond) // make *sure* we don't begin to read before eof signal is sent
   492  		_, _ = io.Copy(conn, conn)
   493  		_ = conn.(CloseWriter).CloseWrite()
   494  		close(listenerDone)
   495  	}()
   496  	timeout := 1 * time.Second
   497  	client, err := DialPipe(testPipeName, &timeout)
   498  	if err != nil {
   499  		t.Fatal(err)
   500  	}
   501  	defer client.Close()
   502  
   503  	go func() {
   504  		// client read back
   505  		bytes := make([]byte, 2)
   506  		n, e := client.Read(bytes)
   507  		if e != nil {
   508  			t.Error(err)
   509  			return
   510  		}
   511  		if n != 2 {
   512  			t.Errorf("expected 2 bytes, got %v", n)
   513  			return
   514  		}
   515  		close(clientDone)
   516  	}()
   517  
   518  	payload := make([]byte, 2)
   519  	payload[0] = 0
   520  	payload[1] = 1
   521  
   522  	n, err := client.Write(payload)
   523  	if err != nil {
   524  		t.Fatal(err)
   525  	}
   526  	if n != 2 {
   527  		t.Fatalf("expected 2 bytes, got %v", n)
   528  	}
   529  	_ = client.(CloseWriter).CloseWrite()
   530  	<-listenerDone
   531  	<-clientDone
   532  }
   533  
   534  func TestConnectRace(t *testing.T) {
   535  	l, err := ListenPipe(testPipeName, nil)
   536  	if err != nil {
   537  		t.Fatal(err)
   538  	}
   539  	defer l.Close()
   540  	go func() {
   541  		for {
   542  			s, err := l.Accept()
   543  			if errors.Is(err, ErrPipeListenerClosed) {
   544  				return
   545  			}
   546  
   547  			if err != nil {
   548  				t.Error(err)
   549  				return
   550  			}
   551  			s.Close()
   552  		}
   553  	}()
   554  
   555  	for i := 0; i < 1000; i++ {
   556  		c, err := DialPipe(testPipeName, nil)
   557  		if err != nil {
   558  			t.Fatal(err)
   559  		}
   560  		c.Close()
   561  	}
   562  }
   563  
   564  func TestMessageReadMode(t *testing.T) {
   565  	var wg sync.WaitGroup
   566  	defer wg.Wait()
   567  
   568  	l, err := ListenPipe(testPipeName, &PipeConfig{MessageMode: true})
   569  	if err != nil {
   570  		t.Fatal(err)
   571  	}
   572  	defer l.Close()
   573  
   574  	msg := ([]byte)("hello world")
   575  
   576  	wg.Add(1)
   577  	go func() {
   578  		defer wg.Done()
   579  		s, err := l.Accept()
   580  		if err != nil {
   581  			t.Error(err)
   582  			return
   583  		}
   584  		_, err = s.Write(msg)
   585  		if err != nil {
   586  			t.Error(err)
   587  			return
   588  		}
   589  		s.Close()
   590  	}()
   591  
   592  	c, err := DialPipe(testPipeName, nil)
   593  	if err != nil {
   594  		t.Fatal(err)
   595  	}
   596  	defer c.Close()
   597  
   598  	setNamedPipeHandleState := windows.NewLazyDLL("kernel32.dll").NewProc("SetNamedPipeHandleState")
   599  
   600  	p := c.(*win32MessageBytePipe)
   601  	mode := uint32(windows.PIPE_READMODE_MESSAGE)
   602  	if s, _, err := setNamedPipeHandleState.Call(uintptr(p.handle), uintptr(unsafe.Pointer(&mode)), 0, 0); s == 0 {
   603  		t.Fatal(err)
   604  	}
   605  
   606  	ch := make([]byte, 1)
   607  	var vmsg []byte
   608  	for {
   609  		n, err := c.Read(ch)
   610  		if err == io.EOF { //nolint:errorlint
   611  			break
   612  		}
   613  		if err != nil {
   614  			t.Fatal(err)
   615  		}
   616  		if n != 1 {
   617  			t.Fatal("expected 1: ", n)
   618  		}
   619  		vmsg = append(vmsg, ch[0])
   620  	}
   621  	if !bytes.Equal(msg, vmsg) {
   622  		t.Fatalf("expected %s: %s", msg, vmsg)
   623  	}
   624  }
   625  
   626  func TestListenConnectRace(t *testing.T) {
   627  	for i := 0; i < 50 && !t.Failed(); i++ {
   628  		var wg sync.WaitGroup
   629  		wg.Add(1)
   630  		go func() {
   631  			c, err := DialPipe(testPipeName, nil)
   632  			if err == nil {
   633  				c.Close()
   634  			}
   635  			wg.Done()
   636  		}()
   637  		s, err := ListenPipe(testPipeName, nil)
   638  		if err != nil {
   639  			t.Error(i, err)
   640  		} else {
   641  			s.Close()
   642  		}
   643  		wg.Wait()
   644  	}
   645  }