github.com/amarpal/go-tools@v0.0.0-20240422043104-40142f59f616/analysis/dfa/dfa.go (about)

     1  // Package dfa provides types and functions for implementing data-flow analyses.
     2  package dfa
     3  
     4  import (
     5  	"cmp"
     6  	"fmt"
     7  	"log"
     8  	"math/bits"
     9  	"slices"
    10  	"strings"
    11  	"sync"
    12  
    13  	"github.com/amarpal/go-tools/go/ir"
    14  	"golang.org/x/exp/constraints"
    15  )
    16  
    17  const debugging = false
    18  
    19  func debugf(f string, args ...any) {
    20  	if debugging {
    21  		log.Printf(f, args...)
    22  	}
    23  }
    24  
    25  // Join defines the [∨] operation for a [join-semilattice]. It must implement a commutative and associative binary operation
    26  // that returns the least upper bound of two states from S.
    27  //
    28  // Code that calls Join functions is expected to handle the [⊥ and ⊤ elements], as well as implement idempotency. That is,
    29  // the following properties will be enforced:
    30  //
    31  //   - x ∨ ⊥ = x
    32  //   - x ∨ ⊤ = ⊤
    33  //   - x ∨ x = x
    34  //
    35  // Simple table-based join functions can be created using [JoinTable].
    36  //
    37  // [∨]: https://en.wikipedia.org/wiki/Join_and_meet
    38  // [join-semilattice]: https://en.wikipedia.org/wiki/Semilattice
    39  // [⊥ and ⊤ elements]: https://en.wikipedia.org/wiki/Greatest_element_and_least_element#Top_and_bottom
    40  type Join[S comparable] func(S, S) S
    41  
    42  // Mapping maps a single [ir.Value] to an abstract state.
    43  type Mapping[S comparable] struct {
    44  	Value    ir.Value
    45  	State    S
    46  	Decision Decision
    47  }
    48  
    49  // Decision describes how a mapping from an [ir.Value] to an abstract state came to be.
    50  // Decisions are provided by transfer functions when they create mappings.
    51  type Decision struct {
    52  	// The relevant values that the transfer function used to make the decision.
    53  	Inputs []ir.Value
    54  	// A human-readable description of the decision.
    55  	Description string
    56  	// Whether this is the source of an abstract state. For example, in a taint analysis, the call to a function that
    57  	// produces a tainted value would be the source of the taint state, and any instructions that operate on
    58  	// and propagate tainted values would not be sources.
    59  	Source bool
    60  }
    61  
    62  func (m Mapping[S]) String() string {
    63  	return fmt.Sprintf("%s = %v", m.Value.Name(), m.State)
    64  }
    65  
    66  // M is a helper for constructing instances of [Mapping].
    67  func M[S comparable](v ir.Value, s S, d Decision) Mapping[S] {
    68  	return Mapping[S]{Value: v, State: s, Decision: d}
    69  }
    70  
    71  // Ms is a helper for constructing slices of mappings.
    72  //
    73  // Example:
    74  //
    75  //	Ms(M(v1, d1, ...), M(v2, d2, ...))
    76  func Ms[S comparable](ms ...Mapping[S]) []Mapping[S] {
    77  	return ms
    78  }
    79  
    80  // Framework describes a monotone data-flow framework ⟨S, ∨, Transfer⟩ using a bounded join-semilattice ⟨S, ∨⟩ and a
    81  // monotonic transfer function.
    82  //
    83  // Transfer implements the transfer function. Given an instruction, it should return zero or more mappings from IR
    84  // values to abstract values, i.e. values from the semilattice. Transfer must be monotonic. ϕ instructions are handled
    85  // automatically and do not cause Transfer to be called.
    86  //
    87  // The set S is defined implicitly by the values returned by Join and Transfer and needn't be finite. In addition, it
    88  // contains the elements ⊥ and ⊤ (Bottom and Top) with Join(x, ⊥) = x and Join(x, ⊤) = ⊤. The provided Join function is
    89  // wrapped to handle these elements automatically. All IR values start in the ⊥ state.
    90  //
    91  // Abstract states are associated with IR values. As such, the analysis is sparse and favours the partitioned variable
    92  // lattice (PVL) property.
    93  type Framework[S comparable] struct {
    94  	Join     Join[S]
    95  	Transfer func(*Instance[S], ir.Instruction) []Mapping[S]
    96  	Bottom   S
    97  	Top      S
    98  }
    99  
   100  // Start returns a new instance of the framework. See also [Framework.Forward].
   101  func (fw *Framework[S]) Start() *Instance[S] {
   102  	if fw.Bottom == fw.Top {
   103  		panic("framework's ⊥ and ⊤ are identical; did you forget to specify them?")
   104  	}
   105  
   106  	return &Instance[S]{
   107  		Framework: fw,
   108  		Mapping:   map[ir.Value]Mapping[S]{},
   109  	}
   110  }
   111  
   112  // Forward runs an intraprocedural forward data flow analysis, using an iterative fixed-point algorithm, given the
   113  // functions specified in the framework. It combines [Framework.Start] and [Instance.Forward].
   114  func (fw *Framework[S]) Forward(fn *ir.Function) *Instance[S] {
   115  	ins := fw.Start()
   116  	ins.Forward(fn)
   117  	return ins
   118  }
   119  
   120  // Dot returns a directed graph in [Graphviz] format that represents the finite join-semilattice ⟨S, ≤⟩.
   121  // Vertices represent elements in S and edges represent the ≤ relation between elements.
   122  // We map from ⟨S, ∨⟩ to ⟨S, ≤⟩ by computing x ∨ y for all elements in [S]², where x ≤ y iff x ∨ y == y.
   123  //
   124  // The resulting graph can be filtered through [tred] to compute the transitive reduction of the graph, the
   125  // visualisation of which corresponds to the Hasse diagram of the semilattice.
   126  //
   127  // The set of states should not include the ⊥ and ⊤ elements.
   128  //
   129  // [Graphviz]: https://graphviz.org/
   130  // [tred]: https://graphviz.org/docs/cli/tred/
   131  func Dot[S comparable](fn Join[S], states []S, bottom, top S) string {
   132  	var sb strings.Builder
   133  	sb.WriteString("digraph{\n")
   134  	sb.WriteString("rankdir=\"BT\"\n")
   135  
   136  	for i, v := range states {
   137  		if vs, ok := any(v).(fmt.Stringer); ok {
   138  			fmt.Fprintf(&sb, "n%d [label=%q]\n", i, vs)
   139  		} else {
   140  			fmt.Fprintf(&sb, "n%d [label=%q]\n", i, fmt.Sprintf("%v", v))
   141  		}
   142  	}
   143  
   144  	for dx, x := range states {
   145  		for dy, y := range states {
   146  			if dx == dy {
   147  				continue
   148  			}
   149  
   150  			if join(fn, x, y, bottom, top) == y {
   151  				fmt.Fprintf(&sb, "n%d -> n%d\n", dx, dy)
   152  			}
   153  		}
   154  	}
   155  
   156  	sb.WriteString("}")
   157  	return sb.String()
   158  }
   159  
   160  // Instance is an instance of a data-flow analysis. It is created by [Framework.Forward].
   161  type Instance[S comparable] struct {
   162  	Framework *Framework[S]
   163  	// Mapping is the result of the analysis. Consider using Instance.Value instead of accessing Mapping
   164  	// directly, as it correctly returns ⊥ for missing values.
   165  	Mapping map[ir.Value]Mapping[S]
   166  }
   167  
   168  // Set maps v to the abstract value d. It does not apply any checks. This should only be used before calling [Instance.Forward], to set
   169  // initial states of values.
   170  func (ins *Instance[S]) Set(v ir.Value, d S) {
   171  	ins.Mapping[v] = Mapping[S]{Value: v, State: d}
   172  }
   173  
   174  // Value returns the abstract value for v. If none was set, it returns ⊥.
   175  func (ins *Instance[S]) Value(v ir.Value) S {
   176  	m, ok := ins.Mapping[v]
   177  	if ok {
   178  		return m.State
   179  	} else {
   180  		return ins.Framework.Bottom
   181  	}
   182  }
   183  
   184  // Decision returns the decision of the mapping for v, if any.
   185  func (ins *Instance[S]) Decision(v ir.Value) Decision {
   186  	return ins.Mapping[v].Decision
   187  }
   188  
   189  var dfsDebugMu sync.Mutex
   190  
   191  func join[S comparable](fn Join[S], a, b, bottom, top S) S {
   192  	switch {
   193  	case a == top || b == top:
   194  		return top
   195  	case a == bottom:
   196  		return b
   197  	case b == bottom:
   198  		return a
   199  	case a == b:
   200  		return a
   201  	default:
   202  		return fn(a, b)
   203  	}
   204  }
   205  
   206  // Forward runs a forward data-flow analysis on fn.
   207  func (ins *Instance[S]) Forward(fn *ir.Function) {
   208  	if debugging {
   209  		dfsDebugMu.Lock()
   210  		defer dfsDebugMu.Unlock()
   211  	}
   212  
   213  	debugf("Analyzing %s\n", fn)
   214  	if ins.Mapping == nil {
   215  		ins.Mapping = map[ir.Value]Mapping[S]{}
   216  	}
   217  
   218  	worklist := map[ir.Instruction]struct{}{}
   219  	for _, b := range fn.Blocks {
   220  		for _, instr := range b.Instrs {
   221  			worklist[instr] = struct{}{}
   222  		}
   223  	}
   224  	for len(worklist) > 0 {
   225  		var instr ir.Instruction
   226  		for instr = range worklist {
   227  			break
   228  		}
   229  		delete(worklist, instr)
   230  
   231  		var ds []Mapping[S]
   232  		if phi, ok := instr.(*ir.Phi); ok {
   233  			d := ins.Framework.Bottom
   234  			for _, edge := range phi.Edges {
   235  				a, b := d, ins.Value(edge)
   236  				d = join(ins.Framework.Join, a, b, ins.Framework.Bottom, ins.Framework.Top)
   237  				debugf("join(%v, %v) = %v", a, b, d)
   238  			}
   239  			ds = []Mapping[S]{{Value: phi, State: d, Decision: Decision{Inputs: phi.Edges, Description: "this variable merges the results of multiple branches"}}}
   240  		} else {
   241  			ds = ins.Framework.Transfer(ins, instr)
   242  		}
   243  		if len(ds) > 0 {
   244  			if v, ok := instr.(ir.Value); ok {
   245  				debugf("transfer(%s = %s) = %v", v.Name(), instr, ds)
   246  			} else {
   247  				debugf("transfer(%s) = %v", instr, ds)
   248  			}
   249  		}
   250  		for i, d := range ds {
   251  			old := ins.Value(d.Value)
   252  			dd := d.State
   253  			if dd != old {
   254  				if j := join(ins.Framework.Join, old, dd, ins.Framework.Bottom, ins.Framework.Top); j != dd {
   255  					panic(fmt.Sprintf("transfer function isn't monotonic; Transfer(%v)[%d] = %v; join(%v, %v) = %v", instr, i, dd, old, dd, j))
   256  				}
   257  				ins.Mapping[d.Value] = Mapping[S]{Value: d.Value, State: dd, Decision: d.Decision}
   258  
   259  				for _, ref := range *instr.Referrers() {
   260  					worklist[ref] = struct{}{}
   261  				}
   262  			}
   263  		}
   264  		printMapping(fn, ins.Mapping)
   265  	}
   266  }
   267  
   268  // Propagate is a helper for creating a [Mapping] that propagates the abstract state of src to dst.
   269  // The desc parameter is used as the value of Decision.Description.
   270  func (ins *Instance[S]) Propagate(dst, src ir.Value, desc string) Mapping[S] {
   271  	return M(dst, ins.Value(src), Decision{Inputs: []ir.Value{src}, Description: desc})
   272  }
   273  
   274  func (ins *Instance[S]) Transform(dst ir.Value, s S, src ir.Value, desc string) Mapping[S] {
   275  	return M(dst, s, Decision{Inputs: []ir.Value{src}, Description: desc})
   276  }
   277  
   278  func printMapping[S any](fn *ir.Function, m map[ir.Value]S) {
   279  	if !debugging {
   280  		return
   281  	}
   282  
   283  	debugf("Mapping for %s:\n", fn)
   284  	var keys []ir.Value
   285  	for k := range m {
   286  		keys = append(keys, k)
   287  	}
   288  	slices.SortFunc(keys, func(a, b ir.Value) int {
   289  		return cmp.Compare(a.ID(), b.ID())
   290  	})
   291  	for _, k := range keys {
   292  		v := m[k]
   293  		debugf("\t%v\n", v)
   294  	}
   295  }
   296  
   297  // BinaryTable returns a binary operator based on the provided mapping.
   298  // For missing pairs of values, the default value will be returned.
   299  func BinaryTable[S comparable](default_ S, m map[[2]S]S) func(S, S) S {
   300  	return func(a, b S) S {
   301  		if d, ok := m[[2]S{a, b}]; ok {
   302  			return d
   303  		} else if d, ok := m[[2]S{b, a}]; ok {
   304  			return d
   305  		} else {
   306  			return default_
   307  		}
   308  	}
   309  }
   310  
   311  // JoinTable returns a [Join] function based on the provided mapping.
   312  // For missing pairs of values, the default value will be returned.
   313  func JoinTable[S comparable](top S, m map[[2]S]S) Join[S] {
   314  	return func(a, b S) S {
   315  		if d, ok := m[[2]S{a, b}]; ok {
   316  			return d
   317  		} else if d, ok := m[[2]S{b, a}]; ok {
   318  			return d
   319  		} else {
   320  			return top
   321  		}
   322  	}
   323  }
   324  
   325  func PowerSet[S constraints.Integer](all S) []S {
   326  	out := make([]S, all+1)
   327  	for i := range out {
   328  		out[i] = S(i)
   329  	}
   330  	return out
   331  }
   332  
   333  func MapSet[S constraints.Integer](set S, fn func(S) S) S {
   334  	bits := 64 - bits.LeadingZeros64(uint64(set))
   335  	var out S
   336  	for i := 0; i < bits; i++ {
   337  		if b := (set & (1 << i)); b != 0 {
   338  			out |= fn(b)
   339  		}
   340  	}
   341  	return out
   342  }
   343  
   344  func MapCartesianProduct[S constraints.Integer](x, y S, fn func(S, S) S) S {
   345  	bitsX := 64 - bits.LeadingZeros64(uint64(x))
   346  	bitsY := 64 - bits.LeadingZeros64(uint64(y))
   347  
   348  	var out S
   349  	for i := 0; i < bitsX; i++ {
   350  		for j := 0; j < bitsY; j++ {
   351  			bx := x & (1 << i)
   352  			by := y & (1 << j)
   353  
   354  			if bx != 0 && by != 0 {
   355  				out |= fn(bx, by)
   356  			}
   357  		}
   358  	}
   359  
   360  	return out
   361  }