github.com/RomiChan/protobuf@v0.1.1-0.20230204044148-2ed269a2e54d/proto/map.go (about)

     1  package proto
     2  
     3  import (
     4  	"fmt"
     5  	"reflect"
     6  	"sync"
     7  	"unsafe"
     8  
     9  	. "github.com/RomiChan/protobuf/internal/runtime_reflect"
    10  )
    11  
    12  const zeroSize = 1 // sizeOfVarint(0)
    13  
    14  type mapField struct {
    15  	wiretag  uint64
    16  	keyField *structField
    17  	valField *structField
    18  }
    19  
    20  func (w *walker) mapCodec(t reflect.Type, f *mapField) *codec {
    21  	m := new(codec)
    22  	w.codecs[t] = m
    23  
    24  	m.size = mapSizeFuncOf(t, f)
    25  	m.encode = mapEncodeFuncOf(t, f)
    26  	m.decode = mapDecodeFuncOf(t, f, w)
    27  	return m
    28  }
    29  
    30  func mapSizeFuncOf(t reflect.Type, f *mapField) sizeFunc {
    31  	mapTagSize := sizeOfVarint(f.wiretag)
    32  	keyCodec := f.keyField.codec
    33  	valCodec := f.valField.codec
    34  
    35  	return func(p unsafe.Pointer, sf *structField) int {
    36  		if p == nil {
    37  			return 0
    38  		}
    39  
    40  		p = *(*unsafe.Pointer)(p)
    41  
    42  		n := 0
    43  		m := MapIter{}
    44  		defer m.Done()
    45  
    46  		for m.Init(pointer(t), p); m.HasNext(); m.Next() {
    47  			keySize := keyCodec.size(m.Key(), f.keyField)
    48  			valSize := valCodec.size(m.Value(), f.valField)
    49  			n += mapTagSize + sizeOfVarint(uint64(keySize+valSize)) + keySize + valSize
    50  		}
    51  		if n == 0 {
    52  			n = mapTagSize + zeroSize
    53  		}
    54  		return n
    55  	}
    56  }
    57  
    58  func mapEncodeFuncOf(t reflect.Type, f *mapField) encodeFunc {
    59  	mapTag := appendVarint(nil, f.wiretag)
    60  	zero := append(mapTag, 0)
    61  	keyCodec := f.keyField.codec
    62  	valCodec := f.valField.codec
    63  
    64  	return func(b []byte, p unsafe.Pointer, sf *structField) []byte {
    65  		if p == nil {
    66  			return b
    67  		}
    68  		p = *(*unsafe.Pointer)(p)
    69  
    70  		origLen := len(b)
    71  
    72  		m := MapIter{}
    73  		defer m.Done()
    74  
    75  		for m.Init(pointer(t), p); m.HasNext(); m.Next() {
    76  			key := m.Key()
    77  			val := m.Value()
    78  
    79  			keySize := keyCodec.size(key, f.keyField)
    80  			valSize := keyCodec.size(val, f.valField)
    81  			elemSize := keySize + valSize
    82  
    83  			b = append(b, mapTag...)
    84  			b = appendVarint(b, uint64(elemSize))
    85  			b = keyCodec.encode(b, key, f.keyField)
    86  			b = valCodec.encode(b, val, f.valField)
    87  		}
    88  
    89  		if len(b) == origLen {
    90  			b = append(b, zero...)
    91  		}
    92  		return b
    93  	}
    94  }
    95  
    96  func formatWireTag(wire uint64) reflect.StructTag {
    97  	return reflect.StructTag(fmt.Sprintf(`protobuf:"%s,%d,opt"`, wireType(wire&7), wire>>3))
    98  }
    99  
   100  func mapDecodeFuncOf(t reflect.Type, m *mapField, w *walker) decodeFunc {
   101  	structType := reflect.StructOf([]reflect.StructField{
   102  		{Name: "Key", Type: t.Key(), Tag: formatWireTag(m.keyField.wiretag)},
   103  		{Name: "Elem", Type: t.Elem(), Tag: formatWireTag(m.valField.wiretag)},
   104  	})
   105  
   106  	info := w.structInfo(structType)
   107  	structPool := new(sync.Pool)
   108  	structZero := pointer(reflect.Zero(structType).Interface())
   109  
   110  	valueType := t.Elem()
   111  	valueOffset := structType.Field(1).Offset
   112  
   113  	mtype := pointer(t)
   114  	stype := pointer(structType)
   115  	vtype := pointer(valueType)
   116  
   117  	return func(b []byte, p unsafe.Pointer) (int, error) {
   118  		m := (*unsafe.Pointer)(p)
   119  		if *m == nil {
   120  			*m = MakeMap(mtype, 10)
   121  		}
   122  		if len(b) == 0 {
   123  			return 0, nil
   124  		}
   125  
   126  		s := pointer(structPool.Get())
   127  		if s == nil {
   128  			s = unsafe.Pointer(reflect.New(structType).Pointer())
   129  		}
   130  
   131  		_, nl, err := decodeVarint(b)
   132  		if err != nil {
   133  			return 0, err
   134  		}
   135  		n, err := info.decode(b[nl:], s)
   136  		if err == nil {
   137  			v := MapAssign(mtype, *m, s)
   138  			Assign(vtype, v, unsafe.Pointer(uintptr(s)+valueOffset))
   139  		}
   140  		Assign(stype, s, structZero)
   141  		structPool.Put(s)
   142  		return n + nl, err
   143  	}
   144  }