github.com/RomiChan/protobuf@v0.1.1-0.20230204044148-2ed269a2e54d/proto/proto.go (about)

     1  package proto
     2  
     3  import (
     4  	"fmt"
     5  	"reflect"
     6  	"sync"
     7  	"unsafe"
     8  
     9  	"github.com/RomiChan/syncx"
    10  )
    11  
    12  //go:generate go run ./gen/option
    13  //go:generate go run ./gen/required
    14  
    15  func Size(v interface{}) int {
    16  	t, p := inspect(v)
    17  	if t.Kind() != reflect.Ptr {
    18  		panic(fmt.Errorf("proto.Marshal(%T): not a pointer", v))
    19  	}
    20  	t = t.Elem()
    21  	info := cachedStructInfoOf(t)
    22  	return info.size(p)
    23  }
    24  
    25  func Marshal(v interface{}) ([]byte, error) {
    26  	t, p := inspect(v)
    27  	if t.Kind() != reflect.Ptr {
    28  		return nil, fmt.Errorf("proto.Marshal(%T): not a pointer", v)
    29  	}
    30  	t = t.Elem()
    31  	info := cachedStructInfoOf(t)
    32  	b := make([]byte, 0, info.size(p))
    33  	b = info.encode(b, p)
    34  	return b, nil
    35  }
    36  
    37  func Unmarshal(b []byte, v interface{}) error {
    38  	if len(b) == 0 {
    39  		// nothing to do
    40  		return nil
    41  	}
    42  
    43  	t, p := inspect(v)
    44  	if t.Kind() != reflect.Pointer || p == nil {
    45  		return &InvalidUnmarshalError{Type: t}
    46  	}
    47  	elem := t.Elem()
    48  	if elem.Kind() != reflect.Struct {
    49  		return &InvalidUnmarshalError{Type: t}
    50  	}
    51  	c := cachedStructInfoOf(elem)
    52  
    53  	n, err := c.decode(b, p)
    54  	if err != nil {
    55  		return err
    56  	}
    57  	if n < len(b) {
    58  		return fmt.Errorf("proto.Unmarshal(%T): read=%d < buffer=%d", v, n, len(b))
    59  	}
    60  	return nil
    61  }
    62  
    63  type iface struct {
    64  	typ unsafe.Pointer
    65  	ptr unsafe.Pointer
    66  }
    67  
    68  func inspect(v interface{}) (reflect.Type, unsafe.Pointer) {
    69  	return reflect.TypeOf(v), pointer(v)
    70  }
    71  
    72  func pointer(v interface{}) unsafe.Pointer {
    73  	return (*iface)(unsafe.Pointer(&v)).ptr
    74  }
    75  
    76  type fieldNumber uint
    77  
    78  type wireType uint
    79  
    80  const (
    81  	varint  wireType = 0
    82  	fixed64 wireType = 1
    83  	varlen  wireType = 2
    84  	fixed32 wireType = 5
    85  )
    86  
    87  func (wt wireType) String() string {
    88  	switch wt {
    89  	case varint:
    90  		return "varint"
    91  	case varlen:
    92  		return "bytes"
    93  	case fixed32:
    94  		return "fixed32"
    95  	case fixed64:
    96  		return "fixed64"
    97  	default:
    98  		return "unknown"
    99  	}
   100  }
   101  
   102  type codec struct {
   103  	size   sizeFunc
   104  	encode encodeFunc
   105  	decode decodeFunc
   106  }
   107  
   108  var structInfoCache syncx.Map[unsafe.Pointer, *structInfo] // map[unsafe.Pointer]*structInfo
   109  var codecCache sync.Map                                    // map[reflect.Type]codec
   110  
   111  func cachedStructInfoOf(t reflect.Type) *structInfo {
   112  	c, ok := structInfoCache.Load(pointer(t))
   113  	if ok {
   114  		return c
   115  	}
   116  
   117  	w := &walker{
   118  		codecs: make(map[reflect.Type]*codec),
   119  		infos:  make(map[reflect.Type]*structInfo),
   120  	}
   121  
   122  	info := w.structInfo(t)
   123  	actual, _ := structInfoCache.LoadOrStore(pointer(t), info)
   124  	return actual
   125  }