github.com/segmentio/encoding@v0.4.0/proto/proto.go (about)

     1  package proto
     2  
     3  import (
     4  	"fmt"
     5  	"reflect"
     6  	"sync/atomic"
     7  	"unsafe"
     8  )
     9  
    10  func Size(v interface{}) int {
    11  	t, p := inspect(v)
    12  	c := cachedCodecOf(t)
    13  	return c.size(p, inline|toplevel)
    14  }
    15  
    16  func Marshal(v interface{}) ([]byte, error) {
    17  	t, p := inspect(v)
    18  	c := cachedCodecOf(t)
    19  	b := make([]byte, c.size(p, inline|toplevel))
    20  	_, err := c.encode(b, p, inline|toplevel)
    21  	if err != nil {
    22  		return nil, fmt.Errorf("proto.Marshal(%T): %w", v, err)
    23  	}
    24  	return b, nil
    25  }
    26  
    27  func MarshalTo(b []byte, v interface{}) (int, error) {
    28  	t, p := inspect(v)
    29  	c := cachedCodecOf(t)
    30  	n, err := c.encode(b, p, inline|toplevel)
    31  	if err != nil {
    32  		err = fmt.Errorf("proto.MarshalTo: %w", err)
    33  	}
    34  	return n, err
    35  }
    36  
    37  func Unmarshal(b []byte, v interface{}) error {
    38  	if len(b) == 0 {
    39  		// An empty input is a valid protobuf message with all fields set to the
    40  		// zero-value.
    41  		reflect.ValueOf(v).Elem().Set(reflect.Zero(reflect.TypeOf(v).Elem()))
    42  		return nil
    43  	}
    44  
    45  	t, p := inspect(v)
    46  	t = t.Elem() // Unmarshal must be passed a pointer
    47  	c := cachedCodecOf(t)
    48  
    49  	n, err := c.decode(b, p, toplevel)
    50  	if err != nil {
    51  		return err
    52  	}
    53  	if n < len(b) {
    54  		return fmt.Errorf("proto.Unmarshal(%T): read=%d < buffer=%d", v, n, len(b))
    55  	}
    56  	return nil
    57  }
    58  
    59  type flags uintptr
    60  
    61  const (
    62  	noflags  flags = 0
    63  	inline   flags = 1 << 0
    64  	wantzero flags = 1 << 1
    65  	// Shared with structField.flags in struct.go:
    66  	// zigzag flags = 1 << 2
    67  	toplevel flags = 1 << 3
    68  )
    69  
    70  func (f flags) has(x flags) bool {
    71  	return (f & x) != 0
    72  }
    73  
    74  func (f flags) with(x flags) flags {
    75  	return f | x
    76  }
    77  
    78  func (f flags) without(x flags) flags {
    79  	return f & ^x
    80  }
    81  
    82  func (f flags) uint64(i int64) uint64 {
    83  	if f.has(zigzag) {
    84  		return encodeZigZag64(i)
    85  	} else {
    86  		return uint64(i)
    87  	}
    88  }
    89  
    90  func (f flags) int64(u uint64) int64 {
    91  	if f.has(zigzag) {
    92  		return decodeZigZag64(u)
    93  	} else {
    94  		return int64(u)
    95  	}
    96  }
    97  
    98  type iface struct {
    99  	typ unsafe.Pointer
   100  	ptr unsafe.Pointer
   101  }
   102  
   103  func inspect(v interface{}) (reflect.Type, unsafe.Pointer) {
   104  	return reflect.TypeOf(v), pointer(v)
   105  }
   106  
   107  func pointer(v interface{}) unsafe.Pointer {
   108  	return (*iface)(unsafe.Pointer(&v)).ptr
   109  }
   110  
   111  func inlined(t reflect.Type) bool {
   112  	switch t.Kind() {
   113  	case reflect.Ptr:
   114  		return true
   115  	case reflect.Map:
   116  		return true
   117  	case reflect.Struct:
   118  		return t.NumField() == 1 && inlined(t.Field(0).Type)
   119  	default:
   120  		return false
   121  	}
   122  }
   123  
   124  type fieldNumber uint
   125  
   126  type wireType uint
   127  
   128  const (
   129  	varint  wireType = 0
   130  	fixed64 wireType = 1
   131  	varlen  wireType = 2
   132  	fixed32 wireType = 5
   133  )
   134  
   135  func (wt wireType) String() string {
   136  	switch wt {
   137  	case varint:
   138  		return "varint"
   139  	case varlen:
   140  		return "varlen"
   141  	case fixed32:
   142  		return "fixed32"
   143  	case fixed64:
   144  		return "fixed64"
   145  	default:
   146  		return "unknown"
   147  	}
   148  }
   149  
   150  type codec struct {
   151  	wire   wireType
   152  	size   sizeFunc
   153  	encode encodeFunc
   154  	decode decodeFunc
   155  }
   156  
   157  var codecCache atomic.Value // map[unsafe.Pointer]*codec
   158  
   159  func loadCachedCodec(t reflect.Type) (*codec, map[unsafe.Pointer]*codec) {
   160  	cache, _ := codecCache.Load().(map[unsafe.Pointer]*codec)
   161  	return cache[pointer(t)], cache
   162  }
   163  
   164  func storeCachedCodec(newCache map[unsafe.Pointer]*codec) {
   165  	codecCache.Store(newCache)
   166  }
   167  
   168  func cachedCodecOf(t reflect.Type) *codec {
   169  	c, oldCache := loadCachedCodec(t)
   170  	if c != nil {
   171  		return c
   172  	}
   173  
   174  	var p reflect.Type
   175  	isPtr := t.Kind() == reflect.Ptr
   176  	if isPtr {
   177  		p = t
   178  		t = t.Elem()
   179  	} else {
   180  		p = reflect.PtrTo(t)
   181  	}
   182  
   183  	seen := make(map[reflect.Type]*codec)
   184  	c1 := codecOf(t, seen)
   185  	c2 := codecOf(p, seen)
   186  
   187  	newCache := make(map[unsafe.Pointer]*codec, len(oldCache)+2)
   188  	for p, c := range oldCache {
   189  		newCache[p] = c
   190  	}
   191  
   192  	newCache[pointer(t)] = c1
   193  	newCache[pointer(p)] = c2
   194  	storeCachedCodec(newCache)
   195  
   196  	if isPtr {
   197  		return c2
   198  	} else {
   199  		return c1
   200  	}
   201  }
   202  
   203  func codecOf(t reflect.Type, seen map[reflect.Type]*codec) *codec {
   204  	if c := seen[t]; c != nil {
   205  		return c
   206  	}
   207  
   208  	switch {
   209  	case implements(t, messageType):
   210  		return messageCodecOf(t)
   211  	case implements(t, customMessageType) && !implements(t, protoMessageType):
   212  		return customCodecOf(t)
   213  	}
   214  
   215  	switch t.Kind() {
   216  	case reflect.Bool:
   217  		return &boolCodec
   218  	case reflect.Int:
   219  		return &intCodec
   220  	case reflect.Int32:
   221  		return &int32Codec
   222  	case reflect.Int64:
   223  		return &int64Codec
   224  	case reflect.Uint:
   225  		return &uintCodec
   226  	case reflect.Uint32:
   227  		return &uint32Codec
   228  	case reflect.Uint64:
   229  		return &uint64Codec
   230  	case reflect.Float32:
   231  		return &float32Codec
   232  	case reflect.Float64:
   233  		return &float64Codec
   234  	case reflect.String:
   235  		return &stringCodec
   236  	case reflect.Array:
   237  		elem := t.Elem()
   238  		switch elem.Kind() {
   239  		case reflect.Uint8:
   240  			return byteArrayCodecOf(t, seen)
   241  		}
   242  	case reflect.Slice:
   243  		elem := t.Elem()
   244  		switch elem.Kind() {
   245  		case reflect.Uint8:
   246  			return &bytesCodec
   247  		}
   248  	case reflect.Struct:
   249  		return structCodecOf(t, seen)
   250  	case reflect.Ptr:
   251  		return pointerCodecOf(t, seen)
   252  	}
   253  
   254  	panic("unsupported type: " + t.String())
   255  }
   256  
   257  // backward compatibility with gogoproto custom types.
   258  type customMessage interface {
   259  	Size() int
   260  	MarshalTo([]byte) (int, error)
   261  	Unmarshal([]byte) error
   262  }
   263  
   264  type protoMessage interface {
   265  	ProtoMessage()
   266  }
   267  
   268  var (
   269  	messageType       = reflect.TypeOf((*Message)(nil)).Elem()
   270  	customMessageType = reflect.TypeOf((*customMessage)(nil)).Elem()
   271  	protoMessageType  = reflect.TypeOf((*protoMessage)(nil)).Elem()
   272  )
   273  
   274  func implements(t, iface reflect.Type) bool {
   275  	return t.Implements(iface) || reflect.PtrTo(t).Implements(iface)
   276  }