github.com/pkg/sftp@v1.13.6/request_test.go (about)

     1  package sftp
     2  
     3  import (
     4  	"bytes"
     5  	"errors"
     6  	"io"
     7  	"os"
     8  	"testing"
     9  
    10  	"github.com/stretchr/testify/assert"
    11  )
    12  
    13  type testHandler struct {
    14  	filecontents []byte      // dummy contents
    15  	output       io.WriterAt // dummy file out
    16  	err          error       // dummy error, should be file related
    17  }
    18  
    19  func (t *testHandler) Fileread(r *Request) (io.ReaderAt, error) {
    20  	if t.err != nil {
    21  		return nil, t.err
    22  	}
    23  	_ = r.WithContext(r.Context()) // initialize context for deadlock testing
    24  	return bytes.NewReader(t.filecontents), nil
    25  }
    26  
    27  func (t *testHandler) Filewrite(r *Request) (io.WriterAt, error) {
    28  	if t.err != nil {
    29  		return nil, t.err
    30  	}
    31  	_ = r.WithContext(r.Context()) // initialize context for deadlock testing
    32  	return io.WriterAt(t.output), nil
    33  }
    34  
    35  func (t *testHandler) Filecmd(r *Request) error {
    36  	_ = r.WithContext(r.Context()) // initialize context for deadlock testing
    37  	return t.err
    38  }
    39  
    40  func (t *testHandler) Filelist(r *Request) (ListerAt, error) {
    41  	if t.err != nil {
    42  		return nil, t.err
    43  	}
    44  	_ = r.WithContext(r.Context()) // initialize context for deadlock testing
    45  	f, err := os.Open(r.Filepath)
    46  	if err != nil {
    47  		return nil, err
    48  	}
    49  	fi, err := f.Stat()
    50  	if err != nil {
    51  		return nil, err
    52  	}
    53  	return listerat([]os.FileInfo{fi}), nil
    54  }
    55  
    56  // make sure len(fakefile) == len(filecontents)
    57  type fakefile [10]byte
    58  
    59  var filecontents = []byte("file-data.")
    60  
    61  // XXX need new for creating test requests that supports Open-ing
    62  func testRequest(method string) *Request {
    63  	var flags uint32
    64  	switch method {
    65  	case "Get":
    66  		flags = flags | sshFxfRead
    67  	case "Put":
    68  		flags = flags | sshFxfWrite
    69  	}
    70  	request := &Request{
    71  		Filepath: "./request_test.go",
    72  		Method:   method,
    73  		Attrs:    []byte("foo"),
    74  		Flags:    flags,
    75  		Target:   "foo",
    76  	}
    77  	return request
    78  }
    79  
    80  func (ff *fakefile) WriteAt(p []byte, off int64) (int, error) {
    81  	n := copy(ff[off:], p)
    82  	return n, nil
    83  }
    84  
    85  func (ff fakefile) string() string {
    86  	b := make([]byte, len(ff))
    87  	copy(b, ff[:])
    88  	return string(b)
    89  }
    90  
    91  func newTestHandlers() Handlers {
    92  	handler := &testHandler{
    93  		filecontents: filecontents,
    94  		output:       &fakefile{},
    95  		err:          nil,
    96  	}
    97  	return Handlers{
    98  		FileGet:  handler,
    99  		FilePut:  handler,
   100  		FileCmd:  handler,
   101  		FileList: handler,
   102  	}
   103  }
   104  
   105  func (h Handlers) getOutString() string {
   106  	handler := h.FilePut.(*testHandler)
   107  	return handler.output.(*fakefile).string()
   108  }
   109  
   110  var errTest = errors.New("test error")
   111  
   112  func (h *Handlers) returnError(err error) {
   113  	handler := h.FilePut.(*testHandler)
   114  	handler.err = err
   115  }
   116  
   117  func getStatusMsg(p interface{}) string {
   118  	pkt := p.(*sshFxpStatusPacket)
   119  	return pkt.StatusError.msg
   120  }
   121  func checkOkStatus(t *testing.T, p interface{}) {
   122  	pkt := p.(*sshFxpStatusPacket)
   123  	assert.Equal(t, pkt.StatusError.Code, uint32(sshFxOk),
   124  		"sshFxpStatusPacket not OK\n", pkt.StatusError.msg)
   125  }
   126  
   127  // fake/test packet
   128  type fakePacket struct {
   129  	myid   uint32
   130  	handle string
   131  }
   132  
   133  func (f fakePacket) id() uint32 {
   134  	return f.myid
   135  }
   136  
   137  func (f fakePacket) getHandle() string {
   138  	return f.handle
   139  }
   140  func (fakePacket) UnmarshalBinary(d []byte) error { return nil }
   141  
   142  // XXX can't just set method to Get, need to use Open to setup Get/Put
   143  func TestRequestGet(t *testing.T) {
   144  	handlers := newTestHandlers()
   145  	request := testRequest("Get")
   146  	pkt := fakePacket{myid: 1}
   147  	request.open(handlers, pkt)
   148  	// req.length is 5, so we test reads in 5 byte chunks
   149  	for i, txt := range []string{"file-", "data."} {
   150  		pkt := &sshFxpReadPacket{ID: uint32(i), Handle: "a",
   151  			Offset: uint64(i * 5), Len: 5}
   152  		rpkt := request.call(handlers, pkt, nil, 0)
   153  		dpkt := rpkt.(*sshFxpDataPacket)
   154  		assert.Equal(t, dpkt.id(), uint32(i))
   155  		assert.Equal(t, string(dpkt.Data), txt)
   156  	}
   157  }
   158  
   159  func TestRequestCustomError(t *testing.T) {
   160  	handlers := newTestHandlers()
   161  	request := testRequest("Stat")
   162  	pkt := fakePacket{myid: 1}
   163  	cmdErr := errors.New("stat not supported")
   164  	handlers.returnError(cmdErr)
   165  	rpkt := request.call(handlers, pkt, nil, 0)
   166  	assert.Equal(t, rpkt, statusFromError(pkt.myid, cmdErr))
   167  }
   168  
   169  // XXX can't just set method to Get, need to use Open to setup Get/Put
   170  func TestRequestPut(t *testing.T) {
   171  	handlers := newTestHandlers()
   172  	request := testRequest("Put")
   173  	request.state.writerAt, _ = handlers.FilePut.Filewrite(request)
   174  	pkt := &sshFxpWritePacket{ID: 0, Handle: "a", Offset: 0, Length: 5,
   175  		Data: []byte("file-")}
   176  	rpkt := request.call(handlers, pkt, nil, 0)
   177  	checkOkStatus(t, rpkt)
   178  	pkt = &sshFxpWritePacket{ID: 1, Handle: "a", Offset: 5, Length: 5,
   179  		Data: []byte("data.")}
   180  	rpkt = request.call(handlers, pkt, nil, 0)
   181  	checkOkStatus(t, rpkt)
   182  	assert.Equal(t, "file-data.", handlers.getOutString())
   183  }
   184  
   185  func TestRequestCmdr(t *testing.T) {
   186  	handlers := newTestHandlers()
   187  	request := testRequest("Mkdir")
   188  	pkt := fakePacket{myid: 1}
   189  	rpkt := request.call(handlers, pkt, nil, 0)
   190  	checkOkStatus(t, rpkt)
   191  
   192  	handlers.returnError(errTest)
   193  	rpkt = request.call(handlers, pkt, nil, 0)
   194  	assert.Equal(t, rpkt, statusFromError(pkt.myid, errTest))
   195  }
   196  
   197  func TestRequestInfoStat(t *testing.T) {
   198  	handlers := newTestHandlers()
   199  	request := testRequest("Stat")
   200  	pkt := fakePacket{myid: 1}
   201  	rpkt := request.call(handlers, pkt, nil, 0)
   202  	spkt, ok := rpkt.(*sshFxpStatResponse)
   203  	assert.True(t, ok)
   204  	assert.Equal(t, spkt.info.Name(), "request_test.go")
   205  }
   206  
   207  func TestRequestInfoList(t *testing.T) {
   208  	handlers := newTestHandlers()
   209  	request := testRequest("List")
   210  	request.handle = "1"
   211  	pkt := fakePacket{myid: 1}
   212  	rpkt := request.opendir(handlers, pkt)
   213  	hpkt, ok := rpkt.(*sshFxpHandlePacket)
   214  	if assert.True(t, ok) {
   215  		assert.Equal(t, hpkt.Handle, "1")
   216  	}
   217  	pkt = fakePacket{myid: 2}
   218  	request.call(handlers, pkt, nil, 0)
   219  }
   220  func TestRequestInfoReadlink(t *testing.T) {
   221  	handlers := newTestHandlers()
   222  	request := testRequest("Readlink")
   223  	pkt := fakePacket{myid: 1}
   224  	rpkt := request.call(handlers, pkt, nil, 0)
   225  	npkt, ok := rpkt.(*sshFxpNamePacket)
   226  	if assert.True(t, ok) {
   227  		assert.IsType(t, &sshFxpNameAttr{}, npkt.NameAttrs[0])
   228  		assert.Equal(t, npkt.NameAttrs[0].Name, "request_test.go")
   229  	}
   230  }
   231  
   232  func TestOpendirHandleReuse(t *testing.T) {
   233  	handlers := newTestHandlers()
   234  	request := testRequest("Stat")
   235  	request.handle = "1"
   236  	pkt := fakePacket{myid: 1}
   237  	rpkt := request.call(handlers, pkt, nil, 0)
   238  	assert.IsType(t, &sshFxpStatResponse{}, rpkt)
   239  
   240  	request.Method = "List"
   241  	pkt = fakePacket{myid: 2}
   242  	rpkt = request.opendir(handlers, pkt)
   243  	if assert.IsType(t, &sshFxpHandlePacket{}, rpkt) {
   244  		hpkt := rpkt.(*sshFxpHandlePacket)
   245  		assert.Equal(t, hpkt.Handle, "1")
   246  	}
   247  	rpkt = request.call(handlers, pkt, nil, 0)
   248  	assert.IsType(t, &sshFxpNamePacket{}, rpkt)
   249  }