code.gitea.io/gitea@v1.22.3/modules/web/handler.go (about)

     1  // Copyright 2023 The Gitea Authors. All rights reserved.
     2  // SPDX-License-Identifier: MIT
     3  
     4  package web
     5  
     6  import (
     7  	goctx "context"
     8  	"fmt"
     9  	"net/http"
    10  	"reflect"
    11  
    12  	"code.gitea.io/gitea/modules/log"
    13  	"code.gitea.io/gitea/modules/web/routing"
    14  	"code.gitea.io/gitea/modules/web/types"
    15  )
    16  
    17  var responseStatusProviders = map[reflect.Type]func(req *http.Request) types.ResponseStatusProvider{}
    18  
    19  func RegisterResponseStatusProvider[T any](fn func(req *http.Request) types.ResponseStatusProvider) {
    20  	responseStatusProviders[reflect.TypeOf((*T)(nil)).Elem()] = fn
    21  }
    22  
    23  // responseWriter is a wrapper of http.ResponseWriter, to check whether the response has been written
    24  type responseWriter struct {
    25  	respWriter http.ResponseWriter
    26  	status     int
    27  }
    28  
    29  var _ types.ResponseStatusProvider = (*responseWriter)(nil)
    30  
    31  func (r *responseWriter) WrittenStatus() int {
    32  	return r.status
    33  }
    34  
    35  func (r *responseWriter) Header() http.Header {
    36  	return r.respWriter.Header()
    37  }
    38  
    39  func (r *responseWriter) Write(bytes []byte) (int, error) {
    40  	if r.status == 0 {
    41  		r.status = http.StatusOK
    42  	}
    43  	return r.respWriter.Write(bytes)
    44  }
    45  
    46  func (r *responseWriter) WriteHeader(statusCode int) {
    47  	r.status = statusCode
    48  	r.respWriter.WriteHeader(statusCode)
    49  }
    50  
    51  var (
    52  	httpReqType    = reflect.TypeOf((*http.Request)(nil))
    53  	respWriterType = reflect.TypeOf((*http.ResponseWriter)(nil)).Elem()
    54  	cancelFuncType = reflect.TypeOf((*goctx.CancelFunc)(nil)).Elem()
    55  )
    56  
    57  // preCheckHandler checks whether the handler is valid, developers could get first-time feedback, all mistakes could be found at startup
    58  func preCheckHandler(fn reflect.Value, argsIn []reflect.Value) {
    59  	hasStatusProvider := false
    60  	for _, argIn := range argsIn {
    61  		if _, hasStatusProvider = argIn.Interface().(types.ResponseStatusProvider); hasStatusProvider {
    62  			break
    63  		}
    64  	}
    65  	if !hasStatusProvider {
    66  		panic(fmt.Sprintf("handler should have at least one ResponseStatusProvider argument, but got %s", fn.Type()))
    67  	}
    68  	if fn.Type().NumOut() != 0 && fn.Type().NumIn() != 1 {
    69  		panic(fmt.Sprintf("handler should have no return value or only one argument, but got %s", fn.Type()))
    70  	}
    71  	if fn.Type().NumOut() == 1 && fn.Type().Out(0) != cancelFuncType {
    72  		panic(fmt.Sprintf("handler should return a cancel function, but got %s", fn.Type()))
    73  	}
    74  }
    75  
    76  func prepareHandleArgsIn(resp http.ResponseWriter, req *http.Request, fn reflect.Value, fnInfo *routing.FuncInfo) []reflect.Value {
    77  	defer func() {
    78  		if err := recover(); err != nil {
    79  			log.Error("unable to prepare handler arguments for %s: %v", fnInfo.String(), err)
    80  			panic(err)
    81  		}
    82  	}()
    83  	isPreCheck := req == nil
    84  
    85  	argsIn := make([]reflect.Value, fn.Type().NumIn())
    86  	for i := 0; i < fn.Type().NumIn(); i++ {
    87  		argTyp := fn.Type().In(i)
    88  		switch argTyp {
    89  		case respWriterType:
    90  			argsIn[i] = reflect.ValueOf(resp)
    91  		case httpReqType:
    92  			argsIn[i] = reflect.ValueOf(req)
    93  		default:
    94  			if argFn, ok := responseStatusProviders[argTyp]; ok {
    95  				if isPreCheck {
    96  					argsIn[i] = reflect.ValueOf(&responseWriter{})
    97  				} else {
    98  					argsIn[i] = reflect.ValueOf(argFn(req))
    99  				}
   100  			} else {
   101  				panic(fmt.Sprintf("unsupported argument type: %s", argTyp))
   102  			}
   103  		}
   104  	}
   105  	return argsIn
   106  }
   107  
   108  func handleResponse(fn reflect.Value, ret []reflect.Value) goctx.CancelFunc {
   109  	if len(ret) == 1 {
   110  		if cancelFunc, ok := ret[0].Interface().(goctx.CancelFunc); ok {
   111  			return cancelFunc
   112  		}
   113  		panic(fmt.Sprintf("unsupported return type: %s", ret[0].Type()))
   114  	} else if len(ret) > 1 {
   115  		panic(fmt.Sprintf("unsupported return values: %s", fn.Type()))
   116  	}
   117  	return nil
   118  }
   119  
   120  func hasResponseBeenWritten(argsIn []reflect.Value) bool {
   121  	for _, argIn := range argsIn {
   122  		if statusProvider, ok := argIn.Interface().(types.ResponseStatusProvider); ok {
   123  			if statusProvider.WrittenStatus() != 0 {
   124  				return true
   125  			}
   126  		}
   127  	}
   128  	return false
   129  }
   130  
   131  func wrapHandlerProvider[T http.Handler](hp func(next http.Handler) T, funcInfo *routing.FuncInfo) func(next http.Handler) http.Handler {
   132  	return func(next http.Handler) http.Handler {
   133  		h := hp(next) // this handle could be dynamically generated, so we can't use it for debug info
   134  		return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) {
   135  			routing.UpdateFuncInfo(req.Context(), funcInfo)
   136  			h.ServeHTTP(resp, req)
   137  		})
   138  	}
   139  }
   140  
   141  // toHandlerProvider converts a handler to a handler provider
   142  // A handler provider is a function that takes a "next" http.Handler, it can be used as a middleware
   143  func toHandlerProvider(handler any) func(next http.Handler) http.Handler {
   144  	funcInfo := routing.GetFuncInfo(handler)
   145  	fn := reflect.ValueOf(handler)
   146  	if fn.Type().Kind() != reflect.Func {
   147  		panic(fmt.Sprintf("handler must be a function, but got %s", fn.Type()))
   148  	}
   149  
   150  	if hp, ok := handler.(func(next http.Handler) http.Handler); ok {
   151  		return wrapHandlerProvider(hp, funcInfo)
   152  	} else if hp, ok := handler.(func(http.Handler) http.HandlerFunc); ok {
   153  		return wrapHandlerProvider(hp, funcInfo)
   154  	}
   155  
   156  	provider := func(next http.Handler) http.Handler {
   157  		return http.HandlerFunc(func(respOrig http.ResponseWriter, req *http.Request) {
   158  			// wrap the response writer to check whether the response has been written
   159  			resp := respOrig
   160  			if _, ok := resp.(types.ResponseStatusProvider); !ok {
   161  				resp = &responseWriter{respWriter: resp}
   162  			}
   163  
   164  			// prepare the arguments for the handler and do pre-check
   165  			argsIn := prepareHandleArgsIn(resp, req, fn, funcInfo)
   166  			if req == nil {
   167  				preCheckHandler(fn, argsIn)
   168  				return // it's doing pre-check, just return
   169  			}
   170  
   171  			routing.UpdateFuncInfo(req.Context(), funcInfo)
   172  			ret := fn.Call(argsIn)
   173  
   174  			// handle the return value, and defer the cancel function if there is one
   175  			cancelFunc := handleResponse(fn, ret)
   176  			if cancelFunc != nil {
   177  				defer cancelFunc()
   178  			}
   179  
   180  			// if the response has not been written, call the next handler
   181  			if next != nil && !hasResponseBeenWritten(argsIn) {
   182  				next.ServeHTTP(resp, req)
   183  			}
   184  		})
   185  	}
   186  
   187  	provider(nil).ServeHTTP(nil, nil) // do a pre-check to make sure all arguments and return values are supported
   188  	return provider
   189  }