github.com/xmidt-org/webpa-common@v1.11.9/xhttp/context_test.go (about)

     1  package xhttp
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"net/http"
     7  	"net/http/httptest"
     8  	"testing"
     9  
    10  	gokithttp "github.com/go-kit/kit/transport/http"
    11  	"github.com/stretchr/testify/assert"
    12  	"github.com/stretchr/testify/require"
    13  )
    14  
    15  func testGetErrorEncoderDefault(t *testing.T) {
    16  	assert := assert.New(t)
    17  	assert.NotNil(GetErrorEncoder(context.Background()))
    18  }
    19  
    20  func testGetErrorEncoderCustom(t *testing.T) {
    21  	var (
    22  		assert  = assert.New(t)
    23  		require = require.New(t)
    24  
    25  		expectedCalled                        = false
    26  		expected       gokithttp.ErrorEncoder = func(_ context.Context, _ error, _ http.ResponseWriter) {
    27  			expectedCalled = true
    28  		}
    29  
    30  		actual = GetErrorEncoder(
    31  			context.WithValue(context.Background(), errorEncoderKey{}, expected),
    32  		)
    33  	)
    34  
    35  	require.NotNil(actual)
    36  	actual(context.Background(), errors.New("expected"), httptest.NewRecorder())
    37  	assert.True(expectedCalled)
    38  }
    39  
    40  func TestGetErrorEncoder(t *testing.T) {
    41  	t.Run("Default", testGetErrorEncoderDefault)
    42  	t.Run("Custom", testGetErrorEncoderCustom)
    43  }
    44  
    45  func testWithErrorEncoderDefault(t *testing.T) {
    46  	var (
    47  		assert = assert.New(t)
    48  		ctx    = WithErrorEncoder(context.Background(), nil)
    49  	)
    50  
    51  	assert.Equal(context.Background(), ctx)
    52  }
    53  
    54  func testWithErrorEncoderCustom(t *testing.T) {
    55  	var (
    56  		assert  = assert.New(t)
    57  		require = require.New(t)
    58  
    59  		expectedCalled                        = false
    60  		expected       gokithttp.ErrorEncoder = func(_ context.Context, _ error, _ http.ResponseWriter) {
    61  			expectedCalled = true
    62  		}
    63  
    64  		ctx = WithErrorEncoder(context.Background(), expected)
    65  	)
    66  
    67  	require.NotNil(ctx)
    68  	actual, ok := ctx.Value(errorEncoderKey{}).(gokithttp.ErrorEncoder)
    69  	require.True(ok)
    70  	require.NotNil(actual)
    71  
    72  	actual(context.Background(), errors.New("expected"), httptest.NewRecorder())
    73  	assert.True(expectedCalled)
    74  }
    75  
    76  func TestWithErrorEncoder(t *testing.T) {
    77  	t.Run("Default", testWithErrorEncoderDefault)
    78  	t.Run("Custom", testWithErrorEncoderCustom)
    79  }
    80  
    81  func testGetClientDefault(t *testing.T) {
    82  	assert := assert.New(t)
    83  	assert.Equal(http.DefaultClient, GetClient(context.Background()))
    84  }
    85  
    86  func testGetClientCustom(t *testing.T) {
    87  	var (
    88  		assert = assert.New(t)
    89  
    90  		expected = new(http.Client)
    91  		actual   = GetClient(
    92  			context.WithValue(context.Background(), httpClientKey{}, expected),
    93  		)
    94  	)
    95  
    96  	assert.Equal(expected, actual)
    97  }
    98  
    99  func TestGetClient(t *testing.T) {
   100  	t.Run("Default", testGetClientDefault)
   101  	t.Run("Custom", testGetClientCustom)
   102  }
   103  
   104  func testWithClientDefault(t *testing.T) {
   105  	var (
   106  		assert = assert.New(t)
   107  		ctx    = WithClient(context.Background(), nil)
   108  	)
   109  
   110  	assert.Equal(context.Background(), ctx)
   111  }
   112  
   113  func testWithClientCustom(t *testing.T) {
   114  	var (
   115  		assert  = assert.New(t)
   116  		require = require.New(t)
   117  
   118  		expected = new(http.Client)
   119  		ctx      = WithClient(context.Background(), expected)
   120  	)
   121  
   122  	require.NotNil(ctx)
   123  	actual, ok := ctx.Value(httpClientKey{}).(Client)
   124  	require.True(ok)
   125  	require.NotNil(actual)
   126  
   127  	assert.Equal(expected, actual)
   128  }
   129  
   130  func TestWithClient(t *testing.T) {
   131  	t.Run("Default", testWithClientDefault)
   132  	t.Run("Custom", testWithClientCustom)
   133  }