github.com/grailbio/base@v0.0.11/diagnostic/memsize/deep_size.go (about)

     1  package memsize
     2  
     3  import (
     4  	"reflect"
     5  	"unsafe"
     6  )
     7  
     8  // DeepSize estimates the amount of memory used by a Go value. It's intended as a
     9  // memory usage debugging aid. Argument must be a pointer to a value.
    10  //
    11  // Not thread safe. Behavior is undefined if any value reachable from the argument
    12  // is concurrently mutated. In general, do not call this in production.
    13  //
    14  // Behavior:
    15  // * Recursively descends into contained values (struct fields, slice elements,
    16  // etc.), tracking visitation (by memory address) to handle cycles.
    17  // * Only counts slice length, not unused capacity.
    18  // * Only counts map key and value size, not map overhead.
    19  // * Does not count functions or channels.
    20  //
    21  // The implementation relies on the Go garbage collector being non-compacting (not
    22  // moving values in memory), due to thread non-safety noted above. This is true as
    23  // of Go 1.13, but could change in the future.
    24  func DeepSize(x interface{}) (numBytes int) {
    25  	if x == nil {
    26  		return 0
    27  	}
    28  	v := reflect.ValueOf(x)
    29  	if v.Kind() != reflect.Ptr {
    30  		panic("must be a pointer")
    31  	}
    32  	if v.IsNil() {
    33  		return 0
    34  	}
    35  	scanner := &memoryScanner{
    36  		memory:  &intervalSet{},
    37  		visited: make(map[memoryAndKind]struct{}),
    38  	}
    39  
    40  	unaddressableBytes := scanner.scan(v.Elem(), true)
    41  	return scanner.memory.totalCovered() + unaddressableBytes
    42  }
    43  
    44  type memoryAndKind struct {
    45  	interval
    46  	reflect.Kind
    47  }
    48  
    49  func getMemoryAndType(x reflect.Value) memoryAndKind {
    50  	start := x.UnsafeAddr()
    51  	size := int64(x.Type().Size())
    52  	kind := x.Kind()
    53  	return memoryAndKind{
    54  		interval: interval{start: start, length: size},
    55  		Kind:     kind,
    56  	}
    57  }
    58  
    59  // memoryScanner can recursively scan memory used by a reflect.Value
    60  // not thread safe
    61  // scan should only be called once
    62  type memoryScanner struct {
    63  	memory  *intervalSet               // memory is a set of memory locations that are used in scan()
    64  	visited map[memoryAndKind]struct{} // visited is a map of locations that have already been visited by scan
    65  }
    66  
    67  // scan recursively traverses a reflect.Value and populates all
    68  // x is the Value whose size is to be counted
    69  // includeX indicates whether the bytes for x itself should be counted
    70  // returns a count of unaddressable bytes.
    71  func (s *memoryScanner) scan(x reflect.Value, includeX bool) (unaddressableBytes int) {
    72  	if x.CanAddr() {
    73  		memtype := getMemoryAndType(x)
    74  		if _, ok := s.visited[memtype]; ok {
    75  			return
    76  		}
    77  		s.visited[memtype] = struct{}{}
    78  		s.memory.add(memtype.interval)
    79  	} else if includeX {
    80  		unaddressableBytes += int(x.Type().Size())
    81  	}
    82  
    83  	switch x.Kind() {
    84  	case reflect.String:
    85  		m := x.String()
    86  		hdr := (*reflect.StringHeader)(unsafe.Pointer(&m))
    87  		s.memory.add(interval{hdr.Data, int64(hdr.Len)})
    88  	case reflect.Array:
    89  		if containsPointers(x.Type()) { // must scan each element individually
    90  			for i := 0; i < x.Len(); i++ {
    91  				unaddressableBytes += s.scan(x.Index(i), false)
    92  			}
    93  		}
    94  	case reflect.Slice:
    95  		if x.Len() > 0 {
    96  			if containsPointers(x.Index(0).Type()) { // must scan each element individually
    97  				for i := 0; i < x.Len(); i++ {
    98  					unaddressableBytes += s.scan(x.Index(i), true)
    99  				}
   100  			} else { // add the content of the slice to the memory counter
   101  				start := x.Pointer()
   102  				size := int64(x.Index(0).Type().Size()) * int64(x.Len())
   103  				s.memory.add(interval{start: start, length: size})
   104  			}
   105  		}
   106  	case reflect.Interface, reflect.Ptr:
   107  		if !x.IsNil() {
   108  			unaddressableBytes += s.scan(x.Elem(), true)
   109  		}
   110  	case reflect.Struct:
   111  		for _, fieldI := range structChild(x) {
   112  			unaddressableBytes += s.scan(fieldI, false)
   113  		}
   114  	case reflect.Map:
   115  		for _, key := range x.MapKeys() {
   116  			val := x.MapIndex(key)
   117  			unaddressableBytes += s.scan(key, true)
   118  			unaddressableBytes += s.scan(val, true)
   119  		}
   120  	case reflect.Func, reflect.Chan:
   121  		// Can't do better than this:
   122  	default:
   123  	}
   124  	return
   125  }
   126  
   127  func containsPointers(x reflect.Type) bool {
   128  	switch x.Kind() {
   129  	case reflect.String, reflect.Slice, reflect.Map, reflect.Interface, reflect.Ptr:
   130  		return true
   131  	case reflect.Array:
   132  		if x.Len() > 0 {
   133  			return containsPointers(x.Elem())
   134  		}
   135  	case reflect.Struct:
   136  		for i, n := 0, x.NumField(); i < n; i++ {
   137  			if containsPointers(x.Field(i).Type) {
   138  				return true
   139  			}
   140  		}
   141  	}
   142  	return false
   143  }
   144  
   145  // v must be a struct kind.
   146  // returns all the fields of this struct (recursively for nested structs) that are pointer types
   147  func structChild(x reflect.Value) []reflect.Value {
   148  	var ret []reflect.Value
   149  	for i, n := 0, x.NumField(); i < n; i++ {
   150  		fieldI := x.Field(i)
   151  		switch fieldI.Kind() {
   152  		case reflect.Struct:
   153  			ret = append(ret, structChild(fieldI)...)
   154  		case reflect.Ptr, reflect.String, reflect.Interface, reflect.Slice, reflect.Map:
   155  			ret = append(ret, fieldI)
   156  		}
   157  	}
   158  	return ret
   159  }