go-hep.org/x/hep@v0.38.1/xrootd/server.go (about)

     1  // Copyright ©2018 The go-hep Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package xrootd // import "go-hep.org/x/hep/xrootd"
     6  
     7  import (
     8  	"context"
     9  	"crypto/rand"
    10  	"errors"
    11  	"fmt"
    12  	"io"
    13  	"net"
    14  	"reflect"
    15  	"sync"
    16  
    17  	"go-hep.org/x/hep/xrootd/internal/xrdenc"
    18  	"go-hep.org/x/hep/xrootd/xrdproto"
    19  	"go-hep.org/x/hep/xrootd/xrdproto/dirlist"
    20  	"go-hep.org/x/hep/xrootd/xrdproto/handshake"
    21  	"go-hep.org/x/hep/xrootd/xrdproto/login"
    22  	"go-hep.org/x/hep/xrootd/xrdproto/mkdir"
    23  	"go-hep.org/x/hep/xrootd/xrdproto/mv"
    24  	"go-hep.org/x/hep/xrootd/xrdproto/open"
    25  	"go-hep.org/x/hep/xrootd/xrdproto/ping"
    26  	"go-hep.org/x/hep/xrootd/xrdproto/protocol"
    27  	"go-hep.org/x/hep/xrootd/xrdproto/read"
    28  	"go-hep.org/x/hep/xrootd/xrdproto/rm"
    29  	"go-hep.org/x/hep/xrootd/xrdproto/rmdir"
    30  	"go-hep.org/x/hep/xrootd/xrdproto/stat"
    31  	xrdsync "go-hep.org/x/hep/xrootd/xrdproto/sync"
    32  	"go-hep.org/x/hep/xrootd/xrdproto/truncate"
    33  	"go-hep.org/x/hep/xrootd/xrdproto/write"
    34  	"go-hep.org/x/hep/xrootd/xrdproto/xrdclose"
    35  )
    36  
    37  // ErrServerClosed is returned by the Server's Serve method after a call to Shutdown.
    38  var ErrServerClosed = errors.New("xrootd: server closed")
    39  
    40  // ErrorHandler is the function which handles occurred error (e.g. logs it).
    41  type ErrorHandler func(error)
    42  
    43  // Server implements the XRootD server following protocol from http://xrootd.org.
    44  // The Server uses a Handler to handle incoming requests.
    45  // To listen for incoming connections, Serve method must be called.
    46  // It is possible to configure to listen on several ports simultaneously
    47  // by calling Serve with different net.Listeners.
    48  type Server struct {
    49  	handler      Handler
    50  	errorHandler ErrorHandler
    51  
    52  	mu        sync.Mutex
    53  	listeners []net.Listener
    54  
    55  	closedMu sync.RWMutex
    56  	closed   bool
    57  
    58  	connMu     sync.Mutex
    59  	activeConn map[net.Conn]struct{}
    60  }
    61  
    62  // NewServer creates a XRootD server which uses specified handler to handle requests
    63  // and errorHandler to handle errors. If the errorHandler is nil,
    64  // then a default error handler is used that does nothing.
    65  func NewServer(handler Handler, errorHandler ErrorHandler) *Server {
    66  	if errorHandler == nil {
    67  		errorHandler = func(error) {}
    68  	}
    69  	return &Server{
    70  		handler:      handler,
    71  		errorHandler: errorHandler,
    72  		activeConn:   make(map[net.Conn]struct{}),
    73  	}
    74  }
    75  
    76  // Shutdown stops Server and closes all listeners and active connections.
    77  // Shutdown returns the first non nil error while closing listeners and connections.
    78  func (s *Server) Shutdown(ctx context.Context) error {
    79  	var err error
    80  
    81  	s.closedMu.Lock()
    82  	s.closed = true
    83  	s.closedMu.Unlock()
    84  
    85  	s.mu.Lock()
    86  	defer s.mu.Unlock()
    87  	for i := range s.listeners {
    88  		if cerr := s.listeners[i].Close(); cerr != nil && err == nil {
    89  			err = cerr
    90  		}
    91  	}
    92  
    93  	// TODO: wait for active requests to be processed as long as ctx is not done.
    94  	s.connMu.Lock()
    95  	defer s.connMu.Unlock()
    96  	for conn := range s.activeConn {
    97  		if cerr := conn.Close(); cerr != nil && err == nil {
    98  			err = cerr
    99  		}
   100  	}
   101  	return err
   102  }
   103  
   104  // Serve accepts incoming connections on the Listener l, creating a
   105  // new service goroutine for each. The service goroutines read requests and
   106  // then call s.handler to handle them.
   107  func (s *Server) Serve(l net.Listener) error {
   108  	s.mu.Lock()
   109  	s.listeners = append(s.listeners, l)
   110  	s.mu.Unlock()
   111  	for {
   112  		conn, err := l.Accept()
   113  		if err != nil {
   114  			s.closedMu.RLock()
   115  			defer s.closedMu.RUnlock()
   116  			if s.closed {
   117  				return ErrServerClosed
   118  			}
   119  			return err
   120  		}
   121  
   122  		s.connMu.Lock()
   123  		s.activeConn[conn] = struct{}{}
   124  		s.connMu.Unlock()
   125  
   126  		go s.handleConnection(conn)
   127  	}
   128  }
   129  
   130  // handleConnection handles the client connection.
   131  // handleConnection reads the handshake and checks it correctness.
   132  // In case of success, main loop is started that reads requests and
   133  // handles them. Otherwise, connection is aborted.
   134  func (s *Server) handleConnection(conn net.Conn) {
   135  	defer conn.Close()
   136  	defer func() {
   137  		s.connMu.Lock()
   138  		delete(s.activeConn, conn)
   139  		s.connMu.Unlock()
   140  	}()
   141  
   142  	var sessionID [16]byte
   143  	if _, err := rand.Read(sessionID[:]); err != nil {
   144  		s.errorHandler(fmt.Errorf("could not read session ID: %w", err))
   145  	}
   146  	defer func() {
   147  		if err := s.handler.CloseSession(sessionID); err != nil {
   148  			s.errorHandler(fmt.Errorf("could not close session ID %q: %w", sessionID, err))
   149  		}
   150  	}()
   151  
   152  	if err := s.handleHandshake(conn); err != nil {
   153  		s.errorHandler(fmt.Errorf("could not handle handshake: %w", err))
   154  		// Abort the connection if the handshake was malformed.
   155  		return
   156  	}
   157  
   158  	for {
   159  		// We are using conn for read access only in that place
   160  		// and only once at time for each conn, so no additional
   161  		// serialization is needed.
   162  		reqData, err := xrdproto.ReadRequest(conn)
   163  		if err == io.EOF || err == io.ErrClosedPipe {
   164  			// Client closed the connection.
   165  			return
   166  		}
   167  		if err != nil {
   168  			s.closedMu.RLock()
   169  			defer s.closedMu.RUnlock()
   170  			// TODO: wait for active requests to be processed while closing.
   171  			if !s.closed {
   172  				s.errorHandler(fmt.Errorf("could not close connection: %w", err))
   173  			}
   174  			// Abort the connection if an error occurred during
   175  			// the reading phase because we can't recover from it.
   176  			return
   177  		}
   178  
   179  		// Performing a request may take some time so we are running it
   180  		// in the separate goroutine. We follow the XRootD protocol and
   181  		// write results back with StreamID provided in the request,
   182  		// so Client will match the responses to the corresponding request calls.
   183  		go func(req []byte) {
   184  			var (
   185  				reqHeader xrdproto.RequestHeader
   186  				resp      xrdproto.Marshaler
   187  				status    xrdproto.ResponseStatus
   188  			)
   189  
   190  			rBuffer := xrdenc.NewRBuffer(req)
   191  			if err := reqHeader.UnmarshalXrd(rBuffer); err != nil {
   192  				resp, status = newUnmarshalingErrorResponse(err)
   193  			} else {
   194  				resp, status = s.handleRequest(sessionID, reqHeader.RequestID, rBuffer)
   195  			}
   196  
   197  			if err := xrdproto.WriteResponse(conn, reqHeader.StreamID, status, resp); err != nil {
   198  				s.closedMu.RLock()
   199  				defer s.closedMu.RUnlock()
   200  				// TODO: wait for active requests to be processed while closing.
   201  				if !s.closed {
   202  					s.errorHandler(fmt.Errorf("could not close connection: %w", err))
   203  				}
   204  				// Abort the connection if an error occurred during
   205  				// the writing phase because we can't recover from it.
   206  				return
   207  			}
   208  		}(reqData)
   209  	}
   210  }
   211  
   212  func (s *Server) handleHandshake(conn net.Conn) error {
   213  	data := make([]byte, handshake.RequestLength)
   214  	if _, err := io.ReadFull(conn, data); err != nil {
   215  		return err
   216  	}
   217  
   218  	var req handshake.Request
   219  	rBuffer := xrdenc.NewRBuffer(data)
   220  	err := req.UnmarshalXrd(rBuffer)
   221  	if err != nil {
   222  		return err
   223  	}
   224  
   225  	correctHandshake := handshake.NewRequest()
   226  	if !reflect.DeepEqual(req, correctHandshake) {
   227  		return fmt.Errorf("xrootd: connection %v: wrong handshake\ngot = %v\nwant = %v", conn.RemoteAddr(), req, correctHandshake)
   228  	}
   229  
   230  	resp, status := s.handler.Handshake()
   231  	return xrdproto.WriteResponse(conn, xrdproto.StreamID{0, 0}, status, resp)
   232  }
   233  
   234  func newUnmarshalingErrorResponse(err error) (xrdproto.Marshaler, xrdproto.ResponseStatus) {
   235  	response := xrdproto.ServerError{
   236  		Code:    xrdproto.InvalidRequest,
   237  		Message: fmt.Errorf("xrootd: an error occurred while parsing the request: %w", err).Error(),
   238  	}
   239  	return response, xrdproto.Error
   240  }
   241  
   242  func (s *Server) handleRequest(sessionID [16]byte, requestID uint16, rBuffer *xrdenc.RBuffer) (xrdproto.Marshaler, xrdproto.ResponseStatus) {
   243  	switch requestID {
   244  	case login.RequestID:
   245  		var request login.Request
   246  		err := request.UnmarshalXrd(rBuffer)
   247  		if err != nil {
   248  			return newUnmarshalingErrorResponse(err)
   249  		}
   250  		return s.handler.Login(sessionID, &request)
   251  	case protocol.RequestID:
   252  		var request protocol.Request
   253  		err := request.UnmarshalXrd(rBuffer)
   254  		if err != nil {
   255  			return newUnmarshalingErrorResponse(err)
   256  		}
   257  		return s.handler.Protocol(sessionID, &request)
   258  	case dirlist.RequestID:
   259  		var request dirlist.Request
   260  		err := request.UnmarshalXrd(rBuffer)
   261  		if err != nil {
   262  			return newUnmarshalingErrorResponse(err)
   263  		}
   264  		return s.handler.Dirlist(sessionID, &request)
   265  	case open.RequestID:
   266  		var request open.Request
   267  		err := request.UnmarshalXrd(rBuffer)
   268  		if err != nil {
   269  			return newUnmarshalingErrorResponse(err)
   270  		}
   271  		return s.handler.Open(sessionID, &request)
   272  	case xrdclose.RequestID:
   273  		var request xrdclose.Request
   274  		err := request.UnmarshalXrd(rBuffer)
   275  		if err != nil {
   276  			return newUnmarshalingErrorResponse(err)
   277  		}
   278  		return s.handler.Close(sessionID, &request)
   279  	case read.RequestID:
   280  		var request read.Request
   281  		err := request.UnmarshalXrd(rBuffer)
   282  		if err != nil {
   283  			return newUnmarshalingErrorResponse(err)
   284  		}
   285  		return s.handler.Read(sessionID, &request)
   286  	case write.RequestID:
   287  		var request write.Request
   288  		err := request.UnmarshalXrd(rBuffer)
   289  		if err != nil {
   290  			return newUnmarshalingErrorResponse(err)
   291  		}
   292  		return s.handler.Write(sessionID, &request)
   293  	case stat.RequestID:
   294  		var request stat.Request
   295  		err := request.UnmarshalXrd(rBuffer)
   296  		if err != nil {
   297  			return newUnmarshalingErrorResponse(err)
   298  		}
   299  		return s.handler.Stat(sessionID, &request)
   300  	case xrdsync.RequestID:
   301  		var request xrdsync.Request
   302  		err := request.UnmarshalXrd(rBuffer)
   303  		if err != nil {
   304  			return newUnmarshalingErrorResponse(err)
   305  		}
   306  		return s.handler.Sync(sessionID, &request)
   307  	case truncate.RequestID:
   308  		var request truncate.Request
   309  		err := request.UnmarshalXrd(rBuffer)
   310  		if err != nil {
   311  			return newUnmarshalingErrorResponse(err)
   312  		}
   313  		return s.handler.Truncate(sessionID, &request)
   314  	case mv.RequestID:
   315  		var request mv.Request
   316  		err := request.UnmarshalXrd(rBuffer)
   317  		if err != nil {
   318  			return newUnmarshalingErrorResponse(err)
   319  		}
   320  		return s.handler.Rename(sessionID, &request)
   321  	case mkdir.RequestID:
   322  		var request mkdir.Request
   323  		err := request.UnmarshalXrd(rBuffer)
   324  		if err != nil {
   325  			return newUnmarshalingErrorResponse(err)
   326  		}
   327  		return s.handler.Mkdir(sessionID, &request)
   328  	case ping.RequestID:
   329  		var request ping.Request
   330  		err := request.UnmarshalXrd(rBuffer)
   331  		if err != nil {
   332  			return newUnmarshalingErrorResponse(err)
   333  		}
   334  		return s.handler.Ping(sessionID, &request)
   335  	case rm.RequestID:
   336  		var request rm.Request
   337  		err := request.UnmarshalXrd(rBuffer)
   338  		if err != nil {
   339  			return newUnmarshalingErrorResponse(err)
   340  		}
   341  		return s.handler.Remove(sessionID, &request)
   342  	case rmdir.RequestID:
   343  		var request rmdir.Request
   344  		err := request.UnmarshalXrd(rBuffer)
   345  		if err != nil {
   346  			return newUnmarshalingErrorResponse(err)
   347  		}
   348  		return s.handler.RemoveDir(sessionID, &request)
   349  	default:
   350  		response := xrdproto.ServerError{
   351  			Code:    xrdproto.InvalidRequest,
   352  			Message: fmt.Sprintf("Unknown request id: %d", requestID),
   353  		}
   354  		return response, xrdproto.Error
   355  	}
   356  }