golang.org/toolchain@v0.0.1-go1.9rc2.windows-amd64/src/vendor/golang_org/x/net/nettest/conntest_test.go (about)

     1  // Copyright 2016 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // +build go1.8
     6  
     7  package nettest
     8  
     9  import (
    10  	"fmt"
    11  	"io/ioutil"
    12  	"net"
    13  	"os"
    14  	"runtime"
    15  	"testing"
    16  )
    17  
    18  // testUnixAddr uses ioutil.TempFile to get a name that is unique.
    19  // It also uses /tmp directory in case it is prohibited to create UNIX
    20  // sockets in TMPDIR.
    21  func testUnixAddr() string {
    22  	f, err := ioutil.TempFile("", "go-nettest")
    23  	if err != nil {
    24  		panic(err)
    25  	}
    26  	addr := f.Name()
    27  	f.Close()
    28  	os.Remove(addr)
    29  	return addr
    30  }
    31  
    32  // testableNetwork reports whether network is testable on the current
    33  // platform configuration.
    34  // This is based on logic from standard library's net/platform_test.go.
    35  func testableNetwork(network string) bool {
    36  	switch network {
    37  	case "unix":
    38  		switch runtime.GOOS {
    39  		case "android", "nacl", "plan9", "windows":
    40  			return false
    41  		}
    42  		if runtime.GOOS == "darwin" && (runtime.GOARCH == "arm" || runtime.GOARCH == "arm64") {
    43  			return false
    44  		}
    45  	case "unixpacket":
    46  		switch runtime.GOOS {
    47  		case "android", "darwin", "nacl", "plan9", "windows", "freebsd":
    48  			return false
    49  		}
    50  	}
    51  	return true
    52  }
    53  
    54  func newLocalListener(network string) (net.Listener, error) {
    55  	switch network {
    56  	case "tcp":
    57  		ln, err := net.Listen("tcp", "127.0.0.1:0")
    58  		if err != nil {
    59  			ln, err = net.Listen("tcp6", "[::1]:0")
    60  		}
    61  		return ln, err
    62  	case "unix", "unixpacket":
    63  		return net.Listen(network, testUnixAddr())
    64  	}
    65  	return nil, fmt.Errorf("%s is not supported", network)
    66  }
    67  
    68  func TestTestConn(t *testing.T) {
    69  	tests := []struct{ name, network string }{
    70  		{"TCP", "tcp"},
    71  		{"UnixPipe", "unix"},
    72  		{"UnixPacketPipe", "unixpacket"},
    73  	}
    74  
    75  	for _, tt := range tests {
    76  		t.Run(tt.name, func(t *testing.T) {
    77  			if !testableNetwork(tt.network) {
    78  				t.Skipf("not supported on %s", runtime.GOOS)
    79  			}
    80  
    81  			mp := func() (c1, c2 net.Conn, stop func(), err error) {
    82  				ln, err := newLocalListener(tt.network)
    83  				if err != nil {
    84  					return nil, nil, nil, err
    85  				}
    86  
    87  				// Start a connection between two endpoints.
    88  				var err1, err2 error
    89  				done := make(chan bool)
    90  				go func() {
    91  					c2, err2 = ln.Accept()
    92  					close(done)
    93  				}()
    94  				c1, err1 = net.Dial(ln.Addr().Network(), ln.Addr().String())
    95  				<-done
    96  
    97  				stop = func() {
    98  					if err1 == nil {
    99  						c1.Close()
   100  					}
   101  					if err2 == nil {
   102  						c2.Close()
   103  					}
   104  					ln.Close()
   105  					switch tt.network {
   106  					case "unix", "unixpacket":
   107  						os.Remove(ln.Addr().String())
   108  					}
   109  				}
   110  
   111  				switch {
   112  				case err1 != nil:
   113  					stop()
   114  					return nil, nil, nil, err1
   115  				case err2 != nil:
   116  					stop()
   117  					return nil, nil, nil, err2
   118  				default:
   119  					return c1, c2, stop, nil
   120  				}
   121  			}
   122  
   123  			TestConn(t, mp)
   124  		})
   125  	}
   126  }