github.com/RomiChan/protobuf@v0.1.1-0.20230204044148-2ed269a2e54d/proto/walker.go (about) 1 package proto 2 3 import ( 4 "reflect" 5 "unsafe" 6 ) 7 8 var ( 9 optionBoolType = reflect.TypeOf((*Option[bool])(nil)).Elem() 10 optionInt32Type = reflect.TypeOf((*Option[int32])(nil)).Elem() 11 optionInt64Type = reflect.TypeOf((*Option[int64])(nil)).Elem() 12 optionUInt32Type = reflect.TypeOf((*Option[uint32])(nil)).Elem() 13 optionUInt64Type = reflect.TypeOf((*Option[uint64])(nil)).Elem() 14 optionFloat32Type = reflect.TypeOf((*Option[float32])(nil)).Elem() 15 optionFloat64Type = reflect.TypeOf((*Option[float64])(nil)).Elem() 16 optionStringType = reflect.TypeOf((*Option[string])(nil)).Elem() 17 ) 18 19 type walker struct { 20 codecs map[reflect.Type]*codec 21 infos map[reflect.Type]*structInfo 22 } 23 24 type walkerConfig struct { 25 zigzag bool 26 required bool 27 } 28 29 func (w *walker) codec(t reflect.Type, conf *walkerConfig) *codec { 30 if c, ok := w.codecs[t]; ok { 31 return c 32 } 33 if conf.required { 34 return w.required(t, conf) 35 } 36 switch t.Kind() { 37 case reflect.Bool: 38 return &boolCodec 39 case reflect.Int32: 40 if conf.zigzag { 41 return &zigzag32Codec 42 } 43 return &int32Codec 44 case reflect.Int64: 45 if conf.zigzag { 46 return &zigzag64Codec 47 } 48 return &int64Codec 49 case reflect.Uint32: 50 return &uint32Codec 51 case reflect.Uint64: 52 return &uint64Codec 53 case reflect.Float32: 54 return &float32Codec 55 case reflect.Float64: 56 return &float64Codec 57 case reflect.String: 58 if conf.required { 59 return &stringRequiredCodec 60 } 61 return &stringCodec 62 case reflect.Slice: 63 elem := t.Elem() 64 switch elem.Kind() { 65 case reflect.Uint8: 66 return &bytesCodec 67 } 68 case reflect.Struct: 69 return w.structCodec(t) 70 case reflect.Ptr: 71 return w.pointer(t, conf) 72 } 73 74 panic("unsupported type: " + t.String()) 75 } 76 77 func (w *walker) structCodec(t reflect.Type) *codec { 78 if c, ok := codecCache.Load(pointer(t)); ok { 79 return c.(*codec) 80 } 81 if c, ok := w.codecs[t]; ok { 82 return c 83 } 84 c := new(codec) 85 w.codecs[t] = c 86 elem := t.Elem() 87 info := w.structInfo(elem) 88 c.size = func(p unsafe.Pointer, f *structField) int { 89 p = deref(p) 90 if p != nil { 91 n := info.size(p) 92 n += sizeOfVarint(uint64(n)) + f.tagsize 93 return n 94 } 95 return 0 96 } 97 c.encode = func(b []byte, p unsafe.Pointer, f *structField) []byte { 98 p = deref(p) 99 if p != nil { 100 b = appendVarint(b, f.wiretag) 101 n := info.size(p) 102 b = appendVarint(b, uint64(n)) 103 return info.encode(b, p) 104 } 105 return b 106 } 107 c.decode = func(b []byte, p unsafe.Pointer) (int, error) { 108 v := (*unsafe.Pointer)(p) 109 if *v == nil { 110 *v = unsafe.Pointer(reflect.New(elem).Pointer()) 111 } 112 _, n, err := decodeVarint(b) 113 if err != nil { 114 return n, err 115 } 116 l, err := info.decode(b[n:], *v) 117 return n + l, err 118 } 119 actualCodec, _ := codecCache.LoadOrStore(pointer(t), c) 120 return actualCodec.(*codec) 121 } 122 123 func baseKindOf(t reflect.Type) reflect.Kind { 124 return baseTypeOf(t).Kind() 125 } 126 127 func baseTypeOf(t reflect.Type) reflect.Type { 128 for t.Kind() == reflect.Ptr { 129 t = t.Elem() 130 } 131 return t 132 } 133 134 func (w *walker) structInfo(t reflect.Type) *structInfo { 135 if info, ok := structInfoCache.Load(pointer(t)); ok { 136 return info 137 } 138 if i, ok := w.infos[t]; ok { 139 return i 140 } 141 142 info := new(structInfo) 143 w.infos[t] = info 144 numField := t.NumField() 145 fields := make([]*structField, 0, numField) 146 for i := 0; i < numField; i++ { 147 f := t.Field(i) 148 if f.PkgPath != "" { 149 continue // unexported 150 } 151 152 tag, ok := f.Tag.Lookup("protobuf") 153 if !ok { 154 continue // no tag 155 } 156 157 field := structField{ 158 offset: f.Offset, 159 } 160 161 t, err := parseStructTag(tag) 162 if err != nil { 163 panic(err) 164 } 165 field.wiretag = uint64(t.fieldNumber)<<3 | uint64(t.wireType) 166 switch t.wireType { 167 case fixed32: 168 switch f.Type { 169 case optionFloat32Type: 170 field.codec = &float32OptionCodec 171 case optionUInt32Type: 172 field.codec = &fixed32OptionCodec 173 } 174 switch baseKindOf(f.Type) { 175 case reflect.Uint32: 176 field.codec = &fixed32Codec 177 case reflect.Float32: 178 field.codec = &float32Codec 179 } 180 case fixed64: 181 switch f.Type { 182 case optionUInt64Type: 183 field.codec = &fixed64OptionCodec 184 case optionFloat64Type: 185 field.codec = &float64OptionCodec 186 } 187 switch baseKindOf(f.Type) { 188 case reflect.Uint64: 189 field.codec = &fixed64Codec 190 case reflect.Float64: 191 field.codec = &float64Codec 192 } 193 } 194 if field.codec == nil { 195 switch f.Type { 196 case optionBoolType: 197 field.codec = &boolOptionCodec 198 case optionInt32Type: 199 field.codec = &int32OptionCodec 200 if t.zigzag { 201 field.codec = &zigzag32OptionCodec 202 } 203 case optionInt64Type: 204 field.codec = &int64OptionCodec 205 if t.zigzag { 206 field.codec = &zigzag64OptionCodec 207 } 208 case optionUInt32Type: 209 field.codec = &uint32OptionCodec 210 case optionUInt64Type: 211 field.codec = &uint64OptionCodec 212 case optionStringType: 213 field.codec = &stringOptionCodec 214 } 215 } 216 if field.codec == nil { 217 conf := &walkerConfig{ 218 zigzag: t.zigzag, 219 // required: t.required, 220 } 221 switch baseKindOf(f.Type) { 222 case reflect.Struct: 223 field.codec = w.codec(f.Type, conf) 224 225 case reflect.Slice: 226 elem := f.Type.Elem() 227 if elem.Kind() == reflect.Uint8 { // []byte 228 field.codec = &bytesCodec 229 } else { 230 conf.required = true 231 field.codec = w.codec(elem, conf) 232 field.codec = sliceCodecOf(f.Type, field.codec, w) 233 } 234 235 case reflect.Map: 236 conf.required = true // map key and val should be encoded always 237 key, val := f.Type.Key(), f.Type.Elem() 238 m := &mapField{wiretag: field.wiretag} 239 240 t, _ := parseStructTag(f.Tag.Get("protobuf_key")) 241 keyField := &structField{wiretag: uint64(t.fieldNumber)<<3 | uint64(t.wireType)} 242 keyField.tagsize = sizeOfVarint(keyField.wiretag) 243 conf.zigzag = t.zigzag 244 keyField.codec = w.codec(key, conf) 245 246 t, _ = parseStructTag(f.Tag.Get("protobuf_val")) 247 valFiled := &structField{wiretag: uint64(t.fieldNumber)<<3 | uint64(t.wireType)} 248 valFiled.tagsize = sizeOfVarint(valFiled.wiretag) 249 conf.zigzag = t.zigzag 250 valFiled.codec = w.codec(val, conf) 251 252 m.keyField = keyField 253 m.valField = valFiled 254 field.codec = w.mapCodec(f.Type, m) 255 256 default: 257 field.codec = w.codec(f.Type, conf) 258 } 259 } 260 field.tagsize = sizeOfVarint(field.wiretag) 261 fields = append(fields, &field) 262 } 263 264 // copy to save capacity 265 fields2 := make([]*structField, len(fields)) 266 copy(fields2, fields) 267 info.fields = fields2 268 269 info.fieldIndex = make(map[fieldNumber]*structField, len(info.fields)) 270 for _, f := range info.fields { 271 info.fieldIndex[f.fieldNumber()] = f 272 } 273 274 structInfoCache.Store(pointer(t), info) 275 return info 276 } 277 278 // @@@ Pointers @@@ 279 280 func deref(p unsafe.Pointer) unsafe.Pointer { 281 return *(*unsafe.Pointer)(p) 282 } 283 284 func (w *walker) pointer(t reflect.Type, conf *walkerConfig) *codec { 285 switch t.Elem().Kind() { 286 case reflect.Struct: 287 return w.structCodec(t) 288 } 289 // common value 290 p := new(codec) 291 w.codecs[t] = p 292 c := w.codec(t.Elem(), conf) 293 p.size = pointerSizeFuncOf(t, c) 294 p.encode = pointerEncodeFuncOf(t, c) 295 p.decode = pointerDecodeFuncOf(t, c) 296 return p 297 } 298 299 func (w *walker) required(t reflect.Type, conf *walkerConfig) *codec { 300 if c, ok := w.codecs[t]; ok { 301 return c 302 } 303 304 switch t.Kind() { 305 case reflect.Bool: 306 return &boolRequiredCodec 307 case reflect.Int32: 308 if conf.zigzag { 309 return &zigzag32RequiredCodec 310 } 311 return &int32RequiredCodec 312 case reflect.Int64: 313 if conf.zigzag { 314 return &zigzag64RequiredCodec 315 } 316 return &int64RequiredCodec 317 case reflect.Uint32: 318 return &uint32RequiredCodec 319 case reflect.Uint64: 320 return &uint64RequiredCodec 321 case reflect.Float32: 322 return &float32RequiredCodec 323 case reflect.Float64: 324 return &float64RequiredCodec 325 case reflect.String: 326 return &stringRequiredCodec 327 case reflect.Slice: 328 elem := t.Elem() 329 switch elem.Kind() { 330 case reflect.Uint8: 331 return &bytesCodec 332 } 333 case reflect.Struct: 334 panic("nested message must be pointer:" + t.String()) 335 case reflect.Ptr: 336 return w.pointer(t, conf) 337 } 338 339 panic("unsupported type: " + t.String()) 340 } 341 342 func pointerSizeFuncOf(_ reflect.Type, c *codec) sizeFunc { 343 return func(p unsafe.Pointer, f *structField) int { 344 p = deref(p) 345 if p != nil { 346 return c.size(p, f) 347 } 348 return 0 349 } 350 } 351 352 func pointerEncodeFuncOf(_ reflect.Type, c *codec) encodeFunc { 353 return func(b []byte, p unsafe.Pointer, f *structField) []byte { 354 p = deref(p) 355 if p != nil { 356 return c.encode(b, p, f) 357 } 358 return b 359 } 360 } 361 362 func pointerDecodeFuncOf(t reflect.Type, c *codec) decodeFunc { 363 t = t.Elem() 364 return func(b []byte, p unsafe.Pointer) (int, error) { 365 v := (*unsafe.Pointer)(p) 366 if *v == nil { 367 *v = unsafe.Pointer(reflect.New(t).Pointer()) 368 } 369 return c.decode(b, *v) 370 } 371 }