github.com/grafana/pyroscope@v1.18.0/pkg/api/register_options_test.go (about)

     1  package api
     2  
     3  import (
     4  	"context"
     5  	"net/http"
     6  	"net/http/httptest"
     7  	"testing"
     8  
     9  	"github.com/go-kit/log"
    10  	"github.com/gorilla/mux"
    11  	"github.com/grafana/dskit/middleware"
    12  	"github.com/stretchr/testify/assert"
    13  	"github.com/stretchr/testify/require"
    14  )
    15  
    16  type contextKey uint8
    17  
    18  const (
    19  	contextKeyTest contextKey = iota
    20  )
    21  
    22  func newTestMiddleware(name string) middleware.Interface {
    23  	return middleware.Func(func(next http.Handler) http.Handler {
    24  		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    25  			ctx := r.Context()
    26  			middlewares, ok := ctx.Value(contextKeyTest).([]string)
    27  			if !ok {
    28  				middlewares = []string{}
    29  			}
    30  			middlewares = append(middlewares, name)
    31  			ctx = context.WithValue(ctx, contextKeyTest, middlewares)
    32  			next.ServeHTTP(w, r.WithContext(ctx))
    33  		})
    34  	})
    35  
    36  }
    37  
    38  func Test_registerRoute(t *testing.T) {
    39  	router := mux.NewRouter()
    40  	registerRoute(
    41  		log.NewNopLogger(),
    42  		router,
    43  		"/test",
    44  		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    45  			middlewares := r.Context().Value(contextKeyTest).([]string)
    46  			assert.Equal(t, []string{"outer", "middle", "inner"}, middlewares)
    47  
    48  			w.WriteHeader(http.StatusOK)
    49  		}),
    50  		func(r *registerParams) {
    51  			r.middlewares = append(r.middlewares, registerMiddleware{newTestMiddleware("outer"), "outer"})
    52  		},
    53  		func(r *registerParams) {
    54  			r.middlewares = append(r.middlewares, registerMiddleware{newTestMiddleware("middle"), "middle"})
    55  		},
    56  		func(r *registerParams) {
    57  			r.middlewares = append(r.middlewares, registerMiddleware{newTestMiddleware("inner"), "inner"})
    58  		},
    59  	)
    60  
    61  	testServer := httptest.NewServer(router)
    62  	defer testServer.Close()
    63  
    64  	req, err := http.NewRequest("GET", testServer.URL+"/test", nil)
    65  	require.NoError(t, err)
    66  
    67  	resp, err := testServer.Client().Do(req)
    68  	require.NoError(t, err)
    69  	assert.Equal(t, http.StatusOK, resp.StatusCode)
    70  }