github.com/m3db/m3@v1.5.0/src/dbnode/network/server/httpjson/handlers.go (about) 1 // Copyright (c) 2016 Uber Technologies, Inc. 2 // 3 // Permission is hereby granted, free of charge, to any person obtaining a copy 4 // of this software and associated documentation files (the "Software"), to deal 5 // in the Software without restriction, including without limitation the rights 6 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 // copies of the Software, and to permit persons to whom the Software is 8 // furnished to do so, subject to the following conditions: 9 // 10 // The above copyright notice and this permission notice shall be included in 11 // all copies or substantial portions of the Software. 12 // 13 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 // THE SOFTWARE. 20 21 package httpjson 22 23 import ( 24 "bytes" 25 "encoding/json" 26 "errors" 27 "fmt" 28 "net/http" 29 "reflect" 30 "strconv" 31 "strings" 32 33 xerrors "github.com/m3db/m3/src/x/errors" 34 "github.com/m3db/m3/src/x/headers" 35 36 apachethrift "github.com/apache/thrift/lib/go/thrift" 37 "github.com/uber/tchannel-go/thrift" 38 ) 39 40 var ( 41 errRequestMustBeGet = xerrors.NewInvalidParamsError(errors.New("request without request params must be GET")) 42 errRequestMustBePost = xerrors.NewInvalidParamsError(errors.New("request with request params must be POST")) 43 ) 44 45 // Error is an HTTP JSON error that also sets a return status code. 46 type Error interface { 47 error 48 49 StatusCode() int 50 } 51 52 type errorType struct { 53 error 54 statusCode int 55 } 56 57 // NewError creates a new HTTP JSON error which has a specified status code. 58 func NewError(err error, statusCode int) Error { 59 e := errorType{error: err} 60 e.statusCode = statusCode 61 return e 62 } 63 64 // StatusCode returns the HTTP status code that matches the error. 65 func (e errorType) StatusCode() int { 66 return e.statusCode 67 } 68 69 type respSuccess struct { 70 } 71 72 type respErrorResult struct { 73 Error respError `json:"error"` 74 } 75 76 type respError struct { 77 Message string `json:"message"` 78 Data interface{} `json:"data"` 79 } 80 81 // RegisterHandlers will register handlers on the HTTP serve mux for a given service and options 82 func RegisterHandlers(mux *http.ServeMux, service interface{}, opts ServerOptions) error { 83 v := reflect.ValueOf(service) 84 t := v.Type() 85 contextFn := opts.ContextFn() 86 postResponseFn := opts.PostResponseFn() 87 for i := 0; i < t.NumMethod(); i++ { 88 method := t.Method(i) 89 90 // Ensure this method is of either: 91 // - methodName(RequestObject) error 92 // - methodName(RequestObject) (ResultObject, error) 93 // - methodName() error 94 // - methodName() (ResultObject, error) 95 if !(method.Type.NumIn() == 2 || method.Type.NumIn() == 3) || 96 !(method.Type.NumOut() == 1 || method.Type.NumOut() == 2) { 97 continue 98 } 99 100 var reqIn reflect.Type 101 obj := method.Type.In(0) 102 context := method.Type.In(1) 103 if method.Type.NumIn() == 3 { 104 reqIn = method.Type.In(2) 105 } 106 107 var resultOut, resultErr reflect.Type 108 if method.Type.NumOut() == 1 { 109 resultErr = method.Type.Out(0) 110 } else { 111 resultOut = method.Type.Out(0) 112 resultErr = method.Type.Out(1) 113 } 114 115 if obj != t { 116 continue 117 } 118 119 contextInterfaceType := reflect.TypeOf((*thrift.Context)(nil)).Elem() 120 if context.Kind() != reflect.Interface || !context.Implements(contextInterfaceType) { 121 continue 122 } 123 124 if method.Type.NumIn() == 3 { 125 if reqIn.Kind() != reflect.Ptr || reqIn.Elem().Kind() != reflect.Struct { 126 continue 127 } 128 } 129 130 if method.Type.NumOut() == 2 { 131 if resultOut.Kind() != reflect.Ptr || resultOut.Elem().Kind() != reflect.Struct { 132 continue 133 } 134 } 135 136 errInterfaceType := reflect.TypeOf((*error)(nil)).Elem() 137 if resultErr.Kind() != reflect.Interface || !resultErr.Implements(errInterfaceType) { 138 continue 139 } 140 141 name := strings.ToLower(method.Name) 142 mux.HandleFunc(fmt.Sprintf("/%s", name), func(w http.ResponseWriter, r *http.Request) { 143 w.Header().Set("Content-Type", "application/json") 144 145 // Always close the request body 146 defer r.Body.Close() 147 148 httpMethod := strings.ToUpper(r.Method) 149 if reqIn == nil && httpMethod != "GET" { 150 writeError(w, errRequestMustBeGet) 151 return 152 } 153 if reqIn != nil && httpMethod != "POST" { 154 writeError(w, errRequestMustBePost) 155 return 156 } 157 158 httpHeaders := make(map[string]string) 159 for key, values := range r.Header { 160 if len(values) > 0 { 161 httpHeaders[key] = values[0] 162 } 163 } 164 165 var in interface{} 166 if reqIn != nil { 167 in = reflect.New(reqIn.Elem()).Interface() 168 decoder := json.NewDecoder(r.Body) 169 disableDisallowUnknownFields, err := strconv.ParseBool( 170 r.Header.Get(headers.JSONDisableDisallowUnknownFields)) 171 if err != nil || !disableDisallowUnknownFields { 172 decoder.DisallowUnknownFields() 173 } 174 if err := decoder.Decode(in); err != nil { 175 err := fmt.Errorf("invalid request body: %v", err) 176 writeError(w, xerrors.NewInvalidParamsError(err)) 177 return 178 } 179 } 180 181 // Prepare the call context 182 callContext, _ := thrift.NewContext(opts.RequestTimeout()) 183 if contextFn != nil { 184 // Allow derivation of context if context fn is set 185 callContext = contextFn(callContext, method.Name, httpHeaders) 186 } 187 // Always set headers finally 188 callContext = thrift.WithHeaders(callContext, httpHeaders) 189 190 var ( 191 svc = reflect.ValueOf(service) 192 ctx = reflect.ValueOf(callContext) 193 ret []reflect.Value 194 ) 195 if reqIn != nil { 196 ret = method.Func.Call([]reflect.Value{svc, ctx, reflect.ValueOf(in)}) 197 } else { 198 ret = method.Func.Call([]reflect.Value{svc, ctx}) 199 } 200 201 if method.Type.NumOut() == 1 { 202 // Ensure we always call the post response fn if set 203 if postResponseFn != nil { 204 defer postResponseFn(callContext, method.Name, nil) 205 } 206 207 // Deal with error case 208 if !ret[0].IsNil() { 209 writeError(w, ret[0].Interface()) 210 return 211 } 212 json.NewEncoder(w).Encode(&respSuccess{}) 213 return 214 } 215 216 // Ensure we always call the post response fn if set 217 if postResponseFn != nil { 218 defer func() { 219 var response apachethrift.TStruct 220 if result, ok := ret[0].Interface().(apachethrift.TStruct); ok { 221 response = result 222 } 223 postResponseFn(callContext, method.Name, response) 224 }() 225 } 226 227 // Deal with error case 228 if !ret[1].IsNil() { 229 writeError(w, ret[1].Interface()) 230 return 231 } 232 233 buff := bytes.NewBuffer(nil) 234 if err := json.NewEncoder(buff).Encode(ret[0].Interface()); err != nil { 235 writeError(w, fmt.Errorf("failed to encode response body: %v", err)) 236 return 237 } 238 239 w.WriteHeader(http.StatusOK) 240 w.Write(buff.Bytes()) 241 }) 242 } 243 return nil 244 } 245 246 func writeError(w http.ResponseWriter, errValue interface{}) { 247 result := respErrorResult{respError{}} 248 if value, ok := errValue.(error); ok { 249 result.Error.Message = value.Error() 250 } else if value, ok := errValue.(fmt.Stringer); ok { 251 result.Error.Message = value.String() 252 } 253 result.Error.Data = errValue 254 255 buff := bytes.NewBuffer(nil) 256 if err := json.NewEncoder(buff).Encode(&result); err != nil { 257 // Not a JSON returnable error 258 w.WriteHeader(http.StatusInternalServerError) 259 result.Error.Message = fmt.Sprintf("%v", errValue) 260 result.Error.Data = nil 261 json.NewEncoder(w).Encode(&result) 262 return 263 } 264 265 switch v := errValue.(type) { 266 case Error: 267 w.WriteHeader(v.StatusCode()) 268 case error: 269 if xerrors.IsInvalidParams(v) { 270 w.WriteHeader(http.StatusBadRequest) 271 } else { 272 w.WriteHeader(http.StatusInternalServerError) 273 } 274 default: 275 w.WriteHeader(http.StatusInternalServerError) 276 } 277 278 w.Write(buff.Bytes()) 279 }