github.com/grafana/pyroscope@v1.18.0/pkg/util/connectgrpc/connectgrpc_test.go (about) 1 package connectgrpc 2 3 import ( 4 "bytes" 5 "context" 6 "net/http" 7 "net/http/httptest" 8 "strings" 9 "testing" 10 11 "connectrpc.com/connect" 12 "github.com/gorilla/mux" 13 "github.com/samber/lo" 14 "github.com/stretchr/testify/require" 15 "golang.org/x/net/http2" 16 "golang.org/x/net/http2/h2c" 17 18 "github.com/grafana/pyroscope/api/gen/proto/go/querier/v1/querierv1connect" 19 typesv1 "github.com/grafana/pyroscope/api/gen/proto/go/types/v1" 20 "github.com/grafana/pyroscope/pkg/util/httpgrpc" 21 ) 22 23 type fakeQuerier struct { 24 querierv1connect.UnimplementedQuerierServiceHandler 25 req *connect.Request[typesv1.LabelValuesRequest] 26 resp *connect.Response[typesv1.LabelValuesResponse] 27 } 28 29 func (f *fakeQuerier) LabelValues(_ context.Context, req *connect.Request[typesv1.LabelValuesRequest]) (*connect.Response[typesv1.LabelValuesResponse], error) { 30 f.req = req 31 return f.resp, nil 32 } 33 34 type mockRoundTripper struct { 35 req *httpgrpc.HTTPRequest 36 resp *httpgrpc.HTTPResponse 37 } 38 39 func (m *mockRoundTripper) RoundTripGRPC(_ context.Context, req *httpgrpc.HTTPRequest) (*httpgrpc.HTTPResponse, error) { 40 m.req = req 41 return m.resp, nil 42 } 43 44 func headerToSlice(t testing.TB, header http.Header) []string { 45 buf := new(bytes.Buffer) 46 excludeHeaders := map[string]bool{"Content-Length": true, "Date": true} 47 require.NoError(t, header.WriteSubset(buf, excludeHeaders)) 48 sl := strings.Split(strings.ReplaceAll(buf.String(), "\r\n", "\n"), "\n") 49 if len(sl) > 0 && sl[len(sl)-1] == "" { 50 sl = sl[:len(sl)-1] 51 } 52 return sl 53 } 54 55 func Test_RoundTripUnary(t *testing.T) { 56 request := func(t *testing.T) *connect.Request[typesv1.LabelValuesRequest] { 57 server := httptest.NewUnstartedServer(nil) 58 mux := mux.NewRouter() 59 server.Config.Handler = h2c.NewHandler(mux, &http2.Server{}) 60 61 server.Start() 62 defer server.Close() 63 f := &fakeQuerier{resp: &connect.Response[typesv1.LabelValuesResponse]{ 64 Msg: &typesv1.LabelValuesResponse{Names: []string{"foo", "bar"}}, 65 }} 66 querierv1connect.RegisterQuerierServiceHandler(mux, f) 67 68 client := querierv1connect.NewQuerierServiceClient(http.DefaultClient, server.URL) 69 req := &typesv1.LabelValuesRequest{ 70 Name: "foo", 71 } 72 _, err := client.LabelValues(context.Background(), connect.NewRequest(req)) 73 require.NoError(t, err) 74 return f.req 75 } 76 77 t.Run("HTTP request can trip GRPC", func(t *testing.T) { 78 req := request(t) 79 m := &mockRoundTripper{resp: &httpgrpc.HTTPResponse{ 80 Code: 200, 81 Headers: []*httpgrpc.Header{ 82 {Key: "Content-Type", Values: []string{"application/proto"}}, 83 {Key: "X-My-App", Values: []string{"foobar"}}, 84 }, 85 }} 86 87 resp, err := RoundTripUnary[typesv1.LabelValuesRequest, typesv1.LabelValuesResponse](context.Background(), m, req) 88 require.NoError(t, err) 89 require.Equal(t, "POST", m.req.Method) 90 require.Equal(t, "/querier.v1.QuerierService/LabelValues", m.req.Url) 91 actualHeaders := lo.Map(m.req.Headers, func(h *httpgrpc.Header, index int) string { 92 return h.Key + ": " + strings.Join(h.Values, ",") 93 }) 94 require.Contains(t, actualHeaders, "Content-Type: application/proto") 95 require.Contains(t, actualHeaders, "Connect-Protocol-Version: 1") 96 require.Contains(t, actualHeaders, "Accept-Encoding: gzip") 97 98 decoded, err := decodeRequest[typesv1.LabelValuesRequest](m.req) 99 require.NoError(t, err) 100 require.Equal(t, req.Msg.Name, decoded.Msg.Name) 101 102 // ensure no headers leak 103 require.Equal(t, []string{"X-My-App: foobar"}, headerToSlice(t, resp.Header())) 104 105 }) 106 107 t.Run("HTTP request URL can be overridden", func(t *testing.T) { 108 req := request(t) 109 m := &mockRoundTripper{resp: &httpgrpc.HTTPResponse{Code: 200}} 110 const url = "TestURL" 111 ctx := WithProcedure(context.Background(), url) 112 _, err := RoundTripUnary[typesv1.LabelValuesRequest, typesv1.LabelValuesResponse](ctx, m, req) 113 require.NoError(t, err) 114 require.Equal(t, url, m.req.Url) 115 }) 116 }