github.com/xmidt-org/webpa-common@v1.11.9/xhttp/xcontext/contextaware_test.go (about)

     1  package xcontext
     2  
     3  import (
     4  	"context"
     5  	"github.com/justinas/alice"
     6  	"github.com/stretchr/testify/assert"
     7  	"github.com/stretchr/testify/require"
     8  	"github.com/xmidt-org/webpa-common/logging"
     9  	"net/http"
    10  	"net/http/httptest"
    11  	"testing"
    12  )
    13  
    14  func TestSetContext(t *testing.T) {
    15  	assert := assert.New(t)
    16  
    17  	server := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
    18  		writer, request = WithContext(writer, request, request.Context())
    19  
    20  		assert.Panics(func() {
    21  			SetContext(writer, nil)
    22  		})
    23  		writer = SetContext(writer, context.WithValue(writer.(ContextAware).Context(), "key", "value"))
    24  		assert.Equal("value", writer.(ContextAware).Context().Value("key"))
    25  		writer.WriteHeader(200)
    26  		writer.Write([]byte("Hello World"))
    27  
    28  	}))
    29  	defer server.Close()
    30  
    31  	r, err := http.NewRequest("GET", server.URL, nil)
    32  	assert.NoError(err)
    33  	r = r.WithContext(logging.WithLogger(r.Context(), logging.New(nil)))
    34  	response, err := (&http.Client{}).Do(r)
    35  	assert.NoError(err)
    36  	assert.NotNil(response)
    37  }
    38  
    39  func TestSingleHandler(t *testing.T) {
    40  	assert := assert.New(t)
    41  	require := require.New(t)
    42  
    43  	server := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
    44  		writer, request = WithContext(writer, request, request.Context())
    45  		require.NotNil(writer)
    46  		require.NotNil(request)
    47  
    48  		writer.WriteHeader(200)
    49  		writer.Write([]byte("Hello World"))
    50  
    51  	}))
    52  	defer server.Close()
    53  
    54  	r, err := http.NewRequest("GET", server.URL, nil)
    55  	assert.NoError(err)
    56  	r = r.WithContext(logging.WithLogger(r.Context(), logging.New(nil)))
    57  	response, err := (&http.Client{}).Do(r)
    58  	assert.NoError(err)
    59  	assert.NotNil(response)
    60  }
    61  
    62  func TestChain(t *testing.T) {
    63  	assert := assert.New(t)
    64  	require := require.New(t)
    65  
    66  	body := "Hello World"
    67  	bodyKey := "body"
    68  
    69  	handler := http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
    70  		writer, request = WithContext(writer, request, request.Context())
    71  		require.NotNil(writer)
    72  		require.NotNil(request)
    73  
    74  		writer.WriteHeader(200)
    75  		writer.Write([]byte("Hello World"))
    76  
    77  		if writer, ok := writer.(ContextAware); ok {
    78  			writer.SetContext(context.WithValue(writer.Context(), bodyKey, body))
    79  		} else {
    80  			assert.Fail("Writer must be ContextAware")
    81  		}
    82  	})
    83  
    84  	chain := alice.New(func(next http.Handler) http.Handler {
    85  		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    86  			ctx := Context(w, r)
    87  			w, r = WithContext(w, r, ctx)
    88  			next.ServeHTTP(w, r)
    89  			if writer, ok := w.(ContextAware); ok {
    90  				assert.Equal(body, writer.Context().Value(bodyKey))
    91  			} else {
    92  				assert.Fail("Writer must be ContextAware")
    93  			}
    94  		})
    95  	})
    96  
    97  	server := httptest.NewServer(chain.Then(handler))
    98  	defer server.Close()
    99  
   100  	r, err := http.NewRequest("GET", server.URL, nil)
   101  	assert.NoError(err)
   102  	r = r.WithContext(logging.WithLogger(r.Context(), logging.New(nil)))
   103  	response, err := (&http.Client{}).Do(r)
   104  	assert.NoError(err)
   105  	assert.NotNil(response)
   106  }