github.com/maenmax/kairep@v0.0.0-20210218001208-55bf3df36788/src/golang.org/x/crypto/ssh/test/forward_unix_test.go (about)

     1  // Copyright 2012 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // +build darwin dragonfly freebsd linux netbsd openbsd
     6  
     7  package test
     8  
     9  import (
    10  	"bytes"
    11  	"io"
    12  	"io/ioutil"
    13  	"math/rand"
    14  	"net"
    15  	"testing"
    16  	"time"
    17  )
    18  
    19  func TestPortForward(t *testing.T) {
    20  	server := newServer(t)
    21  	defer server.Shutdown()
    22  	conn := server.Dial(clientConfig())
    23  	defer conn.Close()
    24  
    25  	sshListener, err := conn.Listen("tcp", "localhost:0")
    26  	if err != nil {
    27  		t.Fatal(err)
    28  	}
    29  
    30  	go func() {
    31  		sshConn, err := sshListener.Accept()
    32  		if err != nil {
    33  			t.Fatalf("listen.Accept failed: %v", err)
    34  		}
    35  
    36  		_, err = io.Copy(sshConn, sshConn)
    37  		if err != nil && err != io.EOF {
    38  			t.Fatalf("ssh client copy: %v", err)
    39  		}
    40  		sshConn.Close()
    41  	}()
    42  
    43  	forwardedAddr := sshListener.Addr().String()
    44  	tcpConn, err := net.Dial("tcp", forwardedAddr)
    45  	if err != nil {
    46  		t.Fatalf("TCP dial failed: %v", err)
    47  	}
    48  
    49  	readChan := make(chan []byte)
    50  	go func() {
    51  		data, _ := ioutil.ReadAll(tcpConn)
    52  		readChan <- data
    53  	}()
    54  
    55  	// Invent some data.
    56  	data := make([]byte, 100*1000)
    57  	for i := range data {
    58  		data[i] = byte(i % 255)
    59  	}
    60  
    61  	var sent []byte
    62  	for len(sent) < 1000*1000 {
    63  		// Send random sized chunks
    64  		m := rand.Intn(len(data))
    65  		n, err := tcpConn.Write(data[:m])
    66  		if err != nil {
    67  			break
    68  		}
    69  		sent = append(sent, data[:n]...)
    70  	}
    71  	if err := tcpConn.(*net.TCPConn).CloseWrite(); err != nil {
    72  		t.Errorf("tcpConn.CloseWrite: %v", err)
    73  	}
    74  
    75  	read := <-readChan
    76  
    77  	if len(sent) != len(read) {
    78  		t.Fatalf("got %d bytes, want %d", len(read), len(sent))
    79  	}
    80  	if bytes.Compare(sent, read) != 0 {
    81  		t.Fatalf("read back data does not match")
    82  	}
    83  
    84  	if err := sshListener.Close(); err != nil {
    85  		t.Fatalf("sshListener.Close: %v", err)
    86  	}
    87  
    88  	// Check that the forward disappeared.
    89  	tcpConn, err = net.Dial("tcp", forwardedAddr)
    90  	if err == nil {
    91  		tcpConn.Close()
    92  		t.Errorf("still listening to %s after closing", forwardedAddr)
    93  	}
    94  }
    95  
    96  func TestAcceptClose(t *testing.T) {
    97  	server := newServer(t)
    98  	defer server.Shutdown()
    99  	conn := server.Dial(clientConfig())
   100  
   101  	sshListener, err := conn.Listen("tcp", "localhost:0")
   102  	if err != nil {
   103  		t.Fatal(err)
   104  	}
   105  
   106  	quit := make(chan error, 1)
   107  	go func() {
   108  		for {
   109  			c, err := sshListener.Accept()
   110  			if err != nil {
   111  				quit <- err
   112  				break
   113  			}
   114  			c.Close()
   115  		}
   116  	}()
   117  	sshListener.Close()
   118  
   119  	select {
   120  	case <-time.After(1 * time.Second):
   121  		t.Errorf("timeout: listener did not close.")
   122  	case err := <-quit:
   123  		t.Logf("quit as expected (error %v)", err)
   124  	}
   125  }
   126  
   127  // Check that listeners exit if the underlying client transport dies.
   128  func TestPortForwardConnectionClose(t *testing.T) {
   129  	server := newServer(t)
   130  	defer server.Shutdown()
   131  	conn := server.Dial(clientConfig())
   132  
   133  	sshListener, err := conn.Listen("tcp", "localhost:0")
   134  	if err != nil {
   135  		t.Fatal(err)
   136  	}
   137  
   138  	quit := make(chan error, 1)
   139  	go func() {
   140  		for {
   141  			c, err := sshListener.Accept()
   142  			if err != nil {
   143  				quit <- err
   144  				break
   145  			}
   146  			c.Close()
   147  		}
   148  	}()
   149  
   150  	// It would be even nicer if we closed the server side, but it
   151  	// is more involved as the fd for that side is dup()ed.
   152  	server.clientConn.Close()
   153  
   154  	select {
   155  	case <-time.After(1 * time.Second):
   156  		t.Errorf("timeout: listener did not close.")
   157  	case err := <-quit:
   158  		t.Logf("quit as expected (error %v)", err)
   159  	}
   160  }