github.com/clubpay/ronykit/kit@v0.14.4-0.20240515065620-d0dace45cbc7/stub/stub_rest.go (about)

     1  package stub
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"io"
     7  	"net/http"
     8  	"time"
     9  
    10  	"github.com/clubpay/ronykit/kit"
    11  	"github.com/clubpay/ronykit/kit/utils"
    12  	"github.com/clubpay/ronykit/kit/utils/reflector"
    13  	"github.com/valyala/fasthttp"
    14  )
    15  
    16  type RESTResponseHandler func(ctx context.Context, r RESTResponse) *Error
    17  
    18  type RESTResponse interface {
    19  	StatusCode() int
    20  	GetBody() []byte
    21  	GetHeader(key string) string
    22  }
    23  
    24  type RESTPreflightHandler func(r *fasthttp.Request)
    25  
    26  type RESTCtx struct {
    27  	cfg            restConfig
    28  	err            *Error
    29  	handlers       map[int]RESTResponseHandler
    30  	defaultHandler RESTResponseHandler
    31  	r              *reflector.Reflector
    32  	dumpReq        io.Writer
    33  	dumpRes        io.Writer
    34  	timeout        time.Duration
    35  	codec          kit.MessageCodec
    36  
    37  	// fasthttp entities
    38  	c    *fasthttp.Client
    39  	uri  *fasthttp.URI
    40  	args *fasthttp.Args
    41  	req  *fasthttp.Request
    42  	res  *fasthttp.Response
    43  }
    44  
    45  func (hc *RESTCtx) SetMethod(method string) *RESTCtx {
    46  	hc.req.Header.SetMethod(method)
    47  
    48  	return hc
    49  }
    50  
    51  func (hc *RESTCtx) SetPath(path string) *RESTCtx {
    52  	hc.uri.SetPath(path)
    53  
    54  	return hc
    55  }
    56  
    57  func (hc *RESTCtx) GET(path string) *RESTCtx {
    58  	hc.SetMethod(http.MethodGet)
    59  	hc.SetPath(path)
    60  
    61  	return hc
    62  }
    63  
    64  func (hc *RESTCtx) POST(path string) *RESTCtx {
    65  	hc.SetMethod(http.MethodPost)
    66  	hc.SetPath(path)
    67  
    68  	return hc
    69  }
    70  
    71  func (hc *RESTCtx) PUT(path string) *RESTCtx {
    72  	hc.SetMethod(http.MethodPut)
    73  	hc.SetPath(path)
    74  
    75  	return hc
    76  }
    77  
    78  func (hc *RESTCtx) PATCH(path string) *RESTCtx {
    79  	hc.SetMethod(http.MethodPatch)
    80  	hc.SetPath(path)
    81  
    82  	return hc
    83  }
    84  
    85  func (hc *RESTCtx) OPTIONS(path string) *RESTCtx {
    86  	hc.SetMethod(http.MethodOptions)
    87  	hc.SetPath(path)
    88  
    89  	return hc
    90  }
    91  
    92  func (hc *RESTCtx) SetQuery(key, value string) *RESTCtx {
    93  	hc.args.Set(key, value)
    94  
    95  	return hc
    96  }
    97  
    98  func (hc *RESTCtx) AppendQuery(key, value string) *RESTCtx {
    99  	hc.args.Add(key, value)
   100  
   101  	return hc
   102  }
   103  
   104  func (hc *RESTCtx) SetQueryMap(kv map[string]string) *RESTCtx {
   105  	for k, v := range kv {
   106  		hc.args.Set(k, v)
   107  	}
   108  
   109  	return hc
   110  }
   111  
   112  func (hc *RESTCtx) SetHeader(key, value string) *RESTCtx {
   113  	hc.req.Header.Set(key, value)
   114  
   115  	return hc
   116  }
   117  
   118  func (hc *RESTCtx) SetHeaderMap(kv map[string]string) *RESTCtx {
   119  	for k, v := range kv {
   120  		hc.req.Header.Set(k, v)
   121  	}
   122  
   123  	return hc
   124  }
   125  
   126  func (hc *RESTCtx) SetBody(body []byte) *RESTCtx {
   127  	hc.req.SetBody(body)
   128  
   129  	return hc
   130  }
   131  
   132  // SetBodyErr is a helper method, which is useful when we want to pass the marshaler function
   133  // directly without checking the error, before passing it to the SetBody method.
   134  // example:
   135  //
   136  //	restCtx.SetBodyErr(json.Marshal(m))
   137  //
   138  // Is equivalent to:
   139  //
   140  //	b, err := json.Marshal(m)
   141  //	if err != nil {
   142  //		// handle err
   143  //	}
   144  //	restCtx.SetBody(b)
   145  func (hc *RESTCtx) SetBodyErr(body []byte, err error) *RESTCtx {
   146  	if err != nil {
   147  		hc.err = WrapError(err)
   148  
   149  		return hc
   150  	}
   151  
   152  	return hc.SetBody(body)
   153  }
   154  
   155  func (hc *RESTCtx) Run(ctx context.Context) *RESTCtx {
   156  	if hc.err != nil {
   157  		return hc
   158  	}
   159  
   160  	// prepare the request
   161  	hc.uri.SetQueryString(hc.args.String())
   162  	hc.req.SetURI(hc.uri)
   163  	for k, v := range hc.cfg.hdr {
   164  		hc.req.Header.Set(k, v)
   165  	}
   166  
   167  	if tp := hc.cfg.tp; tp != nil {
   168  		tp.Inject(ctx, restTraceCarrier{r: &hc.req.Header})
   169  	}
   170  
   171  	// run preflights
   172  	for _, pre := range hc.cfg.preflights {
   173  		pre(hc.req)
   174  	}
   175  
   176  	// execute the request
   177  	hc.err = WrapError(hc.c.DoTimeout(hc.req, hc.res, hc.timeout))
   178  
   179  	if hc.dumpReq != nil {
   180  		_, _ = hc.req.WriteTo(hc.dumpReq)
   181  	}
   182  	if hc.dumpRes != nil {
   183  		_, _ = hc.res.WriteTo(hc.dumpRes)
   184  	}
   185  
   186  	// run the response handler if is set
   187  	statusCode := hc.res.StatusCode()
   188  	if hc.err == nil {
   189  		if h, ok := hc.handlers[statusCode]; ok {
   190  			hc.err = h(ctx, hc)
   191  		} else if hc.defaultHandler != nil {
   192  			hc.err = hc.defaultHandler(ctx, hc)
   193  		}
   194  	}
   195  
   196  	return hc
   197  }
   198  
   199  // Err returns the error if any occurred during the execution.
   200  func (hc *RESTCtx) Err() *Error {
   201  	if hc.err == nil {
   202  		return nil
   203  	}
   204  
   205  	return hc.err
   206  }
   207  
   208  // Error returns the error if any occurred during the execution.
   209  func (hc *RESTCtx) Error() error {
   210  	if hc.err == nil {
   211  		return nil
   212  	}
   213  
   214  	return hc.err
   215  }
   216  
   217  // StatusCode returns the status code of the response
   218  func (hc *RESTCtx) StatusCode() int { return hc.res.StatusCode() }
   219  
   220  // GetHeader returns the header value for the key in the response
   221  func (hc *RESTCtx) GetHeader(key string) string {
   222  	return string(hc.res.Header.Peek(key))
   223  }
   224  
   225  // GetBody returns the body, but please note that the returned slice is only valid until
   226  // Release is called. If you need to use the body after releasing RESTCtx then
   227  // use CopyBody method.
   228  func (hc *RESTCtx) GetBody() []byte {
   229  	if hc.err != nil {
   230  		return nil
   231  	}
   232  
   233  	return hc.res.Body()
   234  }
   235  
   236  // ReadResponseBody reads the response body to the provided writer.
   237  // It MUST be called after Run or AutoRun.
   238  func (hc *RESTCtx) ReadResponseBody(w io.Writer) *RESTCtx {
   239  	if hc.err != nil {
   240  		return hc
   241  	}
   242  
   243  	if _, err := w.Write(hc.res.Body()); err != nil {
   244  		hc.err = WrapError(err)
   245  	}
   246  
   247  	return hc
   248  }
   249  
   250  // CopyBody copies the body to `dst`. It creates a new slice and returns it if dst is nil.
   251  func (hc *RESTCtx) CopyBody(dst []byte) []byte {
   252  	if hc.err != nil {
   253  		return nil
   254  	}
   255  
   256  	dst = append(dst[:0], hc.res.Body()...)
   257  
   258  	return dst
   259  }
   260  
   261  // Release frees the allocated internal resources to be re-used.
   262  // You MUST NOT refer to any method of this object after calling this method, if
   263  // you call any method after Release has been called, the result is unpredictable.
   264  func (hc *RESTCtx) Release() {
   265  	fasthttp.ReleaseArgs(hc.args)
   266  	fasthttp.ReleaseURI(hc.uri)
   267  	fasthttp.ReleaseRequest(hc.req)
   268  	fasthttp.ReleaseResponse(hc.res)
   269  }
   270  
   271  func (hc *RESTCtx) SetResponseHandler(statusCode int, h RESTResponseHandler) *RESTCtx {
   272  	hc.handlers[statusCode] = h
   273  
   274  	return hc
   275  }
   276  
   277  func (hc *RESTCtx) SetOKHandler(h RESTResponseHandler) *RESTCtx {
   278  	hc.handlers[http.StatusOK] = h
   279  	hc.handlers[http.StatusCreated] = h
   280  	hc.handlers[http.StatusAccepted] = h
   281  
   282  	return hc
   283  }
   284  
   285  func (hc *RESTCtx) DefaultResponseHandler(h RESTResponseHandler) *RESTCtx {
   286  	hc.defaultHandler = h
   287  
   288  	return hc
   289  }
   290  
   291  func (hc *RESTCtx) DumpResponse() string {
   292  	return hc.res.String()
   293  }
   294  
   295  // DumpResponseTo accepts a writer and will write the response dump to it when Run is
   296  // executed.
   297  // Example:
   298  //
   299  //	httpCtx := s.REST().
   300  //								DumpRequestTo(os.Stdout).
   301  //								DumpResponseTo(os.Stdout).
   302  //								GET("https//google.com").
   303  //								Run(ctx)
   304  //	defer httpCtx.Release()
   305  //
   306  // **YOU MUST NOT USE httpCtx after httpCtx.Release() is called.**
   307  func (hc *RESTCtx) DumpResponseTo(w io.Writer) *RESTCtx {
   308  	hc.dumpRes = w
   309  
   310  	return hc
   311  }
   312  
   313  func (hc *RESTCtx) DumpRequest() string {
   314  	if hc.err != nil {
   315  		return hc.err.Error()
   316  	}
   317  
   318  	return hc.req.String()
   319  }
   320  
   321  // DumpRequestTo accepts a writer and will write the request dump to it when Run is
   322  // executed.
   323  //
   324  // Please refer to DumpResponseTo
   325  func (hc *RESTCtx) DumpRequestTo(w io.Writer) *RESTCtx {
   326  	hc.dumpReq = w
   327  
   328  	return hc
   329  }
   330  
   331  // AutoRun is a helper method, which fills the request based on the input arguments.
   332  // It checks the route which is a path pattern, and fills the dynamic url params based on
   333  // the `m`'s `tag` keys.
   334  // Example:
   335  //
   336  //	type Request struct {
   337  //			ID int64 `json:"id"`
   338  //			Name string `json:"name"`
   339  //	}
   340  //
   341  // AutoRun(
   342  //
   343  //		context.Background(),
   344  //	  "/something/:id/:name",
   345  //	  kit.JSON,
   346  //	  &Request{ID: 10, Name: "customName"},
   347  //
   348  // )
   349  //
   350  // Is equivalent to:
   351  //
   352  // SetPath("/something/10/customName").
   353  // Run(context.Background())
   354  func (hc *RESTCtx) AutoRun(
   355  	ctx context.Context, route string, enc kit.Encoding, m kit.Message,
   356  ) *RESTCtx {
   357  	switch enc.Tag() {
   358  	case kit.JSON.Tag():
   359  		hc.SetHeader("Content-Type", "application/json")
   360  	case kit.Proto.Tag():
   361  		hc.SetHeader("Content-Type", "application/protobuf")
   362  	}
   363  
   364  	ref := hc.r.Load(m, enc.Tag())
   365  	fields, ok := ref.ByTag(enc.Tag())
   366  	if !ok {
   367  		fields = ref.Obj()
   368  	}
   369  
   370  	usedParams := map[string]struct{}{}
   371  	path := fillParams(
   372  		route,
   373  		func(key string) string {
   374  			usedParams[key] = struct{}{}
   375  
   376  			v := fields.Get(m, key)
   377  			if v == nil {
   378  				return "_"
   379  			}
   380  
   381  			return fmt.Sprintf("%v", v)
   382  		},
   383  	)
   384  	hc.SetPath(path)
   385  
   386  	switch utils.B2S(hc.req.Header.Method()) {
   387  	case http.MethodGet:
   388  		fields.WalkFields(
   389  			func(key string, f reflector.FieldInfo) {
   390  				_, ok := usedParams[key]
   391  				if ok {
   392  					return
   393  				}
   394  
   395  				v := fields.Get(m, key)
   396  				if v == nil {
   397  					return
   398  				}
   399  
   400  				hc.SetQuery(key, fmt.Sprintf("%v", v))
   401  			},
   402  		)
   403  	default:
   404  		var reqBody []byte
   405  		switch enc {
   406  		default:
   407  			reqBody, _ = hc.codec.Marshal(m)
   408  		}
   409  		hc.SetBody(reqBody)
   410  	}
   411  
   412  	return hc.Run(ctx)
   413  }
   414  
   415  type restTraceCarrier struct {
   416  	r *fasthttp.RequestHeader
   417  }
   418  
   419  func (t restTraceCarrier) Get(key string) string {
   420  	return string(t.r.Peek(key))
   421  }
   422  
   423  func (t restTraceCarrier) Set(key string, value string) {
   424  	t.r.Set(key, value)
   425  }