github.com/segmentio/encoding@v0.4.0/proto/map.go (about) 1 package proto 2 3 import ( 4 "io" 5 "reflect" 6 "sync" 7 "unsafe" 8 9 . "github.com/segmentio/encoding/internal/runtime_reflect" 10 ) 11 12 const ( 13 zeroSize = 1 // sizeOfVarint(0) 14 ) 15 16 type mapField struct { 17 number uint16 18 keyFlags uint8 19 valFlags uint8 20 keyCodec *codec 21 valCodec *codec 22 } 23 24 func mapCodecOf(t reflect.Type, f *mapField, seen map[reflect.Type]*codec) *codec { 25 m := new(codec) 26 seen[t] = m 27 28 m.wire = varlen 29 m.size = mapSizeFuncOf(t, f) 30 m.encode = mapEncodeFuncOf(t, f) 31 m.decode = mapDecodeFuncOf(t, f, seen) 32 return m 33 } 34 35 func mapSizeFuncOf(t reflect.Type, f *mapField) sizeFunc { 36 mapTagSize := sizeOfTag(fieldNumber(f.number), varlen) 37 keyTagSize := sizeOfTag(1, wireType(f.keyCodec.wire)) 38 valTagSize := sizeOfTag(2, wireType(f.valCodec.wire)) 39 return func(p unsafe.Pointer, flags flags) int { 40 if p == nil { 41 return 0 42 } 43 44 if !flags.has(inline) { 45 p = *(*unsafe.Pointer)(p) 46 } 47 48 n := 0 49 m := MapIter{} 50 defer m.Done() 51 52 for m.Init(pointer(t), p); m.HasNext(); m.Next() { 53 keySize := f.keyCodec.size(m.Key(), wantzero) 54 valSize := f.valCodec.size(m.Value(), wantzero) 55 56 if keySize > 0 { 57 n += keyTagSize + keySize 58 if (f.keyFlags & embedded) != 0 { 59 n += sizeOfVarint(uint64(keySize)) 60 } 61 } 62 63 if valSize > 0 { 64 n += valTagSize + valSize 65 if (f.valFlags & embedded) != 0 { 66 n += sizeOfVarint(uint64(valSize)) 67 } 68 } 69 70 n += mapTagSize + sizeOfVarint(uint64(keySize+valSize)) 71 } 72 73 if n == 0 { 74 n = mapTagSize + zeroSize 75 } 76 77 return n 78 } 79 } 80 81 func mapEncodeFuncOf(t reflect.Type, f *mapField) encodeFunc { 82 keyTag := [1]byte{} 83 valTag := [1]byte{} 84 encodeTag(keyTag[:], 1, f.keyCodec.wire) 85 encodeTag(valTag[:], 2, f.valCodec.wire) 86 87 number := fieldNumber(f.number) 88 mapTag := make([]byte, sizeOfTag(number, varlen)+zeroSize) 89 encodeTag(mapTag, number, varlen) 90 91 zero := mapTag 92 mapTag = mapTag[:len(mapTag)-1] 93 94 return func(b []byte, p unsafe.Pointer, flags flags) (int, error) { 95 if p == nil { 96 return 0, nil 97 } 98 99 if !flags.has(inline) { 100 p = *(*unsafe.Pointer)(p) 101 } 102 103 offset := 0 104 m := MapIter{} 105 defer m.Done() 106 107 for m.Init(pointer(t), p); m.HasNext(); m.Next() { 108 key := m.Key() 109 val := m.Value() 110 111 keySize := f.keyCodec.size(key, wantzero) 112 valSize := f.valCodec.size(val, wantzero) 113 elemSize := keySize + valSize 114 115 if keySize > 0 { 116 elemSize += len(keyTag) 117 if (f.keyFlags & embedded) != 0 { 118 elemSize += sizeOfVarint(uint64(keySize)) 119 } 120 } 121 122 if valSize > 0 { 123 elemSize += len(valTag) 124 if (f.valFlags & embedded) != 0 { 125 elemSize += sizeOfVarint(uint64(valSize)) 126 } 127 } 128 129 n := copy(b[offset:], mapTag) 130 offset += n 131 if n < len(mapTag) { 132 return offset, io.ErrShortBuffer 133 } 134 n, err := encodeVarint(b[offset:], uint64(elemSize)) 135 offset += n 136 if err != nil { 137 return offset, err 138 } 139 140 if keySize > 0 { 141 n := copy(b[offset:], keyTag[:]) 142 offset += n 143 if n < len(keyTag) { 144 return offset, io.ErrShortBuffer 145 } 146 147 if (f.keyFlags & embedded) != 0 { 148 n, err := encodeVarint(b[offset:], uint64(keySize)) 149 offset += n 150 if err != nil { 151 return offset, err 152 } 153 } 154 155 if (len(b) - offset) < keySize { 156 return len(b), io.ErrShortBuffer 157 } 158 159 n, err := f.keyCodec.encode(b[offset:offset+keySize], key, wantzero) 160 offset += n 161 if err != nil { 162 return offset, err 163 } 164 } 165 166 if valSize > 0 { 167 n := copy(b[offset:], valTag[:]) 168 offset += n 169 if n < len(valTag) { 170 return n, io.ErrShortBuffer 171 } 172 173 if (f.valFlags & embedded) != 0 { 174 n, err := encodeVarint(b[offset:], uint64(valSize)) 175 offset += n 176 if err != nil { 177 return offset, err 178 } 179 } 180 181 if (len(b) - offset) < valSize { 182 return len(b), io.ErrShortBuffer 183 } 184 185 n, err := f.valCodec.encode(b[offset:offset+valSize], val, wantzero) 186 offset += n 187 if err != nil { 188 return offset, err 189 } 190 } 191 } 192 193 if offset == 0 { 194 if offset = copy(b, zero); offset < len(zero) { 195 return offset, io.ErrShortBuffer 196 } 197 } 198 199 return offset, nil 200 } 201 } 202 203 func mapDecodeFuncOf(t reflect.Type, f *mapField, seen map[reflect.Type]*codec) decodeFunc { 204 structType := reflect.StructOf([]reflect.StructField{ 205 {Name: "Key", Type: t.Key()}, 206 {Name: "Elem", Type: t.Elem()}, 207 }) 208 209 structCodec := codecOf(structType, seen) 210 structPool := new(sync.Pool) 211 structZero := pointer(reflect.Zero(structType).Interface()) 212 213 valueType := t.Elem() 214 valueOffset := structType.Field(1).Offset 215 216 mtype := pointer(t) 217 stype := pointer(structType) 218 vtype := pointer(valueType) 219 220 return func(b []byte, p unsafe.Pointer, _ flags) (int, error) { 221 m := (*unsafe.Pointer)(p) 222 if *m == nil { 223 *m = MakeMap(mtype, 10) 224 } 225 if len(b) == 0 { 226 return 0, nil 227 } 228 229 s := pointer(structPool.Get()) 230 if s == nil { 231 s = unsafe.Pointer(reflect.New(structType).Pointer()) 232 } 233 234 n, err := structCodec.decode(b, s, noflags) 235 if err == nil { 236 v := MapAssign(mtype, *m, s) 237 Assign(vtype, v, unsafe.Pointer(uintptr(s)+valueOffset)) 238 } 239 240 Assign(stype, s, structZero) 241 structPool.Put(s) 242 return n, err 243 } 244 }