github.com/bingoohuang/gg@v0.0.0-20240325092523-45da7dee9335/pkg/resty/middleware.go (about)

     1  // Copyright (c) 2015-2021 Jeevanandam M (jeeva@myjeeva.com), All rights reserved.
     2  // resty source code and usage is governed by a MIT style
     3  // license that can be found in the LICENSE file.
     4  
     5  package resty
     6  
     7  import (
     8  	"bytes"
     9  	"errors"
    10  	"fmt"
    11  	"io"
    12  	"io/ioutil"
    13  	"mime/multipart"
    14  	"net/http"
    15  	"net/url"
    16  	"os"
    17  	"path/filepath"
    18  	"reflect"
    19  	"strings"
    20  	"time"
    21  )
    22  
    23  const debugRequestLogKey = "__restyDebugRequestLog"
    24  
    25  //‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾
    26  // Request Middleware(s)
    27  //_______________________________________________________________________
    28  
    29  func parseRequestURL(c *Client, r *Request) error {
    30  	// GitHub #103 Path Params
    31  	if len(r.PathParams) > 0 {
    32  		for p, v := range r.PathParams {
    33  			r.URL = strings.Replace(r.URL, "{"+p+"}", url.PathEscape(v), -1)
    34  		}
    35  	}
    36  	if len(c.PathParams) > 0 {
    37  		for p, v := range c.PathParams {
    38  			r.URL = strings.Replace(r.URL, "{"+p+"}", url.PathEscape(v), -1)
    39  		}
    40  	}
    41  
    42  	// Parsing request URL
    43  	reqURL, err := url.Parse(r.URL)
    44  	if err != nil {
    45  		return err
    46  	}
    47  
    48  	// If Request.URL is relative path then added c.BaseURL into
    49  	// the request URL otherwise Request.URL will be used as-is
    50  	if !reqURL.IsAbs() {
    51  		r.URL = reqURL.String()
    52  		if len(r.URL) > 0 && r.URL[0] != '/' {
    53  			r.URL = "/" + r.URL
    54  		}
    55  
    56  		reqURL, err = url.Parse(c.BaseURL + r.URL)
    57  		if err != nil {
    58  			return err
    59  		}
    60  	}
    61  
    62  	// GH #407 && #318
    63  	if reqURL.Scheme == "" && len(c.scheme) > 0 {
    64  		reqURL.Scheme = c.scheme
    65  	}
    66  
    67  	// Adding Query Param
    68  	query := make(url.Values)
    69  	for k, v := range c.QueryParam {
    70  		for _, iv := range v {
    71  			query.Add(k, iv)
    72  		}
    73  	}
    74  
    75  	for k, v := range r.QueryParam {
    76  		// remove query param from client level by key
    77  		// since overrides happens for that key in the request
    78  		query.Del(k)
    79  
    80  		for _, iv := range v {
    81  			query.Add(k, iv)
    82  		}
    83  	}
    84  
    85  	// GitHub #123 Preserve query string order partially.
    86  	// Since not feasible in `SetQuery*` resty methods, because
    87  	// standard package `url.Encode(...)` sorts the query params
    88  	// alphabetically
    89  	if len(query) > 0 {
    90  		if IsStringEmpty(reqURL.RawQuery) {
    91  			reqURL.RawQuery = query.Encode()
    92  		} else {
    93  			reqURL.RawQuery = reqURL.RawQuery + "&" + query.Encode()
    94  		}
    95  	}
    96  
    97  	r.URL = reqURL.String()
    98  
    99  	return nil
   100  }
   101  
   102  func parseRequestHeader(c *Client, r *Request) error {
   103  	hdr := make(http.Header)
   104  	for k := range c.Header {
   105  		hdr[k] = append(hdr[k], c.Header[k]...)
   106  	}
   107  
   108  	for k := range r.Header {
   109  		hdr.Del(k)
   110  		hdr[k] = append(hdr[k], r.Header[k]...)
   111  	}
   112  
   113  	if IsStringEmpty(hdr.Get(hdrUserAgentKey)) {
   114  		hdr.Set(hdrUserAgentKey, hdrUserAgentValue)
   115  	}
   116  
   117  	ct := hdr.Get(hdrContentTypeKey)
   118  	if IsStringEmpty(hdr.Get(hdrAcceptKey)) && !IsStringEmpty(ct) &&
   119  		(IsJSONType(ct) || IsXMLType(ct)) {
   120  		hdr.Set(hdrAcceptKey, hdr.Get(hdrContentTypeKey))
   121  	}
   122  
   123  	r.Header = hdr
   124  
   125  	return nil
   126  }
   127  
   128  func parseRequestBody(c *Client, r *Request) (err error) {
   129  	if isPayloadSupported(r.Method, c.AllowGetMethodPayload) {
   130  		if r.isMultiPart && !(r.Method == http.MethodPatch) {
   131  			// Handling Multipart
   132  			if err = handleMultipart(c, r); err != nil {
   133  				return
   134  			}
   135  		} else if len(c.FormData) > 0 || len(r.FormData) > 0 {
   136  			// Handling Form Data
   137  			handleFormData(c, r)
   138  		} else {
   139  			// Handling Request body
   140  			if r.Body != nil {
   141  				handleContentType(c, r)
   142  				if err = handleRequestBody(c, r); err != nil {
   143  					return
   144  				}
   145  			}
   146  		}
   147  	}
   148  
   149  	// by default resty won't set content length, you can if you want to :)
   150  	if (c.setContentLength || r.setContentLength) && r.bodyBuf != nil {
   151  		r.Header.Set(hdrContentLengthKey, fmt.Sprintf("%d", r.bodyBuf.Len()))
   152  	}
   153  
   154  	return
   155  }
   156  
   157  func createHTTPRequest(c *Client, r *Request) (err error) {
   158  	if r.bodyBuf == nil {
   159  		if reader, ok := r.Body.(io.Reader); ok {
   160  			r.RawRequest, err = http.NewRequest(r.Method, r.URL, reader)
   161  		} else if c.setContentLength || r.setContentLength {
   162  			r.RawRequest, err = http.NewRequest(r.Method, r.URL, http.NoBody)
   163  		} else {
   164  			r.RawRequest, err = http.NewRequest(r.Method, r.URL, nil)
   165  		}
   166  	} else {
   167  		r.RawRequest, err = http.NewRequest(r.Method, r.URL, r.bodyBuf)
   168  	}
   169  
   170  	if err != nil {
   171  		return
   172  	}
   173  
   174  	// Assign close connection option
   175  	r.RawRequest.Close = c.closeConnection
   176  
   177  	// Add headers into http request
   178  	r.RawRequest.Header = r.Header
   179  
   180  	// Add cookies from client instance into http request
   181  	for _, cookie := range c.Cookies {
   182  		r.RawRequest.AddCookie(cookie)
   183  	}
   184  
   185  	// Add cookies from request instance into http request
   186  	for _, cookie := range r.Cookies {
   187  		r.RawRequest.AddCookie(cookie)
   188  	}
   189  
   190  	// Enable trace
   191  	if c.trace || r.trace {
   192  		r.clientTrace = &clientTrace{}
   193  		r.ctx = r.clientTrace.createContext(r.Context())
   194  	}
   195  
   196  	// Use context if it was specified
   197  	if r.ctx != nil {
   198  		r.RawRequest = r.RawRequest.WithContext(r.ctx)
   199  	}
   200  
   201  	bodyCopy, err := getBodyCopy(r)
   202  	if err != nil {
   203  		return err
   204  	}
   205  
   206  	// assign get body func for the underlying raw request instance
   207  	r.RawRequest.GetBody = func() (io.ReadCloser, error) {
   208  		if bodyCopy != nil {
   209  			return ioutil.NopCloser(bytes.NewReader(bodyCopy.Bytes())), nil
   210  		}
   211  		return nil, nil
   212  	}
   213  
   214  	return
   215  }
   216  
   217  func addCredentials(c *Client, r *Request) error {
   218  	// Basic Auth
   219  	userInfo := r.UserInfo
   220  	if userInfo == nil {
   221  		userInfo = c.UserInfo
   222  	}
   223  	isBasicAuth := userInfo != nil
   224  	if isBasicAuth {
   225  		r.RawRequest.SetBasicAuth(userInfo.Username, userInfo.Password)
   226  	}
   227  
   228  	if !c.DisableWarn {
   229  		if isBasicAuth && !strings.HasPrefix(r.URL, "https") {
   230  			c.log.Warnf("Using Basic Auth in HTTP mode is not secure, use HTTPS")
   231  		}
   232  	}
   233  
   234  	// Set the Authorization Header Scheme
   235  	authScheme := "Bearer"
   236  	if !IsStringEmpty(r.AuthScheme) {
   237  		authScheme = r.AuthScheme
   238  	} else if !IsStringEmpty(c.AuthScheme) {
   239  		authScheme = c.AuthScheme
   240  	}
   241  
   242  	// Build the Token Auth header
   243  	token := r.Token
   244  	if IsStringEmpty(token) {
   245  		token = c.Token
   246  	}
   247  
   248  	if !IsStringEmpty(token) {
   249  		r.RawRequest.Header.Set(c.HeaderAuthorizationKey, authScheme+" "+token)
   250  	}
   251  
   252  	return nil
   253  }
   254  
   255  func requestLogger(c *Client, r *Request) error {
   256  	if !c.Debug {
   257  		return nil
   258  	}
   259  
   260  	rr := r.RawRequest
   261  	rl := &RequestLog{Header: copyHeaders(rr.Header), Body: r.fmtBodyString(c.debugBodySizeLimit)}
   262  	if c.requestLog != nil {
   263  		if err := c.requestLog(rl); err != nil {
   264  			return err
   265  		}
   266  	}
   267  	// fmt.Sprintf("COOKIES:\n%s\n", composeCookies(c.GetClient().Jar, *rr.URL)) +
   268  
   269  	reqLog := "\n==============================================================================\n" +
   270  		"~~~ REQUEST ~~~\n" +
   271  		fmt.Sprintf("%s  %s  %s\n", r.Method, rr.URL.RequestURI(), rr.Proto) +
   272  		fmt.Sprintf("HOST   : %s\n", rr.URL.Host) +
   273  		fmt.Sprintf("HEADERS:\n%s\n", composeHeaders(c, r, rl.Header)) +
   274  		fmt.Sprintf("BODY   :\n%v\n", rl.Body) +
   275  		"------------------------------------------------------------------------------\n"
   276  
   277  	r.initValuesMap()
   278  	r.values[debugRequestLogKey] = reqLog
   279  
   280  	return nil
   281  }
   282  
   283  //‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾
   284  // Response Middleware(s)
   285  //_______________________________________________________________________
   286  
   287  func responseLogger(c *Client, res *Response) error {
   288  	if !c.Debug {
   289  		return nil
   290  	}
   291  
   292  	rl := &ResponseLog{Header: copyHeaders(res.Header()), Body: res.fmtBodyString(c.debugBodySizeLimit)}
   293  	if c.responseLog != nil {
   294  		if err := c.responseLog(rl); err != nil {
   295  			return err
   296  		}
   297  	}
   298  
   299  	debugLog := res.Request.values[debugRequestLogKey].(string)
   300  	debugLog += "~~~ RESPONSE ~~~\n" +
   301  		fmt.Sprintf("STATUS       : %s\n", res.Status()) +
   302  		fmt.Sprintf("PROTO        : %s\n", res.RawResponse.Proto) +
   303  		fmt.Sprintf("RECEIVED AT  : %v\n", res.ReceivedAt().Format(time.RFC3339Nano)) +
   304  		fmt.Sprintf("TIME DURATION: %v\n", res.Time()) +
   305  		"HEADERS      :\n" +
   306  		composeHeaders(c, res.Request, rl.Header) + "\n"
   307  	if res.Request.isSaveResponse {
   308  		debugLog += "BODY         :\n***** RESPONSE WRITTEN INTO FILE *****\n"
   309  	} else {
   310  		debugLog += fmt.Sprintf("BODY         :\n%v\n", rl.Body)
   311  	}
   312  	debugLog += "==============================================================================\n"
   313  
   314  	c.log.Debugf("%s", debugLog)
   315  
   316  	return nil
   317  }
   318  
   319  func parseResponseBody(c *Client, res *Response) (err error) {
   320  	if res.StatusCode() == http.StatusNoContent {
   321  		return
   322  	}
   323  	// Handles only JSON or XML content type
   324  	ct := firstNonEmpty(res.Request.forceContentType, res.Header().Get(hdrContentTypeKey), res.Request.fallbackContentType)
   325  	if IsJSONType(ct) || IsXMLType(ct) {
   326  		// HTTP status code > 199 and < 300, considered as Result
   327  		if res.IsSuccess() {
   328  			res.Request.Error = nil
   329  			if res.Request.Result != nil {
   330  				err = Unmarshalc(c, ct, res.body, res.Request.Result)
   331  				return
   332  			}
   333  		}
   334  
   335  		// HTTP status code > 399, considered as Error
   336  		if res.IsError() {
   337  			// global error interface
   338  			if res.Request.Error == nil && c.Error != nil {
   339  				res.Request.Error = reflect.New(c.Error).Interface()
   340  			}
   341  
   342  			if res.Request.Error != nil {
   343  				err = Unmarshalc(c, ct, res.body, res.Request.Error)
   344  			}
   345  		}
   346  	}
   347  
   348  	return
   349  }
   350  
   351  func handleMultipart(c *Client, r *Request) (err error) {
   352  	r.bodyBuf = acquireBuffer()
   353  	w := multipart.NewWriter(r.bodyBuf)
   354  
   355  	for k, v := range c.FormData {
   356  		for _, iv := range v {
   357  			if err = w.WriteField(k, iv); err != nil {
   358  				return err
   359  			}
   360  		}
   361  	}
   362  
   363  	for k, v := range r.FormData {
   364  		for _, iv := range v {
   365  			if strings.HasPrefix(k, "@") { // file
   366  				err = addFile(w, k[1:], iv)
   367  				if err != nil {
   368  					return
   369  				}
   370  			} else { // form value
   371  				if err = w.WriteField(k, iv); err != nil {
   372  					return err
   373  				}
   374  			}
   375  		}
   376  	}
   377  
   378  	// #21 - adding io.Reader support
   379  	if len(r.multipartFiles) > 0 {
   380  		for _, f := range r.multipartFiles {
   381  			err = addFileReader(w, f)
   382  			if err != nil {
   383  				return
   384  			}
   385  		}
   386  	}
   387  
   388  	// GitHub #130 adding multipart field support with content type
   389  	if len(r.multipartFields) > 0 {
   390  		for _, mf := range r.multipartFields {
   391  			if err = addMultipartFormField(w, mf); err != nil {
   392  				return
   393  			}
   394  		}
   395  	}
   396  
   397  	r.Header.Set(hdrContentTypeKey, w.FormDataContentType())
   398  	err = w.Close()
   399  
   400  	return
   401  }
   402  
   403  func handleFormData(c *Client, r *Request) {
   404  	formData := url.Values{}
   405  
   406  	for k, v := range c.FormData {
   407  		for _, iv := range v {
   408  			formData.Add(k, iv)
   409  		}
   410  	}
   411  
   412  	for k, v := range r.FormData {
   413  		// remove form data field from client level by key
   414  		// since overrides happens for that key in the request
   415  		formData.Del(k)
   416  
   417  		for _, iv := range v {
   418  			formData.Add(k, iv)
   419  		}
   420  	}
   421  
   422  	r.bodyBuf = bytes.NewBuffer([]byte(formData.Encode()))
   423  	r.Header.Set(hdrContentTypeKey, formContentType)
   424  	r.isFormData = true
   425  }
   426  
   427  func handleContentType(_ *Client, r *Request) {
   428  	contentType := r.Header.Get(hdrContentTypeKey)
   429  	if IsStringEmpty(contentType) {
   430  		contentType = DetectContentType(r.Body)
   431  		r.Header.Set(hdrContentTypeKey, contentType)
   432  	}
   433  }
   434  
   435  func handleRequestBody(c *Client, r *Request) (err error) {
   436  	var bodyBytes []byte
   437  	contentType := r.Header.Get(hdrContentTypeKey)
   438  	kind := kindOf(r.Body)
   439  	r.bodyBuf = nil
   440  
   441  	if reader, ok := r.Body.(io.Reader); ok {
   442  		if c.setContentLength || r.setContentLength { // keep backward compatibility
   443  			r.bodyBuf = acquireBuffer()
   444  			_, err = r.bodyBuf.ReadFrom(reader)
   445  			r.Body = nil
   446  		} else {
   447  			// Otherwise buffer less processing for `io.Reader`, sounds good.
   448  			return
   449  		}
   450  	} else if b, ok := r.Body.([]byte); ok {
   451  		bodyBytes = b
   452  	} else if s, ok := r.Body.(string); ok {
   453  		bodyBytes = []byte(s)
   454  	} else if IsJSONType(contentType) && (kind == reflect.Struct || kind == reflect.Map || kind == reflect.Slice) {
   455  		r.bodyBuf, err = jsonMarshal(c, r, r.Body)
   456  		if err != nil {
   457  			return
   458  		}
   459  	} else if IsXMLType(contentType) && (kind == reflect.Struct) {
   460  		bodyBytes, err = c.XMLMarshal(r.Body)
   461  		if err != nil {
   462  			return
   463  		}
   464  	}
   465  
   466  	if bodyBytes == nil && r.bodyBuf == nil {
   467  		err = errors.New("unsupported 'Body' type/value")
   468  	}
   469  
   470  	// if any errors during body bytes handling, return it
   471  	if err != nil {
   472  		return
   473  	}
   474  
   475  	// []byte into Buffer
   476  	if bodyBytes != nil && r.bodyBuf == nil {
   477  		r.bodyBuf = acquireBuffer()
   478  		_, _ = r.bodyBuf.Write(bodyBytes)
   479  	}
   480  
   481  	return
   482  }
   483  
   484  func saveResponseIntoFile(c *Client, res *Response) error {
   485  	if res.Request.isSaveResponse {
   486  		file := ""
   487  
   488  		if len(c.outputDirectory) > 0 && !filepath.IsAbs(res.Request.outputFile) {
   489  			file += c.outputDirectory + string(filepath.Separator)
   490  		}
   491  
   492  		file = filepath.Clean(file + res.Request.outputFile)
   493  		if err := createDirectory(filepath.Dir(file)); err != nil {
   494  			return err
   495  		}
   496  
   497  		outFile, err := os.Create(file)
   498  		if err != nil {
   499  			return err
   500  		}
   501  		defer closeq(outFile)
   502  
   503  		// io.Copy reads maximum 32kb size, it is perfect for large file download too
   504  		defer closeq(res.RawResponse.Body)
   505  
   506  		written, err := io.Copy(outFile, res.RawResponse.Body)
   507  		if err != nil {
   508  			return err
   509  		}
   510  
   511  		res.size = written
   512  	}
   513  
   514  	return nil
   515  }
   516  
   517  func getBodyCopy(r *Request) (*bytes.Buffer, error) {
   518  	// If r.bodyBuf present, return the copy
   519  	if r.bodyBuf != nil {
   520  		return bytes.NewBuffer(r.bodyBuf.Bytes()), nil
   521  	}
   522  
   523  	// Maybe body is `io.Reader`.
   524  	// Note: Resty user have to watchout for large body size of `io.Reader`
   525  	if r.RawRequest.Body != nil {
   526  		b, err := io.ReadAll(r.RawRequest.Body)
   527  		if err != nil {
   528  			return nil, err
   529  		}
   530  
   531  		// Restore the Body
   532  		closeq(r.RawRequest.Body)
   533  		r.RawRequest.Body = ioutil.NopCloser(bytes.NewBuffer(b))
   534  
   535  		// Return the Body bytes
   536  		return bytes.NewBuffer(b), nil
   537  	}
   538  	return nil, nil
   539  }