github.com/abemedia/go-don@v0.2.2-0.20240329015135-be88e32bb73b/handler.go (about)

     1  package don
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"net/http"
     7  
     8  	"github.com/abemedia/go-don/encoding"
     9  	"github.com/abemedia/go-don/internal/byteconv"
    10  	"github.com/abemedia/httprouter"
    11  	"github.com/valyala/fasthttp"
    12  )
    13  
    14  // StatusCoder allows you to customise the HTTP response code.
    15  type StatusCoder interface {
    16  	StatusCode() int
    17  }
    18  
    19  // Headerer allows you to customise the HTTP headers.
    20  type Headerer interface {
    21  	Header() http.Header
    22  }
    23  
    24  // Handle is the type for your handlers.
    25  type Handle[T, O any] func(ctx context.Context, request T) (O, error)
    26  
    27  // H wraps your handler function with the Go generics magic.
    28  func H[T, O any](handle Handle[T, O]) httprouter.Handle {
    29  	pool := newRequestPool(*new(T))
    30  	decodeRequest := newRequestDecoder(*new(T))
    31  	isNil := newNilCheck(*new(O))
    32  
    33  	return func(ctx *fasthttp.RequestCtx, p httprouter.Params) {
    34  		contentType := getMediaType(ctx.Request.Header.Peek(fasthttp.HeaderAccept))
    35  
    36  		enc := encoding.GetEncoder(contentType)
    37  		if enc == nil {
    38  			handleError(ctx, ErrNotAcceptable)
    39  			return
    40  		}
    41  
    42  		var res any
    43  
    44  		req := pool.Get()
    45  		err := decodeRequest(req, ctx, p)
    46  		if err != nil {
    47  			res = Error(err, getStatusCode(err, http.StatusBadRequest))
    48  		} else {
    49  			res, err = handle(ctx, *req)
    50  			if err != nil {
    51  				res = Error(err, 0)
    52  			}
    53  		}
    54  		pool.Put(req)
    55  
    56  		ctx.SetContentType(contentType + "; charset=utf-8")
    57  
    58  		if h, ok := res.(Headerer); ok {
    59  			for k, v := range h.Header() {
    60  				ctx.Response.Header.Set(k, v[0])
    61  			}
    62  		}
    63  
    64  		if sc, ok := res.(StatusCoder); ok {
    65  			ctx.SetStatusCode(sc.StatusCode())
    66  		}
    67  
    68  		if err == nil && isNil(res) {
    69  			res = nil
    70  			ctx.Response.Header.SetContentLength(-3)
    71  		}
    72  
    73  		if err = enc(ctx, res); err != nil {
    74  			handleError(ctx, err)
    75  		}
    76  	}
    77  }
    78  
    79  func handleError(ctx *fasthttp.RequestCtx, err error) {
    80  	code := getStatusCode(err, http.StatusInternalServerError)
    81  	if code < http.StatusInternalServerError {
    82  		ctx.Error(err.Error(), code)
    83  		return
    84  	}
    85  	ctx.Error(fasthttp.StatusMessage(code), code)
    86  	ctx.Logger().Printf("%v", err)
    87  }
    88  
    89  func getMediaType(b []byte) string {
    90  	index := bytes.IndexRune(b, ';')
    91  	if index > 0 {
    92  		b = b[:index]
    93  	}
    94  
    95  	return byteconv.Btoa(bytes.TrimSpace(b))
    96  }
    97  
    98  func getStatusCode(i any, fallback int) int {
    99  	if sc, ok := i.(StatusCoder); ok {
   100  		return sc.StatusCode()
   101  	}
   102  	return fallback
   103  }