github.com/cockroachdb/cockroachdb-parser@v0.23.3-0.20240213214944-911057d40c9a/pkg/util/protoutil/marshaler.go (about) 1 // Copyright 2016 The Cockroach Authors. 2 // 3 // Use of this software is governed by the Business Source License 4 // included in the file licenses/BSL.txt. 5 // 6 // As of the Change Date specified in that file, in accordance with 7 // the Business Source License, use of this software will be governed 8 // by the Apache License, Version 2.0, included in the file 9 // licenses/APL.txt. 10 11 package protoutil 12 13 import ( 14 "io" 15 16 "github.com/cockroachdb/errors" 17 "github.com/gogo/protobuf/proto" 18 gwruntime "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" 19 ) 20 21 var _ gwruntime.Marshaler = (*ProtoPb)(nil) 22 23 // ProtoPb is a gwruntime.Marshaler that uses github.com/gogo/protobuf/proto. 24 type ProtoPb struct{} 25 26 // ContentType implements gwruntime.Marshaler. 27 func (*ProtoPb) ContentType(_ interface{}) string { 28 // NB: This is the same as httputil.ProtoContentType which we can't use due 29 // to an import cycle. 30 const ProtoContentType = "application/x-protobuf" 31 return ProtoContentType 32 } 33 34 // Marshal implements gwruntime.Marshaler. 35 func (*ProtoPb) Marshal(v interface{}) ([]byte, error) { 36 // NB: we use proto.Message here because grpc-gateway passes us protos that 37 // we don't control and thus don't implement protoutil.Message. 38 if p, ok := v.(proto.Message); ok { 39 return proto.Marshal(p) 40 } 41 return nil, errors.Errorf("unexpected type %T does not implement %s", v, typeProtoMessage) 42 } 43 44 // Unmarshal implements gwruntime.Marshaler. 45 func (*ProtoPb) Unmarshal(data []byte, v interface{}) error { 46 // NB: we use proto.Message here because grpc-gateway passes us protos that 47 // we don't control and thus don't implement protoutil.Message. 48 if p, ok := v.(proto.Message); ok { 49 return proto.Unmarshal(data, p) 50 } 51 return errors.Errorf("unexpected type %T does not implement %s", v, typeProtoMessage) 52 } 53 54 // NewDecoder implements gwruntime.Marshaler. 55 func (*ProtoPb) NewDecoder(r io.Reader) gwruntime.Decoder { 56 return gwruntime.DecoderFunc(func(v interface{}) error { 57 // NB: we use proto.Message here because grpc-gateway passes us protos that 58 // we don't control and thus don't implement protoutil.Message. 59 if p, ok := v.(proto.Message); ok { 60 bytes, err := io.ReadAll(r) 61 if err == nil { 62 err = proto.Unmarshal(bytes, p) 63 } 64 return err 65 } 66 return errors.Errorf("unexpected type %T does not implement %s", v, typeProtoMessage) 67 }) 68 } 69 70 // NewEncoder implements gwruntime.Marshaler. 71 func (*ProtoPb) NewEncoder(w io.Writer) gwruntime.Encoder { 72 return gwruntime.EncoderFunc(func(v interface{}) error { 73 // NB: we use proto.Message here because grpc-gateway passes us protos that 74 // we don't control and thus don't implement protoutil.Message. 75 if p, ok := v.(proto.Message); ok { 76 bytes, err := proto.Marshal(p) 77 if err == nil { 78 _, err = w.Write(bytes) 79 } 80 return err 81 } 82 return errors.Errorf("unexpected type %T does not implement %s", v, typeProtoMessage) 83 }) 84 } 85 86 var _ gwruntime.Delimited = (*ProtoPb)(nil) 87 88 // Delimiter implements gwruntime.Delimited. 89 func (*ProtoPb) Delimiter() []byte { 90 return nil 91 }