github.com/devops-filetransfer/sshego@v7.0.4+incompatible/ud_test.go (about)

     1  package sshego
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"net"
     7  	"os"
     8  	"testing"
     9  
    10  	cv "github.com/glycerine/goconvey/convey"
    11  )
    12  
    13  // ud_test.go: unix domain socket test.
    14  
    15  func Test401UnixDomainSocketListening(t *testing.T) {
    16  
    17  	cv.Convey("Instead of -listen and -remote only forwarding via connections, if given a path instead of a port it should listen on a unix domain socket.", t, func() {
    18  
    19  		// generate a random payload for the client to send to the server.
    20  		payloadByteCount := 50
    21  		confirmationPayload := RandomString(payloadByteCount)
    22  		confirmationReply := RandomString(payloadByteCount)
    23  
    24  		serverDone := make(chan bool)
    25  
    26  		udpath := startBackgroundTestUnixDomainServer(
    27  			serverDone,
    28  			payloadByteCount,
    29  			confirmationPayload,
    30  			confirmationReply)
    31  		defer os.Remove(udpath)
    32  
    33  		s := MakeTestSshClientAndServer(true)
    34  		defer TempDirCleanup(s.SrvCfg.Origdir, s.SrvCfg.Tempdir)
    35  
    36  		dest := udpath
    37  
    38  		// below over SSH should be equivalent of the following
    39  		// non-encrypted ping/pong.
    40  
    41  		if false {
    42  			udUnencPingPong(udpath, confirmationPayload, confirmationReply, payloadByteCount)
    43  		}
    44  		if true {
    45  			dc := DialConfig{
    46  				ClientKnownHostsPath: s.CliCfg.ClientKnownHostsPath,
    47  				Mylogin:              s.Mylogin,
    48  				RsaPath:              s.RsaPath,
    49  				TotpUrl:              s.Totp,
    50  				Pw:                   s.Pw,
    51  				Sshdhost:             s.SrvCfg.EmbeddedSSHd.Host,
    52  				Sshdport:             s.SrvCfg.EmbeddedSSHd.Port,
    53  				DownstreamHostPort:   dest,
    54  				TofuAddIfNotKnown:    true,
    55  			}
    56  			ctx := context.Background()
    57  
    58  			// first time we add the server key
    59  			channelToTcpServer, _, _, err := dc.Dial(ctx, nil, false)
    60  			cv.So(err.Error(), cv.ShouldContainSubstring, "Re-run without -new")
    61  
    62  			// second time we connect based on that server key
    63  			dc.TofuAddIfNotKnown = false
    64  			channelToTcpServer, _, _, err = dc.Dial(ctx, nil, false)
    65  			cv.So(err, cv.ShouldBeNil)
    66  
    67  			VerifyClientServerExchangeAcrossSshd(channelToTcpServer, confirmationPayload, confirmationReply, payloadByteCount)
    68  			channelToTcpServer.Close()
    69  		}
    70  		// tcp-server should have exited because it got the expected
    71  		// message and replied with the agreed upon reply and then exited.
    72  		<-serverDone
    73  
    74  		// done with testing, cleanup
    75  		s.SrvCfg.Esshd.Stop()
    76  		<-s.SrvCfg.Esshd.Halt.DoneChan()
    77  		cv.So(true, cv.ShouldEqual, true) // we should get here.
    78  	})
    79  }
    80  
    81  func udUnencPingPong(dest, confirmationPayload, confirmationReply string, payloadByteCount int) {
    82  	conn, err := net.Dial("unix", dest)
    83  	panicOn(err)
    84  	m, err := conn.Write([]byte(confirmationPayload))
    85  	panicOn(err)
    86  	if m != payloadByteCount {
    87  		panic("too short a write!")
    88  	}
    89  
    90  	// check reply
    91  	rep := make([]byte, payloadByteCount)
    92  	m, err = conn.Read(rep)
    93  	panicOn(err)
    94  	if m != payloadByteCount {
    95  		panic("too short a reply!")
    96  	}
    97  	srep := string(rep)
    98  	if srep != confirmationReply {
    99  		panic(fmt.Errorf("saw '%s' but expected '%s'", srep, confirmationReply))
   100  	}
   101  	pp("reply success! server back to -> client: we got the expected srep reply '%s'", srep)
   102  	conn.Close()
   103  }
   104  
   105  func startBackgroundTestUnixDomainServer(serverDone chan bool, payloadByteCount int, confirmationPayload string, confirmationReply string) (udpath string) {
   106  
   107  	udpath = "/tmp/ud_test.sock." + RandomString(20)
   108  	lsn, err := net.Listen("unix", udpath)
   109  	panicOn(err)
   110  
   111  	go func() {
   112  		udServerConn, err := lsn.Accept()
   113  		panicOn(err)
   114  
   115  		b := make([]byte, payloadByteCount)
   116  		n, err := udServerConn.Read(b)
   117  		panicOn(err)
   118  		if n != payloadByteCount {
   119  			panic(fmt.Errorf("read too short! got %v but expected %v", n, payloadByteCount))
   120  		}
   121  		saw := string(b)
   122  
   123  		if saw != confirmationPayload {
   124  			panic(fmt.Errorf("expected '%s', but saw '%s'", confirmationPayload, saw))
   125  		}
   126  
   127  		pp("client -> server success! server got expected confirmation payload of '%s'", saw)
   128  
   129  		// reply back
   130  		n, err = udServerConn.Write([]byte(confirmationReply))
   131  		panicOn(err)
   132  		if n != payloadByteCount {
   133  			panic(fmt.Errorf("write too short! got %v but expected %v", n, payloadByteCount))
   134  		}
   135  		//udServerConn.Close()
   136  		close(serverDone)
   137  	}()
   138  
   139  	return udpath
   140  }