github.com/grafana/pyroscope@v1.18.0/pkg/util/connectgrpc/connectgrpc.go (about) 1 package connectgrpc 2 3 import ( 4 "bytes" 5 "compress/gzip" 6 "context" 7 "errors" 8 "fmt" 9 "io" 10 "net/http" 11 "strings" 12 13 "connectrpc.com/connect" 14 "google.golang.org/protobuf/proto" 15 16 "github.com/grafana/pyroscope/pkg/tenant" 17 "github.com/grafana/pyroscope/pkg/util/httpgrpc" 18 ) 19 20 type UnaryHandler[Req any, Res any] func(context.Context, *connect.Request[Req]) (*connect.Response[Res], error) 21 22 func HandleUnary[Req any, Res any](ctx context.Context, req *httpgrpc.HTTPRequest, u UnaryHandler[Req, Res]) (*httpgrpc.HTTPResponse, error) { 23 connectReq, err := decodeRequest[Req](req) 24 if err != nil { 25 return nil, err 26 } 27 connectResp, err := u(ctx, connectReq) 28 if err != nil { 29 if errors.Is(err, tenant.ErrNoTenantID) { 30 err = connect.NewError(connect.CodeUnauthenticated, err) 31 } 32 var connectErr *connect.Error 33 if errors.As(err, &connectErr) { 34 return &httpgrpc.HTTPResponse{ 35 Code: CodeToHTTP(connectErr.Code()), 36 Body: []byte(connectErr.Message()), 37 Headers: connectHeaderToHTTPGRPCHeader(connectErr.Meta()), 38 }, nil 39 } 40 41 return nil, err 42 } 43 return encodeResponse(connectResp) 44 } 45 46 type GRPCRoundTripper interface { 47 RoundTripGRPC(ctx context.Context, req *httpgrpc.HTTPRequest) (*httpgrpc.HTTPResponse, error) 48 } 49 50 type GRPCHandler interface { 51 Handle(ctx context.Context, req *httpgrpc.HTTPRequest) (*httpgrpc.HTTPResponse, error) 52 } 53 54 func RoundTripUnary[Req any, Res any](ctx context.Context, rt GRPCRoundTripper, in *connect.Request[Req]) (*connect.Response[Res], error) { 55 req, err := encodeRequest(ctx, in) 56 if err != nil { 57 return nil, err 58 } 59 res, err := rt.RoundTripGRPC(ctx, req) 60 if err != nil { 61 return nil, err 62 } 63 if res.Code/100 != 2 { 64 err := connect.NewError(HTTPToCode(res.Code), errors.New(string(res.Body))) 65 for _, h := range res.Headers { 66 for _, v := range h.Values { 67 err.Meta().Add(h.Key, v) 68 } 69 } 70 return nil, err 71 } 72 return decodeResponse[Res](res) 73 } 74 75 func CloneRequest[Req any](base *connect.Request[Req], msg *Req) *connect.Request[Req] { 76 r := *base 77 r.Msg = msg 78 return &r 79 } 80 81 func encodeResponse[Req any](resp *connect.Response[Req]) (*httpgrpc.HTTPResponse, error) { 82 out := &httpgrpc.HTTPResponse{ 83 Headers: connectHeaderToHTTPGRPCHeader(resp.Header()), 84 Code: http.StatusOK, 85 } 86 var err error 87 out.Body, err = proto.Marshal(resp.Any().(proto.Message)) 88 if err != nil { 89 return nil, err 90 } 91 return out, nil 92 } 93 94 func connectHeaderToHTTPGRPCHeader(header http.Header) []*httpgrpc.Header { 95 result := make([]*httpgrpc.Header, 0, len(header)) 96 for k, v := range header { 97 result = append(result, &httpgrpc.Header{ 98 Key: k, 99 Values: v, 100 }) 101 } 102 return result 103 } 104 105 func httpgrpcHeaderToConnectHeader(header []*httpgrpc.Header) http.Header { 106 result := make(http.Header, len(header)) 107 for _, h := range header { 108 result[h.Key] = h.Values 109 } 110 return result 111 } 112 113 func decodeRequest[Req any](req *httpgrpc.HTTPRequest) (*connect.Request[Req], error) { 114 result := &connect.Request[Req]{ 115 Msg: new(Req), 116 } 117 err := proto.Unmarshal(req.Body, result.Any().(proto.Message)) 118 if err != nil { 119 return nil, err 120 } 121 return result, nil 122 } 123 124 type connectURLCtxKey struct{} 125 126 func WithProcedure(ctx context.Context, u string) context.Context { 127 return context.WithValue(ctx, connectURLCtxKey{}, u) 128 } 129 130 func ProcedureFromContext(ctx context.Context) string { 131 s, _ := ctx.Value(connectURLCtxKey{}).(string) 132 return s 133 } 134 135 func encodeRequest[Req any](ctx context.Context, req *connect.Request[Req]) (*httpgrpc.HTTPRequest, error) { 136 url := ProcedureFromContext(ctx) 137 if url == "" { 138 if url = req.Spec().Procedure; url == "" { 139 return nil, errors.New("cannot encode a request with empty procedure") 140 } 141 } 142 // The original Content-* headers could be invalidated, 143 // e.g. initial Content-Type could be 'application/json'. 144 h := removeContentHeaders(req.Header().Clone()) 145 h.Set("Content-Type", "application/proto") 146 out := &httpgrpc.HTTPRequest{ 147 Method: http.MethodPost, 148 Url: url, 149 Headers: connectHeaderToHTTPGRPCHeader(h), 150 } 151 var err error 152 msg := req.Any() 153 out.Body, err = proto.Marshal(msg.(proto.Message)) 154 if err != nil { 155 return nil, err 156 } 157 return out, nil 158 } 159 160 func removeContentHeaders(h http.Header) http.Header { 161 for k := range h { 162 if strings.HasPrefix(strings.ToLower(k), "content-") { 163 h.Del(k) 164 } 165 } 166 return h 167 } 168 169 // filterHeader filters headers, which would expose details about the implementation details of the connectgrpc implementation 170 func filterHeader(name string) bool { 171 if strings.ToLower(name) == "content-type" { 172 return true 173 } 174 if strings.ToLower(name) == "accept-encoding" { 175 return true 176 } 177 if strings.ToLower(name) == "content-encoding" { 178 return true 179 } 180 return false 181 } 182 183 func decodeResponse[Resp any](r *httpgrpc.HTTPResponse) (*connect.Response[Resp], error) { 184 if err := decompressResponse(r); err != nil { 185 return nil, err 186 } 187 resp := &connect.Response[Resp]{Msg: new(Resp)} 188 for _, h := range r.Headers { 189 if filterHeader(h.Key) { 190 continue 191 } 192 193 for _, v := range h.Values { 194 resp.Header().Add(h.Key, v) 195 } 196 } 197 if err := proto.Unmarshal(r.Body, resp.Any().(proto.Message)); err != nil { 198 return nil, err 199 } 200 return resp, nil 201 } 202 203 func decompressResponse(r *httpgrpc.HTTPResponse) error { 204 // We use gziphandler to compress responses of some methods, 205 // therefore decompression is very likely to be required. 206 // The handling is pretty much the same as in http.Transport, 207 // which only supports gzip Content-Encoding. 208 for _, h := range r.Headers { 209 if h.Key == "Content-Encoding" { 210 for _, v := range h.Values { 211 switch { 212 default: 213 return fmt.Errorf("unsupported Content-Encoding: %s", v) 214 case v == "": 215 case strings.EqualFold(v, "gzip"): 216 // bytes.Buffer implements flate.Reader, therefore 217 // a gzip reader does not allocate a buffer. 218 g, err := gzip.NewReader(bytes.NewBuffer(r.Body)) 219 if err != nil { 220 return err 221 } 222 r.Body, err = io.ReadAll(g) 223 return err 224 } 225 } 226 return nil 227 } 228 } 229 return nil 230 } 231 232 func CodeToHTTP(code connect.Code) int32 { 233 // Return literals rather than named constants from the HTTP package to make 234 // it easier to compare this function to the Connect specification. 235 switch code { 236 case connect.CodeCanceled: 237 return 499 238 case connect.CodeUnknown: 239 return 500 240 case connect.CodeInvalidArgument: 241 return 400 242 case connect.CodeDeadlineExceeded: 243 return 504 244 case connect.CodeNotFound: 245 return 404 246 case connect.CodeAlreadyExists: 247 return 409 248 case connect.CodePermissionDenied: 249 return 403 250 case connect.CodeResourceExhausted: 251 return 429 252 case connect.CodeFailedPrecondition: 253 return 412 254 case connect.CodeAborted: 255 return 409 256 case connect.CodeOutOfRange: 257 return 400 258 case connect.CodeUnimplemented: 259 return 404 260 case connect.CodeInternal: 261 return 500 262 case connect.CodeUnavailable: 263 return 503 264 case connect.CodeDataLoss: 265 return 500 266 case connect.CodeUnauthenticated: 267 return 401 268 default: 269 return 500 // same as CodeUnknown 270 } 271 } 272 273 func HTTPToCode(httpCode int32) connect.Code { 274 // As above, literals are easier to compare to the specificaton (vs named 275 // constants). 276 switch httpCode { 277 case 400: 278 return connect.CodeInvalidArgument 279 case 401: 280 return connect.CodeUnauthenticated 281 case 403: 282 return connect.CodePermissionDenied 283 case 404: 284 return connect.CodeUnimplemented 285 case 412: 286 return connect.CodeFailedPrecondition 287 case 413: 288 return connect.CodeInvalidArgument 289 case 429: 290 return connect.CodeResourceExhausted 291 case 431: 292 return connect.CodeResourceExhausted 293 case 499: 294 return connect.CodeCanceled 295 case 502, 503: 296 return connect.CodeUnavailable 297 case 504: 298 return connect.CodeDeadlineExceeded 299 default: 300 return connect.CodeUnknown 301 } 302 } 303 304 type responseWriter struct { 305 header http.Header 306 resp httpgrpc.HTTPResponse 307 } 308 309 func (r *responseWriter) Header() http.Header { 310 return r.header 311 } 312 313 func (r *responseWriter) Write(data []byte) (int, error) { 314 r.resp.Body = append(r.resp.Body, data...) 315 return len(data), nil 316 } 317 318 func (r *responseWriter) WriteHeader(statusCode int) { 319 r.resp.Code = int32(statusCode) 320 } 321 322 func (r *responseWriter) HTTPResponse() *httpgrpc.HTTPResponse { 323 r.resp.Headers = connectHeaderToHTTPGRPCHeader(r.header) 324 return &r.resp 325 } 326 327 // NewHandler converts a Connect handler into a HTTPGRPC handler 328 type grpcHandler struct { 329 next http.Handler 330 } 331 332 func NewHandler(h http.Handler) GRPCHandler { 333 return &grpcHandler{next: h} 334 } 335 336 func newResponseWriter() *responseWriter { 337 rw := &responseWriter{header: http.Header{}} 338 rw.resp.Code = 200 339 return rw 340 } 341 342 func (q *grpcHandler) Handle(ctx context.Context, req *httpgrpc.HTTPRequest) (*httpgrpc.HTTPResponse, error) { 343 stdReq, err := http.NewRequestWithContext(ctx, req.Method, req.Url, bytes.NewReader(req.Body)) 344 if err != nil { 345 return nil, err 346 } 347 stdReq.Header = httpgrpcHeaderToConnectHeader(req.Headers) 348 349 rw := newResponseWriter() 350 q.next.ServeHTTP(rw, stdReq) 351 352 return rw.HTTPResponse(), nil 353 } 354 355 type httpgrpcClient struct { 356 transport GRPCRoundTripper 357 } 358 359 func NewClient(transport GRPCRoundTripper) connect.HTTPClient { 360 return &httpgrpcClient{transport: transport} 361 } 362 363 func (g *httpgrpcClient) Do(req *http.Request) (*http.Response, error) { 364 body, err := io.ReadAll(req.Body) 365 if err != nil { 366 return nil, err 367 } 368 369 resp, err := g.transport.RoundTripGRPC(req.Context(), &httpgrpc.HTTPRequest{ 370 Url: req.URL.String(), 371 Headers: connectHeaderToHTTPGRPCHeader(req.Header), 372 Method: req.Method, 373 Body: body, 374 }) 375 if err != nil { 376 return nil, fmt.Errorf("grpc roundtripper error: %w", err) 377 } 378 379 return &http.Response{ 380 Body: io.NopCloser(bytes.NewReader(resp.Body)), 381 ContentLength: int64(len(resp.Body)), 382 StatusCode: int(resp.Code), 383 Header: httpgrpcHeaderToConnectHeader(resp.Headers), 384 }, nil 385 }