github.com/glycerine/xcryptossh@v7.0.4+incompatible/test/dial_unix_test.go (about)

     1  // Copyright 2012 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // +build !windows
     6  
     7  package test
     8  
     9  // direct-tcpip and direct-streamlocal functional tests
    10  
    11  import (
    12  	"context"
    13  	"fmt"
    14  	"io"
    15  	"io/ioutil"
    16  	"net"
    17  	"strings"
    18  	"testing"
    19  
    20  	ssh "github.com/glycerine/xcryptossh"
    21  )
    22  
    23  type dialTester interface {
    24  	TestServerConn(t *testing.T, c net.Conn)
    25  	TestClientConn(t *testing.T, c net.Conn)
    26  }
    27  
    28  func testDial(t *testing.T, n, listenAddr string, x dialTester) {
    29  	ctx, cancelctx := context.WithCancel(context.Background())
    30  	defer cancelctx()
    31  	halt := ssh.NewHalter()
    32  	defer halt.ReqStop.Close()
    33  
    34  	server := newServer(t)
    35  	defer server.Shutdown()
    36  	sshConn := server.Dial(ctx, clientConfig(halt))
    37  	defer sshConn.Close()
    38  
    39  	l, err := net.Listen(n, listenAddr)
    40  	if err != nil {
    41  		t.Fatalf("Listen: %v", err)
    42  	}
    43  	defer l.Close()
    44  
    45  	testData := fmt.Sprintf("hello from %s, %s", n, listenAddr)
    46  	go func() {
    47  		for {
    48  			c, err := l.Accept()
    49  			if err != nil {
    50  				break
    51  			}
    52  			x.TestServerConn(t, c)
    53  
    54  			io.WriteString(c, testData)
    55  			c.Close()
    56  		}
    57  	}()
    58  
    59  	conn, err := sshConn.Dial(n, l.Addr().String())
    60  	if err != nil {
    61  		t.Fatalf("Dial: %v", err)
    62  	}
    63  	x.TestClientConn(t, conn)
    64  	defer conn.Close()
    65  	b, err := ioutil.ReadAll(conn)
    66  	if err != nil {
    67  		t.Fatalf("ReadAll: %v", err)
    68  	}
    69  	t.Logf("got %q", string(b))
    70  	if string(b) != testData {
    71  		t.Fatalf("expected %q, got %q", testData, string(b))
    72  	}
    73  }
    74  
    75  type tcpDialTester struct {
    76  	listenAddr string
    77  }
    78  
    79  func (x *tcpDialTester) TestServerConn(t *testing.T, c net.Conn) {
    80  	host := strings.Split(x.listenAddr, ":")[0]
    81  	prefix := host + ":"
    82  	if !strings.HasPrefix(c.LocalAddr().String(), prefix) {
    83  		t.Fatalf("expected to start with %q, got %q", prefix, c.LocalAddr().String())
    84  	}
    85  	if !strings.HasPrefix(c.RemoteAddr().String(), prefix) {
    86  		t.Fatalf("expected to start with %q, got %q", prefix, c.RemoteAddr().String())
    87  	}
    88  }
    89  
    90  func (x *tcpDialTester) TestClientConn(t *testing.T, c net.Conn) {
    91  	// we use zero addresses. see *Client.Dial.
    92  	if c.LocalAddr().String() != "0.0.0.0:0" {
    93  		t.Fatalf("expected \"0.0.0.0:0\", got %q", c.LocalAddr().String())
    94  	}
    95  	if c.RemoteAddr().String() != "0.0.0.0:0" {
    96  		t.Fatalf("expected \"0.0.0.0:0\", got %q", c.RemoteAddr().String())
    97  	}
    98  }
    99  
   100  func TestDialTCP(t *testing.T) {
   101  	x := &tcpDialTester{
   102  		listenAddr: "127.0.0.1:0",
   103  	}
   104  	testDial(t, "tcp", x.listenAddr, x)
   105  }
   106  
   107  type unixDialTester struct {
   108  	listenAddr string
   109  }
   110  
   111  func (x *unixDialTester) TestServerConn(t *testing.T, c net.Conn) {
   112  	if c.LocalAddr().String() != x.listenAddr {
   113  		t.Fatalf("expected %q, got %q", x.listenAddr, c.LocalAddr().String())
   114  	}
   115  	if c.RemoteAddr().String() != "@" {
   116  		t.Fatalf("expected \"@\", got %q", c.RemoteAddr().String())
   117  	}
   118  }
   119  
   120  func (x *unixDialTester) TestClientConn(t *testing.T, c net.Conn) {
   121  	if c.RemoteAddr().String() != x.listenAddr {
   122  		t.Fatalf("expected %q, got %q", x.listenAddr, c.RemoteAddr().String())
   123  	}
   124  	if c.LocalAddr().String() != "@" {
   125  		t.Fatalf("expected \"@\", got %q", c.LocalAddr().String())
   126  	}
   127  }
   128  
   129  func TestDialUnix(t *testing.T) {
   130  	addr, cleanup := newTempSocket(t)
   131  	defer cleanup()
   132  	x := &unixDialTester{
   133  		listenAddr: addr,
   134  	}
   135  	testDial(t, "unix", x.listenAddr, x)
   136  }