github.com/cockroachdb/cockroachdb-parser@v0.23.3-0.20240213214944-911057d40c9a/pkg/util/protoutil/jsonpb_marshal.go (about)

     1  // Copyright 2016 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package protoutil
    12  
    13  import (
    14  	"bytes"
    15  	"encoding/json"
    16  	"fmt"
    17  	"io"
    18  	"reflect"
    19  
    20  	"github.com/cockroachdb/errors"
    21  	"github.com/gogo/protobuf/jsonpb"
    22  	"github.com/gogo/protobuf/proto"
    23  	gwruntime "github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
    24  )
    25  
    26  var _ gwruntime.Marshaler = (*JSONPb)(nil)
    27  
    28  var typeProtoMessage = reflect.TypeOf((*proto.Message)(nil)).Elem()
    29  
    30  // JSONPb is a gwruntime.Marshaler that uses github.com/gogo/protobuf/jsonpb.
    31  type JSONPb jsonpb.Marshaler
    32  
    33  // ContentType implements gwruntime.Marshaler.
    34  func (*JSONPb) ContentType(_ interface{}) string {
    35  	// NB: This is the same as httputil.JSONContentType which we can't use due to
    36  	// an import cycle.
    37  	const JSONContentType = "application/json"
    38  	return JSONContentType
    39  }
    40  
    41  // Marshal implements gwruntime.Marshaler.
    42  func (j *JSONPb) Marshal(v interface{}) ([]byte, error) {
    43  	return j.marshal(v)
    44  }
    45  
    46  // a lower-case version of marshal to allow for a call from
    47  // marshalNonProtoField without upsetting TestProtoMarshal().
    48  func (j *JSONPb) marshal(v interface{}) ([]byte, error) {
    49  	// NB: we use proto.Message here because grpc-gateway passes us protos that
    50  	// we don't control and thus don't implement protoutil.Message.
    51  	if pb, ok := v.(proto.Message); ok {
    52  		var buf bytes.Buffer
    53  		marshalFn := (*jsonpb.Marshaler)(j).Marshal
    54  		if err := marshalFn(&buf, pb); err != nil {
    55  			return nil, err
    56  		}
    57  		return buf.Bytes(), nil
    58  	}
    59  	return j.marshalNonProtoField(v)
    60  }
    61  
    62  // Cribbed verbatim from grpc-gateway.
    63  type protoEnum interface {
    64  	fmt.Stringer
    65  	EnumDescriptor() ([]byte, []int)
    66  }
    67  
    68  // Cribbed verbatim from grpc-gateway.
    69  func (j *JSONPb) marshalNonProtoField(v interface{}) ([]byte, error) {
    70  	rv := reflect.ValueOf(v)
    71  	for rv.Kind() == reflect.Ptr {
    72  		if rv.IsNil() {
    73  			return []byte("null"), nil
    74  		}
    75  		rv = rv.Elem()
    76  	}
    77  
    78  	if rv.Kind() == reflect.Map {
    79  		m := make(map[string]*json.RawMessage)
    80  		for _, k := range rv.MapKeys() {
    81  			buf, err := j.marshal(rv.MapIndex(k).Interface())
    82  			if err != nil {
    83  				return nil, err
    84  			}
    85  			m[fmt.Sprintf("%v", k.Interface())] = (*json.RawMessage)(&buf)
    86  		}
    87  		if j.Indent != "" {
    88  			return json.MarshalIndent(m, "", j.Indent)
    89  		}
    90  		return json.Marshal(m)
    91  	}
    92  	if enum, ok := rv.Interface().(protoEnum); ok && !j.EnumsAsInts {
    93  		return json.Marshal(enum.String())
    94  	}
    95  	return json.Marshal(rv.Interface())
    96  }
    97  
    98  // Unmarshal implements gwruntime.Marshaler.
    99  func (j *JSONPb) Unmarshal(data []byte, v interface{}) error {
   100  	// NB: we use proto.Message here because grpc-gateway passes us protos that
   101  	// we don't control and thus don't implement protoutil.Message.
   102  	if pb, ok := v.(proto.Message); ok {
   103  		return jsonpb.Unmarshal(bytes.NewReader(data), pb)
   104  	}
   105  	return errors.Errorf("unexpected type %T does not implement %s", v, typeProtoMessage)
   106  }
   107  
   108  // NewDecoder implements gwruntime.Marshaler.
   109  func (j *JSONPb) NewDecoder(r io.Reader) gwruntime.Decoder {
   110  	return gwruntime.DecoderFunc(func(v interface{}) error {
   111  		// NB: we use proto.Message here because grpc-gateway passes us protos that
   112  		// we don't control and thus don't implement protoutil.Message.
   113  		if pb, ok := v.(proto.Message); ok {
   114  			return jsonpb.Unmarshal(r, pb)
   115  		}
   116  		return errors.Errorf("unexpected type %T does not implement %s", v, typeProtoMessage)
   117  	})
   118  }
   119  
   120  // NewEncoder implements gwruntime.Marshaler.
   121  func (j *JSONPb) NewEncoder(w io.Writer) gwruntime.Encoder {
   122  	return gwruntime.EncoderFunc(func(v interface{}) error {
   123  		// NB: we use proto.Message here because grpc-gateway passes us protos that
   124  		// we don't control and thus don't implement protoutil.Message.
   125  		if pb, ok := v.(proto.Message); ok {
   126  			marshalFn := (*jsonpb.Marshaler)(j).Marshal
   127  			return marshalFn(w, pb)
   128  		}
   129  		return errors.Errorf("unexpected type %T does not implement %s", v, typeProtoMessage)
   130  	})
   131  }
   132  
   133  var _ gwruntime.Delimited = (*JSONPb)(nil)
   134  
   135  // Delimiter implements gwruntime.Delimited.
   136  func (*JSONPb) Delimiter() []byte {
   137  	return []byte("\n")
   138  }