github.com/glycerine/xcryptossh@v7.0.4+incompatible/test/forward_unix_test.go (about) 1 // Copyright 2012 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 // +build darwin dragonfly freebsd linux netbsd openbsd 6 7 package test 8 9 import ( 10 "bytes" 11 "context" 12 "io" 13 "io/ioutil" 14 "math/rand" 15 "net" 16 "testing" 17 "time" 18 19 ssh "github.com/glycerine/xcryptossh" 20 ) 21 22 type closeWriter interface { 23 CloseWrite() error 24 } 25 26 func testPortForward(t *testing.T, n, listenAddr string) { 27 ctx, cancelctx := context.WithCancel(context.Background()) 28 defer cancelctx() 29 halt := ssh.NewHalter() 30 defer halt.ReqStop.Close() 31 32 server := newServer(t) 33 defer server.Shutdown() 34 conn := server.Dial(ctx, clientConfig(halt)) 35 defer conn.Close() 36 37 sshListener, err := conn.Listen(n, listenAddr) 38 if err != nil { 39 t.Fatal(err) 40 } 41 42 go func() { 43 sshConn, err := sshListener.Accept() 44 if err != nil { 45 t.Fatalf("listen.Accept failed: %v", err) 46 } 47 48 _, err = io.Copy(sshConn, sshConn) 49 if err != nil && err != io.EOF { 50 t.Fatalf("ssh client copy: %v", err) 51 } 52 sshConn.Close() 53 }() 54 55 forwardedAddr := sshListener.Addr().String() 56 netConn, err := net.Dial(n, forwardedAddr) 57 if err != nil { 58 t.Fatalf("net dial failed: %v", err) 59 } 60 61 readChan := make(chan []byte) 62 go func() { 63 data, _ := ioutil.ReadAll(netConn) 64 readChan <- data 65 }() 66 67 // Invent some data. 68 data := make([]byte, 100*1000) 69 for i := range data { 70 data[i] = byte(i % 255) 71 } 72 73 var sent []byte 74 for len(sent) < 1000*1000 { 75 // Send random sized chunks 76 m := rand.Intn(len(data)) 77 n, err := netConn.Write(data[:m]) 78 if err != nil { 79 break 80 } 81 sent = append(sent, data[:n]...) 82 } 83 if err := netConn.(closeWriter).CloseWrite(); err != nil { 84 t.Errorf("netConn.CloseWrite: %v", err) 85 } 86 87 read := <-readChan 88 89 if len(sent) != len(read) { 90 t.Fatalf("got %d bytes, want %d", len(read), len(sent)) 91 } 92 if bytes.Compare(sent, read) != 0 { 93 t.Fatalf("read back data does not match") 94 } 95 96 if err := sshListener.Close(); err != nil { 97 t.Fatalf("sshListener.Close: %v", err) 98 } 99 100 // Check that the forward disappeared. 101 netConn, err = net.Dial(n, forwardedAddr) 102 if err == nil { 103 netConn.Close() 104 t.Errorf("still listening to %s after closing", forwardedAddr) 105 } 106 } 107 108 func TestPortForwardTCP(t *testing.T) { 109 testPortForward(t, "tcp", "localhost:0") 110 } 111 112 func TestPortForwardUnix(t *testing.T) { 113 addr, cleanup := newTempSocket(t) 114 defer cleanup() 115 testPortForward(t, "unix", addr) 116 } 117 118 func testAcceptClose(t *testing.T, n, listenAddr string) { 119 ctx, cancelctx := context.WithCancel(context.Background()) 120 defer cancelctx() 121 halt := ssh.NewHalter() 122 defer halt.ReqStop.Close() 123 124 server := newServer(t) 125 defer server.Shutdown() 126 conn := server.Dial(ctx, clientConfig(halt)) 127 128 sshListener, err := conn.Listen(n, listenAddr) 129 if err != nil { 130 t.Fatal(err) 131 } 132 133 quit := make(chan error, 1) 134 go func() { 135 for { 136 c, err := sshListener.Accept() 137 if err != nil { 138 quit <- err 139 break 140 } 141 c.Close() 142 } 143 }() 144 sshListener.Close() 145 146 select { 147 case <-time.After(1 * time.Second): 148 t.Errorf("timeout: listener did not close.") 149 case err := <-quit: 150 t.Logf("quit as expected (error %v)", err) 151 } 152 } 153 154 func TestAcceptCloseTCP(t *testing.T) { 155 testAcceptClose(t, "tcp", "localhost:0") 156 } 157 158 func TestAcceptCloseUnix(t *testing.T) { 159 addr, cleanup := newTempSocket(t) 160 defer cleanup() 161 testAcceptClose(t, "unix", addr) 162 } 163 164 // Check that listeners exit if the underlying client transport dies. 165 func testPortForwardConnectionClose(t *testing.T, n, listenAddr string) { 166 ctx, cancelctx := context.WithCancel(context.Background()) 167 defer cancelctx() 168 halt := ssh.NewHalter() 169 defer halt.ReqStop.Close() 170 171 server := newServer(t) 172 defer server.Shutdown() 173 conn := server.Dial(ctx, clientConfig(halt)) 174 175 sshListener, err := conn.Listen(n, listenAddr) 176 if err != nil { 177 t.Fatal(err) 178 } 179 180 quit := make(chan error, 1) 181 go func() { 182 for { 183 c, err := sshListener.Accept() 184 if err != nil { 185 quit <- err 186 break 187 } 188 c.Close() 189 } 190 }() 191 192 // It would be even nicer if we closed the server side, but it 193 // is more involved as the fd for that side is dup()ed. 194 server.clientConn.Close() 195 196 select { 197 case <-time.After(1 * time.Second): 198 t.Errorf("timeout: listener did not close.") 199 case err := <-quit: 200 t.Logf("quit as expected (error %v)", err) 201 } 202 } 203 204 func TestPortForwardConnectionCloseTCP(t *testing.T) { 205 testPortForwardConnectionClose(t, "tcp", "localhost:0") 206 } 207 208 func TestPortForwardConnectionCloseUnix(t *testing.T) { 209 addr, cleanup := newTempSocket(t) 210 defer cleanup() 211 testPortForwardConnectionClose(t, "unix", addr) 212 }