github.com/gofiber/fiber/v2@v2.47.0/middleware/adaptor/adaptor.go (about)

     1  package adaptor
     2  
     3  import (
     4  	"io"
     5  	"net"
     6  	"net/http"
     7  	"reflect"
     8  	"unsafe"
     9  
    10  	"github.com/gofiber/fiber/v2"
    11  	"github.com/gofiber/fiber/v2/utils"
    12  	"github.com/valyala/fasthttp"
    13  	"github.com/valyala/fasthttp/fasthttpadaptor"
    14  )
    15  
    16  // HTTPHandlerFunc wraps net/http handler func to fiber handler
    17  func HTTPHandlerFunc(h http.HandlerFunc) fiber.Handler {
    18  	return HTTPHandler(h)
    19  }
    20  
    21  // HTTPHandler wraps net/http handler to fiber handler
    22  func HTTPHandler(h http.Handler) fiber.Handler {
    23  	return func(c *fiber.Ctx) error {
    24  		handler := fasthttpadaptor.NewFastHTTPHandler(h)
    25  		handler(c.Context())
    26  		return nil
    27  	}
    28  }
    29  
    30  // ConvertRequest converts a fiber.Ctx to an http.Request.
    31  // forServer should be set to true when the http.Request is going to be passed to a http.Handler.
    32  func ConvertRequest(c *fiber.Ctx, forServer bool) (*http.Request, error) {
    33  	var req http.Request
    34  	if err := fasthttpadaptor.ConvertRequest(c.Context(), &req, forServer); err != nil {
    35  		return nil, err //nolint:wrapcheck // This must not be wrapped
    36  	}
    37  	return &req, nil
    38  }
    39  
    40  // CopyContextToFiberContext copies the values of context.Context to a fasthttp.RequestCtx
    41  func CopyContextToFiberContext(context interface{}, requestContext *fasthttp.RequestCtx) {
    42  	contextValues := reflect.ValueOf(context).Elem()
    43  	contextKeys := reflect.TypeOf(context).Elem()
    44  	if contextKeys.Kind() == reflect.Struct {
    45  		var lastKey interface{}
    46  		for i := 0; i < contextValues.NumField(); i++ {
    47  			reflectValue := contextValues.Field(i)
    48  			/* #nosec */
    49  			reflectValue = reflect.NewAt(reflectValue.Type(), unsafe.Pointer(reflectValue.UnsafeAddr())).Elem()
    50  
    51  			reflectField := contextKeys.Field(i)
    52  
    53  			if reflectField.Name == "noCopy" {
    54  				break
    55  			} else if reflectField.Name == "Context" {
    56  				CopyContextToFiberContext(reflectValue.Interface(), requestContext)
    57  			} else if reflectField.Name == "key" {
    58  				lastKey = reflectValue.Interface()
    59  			} else if lastKey != nil && reflectField.Name == "val" {
    60  				requestContext.SetUserValue(lastKey, reflectValue.Interface())
    61  			} else {
    62  				lastKey = nil
    63  			}
    64  		}
    65  	}
    66  }
    67  
    68  // HTTPMiddleware wraps net/http middleware to fiber middleware
    69  func HTTPMiddleware(mw func(http.Handler) http.Handler) fiber.Handler {
    70  	return func(c *fiber.Ctx) error {
    71  		var next bool
    72  		nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    73  			next = true
    74  			// Convert again in case request may modify by middleware
    75  			c.Request().Header.SetMethod(r.Method)
    76  			c.Request().SetRequestURI(r.RequestURI)
    77  			c.Request().SetHost(r.Host)
    78  			for key, val := range r.Header {
    79  				for _, v := range val {
    80  					c.Request().Header.Set(key, v)
    81  				}
    82  			}
    83  			CopyContextToFiberContext(r.Context(), c.Context())
    84  		})
    85  
    86  		if err := HTTPHandler(mw(nextHandler))(c); err != nil {
    87  			return err
    88  		}
    89  
    90  		if next {
    91  			return c.Next()
    92  		}
    93  		return nil
    94  	}
    95  }
    96  
    97  // FiberHandler wraps fiber handler to net/http handler
    98  func FiberHandler(h fiber.Handler) http.Handler {
    99  	return FiberHandlerFunc(h)
   100  }
   101  
   102  // FiberHandlerFunc wraps fiber handler to net/http handler func
   103  func FiberHandlerFunc(h fiber.Handler) http.HandlerFunc {
   104  	return handlerFunc(fiber.New(), h)
   105  }
   106  
   107  // FiberApp wraps fiber app to net/http handler func
   108  func FiberApp(app *fiber.App) http.HandlerFunc {
   109  	return handlerFunc(app)
   110  }
   111  
   112  func handlerFunc(app *fiber.App, h ...fiber.Handler) http.HandlerFunc {
   113  	return func(w http.ResponseWriter, r *http.Request) {
   114  		// New fasthttp request
   115  		req := fasthttp.AcquireRequest()
   116  		defer fasthttp.ReleaseRequest(req)
   117  		// Convert net/http -> fasthttp request
   118  		if r.Body != nil {
   119  			body, err := io.ReadAll(r.Body)
   120  			if err != nil {
   121  				http.Error(w, utils.StatusMessage(fiber.StatusInternalServerError), fiber.StatusInternalServerError)
   122  				return
   123  			}
   124  
   125  			req.Header.SetContentLength(len(body))
   126  			_, err = req.BodyWriter().Write(body)
   127  			if err != nil {
   128  				http.Error(w, utils.StatusMessage(fiber.StatusInternalServerError), fiber.StatusInternalServerError)
   129  				return
   130  			}
   131  		}
   132  		req.Header.SetMethod(r.Method)
   133  		req.SetRequestURI(r.RequestURI)
   134  		req.SetHost(r.Host)
   135  		for key, val := range r.Header {
   136  			for _, v := range val {
   137  				req.Header.Set(key, v)
   138  			}
   139  		}
   140  		if _, _, err := net.SplitHostPort(r.RemoteAddr); err != nil && err.(*net.AddrError).Err == "missing port in address" { //nolint:errorlint, forcetypeassert // overlinting
   141  			r.RemoteAddr = net.JoinHostPort(r.RemoteAddr, "80")
   142  		}
   143  		remoteAddr, err := net.ResolveTCPAddr("tcp", r.RemoteAddr)
   144  		if err != nil {
   145  			http.Error(w, utils.StatusMessage(fiber.StatusInternalServerError), fiber.StatusInternalServerError)
   146  			return
   147  		}
   148  
   149  		// New fasthttp Ctx
   150  		var fctx fasthttp.RequestCtx
   151  		fctx.Init(req, remoteAddr, nil)
   152  		if len(h) > 0 {
   153  			// New fiber Ctx
   154  			ctx := app.AcquireCtx(&fctx)
   155  			defer app.ReleaseCtx(ctx)
   156  			// Execute fiber Ctx
   157  			err := h[0](ctx)
   158  			if err != nil {
   159  				_ = app.Config().ErrorHandler(ctx, err) //nolint:errcheck // not needed
   160  			}
   161  		} else {
   162  			// Execute fasthttp Ctx though app.Handler
   163  			app.Handler()(&fctx)
   164  		}
   165  
   166  		// Convert fasthttp Ctx > net/http
   167  		fctx.Response.Header.VisitAll(func(k, v []byte) {
   168  			w.Header().Add(string(k), string(v))
   169  		})
   170  		w.WriteHeader(fctx.Response.StatusCode())
   171  		_, _ = w.Write(fctx.Response.Body()) //nolint:errcheck // not needed
   172  	}
   173  }