github.com/RomiChan/protobuf@v0.1.1-0.20230204044148-2ed269a2e54d/proto/walker.go (about)

     1  package proto
     2  
     3  import (
     4  	"reflect"
     5  	"unsafe"
     6  )
     7  
     8  var (
     9  	optionBoolType    = reflect.TypeOf((*Option[bool])(nil)).Elem()
    10  	optionInt32Type   = reflect.TypeOf((*Option[int32])(nil)).Elem()
    11  	optionInt64Type   = reflect.TypeOf((*Option[int64])(nil)).Elem()
    12  	optionUInt32Type  = reflect.TypeOf((*Option[uint32])(nil)).Elem()
    13  	optionUInt64Type  = reflect.TypeOf((*Option[uint64])(nil)).Elem()
    14  	optionFloat32Type = reflect.TypeOf((*Option[float32])(nil)).Elem()
    15  	optionFloat64Type = reflect.TypeOf((*Option[float64])(nil)).Elem()
    16  	optionStringType  = reflect.TypeOf((*Option[string])(nil)).Elem()
    17  )
    18  
    19  type walker struct {
    20  	codecs map[reflect.Type]*codec
    21  	infos  map[reflect.Type]*structInfo
    22  }
    23  
    24  type walkerConfig struct {
    25  	zigzag   bool
    26  	required bool
    27  }
    28  
    29  func (w *walker) codec(t reflect.Type, conf *walkerConfig) *codec {
    30  	if c, ok := w.codecs[t]; ok {
    31  		return c
    32  	}
    33  	if conf.required {
    34  		return w.required(t, conf)
    35  	}
    36  	switch t.Kind() {
    37  	case reflect.Bool:
    38  		return &boolCodec
    39  	case reflect.Int32:
    40  		if conf.zigzag {
    41  			return &zigzag32Codec
    42  		}
    43  		return &int32Codec
    44  	case reflect.Int64:
    45  		if conf.zigzag {
    46  			return &zigzag64Codec
    47  		}
    48  		return &int64Codec
    49  	case reflect.Uint32:
    50  		return &uint32Codec
    51  	case reflect.Uint64:
    52  		return &uint64Codec
    53  	case reflect.Float32:
    54  		return &float32Codec
    55  	case reflect.Float64:
    56  		return &float64Codec
    57  	case reflect.String:
    58  		if conf.required {
    59  			return &stringRequiredCodec
    60  		}
    61  		return &stringCodec
    62  	case reflect.Slice:
    63  		elem := t.Elem()
    64  		switch elem.Kind() {
    65  		case reflect.Uint8:
    66  			return &bytesCodec
    67  		}
    68  	case reflect.Struct:
    69  		return w.structCodec(t)
    70  	case reflect.Ptr:
    71  		return w.pointer(t, conf)
    72  	}
    73  
    74  	panic("unsupported type: " + t.String())
    75  }
    76  
    77  func (w *walker) structCodec(t reflect.Type) *codec {
    78  	if c, ok := codecCache.Load(pointer(t)); ok {
    79  		return c.(*codec)
    80  	}
    81  	if c, ok := w.codecs[t]; ok {
    82  		return c
    83  	}
    84  	c := new(codec)
    85  	w.codecs[t] = c
    86  	elem := t.Elem()
    87  	info := w.structInfo(elem)
    88  	c.size = func(p unsafe.Pointer, f *structField) int {
    89  		p = deref(p)
    90  		if p != nil {
    91  			n := info.size(p)
    92  			n += sizeOfVarint(uint64(n)) + f.tagsize
    93  			return n
    94  		}
    95  		return 0
    96  	}
    97  	c.encode = func(b []byte, p unsafe.Pointer, f *structField) []byte {
    98  		p = deref(p)
    99  		if p != nil {
   100  			b = appendVarint(b, f.wiretag)
   101  			n := info.size(p)
   102  			b = appendVarint(b, uint64(n))
   103  			return info.encode(b, p)
   104  		}
   105  		return b
   106  	}
   107  	c.decode = func(b []byte, p unsafe.Pointer) (int, error) {
   108  		v := (*unsafe.Pointer)(p)
   109  		if *v == nil {
   110  			*v = unsafe.Pointer(reflect.New(elem).Pointer())
   111  		}
   112  		_, n, err := decodeVarint(b)
   113  		if err != nil {
   114  			return n, err
   115  		}
   116  		l, err := info.decode(b[n:], *v)
   117  		return n + l, err
   118  	}
   119  	actualCodec, _ := codecCache.LoadOrStore(pointer(t), c)
   120  	return actualCodec.(*codec)
   121  }
   122  
   123  func baseKindOf(t reflect.Type) reflect.Kind {
   124  	return baseTypeOf(t).Kind()
   125  }
   126  
   127  func baseTypeOf(t reflect.Type) reflect.Type {
   128  	for t.Kind() == reflect.Ptr {
   129  		t = t.Elem()
   130  	}
   131  	return t
   132  }
   133  
   134  func (w *walker) structInfo(t reflect.Type) *structInfo {
   135  	if info, ok := structInfoCache.Load(pointer(t)); ok {
   136  		return info
   137  	}
   138  	if i, ok := w.infos[t]; ok {
   139  		return i
   140  	}
   141  
   142  	info := new(structInfo)
   143  	w.infos[t] = info
   144  	numField := t.NumField()
   145  	fields := make([]*structField, 0, numField)
   146  	for i := 0; i < numField; i++ {
   147  		f := t.Field(i)
   148  		if f.PkgPath != "" {
   149  			continue // unexported
   150  		}
   151  
   152  		tag, ok := f.Tag.Lookup("protobuf")
   153  		if !ok {
   154  			continue // no tag
   155  		}
   156  
   157  		field := structField{
   158  			offset: f.Offset,
   159  		}
   160  
   161  		t, err := parseStructTag(tag)
   162  		if err != nil {
   163  			panic(err)
   164  		}
   165  		field.wiretag = uint64(t.fieldNumber)<<3 | uint64(t.wireType)
   166  		switch t.wireType {
   167  		case fixed32:
   168  			switch f.Type {
   169  			case optionFloat32Type:
   170  				field.codec = &float32OptionCodec
   171  			case optionUInt32Type:
   172  				field.codec = &fixed32OptionCodec
   173  			}
   174  			switch baseKindOf(f.Type) {
   175  			case reflect.Uint32:
   176  				field.codec = &fixed32Codec
   177  			case reflect.Float32:
   178  				field.codec = &float32Codec
   179  			}
   180  		case fixed64:
   181  			switch f.Type {
   182  			case optionUInt64Type:
   183  				field.codec = &fixed64OptionCodec
   184  			case optionFloat64Type:
   185  				field.codec = &float64OptionCodec
   186  			}
   187  			switch baseKindOf(f.Type) {
   188  			case reflect.Uint64:
   189  				field.codec = &fixed64Codec
   190  			case reflect.Float64:
   191  				field.codec = &float64Codec
   192  			}
   193  		}
   194  		if field.codec == nil {
   195  			switch f.Type {
   196  			case optionBoolType:
   197  				field.codec = &boolOptionCodec
   198  			case optionInt32Type:
   199  				field.codec = &int32OptionCodec
   200  				if t.zigzag {
   201  					field.codec = &zigzag32OptionCodec
   202  				}
   203  			case optionInt64Type:
   204  				field.codec = &int64OptionCodec
   205  				if t.zigzag {
   206  					field.codec = &zigzag64OptionCodec
   207  				}
   208  			case optionUInt32Type:
   209  				field.codec = &uint32OptionCodec
   210  			case optionUInt64Type:
   211  				field.codec = &uint64OptionCodec
   212  			case optionStringType:
   213  				field.codec = &stringOptionCodec
   214  			}
   215  		}
   216  		if field.codec == nil {
   217  			conf := &walkerConfig{
   218  				zigzag: t.zigzag,
   219  				// required: t.required,
   220  			}
   221  			switch baseKindOf(f.Type) {
   222  			case reflect.Struct:
   223  				field.codec = w.codec(f.Type, conf)
   224  
   225  			case reflect.Slice:
   226  				elem := f.Type.Elem()
   227  				if elem.Kind() == reflect.Uint8 { // []byte
   228  					field.codec = &bytesCodec
   229  				} else {
   230  					conf.required = true
   231  					field.codec = w.codec(elem, conf)
   232  					field.codec = sliceCodecOf(f.Type, field.codec, w)
   233  				}
   234  
   235  			case reflect.Map:
   236  				conf.required = true // map key and val should be encoded always
   237  				key, val := f.Type.Key(), f.Type.Elem()
   238  				m := &mapField{wiretag: field.wiretag}
   239  
   240  				t, _ := parseStructTag(f.Tag.Get("protobuf_key"))
   241  				keyField := &structField{wiretag: uint64(t.fieldNumber)<<3 | uint64(t.wireType)}
   242  				keyField.tagsize = sizeOfVarint(keyField.wiretag)
   243  				conf.zigzag = t.zigzag
   244  				keyField.codec = w.codec(key, conf)
   245  
   246  				t, _ = parseStructTag(f.Tag.Get("protobuf_val"))
   247  				valFiled := &structField{wiretag: uint64(t.fieldNumber)<<3 | uint64(t.wireType)}
   248  				valFiled.tagsize = sizeOfVarint(valFiled.wiretag)
   249  				conf.zigzag = t.zigzag
   250  				valFiled.codec = w.codec(val, conf)
   251  
   252  				m.keyField = keyField
   253  				m.valField = valFiled
   254  				field.codec = w.mapCodec(f.Type, m)
   255  
   256  			default:
   257  				field.codec = w.codec(f.Type, conf)
   258  			}
   259  		}
   260  		field.tagsize = sizeOfVarint(field.wiretag)
   261  		fields = append(fields, &field)
   262  	}
   263  
   264  	// copy to save capacity
   265  	fields2 := make([]*structField, len(fields))
   266  	copy(fields2, fields)
   267  	info.fields = fields2
   268  
   269  	info.fieldIndex = make(map[fieldNumber]*structField, len(info.fields))
   270  	for _, f := range info.fields {
   271  		info.fieldIndex[f.fieldNumber()] = f
   272  	}
   273  
   274  	structInfoCache.Store(pointer(t), info)
   275  	return info
   276  }
   277  
   278  // @@@ Pointers @@@
   279  
   280  func deref(p unsafe.Pointer) unsafe.Pointer {
   281  	return *(*unsafe.Pointer)(p)
   282  }
   283  
   284  func (w *walker) pointer(t reflect.Type, conf *walkerConfig) *codec {
   285  	switch t.Elem().Kind() {
   286  	case reflect.Struct:
   287  		return w.structCodec(t)
   288  	}
   289  	// common value
   290  	p := new(codec)
   291  	w.codecs[t] = p
   292  	c := w.codec(t.Elem(), conf)
   293  	p.size = pointerSizeFuncOf(t, c)
   294  	p.encode = pointerEncodeFuncOf(t, c)
   295  	p.decode = pointerDecodeFuncOf(t, c)
   296  	return p
   297  }
   298  
   299  func (w *walker) required(t reflect.Type, conf *walkerConfig) *codec {
   300  	if c, ok := w.codecs[t]; ok {
   301  		return c
   302  	}
   303  
   304  	switch t.Kind() {
   305  	case reflect.Bool:
   306  		return &boolRequiredCodec
   307  	case reflect.Int32:
   308  		if conf.zigzag {
   309  			return &zigzag32RequiredCodec
   310  		}
   311  		return &int32RequiredCodec
   312  	case reflect.Int64:
   313  		if conf.zigzag {
   314  			return &zigzag64RequiredCodec
   315  		}
   316  		return &int64RequiredCodec
   317  	case reflect.Uint32:
   318  		return &uint32RequiredCodec
   319  	case reflect.Uint64:
   320  		return &uint64RequiredCodec
   321  	case reflect.Float32:
   322  		return &float32RequiredCodec
   323  	case reflect.Float64:
   324  		return &float64RequiredCodec
   325  	case reflect.String:
   326  		return &stringRequiredCodec
   327  	case reflect.Slice:
   328  		elem := t.Elem()
   329  		switch elem.Kind() {
   330  		case reflect.Uint8:
   331  			return &bytesCodec
   332  		}
   333  	case reflect.Struct:
   334  		panic("nested message must be pointer:" + t.String())
   335  	case reflect.Ptr:
   336  		return w.pointer(t, conf)
   337  	}
   338  
   339  	panic("unsupported type: " + t.String())
   340  }
   341  
   342  func pointerSizeFuncOf(_ reflect.Type, c *codec) sizeFunc {
   343  	return func(p unsafe.Pointer, f *structField) int {
   344  		p = deref(p)
   345  		if p != nil {
   346  			return c.size(p, f)
   347  		}
   348  		return 0
   349  	}
   350  }
   351  
   352  func pointerEncodeFuncOf(_ reflect.Type, c *codec) encodeFunc {
   353  	return func(b []byte, p unsafe.Pointer, f *structField) []byte {
   354  		p = deref(p)
   355  		if p != nil {
   356  			return c.encode(b, p, f)
   357  		}
   358  		return b
   359  	}
   360  }
   361  
   362  func pointerDecodeFuncOf(t reflect.Type, c *codec) decodeFunc {
   363  	t = t.Elem()
   364  	return func(b []byte, p unsafe.Pointer) (int, error) {
   365  		v := (*unsafe.Pointer)(p)
   366  		if *v == nil {
   367  			*v = unsafe.Pointer(reflect.New(t).Pointer())
   368  		}
   369  		return c.decode(b, *v)
   370  	}
   371  }