github.com/cloudwego/dynamicgo@v0.2.6-0.20240519101509-707f41b6b834/thrift/utils.go (about)

     1  /**
     2   * Copyright 2023 CloudWeGo Authors.
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package thrift
    18  
    19  import (
    20  	"fmt"
    21  	"runtime"
    22  	"sync"
    23  	"unsafe"
    24  
    25  	"github.com/cloudwego/dynamicgo/internal/caching"
    26  	"github.com/cloudwego/dynamicgo/internal/rt"
    27  	"github.com/cloudwego/dynamicgo/meta"
    28  )
    29  
    30  const (
    31  	defaultMaxBucketSize     float64 = 10
    32  	defaultMapSize           int     = 4
    33  	defaultHashMapLoadFactor int     = 4
    34  	defaultMaxFieldID                = 256
    35  	defaultMaxNestedDepth            = 1024
    36  )
    37  
    38  // FieldNameMap is a map for field name and field descriptor
    39  type FieldNameMap struct {
    40  	maxKeyLength int
    41  	all          []caching.Pair
    42  	trie         *caching.TrieTree
    43  	hash         *caching.HashMap
    44  }
    45  
    46  // Set sets the field descriptor for the given key
    47  func (ft *FieldNameMap) Set(key string, field *FieldDescriptor) (exist bool) {
    48  	if len(key) > ft.maxKeyLength {
    49  		ft.maxKeyLength = len(key)
    50  	}
    51  	for i, v := range ft.all {
    52  		if v.Key == key {
    53  			exist = true
    54  			ft.all[i].Val = unsafe.Pointer(field)
    55  			return
    56  		}
    57  	}
    58  	ft.all = append(ft.all, caching.Pair{Val: unsafe.Pointer(field), Key: key})
    59  	return
    60  }
    61  
    62  // Get gets the field descriptor for the given key
    63  func (ft FieldNameMap) Get(k string) *FieldDescriptor {
    64  	if ft.trie != nil {
    65  		return (*FieldDescriptor)(ft.trie.Get(k))
    66  	} else if ft.hash != nil {
    67  		return (*FieldDescriptor)(ft.hash.Get(k))
    68  	}
    69  	return nil
    70  }
    71  
    72  // All returns all field descriptors
    73  func (ft FieldNameMap) All() []*FieldDescriptor {
    74  	return *(*[]*FieldDescriptor)(unsafe.Pointer(&ft.all))
    75  }
    76  
    77  // Size returns the size of the map
    78  func (ft FieldNameMap) Size() int {
    79  	if ft.hash != nil {
    80  		return ft.hash.Size()
    81  	} else {
    82  		return ft.trie.Size()
    83  	}
    84  }
    85  
    86  // Build builds the map.
    87  // It will try to build a trie tree if the dispersion of keys is higher enough (min).
    88  func (ft *FieldNameMap) Build() {
    89  	var empty unsafe.Pointer
    90  
    91  	// statistics the distrubution for each position:
    92  	//   - primary slice store the position as its index
    93  	//   - secondary map used to merge values with same char at the same position
    94  	var positionDispersion = make([]map[byte][]int, ft.maxKeyLength)
    95  
    96  	for i, v := range ft.all {
    97  		for j := ft.maxKeyLength - 1; j >= 0; j-- {
    98  			if v.Key == "" {
    99  				// empty key, especially store
   100  				empty = v.Val
   101  			}
   102  			// get the char at the position, defualt (position beyonds key range) is ASCII 0
   103  			var c = byte(0)
   104  			if j < len(v.Key) {
   105  				c = v.Key[j]
   106  			}
   107  
   108  			if positionDispersion[j] == nil {
   109  				positionDispersion[j] = make(map[byte][]int, 16)
   110  			}
   111  			// recoder the index i of the value with same char c at the same position j
   112  			positionDispersion[j][c] = append(positionDispersion[j][c], i)
   113  		}
   114  	}
   115  
   116  	// calculate the best position which has the highest dispersion
   117  	var idealPos = -1
   118  	var min = defaultMaxBucketSize
   119  	var count = len(ft.all)
   120  
   121  	for i := ft.maxKeyLength - 1; i >= 0; i-- {
   122  		cd := positionDispersion[i]
   123  		l := len(cd)
   124  		// calculate the dispersion (average bucket size)
   125  		f := float64(count) / float64(l)
   126  		if f < min {
   127  			min = f
   128  			idealPos = i
   129  		}
   130  		// 1 means all the value store in different bucket, no need to continue calulating
   131  		if min == 1 {
   132  			break
   133  		}
   134  	}
   135  
   136  	if idealPos != -1 {
   137  		// find the best position, build a trie tree
   138  		ft.hash = nil
   139  		ft.trie = &caching.TrieTree{}
   140  		// NOTICE: we only use a two-layer tree here, for better performance
   141  		ft.trie.Positions = append(ft.trie.Positions, idealPos)
   142  		// set all key-values to the trie tree
   143  		for _, v := range ft.all {
   144  			ft.trie.Set(v.Key, v.Val)
   145  		}
   146  		if empty != nil {
   147  			ft.trie.Empty = empty
   148  		}
   149  
   150  	} else {
   151  		// no ideal position, build a hash map
   152  		ft.trie = nil
   153  		ft.hash = caching.NewHashMap(len(ft.all), defaultHashMapLoadFactor)
   154  		// set all key-values to the trie tree
   155  		for _, v := range ft.all {
   156  			// caching.HashMap does not support duplicate key, so must check if the key exists before set
   157  			// WARN: if the key exists, the value WON'T be replaced
   158  			o := ft.hash.Get(v.Key)
   159  			if o == nil {
   160  				ft.hash.Set(v.Key, v.Val)
   161  			}
   162  		}
   163  		if empty != nil {
   164  			ft.hash.Set("", empty)
   165  		}
   166  	}
   167  }
   168  
   169  // FieldIDMap is a map from field id to field descriptor
   170  type FieldIDMap struct {
   171  	m   []*FieldDescriptor
   172  	all []*FieldDescriptor
   173  }
   174  
   175  // All returns all field descriptors
   176  func (fd FieldIDMap) All() (ret []*FieldDescriptor) {
   177  	return fd.all
   178  }
   179  
   180  // Size returns the size of the map
   181  func (fd FieldIDMap) Size() int {
   182  	return len(fd.m)
   183  }
   184  
   185  // Get gets the field descriptor for the given id
   186  func (fd FieldIDMap) Get(id FieldID) *FieldDescriptor {
   187  	if int(id) >= len(fd.m) {
   188  		return nil
   189  	}
   190  	return fd.m[id]
   191  }
   192  
   193  // Set sets the field descriptor for the given id
   194  func (fd *FieldIDMap) Set(id FieldID, f *FieldDescriptor) {
   195  	if int(id) >= len(fd.m) {
   196  		len := int(id) + 1
   197  		tmp := make([]*FieldDescriptor, len)
   198  		copy(tmp, fd.m)
   199  		fd.m = tmp
   200  	}
   201  	o := (fd.m)[id]
   202  	if o == nil {
   203  		fd.all = append(fd.all, f)
   204  	} else {
   205  		for i, v := range fd.all {
   206  			if v == o {
   207  				fd.all[i] = f
   208  				break
   209  			}
   210  		}
   211  	}
   212  	fd.m[id] = f
   213  }
   214  
   215  // RequiresBitmap is a bitmap to mark fields
   216  type RequiresBitmap []uint64
   217  
   218  const (
   219  	int64BitSize  = 64
   220  	int64ByteSize = 8
   221  )
   222  
   223  var bitmapPool = sync.Pool{
   224  	New: func() interface{} {
   225  		ret := RequiresBitmap(make([]uint64, 0, defaultMaxFieldID/int64BitSize+1))
   226  		return &ret
   227  	},
   228  }
   229  
   230  // Set mark the bit corresponding the given id, with the given requireness
   231  //   - RequiredRequireness|DefaultRequireness mark the bit as 1
   232  //   - OptionalRequireness mark the bit as 0
   233  func (b *RequiresBitmap) Set(id FieldID, val Requireness) {
   234  	i := int(id) / int64BitSize
   235  	if len(*b) <= i {
   236  		b.malloc(int32(id))
   237  	}
   238  	p := unsafe.Pointer(uintptr((*rt.GoSlice)(unsafe.Pointer(b)).Ptr) + uintptr(i)*int64ByteSize)
   239  	switch val {
   240  	case RequiredRequireness, DefaultRequireness:
   241  		*(*uint64)(p) |= (0b1 << (id % int64BitSize))
   242  	case OptionalRequireness:
   243  		*(*uint64)(p) &= ^(0b1 << (id % int64BitSize))
   244  	default:
   245  		panic("invalid requireness")
   246  	}
   247  }
   248  
   249  // IsSet tells if the bit corresponding the given id is marked
   250  func (b RequiresBitmap) IsSet(id FieldID) bool {
   251  	i := int(id) / int64BitSize
   252  	if i >= len(b) {
   253  		panic("bitmap id out of range")
   254  	}
   255  	return (b[i] & (0b1 << (id % int64BitSize))) != 0
   256  }
   257  
   258  func (b *RequiresBitmap) malloc(id int32) {
   259  	if n := int32(id / int64BitSize); int(n) >= len(*b) {
   260  		buf := make([]uint64, n+1, int32((n+1)*2))
   261  		copy(buf, *b)
   262  		*b = buf
   263  	}
   264  }
   265  
   266  // CopyTo copy the bitmap to a given bitmap
   267  func (b RequiresBitmap) CopyTo(to *RequiresBitmap) {
   268  	c := cap(*to)
   269  	l := len(b)
   270  	if l > c {
   271  		*to = make([]uint64, l)
   272  	}
   273  	*to = (*to)[:l]
   274  	copy(*to, b)
   275  }
   276  
   277  // NewRequiresBitmap get bitmap from pool, if pool is empty, create a new one
   278  // WARN: memory from pool maybe dirty!
   279  func NewRequiresBitmap() *RequiresBitmap {
   280  	return bitmapPool.Get().(*RequiresBitmap)
   281  }
   282  
   283  // FreeRequiresBitmap free the bitmap, but not clear its memory
   284  func FreeRequiresBitmap(b *RequiresBitmap) {
   285  	// memclrNoHeapPointers(*(*unsafe.Pointer)(unsafe.Pointer(b)), uintptr(len(*b))*uint64TypeSize)
   286  	*b = (*b)[:0]
   287  	bitmapPool.Put(b)
   288  }
   289  
   290  //go:nocheckptr
   291  // CheckRequires scan every bit of the bitmap. When a bit is marked, it will:
   292  //   - if the corresponding field is required-requireness, it reports error
   293  //   - if the corresponding is not required-requireness but writeDefault is true, it will call handler to handle this field
   294  func (b RequiresBitmap) CheckRequires(desc *StructDescriptor, writeDefault bool, handler func(field *FieldDescriptor) error) error {
   295  	// handle bitmap first
   296  	n := len(b)
   297  	s := (*rt.GoSlice)(unsafe.Pointer(&b)).Ptr
   298  
   299  	// test 64 bits once
   300  	for i := 0; i < n; i++ {
   301  		v := *(*uint64)(s)
   302  		for j := 0; v != 0 && j < int64BitSize; j++ {
   303  			if v%2 == 1 {
   304  				id := FieldID(i*int64BitSize + j)
   305  				f := desc.FieldById(id)
   306  				if f == nil {
   307  					return errInvalidBitmapId(id, desc)
   308  				}
   309  				if f.Required() == RequiredRequireness {
   310  					return errMissRequiredField(f, desc)
   311  				} else if !writeDefault {
   312  					v >>= 1
   313  					continue
   314  				}
   315  				if err := handler(f); err != nil {
   316  					return err
   317  				}
   318  			}
   319  			v >>= 1
   320  		}
   321  		s = unsafe.Pointer(uintptr(s) + int64ByteSize)
   322  	}
   323  	runtime.KeepAlive(s)
   324  	return nil
   325  }
   326  
   327  //go:nocheckptr
   328  // CheckRequires scan every bit of the bitmap. When a bit is marked, it will:
   329  //   - if the corresponding field is required-requireness and writeRquired is true, it will call handler to handle this field, otherwise report error
   330  //   - if the corresponding is default-requireness and writeDefault is true, it will call handler to handle this field
   331  //   - if the corresponding is optional-requireness and writeOptional is true, it will call handler to handle this field
   332  func (b RequiresBitmap) HandleRequires(desc *StructDescriptor, writeRquired bool, writeDefault bool, writeOptional bool, handler func(field *FieldDescriptor) error) error {
   333  	// handle bitmap first
   334  	n := len(b)
   335  	s := (*rt.GoSlice)(unsafe.Pointer(&b)).Ptr
   336  	// test 64 bits once
   337  	for i := 0; i < n; i++ {
   338  		v := *(*uint64)(s)
   339  		for j := 0; v != 0 && j < int64BitSize; j++ {
   340  			if v%2 == 1 {
   341  				f := desc.FieldById(FieldID(i*int64BitSize + j))
   342  				if f.Required() == RequiredRequireness && !writeRquired {
   343  					return errMissRequiredField(f, desc)
   344  				}
   345  				if (f.Required() == DefaultRequireness && !writeDefault) || (f.Required() == OptionalRequireness && !writeOptional) {
   346  					v >>= 1
   347  					continue
   348  				}
   349  				if err := handler(f); err != nil {
   350  					return err
   351  				}
   352  			}
   353  			v >>= 1
   354  		}
   355  		s = unsafe.Pointer(uintptr(s) + int64ByteSize)
   356  	}
   357  	runtime.KeepAlive(s)
   358  	return nil
   359  }
   360  
   361  func errMissRequiredField(field *FieldDescriptor, st *StructDescriptor) error {
   362  	return meta.NewError(meta.ErrMissRequiredField, fmt.Sprintf("miss required field '%s' of struct '%s'", field.Name(), st.Name()), nil)
   363  }
   364  
   365  func errInvalidBitmapId(id FieldID, st *StructDescriptor) error {
   366  	return meta.NewError(meta.ErrInvalidParam, fmt.Sprintf("invalid field id %d of struct '%s'", id, st.Name()), nil)
   367  }
   368  
   369  // DefaultValue is the default value of a field
   370  type DefaultValue struct {
   371  	goValue      interface{}
   372  	jsonValue    string
   373  	thriftBinary string
   374  }
   375  
   376  // GoValue return the go runtime representation of the default value
   377  func (d DefaultValue) GoValue() interface{} {
   378  	return d.goValue
   379  }
   380  
   381  // JSONValue return the json-encoded representation of the default value
   382  func (d DefaultValue) JSONValue() string {
   383  	return d.jsonValue
   384  }
   385  
   386  // ThriftBinary return the thrift-binary-encoded representation of the default value
   387  func (d DefaultValue) ThriftBinary() string {
   388  	return d.thriftBinary
   389  }