github.com/pkg/sftp@v1.13.6/request-server.go (about)

     1  package sftp
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"io"
     7  	"path"
     8  	"path/filepath"
     9  	"strconv"
    10  	"sync"
    11  )
    12  
    13  var maxTxPacket uint32 = 1 << 15
    14  
    15  // Handlers contains the 4 SFTP server request handlers.
    16  type Handlers struct {
    17  	FileGet  FileReader
    18  	FilePut  FileWriter
    19  	FileCmd  FileCmder
    20  	FileList FileLister
    21  }
    22  
    23  // RequestServer abstracts the sftp protocol with an http request-like protocol
    24  type RequestServer struct {
    25  	Handlers Handlers
    26  
    27  	*serverConn
    28  	pktMgr *packetManager
    29  
    30  	startDirectory string
    31  
    32  	mu           sync.RWMutex
    33  	handleCount  int
    34  	openRequests map[string]*Request
    35  }
    36  
    37  // A RequestServerOption is a function which applies configuration to a RequestServer.
    38  type RequestServerOption func(*RequestServer)
    39  
    40  // WithRSAllocator enable the allocator.
    41  // After processing a packet we keep in memory the allocated slices
    42  // and we reuse them for new packets.
    43  // The allocator is experimental
    44  func WithRSAllocator() RequestServerOption {
    45  	return func(rs *RequestServer) {
    46  		alloc := newAllocator()
    47  		rs.pktMgr.alloc = alloc
    48  		rs.conn.alloc = alloc
    49  	}
    50  }
    51  
    52  // WithStartDirectory sets a start directory to use as base for relative paths.
    53  // If unset the default is "/"
    54  func WithStartDirectory(startDirectory string) RequestServerOption {
    55  	return func(rs *RequestServer) {
    56  		rs.startDirectory = cleanPath(startDirectory)
    57  	}
    58  }
    59  
    60  // NewRequestServer creates/allocates/returns new RequestServer.
    61  // Normally there will be one server per user-session.
    62  func NewRequestServer(rwc io.ReadWriteCloser, h Handlers, options ...RequestServerOption) *RequestServer {
    63  	svrConn := &serverConn{
    64  		conn: conn{
    65  			Reader:      rwc,
    66  			WriteCloser: rwc,
    67  		},
    68  	}
    69  	rs := &RequestServer{
    70  		Handlers: h,
    71  
    72  		serverConn: svrConn,
    73  		pktMgr:     newPktMgr(svrConn),
    74  
    75  		startDirectory: "/",
    76  
    77  		openRequests: make(map[string]*Request),
    78  	}
    79  
    80  	for _, o := range options {
    81  		o(rs)
    82  	}
    83  	return rs
    84  }
    85  
    86  // New Open packet/Request
    87  func (rs *RequestServer) nextRequest(r *Request) string {
    88  	rs.mu.Lock()
    89  	defer rs.mu.Unlock()
    90  
    91  	rs.handleCount++
    92  
    93  	r.handle = strconv.Itoa(rs.handleCount)
    94  	rs.openRequests[r.handle] = r
    95  
    96  	return r.handle
    97  }
    98  
    99  // Returns Request from openRequests, bool is false if it is missing.
   100  //
   101  // The Requests in openRequests work essentially as open file descriptors that
   102  // you can do different things with. What you are doing with it are denoted by
   103  // the first packet of that type (read/write/etc).
   104  func (rs *RequestServer) getRequest(handle string) (*Request, bool) {
   105  	rs.mu.RLock()
   106  	defer rs.mu.RUnlock()
   107  
   108  	r, ok := rs.openRequests[handle]
   109  	return r, ok
   110  }
   111  
   112  // Close the Request and clear from openRequests map
   113  func (rs *RequestServer) closeRequest(handle string) error {
   114  	rs.mu.Lock()
   115  	defer rs.mu.Unlock()
   116  
   117  	if r, ok := rs.openRequests[handle]; ok {
   118  		delete(rs.openRequests, handle)
   119  		return r.close()
   120  	}
   121  
   122  	return EBADF
   123  }
   124  
   125  // Close the read/write/closer to trigger exiting the main server loop
   126  func (rs *RequestServer) Close() error { return rs.conn.Close() }
   127  
   128  func (rs *RequestServer) serveLoop(pktChan chan<- orderedRequest) error {
   129  	defer close(pktChan) // shuts down sftpServerWorkers
   130  
   131  	var err error
   132  	var pkt requestPacket
   133  	var pktType uint8
   134  	var pktBytes []byte
   135  
   136  	for {
   137  		pktType, pktBytes, err = rs.serverConn.recvPacket(rs.pktMgr.getNextOrderID())
   138  		if err != nil {
   139  			// we don't care about releasing allocated pages here, the server will quit and the allocator freed
   140  			return err
   141  		}
   142  
   143  		pkt, err = makePacket(rxPacket{fxp(pktType), pktBytes})
   144  		if err != nil {
   145  			switch {
   146  			case errors.Is(err, errUnknownExtendedPacket):
   147  				// do nothing
   148  			default:
   149  				debug("makePacket err: %v", err)
   150  				rs.conn.Close() // shuts down recvPacket
   151  				return err
   152  			}
   153  		}
   154  
   155  		pktChan <- rs.pktMgr.newOrderedRequest(pkt)
   156  	}
   157  }
   158  
   159  // Serve requests for user session
   160  func (rs *RequestServer) Serve() error {
   161  	defer func() {
   162  		if rs.pktMgr.alloc != nil {
   163  			rs.pktMgr.alloc.Free()
   164  		}
   165  	}()
   166  
   167  	ctx, cancel := context.WithCancel(context.Background())
   168  	defer cancel()
   169  
   170  	var wg sync.WaitGroup
   171  	runWorker := func(ch chan orderedRequest) {
   172  		wg.Add(1)
   173  		go func() {
   174  			defer wg.Done()
   175  			if err := rs.packetWorker(ctx, ch); err != nil {
   176  				rs.conn.Close() // shuts down recvPacket
   177  			}
   178  		}()
   179  	}
   180  	pktChan := rs.pktMgr.workerChan(runWorker)
   181  
   182  	err := rs.serveLoop(pktChan)
   183  
   184  	wg.Wait() // wait for all workers to exit
   185  
   186  	rs.mu.Lock()
   187  	defer rs.mu.Unlock()
   188  
   189  	// make sure all open requests are properly closed
   190  	// (eg. possible on dropped connections, client crashes, etc.)
   191  	for handle, req := range rs.openRequests {
   192  		if err == io.EOF {
   193  			err = io.ErrUnexpectedEOF
   194  		}
   195  		req.transferError(err)
   196  
   197  		delete(rs.openRequests, handle)
   198  		req.close()
   199  	}
   200  
   201  	return err
   202  }
   203  
   204  func (rs *RequestServer) packetWorker(ctx context.Context, pktChan chan orderedRequest) error {
   205  	for pkt := range pktChan {
   206  		orderID := pkt.orderID()
   207  		if epkt, ok := pkt.requestPacket.(*sshFxpExtendedPacket); ok {
   208  			if epkt.SpecificPacket != nil {
   209  				pkt.requestPacket = epkt.SpecificPacket
   210  			}
   211  		}
   212  
   213  		var rpkt responsePacket
   214  		switch pkt := pkt.requestPacket.(type) {
   215  		case *sshFxInitPacket:
   216  			rpkt = &sshFxVersionPacket{Version: sftpProtocolVersion, Extensions: sftpExtensions}
   217  		case *sshFxpClosePacket:
   218  			handle := pkt.getHandle()
   219  			rpkt = statusFromError(pkt.ID, rs.closeRequest(handle))
   220  		case *sshFxpRealpathPacket:
   221  			var realPath string
   222  			var err error
   223  
   224  			switch pather := rs.Handlers.FileList.(type) {
   225  			case RealPathFileLister:
   226  				realPath, err = pather.RealPath(pkt.getPath())
   227  			case legacyRealPathFileLister:
   228  				realPath = pather.RealPath(pkt.getPath())
   229  			default:
   230  				realPath = cleanPathWithBase(rs.startDirectory, pkt.getPath())
   231  			}
   232  			if err != nil {
   233  				rpkt = statusFromError(pkt.ID, err)
   234  			} else {
   235  				rpkt = cleanPacketPath(pkt, realPath)
   236  			}
   237  		case *sshFxpOpendirPacket:
   238  			request := requestFromPacket(ctx, pkt, rs.startDirectory)
   239  			handle := rs.nextRequest(request)
   240  			rpkt = request.opendir(rs.Handlers, pkt)
   241  			if _, ok := rpkt.(*sshFxpHandlePacket); !ok {
   242  				// if we return an error we have to remove the handle from the active ones
   243  				rs.closeRequest(handle)
   244  			}
   245  		case *sshFxpOpenPacket:
   246  			request := requestFromPacket(ctx, pkt, rs.startDirectory)
   247  			handle := rs.nextRequest(request)
   248  			rpkt = request.open(rs.Handlers, pkt)
   249  			if _, ok := rpkt.(*sshFxpHandlePacket); !ok {
   250  				// if we return an error we have to remove the handle from the active ones
   251  				rs.closeRequest(handle)
   252  			}
   253  		case *sshFxpFstatPacket:
   254  			handle := pkt.getHandle()
   255  			request, ok := rs.getRequest(handle)
   256  			if !ok {
   257  				rpkt = statusFromError(pkt.ID, EBADF)
   258  			} else {
   259  				request = &Request{
   260  					Method:   "Stat",
   261  					Filepath: cleanPathWithBase(rs.startDirectory, request.Filepath),
   262  				}
   263  				rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
   264  			}
   265  		case *sshFxpFsetstatPacket:
   266  			handle := pkt.getHandle()
   267  			request, ok := rs.getRequest(handle)
   268  			if !ok {
   269  				rpkt = statusFromError(pkt.ID, EBADF)
   270  			} else {
   271  				request = &Request{
   272  					Method:   "Setstat",
   273  					Filepath: cleanPathWithBase(rs.startDirectory, request.Filepath),
   274  				}
   275  				rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
   276  			}
   277  		case *sshFxpExtendedPacketPosixRename:
   278  			request := &Request{
   279  				Method:   "PosixRename",
   280  				Filepath: cleanPathWithBase(rs.startDirectory, pkt.Oldpath),
   281  				Target:   cleanPathWithBase(rs.startDirectory, pkt.Newpath),
   282  			}
   283  			rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
   284  		case *sshFxpExtendedPacketStatVFS:
   285  			request := &Request{
   286  				Method:   "StatVFS",
   287  				Filepath: cleanPathWithBase(rs.startDirectory, pkt.Path),
   288  			}
   289  			rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
   290  		case hasHandle:
   291  			handle := pkt.getHandle()
   292  			request, ok := rs.getRequest(handle)
   293  			if !ok {
   294  				rpkt = statusFromError(pkt.id(), EBADF)
   295  			} else {
   296  				rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
   297  			}
   298  		case hasPath:
   299  			request := requestFromPacket(ctx, pkt, rs.startDirectory)
   300  			rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
   301  			request.close()
   302  		default:
   303  			rpkt = statusFromError(pkt.id(), ErrSSHFxOpUnsupported)
   304  		}
   305  
   306  		rs.pktMgr.readyPacket(
   307  			rs.pktMgr.newOrderedResponse(rpkt, orderID))
   308  	}
   309  	return nil
   310  }
   311  
   312  // clean and return name packet for file
   313  func cleanPacketPath(pkt *sshFxpRealpathPacket, realPath string) responsePacket {
   314  	return &sshFxpNamePacket{
   315  		ID: pkt.id(),
   316  		NameAttrs: []*sshFxpNameAttr{
   317  			{
   318  				Name:     realPath,
   319  				LongName: realPath,
   320  				Attrs:    emptyFileStat,
   321  			},
   322  		},
   323  	}
   324  }
   325  
   326  // Makes sure we have a clean POSIX (/) absolute path to work with
   327  func cleanPath(p string) string {
   328  	return cleanPathWithBase("/", p)
   329  }
   330  
   331  func cleanPathWithBase(base, p string) string {
   332  	p = filepath.ToSlash(filepath.Clean(p))
   333  	if !path.IsAbs(p) {
   334  		return path.Join(base, p)
   335  	}
   336  	return p
   337  }