github.com/bakjos/protoreflect@v1.9.2/dynamic/json.go (about)

     1  package dynamic
     2  
     3  // JSON marshalling and unmarshalling for dynamic messages
     4  
     5  import (
     6  	"bytes"
     7  	"encoding/base64"
     8  	"encoding/json"
     9  	"fmt"
    10  	"io"
    11  	"io/ioutil"
    12  	"math"
    13  	"reflect"
    14  	"sort"
    15  	"strconv"
    16  	"strings"
    17  
    18  	"github.com/golang/protobuf/jsonpb"
    19  	"github.com/golang/protobuf/proto"
    20  	"github.com/golang/protobuf/protoc-gen-go/descriptor"
    21  
    22  	// link in the well-known-types that have a special JSON format
    23  	_ "github.com/golang/protobuf/ptypes/any"
    24  	_ "github.com/golang/protobuf/ptypes/duration"
    25  	_ "github.com/golang/protobuf/ptypes/empty"
    26  	_ "github.com/golang/protobuf/ptypes/struct"
    27  	_ "github.com/golang/protobuf/ptypes/timestamp"
    28  	_ "github.com/golang/protobuf/ptypes/wrappers"
    29  
    30  	"github.com/bakjos/protoreflect/desc"
    31  )
    32  
    33  var wellKnownTypeNames = map[string]struct{}{
    34  	"google.protobuf.Any":       {},
    35  	"google.protobuf.Empty":     {},
    36  	"google.protobuf.Duration":  {},
    37  	"google.protobuf.Timestamp": {},
    38  	// struct.proto
    39  	"google.protobuf.Struct":    {},
    40  	"google.protobuf.Value":     {},
    41  	"google.protobuf.ListValue": {},
    42  	// wrappers.proto
    43  	"google.protobuf.DoubleValue": {},
    44  	"google.protobuf.FloatValue":  {},
    45  	"google.protobuf.Int64Value":  {},
    46  	"google.protobuf.UInt64Value": {},
    47  	"google.protobuf.Int32Value":  {},
    48  	"google.protobuf.UInt32Value": {},
    49  	"google.protobuf.BoolValue":   {},
    50  	"google.protobuf.StringValue": {},
    51  	"google.protobuf.BytesValue":  {},
    52  }
    53  
    54  // MarshalJSON serializes this message to bytes in JSON format, returning an
    55  // error if the operation fails. The resulting bytes will be a valid UTF8
    56  // string.
    57  //
    58  // This method uses a compact form: no newlines, and spaces between fields and
    59  // between field identifiers and values are elided.
    60  //
    61  // This method is convenient shorthand for invoking MarshalJSONPB with a default
    62  // (zero value) marshaler:
    63  //
    64  //    m.MarshalJSONPB(&jsonpb.Marshaler{})
    65  //
    66  // So enums are serialized using enum value name strings, and values that are
    67  // not present (including those with default/zero value for messages defined in
    68  // "proto3" syntax) are omitted.
    69  func (m *Message) MarshalJSON() ([]byte, error) {
    70  	return m.MarshalJSONPB(&jsonpb.Marshaler{})
    71  }
    72  
    73  // MarshalJSONIndent serializes this message to bytes in JSON format, returning
    74  // an error if the operation fails. The resulting bytes will be a valid UTF8
    75  // string.
    76  //
    77  // This method uses a "pretty-printed" form, with each field on its own line and
    78  // spaces between field identifiers and values. Indentation of two spaces is
    79  // used.
    80  //
    81  // This method is convenient shorthand for invoking MarshalJSONPB with a default
    82  // (zero value) marshaler:
    83  //
    84  //    m.MarshalJSONPB(&jsonpb.Marshaler{Indent: "  "})
    85  //
    86  // So enums are serialized using enum value name strings, and values that are
    87  // not present (including those with default/zero value for messages defined in
    88  // "proto3" syntax) are omitted.
    89  func (m *Message) MarshalJSONIndent() ([]byte, error) {
    90  	return m.MarshalJSONPB(&jsonpb.Marshaler{Indent: "  "})
    91  }
    92  
    93  // MarshalJSONPB serializes this message to bytes in JSON format, returning an
    94  // error if the operation fails. The resulting bytes will be a valid UTF8
    95  // string. The given marshaler is used to convey options used during marshaling.
    96  //
    97  // If this message contains nested messages that are generated message types (as
    98  // opposed to dynamic messages), the given marshaler is used to marshal it.
    99  //
   100  // When marshaling any nested messages, any jsonpb.AnyResolver configured in the
   101  // given marshaler is augmented with knowledge of message types known to this
   102  // message's descriptor (and its enclosing file and set of transitive
   103  // dependencies).
   104  func (m *Message) MarshalJSONPB(opts *jsonpb.Marshaler) ([]byte, error) {
   105  	var b indentBuffer
   106  	b.indent = opts.Indent
   107  	if len(opts.Indent) == 0 {
   108  		b.indentCount = -1
   109  	}
   110  	b.comma = true
   111  	if err := m.marshalJSON(&b, opts); err != nil {
   112  		return nil, err
   113  	}
   114  	return b.Bytes(), nil
   115  }
   116  
   117  func (m *Message) marshalJSON(b *indentBuffer, opts *jsonpb.Marshaler) error {
   118  	if m == nil {
   119  		_, err := b.WriteString("null")
   120  		return err
   121  	}
   122  	if r, changed := wrapResolver(opts.AnyResolver, m.mf, m.md.GetFile()); changed {
   123  		newOpts := *opts
   124  		newOpts.AnyResolver = r
   125  		opts = &newOpts
   126  	}
   127  
   128  	if ok, err := marshalWellKnownType(m, b, opts); ok {
   129  		return err
   130  	}
   131  
   132  	err := b.WriteByte('{')
   133  	if err != nil {
   134  		return err
   135  	}
   136  	err = b.start()
   137  	if err != nil {
   138  		return err
   139  	}
   140  
   141  	var tags []int
   142  	if opts.EmitDefaults {
   143  		tags = m.allKnownFieldTags()
   144  	} else {
   145  		tags = m.knownFieldTags()
   146  	}
   147  
   148  	first := true
   149  
   150  	for _, tag := range tags {
   151  		itag := int32(tag)
   152  		fd := m.FindFieldDescriptor(itag)
   153  
   154  		v, ok := m.values[itag]
   155  		if !ok {
   156  			if fd.GetOneOf() != nil {
   157  				// don't print defaults for fields in a oneof
   158  				continue
   159  			}
   160  			v = fd.GetDefaultValue()
   161  		}
   162  
   163  		err := b.maybeNext(&first)
   164  		if err != nil {
   165  			return err
   166  		}
   167  		err = marshalKnownFieldJSON(b, fd, v, opts)
   168  		if err != nil {
   169  			return err
   170  		}
   171  	}
   172  
   173  	err = b.end()
   174  	if err != nil {
   175  		return err
   176  	}
   177  	err = b.WriteByte('}')
   178  	if err != nil {
   179  		return err
   180  	}
   181  
   182  	return nil
   183  }
   184  
   185  func marshalWellKnownType(m *Message, b *indentBuffer, opts *jsonpb.Marshaler) (bool, error) {
   186  	fqn := m.md.GetFullyQualifiedName()
   187  	if _, ok := wellKnownTypeNames[fqn]; !ok {
   188  		return false, nil
   189  	}
   190  
   191  	msgType := proto.MessageType(fqn)
   192  	if msgType == nil {
   193  		// wtf?
   194  		panic(fmt.Sprintf("could not find registered message type for %q", fqn))
   195  	}
   196  
   197  	// convert dynamic message to well-known type and let jsonpb marshal it
   198  	msg := reflect.New(msgType.Elem()).Interface().(proto.Message)
   199  	if err := m.MergeInto(msg); err != nil {
   200  		return true, err
   201  	}
   202  	return true, opts.Marshal(b, msg)
   203  }
   204  
   205  func marshalKnownFieldJSON(b *indentBuffer, fd *desc.FieldDescriptor, v interface{}, opts *jsonpb.Marshaler) error {
   206  	var jsonName string
   207  	if opts.OrigName {
   208  		jsonName = fd.GetName()
   209  	} else {
   210  		jsonName = fd.AsFieldDescriptorProto().GetJsonName()
   211  		if jsonName == "" {
   212  			jsonName = fd.GetName()
   213  		}
   214  	}
   215  	if fd.IsExtension() {
   216  		var scope string
   217  		switch parent := fd.GetParent().(type) {
   218  		case *desc.FileDescriptor:
   219  			scope = parent.GetPackage()
   220  		default:
   221  			scope = parent.GetFullyQualifiedName()
   222  		}
   223  		if scope == "" {
   224  			jsonName = fmt.Sprintf("[%s]", jsonName)
   225  		} else {
   226  			jsonName = fmt.Sprintf("[%s.%s]", scope, jsonName)
   227  		}
   228  	}
   229  	err := writeJsonString(b, jsonName)
   230  	if err != nil {
   231  		return err
   232  	}
   233  	err = b.sep()
   234  	if err != nil {
   235  		return err
   236  	}
   237  
   238  	if isNil(v) {
   239  		_, err := b.WriteString("null")
   240  		return err
   241  	}
   242  
   243  	if fd.IsMap() {
   244  		err = b.WriteByte('{')
   245  		if err != nil {
   246  			return err
   247  		}
   248  		err = b.start()
   249  		if err != nil {
   250  			return err
   251  		}
   252  
   253  		md := fd.GetMessageType()
   254  		vfd := md.FindFieldByNumber(2)
   255  
   256  		mp := v.(map[interface{}]interface{})
   257  		keys := make([]interface{}, 0, len(mp))
   258  		for k := range mp {
   259  			keys = append(keys, k)
   260  		}
   261  		sort.Sort(sortable(keys))
   262  		first := true
   263  		for _, mk := range keys {
   264  			mv := mp[mk]
   265  			err := b.maybeNext(&first)
   266  			if err != nil {
   267  				return err
   268  			}
   269  
   270  			err = marshalKnownFieldMapEntryJSON(b, mk, vfd, mv, opts)
   271  			if err != nil {
   272  				return err
   273  			}
   274  		}
   275  
   276  		err = b.end()
   277  		if err != nil {
   278  			return err
   279  		}
   280  		return b.WriteByte('}')
   281  
   282  	} else if fd.IsRepeated() {
   283  		err = b.WriteByte('[')
   284  		if err != nil {
   285  			return err
   286  		}
   287  		err = b.start()
   288  		if err != nil {
   289  			return err
   290  		}
   291  
   292  		sl := v.([]interface{})
   293  		first := true
   294  		for _, slv := range sl {
   295  			err := b.maybeNext(&first)
   296  			if err != nil {
   297  				return err
   298  			}
   299  			err = marshalKnownFieldValueJSON(b, fd, slv, opts)
   300  			if err != nil {
   301  				return err
   302  			}
   303  		}
   304  
   305  		err = b.end()
   306  		if err != nil {
   307  			return err
   308  		}
   309  		return b.WriteByte(']')
   310  
   311  	} else {
   312  		return marshalKnownFieldValueJSON(b, fd, v, opts)
   313  	}
   314  }
   315  
   316  // sortable is used to sort map keys. Values will be integers (int32, int64, uint32, and uint64),
   317  // bools, or strings.
   318  type sortable []interface{}
   319  
   320  func (s sortable) Len() int {
   321  	return len(s)
   322  }
   323  
   324  func (s sortable) Less(i, j int) bool {
   325  	vi := s[i]
   326  	vj := s[j]
   327  	switch reflect.TypeOf(vi).Kind() {
   328  	case reflect.Int32:
   329  		return vi.(int32) < vj.(int32)
   330  	case reflect.Int64:
   331  		return vi.(int64) < vj.(int64)
   332  	case reflect.Uint32:
   333  		return vi.(uint32) < vj.(uint32)
   334  	case reflect.Uint64:
   335  		return vi.(uint64) < vj.(uint64)
   336  	case reflect.String:
   337  		return vi.(string) < vj.(string)
   338  	case reflect.Bool:
   339  		return !vi.(bool) && vj.(bool)
   340  	default:
   341  		panic(fmt.Sprintf("cannot compare keys of type %v", reflect.TypeOf(vi)))
   342  	}
   343  }
   344  
   345  func (s sortable) Swap(i, j int) {
   346  	s[i], s[j] = s[j], s[i]
   347  }
   348  
   349  func isNil(v interface{}) bool {
   350  	if v == nil {
   351  		return true
   352  	}
   353  	rv := reflect.ValueOf(v)
   354  	return rv.Kind() == reflect.Ptr && rv.IsNil()
   355  }
   356  
   357  func marshalKnownFieldMapEntryJSON(b *indentBuffer, mk interface{}, vfd *desc.FieldDescriptor, mv interface{}, opts *jsonpb.Marshaler) error {
   358  	rk := reflect.ValueOf(mk)
   359  	var strkey string
   360  	switch rk.Kind() {
   361  	case reflect.Bool:
   362  		strkey = strconv.FormatBool(rk.Bool())
   363  	case reflect.Int32, reflect.Int64:
   364  		strkey = strconv.FormatInt(rk.Int(), 10)
   365  	case reflect.Uint32, reflect.Uint64:
   366  		strkey = strconv.FormatUint(rk.Uint(), 10)
   367  	case reflect.String:
   368  		strkey = rk.String()
   369  	default:
   370  		return fmt.Errorf("invalid map key value: %v (%v)", mk, rk.Type())
   371  	}
   372  	err := writeString(b, strkey)
   373  	if err != nil {
   374  		return err
   375  	}
   376  	err = b.sep()
   377  	if err != nil {
   378  		return err
   379  	}
   380  	return marshalKnownFieldValueJSON(b, vfd, mv, opts)
   381  }
   382  
   383  func marshalKnownFieldValueJSON(b *indentBuffer, fd *desc.FieldDescriptor, v interface{}, opts *jsonpb.Marshaler) error {
   384  	rv := reflect.ValueOf(v)
   385  	switch rv.Kind() {
   386  	case reflect.Int64:
   387  		return writeJsonString(b, strconv.FormatInt(rv.Int(), 10))
   388  	case reflect.Int32:
   389  		ed := fd.GetEnumType()
   390  		if !opts.EnumsAsInts && ed != nil {
   391  			n := int32(rv.Int())
   392  			vd := ed.FindValueByNumber(n)
   393  			if vd == nil {
   394  				_, err := b.WriteString(strconv.FormatInt(rv.Int(), 10))
   395  				return err
   396  			} else {
   397  				return writeJsonString(b, vd.GetName())
   398  			}
   399  		} else {
   400  			_, err := b.WriteString(strconv.FormatInt(rv.Int(), 10))
   401  			return err
   402  		}
   403  	case reflect.Uint64:
   404  		return writeJsonString(b, strconv.FormatUint(rv.Uint(), 10))
   405  	case reflect.Uint32:
   406  		_, err := b.WriteString(strconv.FormatUint(rv.Uint(), 10))
   407  		return err
   408  	case reflect.Float32, reflect.Float64:
   409  		f := rv.Float()
   410  		var str string
   411  		if math.IsNaN(f) {
   412  			str = `"NaN"`
   413  		} else if math.IsInf(f, 1) {
   414  			str = `"Infinity"`
   415  		} else if math.IsInf(f, -1) {
   416  			str = `"-Infinity"`
   417  		} else {
   418  			var bits int
   419  			if rv.Kind() == reflect.Float32 {
   420  				bits = 32
   421  			} else {
   422  				bits = 64
   423  			}
   424  			str = strconv.FormatFloat(rv.Float(), 'g', -1, bits)
   425  		}
   426  		_, err := b.WriteString(str)
   427  		return err
   428  	case reflect.Bool:
   429  		_, err := b.WriteString(strconv.FormatBool(rv.Bool()))
   430  		return err
   431  	case reflect.Slice:
   432  		bstr := base64.StdEncoding.EncodeToString(rv.Bytes())
   433  		return writeJsonString(b, bstr)
   434  	case reflect.String:
   435  		return writeJsonString(b, rv.String())
   436  	default:
   437  		// must be a message
   438  		if isNil(v) {
   439  			_, err := b.WriteString("null")
   440  			return err
   441  		}
   442  
   443  		if dm, ok := v.(*Message); ok {
   444  			return dm.marshalJSON(b, opts)
   445  		}
   446  
   447  		var err error
   448  		if b.indentCount <= 0 || len(b.indent) == 0 {
   449  			err = opts.Marshal(b, v.(proto.Message))
   450  		} else {
   451  			str, err := opts.MarshalToString(v.(proto.Message))
   452  			if err != nil {
   453  				return err
   454  			}
   455  			indent := strings.Repeat(b.indent, b.indentCount)
   456  			pos := 0
   457  			// add indention prefix to each line
   458  			for pos < len(str) {
   459  				start := pos
   460  				nextPos := strings.Index(str[pos:], "\n")
   461  				if nextPos == -1 {
   462  					nextPos = len(str)
   463  				} else {
   464  					nextPos = pos + nextPos + 1 // include newline
   465  				}
   466  				line := str[start:nextPos]
   467  				if pos > 0 {
   468  					_, err = b.WriteString(indent)
   469  					if err != nil {
   470  						return err
   471  					}
   472  				}
   473  				_, err = b.WriteString(line)
   474  				if err != nil {
   475  					return err
   476  				}
   477  				pos = nextPos
   478  			}
   479  		}
   480  		return err
   481  	}
   482  }
   483  
   484  func writeJsonString(b *indentBuffer, s string) error {
   485  	if sbytes, err := json.Marshal(s); err != nil {
   486  		return err
   487  	} else {
   488  		_, err := b.Write(sbytes)
   489  		return err
   490  	}
   491  }
   492  
   493  // UnmarshalJSON de-serializes the message that is present, in JSON format, in
   494  // the given bytes into this message. It first resets the current message. It
   495  // returns an error if the given bytes do not contain a valid encoding of this
   496  // message type in JSON format.
   497  //
   498  // This method is shorthand for invoking UnmarshalJSONPB with a default (zero
   499  // value) unmarshaler:
   500  //
   501  //    m.UnmarshalMergeJSONPB(&jsonpb.Unmarshaler{}, js)
   502  //
   503  // So unknown fields will result in an error, and no provided jsonpb.AnyResolver
   504  // will be used when parsing google.protobuf.Any messages.
   505  func (m *Message) UnmarshalJSON(js []byte) error {
   506  	return m.UnmarshalJSONPB(&jsonpb.Unmarshaler{}, js)
   507  }
   508  
   509  // UnmarshalMergeJSON de-serializes the message that is present, in JSON format,
   510  // in the given bytes into this message. Unlike UnmarshalJSON, it does not first
   511  // reset the message, instead merging the data in the given bytes into the
   512  // existing data in this message.
   513  func (m *Message) UnmarshalMergeJSON(js []byte) error {
   514  	return m.UnmarshalMergeJSONPB(&jsonpb.Unmarshaler{}, js)
   515  }
   516  
   517  // UnmarshalJSONPB de-serializes the message that is present, in JSON format, in
   518  // the given bytes into this message. The given unmarshaler conveys options used
   519  // when parsing the JSON. This function first resets the current message. It
   520  // returns an error if the given bytes do not contain a valid encoding of this
   521  // message type in JSON format.
   522  //
   523  // The decoding is lenient:
   524  //  1. The JSON can refer to fields either by their JSON name or by their
   525  //     declared name.
   526  //  2. The JSON can use either numeric values or string names for enum values.
   527  //
   528  // When instantiating nested messages, if this message's associated factory
   529  // returns a generated message type (as opposed to a dynamic message), the given
   530  // unmarshaler is used to unmarshal it.
   531  //
   532  // When unmarshaling any nested messages, any jsonpb.AnyResolver configured in
   533  // the given unmarshaler is augmented with knowledge of message types known to
   534  // this message's descriptor (and its enclosing file and set of transitive
   535  // dependencies).
   536  func (m *Message) UnmarshalJSONPB(opts *jsonpb.Unmarshaler, js []byte) error {
   537  	m.Reset()
   538  	if err := m.UnmarshalMergeJSONPB(opts, js); err != nil {
   539  		return err
   540  	}
   541  	return m.Validate()
   542  }
   543  
   544  // UnmarshalMergeJSONPB de-serializes the message that is present, in JSON
   545  // format, in the given bytes into this message. The given unmarshaler conveys
   546  // options used when parsing the JSON. Unlike UnmarshalJSONPB, it does not first
   547  // reset the message, instead merging the data in the given bytes into the
   548  // existing data in this message.
   549  func (m *Message) UnmarshalMergeJSONPB(opts *jsonpb.Unmarshaler, js []byte) error {
   550  	r := newJsReader(js)
   551  	err := m.unmarshalJson(r, opts)
   552  	if err != nil {
   553  		return err
   554  	}
   555  	if t, err := r.poll(); err != io.EOF {
   556  		b, _ := ioutil.ReadAll(r.unread())
   557  		s := fmt.Sprintf("%v%s", t, string(b))
   558  		return fmt.Errorf("superfluous data found after JSON object: %q", s)
   559  	}
   560  	return nil
   561  }
   562  
   563  func unmarshalWellKnownType(m *Message, r *jsReader, opts *jsonpb.Unmarshaler) (bool, error) {
   564  	fqn := m.md.GetFullyQualifiedName()
   565  	if _, ok := wellKnownTypeNames[fqn]; !ok {
   566  		return false, nil
   567  	}
   568  
   569  	msgType := proto.MessageType(fqn)
   570  	if msgType == nil {
   571  		// wtf?
   572  		panic(fmt.Sprintf("could not find registered message type for %q", fqn))
   573  	}
   574  
   575  	// extract json value from r
   576  	var js json.RawMessage
   577  	if err := json.NewDecoder(r.unread()).Decode(&js); err != nil {
   578  		return true, err
   579  	}
   580  	if err := r.skip(); err != nil {
   581  		return true, err
   582  	}
   583  
   584  	// unmarshal into well-known type and then convert to dynamic message
   585  	msg := reflect.New(msgType.Elem()).Interface().(proto.Message)
   586  	if err := opts.Unmarshal(bytes.NewReader(js), msg); err != nil {
   587  		return true, err
   588  	}
   589  	return true, m.MergeFrom(msg)
   590  }
   591  
   592  func (m *Message) unmarshalJson(r *jsReader, opts *jsonpb.Unmarshaler) error {
   593  	if r, changed := wrapResolver(opts.AnyResolver, m.mf, m.md.GetFile()); changed {
   594  		newOpts := *opts
   595  		newOpts.AnyResolver = r
   596  		opts = &newOpts
   597  	}
   598  
   599  	if ok, err := unmarshalWellKnownType(m, r, opts); ok {
   600  		return err
   601  	}
   602  
   603  	t, err := r.peek()
   604  	if err != nil {
   605  		return err
   606  	}
   607  	if t == nil {
   608  		// if json is simply "null" we do nothing
   609  		r.poll()
   610  		return nil
   611  	}
   612  
   613  	if err := r.beginObject(); err != nil {
   614  		return err
   615  	}
   616  
   617  	for r.hasNext() {
   618  		f, err := r.nextObjectKey()
   619  		if err != nil {
   620  			return err
   621  		}
   622  		fd := m.FindFieldDescriptorByJSONName(f)
   623  		if fd == nil {
   624  			if opts.AllowUnknownFields {
   625  				r.skip()
   626  				continue
   627  			}
   628  			return fmt.Errorf("message type %s has no known field named %s", m.md.GetFullyQualifiedName(), f)
   629  		}
   630  		v, err := unmarshalJsField(fd, r, m.mf, opts)
   631  		if err != nil {
   632  			return err
   633  		}
   634  		if v != nil {
   635  			if err := mergeField(m, fd, v); err != nil {
   636  				return err
   637  			}
   638  		} else if fd.GetOneOf() != nil {
   639  			// preserve explicit null for oneof fields (this is a little odd but
   640  			// mimics the behavior of jsonpb with oneofs in generated message types)
   641  			if fd.GetMessageType() != nil {
   642  				typ := m.mf.GetKnownTypeRegistry().GetKnownType(fd.GetMessageType().GetFullyQualifiedName())
   643  				if typ != nil {
   644  					// typed nil
   645  					if typ.Kind() != reflect.Ptr {
   646  						typ = reflect.PtrTo(typ)
   647  					}
   648  					v = reflect.Zero(typ).Interface()
   649  				} else {
   650  					// can't use nil dynamic message, so we just use empty one instead
   651  					v = m.mf.NewDynamicMessage(fd.GetMessageType())
   652  				}
   653  				if err := m.setField(fd, v); err != nil {
   654  					return err
   655  				}
   656  			} else {
   657  				// not a message... explicit null makes no sense
   658  				return fmt.Errorf("message type %s cannot set field %s to null: it is not a message type", m.md.GetFullyQualifiedName(), f)
   659  			}
   660  		} else {
   661  			m.clearField(fd)
   662  		}
   663  	}
   664  
   665  	if err := r.endObject(); err != nil {
   666  		return err
   667  	}
   668  
   669  	return nil
   670  }
   671  
   672  func isWellKnownValue(fd *desc.FieldDescriptor) bool {
   673  	return !fd.IsRepeated() && fd.GetType() == descriptor.FieldDescriptorProto_TYPE_MESSAGE &&
   674  		fd.GetMessageType().GetFullyQualifiedName() == "google.protobuf.Value"
   675  }
   676  
   677  func isWellKnownListValue(fd *desc.FieldDescriptor) bool {
   678  	return !fd.IsRepeated() && fd.GetType() == descriptor.FieldDescriptorProto_TYPE_MESSAGE &&
   679  		fd.GetMessageType().GetFullyQualifiedName() == "google.protobuf.ListValue"
   680  }
   681  
   682  func unmarshalJsField(fd *desc.FieldDescriptor, r *jsReader, mf *MessageFactory, opts *jsonpb.Unmarshaler) (interface{}, error) {
   683  	t, err := r.peek()
   684  	if err != nil {
   685  		return nil, err
   686  	}
   687  	if t == nil && !isWellKnownValue(fd) {
   688  		// if value is null, just return nil
   689  		// (unless field is google.protobuf.Value, in which case
   690  		// we fall through to parse it as an instance where its
   691  		// underlying value is set to a NullValue)
   692  		r.poll()
   693  		return nil, nil
   694  	}
   695  
   696  	if t == json.Delim('{') && fd.IsMap() {
   697  		entryType := fd.GetMessageType()
   698  		keyType := entryType.FindFieldByNumber(1)
   699  		valueType := entryType.FindFieldByNumber(2)
   700  		mp := map[interface{}]interface{}{}
   701  
   702  		// TODO: if there are just two map keys "key" and "value" and they have the right type of values,
   703  		// treat this JSON object as a single map entry message. (In keeping with support of map fields as
   704  		// if they were normal repeated field of entry messages as well as supporting a transition from
   705  		// optional to repeated...)
   706  
   707  		if err := r.beginObject(); err != nil {
   708  			return nil, err
   709  		}
   710  		for r.hasNext() {
   711  			kk, err := unmarshalJsFieldElement(keyType, r, mf, opts, false)
   712  			if err != nil {
   713  				return nil, err
   714  			}
   715  			vv, err := unmarshalJsFieldElement(valueType, r, mf, opts, true)
   716  			if err != nil {
   717  				return nil, err
   718  			}
   719  			mp[kk] = vv
   720  		}
   721  		if err := r.endObject(); err != nil {
   722  			return nil, err
   723  		}
   724  
   725  		return mp, nil
   726  	} else if t == json.Delim('[') && !isWellKnownListValue(fd) {
   727  		// We support parsing an array, even if field is not repeated, to mimic support in proto
   728  		// binary wire format that supports changing an optional field to repeated and vice versa.
   729  		// If the field is not repeated, we only keep the last value in the array.
   730  
   731  		if err := r.beginArray(); err != nil {
   732  			return nil, err
   733  		}
   734  		var sl []interface{}
   735  		var v interface{}
   736  		for r.hasNext() {
   737  			var err error
   738  			v, err = unmarshalJsFieldElement(fd, r, mf, opts, false)
   739  			if err != nil {
   740  				return nil, err
   741  			}
   742  			if fd.IsRepeated() && v != nil {
   743  				sl = append(sl, v)
   744  			}
   745  		}
   746  		if err := r.endArray(); err != nil {
   747  			return nil, err
   748  		}
   749  		if fd.IsMap() {
   750  			mp := map[interface{}]interface{}{}
   751  			for _, m := range sl {
   752  				msg := m.(*Message)
   753  				kk, err := msg.TryGetFieldByNumber(1)
   754  				if err != nil {
   755  					return nil, err
   756  				}
   757  				vv, err := msg.TryGetFieldByNumber(2)
   758  				if err != nil {
   759  					return nil, err
   760  				}
   761  				mp[kk] = vv
   762  			}
   763  			return mp, nil
   764  		} else if fd.IsRepeated() {
   765  			return sl, nil
   766  		} else {
   767  			return v, nil
   768  		}
   769  	} else {
   770  		// We support parsing a singular value, even if field is repeated, to mimic support in proto
   771  		// binary wire format that supports changing an optional field to repeated and vice versa.
   772  		// If the field is repeated, we store value as singleton slice of that one value.
   773  
   774  		v, err := unmarshalJsFieldElement(fd, r, mf, opts, false)
   775  		if err != nil {
   776  			return nil, err
   777  		}
   778  		if v == nil {
   779  			return nil, nil
   780  		}
   781  		if fd.IsRepeated() {
   782  			return []interface{}{v}, nil
   783  		} else {
   784  			return v, nil
   785  		}
   786  	}
   787  }
   788  
   789  func unmarshalJsFieldElement(fd *desc.FieldDescriptor, r *jsReader, mf *MessageFactory, opts *jsonpb.Unmarshaler, allowNilMessage bool) (interface{}, error) {
   790  	t, err := r.peek()
   791  	if err != nil {
   792  		return nil, err
   793  	}
   794  
   795  	switch fd.GetType() {
   796  	case descriptor.FieldDescriptorProto_TYPE_MESSAGE,
   797  		descriptor.FieldDescriptorProto_TYPE_GROUP:
   798  
   799  		if t == nil && allowNilMessage {
   800  			// if json is simply "null" return a nil pointer
   801  			r.poll()
   802  			return nilMessage(fd.GetMessageType()), nil
   803  		}
   804  
   805  		m := mf.NewMessage(fd.GetMessageType())
   806  		if dm, ok := m.(*Message); ok {
   807  			if err := dm.unmarshalJson(r, opts); err != nil {
   808  				return nil, err
   809  			}
   810  		} else {
   811  			var msg json.RawMessage
   812  			if err := json.NewDecoder(r.unread()).Decode(&msg); err != nil {
   813  				return nil, err
   814  			}
   815  			if err := r.skip(); err != nil {
   816  				return nil, err
   817  			}
   818  			if err := opts.Unmarshal(bytes.NewReader([]byte(msg)), m); err != nil {
   819  				return nil, err
   820  			}
   821  		}
   822  		return m, nil
   823  
   824  	case descriptor.FieldDescriptorProto_TYPE_ENUM:
   825  		if e, err := r.nextNumber(); err != nil {
   826  			return nil, err
   827  		} else {
   828  			// value could be string or number
   829  			if i, err := e.Int64(); err != nil {
   830  				// number cannot be parsed, so see if it's an enum value name
   831  				vd := fd.GetEnumType().FindValueByName(string(e))
   832  				if vd != nil {
   833  					return vd.GetNumber(), nil
   834  				} else {
   835  					return nil, fmt.Errorf("enum %q does not have value named %q", fd.GetEnumType().GetFullyQualifiedName(), e)
   836  				}
   837  			} else if i > math.MaxInt32 || i < math.MinInt32 {
   838  				return nil, NumericOverflowError
   839  			} else {
   840  				return int32(i), err
   841  			}
   842  		}
   843  
   844  	case descriptor.FieldDescriptorProto_TYPE_INT32,
   845  		descriptor.FieldDescriptorProto_TYPE_SINT32,
   846  		descriptor.FieldDescriptorProto_TYPE_SFIXED32:
   847  		if i, err := r.nextInt(); err != nil {
   848  			return nil, err
   849  		} else if i > math.MaxInt32 || i < math.MinInt32 {
   850  			return nil, NumericOverflowError
   851  		} else {
   852  			return int32(i), err
   853  		}
   854  
   855  	case descriptor.FieldDescriptorProto_TYPE_INT64,
   856  		descriptor.FieldDescriptorProto_TYPE_SINT64,
   857  		descriptor.FieldDescriptorProto_TYPE_SFIXED64:
   858  		return r.nextInt()
   859  
   860  	case descriptor.FieldDescriptorProto_TYPE_UINT32,
   861  		descriptor.FieldDescriptorProto_TYPE_FIXED32:
   862  		if i, err := r.nextUint(); err != nil {
   863  			return nil, err
   864  		} else if i > math.MaxUint32 {
   865  			return nil, NumericOverflowError
   866  		} else {
   867  			return uint32(i), err
   868  		}
   869  
   870  	case descriptor.FieldDescriptorProto_TYPE_UINT64,
   871  		descriptor.FieldDescriptorProto_TYPE_FIXED64:
   872  		return r.nextUint()
   873  
   874  	case descriptor.FieldDescriptorProto_TYPE_BOOL:
   875  		if str, ok := t.(string); ok {
   876  			if str == "true" {
   877  				r.poll() // consume token
   878  				return true, err
   879  			} else if str == "false" {
   880  				r.poll() // consume token
   881  				return false, err
   882  			}
   883  		}
   884  		return r.nextBool()
   885  
   886  	case descriptor.FieldDescriptorProto_TYPE_FLOAT:
   887  		if f, err := r.nextFloat(); err != nil {
   888  			return nil, err
   889  		} else {
   890  			return float32(f), nil
   891  		}
   892  
   893  	case descriptor.FieldDescriptorProto_TYPE_DOUBLE:
   894  		return r.nextFloat()
   895  
   896  	case descriptor.FieldDescriptorProto_TYPE_BYTES:
   897  		return r.nextBytes()
   898  
   899  	case descriptor.FieldDescriptorProto_TYPE_STRING:
   900  		return r.nextString()
   901  
   902  	default:
   903  		return nil, fmt.Errorf("unknown field type: %v", fd.GetType())
   904  	}
   905  }
   906  
   907  type jsReader struct {
   908  	reader  *bytes.Reader
   909  	dec     *json.Decoder
   910  	current json.Token
   911  	peeked  bool
   912  }
   913  
   914  func newJsReader(b []byte) *jsReader {
   915  	reader := bytes.NewReader(b)
   916  	dec := json.NewDecoder(reader)
   917  	dec.UseNumber()
   918  	return &jsReader{reader: reader, dec: dec}
   919  }
   920  
   921  func (r *jsReader) unread() io.Reader {
   922  	bufs := make([]io.Reader, 3)
   923  	var peeked []byte
   924  	if r.peeked {
   925  		if _, ok := r.current.(json.Delim); ok {
   926  			peeked = []byte(fmt.Sprintf("%v", r.current))
   927  		} else {
   928  			peeked, _ = json.Marshal(r.current)
   929  		}
   930  	}
   931  	readerCopy := *r.reader
   932  	decCopy := *r.dec
   933  
   934  	bufs[0] = bytes.NewReader(peeked)
   935  	bufs[1] = decCopy.Buffered()
   936  	bufs[2] = &readerCopy
   937  	return &concatReader{bufs: bufs}
   938  }
   939  
   940  func (r *jsReader) hasNext() bool {
   941  	return r.dec.More()
   942  }
   943  
   944  func (r *jsReader) peek() (json.Token, error) {
   945  	if r.peeked {
   946  		return r.current, nil
   947  	}
   948  	t, err := r.dec.Token()
   949  	if err != nil {
   950  		return nil, err
   951  	}
   952  	r.peeked = true
   953  	r.current = t
   954  	return t, nil
   955  }
   956  
   957  func (r *jsReader) poll() (json.Token, error) {
   958  	if r.peeked {
   959  		ret := r.current
   960  		r.current = nil
   961  		r.peeked = false
   962  		return ret, nil
   963  	}
   964  	return r.dec.Token()
   965  }
   966  
   967  func (r *jsReader) beginObject() error {
   968  	_, err := r.expect(func(t json.Token) bool { return t == json.Delim('{') }, nil, "start of JSON object: '{'")
   969  	return err
   970  }
   971  
   972  func (r *jsReader) endObject() error {
   973  	_, err := r.expect(func(t json.Token) bool { return t == json.Delim('}') }, nil, "end of JSON object: '}'")
   974  	return err
   975  }
   976  
   977  func (r *jsReader) beginArray() error {
   978  	_, err := r.expect(func(t json.Token) bool { return t == json.Delim('[') }, nil, "start of array: '['")
   979  	return err
   980  }
   981  
   982  func (r *jsReader) endArray() error {
   983  	_, err := r.expect(func(t json.Token) bool { return t == json.Delim(']') }, nil, "end of array: ']'")
   984  	return err
   985  }
   986  
   987  func (r *jsReader) nextObjectKey() (string, error) {
   988  	return r.nextString()
   989  }
   990  
   991  func (r *jsReader) nextString() (string, error) {
   992  	t, err := r.expect(func(t json.Token) bool { _, ok := t.(string); return ok }, "", "string")
   993  	if err != nil {
   994  		return "", err
   995  	}
   996  	return t.(string), nil
   997  }
   998  
   999  func (r *jsReader) nextBytes() ([]byte, error) {
  1000  	str, err := r.nextString()
  1001  	if err != nil {
  1002  		return nil, err
  1003  	}
  1004  	return base64.StdEncoding.DecodeString(str)
  1005  }
  1006  
  1007  func (r *jsReader) nextBool() (bool, error) {
  1008  	t, err := r.expect(func(t json.Token) bool { _, ok := t.(bool); return ok }, false, "boolean")
  1009  	if err != nil {
  1010  		return false, err
  1011  	}
  1012  	return t.(bool), nil
  1013  }
  1014  
  1015  func (r *jsReader) nextInt() (int64, error) {
  1016  	n, err := r.nextNumber()
  1017  	if err != nil {
  1018  		return 0, err
  1019  	}
  1020  	return n.Int64()
  1021  }
  1022  
  1023  func (r *jsReader) nextUint() (uint64, error) {
  1024  	n, err := r.nextNumber()
  1025  	if err != nil {
  1026  		return 0, err
  1027  	}
  1028  	return strconv.ParseUint(string(n), 10, 64)
  1029  }
  1030  
  1031  func (r *jsReader) nextFloat() (float64, error) {
  1032  	n, err := r.nextNumber()
  1033  	if err != nil {
  1034  		return 0, err
  1035  	}
  1036  	return n.Float64()
  1037  }
  1038  
  1039  func (r *jsReader) nextNumber() (json.Number, error) {
  1040  	t, err := r.expect(func(t json.Token) bool { return reflect.TypeOf(t).Kind() == reflect.String }, "0", "number")
  1041  	if err != nil {
  1042  		return "", err
  1043  	}
  1044  	switch t := t.(type) {
  1045  	case json.Number:
  1046  		return t, nil
  1047  	case string:
  1048  		return json.Number(t), nil
  1049  	}
  1050  	return "", fmt.Errorf("expecting a number but got %v", t)
  1051  }
  1052  
  1053  func (r *jsReader) skip() error {
  1054  	t, err := r.poll()
  1055  	if err != nil {
  1056  		return err
  1057  	}
  1058  	if t == json.Delim('[') {
  1059  		if err := r.skipArray(); err != nil {
  1060  			return err
  1061  		}
  1062  	} else if t == json.Delim('{') {
  1063  		if err := r.skipObject(); err != nil {
  1064  			return err
  1065  		}
  1066  	}
  1067  	return nil
  1068  }
  1069  
  1070  func (r *jsReader) skipArray() error {
  1071  	for r.hasNext() {
  1072  		if err := r.skip(); err != nil {
  1073  			return err
  1074  		}
  1075  	}
  1076  	if err := r.endArray(); err != nil {
  1077  		return err
  1078  	}
  1079  	return nil
  1080  }
  1081  
  1082  func (r *jsReader) skipObject() error {
  1083  	for r.hasNext() {
  1084  		// skip object key
  1085  		if err := r.skip(); err != nil {
  1086  			return err
  1087  		}
  1088  		// and value
  1089  		if err := r.skip(); err != nil {
  1090  			return err
  1091  		}
  1092  	}
  1093  	if err := r.endObject(); err != nil {
  1094  		return err
  1095  	}
  1096  	return nil
  1097  }
  1098  
  1099  func (r *jsReader) expect(predicate func(json.Token) bool, ifNil interface{}, expected string) (interface{}, error) {
  1100  	t, err := r.poll()
  1101  	if err != nil {
  1102  		return nil, err
  1103  	}
  1104  	if t == nil && ifNil != nil {
  1105  		return ifNil, nil
  1106  	}
  1107  	if !predicate(t) {
  1108  		return t, fmt.Errorf("bad input: expecting %s ; instead got %v", expected, t)
  1109  	}
  1110  	return t, nil
  1111  }
  1112  
  1113  type concatReader struct {
  1114  	bufs []io.Reader
  1115  	curr int
  1116  }
  1117  
  1118  func (r *concatReader) Read(p []byte) (n int, err error) {
  1119  	for {
  1120  		if r.curr >= len(r.bufs) {
  1121  			err = io.EOF
  1122  			return
  1123  		}
  1124  		var c int
  1125  		c, err = r.bufs[r.curr].Read(p)
  1126  		n += c
  1127  		if err != io.EOF {
  1128  			return
  1129  		}
  1130  		r.curr++
  1131  		p = p[c:]
  1132  	}
  1133  }
  1134  
  1135  // AnyResolver returns a jsonpb.AnyResolver that uses the given file descriptors
  1136  // to resolve message names. It uses the given factory, which may be nil, to
  1137  // instantiate messages. The messages that it returns when resolving a type name
  1138  // may often be dynamic messages.
  1139  func AnyResolver(mf *MessageFactory, files ...*desc.FileDescriptor) jsonpb.AnyResolver {
  1140  	return &anyResolver{mf: mf, files: files}
  1141  }
  1142  
  1143  type anyResolver struct {
  1144  	mf      *MessageFactory
  1145  	files   []*desc.FileDescriptor
  1146  	ignored map[*desc.FileDescriptor]struct{}
  1147  	other   jsonpb.AnyResolver
  1148  }
  1149  
  1150  func wrapResolver(r jsonpb.AnyResolver, mf *MessageFactory, f *desc.FileDescriptor) (jsonpb.AnyResolver, bool) {
  1151  	if r, ok := r.(*anyResolver); ok {
  1152  		if _, ok := r.ignored[f]; ok {
  1153  			// if the current resolver is ignoring this file, it's because another
  1154  			// (upstream) resolver is already handling it, so nothing to do
  1155  			return r, false
  1156  		}
  1157  		for _, file := range r.files {
  1158  			if file == f {
  1159  				// no need to wrap!
  1160  				return r, false
  1161  			}
  1162  		}
  1163  		// ignore files that will be checked by the resolver we're wrapping
  1164  		// (we'll just delegate and let it search those files)
  1165  		ignored := map[*desc.FileDescriptor]struct{}{}
  1166  		for i := range r.ignored {
  1167  			ignored[i] = struct{}{}
  1168  		}
  1169  		ignore(r.files, ignored)
  1170  		return &anyResolver{mf: mf, files: []*desc.FileDescriptor{f}, ignored: ignored, other: r}, true
  1171  	}
  1172  	return &anyResolver{mf: mf, files: []*desc.FileDescriptor{f}, other: r}, true
  1173  }
  1174  
  1175  func ignore(files []*desc.FileDescriptor, ignored map[*desc.FileDescriptor]struct{}) {
  1176  	for _, f := range files {
  1177  		if _, ok := ignored[f]; ok {
  1178  			continue
  1179  		}
  1180  		ignored[f] = struct{}{}
  1181  		ignore(f.GetDependencies(), ignored)
  1182  	}
  1183  }
  1184  
  1185  func (r *anyResolver) Resolve(typeUrl string) (proto.Message, error) {
  1186  	mname := typeUrl
  1187  	if slash := strings.LastIndex(mname, "/"); slash >= 0 {
  1188  		mname = mname[slash+1:]
  1189  	}
  1190  
  1191  	// see if the user-specified resolver is able to do the job
  1192  	if r.other != nil {
  1193  		msg, err := r.other.Resolve(typeUrl)
  1194  		if err == nil {
  1195  			return msg, nil
  1196  		}
  1197  	}
  1198  
  1199  	// try to find the message in our known set of files
  1200  	checked := map[*desc.FileDescriptor]struct{}{}
  1201  	for _, f := range r.files {
  1202  		md := r.findMessage(f, mname, checked)
  1203  		if md != nil {
  1204  			return r.mf.NewMessage(md), nil
  1205  		}
  1206  	}
  1207  	// failing that, see if the message factory knows about this type
  1208  	var ktr *KnownTypeRegistry
  1209  	if r.mf != nil {
  1210  		ktr = r.mf.ktr
  1211  	} else {
  1212  		ktr = (*KnownTypeRegistry)(nil)
  1213  	}
  1214  	m := ktr.CreateIfKnown(mname)
  1215  	if m != nil {
  1216  		return m, nil
  1217  	}
  1218  
  1219  	// no other resolver to fallback to? mimic default behavior
  1220  	mt := proto.MessageType(mname)
  1221  	if mt == nil {
  1222  		return nil, fmt.Errorf("unknown message type %q", mname)
  1223  	}
  1224  	return reflect.New(mt.Elem()).Interface().(proto.Message), nil
  1225  }
  1226  
  1227  func (r *anyResolver) findMessage(fd *desc.FileDescriptor, msgName string, checked map[*desc.FileDescriptor]struct{}) *desc.MessageDescriptor {
  1228  	// if this is an ignored descriptor, skip
  1229  	if _, ok := r.ignored[fd]; ok {
  1230  		return nil
  1231  	}
  1232  
  1233  	// bail if we've already checked this file
  1234  	if _, ok := checked[fd]; ok {
  1235  		return nil
  1236  	}
  1237  	checked[fd] = struct{}{}
  1238  
  1239  	// see if this file has the message
  1240  	md := fd.FindMessage(msgName)
  1241  	if md != nil {
  1242  		return md
  1243  	}
  1244  
  1245  	// if not, recursively search the file's imports
  1246  	for _, dep := range fd.GetDependencies() {
  1247  		md = r.findMessage(dep, msgName, checked)
  1248  		if md != nil {
  1249  			return md
  1250  		}
  1251  	}
  1252  	return nil
  1253  }
  1254  
  1255  var _ jsonpb.AnyResolver = (*anyResolver)(nil)