github.com/cockroachdb/cockroachdb-parser@v0.23.3-0.20240213214944-911057d40c9a/pkg/util/protoutil/marshaler.go (about)

     1  // Copyright 2016 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package protoutil
    12  
    13  import (
    14  	"io"
    15  
    16  	"github.com/cockroachdb/errors"
    17  	"github.com/gogo/protobuf/proto"
    18  	gwruntime "github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
    19  )
    20  
    21  var _ gwruntime.Marshaler = (*ProtoPb)(nil)
    22  
    23  // ProtoPb is a gwruntime.Marshaler that uses github.com/gogo/protobuf/proto.
    24  type ProtoPb struct{}
    25  
    26  // ContentType implements gwruntime.Marshaler.
    27  func (*ProtoPb) ContentType(_ interface{}) string {
    28  	// NB: This is the same as httputil.ProtoContentType which we can't use due
    29  	// to an import cycle.
    30  	const ProtoContentType = "application/x-protobuf"
    31  	return ProtoContentType
    32  }
    33  
    34  // Marshal implements gwruntime.Marshaler.
    35  func (*ProtoPb) Marshal(v interface{}) ([]byte, error) {
    36  	// NB: we use proto.Message here because grpc-gateway passes us protos that
    37  	// we don't control and thus don't implement protoutil.Message.
    38  	if p, ok := v.(proto.Message); ok {
    39  		return proto.Marshal(p)
    40  	}
    41  	return nil, errors.Errorf("unexpected type %T does not implement %s", v, typeProtoMessage)
    42  }
    43  
    44  // Unmarshal implements gwruntime.Marshaler.
    45  func (*ProtoPb) Unmarshal(data []byte, v interface{}) error {
    46  	// NB: we use proto.Message here because grpc-gateway passes us protos that
    47  	// we don't control and thus don't implement protoutil.Message.
    48  	if p, ok := v.(proto.Message); ok {
    49  		return proto.Unmarshal(data, p)
    50  	}
    51  	return errors.Errorf("unexpected type %T does not implement %s", v, typeProtoMessage)
    52  }
    53  
    54  // NewDecoder implements gwruntime.Marshaler.
    55  func (*ProtoPb) NewDecoder(r io.Reader) gwruntime.Decoder {
    56  	return gwruntime.DecoderFunc(func(v interface{}) error {
    57  		// NB: we use proto.Message here because grpc-gateway passes us protos that
    58  		// we don't control and thus don't implement protoutil.Message.
    59  		if p, ok := v.(proto.Message); ok {
    60  			bytes, err := io.ReadAll(r)
    61  			if err == nil {
    62  				err = proto.Unmarshal(bytes, p)
    63  			}
    64  			return err
    65  		}
    66  		return errors.Errorf("unexpected type %T does not implement %s", v, typeProtoMessage)
    67  	})
    68  }
    69  
    70  // NewEncoder implements gwruntime.Marshaler.
    71  func (*ProtoPb) NewEncoder(w io.Writer) gwruntime.Encoder {
    72  	return gwruntime.EncoderFunc(func(v interface{}) error {
    73  		// NB: we use proto.Message here because grpc-gateway passes us protos that
    74  		// we don't control and thus don't implement protoutil.Message.
    75  		if p, ok := v.(proto.Message); ok {
    76  			bytes, err := proto.Marshal(p)
    77  			if err == nil {
    78  				_, err = w.Write(bytes)
    79  			}
    80  			return err
    81  		}
    82  		return errors.Errorf("unexpected type %T does not implement %s", v, typeProtoMessage)
    83  	})
    84  }
    85  
    86  var _ gwruntime.Delimited = (*ProtoPb)(nil)
    87  
    88  // Delimiter implements gwruntime.Delimited.
    89  func (*ProtoPb) Delimiter() []byte {
    90  	return nil
    91  }