github.com/segmentio/encoding@v0.4.0/proto/reflect.go (about)

     1  package proto
     2  
     3  import (
     4  	"fmt"
     5  	"reflect"
     6  	"strconv"
     7  	"strings"
     8  	"sync"
     9  	"sync/atomic"
    10  )
    11  
    12  // Kind is an enumeration representing the various data types supported by the
    13  // protobuf language.
    14  type Kind int
    15  
    16  const (
    17  	Bool Kind = iota
    18  	Int32
    19  	Int64
    20  	Sint32
    21  	Sint64
    22  	Uint32
    23  	Uint64
    24  	Fix32
    25  	Fix64
    26  	Sfix32
    27  	Sfix64
    28  	Float
    29  	Double
    30  	String
    31  	Bytes
    32  	Map
    33  	Struct
    34  )
    35  
    36  // Type is an interface similar to reflect.Type. Values implementing this
    37  // interface represent high level protobuf types.
    38  //
    39  // Type values are safe to use concurrently from multiple goroutines.
    40  //
    41  // Types are comparable value.
    42  type Type interface {
    43  	// Returns a human-readable representation of the type.
    44  	String() string
    45  
    46  	// Returns the name of the type.
    47  	Name() string
    48  
    49  	// Kind returns the kind of protobuf values that are represented.
    50  	Kind() Kind
    51  
    52  	// When the Type represents a protobuf map, calling this method returns the
    53  	// type of the map keys.
    54  	//
    55  	// If the Type is not a map type, the method panics.
    56  	Key() Type
    57  
    58  	// When the Type represents a protobuf map, calling this method returns the
    59  	// type of the map values.
    60  	//
    61  	// If the Type is not a map type, the method panics.
    62  	Elem() Type
    63  
    64  	// Returns the protobuf wire type for the Type it is called on.
    65  	WireType() WireType
    66  
    67  	// Returns the number of fields in the protobuf message.
    68  	//
    69  	// If the Type does not represent a struct type, the method returns zero.
    70  	NumField() int
    71  
    72  	// Returns the Field at the given in Type.
    73  	//
    74  	// If the Type does not represent a struct type, the method panics.
    75  	Field(int) Field
    76  
    77  	// Returns the Field with the given name in Type.
    78  	//
    79  	// If the Type does not represent a struct type, or if the field does not
    80  	// exist, the method panics.
    81  	FieldByName(string) Field
    82  
    83  	// Returns the Field with the given number in Type.
    84  	//
    85  	// If the Type does not represent a struct type, or if the field does not
    86  	// exist, the method panics.
    87  	FieldByNumber(FieldNumber) Field
    88  
    89  	// For unsigned types, convert to their zig-zag form.
    90  	//
    91  	// The method uses the following table to perform the conversion:
    92  	//
    93  	//  base    | zig-zag
    94  	//	--------+---------
    95  	//	int32   | sint32
    96  	//	int64   | sint64
    97  	//	uint32  | sint32
    98  	//	uint64  | sint64
    99  	//	fixed32 | sfixed32
   100  	//	fixed64 | sfixed64
   101  	//
   102  	// If the type cannot be converted to a zig-zag type, the method panics.
   103  	ZigZag() Type
   104  }
   105  
   106  // TypeOf returns the protobuf type used to represent a go type.
   107  //
   108  // The function uses the following table to map Go types to Protobuf:
   109  //
   110  //	Go      | Protobuf
   111  //	--------+---------
   112  //	bool    | bool
   113  //	int     | int64
   114  //	int32   | int32
   115  //	int64   | int64
   116  //	uint    | uint64
   117  //	uint32  | uint32
   118  //	uint64  | uint64
   119  //	float32 | float
   120  //	float64 | double
   121  //	string  | string
   122  //	[]byte  | bytes
   123  //	map     | map
   124  //	struct  | message
   125  //
   126  // Pointer types are also supported and automatically dereferenced.
   127  func TypeOf(t reflect.Type) Type {
   128  	cache, _ := typesCache.Load().(map[reflect.Type]Type)
   129  	if r, ok := cache[t]; ok {
   130  		return r
   131  	}
   132  
   133  	typesMutex.Lock()
   134  	defer typesMutex.Unlock()
   135  
   136  	cache, _ = typesCache.Load().(map[reflect.Type]Type)
   137  	if r, ok := cache[t]; ok {
   138  		return r
   139  	}
   140  
   141  	seen := map[reflect.Type]Type{}
   142  	r := typeOf(t, seen)
   143  
   144  	newCache := make(map[reflect.Type]Type, len(cache)+len(seen))
   145  	for t, r := range cache {
   146  		newCache[t] = r
   147  	}
   148  
   149  	for t, r := range seen {
   150  		if x, ok := newCache[t]; ok {
   151  			r = x
   152  		} else {
   153  			newCache[t] = r
   154  		}
   155  	}
   156  
   157  	if x, ok := newCache[t]; ok {
   158  		r = x
   159  	} else {
   160  		newCache[t] = r
   161  	}
   162  
   163  	typesCache.Store(newCache)
   164  	return r
   165  }
   166  
   167  func typeOf(t reflect.Type, seen map[reflect.Type]Type) Type {
   168  	if r, ok := seen[t]; ok {
   169  		return r
   170  	}
   171  
   172  	switch {
   173  	case implements(t, messageType):
   174  		return &opaqueMessageType{}
   175  	case implements(t, customMessageType) && !implements(t, protoMessageType):
   176  		return &primitiveTypes[Bytes]
   177  	}
   178  
   179  	switch t.Kind() {
   180  	case reflect.Bool:
   181  		return &primitiveTypes[Bool]
   182  	case reflect.Int:
   183  		return &primitiveTypes[Int64]
   184  	case reflect.Int32:
   185  		return &primitiveTypes[Int32]
   186  	case reflect.Int64:
   187  		return &primitiveTypes[Int64]
   188  	case reflect.Uint:
   189  		return &primitiveTypes[Uint64]
   190  	case reflect.Uint32:
   191  		return &primitiveTypes[Uint32]
   192  	case reflect.Uint64:
   193  		return &primitiveTypes[Uint64]
   194  	case reflect.Float32:
   195  		return &primitiveTypes[Float]
   196  	case reflect.Float64:
   197  		return &primitiveTypes[Double]
   198  	case reflect.String:
   199  		return &primitiveTypes[String]
   200  	case reflect.Slice, reflect.Array:
   201  		if t.Elem().Kind() == reflect.Uint8 {
   202  			return &primitiveTypes[Bytes]
   203  		}
   204  	case reflect.Map:
   205  		return mapTypeOf(t, seen)
   206  	case reflect.Struct:
   207  		return structTypeOf(t, seen)
   208  	case reflect.Ptr:
   209  		return typeOf(t.Elem(), seen)
   210  	}
   211  
   212  	panic(fmt.Errorf("cannot construct protobuf type from go value of type %s", t))
   213  }
   214  
   215  var (
   216  	typesMutex sync.Mutex
   217  	typesCache atomic.Value // map[reflect.Type]Type{}
   218  )
   219  
   220  type Field struct {
   221  	Index    int
   222  	Number   FieldNumber
   223  	Name     string
   224  	Type     Type
   225  	Repeated bool
   226  }
   227  
   228  type primitiveType struct {
   229  	name   string
   230  	kind   Kind
   231  	wire   WireType
   232  	zigzag Kind
   233  }
   234  
   235  func (t *primitiveType) String() string {
   236  	return t.name
   237  }
   238  
   239  func (t *primitiveType) Name() string {
   240  	return t.name
   241  }
   242  
   243  func (t *primitiveType) Kind() Kind {
   244  	return t.kind
   245  }
   246  
   247  func (t *primitiveType) Key() Type {
   248  	panic(fmt.Errorf("proto.Type.Key: called on unsupported type: %s", t))
   249  }
   250  
   251  func (t *primitiveType) Elem() Type {
   252  	panic(fmt.Errorf("proto.Type.Elem: called on unsupported type: %s", t))
   253  }
   254  
   255  func (t *primitiveType) WireType() WireType {
   256  	return t.wire
   257  }
   258  
   259  func (t *primitiveType) NumField() int {
   260  	return 0
   261  }
   262  
   263  func (t *primitiveType) Field(int) Field {
   264  	panic(fmt.Errorf("proto.Type.Field: called on unsupported type: %s", t))
   265  }
   266  
   267  func (t *primitiveType) FieldByName(string) Field {
   268  	panic(fmt.Errorf("proto.Type.FieldByName: called on unsupported type: %s", t))
   269  }
   270  
   271  func (t *primitiveType) FieldByNumber(FieldNumber) Field {
   272  	panic(fmt.Errorf("proto.Type.FieldByNumber: called on unsupported type: %s", t))
   273  }
   274  
   275  func (t *primitiveType) ZigZag() Type {
   276  	if t.zigzag == 0 {
   277  		panic(fmt.Errorf("proto.Type.ZigZag: called on unsupported type: %s", t))
   278  	}
   279  	return &primitiveTypes[t.zigzag]
   280  }
   281  
   282  var primitiveTypes = [...]primitiveType{
   283  	{name: "bool", kind: Bool, wire: Varint},
   284  	{name: "int32", kind: Int32, wire: Varint, zigzag: Sint32},
   285  	{name: "int64", kind: Int64, wire: Varint, zigzag: Sint64},
   286  	{name: "sint32", kind: Sint32, wire: Varint},
   287  	{name: "sint64", kind: Sint64, wire: Varint},
   288  	{name: "uint32", kind: Uint32, wire: Varint, zigzag: Sint32},
   289  	{name: "uint64", kind: Uint64, wire: Varint, zigzag: Sint64},
   290  	{name: "fixed32", kind: Fix32, wire: Fixed32, zigzag: Sfix32},
   291  	{name: "fixed64", kind: Fix64, wire: Fixed64, zigzag: Sfix64},
   292  	{name: "sfixed32", kind: Sfix32, wire: Fixed32},
   293  	{name: "sfixed64", kind: Sfix64, wire: Fixed64},
   294  	{name: "float", kind: Float, wire: Fixed32},
   295  	{name: "double", kind: Double, wire: Fixed64},
   296  	{name: "string", kind: String, wire: Varlen},
   297  	{name: "bytes", kind: Bytes, wire: Varlen},
   298  }
   299  
   300  func mapTypeOf(t reflect.Type, seen map[reflect.Type]Type) *mapType {
   301  	mt := &mapType{}
   302  	seen[t] = mt
   303  	mt.key = typeOf(t.Key(), seen)
   304  	mt.elem = typeOf(t.Elem(), seen)
   305  	return mt
   306  }
   307  
   308  type mapType struct {
   309  	key  Type
   310  	elem Type
   311  }
   312  
   313  func (t *mapType) String() string {
   314  	return fmt.Sprintf("map<%s, %s>", t.key.Name(), t.elem.Name())
   315  }
   316  
   317  func (t *mapType) Name() string {
   318  	return t.String()
   319  }
   320  
   321  func (t *mapType) Kind() Kind {
   322  	return Map
   323  }
   324  
   325  func (t *mapType) Key() Type {
   326  	return t.key
   327  }
   328  
   329  func (t *mapType) Elem() Type {
   330  	return t.elem
   331  }
   332  
   333  func (t *mapType) WireType() WireType {
   334  	return Varlen
   335  }
   336  
   337  func (t *mapType) NumField() int {
   338  	return 0
   339  }
   340  
   341  func (t *mapType) Field(int) Field {
   342  	panic(fmt.Errorf("proto.Type.Field: called on unsupported type: %s", t))
   343  }
   344  
   345  func (t *mapType) FieldByName(string) Field {
   346  	panic(fmt.Errorf("proto.Type.FieldByName: called on unsupported type: %s", t))
   347  }
   348  
   349  func (t *mapType) FieldByNumber(FieldNumber) Field {
   350  	panic(fmt.Errorf("proto.Type.FieldByNumber: called on unsupported type: %s", t))
   351  }
   352  
   353  func (t *mapType) ZigZag() Type {
   354  	panic(fmt.Errorf("proto.Type.ZigZag: called on unsupported type: %s", t))
   355  }
   356  
   357  func structTypeOf(t reflect.Type, seen map[reflect.Type]Type) *structType {
   358  	st := &structType{
   359  		name:           t.Name(),
   360  		fieldsByName:   make(map[string]int),
   361  		fieldsByNumber: make(map[FieldNumber]int),
   362  	}
   363  
   364  	seen[t] = st
   365  
   366  	fieldNumber := FieldNumber(0)
   367  	taggedFields := FieldNumber(0)
   368  
   369  	for i, n := 0, t.NumField(); i < n; i++ {
   370  		f := t.Field(i)
   371  
   372  		if f.PkgPath != "" {
   373  			continue // unexported
   374  		}
   375  
   376  		repeated := false
   377  		if f.Type.Kind() == reflect.Slice && f.Type.Elem().Kind() != reflect.Uint8 {
   378  			repeated = true
   379  			f.Type = f.Type.Elem() // for typeOf
   380  		}
   381  
   382  		fieldName := f.Name
   383  		fieldType := typeOf(f.Type, seen)
   384  
   385  		if tag, ok := f.Tag.Lookup("protobuf"); ok {
   386  			if fieldNumber != taggedFields {
   387  				panic("conflicting use of struct tag and naked fields")
   388  			}
   389  			t, err := parseStructTag(tag)
   390  			if err != nil {
   391  				panic(err)
   392  			}
   393  
   394  			fieldName = t.name
   395  			fieldNumber = t.fieldNumber
   396  			taggedFields = t.fieldNumber
   397  			// Because maps are represented as repeated varlen fields on the
   398  			// wire, the generated protobuf code sets the `rep` attribute on
   399  			// the struct fields.
   400  			repeated = t.repeated && f.Type.Kind() != reflect.Map
   401  
   402  			if t.zigzag {
   403  				fieldType = fieldType.ZigZag()
   404  			}
   405  		} else if fieldNumber == 0 && len(st.fields) != 0 {
   406  			panic("conflicting use of struct tag and naked fields")
   407  		} else {
   408  			fieldNumber++
   409  		}
   410  
   411  		index := len(st.fields)
   412  		st.fields = append(st.fields, Field{
   413  			Index:    index,
   414  			Number:   fieldNumber,
   415  			Name:     fieldName,
   416  			Type:     fieldType,
   417  			Repeated: repeated,
   418  		})
   419  		st.fieldsByName[fieldName] = index
   420  		st.fieldsByNumber[fieldNumber] = index
   421  	}
   422  
   423  	return st
   424  }
   425  
   426  type structType struct {
   427  	name           string
   428  	fields         []Field
   429  	fieldsByName   map[string]int
   430  	fieldsByNumber map[FieldNumber]int
   431  }
   432  
   433  func (t *structType) String() string {
   434  	s := strings.Builder{}
   435  	s.WriteString("message ")
   436  
   437  	if t.name != "" {
   438  		s.WriteString(t.name)
   439  		s.WriteString(" ")
   440  	}
   441  
   442  	s.WriteString("{")
   443  
   444  	for _, f := range t.fields {
   445  		s.WriteString("\n  ")
   446  
   447  		if f.Repeated {
   448  			s.WriteString("repeated ")
   449  		} else {
   450  		}
   451  
   452  		s.WriteString(f.Type.Name())
   453  		s.WriteString(" ")
   454  		s.WriteString(f.Name)
   455  		s.WriteString(" = ")
   456  		s.WriteString(strconv.Itoa(int(f.Number)))
   457  		s.WriteString(";")
   458  	}
   459  
   460  	if len(t.fields) == 0 {
   461  		s.WriteString("}")
   462  	} else {
   463  		s.WriteString("\n}")
   464  	}
   465  
   466  	return s.String()
   467  }
   468  
   469  func (t *structType) Name() string {
   470  	return t.name
   471  }
   472  
   473  func (t *structType) Kind() Kind {
   474  	return Struct
   475  }
   476  
   477  func (t *structType) Key() Type {
   478  	panic(fmt.Errorf("proto.Type.Key: called on unsupported type: %s", t.name))
   479  }
   480  
   481  func (t *structType) Elem() Type {
   482  	panic(fmt.Errorf("proto.Type.Elem: called on unsupported type: %s", t.name))
   483  }
   484  
   485  func (t *structType) WireType() WireType {
   486  	return Varlen
   487  }
   488  
   489  func (t *structType) NumField() int {
   490  	return len(t.fields)
   491  }
   492  
   493  func (t *structType) Field(index int) Field {
   494  	if index >= 0 && index < len(t.fields) {
   495  		return t.fields[index]
   496  	}
   497  	panic(fmt.Errorf("proto.Type.Field: protobuf message field out of bounds: %d/%d", index, len(t.fields)))
   498  }
   499  
   500  func (t *structType) FieldByName(name string) Field {
   501  	i, ok := t.fieldsByName[name]
   502  	if ok {
   503  		return t.fields[i]
   504  	}
   505  	panic(fmt.Errorf("proto.Type.FieldByName: protobuf message has not field named %q", name))
   506  }
   507  
   508  func (t *structType) FieldByNumber(number FieldNumber) Field {
   509  	i, ok := t.fieldsByNumber[number]
   510  	if ok {
   511  		return t.fields[i]
   512  	}
   513  	panic(fmt.Errorf("proto.Type.FieldByNumber: protobuf message has no field number %d", number))
   514  }
   515  
   516  func (t *structType) ZigZag() Type {
   517  	panic(fmt.Errorf("proto.Type.ZigZag: called on unsupported type: %s", t.name))
   518  }
   519  
   520  type structTag struct {
   521  	name        string
   522  	enum        string
   523  	json        string
   524  	version     int
   525  	wireType    WireType
   526  	fieldNumber FieldNumber
   527  	extensions  map[string]string
   528  	repeated    bool
   529  	zigzag      bool
   530  }
   531  
   532  func parseStructTag(tag string) (structTag, error) {
   533  	t := structTag{
   534  		version:    2,
   535  		extensions: make(map[string]string),
   536  	}
   537  
   538  	for i, f := range splitFields(tag) {
   539  		switch i {
   540  		case 0:
   541  			switch f {
   542  			case "varint":
   543  				t.wireType = Varint
   544  			case "bytes":
   545  				t.wireType = Varlen
   546  			case "fixed32":
   547  				t.wireType = Fixed32
   548  			case "fixed64":
   549  				t.wireType = Fixed64
   550  			case "zigzag32":
   551  				t.wireType = Varint
   552  				t.zigzag = true
   553  			case "zigzag64":
   554  				t.wireType = Varint
   555  				t.zigzag = true
   556  			default:
   557  				return t, fmt.Errorf("unsupported wire type in struct tag %q: %s", tag, f)
   558  			}
   559  
   560  		case 1:
   561  			n, err := strconv.Atoi(f)
   562  			if err != nil {
   563  				return t, fmt.Errorf("unsupported field number in struct tag %q: %w", tag, err)
   564  			}
   565  			t.fieldNumber = FieldNumber(n)
   566  
   567  		case 2:
   568  			switch f {
   569  			case "opt":
   570  				// not sure what this is for
   571  			case "rep":
   572  				t.repeated = true
   573  			default:
   574  				return t, fmt.Errorf("unsupported field option in struct tag %q: %s", tag, f)
   575  			}
   576  
   577  		default:
   578  			name, value := splitNameValue(f)
   579  			switch name {
   580  			case "name":
   581  				t.name = value
   582  			case "enum":
   583  				t.enum = value
   584  			case "json":
   585  				t.json = value
   586  			case "proto3":
   587  				t.version = 3
   588  			default:
   589  				t.extensions[name] = value
   590  			}
   591  		}
   592  	}
   593  
   594  	return t, nil
   595  }
   596  
   597  func splitFields(s string) []string {
   598  	return strings.Split(s, ",")
   599  }
   600  
   601  func splitNameValue(s string) (name, value string) {
   602  	i := strings.IndexByte(s, '=')
   603  	if i < 0 {
   604  		return strings.TrimSpace(s), ""
   605  	} else {
   606  		return strings.TrimSpace(s[:i]), strings.TrimSpace(s[i+1:])
   607  	}
   608  }
   609  
   610  type opaqueMessageType struct{}
   611  
   612  func (t *opaqueMessageType) String() string {
   613  	return "bytes"
   614  }
   615  
   616  func (t *opaqueMessageType) Name() string {
   617  	return "bytes"
   618  }
   619  
   620  func (t *opaqueMessageType) Kind() Kind {
   621  	return Struct
   622  }
   623  
   624  func (t *opaqueMessageType) Key() Type {
   625  	panic(fmt.Errorf("proto.Type.Key: called on unsupported type: %s", t))
   626  }
   627  
   628  func (t *opaqueMessageType) Elem() Type {
   629  	panic(fmt.Errorf("proto.Type.Elem: called on unsupported type: %s", t))
   630  }
   631  
   632  func (t *opaqueMessageType) WireType() WireType {
   633  	return Varlen
   634  }
   635  
   636  func (t *opaqueMessageType) NumField() int {
   637  	return 0
   638  }
   639  
   640  func (t *opaqueMessageType) Field(int) Field {
   641  	panic(fmt.Errorf("proto.Type.Field: called on unsupported type: %s", t))
   642  }
   643  
   644  func (t *opaqueMessageType) FieldByName(string) Field {
   645  	panic(fmt.Errorf("proto.Type.FieldByName: called on unsupported type: %s", t))
   646  }
   647  
   648  func (t *opaqueMessageType) FieldByNumber(FieldNumber) Field {
   649  	panic(fmt.Errorf("proto.Type.FieldByNumber: called on unsupported type: %s", t))
   650  }
   651  
   652  func (t *opaqueMessageType) ZigZag() Type {
   653  	panic(fmt.Errorf("proto.Type.ZigZag: called on unsupported type: %s", t))
   654  }