github.com/glycerine/xcryptossh@v7.0.4+incompatible/test/dial_unix_test.go (about) 1 // Copyright 2012 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 // +build !windows 6 7 package test 8 9 // direct-tcpip and direct-streamlocal functional tests 10 11 import ( 12 "context" 13 "fmt" 14 "io" 15 "io/ioutil" 16 "net" 17 "strings" 18 "testing" 19 20 ssh "github.com/glycerine/xcryptossh" 21 ) 22 23 type dialTester interface { 24 TestServerConn(t *testing.T, c net.Conn) 25 TestClientConn(t *testing.T, c net.Conn) 26 } 27 28 func testDial(t *testing.T, n, listenAddr string, x dialTester) { 29 ctx, cancelctx := context.WithCancel(context.Background()) 30 defer cancelctx() 31 halt := ssh.NewHalter() 32 defer halt.ReqStop.Close() 33 34 server := newServer(t) 35 defer server.Shutdown() 36 sshConn := server.Dial(ctx, clientConfig(halt)) 37 defer sshConn.Close() 38 39 l, err := net.Listen(n, listenAddr) 40 if err != nil { 41 t.Fatalf("Listen: %v", err) 42 } 43 defer l.Close() 44 45 testData := fmt.Sprintf("hello from %s, %s", n, listenAddr) 46 go func() { 47 for { 48 c, err := l.Accept() 49 if err != nil { 50 break 51 } 52 x.TestServerConn(t, c) 53 54 io.WriteString(c, testData) 55 c.Close() 56 } 57 }() 58 59 conn, err := sshConn.Dial(n, l.Addr().String()) 60 if err != nil { 61 t.Fatalf("Dial: %v", err) 62 } 63 x.TestClientConn(t, conn) 64 defer conn.Close() 65 b, err := ioutil.ReadAll(conn) 66 if err != nil { 67 t.Fatalf("ReadAll: %v", err) 68 } 69 t.Logf("got %q", string(b)) 70 if string(b) != testData { 71 t.Fatalf("expected %q, got %q", testData, string(b)) 72 } 73 } 74 75 type tcpDialTester struct { 76 listenAddr string 77 } 78 79 func (x *tcpDialTester) TestServerConn(t *testing.T, c net.Conn) { 80 host := strings.Split(x.listenAddr, ":")[0] 81 prefix := host + ":" 82 if !strings.HasPrefix(c.LocalAddr().String(), prefix) { 83 t.Fatalf("expected to start with %q, got %q", prefix, c.LocalAddr().String()) 84 } 85 if !strings.HasPrefix(c.RemoteAddr().String(), prefix) { 86 t.Fatalf("expected to start with %q, got %q", prefix, c.RemoteAddr().String()) 87 } 88 } 89 90 func (x *tcpDialTester) TestClientConn(t *testing.T, c net.Conn) { 91 // we use zero addresses. see *Client.Dial. 92 if c.LocalAddr().String() != "0.0.0.0:0" { 93 t.Fatalf("expected \"0.0.0.0:0\", got %q", c.LocalAddr().String()) 94 } 95 if c.RemoteAddr().String() != "0.0.0.0:0" { 96 t.Fatalf("expected \"0.0.0.0:0\", got %q", c.RemoteAddr().String()) 97 } 98 } 99 100 func TestDialTCP(t *testing.T) { 101 x := &tcpDialTester{ 102 listenAddr: "127.0.0.1:0", 103 } 104 testDial(t, "tcp", x.listenAddr, x) 105 } 106 107 type unixDialTester struct { 108 listenAddr string 109 } 110 111 func (x *unixDialTester) TestServerConn(t *testing.T, c net.Conn) { 112 if c.LocalAddr().String() != x.listenAddr { 113 t.Fatalf("expected %q, got %q", x.listenAddr, c.LocalAddr().String()) 114 } 115 if c.RemoteAddr().String() != "@" { 116 t.Fatalf("expected \"@\", got %q", c.RemoteAddr().String()) 117 } 118 } 119 120 func (x *unixDialTester) TestClientConn(t *testing.T, c net.Conn) { 121 if c.RemoteAddr().String() != x.listenAddr { 122 t.Fatalf("expected %q, got %q", x.listenAddr, c.RemoteAddr().String()) 123 } 124 if c.LocalAddr().String() != "@" { 125 t.Fatalf("expected \"@\", got %q", c.LocalAddr().String()) 126 } 127 } 128 129 func TestDialUnix(t *testing.T) { 130 addr, cleanup := newTempSocket(t) 131 defer cleanup() 132 x := &unixDialTester{ 133 listenAddr: addr, 134 } 135 testDial(t, "unix", x.listenAddr, x) 136 }