github.com/sunvim/utils@v0.1.0/linear_ac/linear_ac.go (about)

     1  // reference to
     2  // https://github.com/crazybie/linear_ac
     3  
     4  package linear_ac
     5  
     6  import (
     7  	"fmt"
     8  	"math"
     9  	"reflect"
    10  	"runtime"
    11  	"sync"
    12  	"unsafe"
    13  )
    14  
    15  var (
    16  	DbgMode         = false
    17  	DisableLinearAc = false
    18  	ChunkSize       = 1024 * 32
    19  )
    20  
    21  // Chunk
    22  
    23  type chunk []byte
    24  
    25  var chunkPool = syncPool{
    26  	New: func() interface{} {
    27  		ck := make(chunk, 0, ChunkSize)
    28  		return &ck
    29  	},
    30  }
    31  
    32  // Objects in sync.Pool will be recycled on demand by the system (usually after two full-gc).
    33  // we can put chunks here to make pointers live longer,
    34  // useful to diagnosis use-after-free bugs.
    35  var diagnosisChunkPool = sync.Pool{}
    36  
    37  func init() {
    38  	if DbgMode {
    39  		ChunkSize /= 8
    40  	}
    41  }
    42  
    43  // Allocator
    44  
    45  type Allocator struct {
    46  	disabled bool
    47  	chunks   []*chunk
    48  	curChunk int
    49  	scanObjs []interface{}
    50  	maps     map[unsafe.Pointer]struct{}
    51  }
    52  
    53  // buildInAc switches to native allocator.
    54  var buildInAc = &Allocator{disabled: true}
    55  
    56  var acPool = syncPool{
    57  	New: func() interface{} {
    58  		return NewLinearAc()
    59  	},
    60  }
    61  
    62  func NewLinearAc() *Allocator {
    63  	ac := &Allocator{
    64  		disabled: DisableLinearAc,
    65  		maps:     map[unsafe.Pointer]struct{}{},
    66  	}
    67  	return ac
    68  }
    69  
    70  // Bind allocator to goroutine
    71  
    72  var acMap = sync.Map{}
    73  
    74  func BindNew() *Allocator {
    75  	ac := acPool.get().(*Allocator)
    76  	acMap.Store(goRoutineId(), ac)
    77  	return ac
    78  }
    79  
    80  func Get() *Allocator {
    81  	if val, ok := acMap.Load(goRoutineId()); ok {
    82  		return val.(*Allocator)
    83  	}
    84  	return buildInAc
    85  }
    86  
    87  func (ac *Allocator) Unbind() {
    88  	if Get() == ac {
    89  		acMap.Delete(goRoutineId())
    90  	}
    91  }
    92  
    93  func (ac *Allocator) Release() {
    94  	if ac == buildInAc {
    95  		return
    96  	}
    97  	ac.Unbind()
    98  	ac.reset()
    99  	acPool.put(ac)
   100  }
   101  
   102  func (ac *Allocator) reset() {
   103  	if ac.disabled {
   104  		return
   105  	}
   106  
   107  	if DbgMode {
   108  		ac.debugCheck(true)
   109  		ac.scanObjs = ac.scanObjs[:0]
   110  	}
   111  
   112  	for _, ck := range ac.chunks {
   113  		*ck = (*ck)[:0]
   114  		if DbgMode {
   115  			diagnosisChunkPool.Put(ck)
   116  		} else {
   117  			chunkPool.put(ck)
   118  		}
   119  	}
   120  	// clear all ref
   121  	ac.chunks = ac.chunks[:cap(ac.chunks)]
   122  	for i := 0; i < cap(ac.chunks); i++ {
   123  		ac.chunks[i] = nil
   124  	}
   125  	ac.chunks = ac.chunks[:0]
   126  	ac.curChunk = 0
   127  
   128  	for k := range ac.maps {
   129  		delete(ac.maps, k)
   130  	}
   131  
   132  	ac.disabled = DisableLinearAc
   133  }
   134  
   135  func (ac *Allocator) New(ptrToPtr interface{}) {
   136  	tmp := noEscape(ptrToPtr)
   137  
   138  	if ac.disabled {
   139  		tp := reflect.TypeOf(tmp).Elem().Elem()
   140  		reflect.ValueOf(tmp).Elem().Set(reflect.New(tp))
   141  		return
   142  	}
   143  
   144  	tp := reflect.TypeOf(tmp).Elem()
   145  	v := ac.typedNew(tp, true)
   146  	reflect.ValueOf(tmp).Elem().Set(reflect.ValueOf(v))
   147  }
   148  
   149  // NewCopy is useful for code migration.
   150  // native mode is slower than new() due to the additional memory move from stack to heap,
   151  // this is on purpose to avoid heap alloc in linear mode.
   152  func (ac *Allocator) NewCopy(ptr interface{}) (ret interface{}) {
   153  	ptrTemp := noEscape(ptr)
   154  	ptrType := reflect.TypeOf(ptrTemp)
   155  	tp := ptrType.Elem()
   156  
   157  	if ac.disabled {
   158  		ret = reflect.New(tp).Interface()
   159  		reflect_typedmemmove(data(tp), data(ret), data(ptrTemp))
   160  	} else {
   161  		ret = ac.typedNew(ptrType, false)
   162  		copyBytes(data(ptrTemp), data(ret), int(tp.Size()))
   163  	}
   164  	return
   165  }
   166  
   167  func (ac *Allocator) typedNew(ptrTp reflect.Type, zero bool) (ret interface{}) {
   168  	objType := ptrTp.Elem()
   169  	ptr := ac.alloc(int(objType.Size()), zero)
   170  	*(*emptyInterface)(unsafe.Pointer(&ret)) = emptyInterface{data(ptrTp), ptr}
   171  	if DbgMode {
   172  		if objType.Kind() == reflect.Struct {
   173  			ac.scanObjs = append(ac.scanObjs, ret)
   174  		}
   175  	}
   176  	return
   177  }
   178  
   179  func (ac *Allocator) alloc(need int, zero bool) unsafe.Pointer {
   180  	if len(ac.chunks) == 0 {
   181  		ac.chunks = append(ac.chunks, chunkPool.get().(*chunk))
   182  	}
   183  start:
   184  	cur := ac.chunks[ac.curChunk]
   185  	used := len(*cur)
   186  	if used+need > cap(*cur) {
   187  		if ac.curChunk == len(ac.chunks)-1 {
   188  			var ck *chunk
   189  			if need > ChunkSize {
   190  				c := make(chunk, 0, need)
   191  				ck = &c
   192  			} else {
   193  				ck = chunkPool.get().(*chunk)
   194  			}
   195  			ac.chunks = append(ac.chunks, ck)
   196  		} else if cap(*ac.chunks[ac.curChunk+1]) < need {
   197  			chunkPool.put(ac.chunks[ac.curChunk+1])
   198  			ck := make(chunk, 0, need)
   199  			ac.chunks[ac.curChunk+1] = &ck
   200  		}
   201  		ac.curChunk++
   202  		goto start
   203  	}
   204  	*cur = (*cur)[:used+need]
   205  	ptr := add((*sliceHeader)(unsafe.Pointer(cur)).Data, used)
   206  	if zero {
   207  		clearBytes(ptr, need)
   208  	}
   209  	return ptr
   210  }
   211  
   212  func (ac *Allocator) NewString(v string) string {
   213  	if ac.disabled {
   214  		return v
   215  	}
   216  	h := (*stringHeader)(unsafe.Pointer(&v))
   217  	ptr := ac.alloc(h.Len, false)
   218  	copyBytes(h.Data, ptr, h.Len)
   219  	h.Data = ptr
   220  	return v
   221  }
   222  
   223  // NewMap use build-in allocator
   224  func (ac *Allocator) NewMap(mapPtr interface{}) {
   225  	mapPtrTemp := noEscape(mapPtr)
   226  
   227  	if ac.disabled {
   228  		tp := reflect.TypeOf(mapPtrTemp).Elem()
   229  		reflect.ValueOf(mapPtrTemp).Elem().Set(reflect.MakeMap(tp))
   230  		return
   231  	}
   232  
   233  	m := reflect.MakeMap(reflect.TypeOf(mapPtrTemp).Elem())
   234  	reflect.ValueOf(mapPtrTemp).Elem().Set(m)
   235  	ac.maps[data(m.Interface())] = struct{}{}
   236  }
   237  
   238  func (ac *Allocator) NewSlice(slicePtr interface{}, len, cap int) {
   239  	slicePtrTmp := noEscape(slicePtr)
   240  
   241  	if ac.disabled {
   242  		v := reflect.MakeSlice(reflect.TypeOf(slicePtrTmp).Elem(), len, cap)
   243  		reflect.ValueOf(slicePtrTmp).Elem().Set(v)
   244  		return
   245  	}
   246  
   247  	slicePtrType := reflect.TypeOf(slicePtrTmp)
   248  	if slicePtrType.Kind() != reflect.Ptr || slicePtrType.Elem().Kind() != reflect.Slice {
   249  		panic("need a pointer to slice")
   250  	}
   251  
   252  	slice := (*sliceHeader)(data(slicePtrTmp))
   253  	if cap < len {
   254  		cap = len
   255  	}
   256  	slice.Data = ac.alloc(cap*int(slicePtrType.Elem().Elem().Size()), false)
   257  	slice.Len = len
   258  	slice.Cap = cap
   259  }
   260  
   261  // CopySlice is useful to create simple slice (simple type as element)
   262  func (ac *Allocator) CopySlice(slice interface{}) (ret interface{}) {
   263  	sliceTmp := noEscape(slice)
   264  	if ac.disabled {
   265  		return sliceTmp
   266  	}
   267  
   268  	sliceType := reflect.TypeOf(sliceTmp)
   269  	if sliceType.Kind() != reflect.Slice {
   270  		panic("need a slice")
   271  	}
   272  	elemType := sliceType.Elem()
   273  	switch elemType.Kind() {
   274  	case reflect.Int, reflect.Int32, reflect.Int64,
   275  		reflect.Uint, reflect.Uint32, reflect.Uint64,
   276  		reflect.Float32, reflect.Float64:
   277  	default:
   278  		panic("must be simple type")
   279  	}
   280  
   281  	// input is a temp copy, directly use it.
   282  	ret = sliceTmp
   283  	header := (*sliceHeader)(data(sliceTmp))
   284  	size := int(elemType.Size()) * header.Len
   285  	dst := ac.alloc(size, false)
   286  	copyBytes(header.Data, dst, size)
   287  	header.Data = dst
   288  
   289  	runtime.KeepAlive(slice)
   290  	return ret
   291  }
   292  
   293  func (ac *Allocator) SliceAppend(slicePtr interface{}, elem interface{}) {
   294  	slicePtrTmp := noEscape(slicePtr)
   295  
   296  	if ac.disabled {
   297  		s := reflect.ValueOf(slicePtrTmp).Elem()
   298  		v := reflect.Append(s, reflect.ValueOf(elem))
   299  		s.Set(v)
   300  		return
   301  	}
   302  
   303  	slicePtrTp := reflect.TypeOf(slicePtrTmp)
   304  	if slicePtrTp.Kind() != reflect.Ptr || slicePtrTp.Elem().Kind() != reflect.Slice {
   305  		panic("expect pointer to slice")
   306  	}
   307  	inputElemTp := reflect.TypeOf(elem)
   308  	sliceElemTp := slicePtrTp.Elem().Elem()
   309  	if sliceElemTp != inputElemTp && elem != nil {
   310  		panic("elem type not match with slice")
   311  	}
   312  
   313  	header := (*sliceHeader)(data(slicePtrTmp))
   314  	elemSz := int(sliceElemTp.Size())
   315  
   316  	// grow
   317  	if header.Len >= header.Cap {
   318  		pre := *header
   319  		if header.Cap >= 16 {
   320  			header.Cap = int(float32(header.Cap) * 1.5)
   321  		} else {
   322  			header.Cap *= 2
   323  		}
   324  		if header.Cap == 0 {
   325  			header.Cap = 1
   326  		}
   327  		header.Data = ac.alloc(header.Cap*elemSz, false)
   328  		copyBytes(pre.Data, header.Data, pre.Len*elemSz)
   329  	}
   330  
   331  	// append
   332  	if header.Len < header.Cap {
   333  		elemData := data(elem)
   334  		dst := add(header.Data, elemSz*header.Len)
   335  		if sliceElemTp.Kind() == reflect.Ptr {
   336  			*(*unsafe.Pointer)(dst) = elemData
   337  		} else {
   338  			copyBytes(elemData, dst, elemSz)
   339  		}
   340  		header.Len++
   341  	}
   342  }
   343  
   344  func (ac *Allocator) Enum(e interface{}) interface{} {
   345  	temp := noEscape(e)
   346  	if ac.disabled {
   347  		r := reflect.New(reflect.TypeOf(temp))
   348  		r.Elem().Set(reflect.ValueOf(temp))
   349  		return r.Interface()
   350  	}
   351  	tp := reflect.TypeOf(temp)
   352  	r := ac.typedNew(reflect.PtrTo(tp), false)
   353  	copyBytes(data(temp), data(r), int(tp.Size()))
   354  	return r
   355  }
   356  
   357  // Use 1 instead of nil or MaxUint64 to
   358  // 1. make non-nil check pass.
   359  // 2. generate a recoverable panic.
   360  const trickyAddress = uintptr(1)
   361  
   362  func (ac *Allocator) internalPointer(addr uintptr) bool {
   363  	if addr == 0 || addr == trickyAddress {
   364  		return true
   365  	}
   366  	for _, c := range ac.chunks {
   367  		h := (*sliceHeader)(unsafe.Pointer(c))
   368  		if addr >= uintptr(h.Data) && addr < uintptr(h.Data)+uintptr(h.Cap) {
   369  			return true
   370  		}
   371  	}
   372  	return false
   373  }
   374  
   375  // NOTE: all memories must be referenced by structs.
   376  func (ac *Allocator) debugCheck(invalidatePointers bool) {
   377  	checked := map[interface{}]struct{}{}
   378  	// reverse order to bypass obfuscated pointers
   379  	for i := len(ac.scanObjs) - 1; i >= 0; i-- {
   380  		ptr := ac.scanObjs[i]
   381  		if _, ok := checked[ptr]; ok {
   382  			continue
   383  		}
   384  		if err := ac.checkRecursively(reflect.ValueOf(ptr), checked, invalidatePointers); err != nil {
   385  			panic(err)
   386  		}
   387  	}
   388  }
   389  
   390  // CheckExternalPointers is useful for if you want to check external pointers but don't want to invalidate pointers.
   391  // e.g. using ac as config memory allocator globally.
   392  func (ac *Allocator) CheckExternalPointers() {
   393  	ac.debugCheck(false)
   394  }
   395  
   396  func (ac *Allocator) checkRecursively(val reflect.Value, checked map[interface{}]struct{}, invalidatePointers bool) error {
   397  	if val.Kind() == reflect.Ptr {
   398  		if val.Pointer() != trickyAddress && !val.IsNil() {
   399  			if !ac.internalPointer(val.Pointer()) {
   400  				return fmt.Errorf("unexpected external pointer: %+v", val)
   401  			}
   402  			if val.Elem().Type().Kind() == reflect.Struct {
   403  				if err := ac.checkRecursively(val.Elem(), checked, invalidatePointers); err != nil {
   404  					return err
   405  				}
   406  				checked[val.Interface()] = struct{}{}
   407  			}
   408  		}
   409  		return nil
   410  	}
   411  
   412  	tp := val.Type()
   413  	fieldName := func(i int) string {
   414  		return fmt.Sprintf("%v.%v", tp.Name(), tp.Field(i).Name)
   415  	}
   416  
   417  	if val.Kind() == reflect.Struct {
   418  		for i := 0; i < val.NumField(); i++ {
   419  			f := val.Field(i)
   420  
   421  			switch f.Kind() {
   422  			case reflect.Ptr:
   423  				if err := ac.checkRecursively(f, checked, invalidatePointers); err != nil {
   424  					return fmt.Errorf("%v: %v", fieldName(i), err)
   425  				}
   426  				if invalidatePointers {
   427  					*(*uintptr)(unsafe.Pointer(f.UnsafeAddr())) = trickyAddress
   428  				}
   429  
   430  			case reflect.Slice:
   431  				h := (*sliceHeader)(unsafe.Pointer(f.UnsafeAddr()))
   432  				if f.Len() > 0 && h.Data != nil {
   433  					if !ac.internalPointer((uintptr)(h.Data)) {
   434  						return fmt.Errorf("%s: unexpected external slice: %s", fieldName(i), f.String())
   435  					}
   436  					for j := 0; j < f.Len(); j++ {
   437  						if err := ac.checkRecursively(f.Index(j), checked, invalidatePointers); err != nil {
   438  							return fmt.Errorf("%v: %v", fieldName(i), err)
   439  						}
   440  					}
   441  				}
   442  				if invalidatePointers {
   443  					h.Data = nil
   444  					h.Len = math.MaxInt32
   445  					h.Cap = math.MaxInt32
   446  				}
   447  
   448  			case reflect.Array:
   449  				for j := 0; j < f.Len(); j++ {
   450  					if err := ac.checkRecursively(f.Index(j), checked, invalidatePointers); err != nil {
   451  						return fmt.Errorf("%v: %v", fieldName(i), err)
   452  					}
   453  				}
   454  
   455  			case reflect.Map:
   456  				m := *(*unsafe.Pointer)(unsafe.Pointer(f.UnsafeAddr()))
   457  				if _, ok := ac.maps[m]; !ok {
   458  					return fmt.Errorf("%v: unexpected external map: %+v", fieldName(i), f)
   459  				}
   460  				for iter := f.MapRange(); iter.Next(); {
   461  					if err := ac.checkRecursively(iter.Value(), checked, invalidatePointers); err != nil {
   462  						return fmt.Errorf("%v: %v", fieldName(i), err)
   463  					}
   464  				}
   465  			}
   466  		}
   467  	}
   468  	return nil
   469  }
   470  
   471  func (ac *Allocator) Bool(v bool) (r *bool) {
   472  	if ac.disabled {
   473  		r = new(bool)
   474  	} else {
   475  		r = ac.typedNew(boolPtrType, false).(*bool)
   476  	}
   477  	*r = v
   478  	return
   479  }
   480  
   481  func (ac *Allocator) Int(v int) (r *int) {
   482  	if ac.disabled {
   483  		r = new(int)
   484  	} else {
   485  		r = ac.typedNew(intPtrType, false).(*int)
   486  	}
   487  	*r = v
   488  	return
   489  }
   490  
   491  func (ac *Allocator) Int32(v int32) (r *int32) {
   492  	if ac.disabled {
   493  		r = new(int32)
   494  	} else {
   495  		r = ac.typedNew(i32PtrType, false).(*int32)
   496  	}
   497  	*r = v
   498  	return
   499  }
   500  
   501  func (ac *Allocator) Uint32(v uint32) (r *uint32) {
   502  	if ac.disabled {
   503  		r = new(uint32)
   504  	} else {
   505  		r = ac.typedNew(u32PtrType, false).(*uint32)
   506  	}
   507  	*r = v
   508  	return
   509  }
   510  
   511  func (ac *Allocator) Int64(v int64) (r *int64) {
   512  	if ac.disabled {
   513  		r = new(int64)
   514  	} else {
   515  		r = ac.typedNew(i64PtrType, false).(*int64)
   516  	}
   517  	*r = v
   518  	return
   519  }
   520  
   521  func (ac *Allocator) Uint64(v uint64) (r *uint64) {
   522  	if ac.disabled {
   523  		r = new(uint64)
   524  	} else {
   525  		r = ac.typedNew(u64PtrType, false).(*uint64)
   526  	}
   527  	*r = v
   528  	return
   529  }
   530  
   531  func (ac *Allocator) Float32(v float32) (r *float32) {
   532  	if ac.disabled {
   533  		r = new(float32)
   534  	} else {
   535  		r = ac.typedNew(f32PtrType, false).(*float32)
   536  	}
   537  	*r = v
   538  	return
   539  }
   540  
   541  func (ac *Allocator) Float64(v float64) (r *float64) {
   542  	if ac.disabled {
   543  		r = new(float64)
   544  	} else {
   545  		r = ac.typedNew(f64PtrType, false).(*float64)
   546  	}
   547  	*r = v
   548  	return
   549  }
   550  
   551  func (ac *Allocator) String(v string) (r *string) {
   552  	if ac.disabled {
   553  		r = new(string)
   554  		*r = v
   555  	} else {
   556  		r = ac.typedNew(strPtrType, false).(*string)
   557  		*r = ac.NewString(v)
   558  	}
   559  	return
   560  }