storj.io/minio@v0.0.0-20230509071714-0cbc90f649b1/pkg/rpc/server.go (about) 1 // Copyright 2009 The Go Authors. All rights reserved. 2 // Copyright 2012 The Gorilla Authors. All rights reserved. 3 // Use of this source code is governed by a BSD-style 4 // license that can be found in the LICENSE file. 5 6 // Copyright 2020 MinIO, Inc. All rights reserved. 7 // forked from https://github.com/gorilla/rpc/v2 8 // modified to be used with MinIO under Apache 9 // 2.0 license that can be found in the LICENSE file. 10 11 package rpc 12 13 import ( 14 "fmt" 15 "net/http" 16 "reflect" 17 "strings" 18 ) 19 20 var nilErrorValue = reflect.Zero(reflect.TypeOf((*error)(nil)).Elem()) 21 22 // ---------------------------------------------------------------------------- 23 // Codec 24 // ---------------------------------------------------------------------------- 25 26 // Codec creates a CodecRequest to process each request. 27 type Codec interface { 28 NewRequest(*http.Request) CodecRequest 29 } 30 31 // CodecRequest decodes a request and encodes a response using a specific 32 // serialization scheme. 33 type CodecRequest interface { 34 // Reads the request and returns the RPC method name. 35 Method() (string, error) 36 // Reads the request filling the RPC method args. 37 ReadRequest(interface{}) error 38 // Writes the response using the RPC method reply. 39 WriteResponse(http.ResponseWriter, interface{}) 40 // Writes an error produced by the server. 41 WriteError(w http.ResponseWriter, status int, err error) 42 } 43 44 // ---------------------------------------------------------------------------- 45 // Server 46 // ---------------------------------------------------------------------------- 47 48 // NewServer returns a new RPC server. 49 func NewServer() *Server { 50 return &Server{ 51 codecs: make(map[string]Codec), 52 services: new(serviceMap), 53 } 54 } 55 56 // RequestInfo contains all the information we pass to before/after functions 57 type RequestInfo struct { 58 Args reflect.Value 59 Method string 60 Error error 61 ResponseWriter http.ResponseWriter 62 Request *http.Request 63 StatusCode int 64 } 65 66 // Server serves registered RPC services using registered codecs. 67 type Server struct { 68 codecs map[string]Codec 69 services *serviceMap 70 interceptFunc func(i *RequestInfo) *http.Request 71 beforeFunc func(i *RequestInfo) 72 afterFunc func(i *RequestInfo) 73 validateFunc reflect.Value 74 } 75 76 // RegisterCodec adds a new codec to the server. 77 // 78 // Codecs are defined to process a given serialization scheme, e.g., JSON or 79 // XML. A codec is chosen based on the "Content-Type" header from the request, 80 // excluding the charset definition. 81 func (s *Server) RegisterCodec(codec Codec, contentType string) { 82 s.codecs[strings.ToLower(contentType)] = codec 83 } 84 85 // RegisterInterceptFunc registers the specified function as the function 86 // that will be called before every request. The function is allowed to intercept 87 // the request e.g. add values to the context. 88 // 89 // Note: Only one function can be registered, subsequent calls to this 90 // method will overwrite all the previous functions. 91 func (s *Server) RegisterInterceptFunc(f func(i *RequestInfo) *http.Request) { 92 s.interceptFunc = f 93 } 94 95 // RegisterBeforeFunc registers the specified function as the function 96 // that will be called before every request. 97 // 98 // Note: Only one function can be registered, subsequent calls to this 99 // method will overwrite all the previous functions. 100 func (s *Server) RegisterBeforeFunc(f func(i *RequestInfo)) { 101 s.beforeFunc = f 102 } 103 104 // RegisterValidateRequestFunc registers the specified function as the function 105 // that will be called after the BeforeFunc (if registered) and before invoking 106 // the actual Service method. If this function returns a non-nil error, the method 107 // won't be invoked and this error will be considered as the method result. 108 // The first argument is information about the request, useful for accessing to http.Request.Context() 109 // The second argument of this function is the already-unmarshalled *args parameter of the method. 110 func (s *Server) RegisterValidateRequestFunc(f func(r *RequestInfo, i interface{}) error) { 111 s.validateFunc = reflect.ValueOf(f) 112 } 113 114 // RegisterAfterFunc registers the specified function as the function 115 // that will be called after every request 116 // 117 // Note: Only one function can be registered, subsequent calls to this 118 // method will overwrite all the previous functions. 119 func (s *Server) RegisterAfterFunc(f func(i *RequestInfo)) { 120 s.afterFunc = f 121 } 122 123 // RegisterService adds a new service to the server. 124 // 125 // The name parameter is optional: if empty it will be inferred from 126 // the receiver type name. 127 // 128 // Methods from the receiver will be extracted if these rules are satisfied: 129 // 130 // - The receiver is exported (begins with an upper case letter) or local 131 // (defined in the package registering the service). 132 // - The method name is exported. 133 // - The method has three arguments: *http.Request, *args, *reply. 134 // - All three arguments are pointers. 135 // - The second and third arguments are exported or local. 136 // - The method has return type error. 137 // 138 // All other methods are ignored. 139 func (s *Server) RegisterService(receiver interface{}, name string) error { 140 return s.services.register(receiver, name) 141 } 142 143 // HasMethod returns true if the given method is registered. 144 // 145 // The method uses a dotted notation as in "Service.Method". 146 func (s *Server) HasMethod(method string) bool { 147 if _, _, err := s.services.get(method); err == nil { 148 return true 149 } 150 return false 151 } 152 153 // ServeHTTP 154 func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { 155 if r.Method != "POST" { 156 err := fmt.Errorf("rpc: POST method required, received %s", r.Method) 157 WriteError(w, http.StatusMethodNotAllowed, err.Error()) 158 // Call the registered After Function 159 if s.afterFunc != nil { 160 s.afterFunc(&RequestInfo{ 161 ResponseWriter: w, 162 Request: r, 163 Method: "Unknown." + r.Method, 164 StatusCode: http.StatusMethodNotAllowed, 165 }) 166 } 167 return 168 } 169 contentType := r.Header.Get("Content-Type") 170 idx := strings.Index(contentType, ";") 171 if idx != -1 { 172 contentType = contentType[:idx] 173 } 174 var codec Codec 175 if contentType == "" && len(s.codecs) == 1 { 176 // If Content-Type is not set and only one codec has been registered, 177 // then default to that codec. 178 for _, c := range s.codecs { 179 codec = c 180 } 181 } else if codec = s.codecs[strings.ToLower(contentType)]; codec == nil { 182 err := fmt.Errorf("rpc: unrecognized Content-Type: %s", contentType) 183 WriteError(w, http.StatusUnsupportedMediaType, err.Error()) 184 // Call the registered After Function 185 if s.afterFunc != nil { 186 s.afterFunc(&RequestInfo{ 187 ResponseWriter: w, 188 Request: r, 189 Method: "Unknown." + r.Method, 190 Error: err, 191 StatusCode: http.StatusUnsupportedMediaType, 192 }) 193 } 194 return 195 } 196 // Create a new codec request. 197 codecReq := codec.NewRequest(r) 198 // Get service method to be called. 199 method, errMethod := codecReq.Method() 200 if errMethod != nil { 201 codecReq.WriteError(w, http.StatusBadRequest, errMethod) 202 if s.afterFunc != nil { 203 s.afterFunc(&RequestInfo{ 204 ResponseWriter: w, 205 Request: r, 206 Method: "Unknown." + r.Method, 207 Error: errMethod, 208 StatusCode: http.StatusBadRequest, 209 }) 210 } 211 return 212 } 213 serviceSpec, methodSpec, errGet := s.services.get(method) 214 if errGet != nil { 215 codecReq.WriteError(w, http.StatusBadRequest, errGet) 216 if s.afterFunc != nil { 217 s.afterFunc(&RequestInfo{ 218 ResponseWriter: w, 219 Request: r, 220 Method: method, 221 Error: errGet, 222 StatusCode: http.StatusBadRequest, 223 }) 224 } 225 return 226 } 227 // Decode the args. 228 args := reflect.New(methodSpec.argsType) 229 if errRead := codecReq.ReadRequest(args.Interface()); errRead != nil { 230 codecReq.WriteError(w, http.StatusBadRequest, errRead) 231 if s.afterFunc != nil { 232 s.afterFunc(&RequestInfo{ 233 ResponseWriter: w, 234 Request: r, 235 Method: method, 236 Error: errRead, 237 StatusCode: http.StatusBadRequest, 238 }) 239 } 240 return 241 } 242 243 // Call the registered Intercept Function 244 if s.interceptFunc != nil { 245 req := s.interceptFunc(&RequestInfo{ 246 Request: r, 247 Method: method, 248 }) 249 if req != nil { 250 r = req 251 } 252 } 253 254 requestInfo := &RequestInfo{ 255 Request: r, 256 Method: method, 257 } 258 259 // Call the registered Before Function 260 if s.beforeFunc != nil { 261 s.beforeFunc(requestInfo) 262 } 263 264 // Prepare the reply, we need it even if validation fails 265 reply := reflect.New(methodSpec.replyType) 266 errValue := []reflect.Value{nilErrorValue} 267 268 // Call the registered Validator Function 269 if s.validateFunc.IsValid() { 270 errValue = s.validateFunc.Call([]reflect.Value{reflect.ValueOf(requestInfo), args}) 271 } 272 273 // If still no errors after validation, call the method 274 if errValue[0].IsNil() { 275 errValue = methodSpec.method.Func.Call([]reflect.Value{ 276 serviceSpec.rcvr, 277 reflect.ValueOf(r), 278 args, 279 reply, 280 }) 281 } 282 283 // Extract the result to error if needed. 284 var errResult error 285 statusCode := http.StatusOK 286 errInter := errValue[0].Interface() 287 if errInter != nil { 288 statusCode = http.StatusBadRequest 289 errResult = errInter.(error) 290 } 291 292 // Prevents Internet Explorer from MIME-sniffing a response away 293 // from the declared content-type 294 w.Header().Set("x-content-type-options", "nosniff") 295 296 // Encode the response. 297 if errResult == nil { 298 codecReq.WriteResponse(w, reply.Interface()) 299 } else { 300 codecReq.WriteError(w, statusCode, errResult) 301 } 302 303 // Call the registered After Function 304 if s.afterFunc != nil { 305 s.afterFunc(&RequestInfo{ 306 Args: args, 307 ResponseWriter: w, 308 Request: r, 309 Method: method, 310 Error: errResult, 311 StatusCode: statusCode, 312 }) 313 } 314 } 315 316 func WriteError(w http.ResponseWriter, status int, msg string) { 317 w.Header().Set("Content-Type", "text/plain; charset=utf-8") 318 w.WriteHeader(status) 319 fmt.Fprint(w, msg) 320 }