github.com/alwitt/goutils@v0.6.4/rest_test.go (about)

     1  package goutils_test
     2  
     3  import (
     4  	"bufio"
     5  	"context"
     6  	"encoding/json"
     7  	"fmt"
     8  	"math/rand"
     9  	"net/http"
    10  	"net/http/httptest"
    11  	"sync"
    12  	"testing"
    13  	"time"
    14  
    15  	"github.com/alwitt/goutils"
    16  	"github.com/apex/log"
    17  	"github.com/google/uuid"
    18  	"github.com/gorilla/mux"
    19  	"github.com/stretchr/testify/assert"
    20  )
    21  
    22  func TestRestAPIHandlerRequestIDInjection(t *testing.T) {
    23  	assert := assert.New(t)
    24  	log.SetLevel(log.DebugLevel)
    25  
    26  	// Case 0: no user request ID header defined
    27  	uutNoUserRequestIDHeader := goutils.RestAPIHandler{
    28  		Component: goutils.Component{
    29  			LogTags: log.Fields{"entity": "unit-tester"},
    30  			LogTagModifiers: []goutils.LogMetadataModifier{
    31  				goutils.ModifyLogMetadataByRestRequestParam,
    32  			},
    33  		},
    34  		CallRequestIDHeaderField: nil,
    35  		LogLevel:                 goutils.HTTPLogLevelDEBUG,
    36  	}
    37  	{
    38  		rid := uuid.New().String()
    39  		req, err := http.NewRequest("GET", "/testing", nil)
    40  		assert.Nil(err)
    41  		req.Header.Add("Request-ID", rid)
    42  
    43  		dummyHandler := func(w http.ResponseWriter, r *http.Request) {
    44  			callContext := r.Context()
    45  			assert.NotNil(callContext.Value(goutils.RestRequestParamKey{}))
    46  			v, ok := callContext.Value(goutils.RestRequestParamKey{}).(goutils.RestRequestParam)
    47  			assert.True(ok)
    48  			assert.NotEqual(rid, v.ID)
    49  			assert.Equal("GET", v.Method)
    50  			assert.Equal("/testing", v.URI)
    51  		}
    52  
    53  		router := mux.NewRouter()
    54  		respRecorder := httptest.NewRecorder()
    55  		router.HandleFunc("/testing", uutNoUserRequestIDHeader.LoggingMiddleware(dummyHandler))
    56  		router.ServeHTTP(respRecorder, req)
    57  
    58  		assert.Equal(http.StatusOK, respRecorder.Code)
    59  		assert.Equal("", (respRecorder.Header().Get("Request-ID")))
    60  	}
    61  
    62  	// Case 1: user request ID header defined
    63  	testReqIDHeader := uuid.New().String()
    64  	uutWithUserRequestIDHeader := goutils.RestAPIHandler{
    65  		Component: goutils.Component{
    66  			LogTags: log.Fields{"entity": "unit-tester"},
    67  			LogTagModifiers: []goutils.LogMetadataModifier{
    68  				goutils.ModifyLogMetadataByRestRequestParam,
    69  			},
    70  		},
    71  		CallRequestIDHeaderField: &testReqIDHeader,
    72  		LogLevel:                 goutils.HTTPLogLevelINFO,
    73  	}
    74  	{
    75  		rid := uuid.New().String()
    76  		req, err := http.NewRequest("DELETE", "/testing2", nil)
    77  		assert.Nil(err)
    78  		req.Header.Add(testReqIDHeader, rid)
    79  
    80  		dummyHandler := func(w http.ResponseWriter, r *http.Request) {
    81  			callContext := r.Context()
    82  			assert.NotNil(callContext.Value(goutils.RestRequestParamKey{}))
    83  			v, ok := callContext.Value(goutils.RestRequestParamKey{}).(goutils.RestRequestParam)
    84  			assert.True(ok)
    85  			assert.Equal(rid, v.ID)
    86  			assert.Equal("DELETE", v.Method)
    87  			assert.Equal("/testing2", v.URI)
    88  		}
    89  
    90  		router := mux.NewRouter()
    91  		respRecorder := httptest.NewRecorder()
    92  		router.HandleFunc("/testing2", uutWithUserRequestIDHeader.LoggingMiddleware(dummyHandler))
    93  		router.ServeHTTP(respRecorder, req)
    94  
    95  		assert.Equal(http.StatusOK, respRecorder.Code)
    96  		assert.Equal(rid, (respRecorder.Header().Get(testReqIDHeader)))
    97  	}
    98  }
    99  
   100  func TestRestAPIHandlerRequestLogging(t *testing.T) {
   101  	assert := assert.New(t)
   102  	log.SetLevel(log.DebugLevel)
   103  
   104  	uut := goutils.RestAPIHandler{
   105  		Component: goutils.Component{
   106  			LogTags: log.Fields{"entity": "unit-tester"},
   107  			LogTagModifiers: []goutils.LogMetadataModifier{
   108  				goutils.ModifyLogMetadataByRestRequestParam,
   109  			},
   110  		},
   111  		DoNotLogHeaders: map[string]bool{"Not-Allowed": true},
   112  		LogLevel:        goutils.HTTPLogLevelDEBUG,
   113  	}
   114  	{
   115  		value1 := uuid.New().String()
   116  		value2 := uuid.New().String()
   117  		req, err := http.NewRequest("GET", "/testing", nil)
   118  		assert.Nil(err)
   119  		req.Header.Add("Allowed", value1)
   120  		req.Header.Add("Not-Allowed", value2)
   121  
   122  		dummyHandler := func(w http.ResponseWriter, r *http.Request) {
   123  			callContext := r.Context()
   124  			assert.NotNil(callContext.Value(goutils.RestRequestParamKey{}))
   125  			v, ok := callContext.Value(goutils.RestRequestParamKey{}).(goutils.RestRequestParam)
   126  			assert.True(ok)
   127  			assert.Equal("GET", v.Method)
   128  			assert.Equal("/testing", v.URI)
   129  			assert.Equal(value1, v.RequestHeaders.Get("Allowed"))
   130  			assert.Equal("", v.RequestHeaders.Get("Not-Allowed"))
   131  		}
   132  
   133  		router := mux.NewRouter()
   134  		respRecorder := httptest.NewRecorder()
   135  		router.HandleFunc("/testing", uut.LoggingMiddleware(dummyHandler))
   136  		router.ServeHTTP(respRecorder, req)
   137  
   138  		assert.Equal(http.StatusOK, respRecorder.Code)
   139  	}
   140  }
   141  
   142  func TestRestAPIHandlerProcessStreamingEndpoints(t *testing.T) {
   143  	assert := assert.New(t)
   144  	log.SetLevel(log.DebugLevel)
   145  
   146  	testReqIDHeader := uuid.New().String()
   147  	uut := goutils.RestAPIHandler{
   148  		Component: goutils.Component{
   149  			LogTags: log.Fields{"entity": "unit-tester"},
   150  			LogTagModifiers: []goutils.LogMetadataModifier{
   151  				goutils.ModifyLogMetadataByRestRequestParam,
   152  			},
   153  		},
   154  		CallRequestIDHeaderField: &testReqIDHeader,
   155  		LogLevel:                 goutils.HTTPLogLevelDEBUG,
   156  	}
   157  
   158  	type testMessage struct {
   159  		Timestamp time.Time
   160  		Msg       string
   161  	}
   162  
   163  	testMsgTX := make(chan testMessage, 1)
   164  	testMsgRX := make(chan testMessage, 1)
   165  
   166  	wg := sync.WaitGroup{}
   167  	defer wg.Wait()
   168  	utCtxt, ctxtCancel := context.WithCancel(context.Background())
   169  	defer ctxtCancel()
   170  
   171  	// Define streaming data handler
   172  	testHandler := func(w http.ResponseWriter, r *http.Request) {
   173  		flusher, ok := w.(http.Flusher)
   174  		assert.True(ok)
   175  		w.Header().Set("Content-Type", "text/event-stream")
   176  		w.Header().Set("Cache-Control", "no-cache")
   177  		w.Header().Set("Connection", "keep-alive")
   178  		w.Header().Set("Access-Control-Allow-Origin", "*")
   179  
   180  		log.Debug("Starting stream response handler")
   181  		complete := false
   182  		for !complete {
   183  			select {
   184  			case <-utCtxt.Done():
   185  				complete = true
   186  			case msg, ok := <-testMsgTX:
   187  				assert.True(ok)
   188  				t, err := json.Marshal(&msg)
   189  				assert.Nil(err)
   190  				fmt.Fprintf(w, "%s\n", t)
   191  				flusher.Flush()
   192  				log.Debugf("Sent %s\n", t)
   193  			}
   194  		}
   195  		log.Debug("Stoping stream response handler")
   196  	}
   197  
   198  	router := mux.NewRouter()
   199  	router.HandleFunc("/testing", uut.LoggingMiddleware(testHandler))
   200  
   201  	// Define HTTP server
   202  	testServerPort := rand.Intn(30000) + 32769
   203  	testServerListen := fmt.Sprintf("127.0.0.1:%d", testServerPort)
   204  	testServer := &http.Server{
   205  		Addr:    testServerListen,
   206  		Handler: router,
   207  	}
   208  	// Start the HTTP server
   209  	log.Debugf("Starting test server on %s", testServerListen)
   210  	wg.Add(1)
   211  	go func() {
   212  		defer wg.Done()
   213  		if err := testServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
   214  			assert.Nil(err)
   215  		}
   216  		log.Debugf("Stopped test server on %s", testServerListen)
   217  	}()
   218  	defer func() {
   219  		// Helper function to shutdown the server
   220  		ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
   221  		defer cancel()
   222  		if err := testServer.Shutdown(ctx); err != nil {
   223  			assert.Nil(err)
   224  		}
   225  	}()
   226  
   227  	// Define test HTTP client
   228  	testClient := http.Client{}
   229  	req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/testing", testServerListen), nil)
   230  	assert.Nil(err)
   231  	testRID := uuid.New().String()
   232  	req.Header.Add(testReqIDHeader, testRID)
   233  
   234  	// Make the request in another thread
   235  	wg.Add(1)
   236  	go func() {
   237  		defer wg.Done()
   238  		var resp *http.Response
   239  		var err error
   240  		for i := 0; i < 3; i++ {
   241  			log.Debug("Connecting to test server")
   242  			resp, err = testClient.Do(req)
   243  			if err == nil {
   244  				break
   245  			}
   246  			time.Sleep(time.Millisecond * 25)
   247  		}
   248  		log.Debugf("Connected to test server http://%s/testing", testServerListen)
   249  		assert.Nil(err)
   250  		assert.Equal(http.StatusOK, resp.StatusCode)
   251  		assert.Equal(testRID, resp.Header.Get(testReqIDHeader))
   252  		// Process the resp stream
   253  		scanner := bufio.NewScanner(resp.Body)
   254  		scanner.Split(bufio.ScanLines)
   255  		log.Debug("Scanning SSE stream")
   256  		for scanner.Scan() {
   257  			received := scanner.Text()
   258  			log.Debugf("Received: %s", received)
   259  			var parsed testMessage
   260  			assert.Nil(json.Unmarshal([]byte(received), &parsed))
   261  			testMsgRX <- parsed
   262  		}
   263  		log.Debug("Stopped scanner")
   264  	}()
   265  
   266  	// Send message multiple times
   267  	for i := 0; i < 4; i++ {
   268  		newMsg := testMessage{Timestamp: time.Now(), Msg: uuid.New().String()}
   269  		testMsgTX <- newMsg
   270  		ctxt, lclCancel := context.WithTimeout(utCtxt, time.Millisecond*100)
   271  		defer lclCancel()
   272  		select {
   273  		case <-ctxt.Done():
   274  			assert.Nil(ctxt.Err())
   275  		case rx, ok := <-testMsgRX:
   276  			assert.True(ok)
   277  			assert.Equal(newMsg.Timestamp.UnixMicro(), rx.Timestamp.UnixMicro())
   278  			assert.Equal(newMsg.Msg, rx.Msg)
   279  		}
   280  	}
   281  
   282  	// Allow for clean shutdown
   283  	ctxtCancel()
   284  }