github.com/aclements/go-misc@v0.0.0-20240129233631-2f6ede80790c/rtcheck/val.go (about)

     1  // Copyright 2016 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package main
     6  
     7  import (
     8  	"fmt"
     9  	"go/constant"
    10  	"go/token"
    11  	"go/types"
    12  	"io"
    13  	"log"
    14  
    15  	"golang.org/x/tools/go/ssa"
    16  )
    17  
    18  // ValState tracks the known dynamic values of instructions and heap
    19  // objects.
    20  type ValState struct {
    21  	frame *frameValState
    22  	heap  *heapValState
    23  }
    24  
    25  // frameValState tracks the known dynamic values of ssa.Values in a
    26  // particular execution path of a single stack frame.
    27  //
    28  // frameValState is effectively a persistent map represented as a
    29  // linked list possibly terminated in a Go map. For efficiency, the
    30  // linked list is compacted into a Go map when the size of the linked
    31  // list and the size of the map at the end of the chain are the same.
    32  type frameValState struct {
    33  	parent *frameValState
    34  	budget int
    35  
    36  	// If flat is non-nil, parent, bind, and val must all be nil.
    37  	flat map[ssa.Value]DynValue
    38  
    39  	bind ssa.Value
    40  	val  DynValue // nil to unbind this value
    41  }
    42  
    43  // heapValState tracks the known dynamic values of heap objects.
    44  type heapValState struct {
    45  	parent *heapValState
    46  	budget int
    47  
    48  	// If flat is non-nil, parent, bind, and val must all be nil.
    49  	flat map[*HeapObject]DynValue
    50  
    51  	bind *HeapObject // A value in the heap
    52  	val  DynValue    // nil to unbind this object
    53  }
    54  
    55  // Get returns the dynamic value of val, or nil if unknown. val may be
    56  // a constant ssa.Value, in which case it will be resolved directly to
    57  // a DynValue if possible. Otherwise, Get will look up the value bound
    58  // to val by a previous call to Extend.
    59  func (vs ValState) Get(val ssa.Value) DynValue {
    60  	switch val := val.(type) {
    61  	case *ssa.Const:
    62  		if val.Value == nil {
    63  			return DynNil{}
    64  		}
    65  		return DynConst{val.Value}
    66  	case *ssa.Global:
    67  		return DynGlobal{val}
    68  	}
    69  	for frame := vs.frame; frame != nil; frame = frame.parent {
    70  		if frame.flat != nil {
    71  			return frame.flat[val]
    72  		}
    73  		if frame.bind == val {
    74  			return frame.val
    75  		}
    76  	}
    77  	return nil
    78  }
    79  
    80  // GetHeap returns the dynamic value of a heap object, or nil if
    81  // unknown.
    82  func (vs ValState) GetHeap(h *HeapObject) DynValue {
    83  	for heap := vs.heap; heap != nil; heap = heap.parent {
    84  		if heap.flat != nil {
    85  			return heap.flat[h]
    86  		}
    87  		if heap.bind == h {
    88  			return heap.val
    89  		}
    90  	}
    91  	return nil
    92  }
    93  
    94  // Extend returns a new ValState that is like vs, but with bind bound
    95  // to dynamic value val. If dyn is dynUnknown, Extend unbinds val.
    96  // Extend is a no-op if called with a constant ssa.Value.
    97  func (vs ValState) Extend(val ssa.Value, dyn DynValue) ValState {
    98  	if _, ok := dyn.(dynUnknown); ok {
    99  		// "Unbind" val.
   100  		if vs.Get(val) == nil {
   101  			return vs
   102  		}
   103  		dyn = nil
   104  	}
   105  	switch val.(type) {
   106  	case *ssa.Const, *ssa.Global, *ssa.Function, *ssa.Builtin:
   107  		return vs
   108  	}
   109  
   110  	budget := 4
   111  	if vs.frame != nil {
   112  		budget = vs.frame.budget - 1
   113  	}
   114  	vs = ValState{&frameValState{vs.frame, budget, nil, val, dyn}, vs.heap}
   115  	if vs.frame.budget <= 0 {
   116  		vs.frame.flatten()
   117  	}
   118  	return vs
   119  }
   120  
   121  // ExtendHeap returns a new ValState that is like vs, but with heap
   122  // object h bound to dynamic value val.
   123  func (vs ValState) ExtendHeap(h *HeapObject, dyn DynValue) ValState {
   124  	if _, ok := dyn.(dynUnknown); ok {
   125  		// "Unbind" val.
   126  		if vs.GetHeap(h) == nil {
   127  			return vs
   128  		}
   129  		dyn = nil
   130  	}
   131  
   132  	budget := 4
   133  	if vs.heap != nil {
   134  		budget = vs.heap.budget - 1
   135  	}
   136  	vs = ValState{vs.frame, &heapValState{vs.heap, budget, nil, h, dyn}}
   137  	if vs.heap.budget <= 0 {
   138  		vs.heap.flatten()
   139  	}
   140  	return vs
   141  }
   142  
   143  // LimitToHeap returns a ValState containing only the heap bindings in
   144  // vs.
   145  func (vs ValState) LimitToHeap() ValState {
   146  	return ValState{nil, vs.heap}
   147  }
   148  
   149  // Do applies the effect of instr to the value state and returns an
   150  // Extended ValState.
   151  func (vs ValState) Do(instr ssa.Instruction) ValState {
   152  	switch instr := instr.(type) {
   153  	case *ssa.BinOp:
   154  		if x, y := vs.Get(instr.X), vs.Get(instr.Y); x != nil && y != nil {
   155  			return vs.Extend(instr, x.BinOp(instr.Op, y))
   156  		}
   157  
   158  	case *ssa.UnOp:
   159  		if x := vs.Get(instr.X); x != nil {
   160  			return vs.Extend(instr, x.UnOp(instr.Op, vs))
   161  		}
   162  
   163  	case *ssa.ChangeType:
   164  		if x := vs.Get(instr.X); x != nil {
   165  			return vs.Extend(instr, x)
   166  		}
   167  
   168  	case *ssa.FieldAddr:
   169  		if x := vs.Get(instr.X); x != nil {
   170  			switch x := x.(type) {
   171  			case DynGlobal:
   172  				return vs.Extend(instr, DynFieldAddr{x.global, instr.Field})
   173  			case DynHeapPtr:
   174  				return vs.Extend(instr, x.FieldAddr(vs, instr))
   175  			}
   176  		}
   177  
   178  	case *ssa.Store:
   179  		// Handle stores to tracked heap objects.
   180  		//
   181  		// TODO: This could be storing to something in the
   182  		// known heap, but we may have failed to track the
   183  		// aliasing of it and think that this is untracked.
   184  		if addr := vs.Get(instr.Addr); addr != nil {
   185  			if addr, ok := addr.(DynHeapPtr); ok {
   186  				val := vs.Get(instr.Val)
   187  				if val == nil {
   188  					val = dynUnknown{}
   189  				}
   190  				return vs.ExtendHeap(addr.elem, val)
   191  			}
   192  		}
   193  
   194  		// TODO: ssa.Convert, ssa.Field
   195  	}
   196  	return vs
   197  }
   198  
   199  func (fs *frameValState) flatten() map[ssa.Value]DynValue {
   200  	if fs == nil {
   201  		return nil
   202  	}
   203  	if fs.flat != nil {
   204  		return fs.flat
   205  	}
   206  	// Collect bindings into a map.
   207  	flat := make(map[ssa.Value]DynValue)
   208  	for fs2 := fs; fs2 != nil; fs2 = fs2.parent {
   209  		if fs2.flat != nil {
   210  			for k, v := range fs2.flat {
   211  				if _, ok := flat[k]; !ok {
   212  					flat[k] = v
   213  				}
   214  			}
   215  			break
   216  		}
   217  		if _, ok := flat[fs2.bind]; !ok {
   218  			flat[fs2.bind] = fs2.val
   219  		}
   220  	}
   221  	// Eliminate unbound values.
   222  	for k, v := range flat {
   223  		if v == nil {
   224  			delete(flat, k)
   225  		}
   226  	}
   227  	fs.flat = flat
   228  	fs.budget = len(flat) + 1
   229  	fs.parent, fs.bind, fs.val = nil, nil, nil
   230  	return fs.flat
   231  }
   232  
   233  func (hs *heapValState) flatten() map[*HeapObject]DynValue {
   234  	if hs == nil {
   235  		return nil
   236  	}
   237  	if hs.flat != nil {
   238  		return hs.flat
   239  	}
   240  	// Collect bindings into a map.
   241  	flat := make(map[*HeapObject]DynValue)
   242  	for hs2 := hs; hs2 != nil; hs2 = hs2.parent {
   243  		if hs2.flat != nil {
   244  			for k, v := range hs2.flat {
   245  				if _, ok := flat[k]; !ok {
   246  					flat[k] = v
   247  				}
   248  			}
   249  			break
   250  		}
   251  		if _, ok := flat[hs2.bind]; !ok {
   252  			flat[hs2.bind] = hs2.val
   253  		}
   254  	}
   255  	// Eliminate unbound values.
   256  	for k, v := range flat {
   257  		if v == nil {
   258  			delete(flat, k)
   259  		}
   260  	}
   261  	hs.flat = flat
   262  	hs.budget = len(flat) + 1
   263  	hs.parent, hs.bind, hs.val = nil, nil, nil
   264  	return hs.flat
   265  }
   266  
   267  // EqualAt returns true if vs and o have equal dynamic values for each
   268  // value in at, and equal heap values for all heap objects.
   269  func (vs ValState) EqualAt(o ValState, at map[ssa.Value]struct{}) bool {
   270  	if len(at) != 0 {
   271  		// Check frame state.
   272  		i1, i2 := vs.frame.flatten(), o.frame.flatten()
   273  		for k := range at {
   274  			v1, ok1 := i1[k]
   275  			v2, ok2 := i2[k]
   276  			if ok1 != ok2 || (ok1 && !v1.Equal(v2)) {
   277  				return false
   278  			}
   279  		}
   280  	}
   281  	// Check heap state.
   282  	h1, h2 := vs.heap.flatten(), o.heap.flatten()
   283  	if len(h1) != len(h2) {
   284  		return false
   285  	}
   286  	for k1, v1 := range h1 {
   287  		if v2, ok := h2[k1]; !ok || !v1.Equal(v2) {
   288  			return false
   289  		}
   290  	}
   291  	return true
   292  }
   293  
   294  // WriteTo writes a debug representation of vs to w.
   295  func (vs ValState) WriteTo(w io.Writer) {
   296  	// TODO: Sort.
   297  	h := vs.heap.flatten()
   298  	for bind, val := range h {
   299  		fmt.Fprintf(w, "%s = %v\n", bind, val)
   300  	}
   301  	f := vs.frame.flatten()
   302  	for bind, val := range f {
   303  		fmt.Fprintf(w, "%s = %v\n", bind.Name(), val)
   304  	}
   305  }
   306  
   307  // A DynValue is the dynamic value of an ssa.Value on a particular
   308  // execution path. It can track any scalar value and addresses that
   309  // cannot alias (e.g., addresses of globals).
   310  type DynValue interface {
   311  	Equal(other DynValue) bool
   312  	BinOp(op token.Token, other DynValue) DynValue
   313  	UnOp(op token.Token, vs ValState) DynValue
   314  }
   315  
   316  type dynUnknown struct{}
   317  
   318  func (dynUnknown) Equal(other DynValue) bool {
   319  	panic("Equal on unknown dynamic value")
   320  }
   321  
   322  func (dynUnknown) BinOp(op token.Token, other DynValue) DynValue {
   323  	panic("BinOp on unknown dynamic value")
   324  }
   325  
   326  func (dynUnknown) UnOp(op token.Token, vs ValState) DynValue {
   327  	panic("UnOp on unknown dynamic value")
   328  }
   329  
   330  // BUG: DynConst is infinite precision. It should track its type and
   331  // truncate the results of every operation.
   332  
   333  type DynConst struct {
   334  	c constant.Value
   335  }
   336  
   337  func (x DynConst) Equal(y DynValue) bool {
   338  	return constant.Compare(x.c, token.EQL, y.(DynConst).c)
   339  }
   340  
   341  func (x DynConst) BinOp(op token.Token, y DynValue) DynValue {
   342  	yc := y.(DynConst).c
   343  	switch op {
   344  	case token.EQL, token.NEQ,
   345  		token.LSS, token.LEQ,
   346  		token.GTR, token.GEQ:
   347  		// Bleh. constant.BinaryOp doesn't work on comparison
   348  		// operations.
   349  		result := constant.Compare(x.c, op, yc)
   350  		return DynConst{constant.MakeBool(result)}
   351  	case token.SHL, token.SHR:
   352  		s, exact := constant.Uint64Val(yc)
   353  		if !exact {
   354  			log.Fatalf("bad shift %v", y)
   355  		}
   356  		return DynConst{constant.Shift(x.c, op, uint(s))}
   357  	case token.QUO:
   358  		if constant.Sign(yc) == 0 {
   359  			// TODO: It would be nice if we could report
   360  			// this for real.
   361  			log.Print("division by zero")
   362  			return dynUnknown{}
   363  		}
   364  		fallthrough
   365  	default:
   366  		return DynConst{constant.BinaryOp(x.c, op, yc)}
   367  	}
   368  }
   369  
   370  func (x DynConst) UnOp(op token.Token, vs ValState) DynValue {
   371  	return DynConst{constant.UnaryOp(op, x.c, 64)}
   372  }
   373  
   374  // comparableBinOp implements DynValue.BinOp for values that support
   375  // only comparison operators.
   376  func comparableBinOp(x DynValue, op token.Token, y DynValue) DynValue {
   377  	equal := x.Equal(y)
   378  	switch op {
   379  	case token.EQL:
   380  		return DynConst{constant.MakeBool(equal)}
   381  	case token.NEQ:
   382  		return DynConst{constant.MakeBool(!equal)}
   383  	}
   384  	log.Fatalf("bad pointer operation: %v", op)
   385  	panic("unreachable")
   386  }
   387  
   388  func addrUnOp(op token.Token) DynValue {
   389  	switch op {
   390  	case token.MUL:
   391  		return dynUnknown{}
   392  	}
   393  	log.Fatalf("bad pointer operation: %v", op)
   394  	panic("unreachable")
   395  }
   396  
   397  // DynNil is a nil pointer.
   398  type DynNil struct{}
   399  
   400  func (x DynNil) Equal(y DynValue) bool {
   401  	_, isNil := y.(DynNil)
   402  	return isNil
   403  }
   404  
   405  func (x DynNil) BinOp(op token.Token, y DynValue) DynValue {
   406  	return comparableBinOp(x, op, y)
   407  }
   408  
   409  func (x DynNil) UnOp(op token.Token, vs ValState) DynValue {
   410  	return addrUnOp(op)
   411  }
   412  
   413  // DynGlobal is the address of a global. Because it's the address of a
   414  // global, it can only alias other DynGlobals.
   415  type DynGlobal struct {
   416  	global *ssa.Global
   417  }
   418  
   419  func (x DynGlobal) Equal(y DynValue) bool {
   420  	yg, isGlobal := y.(DynGlobal)
   421  	return isGlobal && x.global == yg.global
   422  }
   423  
   424  func (x DynGlobal) BinOp(op token.Token, y DynValue) DynValue {
   425  	return comparableBinOp(x, op, y)
   426  }
   427  
   428  func (x DynGlobal) UnOp(op token.Token, vs ValState) DynValue {
   429  	return addrUnOp(op)
   430  }
   431  
   432  // DynFieldAddr is the address of a field in a global. Because it is
   433  // only fields in globals, it can only alias other DynFieldAddrs.
   434  //
   435  // TODO: We could unify DynFieldAddr and DynHeapAddr if we created
   436  // (and cached) HeapObjects for globals and fields of globals as
   437  // needed.
   438  type DynFieldAddr struct {
   439  	object *ssa.Global
   440  	field  int
   441  }
   442  
   443  func (x DynFieldAddr) Equal(y DynValue) bool {
   444  	y2, isFieldAddr := y.(DynFieldAddr)
   445  	return isFieldAddr && x.object == y2.object && x.field == y2.field
   446  }
   447  
   448  func (x DynFieldAddr) BinOp(op token.Token, y DynValue) DynValue {
   449  	return comparableBinOp(x, op, y)
   450  }
   451  
   452  func (x DynFieldAddr) UnOp(op token.Token, vs ValState) DynValue {
   453  	return addrUnOp(op)
   454  }
   455  
   456  // DynHeapPtr is a pointer to a tracked heap object. Because globals
   457  // and heap objects are tracked separately, a DynHeapPtr can only
   458  // alias other DynHeapPtrs.
   459  type DynHeapPtr struct {
   460  	elem *HeapObject
   461  }
   462  
   463  func (x DynHeapPtr) String() string {
   464  	return "&" + x.elem.String()
   465  }
   466  
   467  func (x DynHeapPtr) Equal(y DynValue) bool {
   468  	y2, isHeapPtr := y.(DynHeapPtr)
   469  	return isHeapPtr && x.elem == y2.elem
   470  }
   471  
   472  func (x DynHeapPtr) BinOp(op token.Token, y DynValue) DynValue {
   473  	return comparableBinOp(x, op, y)
   474  }
   475  
   476  func (x DynHeapPtr) UnOp(op token.Token, vs ValState) DynValue {
   477  	if op == token.MUL {
   478  		return vs.GetHeap(x.elem)
   479  	}
   480  	return addrUnOp(op)
   481  }
   482  
   483  func (x DynHeapPtr) FieldAddr(vs ValState, instr *ssa.FieldAddr) DynValue {
   484  	obj := vs.GetHeap(x.elem)
   485  	if obj == nil {
   486  		return dynUnknown{}
   487  	}
   488  	strct := obj.(DynStruct)
   489  	fieldName := instr.X.Type().(*types.Pointer).Elem().Underlying().(*types.Struct).Field(instr.Field).Name()
   490  	if fieldVal, ok := strct[fieldName]; ok {
   491  		return DynHeapPtr{fieldVal}
   492  	}
   493  	return dynUnknown{}
   494  }
   495  
   496  // DynStruct is a struct value consisting of heap objects. It maps
   497  // from field name to heap object. Note that each tracked field is its
   498  // own heap object; e.g., even if it's just an int field, it's
   499  // considered a HeapObject. This makes it possible to track pointers
   500  // to fields.
   501  type DynStruct map[string]*HeapObject
   502  
   503  func (x DynStruct) Equal(y DynValue) bool {
   504  	y2, ok := y.(DynStruct)
   505  	if !ok || len(x) != len(y2) {
   506  		return false
   507  	}
   508  	for k, v := range x {
   509  		if y2[k] != v {
   510  			return false
   511  		}
   512  	}
   513  	return true
   514  }
   515  
   516  func (x DynStruct) BinOp(op token.Token, y DynValue) DynValue {
   517  	return comparableBinOp(x, op, y)
   518  }
   519  
   520  func (x DynStruct) UnOp(op token.Token, vs ValState) DynValue {
   521  	log.Fatal("bad struct operation: %v", op)
   522  	panic("unreachable")
   523  }
   524  
   525  // A HeapObject is a tracked object in the heap. HeapObjects have
   526  // identity; that is, for two *HeapObjects x and y, they refer to the
   527  // same heap object if and only if x == y. HeapObjects have a string
   528  // label for debugging purposes, but this label does not affect
   529  // identity.
   530  type HeapObject struct {
   531  	label string
   532  }
   533  
   534  func NewHeapObject(label string) *HeapObject {
   535  	return &HeapObject{label}
   536  }
   537  
   538  func (h *HeapObject) String() string {
   539  	return "heap:" + h.label
   540  }