github.com/grafana/pyroscope@v1.18.0/pkg/util/connectgrpc/connectgrpc.go (about)

     1  package connectgrpc
     2  
     3  import (
     4  	"bytes"
     5  	"compress/gzip"
     6  	"context"
     7  	"errors"
     8  	"fmt"
     9  	"io"
    10  	"net/http"
    11  	"strings"
    12  
    13  	"connectrpc.com/connect"
    14  	"google.golang.org/protobuf/proto"
    15  
    16  	"github.com/grafana/pyroscope/pkg/tenant"
    17  	"github.com/grafana/pyroscope/pkg/util/httpgrpc"
    18  )
    19  
    20  type UnaryHandler[Req any, Res any] func(context.Context, *connect.Request[Req]) (*connect.Response[Res], error)
    21  
    22  func HandleUnary[Req any, Res any](ctx context.Context, req *httpgrpc.HTTPRequest, u UnaryHandler[Req, Res]) (*httpgrpc.HTTPResponse, error) {
    23  	connectReq, err := decodeRequest[Req](req)
    24  	if err != nil {
    25  		return nil, err
    26  	}
    27  	connectResp, err := u(ctx, connectReq)
    28  	if err != nil {
    29  		if errors.Is(err, tenant.ErrNoTenantID) {
    30  			err = connect.NewError(connect.CodeUnauthenticated, err)
    31  		}
    32  		var connectErr *connect.Error
    33  		if errors.As(err, &connectErr) {
    34  			return &httpgrpc.HTTPResponse{
    35  				Code:    CodeToHTTP(connectErr.Code()),
    36  				Body:    []byte(connectErr.Message()),
    37  				Headers: connectHeaderToHTTPGRPCHeader(connectErr.Meta()),
    38  			}, nil
    39  		}
    40  
    41  		return nil, err
    42  	}
    43  	return encodeResponse(connectResp)
    44  }
    45  
    46  type GRPCRoundTripper interface {
    47  	RoundTripGRPC(ctx context.Context, req *httpgrpc.HTTPRequest) (*httpgrpc.HTTPResponse, error)
    48  }
    49  
    50  type GRPCHandler interface {
    51  	Handle(ctx context.Context, req *httpgrpc.HTTPRequest) (*httpgrpc.HTTPResponse, error)
    52  }
    53  
    54  func RoundTripUnary[Req any, Res any](ctx context.Context, rt GRPCRoundTripper, in *connect.Request[Req]) (*connect.Response[Res], error) {
    55  	req, err := encodeRequest(ctx, in)
    56  	if err != nil {
    57  		return nil, err
    58  	}
    59  	res, err := rt.RoundTripGRPC(ctx, req)
    60  	if err != nil {
    61  		return nil, err
    62  	}
    63  	if res.Code/100 != 2 {
    64  		err := connect.NewError(HTTPToCode(res.Code), errors.New(string(res.Body)))
    65  		for _, h := range res.Headers {
    66  			for _, v := range h.Values {
    67  				err.Meta().Add(h.Key, v)
    68  			}
    69  		}
    70  		return nil, err
    71  	}
    72  	return decodeResponse[Res](res)
    73  }
    74  
    75  func CloneRequest[Req any](base *connect.Request[Req], msg *Req) *connect.Request[Req] {
    76  	r := *base
    77  	r.Msg = msg
    78  	return &r
    79  }
    80  
    81  func encodeResponse[Req any](resp *connect.Response[Req]) (*httpgrpc.HTTPResponse, error) {
    82  	out := &httpgrpc.HTTPResponse{
    83  		Headers: connectHeaderToHTTPGRPCHeader(resp.Header()),
    84  		Code:    http.StatusOK,
    85  	}
    86  	var err error
    87  	out.Body, err = proto.Marshal(resp.Any().(proto.Message))
    88  	if err != nil {
    89  		return nil, err
    90  	}
    91  	return out, nil
    92  }
    93  
    94  func connectHeaderToHTTPGRPCHeader(header http.Header) []*httpgrpc.Header {
    95  	result := make([]*httpgrpc.Header, 0, len(header))
    96  	for k, v := range header {
    97  		result = append(result, &httpgrpc.Header{
    98  			Key:    k,
    99  			Values: v,
   100  		})
   101  	}
   102  	return result
   103  }
   104  
   105  func httpgrpcHeaderToConnectHeader(header []*httpgrpc.Header) http.Header {
   106  	result := make(http.Header, len(header))
   107  	for _, h := range header {
   108  		result[h.Key] = h.Values
   109  	}
   110  	return result
   111  }
   112  
   113  func decodeRequest[Req any](req *httpgrpc.HTTPRequest) (*connect.Request[Req], error) {
   114  	result := &connect.Request[Req]{
   115  		Msg: new(Req),
   116  	}
   117  	err := proto.Unmarshal(req.Body, result.Any().(proto.Message))
   118  	if err != nil {
   119  		return nil, err
   120  	}
   121  	return result, nil
   122  }
   123  
   124  type connectURLCtxKey struct{}
   125  
   126  func WithProcedure(ctx context.Context, u string) context.Context {
   127  	return context.WithValue(ctx, connectURLCtxKey{}, u)
   128  }
   129  
   130  func ProcedureFromContext(ctx context.Context) string {
   131  	s, _ := ctx.Value(connectURLCtxKey{}).(string)
   132  	return s
   133  }
   134  
   135  func encodeRequest[Req any](ctx context.Context, req *connect.Request[Req]) (*httpgrpc.HTTPRequest, error) {
   136  	url := ProcedureFromContext(ctx)
   137  	if url == "" {
   138  		if url = req.Spec().Procedure; url == "" {
   139  			return nil, errors.New("cannot encode a request with empty procedure")
   140  		}
   141  	}
   142  	// The original Content-* headers could be invalidated,
   143  	// e.g. initial Content-Type could be 'application/json'.
   144  	h := removeContentHeaders(req.Header().Clone())
   145  	h.Set("Content-Type", "application/proto")
   146  	out := &httpgrpc.HTTPRequest{
   147  		Method:  http.MethodPost,
   148  		Url:     url,
   149  		Headers: connectHeaderToHTTPGRPCHeader(h),
   150  	}
   151  	var err error
   152  	msg := req.Any()
   153  	out.Body, err = proto.Marshal(msg.(proto.Message))
   154  	if err != nil {
   155  		return nil, err
   156  	}
   157  	return out, nil
   158  }
   159  
   160  func removeContentHeaders(h http.Header) http.Header {
   161  	for k := range h {
   162  		if strings.HasPrefix(strings.ToLower(k), "content-") {
   163  			h.Del(k)
   164  		}
   165  	}
   166  	return h
   167  }
   168  
   169  // filterHeader filters headers, which would expose details about the implementation details of the connectgrpc implementation
   170  func filterHeader(name string) bool {
   171  	if strings.ToLower(name) == "content-type" {
   172  		return true
   173  	}
   174  	if strings.ToLower(name) == "accept-encoding" {
   175  		return true
   176  	}
   177  	if strings.ToLower(name) == "content-encoding" {
   178  		return true
   179  	}
   180  	return false
   181  }
   182  
   183  func decodeResponse[Resp any](r *httpgrpc.HTTPResponse) (*connect.Response[Resp], error) {
   184  	if err := decompressResponse(r); err != nil {
   185  		return nil, err
   186  	}
   187  	resp := &connect.Response[Resp]{Msg: new(Resp)}
   188  	for _, h := range r.Headers {
   189  		if filterHeader(h.Key) {
   190  			continue
   191  		}
   192  
   193  		for _, v := range h.Values {
   194  			resp.Header().Add(h.Key, v)
   195  		}
   196  	}
   197  	if err := proto.Unmarshal(r.Body, resp.Any().(proto.Message)); err != nil {
   198  		return nil, err
   199  	}
   200  	return resp, nil
   201  }
   202  
   203  func decompressResponse(r *httpgrpc.HTTPResponse) error {
   204  	// We use gziphandler to compress responses of some methods,
   205  	// therefore decompression is very likely to be required.
   206  	// The handling is pretty much the same as in http.Transport,
   207  	// which only supports gzip Content-Encoding.
   208  	for _, h := range r.Headers {
   209  		if h.Key == "Content-Encoding" {
   210  			for _, v := range h.Values {
   211  				switch {
   212  				default:
   213  					return fmt.Errorf("unsupported Content-Encoding: %s", v)
   214  				case v == "":
   215  				case strings.EqualFold(v, "gzip"):
   216  					// bytes.Buffer implements flate.Reader, therefore
   217  					// a gzip reader does not allocate a buffer.
   218  					g, err := gzip.NewReader(bytes.NewBuffer(r.Body))
   219  					if err != nil {
   220  						return err
   221  					}
   222  					r.Body, err = io.ReadAll(g)
   223  					return err
   224  				}
   225  			}
   226  			return nil
   227  		}
   228  	}
   229  	return nil
   230  }
   231  
   232  func CodeToHTTP(code connect.Code) int32 {
   233  	// Return literals rather than named constants from the HTTP package to make
   234  	// it easier to compare this function to the Connect specification.
   235  	switch code {
   236  	case connect.CodeCanceled:
   237  		return 499
   238  	case connect.CodeUnknown:
   239  		return 500
   240  	case connect.CodeInvalidArgument:
   241  		return 400
   242  	case connect.CodeDeadlineExceeded:
   243  		return 504
   244  	case connect.CodeNotFound:
   245  		return 404
   246  	case connect.CodeAlreadyExists:
   247  		return 409
   248  	case connect.CodePermissionDenied:
   249  		return 403
   250  	case connect.CodeResourceExhausted:
   251  		return 429
   252  	case connect.CodeFailedPrecondition:
   253  		return 412
   254  	case connect.CodeAborted:
   255  		return 409
   256  	case connect.CodeOutOfRange:
   257  		return 400
   258  	case connect.CodeUnimplemented:
   259  		return 404
   260  	case connect.CodeInternal:
   261  		return 500
   262  	case connect.CodeUnavailable:
   263  		return 503
   264  	case connect.CodeDataLoss:
   265  		return 500
   266  	case connect.CodeUnauthenticated:
   267  		return 401
   268  	default:
   269  		return 500 // same as CodeUnknown
   270  	}
   271  }
   272  
   273  func HTTPToCode(httpCode int32) connect.Code {
   274  	// As above, literals are easier to compare to the specificaton (vs named
   275  	// constants).
   276  	switch httpCode {
   277  	case 400:
   278  		return connect.CodeInvalidArgument
   279  	case 401:
   280  		return connect.CodeUnauthenticated
   281  	case 403:
   282  		return connect.CodePermissionDenied
   283  	case 404:
   284  		return connect.CodeUnimplemented
   285  	case 412:
   286  		return connect.CodeFailedPrecondition
   287  	case 413:
   288  		return connect.CodeInvalidArgument
   289  	case 429:
   290  		return connect.CodeResourceExhausted
   291  	case 431:
   292  		return connect.CodeResourceExhausted
   293  	case 499:
   294  		return connect.CodeCanceled
   295  	case 502, 503:
   296  		return connect.CodeUnavailable
   297  	case 504:
   298  		return connect.CodeDeadlineExceeded
   299  	default:
   300  		return connect.CodeUnknown
   301  	}
   302  }
   303  
   304  type responseWriter struct {
   305  	header http.Header
   306  	resp   httpgrpc.HTTPResponse
   307  }
   308  
   309  func (r *responseWriter) Header() http.Header {
   310  	return r.header
   311  }
   312  
   313  func (r *responseWriter) Write(data []byte) (int, error) {
   314  	r.resp.Body = append(r.resp.Body, data...)
   315  	return len(data), nil
   316  }
   317  
   318  func (r *responseWriter) WriteHeader(statusCode int) {
   319  	r.resp.Code = int32(statusCode)
   320  }
   321  
   322  func (r *responseWriter) HTTPResponse() *httpgrpc.HTTPResponse {
   323  	r.resp.Headers = connectHeaderToHTTPGRPCHeader(r.header)
   324  	return &r.resp
   325  }
   326  
   327  // NewHandler converts a Connect handler into a HTTPGRPC handler
   328  type grpcHandler struct {
   329  	next http.Handler
   330  }
   331  
   332  func NewHandler(h http.Handler) GRPCHandler {
   333  	return &grpcHandler{next: h}
   334  }
   335  
   336  func newResponseWriter() *responseWriter {
   337  	rw := &responseWriter{header: http.Header{}}
   338  	rw.resp.Code = 200
   339  	return rw
   340  }
   341  
   342  func (q *grpcHandler) Handle(ctx context.Context, req *httpgrpc.HTTPRequest) (*httpgrpc.HTTPResponse, error) {
   343  	stdReq, err := http.NewRequestWithContext(ctx, req.Method, req.Url, bytes.NewReader(req.Body))
   344  	if err != nil {
   345  		return nil, err
   346  	}
   347  	stdReq.Header = httpgrpcHeaderToConnectHeader(req.Headers)
   348  
   349  	rw := newResponseWriter()
   350  	q.next.ServeHTTP(rw, stdReq)
   351  
   352  	return rw.HTTPResponse(), nil
   353  }
   354  
   355  type httpgrpcClient struct {
   356  	transport GRPCRoundTripper
   357  }
   358  
   359  func NewClient(transport GRPCRoundTripper) connect.HTTPClient {
   360  	return &httpgrpcClient{transport: transport}
   361  }
   362  
   363  func (g *httpgrpcClient) Do(req *http.Request) (*http.Response, error) {
   364  	body, err := io.ReadAll(req.Body)
   365  	if err != nil {
   366  		return nil, err
   367  	}
   368  
   369  	resp, err := g.transport.RoundTripGRPC(req.Context(), &httpgrpc.HTTPRequest{
   370  		Url:     req.URL.String(),
   371  		Headers: connectHeaderToHTTPGRPCHeader(req.Header),
   372  		Method:  req.Method,
   373  		Body:    body,
   374  	})
   375  	if err != nil {
   376  		return nil, fmt.Errorf("grpc roundtripper error: %w", err)
   377  	}
   378  
   379  	return &http.Response{
   380  		Body:          io.NopCloser(bytes.NewReader(resp.Body)),
   381  		ContentLength: int64(len(resp.Body)),
   382  		StatusCode:    int(resp.Code),
   383  		Header:        httpgrpcHeaderToConnectHeader(resp.Headers),
   384  	}, nil
   385  }