github.com/abemedia/go-don@v0.2.2-0.20240329015135-be88e32bb73b/request.go (about)

     1  package don
     2  
     3  import (
     4  	"reflect"
     5  
     6  	"github.com/abemedia/go-don/decoder"
     7  	"github.com/abemedia/go-don/encoding"
     8  	"github.com/abemedia/httprouter"
     9  	"github.com/valyala/fasthttp"
    10  )
    11  
    12  type requestDecoder[V any] func(v *V, ctx *fasthttp.RequestCtx, p httprouter.Params) error
    13  
    14  func newRequestDecoder[V any](v V) requestDecoder[V] {
    15  	path, _ := decoder.NewCached(v, "path")
    16  	query, _ := decoder.NewCached(v, "query")
    17  	header, _ := decoder.NewCached(v, "header")
    18  
    19  	if path == nil && query == nil && header == nil {
    20  		return decodeBody[V]()
    21  	}
    22  
    23  	return decodeRequest(path, query, header)
    24  }
    25  
    26  func decodeRequest[V any](path, query, header *decoder.CachedDecoder[V]) requestDecoder[V] {
    27  	body := decodeBody[V]()
    28  	return func(v *V, ctx *fasthttp.RequestCtx, p httprouter.Params) error {
    29  		if err := body(v, ctx, nil); err != nil {
    30  			return err
    31  		}
    32  
    33  		val := reflect.ValueOf(v).Elem()
    34  
    35  		if path != nil && len(p) > 0 {
    36  			if err := path.DecodeValue((decoder.Params)(p), val); err != nil {
    37  				return ErrNotFound
    38  			}
    39  		}
    40  
    41  		if query != nil {
    42  			if q := ctx.Request.URI().QueryArgs(); q.Len() > 0 {
    43  				if err := query.DecodeValue((*decoder.Args)(q), val); err != nil {
    44  					return err
    45  				}
    46  			}
    47  		}
    48  
    49  		if header != nil {
    50  			if err := header.DecodeValue((*decoder.Header)(&ctx.Request.Header), val); err != nil {
    51  				return err
    52  			}
    53  		}
    54  
    55  		return nil
    56  	}
    57  }
    58  
    59  func decodeBody[V any]() requestDecoder[V] {
    60  	return func(v *V, ctx *fasthttp.RequestCtx, _ httprouter.Params) error {
    61  		if ctx.Request.Header.ContentLength() == 0 || ctx.IsGet() || ctx.IsHead() {
    62  			return nil
    63  		}
    64  
    65  		dec := encoding.GetDecoder(getMediaType(ctx.Request.Header.ContentType()))
    66  		if dec == nil {
    67  			return ErrUnsupportedMediaType
    68  		}
    69  
    70  		return dec(ctx, v)
    71  	}
    72  }