
     1  // Copyright 2016 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     5  package ssa
     7  import (
     8  	"fmt"
     9  	"math"
    10  )
    12  type branch int
    14  const (
    15  	unknown = iota
    16  	positive
    17  	negative
    18  )
    20  // relation represents the set of possible relations between
    21  // pairs of variables (v, w). Without a priori knowledge the
    22  // mask is lt | eq | gt meaning v can be less than, equal to or
    23  // greater than w. When the execution path branches on the condition
    24  // `v op w` the set of relations is updated to exclude any
    25  // relation not possible due to `v op w` being true (or false).
    26  //
    27  // E.g.
    28  //
    29  // r := relation(...)
    30  //
    31  // if v < w {
    32  //   newR := r & lt
    33  // }
    34  // if v >= w {
    35  //   newR := r & (eq|gt)
    36  // }
    37  // if v != w {
    38  //   newR := r & (lt|gt)
    39  // }
    40  type relation uint
    42  const (
    43  	lt relation = 1 << iota
    44  	eq
    45  	gt
    46  )
    48  // domain represents the domain of a variable pair in which a set
    49  // of relations is known.  For example, relations learned for unsigned
    50  // pairs cannot be transferred to signed pairs because the same bit
    51  // representation can mean something else.
    52  type domain uint
    54  const (
    55  	signed domain = 1 << iota
    56  	unsigned
    57  	pointer
    58  	boolean
    59  )
    61  type pair struct {
    62  	v, w *Value // a pair of values, ordered by ID.
    63  	// v can be nil, to mean the zero value.
    64  	// for booleans the zero value (v == nil) is false.
    65  	d domain
    66  }
    68  // fact is a pair plus a relation for that pair.
    69  type fact struct {
    70  	p pair
    71  	r relation
    72  }
    74  // a limit records known upper and lower bounds for a value.
    75  type limit struct {
    76  	min, max   int64  // min <= value <= max, signed
    77  	umin, umax uint64 // umin <= value <= umax, unsigned
    78  }
    80  func (l limit) String() string {
    81  	return fmt.Sprintf("sm,SM,um,UM=%d,%d,%d,%d", l.min, l.max, l.umin, l.umax)
    82  }
    84  var noLimit = limit{math.MinInt64, math.MaxInt64, 0, math.MaxUint64}
    86  // a limitFact is a limit known for a particular value.
    87  type limitFact struct {
    88  	vid   ID
    89  	limit limit
    90  }
    92  // factsTable keeps track of relations between pairs of values.
    93  type factsTable struct {
    94  	facts map[pair]relation // current known set of relation
    95  	stack []fact            // previous sets of relations
    97  	// known lower and upper bounds on individual values.
    98  	limits     map[ID]limit
    99  	limitStack []limitFact // previous entries
   100  }
   102  // checkpointFact is an invalid value used for checkpointing
   103  // and restoring factsTable.
   104  var checkpointFact = fact{}
   105  var checkpointBound = limitFact{}
   107  func newFactsTable() *factsTable {
   108  	ft := &factsTable{}
   109  	ft.facts = make(map[pair]relation)
   110  	ft.stack = make([]fact, 4)
   111  	ft.limits = make(map[ID]limit)
   112  	ft.limitStack = make([]limitFact, 4)
   113  	return ft
   114  }
   116  // get returns the known possible relations between v and w.
   117  // If v and w are not in the map it returns lt|eq|gt, i.e. any order.
   118  func (ft *factsTable) get(v, w *Value, d domain) relation {
   119  	if v.isGenericIntConst() || w.isGenericIntConst() {
   120  		reversed := false
   121  		if v.isGenericIntConst() {
   122  			v, w = w, v
   123  			reversed = true
   124  		}
   125  		r := lt | eq | gt
   126  		lim, ok := ft.limits[v.ID]
   127  		if !ok {
   128  			return r
   129  		}
   130  		c := w.AuxInt
   131  		switch d {
   132  		case signed:
   133  			switch {
   134  			case c < lim.min:
   135  				r = gt
   136  			case c > lim.max:
   137  				r = lt
   138  			case c == lim.min && c == lim.max:
   139  				r = eq
   140  			case c == lim.min:
   141  				r = gt | eq
   142  			case c == lim.max:
   143  				r = lt | eq
   144  			}
   145  		case unsigned:
   146  			// TODO: also use signed data if lim.min >= 0?
   147  			var uc uint64
   148  			switch w.Op {
   149  			case OpConst64:
   150  				uc = uint64(c)
   151  			case OpConst32:
   152  				uc = uint64(uint32(c))
   153  			case OpConst16:
   154  				uc = uint64(uint16(c))
   155  			case OpConst8:
   156  				uc = uint64(uint8(c))
   157  			}
   158  			switch {
   159  			case uc < lim.umin:
   160  				r = gt
   161  			case uc > lim.umax:
   162  				r = lt
   163  			case uc == lim.umin && uc == lim.umax:
   164  				r = eq
   165  			case uc == lim.umin:
   166  				r = gt | eq
   167  			case uc == lim.umax:
   168  				r = lt | eq
   169  			}
   170  		}
   171  		if reversed {
   172  			return reverseBits[r]
   173  		}
   174  		return r
   175  	}
   177  	reversed := false
   178  	if lessByID(w, v) {
   179  		v, w = w, v
   180  		reversed = !reversed
   181  	}
   183  	p := pair{v, w, d}
   184  	r, ok := ft.facts[p]
   185  	if !ok {
   186  		if p.v == p.w {
   187  			r = eq
   188  		} else {
   189  			r = lt | eq | gt
   190  		}
   191  	}
   193  	if reversed {
   194  		return reverseBits[r]
   195  	}
   196  	return r
   197  }
   199  // update updates the set of relations between v and w in domain d
   200  // restricting it to r.
   201  func (ft *factsTable) update(parent *Block, v, w *Value, d domain, r relation) {
   202  	if lessByID(w, v) {
   203  		v, w = w, v
   204  		r = reverseBits[r]
   205  	}
   207  	p := pair{v, w, d}
   208  	oldR := ft.get(v, w, d)
   209  	ft.stack = append(ft.stack, fact{p, oldR})
   210  	ft.facts[p] = oldR & r
   212  	// Extract bounds when comparing against constants
   213  	if v.isGenericIntConst() {
   214  		v, w = w, v
   215  		r = reverseBits[r]
   216  	}
   217  	if v != nil && w.isGenericIntConst() {
   218  		c := w.AuxInt
   219  		// Note: all the +1/-1 below could overflow/underflow. Either will
   220  		// still generate correct results, it will just lead to imprecision.
   221  		// In fact if there is overflow/underflow, the corresponding
   222  		// code is unreachable because the known range is outside the range
   223  		// of the value's type.
   224  		old, ok := ft.limits[v.ID]
   225  		if !ok {
   226  			old = noLimit
   227  		}
   228  		lim := old
   229  		// Update lim with the new information we know.
   230  		switch d {
   231  		case signed:
   232  			switch r {
   233  			case lt:
   234  				if c-1 < lim.max {
   235  					lim.max = c - 1
   236  				}
   237  			case lt | eq:
   238  				if c < lim.max {
   239  					lim.max = c
   240  				}
   241  			case gt | eq:
   242  				if c > lim.min {
   243  					lim.min = c
   244  				}
   245  			case gt:
   246  				if c+1 > lim.min {
   247  					lim.min = c + 1
   248  				}
   249  			case lt | gt:
   250  				if c == lim.min {
   251  					lim.min++
   252  				}
   253  				if c == lim.max {
   254  					lim.max--
   255  				}
   256  			case eq:
   257  				lim.min = c
   258  				lim.max = c
   259  			}
   260  		case unsigned:
   261  			var uc uint64
   262  			switch w.Op {
   263  			case OpConst64:
   264  				uc = uint64(c)
   265  			case OpConst32:
   266  				uc = uint64(uint32(c))
   267  			case OpConst16:
   268  				uc = uint64(uint16(c))
   269  			case OpConst8:
   270  				uc = uint64(uint8(c))
   271  			}
   272  			switch r {
   273  			case lt:
   274  				if uc-1 < lim.umax {
   275  					lim.umax = uc - 1
   276  				}
   277  			case lt | eq:
   278  				if uc < lim.umax {
   279  					lim.umax = uc
   280  				}
   281  			case gt | eq:
   282  				if uc > lim.umin {
   283  					lim.umin = uc
   284  				}
   285  			case gt:
   286  				if uc+1 > lim.umin {
   287  					lim.umin = uc + 1
   288  				}
   289  			case lt | gt:
   290  				if uc == lim.umin {
   291  					lim.umin++
   292  				}
   293  				if uc == lim.umax {
   294  					lim.umax--
   295  				}
   296  			case eq:
   297  				lim.umin = uc
   298  				lim.umax = uc
   299  			}
   300  		}
   301  		ft.limitStack = append(ft.limitStack, limitFact{v.ID, old})
   302  		ft.limits[v.ID] = lim
   303  		if v.Block.Func.pass.debug > 2 {
   304  			v.Block.Func.Config.Warnl(parent.Line, "parent=%s, new limits %s %s %s", parent, v, w, lim.String())
   305  		}
   306  	}
   307  }
   309  // isNonNegative returns true if v is known to be non-negative.
   310  func (ft *factsTable) isNonNegative(v *Value) bool {
   311  	if isNonNegative(v) {
   312  		return true
   313  	}
   314  	l, has := ft.limits[v.ID]
   315  	return has && (l.min >= 0 || l.umax <= math.MaxInt64)
   316  }
   318  // checkpoint saves the current state of known relations.
   319  // Called when descending on a branch.
   320  func (ft *factsTable) checkpoint() {
   321  	ft.stack = append(ft.stack, checkpointFact)
   322  	ft.limitStack = append(ft.limitStack, checkpointBound)
   323  }
   325  // restore restores known relation to the state just
   326  // before the previous checkpoint.
   327  // Called when backing up on a branch.
   328  func (ft *factsTable) restore() {
   329  	for {
   330  		old := ft.stack[len(ft.stack)-1]
   331  		ft.stack = ft.stack[:len(ft.stack)-1]
   332  		if old == checkpointFact {
   333  			break
   334  		}
   335  		if old.r == lt|eq|gt {
   336  			delete(ft.facts, old.p)
   337  		} else {
   338  			ft.facts[old.p] = old.r
   339  		}
   340  	}
   341  	for {
   342  		old := ft.limitStack[len(ft.limitStack)-1]
   343  		ft.limitStack = ft.limitStack[:len(ft.limitStack)-1]
   344  		if old.vid == 0 { // checkpointBound
   345  			break
   346  		}
   347  		if old.limit == noLimit {
   348  			delete(ft.limits, old.vid)
   349  		} else {
   350  			ft.limits[old.vid] = old.limit
   351  		}
   352  	}
   353  }
   355  func lessByID(v, w *Value) bool {
   356  	if v == nil && w == nil {
   357  		// Should not happen, but just in case.
   358  		return false
   359  	}
   360  	if v == nil {
   361  		return true
   362  	}
   363  	return w != nil && v.ID < w.ID
   364  }
   366  var (
   367  	reverseBits = [...]relation{0, 4, 2, 6, 1, 5, 3, 7}
   369  	// maps what we learn when the positive branch is taken.
   370  	// For example:
   371  	//      OpLess8:   {signed, lt},
   372  	//	v1 = (OpLess8 v2 v3).
   373  	// If v1 branch is taken than we learn that the rangeMaks
   374  	// can be at most lt.
   375  	domainRelationTable = map[Op]struct {
   376  		d domain
   377  		r relation
   378  	}{
   379  		OpEq8:   {signed | unsigned, eq},
   380  		OpEq16:  {signed | unsigned, eq},
   381  		OpEq32:  {signed | unsigned, eq},
   382  		OpEq64:  {signed | unsigned, eq},
   383  		OpEqPtr: {pointer, eq},
   385  		OpNeq8:   {signed | unsigned, lt | gt},
   386  		OpNeq16:  {signed | unsigned, lt | gt},
   387  		OpNeq32:  {signed | unsigned, lt | gt},
   388  		OpNeq64:  {signed | unsigned, lt | gt},
   389  		OpNeqPtr: {pointer, lt | gt},
   391  		OpLess8:   {signed, lt},
   392  		OpLess8U:  {unsigned, lt},
   393  		OpLess16:  {signed, lt},
   394  		OpLess16U: {unsigned, lt},
   395  		OpLess32:  {signed, lt},
   396  		OpLess32U: {unsigned, lt},
   397  		OpLess64:  {signed, lt},
   398  		OpLess64U: {unsigned, lt},
   400  		OpLeq8:   {signed, lt | eq},
   401  		OpLeq8U:  {unsigned, lt | eq},
   402  		OpLeq16:  {signed, lt | eq},
   403  		OpLeq16U: {unsigned, lt | eq},
   404  		OpLeq32:  {signed, lt | eq},
   405  		OpLeq32U: {unsigned, lt | eq},
   406  		OpLeq64:  {signed, lt | eq},
   407  		OpLeq64U: {unsigned, lt | eq},
   409  		OpGeq8:   {signed, eq | gt},
   410  		OpGeq8U:  {unsigned, eq | gt},
   411  		OpGeq16:  {signed, eq | gt},
   412  		OpGeq16U: {unsigned, eq | gt},
   413  		OpGeq32:  {signed, eq | gt},
   414  		OpGeq32U: {unsigned, eq | gt},
   415  		OpGeq64:  {signed, eq | gt},
   416  		OpGeq64U: {unsigned, eq | gt},
   418  		OpGreater8:   {signed, gt},
   419  		OpGreater8U:  {unsigned, gt},
   420  		OpGreater16:  {signed, gt},
   421  		OpGreater16U: {unsigned, gt},
   422  		OpGreater32:  {signed, gt},
   423  		OpGreater32U: {unsigned, gt},
   424  		OpGreater64:  {signed, gt},
   425  		OpGreater64U: {unsigned, gt},
   427  		// TODO: OpIsInBounds actually test 0 <= a < b. This means
   428  		// that the positive branch learns signed/LT and unsigned/LT
   429  		// but the negative branch only learns unsigned/GE.
   430  		OpIsInBounds:      {unsigned, lt},
   431  		OpIsSliceInBounds: {unsigned, lt | eq},
   432  	}
   433  )
   435  // prove removes redundant BlockIf controls that can be inferred in a straight line.
   436  //
   437  // By far, the most common redundant pair are generated by bounds checking.
   438  // For example for the code:
   439  //
   440  //    a[i] = 4
   441  //    foo(a[i])
   442  //
   443  // The compiler will generate the following code:
   444  //
   445  //    if i >= len(a) {
   446  //        panic("not in bounds")
   447  //    }
   448  //    a[i] = 4
   449  //    if i >= len(a) {
   450  //        panic("not in bounds")
   451  //    }
   452  //    foo(a[i])
   453  //
   454  // The second comparison i >= len(a) is clearly redundant because if the
   455  // else branch of the first comparison is executed, we already know that i < len(a).
   456  // The code for the second panic can be removed.
   457  func prove(f *Func) {
   458  	// current node state
   459  	type walkState int
   460  	const (
   461  		descend walkState = iota
   462  		simplify
   463  	)
   464  	// work maintains the DFS stack.
   465  	type bp struct {
   466  		block *Block    // current handled block
   467  		state walkState // what's to do
   468  	}
   469  	work := make([]bp, 0, 256)
   470  	work = append(work, bp{
   471  		block: f.Entry,
   472  		state: descend,
   473  	})
   475  	ft := newFactsTable()
   476  	idom := f.Idom()
   477  	sdom := f.sdom()
   479  	// DFS on the dominator tree.
   480  	for len(work) > 0 {
   481  		node := work[len(work)-1]
   482  		work = work[:len(work)-1]
   483  		parent := idom[node.block.ID]
   484  		branch := getBranch(sdom, parent, node.block)
   486  		switch node.state {
   487  		case descend:
   488  			if branch != unknown {
   489  				ft.checkpoint()
   490  				c := parent.Control
   491  				updateRestrictions(parent, ft, boolean, nil, c, lt|gt, branch)
   492  				if tr, has := domainRelationTable[parent.Control.Op]; has {
   493  					// When we branched from parent we learned a new set of
   494  					// restrictions. Update the factsTable accordingly.
   495  					updateRestrictions(parent, ft, tr.d, c.Args[0], c.Args[1], tr.r, branch)
   496  				}
   497  			}
   499  			work = append(work, bp{
   500  				block: node.block,
   501  				state: simplify,
   502  			})
   503  			for s := sdom.Child(node.block); s != nil; s = sdom.Sibling(s) {
   504  				work = append(work, bp{
   505  					block: s,
   506  					state: descend,
   507  				})
   508  			}
   510  		case simplify:
   511  			succ := simplifyBlock(ft, node.block)
   512  			if succ != unknown {
   513  				b := node.block
   514  				b.Kind = BlockFirst
   515  				b.SetControl(nil)
   516  				if succ == negative {
   517  					b.swapSuccessors()
   518  				}
   519  			}
   521  			if branch != unknown {
   522  				ft.restore()
   523  			}
   524  		}
   525  	}
   526  }
   528  // getBranch returns the range restrictions added by p
   529  // when reaching b. p is the immediate dominator of b.
   530  func getBranch(sdom SparseTree, p *Block, b *Block) branch {
   531  	if p == nil || p.Kind != BlockIf {
   532  		return unknown
   533  	}
   534  	// If p and p.Succs[0] are dominators it means that every path
   535  	// from entry to b passes through p and p.Succs[0]. We care that
   536  	// no path from entry to b passes through p.Succs[1]. If p.Succs[0]
   537  	// has one predecessor then (apart from the degenerate case),
   538  	// there is no path from entry that can reach b through p.Succs[1].
   539  	// TODO: how about p->yes->b->yes, i.e. a loop in yes.
   540  	if sdom.isAncestorEq(p.Succs[0].b, b) && len(p.Succs[0].b.Preds) == 1 {
   541  		return positive
   542  	}
   543  	if sdom.isAncestorEq(p.Succs[1].b, b) && len(p.Succs[1].b.Preds) == 1 {
   544  		return negative
   545  	}
   546  	return unknown
   547  }
   549  // updateRestrictions updates restrictions from the immediate
   550  // dominating block (p) using r. r is adjusted according to the branch taken.
   551  func updateRestrictions(parent *Block, ft *factsTable, t domain, v, w *Value, r relation, branch branch) {
   552  	if t == 0 || branch == unknown {
   553  		// Trivial case: nothing to do, or branch unknown.
   554  		// Shoult not happen, but just in case.
   555  		return
   556  	}
   557  	if branch == negative {
   558  		// Negative branch taken, complement the relations.
   559  		r = (lt | eq | gt) ^ r
   560  	}
   561  	for i := domain(1); i <= t; i <<= 1 {
   562  		if t&i != 0 {
   563  			ft.update(parent, v, w, i, r)
   564  		}
   565  	}
   566  }
   568  // simplifyBlock simplifies block known the restrictions in ft.
   569  // Returns which branch must always be taken.
   570  func simplifyBlock(ft *factsTable, b *Block) branch {
   571  	for _, v := range b.Values {
   572  		if v.Op != OpSlicemask {
   573  			continue
   574  		}
   575  		add := v.Args[0]
   576  		if add.Op != OpAdd64 && add.Op != OpAdd32 {
   577  			continue
   578  		}
   579  		// Note that the arg of slicemask was originally a sub, but
   580  		// was rewritten to an add by generic.rules (if the thing
   581  		// being subtracted was a constant).
   582  		x := add.Args[0]
   583  		y := add.Args[1]
   584  		if x.Op == OpConst64 || x.Op == OpConst32 {
   585  			x, y = y, x
   586  		}
   587  		if y.Op != OpConst64 && y.Op != OpConst32 {
   588  			continue
   589  		}
   590  		// slicemask(x + y)
   591  		// if x is larger than -y (y is negative), then slicemask is -1.
   592  		lim, ok := ft.limits[x.ID]
   593  		if !ok {
   594  			continue
   595  		}
   596  		if lim.umin > uint64(-y.AuxInt) {
   597  			if v.Args[0].Op == OpAdd64 {
   598  				v.reset(OpConst64)
   599  			} else {
   600  				v.reset(OpConst32)
   601  			}
   602  			if b.Func.pass.debug > 0 {
   603  				b.Func.Config.Warnl(v.Line, "Proved slicemask not needed")
   604  			}
   605  			v.AuxInt = -1
   606  		}
   607  	}
   609  	if b.Kind != BlockIf {
   610  		return unknown
   611  	}
   613  	// First, checks if the condition itself is redundant.
   614  	m := ft.get(nil, b.Control, boolean)
   615  	if m == lt|gt {
   616  		if b.Func.pass.debug > 0 {
   617  			if b.Func.pass.debug > 1 {
   618  				b.Func.Config.Warnl(b.Line, "Proved boolean %s (%s)", b.Control.Op, b.Control)
   619  			} else {
   620  				b.Func.Config.Warnl(b.Line, "Proved boolean %s", b.Control.Op)
   621  			}
   622  		}
   623  		return positive
   624  	}
   625  	if m == eq {
   626  		if b.Func.pass.debug > 0 {
   627  			if b.Func.pass.debug > 1 {
   628  				b.Func.Config.Warnl(b.Line, "Disproved boolean %s (%s)", b.Control.Op, b.Control)
   629  			} else {
   630  				b.Func.Config.Warnl(b.Line, "Disproved boolean %s", b.Control.Op)
   631  			}
   632  		}
   633  		return negative
   634  	}
   636  	// Next look check equalities.
   637  	c := b.Control
   638  	tr, has := domainRelationTable[c.Op]
   639  	if !has {
   640  		return unknown
   641  	}
   643  	a0, a1 := c.Args[0], c.Args[1]
   644  	for d := domain(1); d <= tr.d; d <<= 1 {
   645  		if d&tr.d == 0 {
   646  			continue
   647  		}
   649  		// tr.r represents in which case the positive branch is taken.
   650  		// m represents which cases are possible because of previous relations.
   651  		// If the set of possible relations m is included in the set of relations
   652  		// need to take the positive branch (or negative) then that branch will
   653  		// always be taken.
   654  		// For shortcut, if m == 0 then this block is dead code.
   655  		m := ft.get(a0, a1, d)
   656  		if m != 0 && tr.r&m == m {
   657  			if b.Func.pass.debug > 0 {
   658  				if b.Func.pass.debug > 1 {
   659  					b.Func.Config.Warnl(b.Line, "Proved %s (%s)", c.Op, c)
   660  				} else {
   661  					b.Func.Config.Warnl(b.Line, "Proved %s", c.Op)
   662  				}
   663  			}
   664  			return positive
   665  		}
   666  		if m != 0 && ((lt|eq|gt)^tr.r)&m == m {
   667  			if b.Func.pass.debug > 0 {
   668  				if b.Func.pass.debug > 1 {
   669  					b.Func.Config.Warnl(b.Line, "Disproved %s (%s)", c.Op, c)
   670  				} else {
   671  					b.Func.Config.Warnl(b.Line, "Disproved %s", c.Op)
   672  				}
   673  			}
   674  			return negative
   675  		}
   676  	}
   678  	// HACK: If the first argument of IsInBounds or IsSliceInBounds
   679  	// is a constant and we already know that constant is smaller (or equal)
   680  	// to the upper bound than this is proven. Most useful in cases such as:
   681  	// if len(a) <= 1 { return }
   682  	// do something with a[1]
   683  	if (c.Op == OpIsInBounds || c.Op == OpIsSliceInBounds) && ft.isNonNegative(c.Args[0]) {
   684  		m := ft.get(a0, a1, signed)
   685  		if m != 0 && tr.r&m == m {
   686  			if b.Func.pass.debug > 0 {
   687  				if b.Func.pass.debug > 1 {
   688  					b.Func.Config.Warnl(b.Line, "Proved non-negative bounds %s (%s)", c.Op, c)
   689  				} else {
   690  					b.Func.Config.Warnl(b.Line, "Proved non-negative bounds %s", c.Op)
   691  				}
   692  			}
   693  			return positive
   694  		}
   695  	}
   697  	return unknown
   698  }
   700  // isNonNegative returns true is v is known to be greater or equal to zero.
   701  func isNonNegative(v *Value) bool {
   702  	switch v.Op {
   703  	case OpConst64:
   704  		return v.AuxInt >= 0
   706  	case OpConst32:
   707  		return int32(v.AuxInt) >= 0
   709  	case OpStringLen, OpSliceLen, OpSliceCap,
   710  		OpZeroExt8to64, OpZeroExt16to64, OpZeroExt32to64:
   711  		return true
   713  	case OpRsh64x64:
   714  		return isNonNegative(v.Args[0])
   715  	}
   716  	return false
   717  }