github.com/rclone/rclone@v1.66.1-0.20240517100346-7b89735ae726/lib/rest/rest.go (about)

     1  // Package rest implements a simple REST wrapper
     2  //
     3  // All methods are safe for concurrent calling.
     4  package rest
     5  
     6  import (
     7  	"bytes"
     8  	"context"
     9  	"encoding/json"
    10  	"encoding/xml"
    11  	"errors"
    12  	"fmt"
    13  	"io"
    14  	"mime/multipart"
    15  	"net/http"
    16  	"net/url"
    17  	"sync"
    18  
    19  	"github.com/rclone/rclone/fs"
    20  	"github.com/rclone/rclone/lib/readers"
    21  )
    22  
    23  // Client contains the info to sustain the API
    24  type Client struct {
    25  	mu           sync.RWMutex
    26  	c            *http.Client
    27  	rootURL      string
    28  	errorHandler func(resp *http.Response) error
    29  	headers      map[string]string
    30  	signer       SignerFn
    31  }
    32  
    33  // NewClient takes an oauth http.Client and makes a new api instance
    34  func NewClient(c *http.Client) *Client {
    35  	api := &Client{
    36  		c:            c,
    37  		errorHandler: defaultErrorHandler,
    38  		headers:      make(map[string]string),
    39  	}
    40  	return api
    41  }
    42  
    43  // ReadBody reads resp.Body into result, closing the body
    44  func ReadBody(resp *http.Response) (result []byte, err error) {
    45  	defer fs.CheckClose(resp.Body, &err)
    46  	return io.ReadAll(resp.Body)
    47  }
    48  
    49  // defaultErrorHandler doesn't attempt to parse the http body, just
    50  // returns it in the error message closing resp.Body
    51  func defaultErrorHandler(resp *http.Response) (err error) {
    52  	body, err := ReadBody(resp)
    53  	if err != nil {
    54  		return fmt.Errorf("error reading error out of body: %w", err)
    55  	}
    56  	return fmt.Errorf("HTTP error %v (%v) returned body: %q", resp.StatusCode, resp.Status, body)
    57  }
    58  
    59  // SetErrorHandler sets the handler to decode an error response when
    60  // the HTTP status code is not 2xx.  The handler should close resp.Body.
    61  func (api *Client) SetErrorHandler(fn func(resp *http.Response) error) *Client {
    62  	api.mu.Lock()
    63  	defer api.mu.Unlock()
    64  	api.errorHandler = fn
    65  	return api
    66  }
    67  
    68  // SetRoot sets the default RootURL.  You can override this on a per
    69  // call basis using the RootURL field in Opts.
    70  func (api *Client) SetRoot(RootURL string) *Client {
    71  	api.mu.Lock()
    72  	defer api.mu.Unlock()
    73  	api.rootURL = RootURL
    74  	return api
    75  }
    76  
    77  // SetHeader sets a header for all requests
    78  // Start the key with "*" for don't canonicalise
    79  func (api *Client) SetHeader(key, value string) *Client {
    80  	api.mu.Lock()
    81  	defer api.mu.Unlock()
    82  	api.headers[key] = value
    83  	return api
    84  }
    85  
    86  // RemoveHeader unsets a header for all requests
    87  func (api *Client) RemoveHeader(key string) *Client {
    88  	api.mu.Lock()
    89  	defer api.mu.Unlock()
    90  	delete(api.headers, key)
    91  	return api
    92  }
    93  
    94  // SignerFn is used to sign an outgoing request
    95  type SignerFn func(*http.Request) error
    96  
    97  // SetSigner sets a signer for all requests
    98  func (api *Client) SetSigner(signer SignerFn) *Client {
    99  	api.mu.Lock()
   100  	defer api.mu.Unlock()
   101  	api.signer = signer
   102  	return api
   103  }
   104  
   105  // SetUserPass creates an Authorization header for all requests with
   106  // the UserName and Password passed in
   107  func (api *Client) SetUserPass(UserName, Password string) *Client {
   108  	req, _ := http.NewRequest("GET", "http://example.com", nil)
   109  	req.SetBasicAuth(UserName, Password)
   110  	api.SetHeader("Authorization", req.Header.Get("Authorization"))
   111  	return api
   112  }
   113  
   114  // SetCookie creates a Cookies Header for all requests with the supplied
   115  // cookies passed in.
   116  // All cookies have to be supplied at once, all cookies will be overwritten
   117  // on a new call to the method
   118  func (api *Client) SetCookie(cks ...*http.Cookie) *Client {
   119  	req, _ := http.NewRequest("GET", "http://example.com", nil)
   120  	for _, ck := range cks {
   121  		req.AddCookie(ck)
   122  	}
   123  	api.SetHeader("Cookie", req.Header.Get("Cookie"))
   124  	return api
   125  }
   126  
   127  // Opts contains parameters for Call, CallJSON, etc.
   128  type Opts struct {
   129  	Method                string // GET, POST, etc.
   130  	Path                  string // relative to RootURL
   131  	RootURL               string // override RootURL passed into SetRoot()
   132  	Body                  io.Reader
   133  	GetBody               func() (io.ReadCloser, error) // body builder, needed to enable low-level HTTP/2 retries
   134  	NoResponse            bool                          // set to close Body
   135  	ContentType           string
   136  	ContentLength         *int64
   137  	ContentRange          string
   138  	ExtraHeaders          map[string]string // extra headers, start them with "*" for don't canonicalise
   139  	UserName              string            // username for Basic Auth
   140  	Password              string            // password for Basic Auth
   141  	Options               []fs.OpenOption
   142  	IgnoreStatus          bool         // if set then we don't check error status or parse error body
   143  	MultipartParams       url.Values   // if set do multipart form upload with attached file
   144  	MultipartMetadataName string       // ..this is used for the name of the metadata form part if set
   145  	MultipartContentName  string       // ..name of the parameter which is the attached file
   146  	MultipartFileName     string       // ..name of the file for the attached file
   147  	Parameters            url.Values   // any parameters for the final URL
   148  	TransferEncoding      []string     // transfer encoding, set to "identity" to disable chunked encoding
   149  	Trailer               *http.Header // set the request trailer
   150  	Close                 bool         // set to close the connection after this transaction
   151  	NoRedirect            bool         // if this is set then the client won't follow redirects
   152  	// On Redirects, call this function - see the http.Client docs: https://pkg.go.dev/net/http#Client
   153  	CheckRedirect func(req *http.Request, via []*http.Request) error
   154  }
   155  
   156  // Copy creates a copy of the options
   157  func (o *Opts) Copy() *Opts {
   158  	newOpts := *o
   159  	return &newOpts
   160  }
   161  
   162  const drainLimit = 10 * 1024 * 1024
   163  
   164  // drainAndClose discards up to drainLimit bytes from r and closes
   165  // it. Any errors from the Read or Close are returned.
   166  func drainAndClose(r io.ReadCloser) (err error) {
   167  	_, readErr := io.CopyN(io.Discard, r, drainLimit)
   168  	if readErr == io.EOF {
   169  		readErr = nil
   170  	}
   171  	err = r.Close()
   172  	if readErr != nil {
   173  		return readErr
   174  	}
   175  	return err
   176  }
   177  
   178  // checkDrainAndClose is a utility function used to check the return
   179  // from drainAndClose in a defer statement.
   180  func checkDrainAndClose(r io.ReadCloser, err *error) {
   181  	cerr := drainAndClose(r)
   182  	if *err == nil {
   183  		*err = cerr
   184  	}
   185  }
   186  
   187  // DecodeJSON decodes resp.Body into result
   188  func DecodeJSON(resp *http.Response, result interface{}) (err error) {
   189  	defer checkDrainAndClose(resp.Body, &err)
   190  	decoder := json.NewDecoder(resp.Body)
   191  	return decoder.Decode(result)
   192  }
   193  
   194  // DecodeXML decodes resp.Body into result
   195  func DecodeXML(resp *http.Response, result interface{}) (err error) {
   196  	defer checkDrainAndClose(resp.Body, &err)
   197  	decoder := xml.NewDecoder(resp.Body)
   198  	// MEGAcmd has included escaped HTML entities in its XML output, so we have to be able to
   199  	// decode them.
   200  	decoder.Strict = false
   201  	decoder.Entity = xml.HTMLEntity
   202  	return decoder.Decode(result)
   203  }
   204  
   205  // ClientWithNoRedirects makes a new http client which won't follow redirects
   206  func ClientWithNoRedirects(c *http.Client) *http.Client {
   207  	clientCopy := *c
   208  	clientCopy.CheckRedirect = func(req *http.Request, via []*http.Request) error {
   209  		return http.ErrUseLastResponse
   210  	}
   211  	return &clientCopy
   212  }
   213  
   214  // Do calls the internal http.Client.Do method
   215  func (api *Client) Do(req *http.Request) (*http.Response, error) {
   216  	return api.c.Do(req)
   217  }
   218  
   219  // Call makes the call and returns the http.Response
   220  //
   221  // if err == nil then resp.Body will need to be closed unless
   222  // opt.NoResponse is set
   223  //
   224  // if err != nil then resp.Body will have been closed
   225  //
   226  // it will return resp if at all possible, even if err is set
   227  func (api *Client) Call(ctx context.Context, opts *Opts) (resp *http.Response, err error) {
   228  	api.mu.RLock()
   229  	defer api.mu.RUnlock()
   230  	if opts == nil {
   231  		return nil, errors.New("call() called with nil opts")
   232  	}
   233  	url := api.rootURL
   234  	if opts.RootURL != "" {
   235  		url = opts.RootURL
   236  	}
   237  	if url == "" {
   238  		return nil, errors.New("RootURL not set")
   239  	}
   240  	url += opts.Path
   241  	if opts.Parameters != nil && len(opts.Parameters) > 0 {
   242  		url += "?" + opts.Parameters.Encode()
   243  	}
   244  	body := readers.NoCloser(opts.Body)
   245  	// If length is set and zero then nil out the body to stop use
   246  	// use of chunked encoding and insert a "Content-Length: 0"
   247  	// header.
   248  	//
   249  	// If we don't do this we get "Content-Length" headers for all
   250  	// files except 0 length files.
   251  	if opts.ContentLength != nil && *opts.ContentLength == 0 {
   252  		body = nil
   253  	}
   254  	req, err := http.NewRequestWithContext(ctx, opts.Method, url, body)
   255  	if err != nil {
   256  		return
   257  	}
   258  	headers := make(map[string]string)
   259  	// Set default headers
   260  	for k, v := range api.headers {
   261  		headers[k] = v
   262  	}
   263  	if opts.ContentType != "" {
   264  		headers["Content-Type"] = opts.ContentType
   265  	}
   266  	if opts.ContentLength != nil {
   267  		req.ContentLength = *opts.ContentLength
   268  	}
   269  	if opts.ContentRange != "" {
   270  		headers["Content-Range"] = opts.ContentRange
   271  	}
   272  	if len(opts.TransferEncoding) != 0 {
   273  		req.TransferEncoding = opts.TransferEncoding
   274  	}
   275  	if opts.GetBody != nil {
   276  		req.GetBody = opts.GetBody
   277  	}
   278  	if opts.Trailer != nil {
   279  		req.Trailer = *opts.Trailer
   280  	}
   281  	if opts.Close {
   282  		req.Close = true
   283  	}
   284  	// Set any extra headers
   285  	for k, v := range opts.ExtraHeaders {
   286  		headers[k] = v
   287  	}
   288  	// add any options to the headers
   289  	fs.OpenOptionAddHeaders(opts.Options, headers)
   290  	// Now set the headers
   291  	for k, v := range headers {
   292  		if k != "" && v != "" {
   293  			if k[0] == '*' {
   294  				// Add non-canonical version if header starts with *
   295  				k = k[1:]
   296  				req.Header[k] = append(req.Header[k], v)
   297  			} else {
   298  				req.Header.Add(k, v)
   299  			}
   300  		}
   301  	}
   302  
   303  	if opts.UserName != "" || opts.Password != "" {
   304  		req.SetBasicAuth(opts.UserName, opts.Password)
   305  	}
   306  	var c *http.Client
   307  	if opts.NoRedirect {
   308  		c = ClientWithNoRedirects(api.c)
   309  	} else if opts.CheckRedirect != nil {
   310  		clientCopy := *api.c
   311  		clientCopy.CheckRedirect = opts.CheckRedirect
   312  		c = &clientCopy
   313  	} else {
   314  		c = api.c
   315  	}
   316  	if api.signer != nil {
   317  		api.mu.RUnlock()
   318  		err = api.signer(req)
   319  		api.mu.RLock()
   320  		if err != nil {
   321  			return nil, fmt.Errorf("signer failed: %w", err)
   322  		}
   323  	}
   324  	api.mu.RUnlock()
   325  	resp, err = c.Do(req)
   326  	api.mu.RLock()
   327  	if err != nil {
   328  		return nil, err
   329  	}
   330  	if !opts.IgnoreStatus {
   331  		if resp.StatusCode < 200 || resp.StatusCode > 299 {
   332  			err = api.errorHandler(resp)
   333  			if err.Error() == "" {
   334  				// replace empty errors with something
   335  				err = fmt.Errorf("http error %d: %v", resp.StatusCode, resp.Status)
   336  			}
   337  			return resp, err
   338  		}
   339  	}
   340  	if opts.NoResponse {
   341  		return resp, drainAndClose(resp.Body)
   342  	}
   343  	return resp, nil
   344  }
   345  
   346  // MultipartUpload creates an io.Reader which produces an encoded a
   347  // multipart form upload from the params passed in and the  passed in
   348  //
   349  // in - the body of the file (may be nil)
   350  // params - the form parameters
   351  // fileName - is the name of the attached file
   352  // contentName - the name of the parameter for the file
   353  //
   354  // the int64 returned is the overhead in addition to the file contents, in case Content-Length is required
   355  //
   356  // NB This doesn't allow setting the content type of the attachment
   357  func MultipartUpload(ctx context.Context, in io.Reader, params url.Values, contentName, fileName string) (io.ReadCloser, string, int64, error) {
   358  	bodyReader, bodyWriter := io.Pipe()
   359  	writer := multipart.NewWriter(bodyWriter)
   360  	contentType := writer.FormDataContentType()
   361  
   362  	// Create a Multipart Writer as base for calculating the Content-Length
   363  	buf := &bytes.Buffer{}
   364  	dummyMultipartWriter := multipart.NewWriter(buf)
   365  	err := dummyMultipartWriter.SetBoundary(writer.Boundary())
   366  	if err != nil {
   367  		return nil, "", 0, err
   368  	}
   369  
   370  	for key, vals := range params {
   371  		for _, val := range vals {
   372  			err := dummyMultipartWriter.WriteField(key, val)
   373  			if err != nil {
   374  				return nil, "", 0, err
   375  			}
   376  		}
   377  	}
   378  	if in != nil {
   379  		_, err = dummyMultipartWriter.CreateFormFile(contentName, fileName)
   380  		if err != nil {
   381  			return nil, "", 0, err
   382  		}
   383  	}
   384  
   385  	err = dummyMultipartWriter.Close()
   386  	if err != nil {
   387  		return nil, "", 0, err
   388  	}
   389  
   390  	multipartLength := int64(buf.Len())
   391  
   392  	// Make sure we close the pipe writer to release the reader on context cancel
   393  	quit := make(chan struct{})
   394  	go func() {
   395  		select {
   396  		case <-quit:
   397  			break
   398  		case <-ctx.Done():
   399  			_ = bodyWriter.CloseWithError(ctx.Err())
   400  		}
   401  	}()
   402  
   403  	// Pump the data in the background
   404  	go func() {
   405  		defer close(quit)
   406  
   407  		var err error
   408  
   409  		for key, vals := range params {
   410  			for _, val := range vals {
   411  				err = writer.WriteField(key, val)
   412  				if err != nil {
   413  					_ = bodyWriter.CloseWithError(fmt.Errorf("create metadata part: %w", err))
   414  					return
   415  				}
   416  			}
   417  		}
   418  
   419  		if in != nil {
   420  			part, err := writer.CreateFormFile(contentName, fileName)
   421  			if err != nil {
   422  				_ = bodyWriter.CloseWithError(fmt.Errorf("failed to create form file: %w", err))
   423  				return
   424  			}
   425  
   426  			_, err = io.Copy(part, in)
   427  			if err != nil {
   428  				_ = bodyWriter.CloseWithError(fmt.Errorf("failed to copy data: %w", err))
   429  				return
   430  			}
   431  		}
   432  
   433  		err = writer.Close()
   434  		if err != nil {
   435  			_ = bodyWriter.CloseWithError(fmt.Errorf("failed to close form: %w", err))
   436  			return
   437  		}
   438  
   439  		_ = bodyWriter.Close()
   440  	}()
   441  
   442  	return bodyReader, contentType, multipartLength, nil
   443  }
   444  
   445  // CallJSON runs Call and decodes the body as a JSON object into response (if not nil)
   446  //
   447  // If request is not nil then it will be JSON encoded as the body of the request.
   448  //
   449  // If response is not nil then the response will be JSON decoded into
   450  // it and resp.Body will be closed.
   451  //
   452  // If response is nil then the resp.Body will be closed only if
   453  // opts.NoResponse is set.
   454  //
   455  // If (opts.MultipartParams or opts.MultipartContentName) and
   456  // opts.Body are set then CallJSON will do a multipart upload with a
   457  // file attached.  opts.MultipartContentName is the name of the
   458  // parameter and opts.MultipartFileName is the name of the file.  If
   459  // MultipartContentName is set, and request != nil is supplied, then
   460  // the request will be marshalled into JSON and added to the form with
   461  // parameter name MultipartMetadataName.
   462  //
   463  // It will return resp if at all possible, even if err is set
   464  func (api *Client) CallJSON(ctx context.Context, opts *Opts, request interface{}, response interface{}) (resp *http.Response, err error) {
   465  	return api.callCodec(ctx, opts, request, response, json.Marshal, DecodeJSON, "application/json")
   466  }
   467  
   468  // CallXML runs Call and decodes the body as an XML object into response (if not nil)
   469  //
   470  // If request is not nil then it will be XML encoded as the body of the request.
   471  //
   472  // If response is not nil then the response will be XML decoded into
   473  // it and resp.Body will be closed.
   474  //
   475  // If response is nil then the resp.Body will be closed only if
   476  // opts.NoResponse is set.
   477  //
   478  // See CallJSON for a description of MultipartParams and related opts.
   479  //
   480  // It will return resp if at all possible, even if err is set
   481  func (api *Client) CallXML(ctx context.Context, opts *Opts, request interface{}, response interface{}) (resp *http.Response, err error) {
   482  	return api.callCodec(ctx, opts, request, response, xml.Marshal, DecodeXML, "application/xml")
   483  }
   484  
   485  type marshalFn func(v interface{}) ([]byte, error)
   486  type decodeFn func(resp *http.Response, result interface{}) (err error)
   487  
   488  func (api *Client) callCodec(ctx context.Context, opts *Opts, request interface{}, response interface{}, marshal marshalFn, decode decodeFn, contentType string) (resp *http.Response, err error) {
   489  	var requestBody []byte
   490  	// Marshal the request if given
   491  	if request != nil {
   492  		requestBody, err = marshal(request)
   493  		if err != nil {
   494  			return nil, err
   495  		}
   496  		// Set the body up as a marshalled object if no body passed in
   497  		if opts.Body == nil {
   498  			opts = opts.Copy()
   499  			opts.ContentType = contentType
   500  			opts.Body = bytes.NewBuffer(requestBody)
   501  		}
   502  	}
   503  	if opts.MultipartParams != nil || opts.MultipartContentName != "" {
   504  		params := opts.MultipartParams
   505  		if params == nil {
   506  			params = url.Values{}
   507  		}
   508  		if opts.MultipartMetadataName != "" {
   509  			params.Add(opts.MultipartMetadataName, string(requestBody))
   510  		}
   511  		opts = opts.Copy()
   512  
   513  		var overhead int64
   514  		opts.Body, opts.ContentType, overhead, err = MultipartUpload(ctx, opts.Body, params, opts.MultipartContentName, opts.MultipartFileName)
   515  		if err != nil {
   516  			return nil, err
   517  		}
   518  		if opts.ContentLength != nil {
   519  			*opts.ContentLength += overhead
   520  		}
   521  	}
   522  	resp, err = api.Call(ctx, opts)
   523  	if err != nil {
   524  		return resp, err
   525  	}
   526  	// if opts.NoResponse is set, resp.Body will have been closed by Call()
   527  	if response == nil || opts.NoResponse {
   528  		return resp, nil
   529  	}
   530  	err = decode(resp, response)
   531  	return resp, err
   532  }