storj.io/minio@v0.0.0-20230509071714-0cbc90f649b1/pkg/rpc/server.go (about)

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Copyright 2012 The Gorilla Authors. All rights reserved.
     3  // Use of this source code is governed by a BSD-style
     4  // license that can be found in the LICENSE file.
     5  
     6  // Copyright 2020 MinIO, Inc. All rights reserved.
     7  // forked from https://github.com/gorilla/rpc/v2
     8  // modified to be used with MinIO under Apache
     9  // 2.0 license that can be found in the LICENSE file.
    10  
    11  package rpc
    12  
    13  import (
    14  	"fmt"
    15  	"net/http"
    16  	"reflect"
    17  	"strings"
    18  )
    19  
    20  var nilErrorValue = reflect.Zero(reflect.TypeOf((*error)(nil)).Elem())
    21  
    22  // ----------------------------------------------------------------------------
    23  // Codec
    24  // ----------------------------------------------------------------------------
    25  
    26  // Codec creates a CodecRequest to process each request.
    27  type Codec interface {
    28  	NewRequest(*http.Request) CodecRequest
    29  }
    30  
    31  // CodecRequest decodes a request and encodes a response using a specific
    32  // serialization scheme.
    33  type CodecRequest interface {
    34  	// Reads the request and returns the RPC method name.
    35  	Method() (string, error)
    36  	// Reads the request filling the RPC method args.
    37  	ReadRequest(interface{}) error
    38  	// Writes the response using the RPC method reply.
    39  	WriteResponse(http.ResponseWriter, interface{})
    40  	// Writes an error produced by the server.
    41  	WriteError(w http.ResponseWriter, status int, err error)
    42  }
    43  
    44  // ----------------------------------------------------------------------------
    45  // Server
    46  // ----------------------------------------------------------------------------
    47  
    48  // NewServer returns a new RPC server.
    49  func NewServer() *Server {
    50  	return &Server{
    51  		codecs:   make(map[string]Codec),
    52  		services: new(serviceMap),
    53  	}
    54  }
    55  
    56  // RequestInfo contains all the information we pass to before/after functions
    57  type RequestInfo struct {
    58  	Args           reflect.Value
    59  	Method         string
    60  	Error          error
    61  	ResponseWriter http.ResponseWriter
    62  	Request        *http.Request
    63  	StatusCode     int
    64  }
    65  
    66  // Server serves registered RPC services using registered codecs.
    67  type Server struct {
    68  	codecs        map[string]Codec
    69  	services      *serviceMap
    70  	interceptFunc func(i *RequestInfo) *http.Request
    71  	beforeFunc    func(i *RequestInfo)
    72  	afterFunc     func(i *RequestInfo)
    73  	validateFunc  reflect.Value
    74  }
    75  
    76  // RegisterCodec adds a new codec to the server.
    77  //
    78  // Codecs are defined to process a given serialization scheme, e.g., JSON or
    79  // XML. A codec is chosen based on the "Content-Type" header from the request,
    80  // excluding the charset definition.
    81  func (s *Server) RegisterCodec(codec Codec, contentType string) {
    82  	s.codecs[strings.ToLower(contentType)] = codec
    83  }
    84  
    85  // RegisterInterceptFunc registers the specified function as the function
    86  // that will be called before every request. The function is allowed to intercept
    87  // the request e.g. add values to the context.
    88  //
    89  // Note: Only one function can be registered, subsequent calls to this
    90  // method will overwrite all the previous functions.
    91  func (s *Server) RegisterInterceptFunc(f func(i *RequestInfo) *http.Request) {
    92  	s.interceptFunc = f
    93  }
    94  
    95  // RegisterBeforeFunc registers the specified function as the function
    96  // that will be called before every request.
    97  //
    98  // Note: Only one function can be registered, subsequent calls to this
    99  // method will overwrite all the previous functions.
   100  func (s *Server) RegisterBeforeFunc(f func(i *RequestInfo)) {
   101  	s.beforeFunc = f
   102  }
   103  
   104  // RegisterValidateRequestFunc registers the specified function as the function
   105  // that will be called after the BeforeFunc (if registered) and before invoking
   106  // the actual Service method. If this function returns a non-nil error, the method
   107  // won't be invoked and this error will be considered as the method result.
   108  // The first argument is information about the request, useful for accessing to http.Request.Context()
   109  // The second argument of this function is the already-unmarshalled *args parameter of the method.
   110  func (s *Server) RegisterValidateRequestFunc(f func(r *RequestInfo, i interface{}) error) {
   111  	s.validateFunc = reflect.ValueOf(f)
   112  }
   113  
   114  // RegisterAfterFunc registers the specified function as the function
   115  // that will be called after every request
   116  //
   117  // Note: Only one function can be registered, subsequent calls to this
   118  // method will overwrite all the previous functions.
   119  func (s *Server) RegisterAfterFunc(f func(i *RequestInfo)) {
   120  	s.afterFunc = f
   121  }
   122  
   123  // RegisterService adds a new service to the server.
   124  //
   125  // The name parameter is optional: if empty it will be inferred from
   126  // the receiver type name.
   127  //
   128  // Methods from the receiver will be extracted if these rules are satisfied:
   129  //
   130  //    - The receiver is exported (begins with an upper case letter) or local
   131  //      (defined in the package registering the service).
   132  //    - The method name is exported.
   133  //    - The method has three arguments: *http.Request, *args, *reply.
   134  //    - All three arguments are pointers.
   135  //    - The second and third arguments are exported or local.
   136  //    - The method has return type error.
   137  //
   138  // All other methods are ignored.
   139  func (s *Server) RegisterService(receiver interface{}, name string) error {
   140  	return s.services.register(receiver, name)
   141  }
   142  
   143  // HasMethod returns true if the given method is registered.
   144  //
   145  // The method uses a dotted notation as in "Service.Method".
   146  func (s *Server) HasMethod(method string) bool {
   147  	if _, _, err := s.services.get(method); err == nil {
   148  		return true
   149  	}
   150  	return false
   151  }
   152  
   153  // ServeHTTP
   154  func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
   155  	if r.Method != "POST" {
   156  		err := fmt.Errorf("rpc: POST method required, received %s", r.Method)
   157  		WriteError(w, http.StatusMethodNotAllowed, err.Error())
   158  		// Call the registered After Function
   159  		if s.afterFunc != nil {
   160  			s.afterFunc(&RequestInfo{
   161  				ResponseWriter: w,
   162  				Request:        r,
   163  				Method:         "Unknown." + r.Method,
   164  				StatusCode:     http.StatusMethodNotAllowed,
   165  			})
   166  		}
   167  		return
   168  	}
   169  	contentType := r.Header.Get("Content-Type")
   170  	idx := strings.Index(contentType, ";")
   171  	if idx != -1 {
   172  		contentType = contentType[:idx]
   173  	}
   174  	var codec Codec
   175  	if contentType == "" && len(s.codecs) == 1 {
   176  		// If Content-Type is not set and only one codec has been registered,
   177  		// then default to that codec.
   178  		for _, c := range s.codecs {
   179  			codec = c
   180  		}
   181  	} else if codec = s.codecs[strings.ToLower(contentType)]; codec == nil {
   182  		err := fmt.Errorf("rpc: unrecognized Content-Type: %s", contentType)
   183  		WriteError(w, http.StatusUnsupportedMediaType, err.Error())
   184  		// Call the registered After Function
   185  		if s.afterFunc != nil {
   186  			s.afterFunc(&RequestInfo{
   187  				ResponseWriter: w,
   188  				Request:        r,
   189  				Method:         "Unknown." + r.Method,
   190  				Error:          err,
   191  				StatusCode:     http.StatusUnsupportedMediaType,
   192  			})
   193  		}
   194  		return
   195  	}
   196  	// Create a new codec request.
   197  	codecReq := codec.NewRequest(r)
   198  	// Get service method to be called.
   199  	method, errMethod := codecReq.Method()
   200  	if errMethod != nil {
   201  		codecReq.WriteError(w, http.StatusBadRequest, errMethod)
   202  		if s.afterFunc != nil {
   203  			s.afterFunc(&RequestInfo{
   204  				ResponseWriter: w,
   205  				Request:        r,
   206  				Method:         "Unknown." + r.Method,
   207  				Error:          errMethod,
   208  				StatusCode:     http.StatusBadRequest,
   209  			})
   210  		}
   211  		return
   212  	}
   213  	serviceSpec, methodSpec, errGet := s.services.get(method)
   214  	if errGet != nil {
   215  		codecReq.WriteError(w, http.StatusBadRequest, errGet)
   216  		if s.afterFunc != nil {
   217  			s.afterFunc(&RequestInfo{
   218  				ResponseWriter: w,
   219  				Request:        r,
   220  				Method:         method,
   221  				Error:          errGet,
   222  				StatusCode:     http.StatusBadRequest,
   223  			})
   224  		}
   225  		return
   226  	}
   227  	// Decode the args.
   228  	args := reflect.New(methodSpec.argsType)
   229  	if errRead := codecReq.ReadRequest(args.Interface()); errRead != nil {
   230  		codecReq.WriteError(w, http.StatusBadRequest, errRead)
   231  		if s.afterFunc != nil {
   232  			s.afterFunc(&RequestInfo{
   233  				ResponseWriter: w,
   234  				Request:        r,
   235  				Method:         method,
   236  				Error:          errRead,
   237  				StatusCode:     http.StatusBadRequest,
   238  			})
   239  		}
   240  		return
   241  	}
   242  
   243  	// Call the registered Intercept Function
   244  	if s.interceptFunc != nil {
   245  		req := s.interceptFunc(&RequestInfo{
   246  			Request: r,
   247  			Method:  method,
   248  		})
   249  		if req != nil {
   250  			r = req
   251  		}
   252  	}
   253  
   254  	requestInfo := &RequestInfo{
   255  		Request: r,
   256  		Method:  method,
   257  	}
   258  
   259  	// Call the registered Before Function
   260  	if s.beforeFunc != nil {
   261  		s.beforeFunc(requestInfo)
   262  	}
   263  
   264  	// Prepare the reply, we need it even if validation fails
   265  	reply := reflect.New(methodSpec.replyType)
   266  	errValue := []reflect.Value{nilErrorValue}
   267  
   268  	// Call the registered Validator Function
   269  	if s.validateFunc.IsValid() {
   270  		errValue = s.validateFunc.Call([]reflect.Value{reflect.ValueOf(requestInfo), args})
   271  	}
   272  
   273  	// If still no errors after validation, call the method
   274  	if errValue[0].IsNil() {
   275  		errValue = methodSpec.method.Func.Call([]reflect.Value{
   276  			serviceSpec.rcvr,
   277  			reflect.ValueOf(r),
   278  			args,
   279  			reply,
   280  		})
   281  	}
   282  
   283  	// Extract the result to error if needed.
   284  	var errResult error
   285  	statusCode := http.StatusOK
   286  	errInter := errValue[0].Interface()
   287  	if errInter != nil {
   288  		statusCode = http.StatusBadRequest
   289  		errResult = errInter.(error)
   290  	}
   291  
   292  	// Prevents Internet Explorer from MIME-sniffing a response away
   293  	// from the declared content-type
   294  	w.Header().Set("x-content-type-options", "nosniff")
   295  
   296  	// Encode the response.
   297  	if errResult == nil {
   298  		codecReq.WriteResponse(w, reply.Interface())
   299  	} else {
   300  		codecReq.WriteError(w, statusCode, errResult)
   301  	}
   302  
   303  	// Call the registered After Function
   304  	if s.afterFunc != nil {
   305  		s.afterFunc(&RequestInfo{
   306  			Args:           args,
   307  			ResponseWriter: w,
   308  			Request:        r,
   309  			Method:         method,
   310  			Error:          errResult,
   311  			StatusCode:     statusCode,
   312  		})
   313  	}
   314  }
   315  
   316  func WriteError(w http.ResponseWriter, status int, msg string) {
   317  	w.Header().Set("Content-Type", "text/plain; charset=utf-8")
   318  	w.WriteHeader(status)
   319  	fmt.Fprint(w, msg)
   320  }