github.com/sandwichdev/go-internals@v0.0.0-20210605002614-12311ac6b2c5/poll/fd_windows_test.go (about)

     1  // Copyright 2017 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package poll_test
     6  
     7  import (
     8  	"fmt"
     9  	"os"
    10  	"sync"
    11  	"syscall"
    12  	"testing"
    13  
    14  	"github.com/SandwichDev/go-internals/poll"
    15  )
    16  
    17  type loggedFD struct {
    18  	Net string
    19  	FD  *poll.FD
    20  	Err error
    21  }
    22  
    23  var (
    24  	logMu     sync.Mutex
    25  	loggedFDs map[syscall.Handle]*loggedFD
    26  )
    27  
    28  func logFD(net string, fd *poll.FD, err error) {
    29  	logMu.Lock()
    30  	defer logMu.Unlock()
    31  
    32  	loggedFDs[fd.Sysfd] = &loggedFD{
    33  		Net: net,
    34  		FD:  fd,
    35  		Err: err,
    36  	}
    37  }
    38  
    39  func init() {
    40  	loggedFDs = make(map[syscall.Handle]*loggedFD)
    41  	*poll.LogInitFD = logFD
    42  }
    43  
    44  func findLoggedFD(h syscall.Handle) (lfd *loggedFD, found bool) {
    45  	logMu.Lock()
    46  	defer logMu.Unlock()
    47  
    48  	lfd, found = loggedFDs[h]
    49  	return lfd, found
    50  }
    51  
    52  // checkFileIsNotPartOfNetpoll verifies that f is not managed by netpoll.
    53  // It returns error, if check fails.
    54  func checkFileIsNotPartOfNetpoll(f *os.File) error {
    55  	lfd, found := findLoggedFD(syscall.Handle(f.Fd()))
    56  	if !found {
    57  		return fmt.Errorf("%v fd=%v: is not found in the log", f.Name(), f.Fd())
    58  	}
    59  	if lfd.FD.IsPartOfNetpoll() {
    60  		return fmt.Errorf("%v fd=%v: is part of netpoll, but should not be (logged: net=%v err=%v)", f.Name(), f.Fd(), lfd.Net, lfd.Err)
    61  	}
    62  	return nil
    63  }
    64  
    65  func TestFileFdsAreInitialised(t *testing.T) {
    66  	exe, err := os.Executable()
    67  	if err != nil {
    68  		t.Fatal(err)
    69  	}
    70  	f, err := os.Open(exe)
    71  	if err != nil {
    72  		t.Fatal(err)
    73  	}
    74  	defer f.Close()
    75  
    76  	err = checkFileIsNotPartOfNetpoll(f)
    77  	if err != nil {
    78  		t.Fatal(err)
    79  	}
    80  }
    81  
    82  func TestSerialFdsAreInitialised(t *testing.T) {
    83  	for _, name := range []string{"COM1", "COM2", "COM3", "COM4"} {
    84  		t.Run(name, func(t *testing.T) {
    85  			h, err := syscall.CreateFile(syscall.StringToUTF16Ptr(name),
    86  				syscall.GENERIC_READ|syscall.GENERIC_WRITE,
    87  				0,
    88  				nil,
    89  				syscall.OPEN_EXISTING,
    90  				syscall.FILE_ATTRIBUTE_NORMAL|syscall.FILE_FLAG_OVERLAPPED,
    91  				0)
    92  			if err != nil {
    93  				if errno, ok := err.(syscall.Errno); ok {
    94  					switch errno {
    95  					case syscall.ERROR_FILE_NOT_FOUND,
    96  						syscall.ERROR_ACCESS_DENIED:
    97  						t.Log("Skipping: ", err)
    98  						return
    99  					}
   100  				}
   101  				t.Fatal(err)
   102  			}
   103  			f := os.NewFile(uintptr(h), name)
   104  			defer f.Close()
   105  
   106  			err = checkFileIsNotPartOfNetpoll(f)
   107  			if err != nil {
   108  				t.Fatal(err)
   109  			}
   110  		})
   111  	}
   112  }