github.com/grafana/pyroscope@v1.18.0/pkg/util/connectgrpc/connectgrpc_test.go (about)

     1  package connectgrpc
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"net/http"
     7  	"net/http/httptest"
     8  	"strings"
     9  	"testing"
    10  
    11  	"connectrpc.com/connect"
    12  	"github.com/gorilla/mux"
    13  	"github.com/samber/lo"
    14  	"github.com/stretchr/testify/require"
    15  	"golang.org/x/net/http2"
    16  	"golang.org/x/net/http2/h2c"
    17  
    18  	"github.com/grafana/pyroscope/api/gen/proto/go/querier/v1/querierv1connect"
    19  	typesv1 "github.com/grafana/pyroscope/api/gen/proto/go/types/v1"
    20  	"github.com/grafana/pyroscope/pkg/util/httpgrpc"
    21  )
    22  
    23  type fakeQuerier struct {
    24  	querierv1connect.UnimplementedQuerierServiceHandler
    25  	req  *connect.Request[typesv1.LabelValuesRequest]
    26  	resp *connect.Response[typesv1.LabelValuesResponse]
    27  }
    28  
    29  func (f *fakeQuerier) LabelValues(_ context.Context, req *connect.Request[typesv1.LabelValuesRequest]) (*connect.Response[typesv1.LabelValuesResponse], error) {
    30  	f.req = req
    31  	return f.resp, nil
    32  }
    33  
    34  type mockRoundTripper struct {
    35  	req  *httpgrpc.HTTPRequest
    36  	resp *httpgrpc.HTTPResponse
    37  }
    38  
    39  func (m *mockRoundTripper) RoundTripGRPC(_ context.Context, req *httpgrpc.HTTPRequest) (*httpgrpc.HTTPResponse, error) {
    40  	m.req = req
    41  	return m.resp, nil
    42  }
    43  
    44  func headerToSlice(t testing.TB, header http.Header) []string {
    45  	buf := new(bytes.Buffer)
    46  	excludeHeaders := map[string]bool{"Content-Length": true, "Date": true}
    47  	require.NoError(t, header.WriteSubset(buf, excludeHeaders))
    48  	sl := strings.Split(strings.ReplaceAll(buf.String(), "\r\n", "\n"), "\n")
    49  	if len(sl) > 0 && sl[len(sl)-1] == "" {
    50  		sl = sl[:len(sl)-1]
    51  	}
    52  	return sl
    53  }
    54  
    55  func Test_RoundTripUnary(t *testing.T) {
    56  	request := func(t *testing.T) *connect.Request[typesv1.LabelValuesRequest] {
    57  		server := httptest.NewUnstartedServer(nil)
    58  		mux := mux.NewRouter()
    59  		server.Config.Handler = h2c.NewHandler(mux, &http2.Server{})
    60  
    61  		server.Start()
    62  		defer server.Close()
    63  		f := &fakeQuerier{resp: &connect.Response[typesv1.LabelValuesResponse]{
    64  			Msg: &typesv1.LabelValuesResponse{Names: []string{"foo", "bar"}},
    65  		}}
    66  		querierv1connect.RegisterQuerierServiceHandler(mux, f)
    67  
    68  		client := querierv1connect.NewQuerierServiceClient(http.DefaultClient, server.URL)
    69  		req := &typesv1.LabelValuesRequest{
    70  			Name: "foo",
    71  		}
    72  		_, err := client.LabelValues(context.Background(), connect.NewRequest(req))
    73  		require.NoError(t, err)
    74  		return f.req
    75  	}
    76  
    77  	t.Run("HTTP request can trip GRPC", func(t *testing.T) {
    78  		req := request(t)
    79  		m := &mockRoundTripper{resp: &httpgrpc.HTTPResponse{
    80  			Code: 200,
    81  			Headers: []*httpgrpc.Header{
    82  				{Key: "Content-Type", Values: []string{"application/proto"}},
    83  				{Key: "X-My-App", Values: []string{"foobar"}},
    84  			},
    85  		}}
    86  
    87  		resp, err := RoundTripUnary[typesv1.LabelValuesRequest, typesv1.LabelValuesResponse](context.Background(), m, req)
    88  		require.NoError(t, err)
    89  		require.Equal(t, "POST", m.req.Method)
    90  		require.Equal(t, "/querier.v1.QuerierService/LabelValues", m.req.Url)
    91  		actualHeaders := lo.Map(m.req.Headers, func(h *httpgrpc.Header, index int) string {
    92  			return h.Key + ": " + strings.Join(h.Values, ",")
    93  		})
    94  		require.Contains(t, actualHeaders, "Content-Type: application/proto")
    95  		require.Contains(t, actualHeaders, "Connect-Protocol-Version: 1")
    96  		require.Contains(t, actualHeaders, "Accept-Encoding: gzip")
    97  
    98  		decoded, err := decodeRequest[typesv1.LabelValuesRequest](m.req)
    99  		require.NoError(t, err)
   100  		require.Equal(t, req.Msg.Name, decoded.Msg.Name)
   101  
   102  		// ensure no headers leak
   103  		require.Equal(t, []string{"X-My-App: foobar"}, headerToSlice(t, resp.Header()))
   104  
   105  	})
   106  
   107  	t.Run("HTTP request URL can be overridden", func(t *testing.T) {
   108  		req := request(t)
   109  		m := &mockRoundTripper{resp: &httpgrpc.HTTPResponse{Code: 200}}
   110  		const url = "TestURL"
   111  		ctx := WithProcedure(context.Background(), url)
   112  		_, err := RoundTripUnary[typesv1.LabelValuesRequest, typesv1.LabelValuesResponse](ctx, m, req)
   113  		require.NoError(t, err)
   114  		require.Equal(t, url, m.req.Url)
   115  	})
   116  }