github.com/pkg/sftp@v1.13.6/request.go (about)

     1  package sftp
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"os"
     9  	"strings"
    10  	"sync"
    11  	"syscall"
    12  )
    13  
    14  // MaxFilelist is the max number of files to return in a readdir batch.
    15  var MaxFilelist int64 = 100
    16  
    17  // state encapsulates the reader/writer/readdir from handlers.
    18  type state struct {
    19  	mu sync.RWMutex
    20  
    21  	writerAt         io.WriterAt
    22  	readerAt         io.ReaderAt
    23  	writerAtReaderAt WriterAtReaderAt
    24  	listerAt         ListerAt
    25  	lsoffset         int64
    26  }
    27  
    28  // copy returns a shallow copy the state.
    29  // This is broken out to specific fields,
    30  // because we have to copy around the mutex in state.
    31  func (s *state) copy() state {
    32  	s.mu.RLock()
    33  	defer s.mu.RUnlock()
    34  
    35  	return state{
    36  		writerAt:         s.writerAt,
    37  		readerAt:         s.readerAt,
    38  		writerAtReaderAt: s.writerAtReaderAt,
    39  		listerAt:         s.listerAt,
    40  		lsoffset:         s.lsoffset,
    41  	}
    42  }
    43  
    44  func (s *state) setReaderAt(rd io.ReaderAt) {
    45  	s.mu.Lock()
    46  	defer s.mu.Unlock()
    47  
    48  	s.readerAt = rd
    49  }
    50  
    51  func (s *state) getReaderAt() io.ReaderAt {
    52  	s.mu.RLock()
    53  	defer s.mu.RUnlock()
    54  
    55  	return s.readerAt
    56  }
    57  
    58  func (s *state) setWriterAt(rd io.WriterAt) {
    59  	s.mu.Lock()
    60  	defer s.mu.Unlock()
    61  
    62  	s.writerAt = rd
    63  }
    64  
    65  func (s *state) getWriterAt() io.WriterAt {
    66  	s.mu.RLock()
    67  	defer s.mu.RUnlock()
    68  
    69  	return s.writerAt
    70  }
    71  
    72  func (s *state) setWriterAtReaderAt(rw WriterAtReaderAt) {
    73  	s.mu.Lock()
    74  	defer s.mu.Unlock()
    75  
    76  	s.writerAtReaderAt = rw
    77  }
    78  
    79  func (s *state) getWriterAtReaderAt() WriterAtReaderAt {
    80  	s.mu.RLock()
    81  	defer s.mu.RUnlock()
    82  
    83  	return s.writerAtReaderAt
    84  }
    85  
    86  func (s *state) getAllReaderWriters() (io.ReaderAt, io.WriterAt, WriterAtReaderAt) {
    87  	s.mu.RLock()
    88  	defer s.mu.RUnlock()
    89  
    90  	return s.readerAt, s.writerAt, s.writerAtReaderAt
    91  }
    92  
    93  // Returns current offset for file list
    94  func (s *state) lsNext() int64 {
    95  	s.mu.RLock()
    96  	defer s.mu.RUnlock()
    97  
    98  	return s.lsoffset
    99  }
   100  
   101  // Increases next offset
   102  func (s *state) lsInc(offset int64) {
   103  	s.mu.Lock()
   104  	defer s.mu.Unlock()
   105  
   106  	s.lsoffset += offset
   107  }
   108  
   109  // manage file read/write state
   110  func (s *state) setListerAt(la ListerAt) {
   111  	s.mu.Lock()
   112  	defer s.mu.Unlock()
   113  
   114  	s.listerAt = la
   115  }
   116  
   117  func (s *state) getListerAt() ListerAt {
   118  	s.mu.RLock()
   119  	defer s.mu.RUnlock()
   120  
   121  	return s.listerAt
   122  }
   123  
   124  // Request contains the data and state for the incoming service request.
   125  type Request struct {
   126  	// Get, Put, Setstat, Stat, Rename, Remove
   127  	// Rmdir, Mkdir, List, Readlink, Link, Symlink
   128  	Method   string
   129  	Filepath string
   130  	Flags    uint32
   131  	Attrs    []byte // convert to sub-struct
   132  	Target   string // for renames and sym-links
   133  	handle   string
   134  
   135  	// reader/writer/readdir from handlers
   136  	state
   137  
   138  	// context lasts duration of request
   139  	ctx       context.Context
   140  	cancelCtx context.CancelFunc
   141  }
   142  
   143  // NewRequest creates a new Request object.
   144  func NewRequest(method, path string) *Request {
   145  	return &Request{
   146  		Method:   method,
   147  		Filepath: cleanPath(path),
   148  	}
   149  }
   150  
   151  // copy returns a shallow copy of existing request.
   152  // This is broken out to specific fields,
   153  // because we have to copy around the mutex in state.
   154  func (r *Request) copy() *Request {
   155  	return &Request{
   156  		Method:   r.Method,
   157  		Filepath: r.Filepath,
   158  		Flags:    r.Flags,
   159  		Attrs:    r.Attrs,
   160  		Target:   r.Target,
   161  		handle:   r.handle,
   162  
   163  		state: r.state.copy(),
   164  
   165  		ctx:       r.ctx,
   166  		cancelCtx: r.cancelCtx,
   167  	}
   168  }
   169  
   170  // New Request initialized based on packet data
   171  func requestFromPacket(ctx context.Context, pkt hasPath, baseDir string) *Request {
   172  	request := &Request{
   173  		Method:   requestMethod(pkt),
   174  		Filepath: cleanPathWithBase(baseDir, pkt.getPath()),
   175  	}
   176  	request.ctx, request.cancelCtx = context.WithCancel(ctx)
   177  
   178  	switch p := pkt.(type) {
   179  	case *sshFxpOpenPacket:
   180  		request.Flags = p.Pflags
   181  	case *sshFxpSetstatPacket:
   182  		request.Flags = p.Flags
   183  		request.Attrs = p.Attrs.([]byte)
   184  	case *sshFxpRenamePacket:
   185  		request.Target = cleanPathWithBase(baseDir, p.Newpath)
   186  	case *sshFxpSymlinkPacket:
   187  		// NOTE: given a POSIX compliant signature: symlink(target, linkpath string)
   188  		// this makes Request.Target the linkpath, and Request.Filepath the target.
   189  		request.Target = cleanPathWithBase(baseDir, p.Linkpath)
   190  		request.Filepath = p.Targetpath
   191  	case *sshFxpExtendedPacketHardlink:
   192  		request.Target = cleanPathWithBase(baseDir, p.Newpath)
   193  	}
   194  	return request
   195  }
   196  
   197  // Context returns the request's context. To change the context,
   198  // use WithContext.
   199  //
   200  // The returned context is always non-nil; it defaults to the
   201  // background context.
   202  //
   203  // For incoming server requests, the context is canceled when the
   204  // request is complete or the client's connection closes.
   205  func (r *Request) Context() context.Context {
   206  	if r.ctx != nil {
   207  		return r.ctx
   208  	}
   209  	return context.Background()
   210  }
   211  
   212  // WithContext returns a copy of r with its context changed to ctx.
   213  // The provided ctx must be non-nil.
   214  func (r *Request) WithContext(ctx context.Context) *Request {
   215  	if ctx == nil {
   216  		panic("nil context")
   217  	}
   218  	r2 := r.copy()
   219  	r2.ctx = ctx
   220  	r2.cancelCtx = nil
   221  	return r2
   222  }
   223  
   224  // Close reader/writer if possible
   225  func (r *Request) close() error {
   226  	defer func() {
   227  		if r.cancelCtx != nil {
   228  			r.cancelCtx()
   229  		}
   230  	}()
   231  
   232  	rd, wr, rw := r.getAllReaderWriters()
   233  
   234  	var err error
   235  
   236  	// Close errors on a Writer are far more likely to be the important one.
   237  	// As they can be information that there was a loss of data.
   238  	if c, ok := wr.(io.Closer); ok {
   239  		if err2 := c.Close(); err == nil {
   240  			// update error if it is still nil
   241  			err = err2
   242  		}
   243  	}
   244  
   245  	if c, ok := rw.(io.Closer); ok {
   246  		if err2 := c.Close(); err == nil {
   247  			// update error if it is still nil
   248  			err = err2
   249  
   250  			r.setWriterAtReaderAt(nil)
   251  		}
   252  	}
   253  
   254  	if c, ok := rd.(io.Closer); ok {
   255  		if err2 := c.Close(); err == nil {
   256  			// update error if it is still nil
   257  			err = err2
   258  		}
   259  	}
   260  
   261  	return err
   262  }
   263  
   264  // Notify transfer error if any
   265  func (r *Request) transferError(err error) {
   266  	if err == nil {
   267  		return
   268  	}
   269  
   270  	rd, wr, rw := r.getAllReaderWriters()
   271  
   272  	if t, ok := wr.(TransferError); ok {
   273  		t.TransferError(err)
   274  	}
   275  
   276  	if t, ok := rw.(TransferError); ok {
   277  		t.TransferError(err)
   278  	}
   279  
   280  	if t, ok := rd.(TransferError); ok {
   281  		t.TransferError(err)
   282  	}
   283  }
   284  
   285  // called from worker to handle packet/request
   286  func (r *Request) call(handlers Handlers, pkt requestPacket, alloc *allocator, orderID uint32) responsePacket {
   287  	switch r.Method {
   288  	case "Get":
   289  		return fileget(handlers.FileGet, r, pkt, alloc, orderID)
   290  	case "Put":
   291  		return fileput(handlers.FilePut, r, pkt, alloc, orderID)
   292  	case "Open":
   293  		return fileputget(handlers.FilePut, r, pkt, alloc, orderID)
   294  	case "Setstat", "Rename", "Rmdir", "Mkdir", "Link", "Symlink", "Remove", "PosixRename", "StatVFS":
   295  		return filecmd(handlers.FileCmd, r, pkt)
   296  	case "List":
   297  		return filelist(handlers.FileList, r, pkt)
   298  	case "Stat", "Lstat":
   299  		return filestat(handlers.FileList, r, pkt)
   300  	case "Readlink":
   301  		if readlinkFileLister, ok := handlers.FileList.(ReadlinkFileLister); ok {
   302  			return readlink(readlinkFileLister, r, pkt)
   303  		}
   304  		return filestat(handlers.FileList, r, pkt)
   305  	default:
   306  		return statusFromError(pkt.id(), fmt.Errorf("unexpected method: %s", r.Method))
   307  	}
   308  }
   309  
   310  // Additional initialization for Open packets
   311  func (r *Request) open(h Handlers, pkt requestPacket) responsePacket {
   312  	flags := r.Pflags()
   313  
   314  	id := pkt.id()
   315  
   316  	switch {
   317  	case flags.Write, flags.Append, flags.Creat, flags.Trunc:
   318  		if flags.Read {
   319  			if openFileWriter, ok := h.FilePut.(OpenFileWriter); ok {
   320  				r.Method = "Open"
   321  				rw, err := openFileWriter.OpenFile(r)
   322  				if err != nil {
   323  					return statusFromError(id, err)
   324  				}
   325  
   326  				r.setWriterAtReaderAt(rw)
   327  
   328  				return &sshFxpHandlePacket{
   329  					ID:     id,
   330  					Handle: r.handle,
   331  				}
   332  			}
   333  		}
   334  
   335  		r.Method = "Put"
   336  		wr, err := h.FilePut.Filewrite(r)
   337  		if err != nil {
   338  			return statusFromError(id, err)
   339  		}
   340  
   341  		r.setWriterAt(wr)
   342  
   343  	case flags.Read:
   344  		r.Method = "Get"
   345  		rd, err := h.FileGet.Fileread(r)
   346  		if err != nil {
   347  			return statusFromError(id, err)
   348  		}
   349  
   350  		r.setReaderAt(rd)
   351  
   352  	default:
   353  		return statusFromError(id, errors.New("bad file flags"))
   354  	}
   355  
   356  	return &sshFxpHandlePacket{
   357  		ID:     id,
   358  		Handle: r.handle,
   359  	}
   360  }
   361  
   362  func (r *Request) opendir(h Handlers, pkt requestPacket) responsePacket {
   363  	r.Method = "List"
   364  	la, err := h.FileList.Filelist(r)
   365  	if err != nil {
   366  		return statusFromError(pkt.id(), wrapPathError(r.Filepath, err))
   367  	}
   368  
   369  	r.setListerAt(la)
   370  
   371  	return &sshFxpHandlePacket{
   372  		ID:     pkt.id(),
   373  		Handle: r.handle,
   374  	}
   375  }
   376  
   377  // wrap FileReader handler
   378  func fileget(h FileReader, r *Request, pkt requestPacket, alloc *allocator, orderID uint32) responsePacket {
   379  	rd := r.getReaderAt()
   380  	if rd == nil {
   381  		return statusFromError(pkt.id(), errors.New("unexpected read packet"))
   382  	}
   383  
   384  	data, offset, _ := packetData(pkt, alloc, orderID)
   385  
   386  	n, err := rd.ReadAt(data, offset)
   387  	// only return EOF error if no data left to read
   388  	if err != nil && (err != io.EOF || n == 0) {
   389  		return statusFromError(pkt.id(), err)
   390  	}
   391  
   392  	return &sshFxpDataPacket{
   393  		ID:     pkt.id(),
   394  		Length: uint32(n),
   395  		Data:   data[:n],
   396  	}
   397  }
   398  
   399  // wrap FileWriter handler
   400  func fileput(h FileWriter, r *Request, pkt requestPacket, alloc *allocator, orderID uint32) responsePacket {
   401  	wr := r.getWriterAt()
   402  	if wr == nil {
   403  		return statusFromError(pkt.id(), errors.New("unexpected write packet"))
   404  	}
   405  
   406  	data, offset, _ := packetData(pkt, alloc, orderID)
   407  
   408  	_, err := wr.WriteAt(data, offset)
   409  	return statusFromError(pkt.id(), err)
   410  }
   411  
   412  // wrap OpenFileWriter handler
   413  func fileputget(h FileWriter, r *Request, pkt requestPacket, alloc *allocator, orderID uint32) responsePacket {
   414  	rw := r.getWriterAtReaderAt()
   415  	if rw == nil {
   416  		return statusFromError(pkt.id(), errors.New("unexpected write and read packet"))
   417  	}
   418  
   419  	switch p := pkt.(type) {
   420  	case *sshFxpReadPacket:
   421  		data, offset := p.getDataSlice(alloc, orderID), int64(p.Offset)
   422  
   423  		n, err := rw.ReadAt(data, offset)
   424  		// only return EOF error if no data left to read
   425  		if err != nil && (err != io.EOF || n == 0) {
   426  			return statusFromError(pkt.id(), err)
   427  		}
   428  
   429  		return &sshFxpDataPacket{
   430  			ID:     pkt.id(),
   431  			Length: uint32(n),
   432  			Data:   data[:n],
   433  		}
   434  
   435  	case *sshFxpWritePacket:
   436  		data, offset := p.Data, int64(p.Offset)
   437  
   438  		_, err := rw.WriteAt(data, offset)
   439  		return statusFromError(pkt.id(), err)
   440  
   441  	default:
   442  		return statusFromError(pkt.id(), errors.New("unexpected packet type for read or write"))
   443  	}
   444  }
   445  
   446  // file data for additional read/write packets
   447  func packetData(p requestPacket, alloc *allocator, orderID uint32) (data []byte, offset int64, length uint32) {
   448  	switch p := p.(type) {
   449  	case *sshFxpReadPacket:
   450  		return p.getDataSlice(alloc, orderID), int64(p.Offset), p.Len
   451  	case *sshFxpWritePacket:
   452  		return p.Data, int64(p.Offset), p.Length
   453  	}
   454  	return
   455  }
   456  
   457  // wrap FileCmder handler
   458  func filecmd(h FileCmder, r *Request, pkt requestPacket) responsePacket {
   459  	switch p := pkt.(type) {
   460  	case *sshFxpFsetstatPacket:
   461  		r.Flags = p.Flags
   462  		r.Attrs = p.Attrs.([]byte)
   463  	}
   464  
   465  	switch r.Method {
   466  	case "PosixRename":
   467  		if posixRenamer, ok := h.(PosixRenameFileCmder); ok {
   468  			err := posixRenamer.PosixRename(r)
   469  			return statusFromError(pkt.id(), err)
   470  		}
   471  
   472  		// PosixRenameFileCmder not implemented handle this request as a Rename
   473  		r.Method = "Rename"
   474  		err := h.Filecmd(r)
   475  		return statusFromError(pkt.id(), err)
   476  
   477  	case "StatVFS":
   478  		if statVFSCmdr, ok := h.(StatVFSFileCmder); ok {
   479  			stat, err := statVFSCmdr.StatVFS(r)
   480  			if err != nil {
   481  				return statusFromError(pkt.id(), err)
   482  			}
   483  			stat.ID = pkt.id()
   484  			return stat
   485  		}
   486  
   487  		return statusFromError(pkt.id(), ErrSSHFxOpUnsupported)
   488  	}
   489  
   490  	err := h.Filecmd(r)
   491  	return statusFromError(pkt.id(), err)
   492  }
   493  
   494  // wrap FileLister handler
   495  func filelist(h FileLister, r *Request, pkt requestPacket) responsePacket {
   496  	lister := r.getListerAt()
   497  	if lister == nil {
   498  		return statusFromError(pkt.id(), errors.New("unexpected dir packet"))
   499  	}
   500  
   501  	offset := r.lsNext()
   502  	finfo := make([]os.FileInfo, MaxFilelist)
   503  	n, err := lister.ListAt(finfo, offset)
   504  	r.lsInc(int64(n))
   505  	// ignore EOF as we only return it when there are no results
   506  	finfo = finfo[:n] // avoid need for nil tests below
   507  
   508  	switch r.Method {
   509  	case "List":
   510  		if err != nil && (err != io.EOF || n == 0) {
   511  			return statusFromError(pkt.id(), err)
   512  		}
   513  
   514  		nameAttrs := make([]*sshFxpNameAttr, 0, len(finfo))
   515  
   516  		// If the type conversion fails, we get untyped `nil`,
   517  		// which is handled by not looking up any names.
   518  		idLookup, _ := h.(NameLookupFileLister)
   519  
   520  		for _, fi := range finfo {
   521  			nameAttrs = append(nameAttrs, &sshFxpNameAttr{
   522  				Name:     fi.Name(),
   523  				LongName: runLs(idLookup, fi),
   524  				Attrs:    []interface{}{fi},
   525  			})
   526  		}
   527  
   528  		return &sshFxpNamePacket{
   529  			ID:        pkt.id(),
   530  			NameAttrs: nameAttrs,
   531  		}
   532  
   533  	default:
   534  		err = fmt.Errorf("unexpected method: %s", r.Method)
   535  		return statusFromError(pkt.id(), err)
   536  	}
   537  }
   538  
   539  func filestat(h FileLister, r *Request, pkt requestPacket) responsePacket {
   540  	var lister ListerAt
   541  	var err error
   542  
   543  	if r.Method == "Lstat" {
   544  		if lstatFileLister, ok := h.(LstatFileLister); ok {
   545  			lister, err = lstatFileLister.Lstat(r)
   546  		} else {
   547  			// LstatFileLister not implemented handle this request as a Stat
   548  			r.Method = "Stat"
   549  			lister, err = h.Filelist(r)
   550  		}
   551  	} else {
   552  		lister, err = h.Filelist(r)
   553  	}
   554  	if err != nil {
   555  		return statusFromError(pkt.id(), err)
   556  	}
   557  	finfo := make([]os.FileInfo, 1)
   558  	n, err := lister.ListAt(finfo, 0)
   559  	finfo = finfo[:n] // avoid need for nil tests below
   560  
   561  	switch r.Method {
   562  	case "Stat", "Lstat":
   563  		if err != nil && err != io.EOF {
   564  			return statusFromError(pkt.id(), err)
   565  		}
   566  		if n == 0 {
   567  			err = &os.PathError{
   568  				Op:   strings.ToLower(r.Method),
   569  				Path: r.Filepath,
   570  				Err:  syscall.ENOENT,
   571  			}
   572  			return statusFromError(pkt.id(), err)
   573  		}
   574  		return &sshFxpStatResponse{
   575  			ID:   pkt.id(),
   576  			info: finfo[0],
   577  		}
   578  	case "Readlink":
   579  		if err != nil && err != io.EOF {
   580  			return statusFromError(pkt.id(), err)
   581  		}
   582  		if n == 0 {
   583  			err = &os.PathError{
   584  				Op:   "readlink",
   585  				Path: r.Filepath,
   586  				Err:  syscall.ENOENT,
   587  			}
   588  			return statusFromError(pkt.id(), err)
   589  		}
   590  		filename := finfo[0].Name()
   591  		return &sshFxpNamePacket{
   592  			ID: pkt.id(),
   593  			NameAttrs: []*sshFxpNameAttr{
   594  				{
   595  					Name:     filename,
   596  					LongName: filename,
   597  					Attrs:    emptyFileStat,
   598  				},
   599  			},
   600  		}
   601  	default:
   602  		err = fmt.Errorf("unexpected method: %s", r.Method)
   603  		return statusFromError(pkt.id(), err)
   604  	}
   605  }
   606  
   607  func readlink(readlinkFileLister ReadlinkFileLister, r *Request, pkt requestPacket) responsePacket {
   608  	resolved, err := readlinkFileLister.Readlink(r.Filepath)
   609  	if err != nil {
   610  		return statusFromError(pkt.id(), err)
   611  	}
   612  	return &sshFxpNamePacket{
   613  		ID: pkt.id(),
   614  		NameAttrs: []*sshFxpNameAttr{
   615  			{
   616  				Name:     resolved,
   617  				LongName: resolved,
   618  				Attrs:    emptyFileStat,
   619  			},
   620  		},
   621  	}
   622  }
   623  
   624  // init attributes of request object from packet data
   625  func requestMethod(p requestPacket) (method string) {
   626  	switch p.(type) {
   627  	case *sshFxpReadPacket, *sshFxpWritePacket, *sshFxpOpenPacket:
   628  		// set in open() above
   629  	case *sshFxpOpendirPacket, *sshFxpReaddirPacket:
   630  		// set in opendir() above
   631  	case *sshFxpSetstatPacket, *sshFxpFsetstatPacket:
   632  		method = "Setstat"
   633  	case *sshFxpRenamePacket:
   634  		method = "Rename"
   635  	case *sshFxpSymlinkPacket:
   636  		method = "Symlink"
   637  	case *sshFxpRemovePacket:
   638  		method = "Remove"
   639  	case *sshFxpStatPacket, *sshFxpFstatPacket:
   640  		method = "Stat"
   641  	case *sshFxpLstatPacket:
   642  		method = "Lstat"
   643  	case *sshFxpRmdirPacket:
   644  		method = "Rmdir"
   645  	case *sshFxpReadlinkPacket:
   646  		method = "Readlink"
   647  	case *sshFxpMkdirPacket:
   648  		method = "Mkdir"
   649  	case *sshFxpExtendedPacketHardlink:
   650  		method = "Link"
   651  	}
   652  	return method
   653  }