github.com/xmidt-org/webpa-common@v1.11.9/xhttp/context_test.go (about) 1 package xhttp 2 3 import ( 4 "context" 5 "errors" 6 "net/http" 7 "net/http/httptest" 8 "testing" 9 10 gokithttp "github.com/go-kit/kit/transport/http" 11 "github.com/stretchr/testify/assert" 12 "github.com/stretchr/testify/require" 13 ) 14 15 func testGetErrorEncoderDefault(t *testing.T) { 16 assert := assert.New(t) 17 assert.NotNil(GetErrorEncoder(context.Background())) 18 } 19 20 func testGetErrorEncoderCustom(t *testing.T) { 21 var ( 22 assert = assert.New(t) 23 require = require.New(t) 24 25 expectedCalled = false 26 expected gokithttp.ErrorEncoder = func(_ context.Context, _ error, _ http.ResponseWriter) { 27 expectedCalled = true 28 } 29 30 actual = GetErrorEncoder( 31 context.WithValue(context.Background(), errorEncoderKey{}, expected), 32 ) 33 ) 34 35 require.NotNil(actual) 36 actual(context.Background(), errors.New("expected"), httptest.NewRecorder()) 37 assert.True(expectedCalled) 38 } 39 40 func TestGetErrorEncoder(t *testing.T) { 41 t.Run("Default", testGetErrorEncoderDefault) 42 t.Run("Custom", testGetErrorEncoderCustom) 43 } 44 45 func testWithErrorEncoderDefault(t *testing.T) { 46 var ( 47 assert = assert.New(t) 48 ctx = WithErrorEncoder(context.Background(), nil) 49 ) 50 51 assert.Equal(context.Background(), ctx) 52 } 53 54 func testWithErrorEncoderCustom(t *testing.T) { 55 var ( 56 assert = assert.New(t) 57 require = require.New(t) 58 59 expectedCalled = false 60 expected gokithttp.ErrorEncoder = func(_ context.Context, _ error, _ http.ResponseWriter) { 61 expectedCalled = true 62 } 63 64 ctx = WithErrorEncoder(context.Background(), expected) 65 ) 66 67 require.NotNil(ctx) 68 actual, ok := ctx.Value(errorEncoderKey{}).(gokithttp.ErrorEncoder) 69 require.True(ok) 70 require.NotNil(actual) 71 72 actual(context.Background(), errors.New("expected"), httptest.NewRecorder()) 73 assert.True(expectedCalled) 74 } 75 76 func TestWithErrorEncoder(t *testing.T) { 77 t.Run("Default", testWithErrorEncoderDefault) 78 t.Run("Custom", testWithErrorEncoderCustom) 79 } 80 81 func testGetClientDefault(t *testing.T) { 82 assert := assert.New(t) 83 assert.Equal(http.DefaultClient, GetClient(context.Background())) 84 } 85 86 func testGetClientCustom(t *testing.T) { 87 var ( 88 assert = assert.New(t) 89 90 expected = new(http.Client) 91 actual = GetClient( 92 context.WithValue(context.Background(), httpClientKey{}, expected), 93 ) 94 ) 95 96 assert.Equal(expected, actual) 97 } 98 99 func TestGetClient(t *testing.T) { 100 t.Run("Default", testGetClientDefault) 101 t.Run("Custom", testGetClientCustom) 102 } 103 104 func testWithClientDefault(t *testing.T) { 105 var ( 106 assert = assert.New(t) 107 ctx = WithClient(context.Background(), nil) 108 ) 109 110 assert.Equal(context.Background(), ctx) 111 } 112 113 func testWithClientCustom(t *testing.T) { 114 var ( 115 assert = assert.New(t) 116 require = require.New(t) 117 118 expected = new(http.Client) 119 ctx = WithClient(context.Background(), expected) 120 ) 121 122 require.NotNil(ctx) 123 actual, ok := ctx.Value(httpClientKey{}).(Client) 124 require.True(ok) 125 require.NotNil(actual) 126 127 assert.Equal(expected, actual) 128 } 129 130 func TestWithClient(t *testing.T) { 131 t.Run("Default", testWithClientDefault) 132 t.Run("Custom", testWithClientCustom) 133 }