github.com/RomiChan/protobuf@v0.1.1-0.20230204044148-2ed269a2e54d/proto/map.go (about) 1 package proto 2 3 import ( 4 "fmt" 5 "reflect" 6 "sync" 7 "unsafe" 8 9 . "github.com/RomiChan/protobuf/internal/runtime_reflect" 10 ) 11 12 const zeroSize = 1 // sizeOfVarint(0) 13 14 type mapField struct { 15 wiretag uint64 16 keyField *structField 17 valField *structField 18 } 19 20 func (w *walker) mapCodec(t reflect.Type, f *mapField) *codec { 21 m := new(codec) 22 w.codecs[t] = m 23 24 m.size = mapSizeFuncOf(t, f) 25 m.encode = mapEncodeFuncOf(t, f) 26 m.decode = mapDecodeFuncOf(t, f, w) 27 return m 28 } 29 30 func mapSizeFuncOf(t reflect.Type, f *mapField) sizeFunc { 31 mapTagSize := sizeOfVarint(f.wiretag) 32 keyCodec := f.keyField.codec 33 valCodec := f.valField.codec 34 35 return func(p unsafe.Pointer, sf *structField) int { 36 if p == nil { 37 return 0 38 } 39 40 p = *(*unsafe.Pointer)(p) 41 42 n := 0 43 m := MapIter{} 44 defer m.Done() 45 46 for m.Init(pointer(t), p); m.HasNext(); m.Next() { 47 keySize := keyCodec.size(m.Key(), f.keyField) 48 valSize := valCodec.size(m.Value(), f.valField) 49 n += mapTagSize + sizeOfVarint(uint64(keySize+valSize)) + keySize + valSize 50 } 51 if n == 0 { 52 n = mapTagSize + zeroSize 53 } 54 return n 55 } 56 } 57 58 func mapEncodeFuncOf(t reflect.Type, f *mapField) encodeFunc { 59 mapTag := appendVarint(nil, f.wiretag) 60 zero := append(mapTag, 0) 61 keyCodec := f.keyField.codec 62 valCodec := f.valField.codec 63 64 return func(b []byte, p unsafe.Pointer, sf *structField) []byte { 65 if p == nil { 66 return b 67 } 68 p = *(*unsafe.Pointer)(p) 69 70 origLen := len(b) 71 72 m := MapIter{} 73 defer m.Done() 74 75 for m.Init(pointer(t), p); m.HasNext(); m.Next() { 76 key := m.Key() 77 val := m.Value() 78 79 keySize := keyCodec.size(key, f.keyField) 80 valSize := keyCodec.size(val, f.valField) 81 elemSize := keySize + valSize 82 83 b = append(b, mapTag...) 84 b = appendVarint(b, uint64(elemSize)) 85 b = keyCodec.encode(b, key, f.keyField) 86 b = valCodec.encode(b, val, f.valField) 87 } 88 89 if len(b) == origLen { 90 b = append(b, zero...) 91 } 92 return b 93 } 94 } 95 96 func formatWireTag(wire uint64) reflect.StructTag { 97 return reflect.StructTag(fmt.Sprintf(`protobuf:"%s,%d,opt"`, wireType(wire&7), wire>>3)) 98 } 99 100 func mapDecodeFuncOf(t reflect.Type, m *mapField, w *walker) decodeFunc { 101 structType := reflect.StructOf([]reflect.StructField{ 102 {Name: "Key", Type: t.Key(), Tag: formatWireTag(m.keyField.wiretag)}, 103 {Name: "Elem", Type: t.Elem(), Tag: formatWireTag(m.valField.wiretag)}, 104 }) 105 106 info := w.structInfo(structType) 107 structPool := new(sync.Pool) 108 structZero := pointer(reflect.Zero(structType).Interface()) 109 110 valueType := t.Elem() 111 valueOffset := structType.Field(1).Offset 112 113 mtype := pointer(t) 114 stype := pointer(structType) 115 vtype := pointer(valueType) 116 117 return func(b []byte, p unsafe.Pointer) (int, error) { 118 m := (*unsafe.Pointer)(p) 119 if *m == nil { 120 *m = MakeMap(mtype, 10) 121 } 122 if len(b) == 0 { 123 return 0, nil 124 } 125 126 s := pointer(structPool.Get()) 127 if s == nil { 128 s = unsafe.Pointer(reflect.New(structType).Pointer()) 129 } 130 131 _, nl, err := decodeVarint(b) 132 if err != nil { 133 return 0, err 134 } 135 n, err := info.decode(b[nl:], s) 136 if err == nil { 137 v := MapAssign(mtype, *m, s) 138 Assign(vtype, v, unsafe.Pointer(uintptr(s)+valueOffset)) 139 } 140 Assign(stype, s, structZero) 141 structPool.Put(s) 142 return n + nl, err 143 } 144 }