github.com/abemedia/go-don@v0.2.2-0.20240329015135-be88e32bb73b/handler.go (about) 1 package don 2 3 import ( 4 "bytes" 5 "context" 6 "net/http" 7 8 "github.com/abemedia/go-don/encoding" 9 "github.com/abemedia/go-don/internal/byteconv" 10 "github.com/abemedia/httprouter" 11 "github.com/valyala/fasthttp" 12 ) 13 14 // StatusCoder allows you to customise the HTTP response code. 15 type StatusCoder interface { 16 StatusCode() int 17 } 18 19 // Headerer allows you to customise the HTTP headers. 20 type Headerer interface { 21 Header() http.Header 22 } 23 24 // Handle is the type for your handlers. 25 type Handle[T, O any] func(ctx context.Context, request T) (O, error) 26 27 // H wraps your handler function with the Go generics magic. 28 func H[T, O any](handle Handle[T, O]) httprouter.Handle { 29 pool := newRequestPool(*new(T)) 30 decodeRequest := newRequestDecoder(*new(T)) 31 isNil := newNilCheck(*new(O)) 32 33 return func(ctx *fasthttp.RequestCtx, p httprouter.Params) { 34 contentType := getMediaType(ctx.Request.Header.Peek(fasthttp.HeaderAccept)) 35 36 enc := encoding.GetEncoder(contentType) 37 if enc == nil { 38 handleError(ctx, ErrNotAcceptable) 39 return 40 } 41 42 var res any 43 44 req := pool.Get() 45 err := decodeRequest(req, ctx, p) 46 if err != nil { 47 res = Error(err, getStatusCode(err, http.StatusBadRequest)) 48 } else { 49 res, err = handle(ctx, *req) 50 if err != nil { 51 res = Error(err, 0) 52 } 53 } 54 pool.Put(req) 55 56 ctx.SetContentType(contentType + "; charset=utf-8") 57 58 if h, ok := res.(Headerer); ok { 59 for k, v := range h.Header() { 60 ctx.Response.Header.Set(k, v[0]) 61 } 62 } 63 64 if sc, ok := res.(StatusCoder); ok { 65 ctx.SetStatusCode(sc.StatusCode()) 66 } 67 68 if err == nil && isNil(res) { 69 res = nil 70 ctx.Response.Header.SetContentLength(-3) 71 } 72 73 if err = enc(ctx, res); err != nil { 74 handleError(ctx, err) 75 } 76 } 77 } 78 79 func handleError(ctx *fasthttp.RequestCtx, err error) { 80 code := getStatusCode(err, http.StatusInternalServerError) 81 if code < http.StatusInternalServerError { 82 ctx.Error(err.Error(), code) 83 return 84 } 85 ctx.Error(fasthttp.StatusMessage(code), code) 86 ctx.Logger().Printf("%v", err) 87 } 88 89 func getMediaType(b []byte) string { 90 index := bytes.IndexRune(b, ';') 91 if index > 0 { 92 b = b[:index] 93 } 94 95 return byteconv.Btoa(bytes.TrimSpace(b)) 96 } 97 98 func getStatusCode(i any, fallback int) int { 99 if sc, ok := i.(StatusCoder); ok { 100 return sc.StatusCode() 101 } 102 return fallback 103 }