github.com/xmidt-org/webpa-common@v1.11.9/xhttp/rewind_test.go (about) 1 package xhttp 2 3 import ( 4 "bytes" 5 "errors" 6 "io" 7 "io/ioutil" 8 "net/http" 9 "net/http/httptest" 10 "strings" 11 "testing" 12 13 "github.com/stretchr/testify/assert" 14 "github.com/stretchr/testify/mock" 15 "github.com/stretchr/testify/require" 16 ) 17 18 func TestNopCloser(t *testing.T) { 19 var ( 20 assert = assert.New(t) 21 require = require.New(t) 22 expectedBytes = []byte{9, 12, 74, 125, 22} 23 24 reader = bytes.NewReader(expectedBytes) 25 ) 26 27 rsc := NopCloser(reader) 28 require.NotNil(rsc) 29 actualBytes, err := ioutil.ReadAll(rsc) 30 assert.Equal(expectedBytes, actualBytes) 31 assert.NoError(err) 32 assert.NoError(rsc.Close()) 33 34 rsc2 := NopCloser(rsc) 35 require.NotNil(rsc2) 36 assert.True(rsc == rsc2) 37 assert.NoError(rsc2.Close()) 38 39 _, err = reader.Seek(0, 0) 40 assert.NoError(err) 41 42 actualBytes, err = ioutil.ReadAll(rsc2) 43 assert.Equal(expectedBytes, actualBytes) 44 assert.NoError(err) 45 } 46 47 func testNewRewindReadSeeker(t *testing.T) { 48 var ( 49 assert = assert.New(t) 50 require = require.New(t) 51 expectedBytes = []byte{9, 234, 12, 93, 41} 52 53 reader = bytes.NewReader(expectedBytes) 54 ) 55 56 body, getBody, err := NewRewind(reader) 57 assert.NoError(err) 58 require.NotNil(body) 59 require.NotNil(getBody) 60 61 actualBytes, err := ioutil.ReadAll(body) 62 assert.Equal(expectedBytes, actualBytes) 63 assert.NoError(err) 64 65 body2, err := getBody() 66 assert.NoError(err) 67 require.NotNil(body2) 68 assert.True(body == body2) 69 70 actualBytes, err = ioutil.ReadAll(body2) 71 assert.Equal(expectedBytes, actualBytes) 72 assert.NoError(err) 73 } 74 75 func testNewRewindReadError(t *testing.T) { 76 var ( 77 assert = assert.New(t) 78 expectedError = errors.New("expected") 79 80 reader = new(mockReader) 81 ) 82 83 reader.On("Read", mock.MatchedBy(func([]byte) bool { return true })).Return(0, expectedError).Once() 84 body, getBody, err := NewRewind(reader) 85 assert.Nil(body) 86 assert.Nil(getBody) 87 assert.Error(err) 88 89 reader.AssertExpectations(t) 90 } 91 92 func testNewRewindBuffer(t *testing.T) { 93 var ( 94 assert = assert.New(t) 95 require = require.New(t) 96 expectedBytes = []byte{9, 234, 12, 93, 41} 97 98 buffer = bytes.NewBuffer(expectedBytes) 99 ) 100 101 body, getBody, err := NewRewind(buffer) 102 assert.NoError(err) 103 require.NotNil(body) 104 require.NotNil(getBody) 105 106 actualBytes, err := ioutil.ReadAll(body) 107 assert.Equal(expectedBytes, actualBytes) 108 assert.NoError(err) 109 110 body2, err := getBody() 111 assert.NoError(err) 112 require.NotNil(body2) 113 assert.True(body == body2) 114 115 actualBytes, err = ioutil.ReadAll(body2) 116 assert.Equal(expectedBytes, actualBytes) 117 assert.NoError(err) 118 } 119 120 func TestNewRewind(t *testing.T) { 121 t.Run("ReadSeeker", testNewRewindReadSeeker) 122 t.Run("ReadError", testNewRewindReadError) 123 t.Run("Buffer", testNewRewindBuffer) 124 } 125 126 func TestNewRewindBytes(t *testing.T) { 127 var ( 128 assert = assert.New(t) 129 require = require.New(t) 130 expectedBytes = []byte{7, 234, 12, 9, 100} 131 ) 132 133 body, getBody := NewRewindBytes(expectedBytes) 134 require.NotNil(body) 135 require.NotNil(getBody) 136 137 actualBytes, err := ioutil.ReadAll(body) 138 assert.Equal(expectedBytes, actualBytes) 139 assert.NoError(err) 140 141 body2, err := getBody() 142 assert.NoError(err) 143 require.NotNil(body2) 144 assert.True(body == body2) 145 146 actualBytes, err = ioutil.ReadAll(body2) 147 assert.Equal(expectedBytes, actualBytes) 148 assert.NoError(err) 149 } 150 151 func testEnsureRewindableNoBody(t *testing.T) { 152 var ( 153 assert = assert.New(t) 154 r = new(http.Request) 155 ) 156 157 assert.NoError(EnsureRewindable(r)) 158 assert.Nil(r.Body) 159 assert.Nil(r.GetBody) 160 } 161 162 func testEnsureRewindableGetBody(t *testing.T) { 163 var ( 164 assert = assert.New(t) 165 require = require.New(t) 166 167 getBodyCalled = false 168 getBody = func() (io.ReadCloser, error) { 169 getBodyCalled = true 170 return nil, nil 171 } 172 173 r = &http.Request{ 174 GetBody: getBody, 175 } 176 ) 177 178 assert.NoError(EnsureRewindable(r)) 179 assert.Nil(r.Body) 180 require.NotNil(r.GetBody) 181 r.GetBody() 182 assert.True(getBodyCalled) 183 } 184 185 func testEnsureRewindableBodyNotRewindable(t *testing.T) { 186 var ( 187 assert = assert.New(t) 188 require = require.New(t) 189 expectedContents = []byte{6, 253, 12, 34} 190 191 r = &http.Request{ 192 Body: ioutil.NopCloser(bytes.NewReader(expectedContents)), 193 } 194 ) 195 196 assert.NoError(EnsureRewindable(r)) 197 198 require.NotNil(r.Body) 199 actualContents, err := ioutil.ReadAll(r.Body) 200 assert.Equal(expectedContents, actualContents) 201 assert.NoError(err) 202 203 require.NotNil(r.GetBody) 204 actualBuffer, err := r.GetBody() 205 require.NoError(err) 206 require.NotNil(actualBuffer) 207 actualContents, err = ioutil.ReadAll(actualBuffer) 208 assert.Equal(expectedContents, actualContents) 209 assert.NoError(err) 210 } 211 212 func testEnsureRewindableReadError(t *testing.T) { 213 var ( 214 assert = assert.New(t) 215 contents = new(mockReader) 216 expectedBody = ioutil.NopCloser(contents) 217 expectedError = errors.New("expected") 218 219 r = &http.Request{ 220 Body: expectedBody, 221 } 222 ) 223 224 contents.On("Read", mock.MatchedBy(func([]byte) bool { return true })).Return(0, expectedError).Once() 225 assert.Equal(expectedError, EnsureRewindable(r)) 226 assert.Nil(r.GetBody) 227 assert.True(expectedBody == r.Body) 228 229 contents.AssertExpectations(t) 230 } 231 232 func TestEnsureRewindable(t *testing.T) { 233 t.Run("NoBody", testEnsureRewindableNoBody) 234 t.Run("GetBody", testEnsureRewindableGetBody) 235 t.Run("BodyNotRewindable", testEnsureRewindableBodyNotRewindable) 236 t.Run("ReadError", testEnsureRewindableReadError) 237 } 238 239 func testRewindGetBodyError(t *testing.T) { 240 var ( 241 assert = assert.New(t) 242 expectedError = errors.New("expected") 243 244 getBody = func() (io.ReadCloser, error) { 245 return nil, expectedError 246 } 247 248 r = &http.Request{ 249 GetBody: getBody, 250 } 251 ) 252 253 assert.Equal(expectedError, Rewind(r)) 254 } 255 256 func testRewindGetBodySuccess(t *testing.T) { 257 var ( 258 assert = assert.New(t) 259 require = require.New(t) 260 expectedBytes = []byte{1, 7, 8, 5, 1, 16, 177} 261 262 getBody = func() (io.ReadCloser, error) { 263 return ioutil.NopCloser(bytes.NewReader(expectedBytes)), nil 264 } 265 266 r = &http.Request{ 267 GetBody: getBody, 268 } 269 ) 270 271 assert.NoError(Rewind(r)) 272 require.NotNil(r.Body) 273 274 actualBytes, err := ioutil.ReadAll(r.Body) 275 assert.Equal(expectedBytes, actualBytes) 276 assert.NoError(err) 277 } 278 279 func testRewindNoBody(t *testing.T) { 280 var ( 281 assert = assert.New(t) 282 r = new(http.Request) 283 ) 284 285 assert.NoError(Rewind(r)) 286 assert.Nil(r.Body) 287 assert.Nil(r.GetBody) 288 } 289 290 func testRewindCantRewind(t *testing.T) { 291 var ( 292 assert = assert.New(t) 293 r = httptest.NewRequest("POST", "/", strings.NewReader("hi there")) 294 ) 295 296 assert.Error(Rewind(r)) 297 assert.NotNil(r.Body) 298 assert.Nil(r.GetBody) 299 } 300 301 func TestRewind(t *testing.T) { 302 t.Run("GetBody", func(t *testing.T) { 303 t.Run("Error", testRewindGetBodyError) 304 t.Run("Success", testRewindGetBodySuccess) 305 }) 306 307 t.Run("NoBody", testRewindNoBody) 308 t.Run("CantRewind", testRewindCantRewind) 309 }