github.com/gramework/gramework@v1.8.1-0.20231027140105-82555c9057f5/router_determineHandler.go (about)

     1  // Copyright 2017-present Kirill Danshin and Gramework contributors
     2  // Copyright 2019-present Highload LTD (UK CN: 11893420)
     3  //
     4  // Licensed under the Apache License, Version 2.0 (the "License");
     5  // you may not use this file except in compliance with the License.
     6  // You may obtain a copy of the License at
     7  //
     8  //     http://www.apache.org/licenses/LICENSE-2.0
     9  //
    10  
    11  package gramework
    12  
    13  import (
    14  	"errors"
    15  	"reflect"
    16  	"strings"
    17  
    18  	"github.com/valyala/fasthttp"
    19  )
    20  
    21  type reqHandlerDefault interface {
    22  	Handler(*Context)
    23  }
    24  
    25  type reqHandlerWithError interface {
    26  	Handler(*Context) error
    27  }
    28  
    29  type reqHandlerWithEfaceError interface {
    30  	Handler(*Context) (interface{}, error)
    31  }
    32  
    33  type reqHandlerWithEface interface {
    34  	Handler(*Context) interface{}
    35  }
    36  
    37  type reqHandlerNoCtx interface {
    38  	Handler()
    39  }
    40  
    41  type reqHandlerWithErrorNoCtx interface {
    42  	Handler() error
    43  }
    44  
    45  type reqHandlerWithEfaceErrorNoCtx interface {
    46  	Handler() (interface{}, error)
    47  }
    48  
    49  type reqHandlerWithEfaceNoCtx interface {
    50  	Handler() interface{}
    51  }
    52  
    53  func (r *Router) determineHandler(handler interface{}) func(*Context) {
    54  	// copy handler, we don't want to mutate our arguments
    55  	rawHandler := handler
    56  
    57  	// prepare handler in case if it one of our supported interfaces
    58  	switch h := handler.(type) {
    59  	case reqHandlerDefault:
    60  		rawHandler = h.Handler
    61  	case reqHandlerWithError:
    62  		rawHandler = h.Handler
    63  	case reqHandlerWithEfaceError:
    64  		rawHandler = h.Handler
    65  	case reqHandlerWithEface:
    66  		rawHandler = h.Handler
    67  	case reqHandlerNoCtx:
    68  		rawHandler = h.Handler
    69  	case reqHandlerWithErrorNoCtx:
    70  		rawHandler = h.Handler
    71  	case reqHandlerWithEfaceErrorNoCtx:
    72  		rawHandler = h.Handler
    73  	case reqHandlerWithEfaceNoCtx:
    74  		rawHandler = h.Handler
    75  	}
    76  
    77  	// finally, process the handler
    78  	switch h := rawHandler.(type) {
    79  	case HTML:
    80  		return r.getHTMLServer(h)
    81  	case JSON:
    82  		return r.getJSONServer(h)
    83  	case func(*Context):
    84  		return h
    85  	case RequestHandler:
    86  		return h
    87  	case func(*Context) error:
    88  		return r.getErrorHandler(h)
    89  	case func(*fasthttp.RequestCtx):
    90  		return r.getGrameHandler(h)
    91  	case func(*fasthttp.RequestCtx) error:
    92  		return r.getGrameErrorHandler(h)
    93  	case func() interface{}:
    94  		return r.getEfaceEncoder(h)
    95  	case func() (interface{}, error):
    96  		return r.getEfaceErrEncoder(h)
    97  	case func(*Context) interface{}:
    98  		return r.getEfaceCtxEncoder(h)
    99  	case func(*Context) (interface{}, error):
   100  		return r.getEfaceCtxErrEncoder(h)
   101  	case string:
   102  		return r.getStringServer(h)
   103  	case []byte:
   104  		return r.getBytesServer(h)
   105  	case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
   106  		return r.getFmtDHandler(h)
   107  	case float32, float64:
   108  		return r.getFmtFHandler(h)
   109  	case func():
   110  		return r.getGrameDumbHandler(h)
   111  	case func() error:
   112  		return r.getGrameDumbErrorHandler(h)
   113  	case func() string:
   114  		return r.getEFuncStrHandler(h)
   115  	case func() map[string]interface{}:
   116  		return r.getHandlerEncoder(h)
   117  	case func(*Context) map[string]interface{}:
   118  		return r.getCtxHandlerEncoder(h)
   119  	case func() (map[string]interface{}, error):
   120  		return r.getHandlerEncoderErr(h)
   121  	case func(*Context) (map[string]interface{}, error):
   122  		return r.getCtxHandlerEncoderErr(h)
   123  	default:
   124  		rv := reflect.ValueOf(h)
   125  		if rv.Kind() == reflect.Func {
   126  			handler, err := r.getCachedReflectHandler(h)
   127  			if err != nil {
   128  				r.app.internalLog.WithError(err).Fatal("Unsupported reflect handler signature")
   129  			}
   130  
   131  			return handler
   132  		}
   133  		r.app.internalLog.Warnf("Unknown handler type: %T, serving fmt.Sprintf(%%v)", h)
   134  		return r.getFmtVHandler(h)
   135  	}
   136  }
   137  
   138  type reflectDecodedBodyRecv struct {
   139  	idx int
   140  	t   reflect.Type
   141  }
   142  
   143  func (r *Router) getCachedReflectHandler(h interface{}) (func(*Context), error) {
   144  	funcT := reflect.TypeOf(h)
   145  	if funcT.IsVariadic() {
   146  		return nil, errors.New("could not process variadic reflect handler")
   147  	}
   148  
   149  	results := funcT.NumOut()
   150  	if results > 2 {
   151  		return nil, errors.New("reflect handler output should be one of (any), (any, error), (error) or ()")
   152  	}
   153  
   154  	params := funcT.NumIn()
   155  	decodedBodyRecv := []reflectDecodedBodyRecv{}
   156  	ctxRecv := -1
   157  
   158  	checkForErrorAt := -1
   159  	encodeDataAt := -1
   160  
   161  	for i := 0; i < params; i++ {
   162  		p := funcT.In(i)
   163  		if strings.Contains(p.String(), "*gramework.Context") {
   164  			ctxRecv = i
   165  			continue
   166  		}
   167  		decodedBodyRecv = append(decodedBodyRecv, reflectDecodedBodyRecv{
   168  			idx: i,
   169  			t:   p,
   170  		})
   171  	}
   172  
   173  	for i := 0; i < results; i++ {
   174  		r := funcT.Out(i)
   175  		println(r.String())
   176  
   177  		if r.String() == "error" {
   178  			if i == 0 && results > 1 {
   179  				return nil, errors.New("reflect handler output should be one of (any), (any, error), (error) or ()")
   180  			}
   181  
   182  			checkForErrorAt = i
   183  			continue
   184  		}
   185  
   186  		if encodeDataAt >= 0 {
   187  			return nil, errors.New("reflect handler output should be one of (any), (any, error), (error) or ()")
   188  		}
   189  
   190  		encodeDataAt = i
   191  	}
   192  
   193  	funcV := reflect.ValueOf(h)
   194  
   195  	handler := func(ctx *Context) {
   196  		callParams := make([]reflect.Value, params)
   197  		if len(decodedBodyRecv) > 0 {
   198  			unsupportedBodyType := true
   199  			for i := range decodedBodyRecv {
   200  				decoded := reflect.New(decodedBodyRecv[i].t).Interface()
   201  				if jsonErr := ctx.UnJSON(decoded); jsonErr == nil {
   202  					unsupportedBodyType = false
   203  					decodedV := reflect.ValueOf(decoded)
   204  
   205  					callParams[decodedBodyRecv[i].idx] = decodedV.Elem()
   206  				} else {
   207  					callParams[decodedBodyRecv[i].idx] = reflect.Zero(decodedBodyRecv[i].t)
   208  				}
   209  			}
   210  
   211  			if unsupportedBodyType {
   212  				ctx.SetStatusCode(500)
   213  				ctx.Logger.Error("unsupported body type")
   214  				return
   215  			}
   216  		}
   217  		if ctxRecv >= 0 {
   218  			callParams[ctxRecv] = reflect.ValueOf(ctx)
   219  		}
   220  
   221  		res := funcV.Call(callParams)
   222  		shouldProcessErr := false
   223  		shouldProcessReturn := false
   224  		var err error
   225  		if checkForErrorAt >= 0 && !res[checkForErrorAt].IsNil() {
   226  			resErr, ok := res[checkForErrorAt].Interface().(error)
   227  			if ok {
   228  				err = resErr
   229  			} else {
   230  				err = errUnknown
   231  			}
   232  			shouldProcessErr = true
   233  		}
   234  
   235  		var v interface{}
   236  		if encodeDataAt >= 0 {
   237  			v = res[encodeDataAt].Interface()
   238  			shouldProcessReturn = true
   239  		}
   240  		if shouldProcessErr {
   241  			if err != nil {
   242  				ctx.jsonErrorLog(err)
   243  				return
   244  			}
   245  		}
   246  		if shouldProcessReturn {
   247  			if v == nil { // err == nil here
   248  				ctx.SetStatusCode(fasthttp.StatusNoContent)
   249  				return
   250  			}
   251  			if err = ctx.JSON(v); err != nil {
   252  				ctx.jsonErrorLog(err)
   253  			}
   254  		}
   255  	}
   256  
   257  	return handler, nil
   258  }
   259  
   260  var errUnknown = errors.New("Unknown Server Error")