github.com/xmidt-org/webpa-common@v1.11.9/xhttp/rewind_test.go (about)

     1  package xhttp
     2  
     3  import (
     4  	"bytes"
     5  	"errors"
     6  	"io"
     7  	"io/ioutil"
     8  	"net/http"
     9  	"net/http/httptest"
    10  	"strings"
    11  	"testing"
    12  
    13  	"github.com/stretchr/testify/assert"
    14  	"github.com/stretchr/testify/mock"
    15  	"github.com/stretchr/testify/require"
    16  )
    17  
    18  func TestNopCloser(t *testing.T) {
    19  	var (
    20  		assert        = assert.New(t)
    21  		require       = require.New(t)
    22  		expectedBytes = []byte{9, 12, 74, 125, 22}
    23  
    24  		reader = bytes.NewReader(expectedBytes)
    25  	)
    26  
    27  	rsc := NopCloser(reader)
    28  	require.NotNil(rsc)
    29  	actualBytes, err := ioutil.ReadAll(rsc)
    30  	assert.Equal(expectedBytes, actualBytes)
    31  	assert.NoError(err)
    32  	assert.NoError(rsc.Close())
    33  
    34  	rsc2 := NopCloser(rsc)
    35  	require.NotNil(rsc2)
    36  	assert.True(rsc == rsc2)
    37  	assert.NoError(rsc2.Close())
    38  
    39  	_, err = reader.Seek(0, 0)
    40  	assert.NoError(err)
    41  
    42  	actualBytes, err = ioutil.ReadAll(rsc2)
    43  	assert.Equal(expectedBytes, actualBytes)
    44  	assert.NoError(err)
    45  }
    46  
    47  func testNewRewindReadSeeker(t *testing.T) {
    48  	var (
    49  		assert        = assert.New(t)
    50  		require       = require.New(t)
    51  		expectedBytes = []byte{9, 234, 12, 93, 41}
    52  
    53  		reader = bytes.NewReader(expectedBytes)
    54  	)
    55  
    56  	body, getBody, err := NewRewind(reader)
    57  	assert.NoError(err)
    58  	require.NotNil(body)
    59  	require.NotNil(getBody)
    60  
    61  	actualBytes, err := ioutil.ReadAll(body)
    62  	assert.Equal(expectedBytes, actualBytes)
    63  	assert.NoError(err)
    64  
    65  	body2, err := getBody()
    66  	assert.NoError(err)
    67  	require.NotNil(body2)
    68  	assert.True(body == body2)
    69  
    70  	actualBytes, err = ioutil.ReadAll(body2)
    71  	assert.Equal(expectedBytes, actualBytes)
    72  	assert.NoError(err)
    73  }
    74  
    75  func testNewRewindReadError(t *testing.T) {
    76  	var (
    77  		assert        = assert.New(t)
    78  		expectedError = errors.New("expected")
    79  
    80  		reader = new(mockReader)
    81  	)
    82  
    83  	reader.On("Read", mock.MatchedBy(func([]byte) bool { return true })).Return(0, expectedError).Once()
    84  	body, getBody, err := NewRewind(reader)
    85  	assert.Nil(body)
    86  	assert.Nil(getBody)
    87  	assert.Error(err)
    88  
    89  	reader.AssertExpectations(t)
    90  }
    91  
    92  func testNewRewindBuffer(t *testing.T) {
    93  	var (
    94  		assert        = assert.New(t)
    95  		require       = require.New(t)
    96  		expectedBytes = []byte{9, 234, 12, 93, 41}
    97  
    98  		buffer = bytes.NewBuffer(expectedBytes)
    99  	)
   100  
   101  	body, getBody, err := NewRewind(buffer)
   102  	assert.NoError(err)
   103  	require.NotNil(body)
   104  	require.NotNil(getBody)
   105  
   106  	actualBytes, err := ioutil.ReadAll(body)
   107  	assert.Equal(expectedBytes, actualBytes)
   108  	assert.NoError(err)
   109  
   110  	body2, err := getBody()
   111  	assert.NoError(err)
   112  	require.NotNil(body2)
   113  	assert.True(body == body2)
   114  
   115  	actualBytes, err = ioutil.ReadAll(body2)
   116  	assert.Equal(expectedBytes, actualBytes)
   117  	assert.NoError(err)
   118  }
   119  
   120  func TestNewRewind(t *testing.T) {
   121  	t.Run("ReadSeeker", testNewRewindReadSeeker)
   122  	t.Run("ReadError", testNewRewindReadError)
   123  	t.Run("Buffer", testNewRewindBuffer)
   124  }
   125  
   126  func TestNewRewindBytes(t *testing.T) {
   127  	var (
   128  		assert        = assert.New(t)
   129  		require       = require.New(t)
   130  		expectedBytes = []byte{7, 234, 12, 9, 100}
   131  	)
   132  
   133  	body, getBody := NewRewindBytes(expectedBytes)
   134  	require.NotNil(body)
   135  	require.NotNil(getBody)
   136  
   137  	actualBytes, err := ioutil.ReadAll(body)
   138  	assert.Equal(expectedBytes, actualBytes)
   139  	assert.NoError(err)
   140  
   141  	body2, err := getBody()
   142  	assert.NoError(err)
   143  	require.NotNil(body2)
   144  	assert.True(body == body2)
   145  
   146  	actualBytes, err = ioutil.ReadAll(body2)
   147  	assert.Equal(expectedBytes, actualBytes)
   148  	assert.NoError(err)
   149  }
   150  
   151  func testEnsureRewindableNoBody(t *testing.T) {
   152  	var (
   153  		assert = assert.New(t)
   154  		r      = new(http.Request)
   155  	)
   156  
   157  	assert.NoError(EnsureRewindable(r))
   158  	assert.Nil(r.Body)
   159  	assert.Nil(r.GetBody)
   160  }
   161  
   162  func testEnsureRewindableGetBody(t *testing.T) {
   163  	var (
   164  		assert  = assert.New(t)
   165  		require = require.New(t)
   166  
   167  		getBodyCalled = false
   168  		getBody       = func() (io.ReadCloser, error) {
   169  			getBodyCalled = true
   170  			return nil, nil
   171  		}
   172  
   173  		r = &http.Request{
   174  			GetBody: getBody,
   175  		}
   176  	)
   177  
   178  	assert.NoError(EnsureRewindable(r))
   179  	assert.Nil(r.Body)
   180  	require.NotNil(r.GetBody)
   181  	r.GetBody()
   182  	assert.True(getBodyCalled)
   183  }
   184  
   185  func testEnsureRewindableBodyNotRewindable(t *testing.T) {
   186  	var (
   187  		assert           = assert.New(t)
   188  		require          = require.New(t)
   189  		expectedContents = []byte{6, 253, 12, 34}
   190  
   191  		r = &http.Request{
   192  			Body: ioutil.NopCloser(bytes.NewReader(expectedContents)),
   193  		}
   194  	)
   195  
   196  	assert.NoError(EnsureRewindable(r))
   197  
   198  	require.NotNil(r.Body)
   199  	actualContents, err := ioutil.ReadAll(r.Body)
   200  	assert.Equal(expectedContents, actualContents)
   201  	assert.NoError(err)
   202  
   203  	require.NotNil(r.GetBody)
   204  	actualBuffer, err := r.GetBody()
   205  	require.NoError(err)
   206  	require.NotNil(actualBuffer)
   207  	actualContents, err = ioutil.ReadAll(actualBuffer)
   208  	assert.Equal(expectedContents, actualContents)
   209  	assert.NoError(err)
   210  }
   211  
   212  func testEnsureRewindableReadError(t *testing.T) {
   213  	var (
   214  		assert        = assert.New(t)
   215  		contents      = new(mockReader)
   216  		expectedBody  = ioutil.NopCloser(contents)
   217  		expectedError = errors.New("expected")
   218  
   219  		r = &http.Request{
   220  			Body: expectedBody,
   221  		}
   222  	)
   223  
   224  	contents.On("Read", mock.MatchedBy(func([]byte) bool { return true })).Return(0, expectedError).Once()
   225  	assert.Equal(expectedError, EnsureRewindable(r))
   226  	assert.Nil(r.GetBody)
   227  	assert.True(expectedBody == r.Body)
   228  
   229  	contents.AssertExpectations(t)
   230  }
   231  
   232  func TestEnsureRewindable(t *testing.T) {
   233  	t.Run("NoBody", testEnsureRewindableNoBody)
   234  	t.Run("GetBody", testEnsureRewindableGetBody)
   235  	t.Run("BodyNotRewindable", testEnsureRewindableBodyNotRewindable)
   236  	t.Run("ReadError", testEnsureRewindableReadError)
   237  }
   238  
   239  func testRewindGetBodyError(t *testing.T) {
   240  	var (
   241  		assert        = assert.New(t)
   242  		expectedError = errors.New("expected")
   243  
   244  		getBody = func() (io.ReadCloser, error) {
   245  			return nil, expectedError
   246  		}
   247  
   248  		r = &http.Request{
   249  			GetBody: getBody,
   250  		}
   251  	)
   252  
   253  	assert.Equal(expectedError, Rewind(r))
   254  }
   255  
   256  func testRewindGetBodySuccess(t *testing.T) {
   257  	var (
   258  		assert        = assert.New(t)
   259  		require       = require.New(t)
   260  		expectedBytes = []byte{1, 7, 8, 5, 1, 16, 177}
   261  
   262  		getBody = func() (io.ReadCloser, error) {
   263  			return ioutil.NopCloser(bytes.NewReader(expectedBytes)), nil
   264  		}
   265  
   266  		r = &http.Request{
   267  			GetBody: getBody,
   268  		}
   269  	)
   270  
   271  	assert.NoError(Rewind(r))
   272  	require.NotNil(r.Body)
   273  
   274  	actualBytes, err := ioutil.ReadAll(r.Body)
   275  	assert.Equal(expectedBytes, actualBytes)
   276  	assert.NoError(err)
   277  }
   278  
   279  func testRewindNoBody(t *testing.T) {
   280  	var (
   281  		assert = assert.New(t)
   282  		r      = new(http.Request)
   283  	)
   284  
   285  	assert.NoError(Rewind(r))
   286  	assert.Nil(r.Body)
   287  	assert.Nil(r.GetBody)
   288  }
   289  
   290  func testRewindCantRewind(t *testing.T) {
   291  	var (
   292  		assert = assert.New(t)
   293  		r      = httptest.NewRequest("POST", "/", strings.NewReader("hi there"))
   294  	)
   295  
   296  	assert.Error(Rewind(r))
   297  	assert.NotNil(r.Body)
   298  	assert.Nil(r.GetBody)
   299  }
   300  
   301  func TestRewind(t *testing.T) {
   302  	t.Run("GetBody", func(t *testing.T) {
   303  		t.Run("Error", testRewindGetBodyError)
   304  		t.Run("Success", testRewindGetBodySuccess)
   305  	})
   306  
   307  	t.Run("NoBody", testRewindNoBody)
   308  	t.Run("CantRewind", testRewindCantRewind)
   309  }