github.com/machinefi/w3bstream@v1.6.5-rc9.0.20240426031326-b8c7c4876e72/pkg/depends/kit/httptransport/req_tsfm.go (about)

     1  package httptransport
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"io"
     7  	"mime"
     8  	"net/http"
     9  	"net/textproto"
    10  	neturl "net/url"
    11  	"reflect"
    12  	"sort"
    13  
    14  	"github.com/julienschmidt/httprouter"
    15  	pkgerr "github.com/pkg/errors"
    16  
    17  	"github.com/machinefi/w3bstream/pkg/depends/kit/httptransport/httpx"
    18  	"github.com/machinefi/w3bstream/pkg/depends/kit/httptransport/transformer"
    19  	"github.com/machinefi/w3bstream/pkg/depends/kit/statusx"
    20  	"github.com/machinefi/w3bstream/pkg/depends/kit/validator"
    21  	vldterr "github.com/machinefi/w3bstream/pkg/depends/kit/validator/errors"
    22  	"github.com/machinefi/w3bstream/pkg/depends/x/contextx"
    23  	"github.com/machinefi/w3bstream/pkg/depends/x/reflectx"
    24  )
    25  
    26  type RequestTsfm struct {
    27  	Type   reflect.Type
    28  	Params map[string][]transformer.ReqParam
    29  }
    30  
    31  func (t *RequestTsfm) NewRequest(method, url string, v interface{}) (*http.Request, error) {
    32  	return t.NewReqWithContext(context.Background(), method, url, v)
    33  }
    34  
    35  func (t *RequestTsfm) NewReqWithContext(ctx context.Context, method, url string, v interface{}) (*http.Request, error) {
    36  	if v == nil {
    37  		return http.NewRequestWithContext(ctx, method, url, nil)
    38  	}
    39  
    40  	typ := reflectx.DeRef(reflect.TypeOf(v))
    41  	if t.Type != typ {
    42  		return nil, pkgerr.Errorf(
    43  			"unmatched request transformer, need %s but got %s", t.Type, typ,
    44  		)
    45  	}
    46  
    47  	var (
    48  		errs    = vldterr.NewErrorSet()
    49  		params  = httprouter.Params{}
    50  		query   = neturl.Values{}
    51  		header  = http.Header{}
    52  		cookies = neturl.Values{}
    53  		body    = bytes.NewBuffer(nil)
    54  	)
    55  
    56  	rv, ok := v.(reflect.Value)
    57  	if !ok {
    58  		rv = reflect.ValueOf(v)
    59  	}
    60  	rv = reflectx.Indirect(rv)
    61  
    62  	for _, parameters := range t.Params {
    63  		for i := range parameters {
    64  			p := parameters[i]
    65  
    66  			fv := p.FieldValue(rv)
    67  			if !fv.IsValid() {
    68  				continue
    69  			}
    70  
    71  			if p.In == "body" {
    72  				if err := p.Tsf.EncodeTo(
    73  					ctx,
    74  					transformer.WriterWithHeader(body, header),
    75  					fv,
    76  				); err != nil {
    77  					errs.AddErr(err, p.Name)
    78  				}
    79  				continue
    80  			}
    81  
    82  			writers := transformer.NewStringBuilders()
    83  			if err := transformer.NewSuper(p.Tsf, &p.Option.CommonOption).
    84  				EncodeTo(ctx, writers, fv); err != nil {
    85  				errs.AddErr(err, p.Name)
    86  				continue
    87  			}
    88  
    89  			values := writers.StringSlice()
    90  			switch p.In {
    91  			case "path":
    92  				params = append(params, httprouter.Param{Key: p.Name, Value: values[0]})
    93  			case "query":
    94  				query[p.Name] = values
    95  			case "header":
    96  				header[textproto.CanonicalMIMEHeaderKey(p.Name)] = values
    97  			case "cookie":
    98  				cookies[p.Name] = values
    99  			}
   100  		}
   101  	}
   102  
   103  	req, err := http.NewRequestWithContext(ctx, method, url, nil)
   104  	if err != nil {
   105  		return nil, err
   106  	}
   107  
   108  	if len(params) > 0 {
   109  		req = req.WithContext(
   110  			contextx.WithValue(req.Context(), httprouter.ParamsKey, params),
   111  		)
   112  		req.URL.Path = transformer.NewPathnamePattern(req.URL.Path).Stringify(params)
   113  	}
   114  
   115  	if len(query) > 0 {
   116  		if method == http.MethodGet && ShouldQueryInBodyForGet(ctx) {
   117  			header.Set(
   118  				httpx.HeaderContentType,
   119  				mime.FormatMediaType(
   120  					httpx.MIME_FORM_URLENCODED,
   121  					map[string]string{"param": "value"},
   122  				),
   123  			)
   124  			body = bytes.NewBufferString(query.Encode())
   125  		} else {
   126  			req.URL.RawQuery = query.Encode()
   127  		}
   128  	}
   129  
   130  	req.Header = header
   131  
   132  	if n := len(cookies); n > 0 {
   133  		names := make([]string, n)
   134  		i := 0
   135  		for name := range cookies {
   136  			names[i] = name
   137  			i++
   138  		}
   139  		sort.Strings(names)
   140  
   141  		for _, name := range names {
   142  			values := cookies[name]
   143  			for _, value := range values {
   144  				req.AddCookie(&http.Cookie{Name: name, Value: value})
   145  			}
   146  		}
   147  	}
   148  
   149  	if n := int64(body.Len()); n != 0 {
   150  		req.ContentLength = n
   151  		rc := io.NopCloser(body)
   152  		req.Body = rc
   153  		req.GetBody = func() (io.ReadCloser, error) { return rc, nil }
   154  	}
   155  
   156  	return req, nil
   157  }
   158  
   159  func (t *RequestTsfm) DecodeAndValidate(ctx context.Context, ri httpx.RequestInfo, v interface{}) error {
   160  	if err := t.DecodeFromRequestInfo(ctx, ri, v); err != nil {
   161  		return err
   162  	}
   163  	return t.validate(v)
   164  }
   165  
   166  func (t *RequestTsfm) DecodeFromRequestInfo(ctx context.Context, ri httpx.RequestInfo, v interface{}) error {
   167  	if with, ok := v.(httpx.WithFromRequestInfo); ok {
   168  		if err := with.FromRequestInfo(ri); err != nil {
   169  			if est := err.(interface {
   170  				ToFieldErrors() statusx.ErrorFields
   171  			}); ok {
   172  				if errorFields := est.ToFieldErrors(); len(errorFields) > 0 {
   173  					return (&badRequest{errorFields: errorFields}).Err()
   174  				}
   175  			}
   176  			return err
   177  		}
   178  		return nil
   179  	}
   180  
   181  	rv, ok := v.(reflect.Value)
   182  	if !ok {
   183  		rv = reflect.ValueOf(v)
   184  	}
   185  
   186  	if rv.Kind() != reflect.Ptr {
   187  		return pkgerr.Errorf("decode target must be an ptr value")
   188  	}
   189  
   190  	rv = reflectx.Indirect(rv)
   191  
   192  	if tpe := rv.Type(); tpe != t.Type {
   193  		return pkgerr.Errorf(
   194  			"unmatched request transformer, need %s but got %s",
   195  			t.Type, tpe,
   196  		)
   197  	}
   198  
   199  	errs := vldterr.NewErrorSet()
   200  
   201  	for in := range t.Params {
   202  		parameters := t.Params[in]
   203  
   204  		for i := range parameters {
   205  			param := parameters[i]
   206  
   207  			if param.In == "body" {
   208  				body := ri.Body()
   209  				if err := param.Tsf.DecodeFrom(
   210  					ctx,
   211  					body,
   212  					param.FieldValue(rv).Addr(),
   213  					textproto.MIMEHeader(ri.Header()),
   214  				); err != nil && err != io.EOF {
   215  					errs.AddErr(err, vldterr.Location(param.In))
   216  				}
   217  				body.Close()
   218  				continue
   219  			}
   220  
   221  			var values []string
   222  
   223  			if param.In == "meta" {
   224  				params := OperatorFactoryFromContext(ctx).Params
   225  				if params != nil {
   226  					values = params[param.Name]
   227  				}
   228  			} else {
   229  				values = ri.Values(param.In, param.Name)
   230  			}
   231  
   232  			if len(values) > 0 {
   233  				if err := transformer.NewSuper(
   234  					param.Tsf,
   235  					&param.Option.CommonOption,
   236  				).DecodeFrom(
   237  					ctx,
   238  					transformer.NewStringReaders(values),
   239  					param.FieldValue(rv).Addr(),
   240  				); err != nil {
   241  					errs.AddErr(err, vldterr.Location(param.In), param.Name)
   242  				}
   243  			}
   244  		}
   245  	}
   246  
   247  	if errs.Err() == nil {
   248  		return nil
   249  	}
   250  
   251  	return (&badRequest{errorFields: errs.ToErrorFields()}).Err()
   252  }
   253  
   254  func (t *RequestTsfm) validate(v interface{}) error {
   255  	if self, ok := v.(validator.CanValidate); ok {
   256  		if err := self.Validate(); err != nil {
   257  			if est := err.(interface {
   258  				ToFieldErrors() statusx.ErrorFields
   259  			}); ok {
   260  				if errorFields := est.ToFieldErrors(); len(errorFields) > 0 {
   261  					return (&badRequest{errorFields: errorFields}).Err()
   262  				}
   263  			}
   264  			return err
   265  		}
   266  		return nil
   267  	}
   268  
   269  	rv, ok := v.(reflect.Value)
   270  	if !ok {
   271  		rv = reflect.ValueOf(v)
   272  	}
   273  
   274  	errSet := vldterr.NewErrorSet()
   275  
   276  	for in := range t.Params {
   277  		parameters := t.Params[in]
   278  
   279  		for i := range parameters {
   280  			param := parameters[i]
   281  
   282  			if param.Validator != nil {
   283  				if err := param.Validator.Validate(param.FieldValue(rv)); err != nil {
   284  					if param.In == "body" {
   285  						errSet.AddErr(err, vldterr.Location(param.In))
   286  					} else {
   287  						errSet.AddErr(err, vldterr.Location(param.In), param.Name)
   288  					}
   289  				}
   290  			}
   291  		}
   292  	}
   293  
   294  	br := &badRequest{errorFields: errSet.ToErrorFields()}
   295  
   296  	// TODO deprecated
   297  	if postValidator, ok := rv.Interface().(PostValidator); ok {
   298  		postValidator.PostValidate(br)
   299  	}
   300  
   301  	if errSet.Err() == nil {
   302  		return nil
   303  	}
   304  
   305  	return br.Err()
   306  }
   307  
   308  type PostValidator interface {
   309  	PostValidate(badReqErr BadRequestError)
   310  }