github.com/m3db/m3@v1.5.0/src/dbnode/network/server/httpjson/handlers.go (about)

     1  // Copyright (c) 2016 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package httpjson
    22  
    23  import (
    24  	"bytes"
    25  	"encoding/json"
    26  	"errors"
    27  	"fmt"
    28  	"net/http"
    29  	"reflect"
    30  	"strconv"
    31  	"strings"
    32  
    33  	xerrors "github.com/m3db/m3/src/x/errors"
    34  	"github.com/m3db/m3/src/x/headers"
    35  
    36  	apachethrift "github.com/apache/thrift/lib/go/thrift"
    37  	"github.com/uber/tchannel-go/thrift"
    38  )
    39  
    40  var (
    41  	errRequestMustBeGet  = xerrors.NewInvalidParamsError(errors.New("request without request params must be GET"))
    42  	errRequestMustBePost = xerrors.NewInvalidParamsError(errors.New("request with request params must be POST"))
    43  )
    44  
    45  // Error is an HTTP JSON error that also sets a return status code.
    46  type Error interface {
    47  	error
    48  
    49  	StatusCode() int
    50  }
    51  
    52  type errorType struct {
    53  	error
    54  	statusCode int
    55  }
    56  
    57  // NewError creates a new HTTP JSON error which has a specified status code.
    58  func NewError(err error, statusCode int) Error {
    59  	e := errorType{error: err}
    60  	e.statusCode = statusCode
    61  	return e
    62  }
    63  
    64  // StatusCode returns the HTTP status code that matches the error.
    65  func (e errorType) StatusCode() int {
    66  	return e.statusCode
    67  }
    68  
    69  type respSuccess struct {
    70  }
    71  
    72  type respErrorResult struct {
    73  	Error respError `json:"error"`
    74  }
    75  
    76  type respError struct {
    77  	Message string      `json:"message"`
    78  	Data    interface{} `json:"data"`
    79  }
    80  
    81  // RegisterHandlers will register handlers on the HTTP serve mux for a given service and options
    82  func RegisterHandlers(mux *http.ServeMux, service interface{}, opts ServerOptions) error {
    83  	v := reflect.ValueOf(service)
    84  	t := v.Type()
    85  	contextFn := opts.ContextFn()
    86  	postResponseFn := opts.PostResponseFn()
    87  	for i := 0; i < t.NumMethod(); i++ {
    88  		method := t.Method(i)
    89  
    90  		// Ensure this method is of either:
    91  		// - methodName(RequestObject) error
    92  		// - methodName(RequestObject) (ResultObject, error)
    93  		// - methodName() error
    94  		// - methodName() (ResultObject, error)
    95  		if !(method.Type.NumIn() == 2 || method.Type.NumIn() == 3) ||
    96  			!(method.Type.NumOut() == 1 || method.Type.NumOut() == 2) {
    97  			continue
    98  		}
    99  
   100  		var reqIn reflect.Type
   101  		obj := method.Type.In(0)
   102  		context := method.Type.In(1)
   103  		if method.Type.NumIn() == 3 {
   104  			reqIn = method.Type.In(2)
   105  		}
   106  
   107  		var resultOut, resultErr reflect.Type
   108  		if method.Type.NumOut() == 1 {
   109  			resultErr = method.Type.Out(0)
   110  		} else {
   111  			resultOut = method.Type.Out(0)
   112  			resultErr = method.Type.Out(1)
   113  		}
   114  
   115  		if obj != t {
   116  			continue
   117  		}
   118  
   119  		contextInterfaceType := reflect.TypeOf((*thrift.Context)(nil)).Elem()
   120  		if context.Kind() != reflect.Interface || !context.Implements(contextInterfaceType) {
   121  			continue
   122  		}
   123  
   124  		if method.Type.NumIn() == 3 {
   125  			if reqIn.Kind() != reflect.Ptr || reqIn.Elem().Kind() != reflect.Struct {
   126  				continue
   127  			}
   128  		}
   129  
   130  		if method.Type.NumOut() == 2 {
   131  			if resultOut.Kind() != reflect.Ptr || resultOut.Elem().Kind() != reflect.Struct {
   132  				continue
   133  			}
   134  		}
   135  
   136  		errInterfaceType := reflect.TypeOf((*error)(nil)).Elem()
   137  		if resultErr.Kind() != reflect.Interface || !resultErr.Implements(errInterfaceType) {
   138  			continue
   139  		}
   140  
   141  		name := strings.ToLower(method.Name)
   142  		mux.HandleFunc(fmt.Sprintf("/%s", name), func(w http.ResponseWriter, r *http.Request) {
   143  			w.Header().Set("Content-Type", "application/json")
   144  
   145  			// Always close the request body
   146  			defer r.Body.Close()
   147  
   148  			httpMethod := strings.ToUpper(r.Method)
   149  			if reqIn == nil && httpMethod != "GET" {
   150  				writeError(w, errRequestMustBeGet)
   151  				return
   152  			}
   153  			if reqIn != nil && httpMethod != "POST" {
   154  				writeError(w, errRequestMustBePost)
   155  				return
   156  			}
   157  
   158  			httpHeaders := make(map[string]string)
   159  			for key, values := range r.Header {
   160  				if len(values) > 0 {
   161  					httpHeaders[key] = values[0]
   162  				}
   163  			}
   164  
   165  			var in interface{}
   166  			if reqIn != nil {
   167  				in = reflect.New(reqIn.Elem()).Interface()
   168  				decoder := json.NewDecoder(r.Body)
   169  				disableDisallowUnknownFields, err := strconv.ParseBool(
   170  					r.Header.Get(headers.JSONDisableDisallowUnknownFields))
   171  				if err != nil || !disableDisallowUnknownFields {
   172  					decoder.DisallowUnknownFields()
   173  				}
   174  				if err := decoder.Decode(in); err != nil {
   175  					err := fmt.Errorf("invalid request body: %v", err)
   176  					writeError(w, xerrors.NewInvalidParamsError(err))
   177  					return
   178  				}
   179  			}
   180  
   181  			// Prepare the call context
   182  			callContext, _ := thrift.NewContext(opts.RequestTimeout())
   183  			if contextFn != nil {
   184  				// Allow derivation of context if context fn is set
   185  				callContext = contextFn(callContext, method.Name, httpHeaders)
   186  			}
   187  			// Always set headers finally
   188  			callContext = thrift.WithHeaders(callContext, httpHeaders)
   189  
   190  			var (
   191  				svc = reflect.ValueOf(service)
   192  				ctx = reflect.ValueOf(callContext)
   193  				ret []reflect.Value
   194  			)
   195  			if reqIn != nil {
   196  				ret = method.Func.Call([]reflect.Value{svc, ctx, reflect.ValueOf(in)})
   197  			} else {
   198  				ret = method.Func.Call([]reflect.Value{svc, ctx})
   199  			}
   200  
   201  			if method.Type.NumOut() == 1 {
   202  				// Ensure we always call the post response fn if set
   203  				if postResponseFn != nil {
   204  					defer postResponseFn(callContext, method.Name, nil)
   205  				}
   206  
   207  				// Deal with error case
   208  				if !ret[0].IsNil() {
   209  					writeError(w, ret[0].Interface())
   210  					return
   211  				}
   212  				json.NewEncoder(w).Encode(&respSuccess{})
   213  				return
   214  			}
   215  
   216  			// Ensure we always call the post response fn if set
   217  			if postResponseFn != nil {
   218  				defer func() {
   219  					var response apachethrift.TStruct
   220  					if result, ok := ret[0].Interface().(apachethrift.TStruct); ok {
   221  						response = result
   222  					}
   223  					postResponseFn(callContext, method.Name, response)
   224  				}()
   225  			}
   226  
   227  			// Deal with error case
   228  			if !ret[1].IsNil() {
   229  				writeError(w, ret[1].Interface())
   230  				return
   231  			}
   232  
   233  			buff := bytes.NewBuffer(nil)
   234  			if err := json.NewEncoder(buff).Encode(ret[0].Interface()); err != nil {
   235  				writeError(w, fmt.Errorf("failed to encode response body: %v", err))
   236  				return
   237  			}
   238  
   239  			w.WriteHeader(http.StatusOK)
   240  			w.Write(buff.Bytes())
   241  		})
   242  	}
   243  	return nil
   244  }
   245  
   246  func writeError(w http.ResponseWriter, errValue interface{}) {
   247  	result := respErrorResult{respError{}}
   248  	if value, ok := errValue.(error); ok {
   249  		result.Error.Message = value.Error()
   250  	} else if value, ok := errValue.(fmt.Stringer); ok {
   251  		result.Error.Message = value.String()
   252  	}
   253  	result.Error.Data = errValue
   254  
   255  	buff := bytes.NewBuffer(nil)
   256  	if err := json.NewEncoder(buff).Encode(&result); err != nil {
   257  		// Not a JSON returnable error
   258  		w.WriteHeader(http.StatusInternalServerError)
   259  		result.Error.Message = fmt.Sprintf("%v", errValue)
   260  		result.Error.Data = nil
   261  		json.NewEncoder(w).Encode(&result)
   262  		return
   263  	}
   264  
   265  	switch v := errValue.(type) {
   266  	case Error:
   267  		w.WriteHeader(v.StatusCode())
   268  	case error:
   269  		if xerrors.IsInvalidParams(v) {
   270  			w.WriteHeader(http.StatusBadRequest)
   271  		} else {
   272  			w.WriteHeader(http.StatusInternalServerError)
   273  		}
   274  	default:
   275  		w.WriteHeader(http.StatusInternalServerError)
   276  	}
   277  
   278  	w.Write(buff.Bytes())
   279  }