github.com/pkg/sftp@v1.13.6/server_integration_test.go (about)

     1  package sftp
     2  
     3  // sftp server integration tests
     4  // enable with -integration
     5  // example invokation (darwin): gofmt -w `find . -name \*.go` && (cd server_standalone/ ; go build -tags debug) && go test -tags debug github.com/pkg/sftp -integration -v -sftp /usr/libexec/sftp-server -run ServerCompareSubsystems
     6  
     7  import (
     8  	"bytes"
     9  	"crypto/ecdsa"
    10  	"crypto/elliptic"
    11  	crand "crypto/rand"
    12  	"crypto/x509"
    13  	"encoding/hex"
    14  	"encoding/pem"
    15  	"flag"
    16  	"fmt"
    17  	"io/ioutil"
    18  	"math/rand"
    19  	"net"
    20  	"os"
    21  	"os/exec"
    22  	"path/filepath"
    23  	"regexp"
    24  	"runtime"
    25  	"strconv"
    26  	"strings"
    27  	"testing"
    28  	"time"
    29  
    30  	"github.com/kr/fs"
    31  	"github.com/stretchr/testify/assert"
    32  	"golang.org/x/crypto/ssh"
    33  )
    34  
    35  func TestMain(m *testing.M) {
    36  	sftpClientLocation, _ := exec.LookPath("sftp")
    37  	testSftpClientBin = flag.String("sftp_client", sftpClientLocation, "location of the sftp client binary")
    38  
    39  	lookSFTPServer := []string{
    40  		"/usr/libexec/sftp-server",
    41  		"/usr/lib/openssh/sftp-server",
    42  		"/usr/lib/ssh/sftp-server",
    43  		"C:\\Program Files\\Git\\usr\\lib\\ssh\\sftp-server.exe",
    44  	}
    45  	sftpServer, _ := exec.LookPath("sftp-server")
    46  	if len(sftpServer) == 0 {
    47  		for _, location := range lookSFTPServer {
    48  			if _, err := os.Stat(location); err == nil {
    49  				sftpServer = location
    50  				break
    51  			}
    52  		}
    53  	}
    54  	testSftp = flag.String("sftp", sftpServer, "location of the sftp server binary")
    55  	flag.Parse()
    56  
    57  	os.Exit(m.Run())
    58  }
    59  
    60  func skipIfWindows(t testing.TB) {
    61  	if runtime.GOOS == "windows" {
    62  		t.Skip("skipping test on windows")
    63  	}
    64  }
    65  
    66  func skipIfPlan9(t testing.TB) {
    67  	if runtime.GOOS == "plan9" {
    68  		t.Skip("skipping test on plan9")
    69  	}
    70  }
    71  
    72  var testServerImpl = flag.Bool("testserver", false, "perform integration tests against sftp package server instance")
    73  var testIntegration = flag.Bool("integration", false, "perform integration tests against sftp server process")
    74  var testAllocator = flag.Bool("allocator", false, "perform tests using the allocator")
    75  var testSftp *string
    76  
    77  var testSftpClientBin *string
    78  var sshServerDebugStream = ioutil.Discard
    79  var sftpServerDebugStream = ioutil.Discard
    80  var sftpClientDebugStream = ioutil.Discard
    81  
    82  const (
    83  	GolangSFTP  = true
    84  	OpenSSHSFTP = false
    85  )
    86  
    87  var (
    88  	hostPrivateKeySigner ssh.Signer
    89  	privKey              = []byte(`
    90  -----BEGIN RSA PRIVATE KEY-----
    91  MIIEowIBAAKCAQEArhp7SqFnXVZAgWREL9Ogs+miy4IU/m0vmdkoK6M97G9NX/Pj
    92  wf8I/3/ynxmcArbt8Rc4JgkjT2uxx/NqR0yN42N1PjO5Czu0dms1PSqcKIJdeUBV
    93  7gdrKSm9Co4d2vwfQp5mg47eG4w63pz7Drk9+VIyi9YiYH4bve7WnGDswn4ycvYZ
    94  slV5kKnjlfCdPig+g5P7yQYud0cDWVwyA0+kxvL6H3Ip+Fu8rLDZn4/P1WlFAIuc
    95  PAf4uEKDGGmC2URowi5eesYR7f6GN/HnBs2776laNlAVXZUmYTUfOGagwLsEkx8x
    96  XdNqntfbs2MOOoK+myJrNtcB9pCrM0H6um19uQIDAQABAoIBABkWr9WdVKvalgkP
    97  TdQmhu3mKRNyd1wCl+1voZ5IM9Ayac/98UAvZDiNU4Uhx52MhtVLJ0gz4Oa8+i16
    98  IkKMAZZW6ro/8dZwkBzQbieWUFJ2Fso2PyvB3etcnGU8/Yhk9IxBDzy+BbuqhYE2
    99  1ebVQtz+v1HvVZzaD11bYYm/Xd7Y28QREVfFen30Q/v3dv7dOteDE/RgDS8Czz7w
   100  jMW32Q8JL5grz7zPkMK39BLXsTcSYcaasT2ParROhGJZDmbgd3l33zKCVc1zcj9B
   101  SA47QljGd09Tys958WWHgtj2o7bp9v1Ufs4LnyKgzrB80WX1ovaSQKvd5THTLchO
   102  kLIhUAECgYEA2doGXy9wMBmTn/hjiVvggR1aKiBwUpnB87Hn5xCMgoECVhFZlT6l
   103  WmZe7R2klbtG1aYlw+y+uzHhoVDAJW9AUSV8qoDUwbRXvBVlp+In5wIqJ+VjfivK
   104  zgIfzomL5NvDz37cvPmzqIeySTowEfbQyq7CUQSoDtE9H97E2wWZhDkCgYEAzJdJ
   105  k+NSFoTkHhfD3L0xCDHpRV3gvaOeew8524fVtVUq53X8m91ng4AX1r74dCUYwwiF
   106  gqTtSSJfx2iH1xKnNq28M9uKg7wOrCKrRqNPnYUO3LehZEC7rwUr26z4iJDHjjoB
   107  uBcS7nw0LJ+0Zeg1IF+aIdZGV3MrAKnrzWPixYECgYBsffX6ZWebrMEmQ89eUtFF
   108  u9ZxcGI/4K8ErC7vlgBD5ffB4TYZ627xzFWuBLs4jmHCeNIJ9tct5rOVYN+wRO1k
   109  /CRPzYUnSqb+1jEgILL6istvvv+DkE+ZtNkeRMXUndWwel94BWsBnUKe0UmrSJ3G
   110  sq23J3iCmJW2T3z+DpXbkQKBgQCK+LUVDNPE0i42NsRnm+fDfkvLP7Kafpr3Umdl
   111  tMY474o+QYn+wg0/aPJIf9463rwMNyyhirBX/k57IIktUdFdtfPicd2MEGETElWv
   112  nN1GzYxD50Rs2f/jKisZhEwqT9YNyV9DkgDdGGdEbJNYqbv0qpwDIg8T9foe8E1p
   113  bdErgQKBgAt290I3L316cdxIQTkJh1DlScN/unFffITwu127WMr28Jt3mq3cZpuM
   114  Aecey/eEKCj+Rlas5NDYKsB18QIuAw+qqWyq0LAKLiAvP1965Rkc4PLScl3MgJtO
   115  QYa37FK0p8NcDeUuF86zXBVutwS5nJLchHhKfd590ks57OROtm29
   116  -----END RSA PRIVATE KEY-----
   117  `)
   118  )
   119  
   120  func init() {
   121  	var err error
   122  	hostPrivateKeySigner, err = ssh.ParsePrivateKey(privKey)
   123  	if err != nil {
   124  		panic(err)
   125  	}
   126  }
   127  
   128  func keyAuth(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
   129  	permissions := &ssh.Permissions{
   130  		CriticalOptions: map[string]string{},
   131  		Extensions:      map[string]string{},
   132  	}
   133  	return permissions, nil
   134  }
   135  
   136  func pwAuth(conn ssh.ConnMetadata, pw []byte) (*ssh.Permissions, error) {
   137  	permissions := &ssh.Permissions{
   138  		CriticalOptions: map[string]string{},
   139  		Extensions:      map[string]string{},
   140  	}
   141  	return permissions, nil
   142  }
   143  
   144  func basicServerConfig() *ssh.ServerConfig {
   145  	config := ssh.ServerConfig{
   146  		Config: ssh.Config{
   147  			MACs: []string{"hmac-sha1"},
   148  		},
   149  		PasswordCallback:  pwAuth,
   150  		PublicKeyCallback: keyAuth,
   151  	}
   152  	config.AddHostKey(hostPrivateKeySigner)
   153  	return &config
   154  }
   155  
   156  type sshServer struct {
   157  	useSubsystem bool
   158  	conn         net.Conn
   159  	config       *ssh.ServerConfig
   160  	sshConn      *ssh.ServerConn
   161  	newChans     <-chan ssh.NewChannel
   162  	newReqs      <-chan *ssh.Request
   163  }
   164  
   165  func sshServerFromConn(conn net.Conn, useSubsystem bool, config *ssh.ServerConfig) (*sshServer, error) {
   166  	// From a standard TCP connection to an encrypted SSH connection
   167  	sshConn, newChans, newReqs, err := ssh.NewServerConn(conn, config)
   168  	if err != nil {
   169  		return nil, err
   170  	}
   171  
   172  	svr := &sshServer{useSubsystem, conn, config, sshConn, newChans, newReqs}
   173  	svr.listenChannels()
   174  	return svr, nil
   175  }
   176  
   177  func (svr *sshServer) Wait() error {
   178  	return svr.sshConn.Wait()
   179  }
   180  
   181  func (svr *sshServer) Close() error {
   182  	return svr.sshConn.Close()
   183  }
   184  
   185  func (svr *sshServer) listenChannels() {
   186  	go func() {
   187  		for chanReq := range svr.newChans {
   188  			go svr.handleChanReq(chanReq)
   189  		}
   190  	}()
   191  	go func() {
   192  		for req := range svr.newReqs {
   193  			go svr.handleReq(req)
   194  		}
   195  	}()
   196  }
   197  
   198  func (svr *sshServer) handleReq(req *ssh.Request) {
   199  	switch req.Type {
   200  	default:
   201  		rejectRequest(req)
   202  	}
   203  }
   204  
   205  type sshChannelServer struct {
   206  	svr     *sshServer
   207  	chanReq ssh.NewChannel
   208  	ch      ssh.Channel
   209  	newReqs <-chan *ssh.Request
   210  }
   211  
   212  type sshSessionChannelServer struct {
   213  	*sshChannelServer
   214  	env []string
   215  }
   216  
   217  func (svr *sshServer) handleChanReq(chanReq ssh.NewChannel) {
   218  	fmt.Fprintf(sshServerDebugStream, "channel request: %v, extra: '%v'\n", chanReq.ChannelType(), hex.EncodeToString(chanReq.ExtraData()))
   219  	switch chanReq.ChannelType() {
   220  	case "session":
   221  		if ch, reqs, err := chanReq.Accept(); err != nil {
   222  			fmt.Fprintf(sshServerDebugStream, "fail to accept channel request: %v\n", err)
   223  			chanReq.Reject(ssh.ResourceShortage, "channel accept failure")
   224  		} else {
   225  			chsvr := &sshSessionChannelServer{
   226  				sshChannelServer: &sshChannelServer{svr, chanReq, ch, reqs},
   227  				env:              append([]string{}, os.Environ()...),
   228  			}
   229  			chsvr.handle()
   230  		}
   231  	default:
   232  		chanReq.Reject(ssh.UnknownChannelType, "channel type is not a session")
   233  	}
   234  }
   235  
   236  func (chsvr *sshSessionChannelServer) handle() {
   237  	// should maybe do something here...
   238  	go chsvr.handleReqs()
   239  }
   240  
   241  func (chsvr *sshSessionChannelServer) handleReqs() {
   242  	for req := range chsvr.newReqs {
   243  		chsvr.handleReq(req)
   244  	}
   245  	fmt.Fprintf(sshServerDebugStream, "ssh server session channel complete\n")
   246  }
   247  
   248  func (chsvr *sshSessionChannelServer) handleReq(req *ssh.Request) {
   249  	switch req.Type {
   250  	case "env":
   251  		chsvr.handleEnv(req)
   252  	case "subsystem":
   253  		chsvr.handleSubsystem(req)
   254  	default:
   255  		rejectRequest(req)
   256  	}
   257  }
   258  
   259  func rejectRequest(req *ssh.Request) error {
   260  	fmt.Fprintf(sshServerDebugStream, "ssh rejecting request, type: %s\n", req.Type)
   261  	err := req.Reply(false, []byte{})
   262  	if err != nil {
   263  		fmt.Fprintf(sshServerDebugStream, "ssh request reply had error: %v\n", err)
   264  	}
   265  	return err
   266  }
   267  
   268  func rejectRequestUnmarshalError(req *ssh.Request, s interface{}, err error) error {
   269  	fmt.Fprintf(sshServerDebugStream, "ssh request unmarshaling error, type '%T': %v\n", s, err)
   270  	rejectRequest(req)
   271  	return err
   272  }
   273  
   274  // env request form:
   275  type sshEnvRequest struct {
   276  	Envvar string
   277  	Value  string
   278  }
   279  
   280  func (chsvr *sshSessionChannelServer) handleEnv(req *ssh.Request) error {
   281  	envReq := &sshEnvRequest{}
   282  	if err := ssh.Unmarshal(req.Payload, envReq); err != nil {
   283  		return rejectRequestUnmarshalError(req, envReq, err)
   284  	}
   285  	req.Reply(true, nil)
   286  
   287  	found := false
   288  	for i, envstr := range chsvr.env {
   289  		if strings.HasPrefix(envstr, envReq.Envvar+"=") {
   290  			found = true
   291  			chsvr.env[i] = envReq.Envvar + "=" + envReq.Value
   292  		}
   293  	}
   294  	if !found {
   295  		chsvr.env = append(chsvr.env, envReq.Envvar+"="+envReq.Value)
   296  	}
   297  
   298  	return nil
   299  }
   300  
   301  // Payload: int: command size, string: command
   302  type sshSubsystemRequest struct {
   303  	Name string
   304  }
   305  
   306  type sshSubsystemExitStatus struct {
   307  	Status uint32
   308  }
   309  
   310  func (chsvr *sshSessionChannelServer) handleSubsystem(req *ssh.Request) error {
   311  	defer func() {
   312  		err1 := chsvr.ch.CloseWrite()
   313  		err2 := chsvr.ch.Close()
   314  		fmt.Fprintf(sshServerDebugStream, "ssh server subsystem request complete, err: %v %v\n", err1, err2)
   315  	}()
   316  
   317  	subsystemReq := &sshSubsystemRequest{}
   318  	if err := ssh.Unmarshal(req.Payload, subsystemReq); err != nil {
   319  		return rejectRequestUnmarshalError(req, subsystemReq, err)
   320  	}
   321  
   322  	// reply to the ssh client
   323  
   324  	// no idea if this is actually correct spec-wise.
   325  	// just enough for an sftp server to start.
   326  	if subsystemReq.Name != "sftp" {
   327  		return req.Reply(false, nil)
   328  	}
   329  
   330  	req.Reply(true, nil)
   331  
   332  	if !chsvr.svr.useSubsystem {
   333  		// use the openssh sftp server backend; this is to test the ssh code, not the sftp code,
   334  		// or is used for comparison between our sftp subsystem and the openssh sftp subsystem
   335  		cmd := exec.Command(*testSftp, "-e", "-l", "DEBUG") // log to stderr
   336  		cmd.Stdin = chsvr.ch
   337  		cmd.Stdout = chsvr.ch
   338  		cmd.Stderr = sftpServerDebugStream
   339  		if err := cmd.Start(); err != nil {
   340  			return err
   341  		}
   342  		return cmd.Wait()
   343  	}
   344  
   345  	sftpServer, err := NewServer(
   346  		chsvr.ch,
   347  		WithDebug(sftpServerDebugStream),
   348  	)
   349  	if err != nil {
   350  		return err
   351  	}
   352  
   353  	// wait for the session to close
   354  	runErr := sftpServer.Serve()
   355  	exitStatus := uint32(1)
   356  	if runErr == nil {
   357  		exitStatus = uint32(0)
   358  	}
   359  
   360  	_, exitStatusErr := chsvr.ch.SendRequest("exit-status", false, ssh.Marshal(sshSubsystemExitStatus{exitStatus}))
   361  	return exitStatusErr
   362  }
   363  
   364  // starts an ssh server to test. returns: host string and port
   365  func testServer(t *testing.T, useSubsystem bool, readonly bool) (func(), string, int) {
   366  	t.Helper()
   367  
   368  	if !*testIntegration {
   369  		t.Skip("skipping integration test")
   370  	}
   371  
   372  	listener, err := net.Listen("tcp", "127.0.0.1:0")
   373  	if err != nil {
   374  		t.Fatal(err)
   375  	}
   376  
   377  	host, portStr, err := net.SplitHostPort(listener.Addr().String())
   378  	if err != nil {
   379  		t.Fatal(err)
   380  	}
   381  	port, err := strconv.Atoi(portStr)
   382  	if err != nil {
   383  		t.Fatal(err)
   384  	}
   385  
   386  	shutdown := make(chan struct{})
   387  
   388  	go func() {
   389  		for {
   390  			conn, err := listener.Accept()
   391  			if err != nil {
   392  				select {
   393  				case <-shutdown:
   394  				default:
   395  					t.Error("ssh server socket closed:", err)
   396  				}
   397  				return
   398  			}
   399  
   400  			go func() {
   401  				defer conn.Close()
   402  
   403  				sshSvr, err := sshServerFromConn(conn, useSubsystem, basicServerConfig())
   404  				if err != nil {
   405  					t.Error(err)
   406  					return
   407  				}
   408  
   409  				_ = sshSvr.Wait()
   410  			}()
   411  		}
   412  	}()
   413  
   414  	return func() { close(shutdown); listener.Close() }, host, port
   415  }
   416  
   417  func makeDummyKey() (string, error) {
   418  	priv, err := ecdsa.GenerateKey(elliptic.P256(), crand.Reader)
   419  	if err != nil {
   420  		return "", fmt.Errorf("cannot generate key: %w", err)
   421  	}
   422  	der, err := x509.MarshalECPrivateKey(priv)
   423  	if err != nil {
   424  		return "", fmt.Errorf("cannot marshal key: %w", err)
   425  	}
   426  	block := &pem.Block{Type: "EC PRIVATE KEY", Bytes: der}
   427  	f, err := ioutil.TempFile("", "sftp-test-key-")
   428  	if err != nil {
   429  		return "", fmt.Errorf("cannot create temp file: %w", err)
   430  	}
   431  	defer func() {
   432  		if f != nil {
   433  			_ = f.Close()
   434  			_ = os.Remove(f.Name())
   435  		}
   436  	}()
   437  	if err := pem.Encode(f, block); err != nil {
   438  		return "", fmt.Errorf("cannot write key: %w", err)
   439  	}
   440  	if err := f.Close(); err != nil {
   441  		return "", fmt.Errorf("error closing key file: %w", err)
   442  	}
   443  	path := f.Name()
   444  	f = nil
   445  	return path, nil
   446  }
   447  
   448  type execError struct {
   449  	path   string
   450  	stderr string
   451  	err    error
   452  }
   453  
   454  func (e *execError) Error() string {
   455  	return fmt.Sprintf("%s: %v: %s", e.path, e.err, e.stderr)
   456  }
   457  
   458  func (e *execError) Unwrap() error {
   459  	return e.err
   460  }
   461  
   462  func (e *execError) Cause() error {
   463  	return e.err
   464  }
   465  
   466  func runSftpClient(t *testing.T, script string, path string, host string, port int) (string, error) {
   467  	// if sftp client binary is unavailable, skip test
   468  	if _, err := os.Stat(*testSftpClientBin); err != nil {
   469  		t.Skip("sftp client binary unavailable")
   470  	}
   471  
   472  	// make a dummy key so we don't rely on ssh-agent
   473  	dummyKey, err := makeDummyKey()
   474  	if err != nil {
   475  		return "", err
   476  	}
   477  	defer os.Remove(dummyKey)
   478  
   479  	cmd := exec.Command(
   480  		*testSftpClientBin,
   481  		// "-vvvv",
   482  		"-b", "-",
   483  		"-o", "StrictHostKeyChecking=no",
   484  		"-o", "LogLevel=ERROR",
   485  		"-o", "UserKnownHostsFile /dev/null",
   486  		// do not trigger ssh-agent prompting
   487  		"-o", "IdentityFile="+dummyKey,
   488  		"-o", "IdentitiesOnly=yes",
   489  		"-P", fmt.Sprintf("%d", port), fmt.Sprintf("%s:%s", host, path),
   490  	)
   491  
   492  	cmd.Stdin = strings.NewReader(script)
   493  
   494  	stdout := new(bytes.Buffer)
   495  	cmd.Stdout = stdout
   496  
   497  	stderr := new(bytes.Buffer)
   498  	cmd.Stderr = stderr
   499  
   500  	if err := cmd.Start(); err != nil {
   501  		return "", err
   502  	}
   503  
   504  	if err := cmd.Wait(); err != nil {
   505  		return stdout.String(), &execError{
   506  			path:   cmd.Path,
   507  			stderr: stderr.String(),
   508  			err:    err,
   509  		}
   510  	}
   511  
   512  	return stdout.String(), nil
   513  }
   514  
   515  // assert.Eventually seems to have a data rate on macOS with go 1.14 so replace it with this simpler function
   516  func waitForCondition(t *testing.T, condition func() bool) {
   517  	start := time.Now()
   518  	tick := 10 * time.Millisecond
   519  	waitFor := 100 * time.Millisecond
   520  	for !condition() {
   521  		time.Sleep(tick)
   522  		if time.Since(start) > waitFor {
   523  			break
   524  		}
   525  	}
   526  	assert.True(t, condition())
   527  }
   528  
   529  func checkAllocatorBeforeServerClose(t *testing.T, alloc *allocator) {
   530  	if alloc != nil {
   531  		// before closing the server we are, generally, waiting for new packets in recvPacket and we have a page allocated.
   532  		// Sometime the sendPacket returns some milliseconds after the client receives the response, and so we have 2
   533  		// allocated pages here, so wait some milliseconds. To avoid crashes we must be sure to not release the pages
   534  		// too soon.
   535  		waitForCondition(t, func() bool { return alloc.countUsedPages() <= 1 })
   536  	}
   537  }
   538  
   539  func checkAllocatorAfterServerClose(t *testing.T, alloc *allocator) {
   540  	if alloc != nil {
   541  		// wait for the server cleanup
   542  		waitForCondition(t, func() bool { return alloc.countUsedPages() == 0 })
   543  		waitForCondition(t, func() bool { return alloc.countAvailablePages() == 0 })
   544  	}
   545  }
   546  
   547  func TestServerCompareSubsystems(t *testing.T) {
   548  	if runtime.GOOS == "windows" {
   549  		// TODO (puellanivis): not sure how to fix this, the OpenSSH SFTP implementation closes immediately.
   550  		t.Skip()
   551  	}
   552  
   553  	shutdownGo, hostGo, portGo := testServer(t, GolangSFTP, READONLY)
   554  	defer shutdownGo()
   555  
   556  	shutdownOp, hostOp, portOp := testServer(t, OpenSSHSFTP, READONLY)
   557  	defer shutdownOp()
   558  
   559  	script := `
   560  ls /
   561  ls -l /
   562  ls /dev/
   563  ls -l /dev/
   564  ls -l /etc/
   565  ls -l /bin/
   566  ls -l /usr/bin/
   567  `
   568  	outputGo, err := runSftpClient(t, script, "/", hostGo, portGo)
   569  	if err != nil {
   570  		t.Fatal(err)
   571  	}
   572  
   573  	outputOp, err := runSftpClient(t, script, "/", hostOp, portOp)
   574  	if err != nil {
   575  		t.Fatal(err)
   576  	}
   577  
   578  	newlineRegex := regexp.MustCompile(`\r*\n`)
   579  	spaceRegex := regexp.MustCompile(`\s+`)
   580  	outputGoLines := newlineRegex.Split(outputGo, -1)
   581  	outputOpLines := newlineRegex.Split(outputOp, -1)
   582  
   583  	if len(outputGoLines) != len(outputOpLines) {
   584  		t.Fatalf("output line count differs, go = %d, openssh = %d", len(outputGoLines), len(outputOpLines))
   585  	}
   586  
   587  	for i, goLine := range outputGoLines {
   588  		opLine := outputOpLines[i]
   589  		bad := false
   590  		if goLine != opLine {
   591  			goWords := spaceRegex.Split(goLine, -1)
   592  			opWords := spaceRegex.Split(opLine, -1)
   593  			// some fields are allowed to be different..
   594  			// words[2] and [3] as these are users & groups
   595  			// words[1] as the link count for directories like proc is unstable
   596  			// during testing as processes are created/destroyed.
   597  			// words[7] as timestamp on dirs can very for things like /tmp
   598  			for j, goWord := range goWords {
   599  				if j >= len(opWords) {
   600  					bad = true
   601  					break
   602  				}
   603  				opWord := opWords[j]
   604  				if goWord != opWord && j != 1 && j != 2 && j != 3 && j != 7 {
   605  					bad = true
   606  				}
   607  			}
   608  		}
   609  
   610  		if bad {
   611  			t.Errorf("outputs differ\n     go: %q\nopenssh: %q\n", goLine, opLine)
   612  		}
   613  	}
   614  }
   615  
   616  var rng = rand.New(rand.NewSource(time.Now().Unix()))
   617  
   618  func randData(length int) []byte {
   619  	data := make([]byte, length)
   620  	for i := 0; i < length; i++ {
   621  		data[i] = byte(rng.Uint32())
   622  	}
   623  	return data
   624  }
   625  
   626  func randName() string {
   627  	return "sftp." + hex.EncodeToString(randData(16))
   628  }
   629  
   630  func TestServerMkdirRmdir(t *testing.T) {
   631  	shutdown, hostGo, portGo := testServer(t, GolangSFTP, READONLY)
   632  	defer shutdown()
   633  
   634  	tmpDir := "/tmp/" + randName()
   635  	defer os.RemoveAll(tmpDir)
   636  
   637  	// mkdir remote
   638  	if _, err := runSftpClient(t, "mkdir "+tmpDir, "/", hostGo, portGo); err != nil {
   639  		t.Fatal(err)
   640  	}
   641  
   642  	// directory should now exist
   643  	if _, err := os.Stat(tmpDir); err != nil {
   644  		t.Fatal(err)
   645  	}
   646  
   647  	// now remove the directory
   648  	if _, err := runSftpClient(t, "rmdir "+tmpDir, "/", hostGo, portGo); err != nil {
   649  		t.Fatal(err)
   650  	}
   651  
   652  	if _, err := os.Stat(tmpDir); err == nil {
   653  		t.Fatal("should have error after deleting the directory")
   654  	}
   655  }
   656  
   657  func TestServerLink(t *testing.T) {
   658  	skipIfWindows(t) // No hard links on windows.
   659  	shutdown, hostGo, portGo := testServer(t, GolangSFTP, READONLY)
   660  	defer shutdown()
   661  
   662  	tmpFileLocalData := randData(999)
   663  
   664  	linkdest := "/tmp/" + randName()
   665  	defer os.RemoveAll(linkdest)
   666  	if err := ioutil.WriteFile(linkdest, tmpFileLocalData, 0644); err != nil {
   667  		t.Fatal(err)
   668  	}
   669  
   670  	link := "/tmp/" + randName()
   671  	defer os.RemoveAll(link)
   672  
   673  	// now create a hard link within the new directory
   674  	if output, err := runSftpClient(t, fmt.Sprintf("ln %s %s", linkdest, link), "/", hostGo, portGo); err != nil {
   675  		t.Fatalf("failed: %v %v", err, string(output))
   676  	}
   677  
   678  	// file should now exist and be the same size as linkdest
   679  	if stat, err := os.Lstat(link); err != nil {
   680  		t.Fatal(err)
   681  	} else if int(stat.Size()) != len(tmpFileLocalData) {
   682  		t.Fatalf("wrong size: %v", len(tmpFileLocalData))
   683  	}
   684  }
   685  
   686  func TestServerSymlink(t *testing.T) {
   687  	skipIfWindows(t) // No symlinks on windows.
   688  	shutdown, hostGo, portGo := testServer(t, GolangSFTP, READONLY)
   689  	defer shutdown()
   690  
   691  	link := "/tmp/" + randName()
   692  	defer os.RemoveAll(link)
   693  
   694  	// now create a symbolic link within the new directory
   695  	if output, err := runSftpClient(t, "symlink /bin/sh "+link, "/", hostGo, portGo); err != nil {
   696  		t.Fatalf("failed: %v %v", err, string(output))
   697  	}
   698  
   699  	// symlink should now exist
   700  	if stat, err := os.Lstat(link); err != nil {
   701  		t.Fatal(err)
   702  	} else if (stat.Mode() & os.ModeSymlink) != os.ModeSymlink {
   703  		t.Fatalf("is not a symlink: %v", stat.Mode())
   704  	}
   705  }
   706  
   707  func TestServerPut(t *testing.T) {
   708  	shutdown, hostGo, portGo := testServer(t, GolangSFTP, READONLY)
   709  	defer shutdown()
   710  
   711  	tmpFileLocal := "/tmp/" + randName()
   712  	tmpFileRemote := "/tmp/" + randName()
   713  	defer os.RemoveAll(tmpFileLocal)
   714  	defer os.RemoveAll(tmpFileRemote)
   715  
   716  	t.Logf("put: local %v remote %v", tmpFileLocal, tmpFileRemote)
   717  
   718  	// create a file with random contents. This will be the local file pushed to the server
   719  	tmpFileLocalData := randData(10 * 1024 * 1024)
   720  	if err := ioutil.WriteFile(tmpFileLocal, tmpFileLocalData, 0644); err != nil {
   721  		t.Fatal(err)
   722  	}
   723  
   724  	// sftp the file to the server
   725  	if output, err := runSftpClient(t, "put "+tmpFileLocal+" "+tmpFileRemote, "/", hostGo, portGo); err != nil {
   726  		t.Fatalf("runSftpClient failed: %v, output\n%v\n", err, output)
   727  	}
   728  
   729  	// tmpFile2 should now exist, with the same contents
   730  	if tmpFileRemoteData, err := ioutil.ReadFile(tmpFileRemote); err != nil {
   731  		t.Fatal(err)
   732  	} else if string(tmpFileLocalData) != string(tmpFileRemoteData) {
   733  		t.Fatal("contents of file incorrect after put")
   734  	}
   735  }
   736  
   737  func TestServerResume(t *testing.T) {
   738  	shutdown, hostGo, portGo := testServer(t, GolangSFTP, READONLY)
   739  	defer shutdown()
   740  
   741  	tmpFileLocal := "/tmp/" + randName()
   742  	tmpFileRemote := "/tmp/" + randName()
   743  	defer os.RemoveAll(tmpFileLocal)
   744  	defer os.RemoveAll(tmpFileRemote)
   745  
   746  	t.Logf("put: local %v remote %v", tmpFileLocal, tmpFileRemote)
   747  
   748  	// create a local file with random contents to be pushed to the server
   749  	tmpFileLocalData := randData(2 * 1024 * 1024)
   750  	// only write half the data to simulate a split upload
   751  	half := 1024 * 1024
   752  	err := ioutil.WriteFile(tmpFileLocal, tmpFileLocalData[:half], 0644)
   753  	if err != nil {
   754  		t.Fatal(err)
   755  	}
   756  
   757  	// sftp the first half of the file to the server
   758  	output, err := runSftpClient(t, "put "+tmpFileLocal+" "+tmpFileRemote,
   759  		"/", hostGo, portGo)
   760  	if err != nil {
   761  		t.Fatalf("runSftpClient failed: %v, output\n%v\n", err, output)
   762  	}
   763  
   764  	// write the full file out
   765  	err = ioutil.WriteFile(tmpFileLocal, tmpFileLocalData, 0644)
   766  	if err != nil {
   767  		t.Fatal(err)
   768  	}
   769  	// re-sftp the full file with the append flag set
   770  	output, err = runSftpClient(t, "put -a "+tmpFileLocal+" "+tmpFileRemote,
   771  		"/", hostGo, portGo)
   772  	if err != nil {
   773  		t.Fatalf("runSftpClient failed: %v, output\n%v\n", err, output)
   774  	}
   775  
   776  	// tmpFileRemote should now exist, with the same contents
   777  	if tmpFileRemoteData, err := ioutil.ReadFile(tmpFileRemote); err != nil {
   778  		t.Fatal(err)
   779  	} else if string(tmpFileLocalData) != string(tmpFileRemoteData) {
   780  		t.Fatal("contents of file incorrect after put")
   781  	}
   782  }
   783  
   784  func TestServerGet(t *testing.T) {
   785  	shutdown, hostGo, portGo := testServer(t, GolangSFTP, READONLY)
   786  	defer shutdown()
   787  
   788  	tmpFileLocal := "/tmp/" + randName()
   789  	tmpFileRemote := "/tmp/" + randName()
   790  	defer os.RemoveAll(tmpFileLocal)
   791  	defer os.RemoveAll(tmpFileRemote)
   792  
   793  	t.Logf("get: local %v remote %v", tmpFileLocal, tmpFileRemote)
   794  
   795  	// create a file with random contents. This will be the remote file pulled from the server
   796  	tmpFileRemoteData := randData(10 * 1024 * 1024)
   797  	if err := ioutil.WriteFile(tmpFileRemote, tmpFileRemoteData, 0644); err != nil {
   798  		t.Fatal(err)
   799  	}
   800  
   801  	// sftp the file to the server
   802  	if output, err := runSftpClient(t, "get "+tmpFileRemote+" "+tmpFileLocal, "/", hostGo, portGo); err != nil {
   803  		t.Fatalf("runSftpClient failed: %v, output\n%v\n", err, output)
   804  	}
   805  
   806  	// tmpFile2 should now exist, with the same contents
   807  	if tmpFileLocalData, err := ioutil.ReadFile(tmpFileLocal); err != nil {
   808  		t.Fatal(err)
   809  	} else if string(tmpFileLocalData) != string(tmpFileRemoteData) {
   810  		t.Fatal("contents of file incorrect after put")
   811  	}
   812  }
   813  
   814  func compareDirectoriesRecursive(t *testing.T, aroot, broot string) {
   815  	walker := fs.Walk(aroot)
   816  	for walker.Step() {
   817  		if err := walker.Err(); err != nil {
   818  			t.Fatal(err)
   819  		}
   820  		// find paths
   821  		aPath := walker.Path()
   822  		aRel, err := filepath.Rel(aroot, aPath)
   823  		if err != nil {
   824  			t.Fatalf("could not find relative path for %v: %v", aPath, err)
   825  		}
   826  		bPath := filepath.Join(broot, aRel)
   827  
   828  		if aRel == "." {
   829  			continue
   830  		}
   831  
   832  		//t.Logf("comparing: %v a: %v b %v", aRel, aPath, bPath)
   833  
   834  		// if a is a link, the sftp recursive copy won't have copied it. ignore
   835  		aLink, err := os.Lstat(aPath)
   836  		if err != nil {
   837  			t.Fatalf("could not lstat %v: %v", aPath, err)
   838  		}
   839  		if aLink.Mode()&os.ModeSymlink != 0 {
   840  			continue
   841  		}
   842  
   843  		// stat the files
   844  		aFile, err := os.Stat(aPath)
   845  		if err != nil {
   846  			t.Fatalf("could not stat %v: %v", aPath, err)
   847  		}
   848  		bFile, err := os.Stat(bPath)
   849  		if err != nil {
   850  			t.Fatalf("could not stat %v: %v", bPath, err)
   851  		}
   852  
   853  		// compare stats, with some leniency for the timestamp
   854  		if aFile.Mode() != bFile.Mode() {
   855  			t.Fatalf("modes different for %v: %v vs %v", aRel, aFile.Mode(), bFile.Mode())
   856  		}
   857  		if !aFile.IsDir() {
   858  			if aFile.Size() != bFile.Size() {
   859  				t.Fatalf("sizes different for %v: %v vs %v", aRel, aFile.Size(), bFile.Size())
   860  			}
   861  		}
   862  		timeDiff := aFile.ModTime().Sub(bFile.ModTime())
   863  		if timeDiff > time.Second || timeDiff < -time.Second {
   864  			t.Fatalf("mtimes different for %v: %v vs %v", aRel, aFile.ModTime(), bFile.ModTime())
   865  		}
   866  
   867  		// compare contents
   868  		if !aFile.IsDir() {
   869  			if aContents, err := ioutil.ReadFile(aPath); err != nil {
   870  				t.Fatal(err)
   871  			} else if bContents, err := ioutil.ReadFile(bPath); err != nil {
   872  				t.Fatal(err)
   873  			} else if string(aContents) != string(bContents) {
   874  				t.Fatalf("contents different for %v", aRel)
   875  			}
   876  		}
   877  	}
   878  }
   879  
   880  func TestServerPutRecursive(t *testing.T) {
   881  	shutdown, hostGo, portGo := testServer(t, GolangSFTP, READONLY)
   882  	defer shutdown()
   883  
   884  	dirLocal, err := os.Getwd()
   885  	if err != nil {
   886  		t.Fatal(err)
   887  	}
   888  	tmpDirRemote := "/tmp/" + randName()
   889  	defer os.RemoveAll(tmpDirRemote)
   890  
   891  	t.Logf("put recursive: local %v remote %v", dirLocal, tmpDirRemote)
   892  
   893  	// On windows, the client copies the contents of the directory, not the directory itself.
   894  	winFix := ""
   895  	if runtime.GOOS == "windows" {
   896  		winFix = "/" + filepath.Base(dirLocal)
   897  	} //*/
   898  
   899  	// push this directory (source code etc) recursively to the server
   900  	if output, err := runSftpClient(t, "mkdir "+tmpDirRemote+"\r\nput -R -p "+dirLocal+" "+tmpDirRemote+winFix, "/", hostGo, portGo); err != nil {
   901  		t.Fatalf("runSftpClient failed: %v, output\n%v\n", err, output)
   902  	}
   903  
   904  	compareDirectoriesRecursive(t, dirLocal, filepath.Join(tmpDirRemote, filepath.Base(dirLocal)))
   905  }
   906  
   907  func TestServerGetRecursive(t *testing.T) {
   908  	shutdown, hostGo, portGo := testServer(t, GolangSFTP, READONLY)
   909  	defer shutdown()
   910  
   911  	dirRemote, err := os.Getwd()
   912  	if err != nil {
   913  		t.Fatal(err)
   914  	}
   915  	tmpDirLocal := "/tmp/" + randName()
   916  	defer os.RemoveAll(tmpDirLocal)
   917  
   918  	t.Logf("get recursive: local %v remote %v", tmpDirLocal, dirRemote)
   919  
   920  	// On windows, the client copies the contents of the directory, not the directory itself.
   921  	winFix := ""
   922  	if runtime.GOOS == "windows" {
   923  		winFix = "/" + filepath.Base(dirRemote)
   924  	}
   925  
   926  	// pull this directory (source code etc) recursively from the server
   927  	if output, err := runSftpClient(t, "lmkdir "+tmpDirLocal+"\r\nget -R -p "+dirRemote+" "+tmpDirLocal+winFix, "/", hostGo, portGo); err != nil {
   928  		t.Fatalf("runSftpClient failed: %v, output\n%v\n", err, output)
   929  	}
   930  
   931  	compareDirectoriesRecursive(t, dirRemote, filepath.Join(tmpDirLocal, filepath.Base(dirRemote)))
   932  }