github.com/xmidt-org/webpa-common@v1.11.9/xhttp/rewind.go (about)

     1  package xhttp
     2  
     3  import (
     4  	"bytes"
     5  	"errors"
     6  	"io"
     7  	"io/ioutil"
     8  	"net/http"
     9  )
    10  
    11  var errNotRewindable = errors.New("That request is not rewindable")
    12  
    13  // ReadSeekerCloser combines the behavior of io.Reader, io.Seeker, and io.Closer.
    14  // This package uses this interface for basic optimizations.
    15  type ReadSeekerCloser interface {
    16  	io.ReadSeeker
    17  	io.Closer
    18  }
    19  
    20  type closeAdapter struct {
    21  	io.ReadSeeker
    22  }
    23  
    24  func (ca closeAdapter) Close() error {
    25  	return nil
    26  }
    27  
    28  // NopCloser is an analog of ioutil.NopCloser.  This function preserves io.Seeker semantics in
    29  // the returned instance.  Additionally, if rs already implements io.Closer, this function
    30  // returns rs as is.
    31  func NopCloser(rs io.ReadSeeker) ReadSeekerCloser {
    32  	if rsc, ok := rs.(ReadSeekerCloser); ok {
    33  		return rsc
    34  	}
    35  
    36  	return closeAdapter{rs}
    37  }
    38  
    39  // NewRewind extracts all remaining bytes from an io.Reader, then uses NewRewindableBytes
    40  // to produce a body and a get body function.  If any error occurred during reading, that error
    41  // is returned and the other return values will be nil.
    42  //
    43  // This function performs certain optimizations on the returned body and get body function.  If
    44  // r implements io.Seeker, then a get body function that simply invokes Seek(0, 0) is used.
    45  // Additionally, this function honors the case where r implements io.Closer, preserving its
    46  // Close() semantics.
    47  func NewRewind(r io.Reader) (io.ReadCloser, func() (io.ReadCloser, error), error) {
    48  	if rs, ok := r.(io.ReadSeeker); ok {
    49  		// no need to bother reading bytes
    50  		rsc := NopCloser(rs)
    51  
    52  		return rsc,
    53  			func() (io.ReadCloser, error) {
    54  				_, err := rsc.Seek(0, 0)
    55  				return rsc, err
    56  			}, nil
    57  	}
    58  
    59  	b, err := ioutil.ReadAll(r)
    60  	if err != nil {
    61  		return nil, nil, err
    62  	}
    63  
    64  	body, getBody := NewRewindBytes(b)
    65  	return body, getBody, nil
    66  }
    67  
    68  // NewRewindBytes produces both an io.ReadCloser that returns the given bytes
    69  // and a function that produces a new io.ReadCloser that returns those same bytes.
    70  // Both return values from this function are appropriate for http.Request.Body and
    71  // http.Request.GetBody, respectively.
    72  func NewRewindBytes(b []byte) (io.ReadCloser, func() (io.ReadCloser, error)) {
    73  	rsc := NopCloser(bytes.NewReader(b))
    74  	return rsc,
    75  		func() (io.ReadCloser, error) {
    76  			_, err := rsc.Seek(0, 0)
    77  			return rsc, err
    78  		}
    79  }
    80  
    81  // EnsureRewindable configures the given request's contents to be restreamed in the event
    82  // of a redirect or other arbitrary code that must resubmit a request.  If this function
    83  // is successful, Rewind can be used to rewind the request.
    84  //
    85  // If a GetBody function is already present on the request, this function does nothing
    86  // as the given request is already rewindable.  Additionally, if there is no Body on the request,
    87  // this function does nothing as there's no body to rewind.
    88  func EnsureRewindable(r *http.Request) error {
    89  	if r.GetBody != nil || r.Body == nil {
    90  		return nil
    91  	}
    92  
    93  	body, getBody, err := NewRewind(r.Body)
    94  	if err != nil {
    95  		return err
    96  	}
    97  
    98  	r.Body = body
    99  	r.GetBody = getBody
   100  	return nil
   101  }
   102  
   103  // Rewind prepares a request body to be replayed.  If a GetBody function is present,
   104  // that function is invoked.  An error is returned if this function could not rewind the request.
   105  func Rewind(r *http.Request) error {
   106  	if r.GetBody != nil {
   107  		b, err := r.GetBody()
   108  		if err != nil {
   109  			return err
   110  		}
   111  
   112  		r.Body = b
   113  		return nil
   114  	}
   115  
   116  	if r.Body == nil {
   117  		// this request has no body, so it is always "rewound"
   118  		return nil
   119  	}
   120  
   121  	return errNotRewindable
   122  }