github.com/grailbio/bigslice@v0.0.0-20230519005545-30c4c12152ad/internal/zero/zero.go (about)

     1  // Copyright 2018 GRAIL, Inc. All rights reserved.
     2  // Use of this source code is governed by the Apache 2.0
     3  // license that can be found in the LICENSE file.
     4  
     5  // Package zero provides facilities for efficiently zeroing Go values.
     6  package zero
     7  
     8  import (
     9  	"reflect"
    10  	"sync"
    11  	"unsafe"
    12  )
    13  
    14  var cache sync.Map // map[reflect.Type]func(ptr uintptr, n int)
    15  
    16  // Slice zeroes the elements 0 <= i < v.Len() of the provided slice.
    17  // Slice panics if the value is not a slice. f
    18  func Slice(v interface{}) {
    19  	SliceValue(reflect.ValueOf(v))
    20  }
    21  
    22  // SliceValue zeroes the elements 0 <= i < v.Len() of the provided slice
    23  // value. Slice panics if the value is not a slice. f
    24  func SliceValue(v reflect.Value) {
    25  	if v.Kind() != reflect.Slice {
    26  		panic("zero.Slice: called on non-slice value")
    27  	}
    28  	Unsafe(v.Type().Elem(), unsafe.Pointer(v.Pointer()), v.Len())
    29  }
    30  
    31  // Unsafe zeroes n elements starting at the address ptr. Elements
    32  // must be of type t.
    33  func Unsafe(t reflect.Type, ptr unsafe.Pointer, n int) {
    34  	zi, ok := cache.Load(t)
    35  	if !ok {
    36  		zi, _ = cache.LoadOrStore(t, slice(t))
    37  	}
    38  	z := zi.(func(ptr unsafe.Pointer, n int))
    39  	z(ptr, n)
    40  }
    41  
    42  func slice(elem reflect.Type) func(ptr unsafe.Pointer, n int) {
    43  	switch kind := elem.Kind(); {
    44  	case isValueType(elem):
    45  		return sliceValue(elem)
    46  	case kind == reflect.String:
    47  		return func(ptr unsafe.Pointer, n int) {
    48  			var strs []string
    49  			strsHdr := (*reflect.SliceHeader)(unsafe.Pointer(&strs))
    50  			strsHdr.Data = uintptr(ptr)
    51  			strsHdr.Len = n
    52  			strsHdr.Cap = n
    53  			for i := range strs {
    54  				strs[i] = ""
    55  			}
    56  		}
    57  	case kind == reflect.Slice:
    58  		return func(ptr unsafe.Pointer, n int) {
    59  			var slices []reflect.SliceHeader
    60  			slicesHdr := (*reflect.SliceHeader)(unsafe.Pointer(&slices))
    61  			slicesHdr.Data = uintptr(ptr)
    62  			slicesHdr.Len = n
    63  			slicesHdr.Cap = n
    64  			for i := range slices {
    65  				slices[i].Data = uintptr(0)
    66  				slices[i].Len = 0
    67  				slices[i].Cap = 0
    68  			}
    69  		}
    70  	case kind == reflect.Ptr:
    71  		return func(ptr unsafe.Pointer, n int) {
    72  			var ps []unsafe.Pointer
    73  			psHdr := (*reflect.SliceHeader)(unsafe.Pointer(&ps))
    74  			psHdr.Data = uintptr(ptr)
    75  			psHdr.Len = n
    76  			psHdr.Cap = n
    77  			for i := range ps {
    78  				ps[i] = nil
    79  			}
    80  		}
    81  	default:
    82  		// Slow case: use reflection API.
    83  		zero := reflect.Zero(elem)
    84  		return func(ptr unsafe.Pointer, n int) {
    85  			v := reflect.NewAt(reflect.ArrayOf(n, elem), ptr).Elem()
    86  			for i := 0; i < n; i++ {
    87  				v.Index(i).Set(zero)
    88  			}
    89  		}
    90  	}
    91  }
    92  
    93  func sliceValue(elem reflect.Type) func(ptr unsafe.Pointer, n int) {
    94  	switch size := elem.Size(); size {
    95  	case 8:
    96  		return func(ptr unsafe.Pointer, n int) {
    97  			var vs []int64
    98  			vsHdr := (*reflect.SliceHeader)(unsafe.Pointer(&vs))
    99  			vsHdr.Data = uintptr(ptr)
   100  			vsHdr.Len = n
   101  			vsHdr.Cap = n
   102  			for i := range vs {
   103  				vs[i] = 0
   104  			}
   105  		}
   106  	case 4:
   107  		return func(ptr unsafe.Pointer, n int) {
   108  			var vs []int32
   109  			vsHdr := (*reflect.SliceHeader)(unsafe.Pointer(&vs))
   110  			vsHdr.Data = uintptr(ptr)
   111  			vsHdr.Len = n
   112  			vsHdr.Cap = n
   113  			for i := range vs {
   114  				vs[i] = 0
   115  			}
   116  		}
   117  	case 2:
   118  		return func(ptr unsafe.Pointer, n int) {
   119  			var vs []int16
   120  			vsHdr := (*reflect.SliceHeader)(unsafe.Pointer(&vs))
   121  			vsHdr.Data = uintptr(ptr)
   122  			vsHdr.Len = n
   123  			vsHdr.Cap = n
   124  			for i := range vs {
   125  				vs[i] = 0
   126  			}
   127  		}
   128  	case 1:
   129  		return func(ptr unsafe.Pointer, n int) {
   130  			var vs []int8
   131  			vsHdr := (*reflect.SliceHeader)(unsafe.Pointer(&vs))
   132  			vsHdr.Data = uintptr(ptr)
   133  			vsHdr.Len = n
   134  			vsHdr.Cap = n
   135  			for i := range vs {
   136  				vs[i] = 0
   137  			}
   138  		}
   139  	default:
   140  		// Slow case: reinterpret to []byte, and set that. Note that the
   141  		// compiler should be able to optimize this too. In this case
   142  		// it's always a value type, so this is always safe to do.
   143  		return func(ptr unsafe.Pointer, n int) {
   144  			var b []byte
   145  			bHdr := (*reflect.SliceHeader)(unsafe.Pointer(&b))
   146  			bHdr.Data = uintptr(ptr)
   147  			bHdr.Len = int(size) * n
   148  			bHdr.Cap = bHdr.Len
   149  			for i := range b {
   150  				b[i] = 0
   151  			}
   152  		}
   153  	}
   154  }
   155  
   156  func isValueType(t reflect.Type) bool {
   157  	switch t.Kind() {
   158  	case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
   159  		reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
   160  		reflect.Uintptr, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
   161  		return true
   162  	case reflect.Array:
   163  		return isValueType(t.Elem())
   164  	case reflect.Struct:
   165  		for i := 0; i < t.NumField(); i++ {
   166  			if !isValueType(t.Field(i).Type) {
   167  				return false
   168  			}
   169  		}
   170  		return true
   171  	}
   172  	return false
   173  }