github.com/Big-big-orange/protoreflect@v0.0.0-20240408141420-285cedfdf6a4/dynamic/json.go (about)

     1  package dynamic
     2  
     3  // JSON marshalling and unmarshalling for dynamic messages
     4  
     5  import (
     6  	"bytes"
     7  	"encoding/base64"
     8  	"encoding/json"
     9  	"fmt"
    10  	"io"
    11  	"io/ioutil"
    12  	"math"
    13  	"reflect"
    14  	"sort"
    15  	"strconv"
    16  	"strings"
    17  
    18  	"github.com/golang/protobuf/jsonpb"
    19  	"github.com/golang/protobuf/proto"
    20  	"google.golang.org/protobuf/types/descriptorpb"
    21  	// link in the well-known-types that have a special JSON format
    22  	_ "google.golang.org/protobuf/types/known/anypb"
    23  	_ "google.golang.org/protobuf/types/known/durationpb"
    24  	_ "google.golang.org/protobuf/types/known/emptypb"
    25  	_ "google.golang.org/protobuf/types/known/structpb"
    26  	_ "google.golang.org/protobuf/types/known/timestamppb"
    27  	_ "google.golang.org/protobuf/types/known/wrapperspb"
    28  
    29  	"github.com/Big-big-orange/protoreflect/desc"
    30  )
    31  
    32  var wellKnownTypeNames = map[string]struct{}{
    33  	"google.protobuf.Any":       {},
    34  	"google.protobuf.Empty":     {},
    35  	"google.protobuf.Duration":  {},
    36  	"google.protobuf.Timestamp": {},
    37  	// struct.proto
    38  	"google.protobuf.Struct":    {},
    39  	"google.protobuf.Value":     {},
    40  	"google.protobuf.ListValue": {},
    41  	// wrappers.proto
    42  	"google.protobuf.DoubleValue": {},
    43  	"google.protobuf.FloatValue":  {},
    44  	"google.protobuf.Int64Value":  {},
    45  	"google.protobuf.UInt64Value": {},
    46  	"google.protobuf.Int32Value":  {},
    47  	"google.protobuf.UInt32Value": {},
    48  	"google.protobuf.BoolValue":   {},
    49  	"google.protobuf.StringValue": {},
    50  	"google.protobuf.BytesValue":  {},
    51  }
    52  
    53  // MarshalJSON serializes this message to bytes in JSON format, returning an
    54  // error if the operation fails. The resulting bytes will be a valid UTF8
    55  // string.
    56  //
    57  // This method uses a compact form: no newlines, and spaces between fields and
    58  // between field identifiers and values are elided.
    59  //
    60  // This method is convenient shorthand for invoking MarshalJSONPB with a default
    61  // (zero value) marshaler:
    62  //
    63  //	m.MarshalJSONPB(&jsonpb.Marshaler{})
    64  //
    65  // So enums are serialized using enum value name strings, and values that are
    66  // not present (including those with default/zero value for messages defined in
    67  // "proto3" syntax) are omitted.
    68  func (m *Message) MarshalJSON() ([]byte, error) {
    69  	return m.MarshalJSONPB(&jsonpb.Marshaler{})
    70  }
    71  
    72  // MarshalJSONIndent serializes this message to bytes in JSON format, returning
    73  // an error if the operation fails. The resulting bytes will be a valid UTF8
    74  // string.
    75  //
    76  // This method uses a "pretty-printed" form, with each field on its own line and
    77  // spaces between field identifiers and values. Indentation of two spaces is
    78  // used.
    79  //
    80  // This method is convenient shorthand for invoking MarshalJSONPB with a default
    81  // (zero value) marshaler:
    82  //
    83  //	m.MarshalJSONPB(&jsonpb.Marshaler{Indent: "  "})
    84  //
    85  // So enums are serialized using enum value name strings, and values that are
    86  // not present (including those with default/zero value for messages defined in
    87  // "proto3" syntax) are omitted.
    88  func (m *Message) MarshalJSONIndent() ([]byte, error) {
    89  	return m.MarshalJSONPB(&jsonpb.Marshaler{Indent: "  "})
    90  }
    91  
    92  // MarshalJSONPB serializes this message to bytes in JSON format, returning an
    93  // error if the operation fails. The resulting bytes will be a valid UTF8
    94  // string. The given marshaler is used to convey options used during marshaling.
    95  //
    96  // If this message contains nested messages that are generated message types (as
    97  // opposed to dynamic messages), the given marshaler is used to marshal it.
    98  //
    99  // When marshaling any nested messages, any jsonpb.AnyResolver configured in the
   100  // given marshaler is augmented with knowledge of message types known to this
   101  // message's descriptor (and its enclosing file and set of transitive
   102  // dependencies).
   103  func (m *Message) MarshalJSONPB(opts *jsonpb.Marshaler) ([]byte, error) {
   104  	var b indentBuffer
   105  	b.indent = opts.Indent
   106  	if len(opts.Indent) == 0 {
   107  		b.indentCount = -1
   108  	}
   109  	b.comma = true
   110  	if err := m.marshalJSON(&b, opts); err != nil {
   111  		return nil, err
   112  	}
   113  	return b.Bytes(), nil
   114  }
   115  
   116  func (m *Message) marshalJSON(b *indentBuffer, opts *jsonpb.Marshaler) error {
   117  	if m == nil {
   118  		_, err := b.WriteString("null")
   119  		return err
   120  	}
   121  	if r, changed := wrapResolver(opts.AnyResolver, m.mf, m.md.GetFile()); changed {
   122  		newOpts := *opts
   123  		newOpts.AnyResolver = r
   124  		opts = &newOpts
   125  	}
   126  
   127  	if ok, err := marshalWellKnownType(m, b, opts); ok {
   128  		return err
   129  	}
   130  
   131  	err := b.WriteByte('{')
   132  	if err != nil {
   133  		return err
   134  	}
   135  	err = b.start()
   136  	if err != nil {
   137  		return err
   138  	}
   139  
   140  	var tags []int
   141  	if opts.EmitDefaults {
   142  		tags = m.allKnownFieldTags()
   143  	} else {
   144  		tags = m.knownFieldTags()
   145  	}
   146  
   147  	first := true
   148  
   149  	for _, tag := range tags {
   150  		itag := int32(tag)
   151  		fd := m.FindFieldDescriptor(itag)
   152  
   153  		v, ok := m.values[itag]
   154  		if !ok {
   155  			if fd.GetOneOf() != nil {
   156  				// don't print defaults for fields in a oneof
   157  				continue
   158  			}
   159  			v = fd.GetDefaultValue()
   160  		}
   161  
   162  		err := b.maybeNext(&first)
   163  		if err != nil {
   164  			return err
   165  		}
   166  		err = marshalKnownFieldJSON(b, fd, v, opts)
   167  		if err != nil {
   168  			return err
   169  		}
   170  	}
   171  
   172  	err = b.end()
   173  	if err != nil {
   174  		return err
   175  	}
   176  	err = b.WriteByte('}')
   177  	if err != nil {
   178  		return err
   179  	}
   180  
   181  	return nil
   182  }
   183  
   184  func marshalWellKnownType(m *Message, b *indentBuffer, opts *jsonpb.Marshaler) (bool, error) {
   185  	fqn := m.md.GetFullyQualifiedName()
   186  	if _, ok := wellKnownTypeNames[fqn]; !ok {
   187  		return false, nil
   188  	}
   189  
   190  	msgType := proto.MessageType(fqn)
   191  	if msgType == nil {
   192  		// wtf?
   193  		panic(fmt.Sprintf("could not find registered message type for %q", fqn))
   194  	}
   195  
   196  	// convert dynamic message to well-known type and let jsonpb marshal it
   197  	msg := reflect.New(msgType.Elem()).Interface().(proto.Message)
   198  	if err := m.MergeInto(msg); err != nil {
   199  		return true, err
   200  	}
   201  	return true, opts.Marshal(b, msg)
   202  }
   203  
   204  func marshalKnownFieldJSON(b *indentBuffer, fd *desc.FieldDescriptor, v interface{}, opts *jsonpb.Marshaler) error {
   205  	var jsonName string
   206  	if opts.OrigName {
   207  		jsonName = fd.GetName()
   208  	} else {
   209  		jsonName = fd.AsFieldDescriptorProto().GetJsonName()
   210  		if jsonName == "" {
   211  			jsonName = fd.GetName()
   212  		}
   213  	}
   214  	if fd.IsExtension() {
   215  		var scope string
   216  		switch parent := fd.GetParent().(type) {
   217  		case *desc.FileDescriptor:
   218  			scope = parent.GetPackage()
   219  		default:
   220  			scope = parent.GetFullyQualifiedName()
   221  		}
   222  		if scope == "" {
   223  			jsonName = fmt.Sprintf("[%s]", jsonName)
   224  		} else {
   225  			jsonName = fmt.Sprintf("[%s.%s]", scope, jsonName)
   226  		}
   227  	}
   228  	err := writeJsonString(b, jsonName)
   229  	if err != nil {
   230  		return err
   231  	}
   232  	err = b.sep()
   233  	if err != nil {
   234  		return err
   235  	}
   236  
   237  	if isNil(v) {
   238  		_, err := b.WriteString("null")
   239  		return err
   240  	}
   241  
   242  	if fd.IsMap() {
   243  		err = b.WriteByte('{')
   244  		if err != nil {
   245  			return err
   246  		}
   247  		err = b.start()
   248  		if err != nil {
   249  			return err
   250  		}
   251  
   252  		md := fd.GetMessageType()
   253  		vfd := md.FindFieldByNumber(2)
   254  
   255  		mp := v.(map[interface{}]interface{})
   256  		keys := make([]interface{}, 0, len(mp))
   257  		for k := range mp {
   258  			keys = append(keys, k)
   259  		}
   260  		sort.Sort(sortable(keys))
   261  		first := true
   262  		for _, mk := range keys {
   263  			mv := mp[mk]
   264  			err := b.maybeNext(&first)
   265  			if err != nil {
   266  				return err
   267  			}
   268  
   269  			err = marshalKnownFieldMapEntryJSON(b, mk, vfd, mv, opts)
   270  			if err != nil {
   271  				return err
   272  			}
   273  		}
   274  
   275  		err = b.end()
   276  		if err != nil {
   277  			return err
   278  		}
   279  		return b.WriteByte('}')
   280  
   281  	} else if fd.IsRepeated() {
   282  		err = b.WriteByte('[')
   283  		if err != nil {
   284  			return err
   285  		}
   286  		err = b.start()
   287  		if err != nil {
   288  			return err
   289  		}
   290  
   291  		sl := v.([]interface{})
   292  		first := true
   293  		for _, slv := range sl {
   294  			err := b.maybeNext(&first)
   295  			if err != nil {
   296  				return err
   297  			}
   298  			err = marshalKnownFieldValueJSON(b, fd, slv, opts)
   299  			if err != nil {
   300  				return err
   301  			}
   302  		}
   303  
   304  		err = b.end()
   305  		if err != nil {
   306  			return err
   307  		}
   308  		return b.WriteByte(']')
   309  
   310  	} else {
   311  		return marshalKnownFieldValueJSON(b, fd, v, opts)
   312  	}
   313  }
   314  
   315  // sortable is used to sort map keys. Values will be integers (int32, int64, uint32, and uint64),
   316  // bools, or strings.
   317  type sortable []interface{}
   318  
   319  func (s sortable) Len() int {
   320  	return len(s)
   321  }
   322  
   323  func (s sortable) Less(i, j int) bool {
   324  	vi := s[i]
   325  	vj := s[j]
   326  	switch reflect.TypeOf(vi).Kind() {
   327  	case reflect.Int32:
   328  		return vi.(int32) < vj.(int32)
   329  	case reflect.Int64:
   330  		return vi.(int64) < vj.(int64)
   331  	case reflect.Uint32:
   332  		return vi.(uint32) < vj.(uint32)
   333  	case reflect.Uint64:
   334  		return vi.(uint64) < vj.(uint64)
   335  	case reflect.String:
   336  		return vi.(string) < vj.(string)
   337  	case reflect.Bool:
   338  		return !vi.(bool) && vj.(bool)
   339  	default:
   340  		panic(fmt.Sprintf("cannot compare keys of type %v", reflect.TypeOf(vi)))
   341  	}
   342  }
   343  
   344  func (s sortable) Swap(i, j int) {
   345  	s[i], s[j] = s[j], s[i]
   346  }
   347  
   348  func isNil(v interface{}) bool {
   349  	if v == nil {
   350  		return true
   351  	}
   352  	rv := reflect.ValueOf(v)
   353  	return rv.Kind() == reflect.Ptr && rv.IsNil()
   354  }
   355  
   356  func marshalKnownFieldMapEntryJSON(b *indentBuffer, mk interface{}, vfd *desc.FieldDescriptor, mv interface{}, opts *jsonpb.Marshaler) error {
   357  	rk := reflect.ValueOf(mk)
   358  	var strkey string
   359  	switch rk.Kind() {
   360  	case reflect.Bool:
   361  		strkey = strconv.FormatBool(rk.Bool())
   362  	case reflect.Int32, reflect.Int64:
   363  		strkey = strconv.FormatInt(rk.Int(), 10)
   364  	case reflect.Uint32, reflect.Uint64:
   365  		strkey = strconv.FormatUint(rk.Uint(), 10)
   366  	case reflect.String:
   367  		strkey = rk.String()
   368  	default:
   369  		return fmt.Errorf("invalid map key value: %v (%v)", mk, rk.Type())
   370  	}
   371  	err := writeJsonString(b, strkey)
   372  	if err != nil {
   373  		return err
   374  	}
   375  	err = b.sep()
   376  	if err != nil {
   377  		return err
   378  	}
   379  	return marshalKnownFieldValueJSON(b, vfd, mv, opts)
   380  }
   381  
   382  func marshalKnownFieldValueJSON(b *indentBuffer, fd *desc.FieldDescriptor, v interface{}, opts *jsonpb.Marshaler) error {
   383  	rv := reflect.ValueOf(v)
   384  	switch rv.Kind() {
   385  	case reflect.Int64:
   386  		value := rv.Int()
   387  		// if value in [-2^53+1, 2^53-1], use as integer
   388  		if value > -9007199254740992 && value < 9007199254740992 {
   389  			_, err := b.WriteString(strconv.FormatInt(value, 10))
   390  			return err
   391  		}
   392  		return writeJsonString(b, strconv.FormatInt(value, 10))
   393  	case reflect.Int32:
   394  		ed := fd.GetEnumType()
   395  		if !opts.EnumsAsInts && ed != nil {
   396  			n := int32(rv.Int())
   397  			vd := ed.FindValueByNumber(n)
   398  			if vd == nil {
   399  				_, err := b.WriteString(strconv.FormatInt(rv.Int(), 10))
   400  				return err
   401  			} else {
   402  				return writeJsonString(b, vd.GetName())
   403  			}
   404  		} else {
   405  			_, err := b.WriteString(strconv.FormatInt(rv.Int(), 10))
   406  			return err
   407  		}
   408  	case reflect.Uint64:
   409  		value := rv.Uint()
   410  		//if value in [0, 2^53-1], use as integer
   411  		if value < 9007199254740992 {
   412  			_, err := b.WriteString(strconv.FormatUint(value, 10))
   413  			return err
   414  		}
   415  		return writeJsonString(b, strconv.FormatUint(value, 10))
   416  	case reflect.Uint32:
   417  		_, err := b.WriteString(strconv.FormatUint(rv.Uint(), 10))
   418  		return err
   419  	case reflect.Float32, reflect.Float64:
   420  		f := rv.Float()
   421  		var str string
   422  		if math.IsNaN(f) {
   423  			str = `"NaN"`
   424  		} else if math.IsInf(f, 1) {
   425  			str = `"Infinity"`
   426  		} else if math.IsInf(f, -1) {
   427  			str = `"-Infinity"`
   428  		} else {
   429  			var bits int
   430  			if rv.Kind() == reflect.Float32 {
   431  				bits = 32
   432  			} else {
   433  				bits = 64
   434  			}
   435  			str = strconv.FormatFloat(rv.Float(), 'g', -1, bits)
   436  		}
   437  		_, err := b.WriteString(str)
   438  		return err
   439  	case reflect.Bool:
   440  		_, err := b.WriteString(strconv.FormatBool(rv.Bool()))
   441  		return err
   442  	case reflect.Slice:
   443  		bstr := base64.StdEncoding.EncodeToString(rv.Bytes())
   444  		return writeJsonString(b, bstr)
   445  	case reflect.String:
   446  		return writeJsonString(b, rv.String())
   447  	default:
   448  		// must be a message
   449  		if isNil(v) {
   450  			_, err := b.WriteString("null")
   451  			return err
   452  		}
   453  
   454  		if dm, ok := v.(*Message); ok {
   455  			return dm.marshalJSON(b, opts)
   456  		}
   457  
   458  		var err error
   459  		if b.indentCount <= 0 || len(b.indent) == 0 {
   460  			err = opts.Marshal(b, v.(proto.Message))
   461  		} else {
   462  			str, err := opts.MarshalToString(v.(proto.Message))
   463  			if err != nil {
   464  				return err
   465  			}
   466  			indent := strings.Repeat(b.indent, b.indentCount)
   467  			pos := 0
   468  			// add indention prefix to each line
   469  			for pos < len(str) {
   470  				start := pos
   471  				nextPos := strings.Index(str[pos:], "\n")
   472  				if nextPos == -1 {
   473  					nextPos = len(str)
   474  				} else {
   475  					nextPos = pos + nextPos + 1 // include newline
   476  				}
   477  				line := str[start:nextPos]
   478  				if pos > 0 {
   479  					_, err = b.WriteString(indent)
   480  					if err != nil {
   481  						return err
   482  					}
   483  				}
   484  				_, err = b.WriteString(line)
   485  				if err != nil {
   486  					return err
   487  				}
   488  				pos = nextPos
   489  			}
   490  		}
   491  		return err
   492  	}
   493  }
   494  
   495  func writeJsonString(b *indentBuffer, s string) error {
   496  	if sbytes, err := json.Marshal(s); err != nil {
   497  		return err
   498  	} else {
   499  		_, err := b.Write(sbytes)
   500  		return err
   501  	}
   502  }
   503  
   504  // UnmarshalJSON de-serializes the message that is present, in JSON format, in
   505  // the given bytes into this message. It first resets the current message. It
   506  // returns an error if the given bytes do not contain a valid encoding of this
   507  // message type in JSON format.
   508  //
   509  // This method is shorthand for invoking UnmarshalJSONPB with a default (zero
   510  // value) unmarshaler:
   511  //
   512  //	m.UnmarshalMergeJSONPB(&jsonpb.Unmarshaler{}, js)
   513  //
   514  // So unknown fields will result in an error, and no provided jsonpb.AnyResolver
   515  // will be used when parsing google.protobuf.Any messages.
   516  func (m *Message) UnmarshalJSON(js []byte) error {
   517  	return m.UnmarshalJSONPB(&jsonpb.Unmarshaler{}, js)
   518  }
   519  
   520  // UnmarshalMergeJSON de-serializes the message that is present, in JSON format,
   521  // in the given bytes into this message. Unlike UnmarshalJSON, it does not first
   522  // reset the message, instead merging the data in the given bytes into the
   523  // existing data in this message.
   524  func (m *Message) UnmarshalMergeJSON(js []byte) error {
   525  	return m.UnmarshalMergeJSONPB(&jsonpb.Unmarshaler{}, js)
   526  }
   527  
   528  // UnmarshalJSONPB de-serializes the message that is present, in JSON format, in
   529  // the given bytes into this message. The given unmarshaler conveys options used
   530  // when parsing the JSON. This function first resets the current message. It
   531  // returns an error if the given bytes do not contain a valid encoding of this
   532  // message type in JSON format.
   533  //
   534  // The decoding is lenient:
   535  //  1. The JSON can refer to fields either by their JSON name or by their
   536  //     declared name.
   537  //  2. The JSON can use either numeric values or string names for enum values.
   538  //
   539  // When instantiating nested messages, if this message's associated factory
   540  // returns a generated message type (as opposed to a dynamic message), the given
   541  // unmarshaler is used to unmarshal it.
   542  //
   543  // When unmarshaling any nested messages, any jsonpb.AnyResolver configured in
   544  // the given unmarshaler is augmented with knowledge of message types known to
   545  // this message's descriptor (and its enclosing file and set of transitive
   546  // dependencies).
   547  func (m *Message) UnmarshalJSONPB(opts *jsonpb.Unmarshaler, js []byte) error {
   548  	m.Reset()
   549  	if err := m.UnmarshalMergeJSONPB(opts, js); err != nil {
   550  		return err
   551  	}
   552  	return m.Validate()
   553  }
   554  
   555  // UnmarshalMergeJSONPB de-serializes the message that is present, in JSON
   556  // format, in the given bytes into this message. The given unmarshaler conveys
   557  // options used when parsing the JSON. Unlike UnmarshalJSONPB, it does not first
   558  // reset the message, instead merging the data in the given bytes into the
   559  // existing data in this message.
   560  func (m *Message) UnmarshalMergeJSONPB(opts *jsonpb.Unmarshaler, js []byte) error {
   561  	r := newJsReader(js)
   562  	err := m.unmarshalJson(r, opts)
   563  	if err != nil {
   564  		return err
   565  	}
   566  	if t, err := r.poll(); err != io.EOF {
   567  		b, _ := ioutil.ReadAll(r.unread())
   568  		s := fmt.Sprintf("%v%s", t, string(b))
   569  		return fmt.Errorf("superfluous data found after JSON object: %q", s)
   570  	}
   571  	return nil
   572  }
   573  
   574  func unmarshalWellKnownType(m *Message, r *jsReader, opts *jsonpb.Unmarshaler) (bool, error) {
   575  	fqn := m.md.GetFullyQualifiedName()
   576  	if _, ok := wellKnownTypeNames[fqn]; !ok {
   577  		return false, nil
   578  	}
   579  
   580  	msgType := proto.MessageType(fqn)
   581  	if msgType == nil {
   582  		// wtf?
   583  		panic(fmt.Sprintf("could not find registered message type for %q", fqn))
   584  	}
   585  
   586  	// extract json value from r
   587  	var js json.RawMessage
   588  	if err := json.NewDecoder(r.unread()).Decode(&js); err != nil {
   589  		return true, err
   590  	}
   591  	if err := r.skip(); err != nil {
   592  		return true, err
   593  	}
   594  
   595  	// unmarshal into well-known type and then convert to dynamic message
   596  	msg := reflect.New(msgType.Elem()).Interface().(proto.Message)
   597  	if err := opts.Unmarshal(bytes.NewReader(js), msg); err != nil {
   598  		return true, err
   599  	}
   600  	return true, m.MergeFrom(msg)
   601  }
   602  
   603  func (m *Message) unmarshalJson(r *jsReader, opts *jsonpb.Unmarshaler) error {
   604  	if r, changed := wrapResolver(opts.AnyResolver, m.mf, m.md.GetFile()); changed {
   605  		newOpts := *opts
   606  		newOpts.AnyResolver = r
   607  		opts = &newOpts
   608  	}
   609  
   610  	if ok, err := unmarshalWellKnownType(m, r, opts); ok {
   611  		return err
   612  	}
   613  
   614  	t, err := r.peek()
   615  	if err != nil {
   616  		return err
   617  	}
   618  	if t == nil {
   619  		// if json is simply "null" we do nothing
   620  		r.poll()
   621  		return nil
   622  	}
   623  
   624  	if err := r.beginObject(); err != nil {
   625  		return err
   626  	}
   627  
   628  	for r.hasNext() {
   629  		f, err := r.nextObjectKey()
   630  		if err != nil {
   631  			return err
   632  		}
   633  		fd := m.FindFieldDescriptorByJSONName(f)
   634  		if fd == nil {
   635  			if opts.AllowUnknownFields {
   636  				r.skip()
   637  				continue
   638  			}
   639  			return fmt.Errorf("message type %s has no known field named %s", m.md.GetFullyQualifiedName(), f)
   640  		}
   641  		v, err := unmarshalJsField(fd, r, m.mf, opts)
   642  		if err != nil {
   643  			return err
   644  		}
   645  		if v != nil {
   646  			if err := mergeField(m, fd, v); err != nil {
   647  				return err
   648  			}
   649  		} else if fd.GetOneOf() != nil {
   650  			// preserve explicit null for oneof fields (this is a little odd but
   651  			// mimics the behavior of jsonpb with oneofs in generated message types)
   652  			if fd.GetMessageType() != nil {
   653  				typ := m.mf.GetKnownTypeRegistry().GetKnownType(fd.GetMessageType().GetFullyQualifiedName())
   654  				if typ != nil {
   655  					// typed nil
   656  					if typ.Kind() != reflect.Ptr {
   657  						typ = reflect.PtrTo(typ)
   658  					}
   659  					v = reflect.Zero(typ).Interface()
   660  				} else {
   661  					// can't use nil dynamic message, so we just use empty one instead
   662  					v = m.mf.NewDynamicMessage(fd.GetMessageType())
   663  				}
   664  				if err := m.setField(fd, v); err != nil {
   665  					return err
   666  				}
   667  			} else {
   668  				// not a message... explicit null makes no sense
   669  				return fmt.Errorf("message type %s cannot set field %s to null: it is not a message type", m.md.GetFullyQualifiedName(), f)
   670  			}
   671  		} else {
   672  			m.clearField(fd)
   673  		}
   674  	}
   675  
   676  	if err := r.endObject(); err != nil {
   677  		return err
   678  	}
   679  
   680  	return nil
   681  }
   682  
   683  func isWellKnownValue(fd *desc.FieldDescriptor) bool {
   684  	return !fd.IsRepeated() && fd.GetType() == descriptorpb.FieldDescriptorProto_TYPE_MESSAGE &&
   685  		fd.GetMessageType().GetFullyQualifiedName() == "google.protobuf.Value"
   686  }
   687  
   688  func isWellKnownListValue(fd *desc.FieldDescriptor) bool {
   689  	// we look for ListValue; but we also look for Value, which can be assigned a ListValue
   690  	return !fd.IsRepeated() && fd.GetType() == descriptorpb.FieldDescriptorProto_TYPE_MESSAGE &&
   691  		(fd.GetMessageType().GetFullyQualifiedName() == "google.protobuf.ListValue" ||
   692  			fd.GetMessageType().GetFullyQualifiedName() == "google.protobuf.Value")
   693  }
   694  
   695  func unmarshalJsField(fd *desc.FieldDescriptor, r *jsReader, mf *MessageFactory, opts *jsonpb.Unmarshaler) (interface{}, error) {
   696  	t, err := r.peek()
   697  	if err != nil {
   698  		return nil, err
   699  	}
   700  	if t == nil && !isWellKnownValue(fd) {
   701  		// if value is null, just return nil
   702  		// (unless field is google.protobuf.Value, in which case
   703  		// we fall through to parse it as an instance where its
   704  		// underlying value is set to a NullValue)
   705  		r.poll()
   706  		return nil, nil
   707  	}
   708  
   709  	if t == json.Delim('{') && fd.IsMap() {
   710  		entryType := fd.GetMessageType()
   711  		keyType := entryType.FindFieldByNumber(1)
   712  		valueType := entryType.FindFieldByNumber(2)
   713  		mp := map[interface{}]interface{}{}
   714  
   715  		// TODO: if there are just two map keys "key" and "value" and they have the right type of values,
   716  		// treat this JSON object as a single map entry message. (In keeping with support of map fields as
   717  		// if they were normal repeated field of entry messages as well as supporting a transition from
   718  		// optional to repeated...)
   719  
   720  		if err := r.beginObject(); err != nil {
   721  			return nil, err
   722  		}
   723  		for r.hasNext() {
   724  			kk, err := unmarshalJsFieldElement(keyType, r, mf, opts, false)
   725  			if err != nil {
   726  				return nil, err
   727  			}
   728  			vv, err := unmarshalJsFieldElement(valueType, r, mf, opts, true)
   729  			if err != nil {
   730  				return nil, err
   731  			}
   732  			mp[kk] = vv
   733  		}
   734  		if err := r.endObject(); err != nil {
   735  			return nil, err
   736  		}
   737  
   738  		return mp, nil
   739  	} else if t == json.Delim('[') && !isWellKnownListValue(fd) {
   740  		// We support parsing an array, even if field is not repeated, to mimic support in proto
   741  		// binary wire format that supports changing an optional field to repeated and vice versa.
   742  		// If the field is not repeated, we only keep the last value in the array.
   743  
   744  		if err := r.beginArray(); err != nil {
   745  			return nil, err
   746  		}
   747  		var sl []interface{}
   748  		var v interface{}
   749  		for r.hasNext() {
   750  			var err error
   751  			v, err = unmarshalJsFieldElement(fd, r, mf, opts, false)
   752  			if err != nil {
   753  				return nil, err
   754  			}
   755  			if fd.IsRepeated() && v != nil {
   756  				sl = append(sl, v)
   757  			}
   758  		}
   759  		if err := r.endArray(); err != nil {
   760  			return nil, err
   761  		}
   762  		if fd.IsMap() {
   763  			mp := map[interface{}]interface{}{}
   764  			for _, m := range sl {
   765  				msg := m.(*Message)
   766  				kk, err := msg.TryGetFieldByNumber(1)
   767  				if err != nil {
   768  					return nil, err
   769  				}
   770  				vv, err := msg.TryGetFieldByNumber(2)
   771  				if err != nil {
   772  					return nil, err
   773  				}
   774  				mp[kk] = vv
   775  			}
   776  			return mp, nil
   777  		} else if fd.IsRepeated() {
   778  			return sl, nil
   779  		} else {
   780  			return v, nil
   781  		}
   782  	} else {
   783  		// We support parsing a singular value, even if field is repeated, to mimic support in proto
   784  		// binary wire format that supports changing an optional field to repeated and vice versa.
   785  		// If the field is repeated, we store value as singleton slice of that one value.
   786  
   787  		v, err := unmarshalJsFieldElement(fd, r, mf, opts, false)
   788  		if err != nil {
   789  			return nil, err
   790  		}
   791  		if v == nil {
   792  			return nil, nil
   793  		}
   794  		if fd.IsRepeated() {
   795  			return []interface{}{v}, nil
   796  		} else {
   797  			return v, nil
   798  		}
   799  	}
   800  }
   801  
   802  func unmarshalJsFieldElement(fd *desc.FieldDescriptor, r *jsReader, mf *MessageFactory, opts *jsonpb.Unmarshaler, allowNilMessage bool) (interface{}, error) {
   803  	t, err := r.peek()
   804  	if err != nil {
   805  		return nil, err
   806  	}
   807  
   808  	switch fd.GetType() {
   809  	case descriptorpb.FieldDescriptorProto_TYPE_MESSAGE,
   810  		descriptorpb.FieldDescriptorProto_TYPE_GROUP:
   811  
   812  		if t == nil && allowNilMessage {
   813  			// if json is simply "null" return a nil pointer
   814  			r.poll()
   815  			return nilMessage(fd.GetMessageType()), nil
   816  		}
   817  
   818  		m := mf.NewMessage(fd.GetMessageType())
   819  		if dm, ok := m.(*Message); ok {
   820  			if err := dm.unmarshalJson(r, opts); err != nil {
   821  				return nil, err
   822  			}
   823  		} else {
   824  			var msg json.RawMessage
   825  			if err := json.NewDecoder(r.unread()).Decode(&msg); err != nil {
   826  				return nil, err
   827  			}
   828  			if err := r.skip(); err != nil {
   829  				return nil, err
   830  			}
   831  			if err := opts.Unmarshal(bytes.NewReader([]byte(msg)), m); err != nil {
   832  				return nil, err
   833  			}
   834  		}
   835  		return m, nil
   836  
   837  	case descriptorpb.FieldDescriptorProto_TYPE_ENUM:
   838  		if e, err := r.nextNumber(); err != nil {
   839  			return nil, err
   840  		} else {
   841  			// value could be string or number
   842  			if i, err := e.Int64(); err != nil {
   843  				// number cannot be parsed, so see if it's an enum value name
   844  				vd := fd.GetEnumType().FindValueByName(string(e))
   845  				if vd != nil {
   846  					return vd.GetNumber(), nil
   847  				} else {
   848  					return nil, fmt.Errorf("enum %q does not have value named %q", fd.GetEnumType().GetFullyQualifiedName(), e)
   849  				}
   850  			} else if i > math.MaxInt32 || i < math.MinInt32 {
   851  				return nil, NumericOverflowError
   852  			} else {
   853  				return int32(i), err
   854  			}
   855  		}
   856  
   857  	case descriptorpb.FieldDescriptorProto_TYPE_INT32,
   858  		descriptorpb.FieldDescriptorProto_TYPE_SINT32,
   859  		descriptorpb.FieldDescriptorProto_TYPE_SFIXED32:
   860  		if i, err := r.nextInt(); err != nil {
   861  			return nil, err
   862  		} else if i > math.MaxInt32 || i < math.MinInt32 {
   863  			return nil, NumericOverflowError
   864  		} else {
   865  			return int32(i), err
   866  		}
   867  
   868  	case descriptorpb.FieldDescriptorProto_TYPE_INT64,
   869  		descriptorpb.FieldDescriptorProto_TYPE_SINT64,
   870  		descriptorpb.FieldDescriptorProto_TYPE_SFIXED64:
   871  		return r.nextInt()
   872  
   873  	case descriptorpb.FieldDescriptorProto_TYPE_UINT32,
   874  		descriptorpb.FieldDescriptorProto_TYPE_FIXED32:
   875  		if i, err := r.nextUint(); err != nil {
   876  			return nil, err
   877  		} else if i > math.MaxUint32 {
   878  			return nil, NumericOverflowError
   879  		} else {
   880  			return uint32(i), err
   881  		}
   882  
   883  	case descriptorpb.FieldDescriptorProto_TYPE_UINT64,
   884  		descriptorpb.FieldDescriptorProto_TYPE_FIXED64:
   885  		return r.nextUint()
   886  
   887  	case descriptorpb.FieldDescriptorProto_TYPE_BOOL:
   888  		if str, ok := t.(string); ok {
   889  			if str == "true" {
   890  				r.poll() // consume token
   891  				return true, err
   892  			} else if str == "false" {
   893  				r.poll() // consume token
   894  				return false, err
   895  			}
   896  		}
   897  		return r.nextBool()
   898  
   899  	case descriptorpb.FieldDescriptorProto_TYPE_FLOAT:
   900  		if f, err := r.nextFloat(); err != nil {
   901  			return nil, err
   902  		} else {
   903  			return float32(f), nil
   904  		}
   905  
   906  	case descriptorpb.FieldDescriptorProto_TYPE_DOUBLE:
   907  		return r.nextFloat()
   908  
   909  	case descriptorpb.FieldDescriptorProto_TYPE_BYTES:
   910  		return r.nextBytes()
   911  
   912  	case descriptorpb.FieldDescriptorProto_TYPE_STRING:
   913  		return r.nextString()
   914  
   915  	default:
   916  		return nil, fmt.Errorf("unknown field type: %v", fd.GetType())
   917  	}
   918  }
   919  
   920  type jsReader struct {
   921  	reader  *bytes.Reader
   922  	dec     *json.Decoder
   923  	current json.Token
   924  	peeked  bool
   925  }
   926  
   927  func newJsReader(b []byte) *jsReader {
   928  	reader := bytes.NewReader(b)
   929  	dec := json.NewDecoder(reader)
   930  	dec.UseNumber()
   931  	return &jsReader{reader: reader, dec: dec}
   932  }
   933  
   934  func (r *jsReader) unread() io.Reader {
   935  	bufs := make([]io.Reader, 3)
   936  	var peeked []byte
   937  	if r.peeked {
   938  		if _, ok := r.current.(json.Delim); ok {
   939  			peeked = []byte(fmt.Sprintf("%v", r.current))
   940  		} else {
   941  			peeked, _ = json.Marshal(r.current)
   942  		}
   943  	}
   944  	readerCopy := *r.reader
   945  	decCopy := *r.dec
   946  
   947  	bufs[0] = bytes.NewReader(peeked)
   948  	bufs[1] = decCopy.Buffered()
   949  	bufs[2] = &readerCopy
   950  	return &concatReader{bufs: bufs}
   951  }
   952  
   953  func (r *jsReader) hasNext() bool {
   954  	return r.dec.More()
   955  }
   956  
   957  func (r *jsReader) peek() (json.Token, error) {
   958  	if r.peeked {
   959  		return r.current, nil
   960  	}
   961  	t, err := r.dec.Token()
   962  	if err != nil {
   963  		return nil, err
   964  	}
   965  	r.peeked = true
   966  	r.current = t
   967  	return t, nil
   968  }
   969  
   970  func (r *jsReader) poll() (json.Token, error) {
   971  	if r.peeked {
   972  		ret := r.current
   973  		r.current = nil
   974  		r.peeked = false
   975  		return ret, nil
   976  	}
   977  	return r.dec.Token()
   978  }
   979  
   980  func (r *jsReader) beginObject() error {
   981  	_, err := r.expect(func(t json.Token) bool { return t == json.Delim('{') }, nil, "start of JSON object: '{'")
   982  	return err
   983  }
   984  
   985  func (r *jsReader) endObject() error {
   986  	_, err := r.expect(func(t json.Token) bool { return t == json.Delim('}') }, nil, "end of JSON object: '}'")
   987  	return err
   988  }
   989  
   990  func (r *jsReader) beginArray() error {
   991  	_, err := r.expect(func(t json.Token) bool { return t == json.Delim('[') }, nil, "start of array: '['")
   992  	return err
   993  }
   994  
   995  func (r *jsReader) endArray() error {
   996  	_, err := r.expect(func(t json.Token) bool { return t == json.Delim(']') }, nil, "end of array: ']'")
   997  	return err
   998  }
   999  
  1000  func (r *jsReader) nextObjectKey() (string, error) {
  1001  	return r.nextString()
  1002  }
  1003  
  1004  func (r *jsReader) nextString() (string, error) {
  1005  	t, err := r.expect(func(t json.Token) bool { _, ok := t.(string); return ok }, "", "string")
  1006  	if err != nil {
  1007  		return "", err
  1008  	}
  1009  	return t.(string), nil
  1010  }
  1011  
  1012  func (r *jsReader) nextBytes() ([]byte, error) {
  1013  	str, err := r.nextString()
  1014  	if err != nil {
  1015  		return nil, err
  1016  	}
  1017  	return base64.StdEncoding.DecodeString(str)
  1018  }
  1019  
  1020  func (r *jsReader) nextBool() (bool, error) {
  1021  	t, err := r.expect(func(t json.Token) bool { _, ok := t.(bool); return ok }, false, "boolean")
  1022  	if err != nil {
  1023  		return false, err
  1024  	}
  1025  	return t.(bool), nil
  1026  }
  1027  
  1028  func (r *jsReader) nextInt() (int64, error) {
  1029  	n, err := r.nextNumber()
  1030  	if err != nil {
  1031  		return 0, err
  1032  	}
  1033  	return n.Int64()
  1034  }
  1035  
  1036  func (r *jsReader) nextUint() (uint64, error) {
  1037  	n, err := r.nextNumber()
  1038  	if err != nil {
  1039  		return 0, err
  1040  	}
  1041  	return strconv.ParseUint(string(n), 10, 64)
  1042  }
  1043  
  1044  func (r *jsReader) nextFloat() (float64, error) {
  1045  	n, err := r.nextNumber()
  1046  	if err != nil {
  1047  		return 0, err
  1048  	}
  1049  	return n.Float64()
  1050  }
  1051  
  1052  func (r *jsReader) nextNumber() (json.Number, error) {
  1053  	t, err := r.expect(func(t json.Token) bool { return reflect.TypeOf(t).Kind() == reflect.String }, "0", "number")
  1054  	if err != nil {
  1055  		return "", err
  1056  	}
  1057  	switch t := t.(type) {
  1058  	case json.Number:
  1059  		return t, nil
  1060  	case string:
  1061  		return json.Number(t), nil
  1062  	}
  1063  	return "", fmt.Errorf("expecting a number but got %v", t)
  1064  }
  1065  
  1066  func (r *jsReader) skip() error {
  1067  	t, err := r.poll()
  1068  	if err != nil {
  1069  		return err
  1070  	}
  1071  	if t == json.Delim('[') {
  1072  		if err := r.skipArray(); err != nil {
  1073  			return err
  1074  		}
  1075  	} else if t == json.Delim('{') {
  1076  		if err := r.skipObject(); err != nil {
  1077  			return err
  1078  		}
  1079  	}
  1080  	return nil
  1081  }
  1082  
  1083  func (r *jsReader) skipArray() error {
  1084  	for r.hasNext() {
  1085  		if err := r.skip(); err != nil {
  1086  			return err
  1087  		}
  1088  	}
  1089  	if err := r.endArray(); err != nil {
  1090  		return err
  1091  	}
  1092  	return nil
  1093  }
  1094  
  1095  func (r *jsReader) skipObject() error {
  1096  	for r.hasNext() {
  1097  		// skip object key
  1098  		if err := r.skip(); err != nil {
  1099  			return err
  1100  		}
  1101  		// and value
  1102  		if err := r.skip(); err != nil {
  1103  			return err
  1104  		}
  1105  	}
  1106  	if err := r.endObject(); err != nil {
  1107  		return err
  1108  	}
  1109  	return nil
  1110  }
  1111  
  1112  func (r *jsReader) expect(predicate func(json.Token) bool, ifNil interface{}, expected string) (interface{}, error) {
  1113  	t, err := r.poll()
  1114  	if err != nil {
  1115  		return nil, err
  1116  	}
  1117  	if t == nil && ifNil != nil {
  1118  		return ifNil, nil
  1119  	}
  1120  	if !predicate(t) {
  1121  		return t, fmt.Errorf("bad input: expecting %s ; instead got %v", expected, t)
  1122  	}
  1123  	return t, nil
  1124  }
  1125  
  1126  type concatReader struct {
  1127  	bufs []io.Reader
  1128  	curr int
  1129  }
  1130  
  1131  func (r *concatReader) Read(p []byte) (n int, err error) {
  1132  	for {
  1133  		if r.curr >= len(r.bufs) {
  1134  			err = io.EOF
  1135  			return
  1136  		}
  1137  		var c int
  1138  		c, err = r.bufs[r.curr].Read(p)
  1139  		n += c
  1140  		if err != io.EOF {
  1141  			return
  1142  		}
  1143  		r.curr++
  1144  		p = p[c:]
  1145  	}
  1146  }
  1147  
  1148  // AnyResolver returns a jsonpb.AnyResolver that uses the given file descriptors
  1149  // to resolve message names. It uses the given factory, which may be nil, to
  1150  // instantiate messages. The messages that it returns when resolving a type name
  1151  // may often be dynamic messages.
  1152  func AnyResolver(mf *MessageFactory, files ...*desc.FileDescriptor) jsonpb.AnyResolver {
  1153  	return &anyResolver{mf: mf, files: files}
  1154  }
  1155  
  1156  type anyResolver struct {
  1157  	mf      *MessageFactory
  1158  	files   []*desc.FileDescriptor
  1159  	ignored map[*desc.FileDescriptor]struct{}
  1160  	other   jsonpb.AnyResolver
  1161  }
  1162  
  1163  func wrapResolver(r jsonpb.AnyResolver, mf *MessageFactory, f *desc.FileDescriptor) (jsonpb.AnyResolver, bool) {
  1164  	if r, ok := r.(*anyResolver); ok {
  1165  		if _, ok := r.ignored[f]; ok {
  1166  			// if the current resolver is ignoring this file, it's because another
  1167  			// (upstream) resolver is already handling it, so nothing to do
  1168  			return r, false
  1169  		}
  1170  		for _, file := range r.files {
  1171  			if file == f {
  1172  				// no need to wrap!
  1173  				return r, false
  1174  			}
  1175  		}
  1176  		// ignore files that will be checked by the resolver we're wrapping
  1177  		// (we'll just delegate and let it search those files)
  1178  		ignored := map[*desc.FileDescriptor]struct{}{}
  1179  		for i := range r.ignored {
  1180  			ignored[i] = struct{}{}
  1181  		}
  1182  		ignore(r.files, ignored)
  1183  		return &anyResolver{mf: mf, files: []*desc.FileDescriptor{f}, ignored: ignored, other: r}, true
  1184  	}
  1185  	return &anyResolver{mf: mf, files: []*desc.FileDescriptor{f}, other: r}, true
  1186  }
  1187  
  1188  func ignore(files []*desc.FileDescriptor, ignored map[*desc.FileDescriptor]struct{}) {
  1189  	for _, f := range files {
  1190  		if _, ok := ignored[f]; ok {
  1191  			continue
  1192  		}
  1193  		ignored[f] = struct{}{}
  1194  		ignore(f.GetDependencies(), ignored)
  1195  	}
  1196  }
  1197  
  1198  func (r *anyResolver) Resolve(typeUrl string) (proto.Message, error) {
  1199  	mname := typeUrl
  1200  	if slash := strings.LastIndex(mname, "/"); slash >= 0 {
  1201  		mname = mname[slash+1:]
  1202  	}
  1203  
  1204  	// see if the user-specified resolver is able to do the job
  1205  	if r.other != nil {
  1206  		msg, err := r.other.Resolve(typeUrl)
  1207  		if err == nil {
  1208  			return msg, nil
  1209  		}
  1210  	}
  1211  
  1212  	// try to find the message in our known set of files
  1213  	checked := map[*desc.FileDescriptor]struct{}{}
  1214  	for _, f := range r.files {
  1215  		md := r.findMessage(f, mname, checked)
  1216  		if md != nil {
  1217  			return r.mf.NewMessage(md), nil
  1218  		}
  1219  	}
  1220  	// failing that, see if the message factory knows about this type
  1221  	var ktr *KnownTypeRegistry
  1222  	if r.mf != nil {
  1223  		ktr = r.mf.ktr
  1224  	} else {
  1225  		ktr = (*KnownTypeRegistry)(nil)
  1226  	}
  1227  	m := ktr.CreateIfKnown(mname)
  1228  	if m != nil {
  1229  		return m, nil
  1230  	}
  1231  
  1232  	// no other resolver to fallback to? mimic default behavior
  1233  	mt := proto.MessageType(mname)
  1234  	if mt == nil {
  1235  		return nil, fmt.Errorf("unknown message type %q", mname)
  1236  	}
  1237  	return reflect.New(mt.Elem()).Interface().(proto.Message), nil
  1238  }
  1239  
  1240  func (r *anyResolver) findMessage(fd *desc.FileDescriptor, msgName string, checked map[*desc.FileDescriptor]struct{}) *desc.MessageDescriptor {
  1241  	// if this is an ignored descriptor, skip
  1242  	if _, ok := r.ignored[fd]; ok {
  1243  		return nil
  1244  	}
  1245  
  1246  	// bail if we've already checked this file
  1247  	if _, ok := checked[fd]; ok {
  1248  		return nil
  1249  	}
  1250  	checked[fd] = struct{}{}
  1251  
  1252  	// see if this file has the message
  1253  	md := fd.FindMessage(msgName)
  1254  	if md != nil {
  1255  		return md
  1256  	}
  1257  
  1258  	// if not, recursively search the file's imports
  1259  	for _, dep := range fd.GetDependencies() {
  1260  		md = r.findMessage(dep, msgName, checked)
  1261  		if md != nil {
  1262  			return md
  1263  		}
  1264  	}
  1265  	return nil
  1266  }
  1267  
  1268  var _ jsonpb.AnyResolver = (*anyResolver)(nil)