code.vegaprotocol.io/vega@v0.79.0/datanode/gateway/rest/jsonpb.go (about)

     1  // Copyright (C) 2023 Gobalsky Labs Limited
     2  //
     3  // This program is free software: you can redistribute it and/or modify
     4  // it under the terms of the GNU Affero General Public License as
     5  // published by the Free Software Foundation, either version 3 of the
     6  // License, or (at your option) any later version.
     7  //
     8  // This program is distributed in the hope that it will be useful,
     9  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    10  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    11  // GNU Affero General Public License for more details.
    12  //
    13  // You should have received a copy of the GNU Affero General Public License
    14  // along with this program.  If not, see <http://www.gnu.org/licenses/>.
    15  
    16  package rest
    17  
    18  import (
    19  	"bytes"
    20  	"encoding/json"
    21  	"fmt"
    22  	"io"
    23  	"reflect"
    24  
    25  	"github.com/golang/protobuf/jsonpb"
    26  	"github.com/golang/protobuf/proto"
    27  	"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
    28  )
    29  
    30  // JSONPb is a runtime.Marshaler which marshals/unmarshals into/from.
    31  type JSONPb jsonpb.Marshaler
    32  
    33  // ContentType always returns "application/json".
    34  func (*JSONPb) ContentType(interface{}) string {
    35  	return "application/json"
    36  }
    37  
    38  // Marshal marshals "v" into JSON.
    39  func (j *JSONPb) Marshal(v interface{}) ([]byte, error) {
    40  	if _, ok := v.(proto.Message); !ok {
    41  		return j.marshalNonProtoField(v)
    42  	}
    43  	var buf bytes.Buffer
    44  	if err := j.marshalTo(&buf, v); err != nil {
    45  		return nil, err
    46  	}
    47  	return buf.Bytes(), nil
    48  }
    49  
    50  func (j *JSONPb) marshalTo(w io.Writer, v interface{}) error {
    51  	p, ok := v.(proto.Message)
    52  	if !ok {
    53  		buf, err := j.marshalNonProtoField(v)
    54  		if err != nil {
    55  			return err
    56  		}
    57  		_, err = w.Write(buf)
    58  		return err
    59  	}
    60  	return (*jsonpb.Marshaler)(j).Marshal(w, p)
    61  }
    62  
    63  // marshalNonProto marshals a non-message field of a protobuf message.
    64  // This function does not correctly marshals arbitrary data structure into JSON,
    65  // but it is only capable of marshaling non-message field values of protobuf,
    66  // i.e. primitive types, enums; pointers to primitives or enums; maps from
    67  // integer/string types to primitives/enums/pointers to messages.
    68  func (j *JSONPb) marshalNonProtoField(v interface{}) ([]byte, error) {
    69  	rv := reflect.ValueOf(v)
    70  	for rv.Kind() == reflect.Ptr {
    71  		if rv.IsNil() {
    72  			return []byte("null"), nil
    73  		}
    74  		rv = rv.Elem()
    75  	}
    76  
    77  	if rv.Kind() == reflect.Map {
    78  		m := make(map[string]*json.RawMessage)
    79  		for _, k := range rv.MapKeys() {
    80  			buf, err := j.Marshal(rv.MapIndex(k).Interface())
    81  			if err != nil {
    82  				return nil, err
    83  			}
    84  			m[fmt.Sprintf("%v", k.Interface())] = (*json.RawMessage)(&buf)
    85  		}
    86  		if j.Indent != "" {
    87  			return json.MarshalIndent(m, "", j.Indent)
    88  		}
    89  		return json.Marshal(m)
    90  	}
    91  	if enum, ok := rv.Interface().(protoEnum); ok && !j.EnumsAsInts {
    92  		return json.Marshal(enum.String())
    93  	}
    94  	return json.Marshal(rv.Interface())
    95  }
    96  
    97  // Unmarshal unmarshals JSON "data" into "v".
    98  // Currently it can only marshal proto.Message.
    99  func (j *JSONPb) Unmarshal(data []byte, v interface{}) error {
   100  	return unmarshalJSONPb(data, v)
   101  }
   102  
   103  // NewDecoder returns a runtime.Decoder which reads JSON stream from "r".
   104  func (j *JSONPb) NewDecoder(r io.Reader) runtime.Decoder {
   105  	d := json.NewDecoder(r)
   106  	return runtime.DecoderFunc(func(v interface{}) error { return decodeJSONPb(d, v) })
   107  }
   108  
   109  // NewEncoder returns an Encoder which writes JSON stream into "w".
   110  func (j *JSONPb) NewEncoder(w io.Writer) runtime.Encoder {
   111  	return runtime.EncoderFunc(func(v interface{}) error { return j.marshalTo(w, v) })
   112  }
   113  
   114  func unmarshalJSONPb(data []byte, v interface{}) error {
   115  	d := json.NewDecoder(bytes.NewReader(data))
   116  	return decodeJSONPb(d, v)
   117  }
   118  
   119  func decodeJSONPb(d *json.Decoder, v interface{}) error {
   120  	p, ok := v.(proto.Message)
   121  	if !ok {
   122  		return decodeNonProtoField(d, v)
   123  	}
   124  	unmarshaler := &jsonpb.Unmarshaler{AllowUnknownFields: true}
   125  	return unmarshaler.UnmarshalNext(d, p)
   126  }
   127  
   128  func decodeNonProtoField(d *json.Decoder, v interface{}) error {
   129  	rv := reflect.ValueOf(v)
   130  	if rv.Kind() != reflect.Ptr {
   131  		return fmt.Errorf("%T is not a pointer", v)
   132  	}
   133  	for rv.Kind() == reflect.Ptr {
   134  		if rv.IsNil() {
   135  			rv.Set(reflect.New(rv.Type().Elem()))
   136  		}
   137  		if rv.Type().ConvertibleTo(typeProtoMessage) {
   138  			unmarshaler := &jsonpb.Unmarshaler{AllowUnknownFields: true}
   139  			return unmarshaler.UnmarshalNext(d, rv.Interface().(proto.Message))
   140  		}
   141  		rv = rv.Elem()
   142  	}
   143  	if rv.Kind() == reflect.Map {
   144  		if rv.IsNil() {
   145  			rv.Set(reflect.MakeMap(rv.Type()))
   146  		}
   147  		conv, ok := convFromType[rv.Type().Key().Kind()]
   148  		if !ok {
   149  			return fmt.Errorf("unsupported type of map field key: %v", rv.Type().Key())
   150  		}
   151  		m := make(map[string]*json.RawMessage)
   152  		if err := d.Decode(&m); err != nil {
   153  			return err
   154  		}
   155  		for k, v := range m {
   156  			result := conv.Call([]reflect.Value{reflect.ValueOf(k)})
   157  			if err := result[1].Interface(); err != nil {
   158  				return err.(error)
   159  			}
   160  			bk := result[0]
   161  			bv := reflect.New(rv.Type().Elem())
   162  			if err := unmarshalJSONPb(*v, bv.Interface()); err != nil {
   163  				return err
   164  			}
   165  			rv.SetMapIndex(bk, bv.Elem())
   166  		}
   167  		return nil
   168  	}
   169  	if _, ok := rv.Interface().(protoEnum); ok {
   170  		var repr interface{}
   171  		if err := d.Decode(&repr); err != nil {
   172  			return err
   173  		}
   174  		switch repr := repr.(type) {
   175  		case string:
   176  			return fmt.Errorf("unmarshaling of symbolic enum %s not supported: %T", repr, rv.Interface())
   177  		case float64:
   178  			rv.Set(reflect.ValueOf(int32(repr)).Convert(rv.Type()))
   179  			return nil
   180  		default:
   181  			return fmt.Errorf("cannot assign %#v into Go type %T", repr, rv.Interface())
   182  		}
   183  	}
   184  	return d.Decode(v)
   185  }
   186  
   187  type protoEnum interface {
   188  	fmt.Stringer
   189  	EnumDescriptor() ([]byte, []int)
   190  }
   191  
   192  var typeProtoMessage = reflect.TypeOf((*proto.Message)(nil)).Elem()
   193  
   194  // Delimiter for newline encoded JSON streams.
   195  func (j *JSONPb) Delimiter() []byte {
   196  	return []byte("\n")
   197  }
   198  
   199  var convFromType = map[reflect.Kind]reflect.Value{
   200  	reflect.String:  reflect.ValueOf(runtime.String),
   201  	reflect.Bool:    reflect.ValueOf(runtime.Bool),
   202  	reflect.Float64: reflect.ValueOf(runtime.Float64),
   203  	reflect.Float32: reflect.ValueOf(runtime.Float32),
   204  	reflect.Int64:   reflect.ValueOf(runtime.Int64),
   205  	reflect.Int32:   reflect.ValueOf(runtime.Int32),
   206  	reflect.Uint64:  reflect.ValueOf(runtime.Uint64),
   207  	reflect.Uint32:  reflect.ValueOf(runtime.Uint32),
   208  	reflect.Slice:   reflect.ValueOf(runtime.Bytes),
   209  }