github.com/glycerine/xcryptossh@v7.0.4+incompatible/test/forward_unix_test.go (about)

     1  // Copyright 2012 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // +build darwin dragonfly freebsd linux netbsd openbsd
     6  
     7  package test
     8  
     9  import (
    10  	"bytes"
    11  	"context"
    12  	"io"
    13  	"io/ioutil"
    14  	"math/rand"
    15  	"net"
    16  	"testing"
    17  	"time"
    18  
    19  	ssh "github.com/glycerine/xcryptossh"
    20  )
    21  
    22  type closeWriter interface {
    23  	CloseWrite() error
    24  }
    25  
    26  func testPortForward(t *testing.T, n, listenAddr string) {
    27  	ctx, cancelctx := context.WithCancel(context.Background())
    28  	defer cancelctx()
    29  	halt := ssh.NewHalter()
    30  	defer halt.ReqStop.Close()
    31  
    32  	server := newServer(t)
    33  	defer server.Shutdown()
    34  	conn := server.Dial(ctx, clientConfig(halt))
    35  	defer conn.Close()
    36  
    37  	sshListener, err := conn.Listen(n, listenAddr)
    38  	if err != nil {
    39  		t.Fatal(err)
    40  	}
    41  
    42  	go func() {
    43  		sshConn, err := sshListener.Accept()
    44  		if err != nil {
    45  			t.Fatalf("listen.Accept failed: %v", err)
    46  		}
    47  
    48  		_, err = io.Copy(sshConn, sshConn)
    49  		if err != nil && err != io.EOF {
    50  			t.Fatalf("ssh client copy: %v", err)
    51  		}
    52  		sshConn.Close()
    53  	}()
    54  
    55  	forwardedAddr := sshListener.Addr().String()
    56  	netConn, err := net.Dial(n, forwardedAddr)
    57  	if err != nil {
    58  		t.Fatalf("net dial failed: %v", err)
    59  	}
    60  
    61  	readChan := make(chan []byte)
    62  	go func() {
    63  		data, _ := ioutil.ReadAll(netConn)
    64  		readChan <- data
    65  	}()
    66  
    67  	// Invent some data.
    68  	data := make([]byte, 100*1000)
    69  	for i := range data {
    70  		data[i] = byte(i % 255)
    71  	}
    72  
    73  	var sent []byte
    74  	for len(sent) < 1000*1000 {
    75  		// Send random sized chunks
    76  		m := rand.Intn(len(data))
    77  		n, err := netConn.Write(data[:m])
    78  		if err != nil {
    79  			break
    80  		}
    81  		sent = append(sent, data[:n]...)
    82  	}
    83  	if err := netConn.(closeWriter).CloseWrite(); err != nil {
    84  		t.Errorf("netConn.CloseWrite: %v", err)
    85  	}
    86  
    87  	read := <-readChan
    88  
    89  	if len(sent) != len(read) {
    90  		t.Fatalf("got %d bytes, want %d", len(read), len(sent))
    91  	}
    92  	if bytes.Compare(sent, read) != 0 {
    93  		t.Fatalf("read back data does not match")
    94  	}
    95  
    96  	if err := sshListener.Close(); err != nil {
    97  		t.Fatalf("sshListener.Close: %v", err)
    98  	}
    99  
   100  	// Check that the forward disappeared.
   101  	netConn, err = net.Dial(n, forwardedAddr)
   102  	if err == nil {
   103  		netConn.Close()
   104  		t.Errorf("still listening to %s after closing", forwardedAddr)
   105  	}
   106  }
   107  
   108  func TestPortForwardTCP(t *testing.T) {
   109  	testPortForward(t, "tcp", "localhost:0")
   110  }
   111  
   112  func TestPortForwardUnix(t *testing.T) {
   113  	addr, cleanup := newTempSocket(t)
   114  	defer cleanup()
   115  	testPortForward(t, "unix", addr)
   116  }
   117  
   118  func testAcceptClose(t *testing.T, n, listenAddr string) {
   119  	ctx, cancelctx := context.WithCancel(context.Background())
   120  	defer cancelctx()
   121  	halt := ssh.NewHalter()
   122  	defer halt.ReqStop.Close()
   123  
   124  	server := newServer(t)
   125  	defer server.Shutdown()
   126  	conn := server.Dial(ctx, clientConfig(halt))
   127  
   128  	sshListener, err := conn.Listen(n, listenAddr)
   129  	if err != nil {
   130  		t.Fatal(err)
   131  	}
   132  
   133  	quit := make(chan error, 1)
   134  	go func() {
   135  		for {
   136  			c, err := sshListener.Accept()
   137  			if err != nil {
   138  				quit <- err
   139  				break
   140  			}
   141  			c.Close()
   142  		}
   143  	}()
   144  	sshListener.Close()
   145  
   146  	select {
   147  	case <-time.After(1 * time.Second):
   148  		t.Errorf("timeout: listener did not close.")
   149  	case err := <-quit:
   150  		t.Logf("quit as expected (error %v)", err)
   151  	}
   152  }
   153  
   154  func TestAcceptCloseTCP(t *testing.T) {
   155  	testAcceptClose(t, "tcp", "localhost:0")
   156  }
   157  
   158  func TestAcceptCloseUnix(t *testing.T) {
   159  	addr, cleanup := newTempSocket(t)
   160  	defer cleanup()
   161  	testAcceptClose(t, "unix", addr)
   162  }
   163  
   164  // Check that listeners exit if the underlying client transport dies.
   165  func testPortForwardConnectionClose(t *testing.T, n, listenAddr string) {
   166  	ctx, cancelctx := context.WithCancel(context.Background())
   167  	defer cancelctx()
   168  	halt := ssh.NewHalter()
   169  	defer halt.ReqStop.Close()
   170  
   171  	server := newServer(t)
   172  	defer server.Shutdown()
   173  	conn := server.Dial(ctx, clientConfig(halt))
   174  
   175  	sshListener, err := conn.Listen(n, listenAddr)
   176  	if err != nil {
   177  		t.Fatal(err)
   178  	}
   179  
   180  	quit := make(chan error, 1)
   181  	go func() {
   182  		for {
   183  			c, err := sshListener.Accept()
   184  			if err != nil {
   185  				quit <- err
   186  				break
   187  			}
   188  			c.Close()
   189  		}
   190  	}()
   191  
   192  	// It would be even nicer if we closed the server side, but it
   193  	// is more involved as the fd for that side is dup()ed.
   194  	server.clientConn.Close()
   195  
   196  	select {
   197  	case <-time.After(1 * time.Second):
   198  		t.Errorf("timeout: listener did not close.")
   199  	case err := <-quit:
   200  		t.Logf("quit as expected (error %v)", err)
   201  	}
   202  }
   203  
   204  func TestPortForwardConnectionCloseTCP(t *testing.T) {
   205  	testPortForwardConnectionClose(t, "tcp", "localhost:0")
   206  }
   207  
   208  func TestPortForwardConnectionCloseUnix(t *testing.T) {
   209  	addr, cleanup := newTempSocket(t)
   210  	defer cleanup()
   211  	testPortForwardConnectionClose(t, "unix", addr)
   212  }