github.com/euank/go@v0.0.0-20160829210321-495514729181/src/cmd/compile/internal/ssa/prove.go (about)

     1  // Copyright 2016 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package ssa
     6  
     7  import "math"
     8  
     9  type branch int
    10  
    11  const (
    12  	unknown = iota
    13  	positive
    14  	negative
    15  )
    16  
    17  // relation represents the set of possible relations between
    18  // pairs of variables (v, w). Without a priori knowledge the
    19  // mask is lt | eq | gt meaning v can be less than, equal to or
    20  // greater than w. When the execution path branches on the condition
    21  // `v op w` the set of relations is updated to exclude any
    22  // relation not possible due to `v op w` being true (or false).
    23  //
    24  // E.g.
    25  //
    26  // r := relation(...)
    27  //
    28  // if v < w {
    29  //   newR := r & lt
    30  // }
    31  // if v >= w {
    32  //   newR := r & (eq|gt)
    33  // }
    34  // if v != w {
    35  //   newR := r & (lt|gt)
    36  // }
    37  type relation uint
    38  
    39  const (
    40  	lt relation = 1 << iota
    41  	eq
    42  	gt
    43  )
    44  
    45  // domain represents the domain of a variable pair in which a set
    46  // of relations is known.  For example, relations learned for unsigned
    47  // pairs cannot be transferred to signed pairs because the same bit
    48  // representation can mean something else.
    49  type domain uint
    50  
    51  const (
    52  	signed domain = 1 << iota
    53  	unsigned
    54  	pointer
    55  	boolean
    56  )
    57  
    58  type pair struct {
    59  	v, w *Value // a pair of values, ordered by ID.
    60  	// v can be nil, to mean the zero value.
    61  	// for booleans the zero value (v == nil) is false.
    62  	d domain
    63  }
    64  
    65  // fact is a pair plus a relation for that pair.
    66  type fact struct {
    67  	p pair
    68  	r relation
    69  }
    70  
    71  // a limit records known upper and lower bounds for a value.
    72  type limit struct {
    73  	min, max   int64  // min <= value <= max, signed
    74  	umin, umax uint64 // umin <= value <= umax, unsigned
    75  }
    76  
    77  var noLimit = limit{math.MinInt64, math.MaxInt64, 0, math.MaxUint64}
    78  
    79  // a limitFact is a limit known for a particular value.
    80  type limitFact struct {
    81  	vid   ID
    82  	limit limit
    83  }
    84  
    85  // factsTable keeps track of relations between pairs of values.
    86  type factsTable struct {
    87  	facts map[pair]relation // current known set of relation
    88  	stack []fact            // previous sets of relations
    89  
    90  	// known lower and upper bounds on individual values.
    91  	limits     map[ID]limit
    92  	limitStack []limitFact // previous entries
    93  }
    94  
    95  // checkpointFact is an invalid value used for checkpointing
    96  // and restoring factsTable.
    97  var checkpointFact = fact{}
    98  var checkpointBound = limitFact{}
    99  
   100  func newFactsTable() *factsTable {
   101  	ft := &factsTable{}
   102  	ft.facts = make(map[pair]relation)
   103  	ft.stack = make([]fact, 4)
   104  	ft.limits = make(map[ID]limit)
   105  	ft.limitStack = make([]limitFact, 4)
   106  	return ft
   107  }
   108  
   109  // get returns the known possible relations between v and w.
   110  // If v and w are not in the map it returns lt|eq|gt, i.e. any order.
   111  func (ft *factsTable) get(v, w *Value, d domain) relation {
   112  	if v.isGenericIntConst() || w.isGenericIntConst() {
   113  		reversed := false
   114  		if v.isGenericIntConst() {
   115  			v, w = w, v
   116  			reversed = true
   117  		}
   118  		r := lt | eq | gt
   119  		lim, ok := ft.limits[v.ID]
   120  		if !ok {
   121  			return r
   122  		}
   123  		c := w.AuxInt
   124  		switch d {
   125  		case signed:
   126  			switch {
   127  			case c < lim.min:
   128  				r = gt
   129  			case c > lim.max:
   130  				r = lt
   131  			case c == lim.min && c == lim.max:
   132  				r = eq
   133  			case c == lim.min:
   134  				r = gt | eq
   135  			case c == lim.max:
   136  				r = lt | eq
   137  			}
   138  		case unsigned:
   139  			// TODO: also use signed data if lim.min >= 0?
   140  			var uc uint64
   141  			switch w.Op {
   142  			case OpConst64:
   143  				uc = uint64(c)
   144  			case OpConst32:
   145  				uc = uint64(uint32(c))
   146  			case OpConst16:
   147  				uc = uint64(uint16(c))
   148  			case OpConst8:
   149  				uc = uint64(uint8(c))
   150  			}
   151  			switch {
   152  			case uc < lim.umin:
   153  				r = gt
   154  			case uc > lim.umax:
   155  				r = lt
   156  			case uc == lim.umin && uc == lim.umax:
   157  				r = eq
   158  			case uc == lim.umin:
   159  				r = gt | eq
   160  			case uc == lim.umax:
   161  				r = lt | eq
   162  			}
   163  		}
   164  		if reversed {
   165  			return reverseBits[r]
   166  		}
   167  		return r
   168  	}
   169  
   170  	reversed := false
   171  	if lessByID(w, v) {
   172  		v, w = w, v
   173  		reversed = !reversed
   174  	}
   175  
   176  	p := pair{v, w, d}
   177  	r, ok := ft.facts[p]
   178  	if !ok {
   179  		if p.v == p.w {
   180  			r = eq
   181  		} else {
   182  			r = lt | eq | gt
   183  		}
   184  	}
   185  
   186  	if reversed {
   187  		return reverseBits[r]
   188  	}
   189  	return r
   190  }
   191  
   192  // update updates the set of relations between v and w in domain d
   193  // restricting it to r.
   194  func (ft *factsTable) update(v, w *Value, d domain, r relation) {
   195  	if lessByID(w, v) {
   196  		v, w = w, v
   197  		r = reverseBits[r]
   198  	}
   199  
   200  	p := pair{v, w, d}
   201  	oldR := ft.get(v, w, d)
   202  	ft.stack = append(ft.stack, fact{p, oldR})
   203  	ft.facts[p] = oldR & r
   204  
   205  	// Extract bounds when comparing against constants
   206  	if v.isGenericIntConst() {
   207  		v, w = w, v
   208  		r = reverseBits[r]
   209  	}
   210  	if v != nil && w.isGenericIntConst() {
   211  		c := w.AuxInt
   212  		// Note: all the +1/-1 below could overflow/underflow. Either will
   213  		// still generate correct results, it will just lead to imprecision.
   214  		// In fact if there is overflow/underflow, the corresponding
   215  		// code is unreachable because the known range is outside the range
   216  		// of the value's type.
   217  		old, ok := ft.limits[v.ID]
   218  		if !ok {
   219  			old = noLimit
   220  		}
   221  		lim := old
   222  		// Update lim with the new information we know.
   223  		switch d {
   224  		case signed:
   225  			switch r {
   226  			case lt:
   227  				if c-1 < lim.max {
   228  					lim.max = c - 1
   229  				}
   230  			case lt | eq:
   231  				if c < lim.max {
   232  					lim.max = c
   233  				}
   234  			case gt | eq:
   235  				if c > lim.min {
   236  					lim.min = c
   237  				}
   238  			case gt:
   239  				if c+1 > lim.min {
   240  					lim.min = c + 1
   241  				}
   242  			case lt | gt:
   243  				if c == lim.min {
   244  					lim.min++
   245  				}
   246  				if c == lim.max {
   247  					lim.max--
   248  				}
   249  			case eq:
   250  				lim.min = c
   251  				lim.max = c
   252  			}
   253  		case unsigned:
   254  			var uc uint64
   255  			switch w.Op {
   256  			case OpConst64:
   257  				uc = uint64(c)
   258  			case OpConst32:
   259  				uc = uint64(uint32(c))
   260  			case OpConst16:
   261  				uc = uint64(uint16(c))
   262  			case OpConst8:
   263  				uc = uint64(uint8(c))
   264  			}
   265  			switch r {
   266  			case lt:
   267  				if uc-1 < lim.umax {
   268  					lim.umax = uc - 1
   269  				}
   270  			case lt | eq:
   271  				if uc < lim.umax {
   272  					lim.umax = uc
   273  				}
   274  			case gt | eq:
   275  				if uc > lim.umin {
   276  					lim.umin = uc
   277  				}
   278  			case gt:
   279  				if uc+1 > lim.umin {
   280  					lim.umin = uc + 1
   281  				}
   282  			case lt | gt:
   283  				if uc == lim.umin {
   284  					lim.umin++
   285  				}
   286  				if uc == lim.umax {
   287  					lim.umax--
   288  				}
   289  			case eq:
   290  				lim.umin = uc
   291  				lim.umax = uc
   292  			}
   293  		}
   294  		ft.limitStack = append(ft.limitStack, limitFact{v.ID, old})
   295  		ft.limits[v.ID] = lim
   296  	}
   297  }
   298  
   299  // isNonNegative returns true if v is known to be non-negative.
   300  func (ft *factsTable) isNonNegative(v *Value) bool {
   301  	if isNonNegative(v) {
   302  		return true
   303  	}
   304  	l, has := ft.limits[v.ID]
   305  	return has && (l.min >= 0 || l.umax <= math.MaxInt64)
   306  }
   307  
   308  // checkpoint saves the current state of known relations.
   309  // Called when descending on a branch.
   310  func (ft *factsTable) checkpoint() {
   311  	ft.stack = append(ft.stack, checkpointFact)
   312  	ft.limitStack = append(ft.limitStack, checkpointBound)
   313  }
   314  
   315  // restore restores known relation to the state just
   316  // before the previous checkpoint.
   317  // Called when backing up on a branch.
   318  func (ft *factsTable) restore() {
   319  	for {
   320  		old := ft.stack[len(ft.stack)-1]
   321  		ft.stack = ft.stack[:len(ft.stack)-1]
   322  		if old == checkpointFact {
   323  			break
   324  		}
   325  		if old.r == lt|eq|gt {
   326  			delete(ft.facts, old.p)
   327  		} else {
   328  			ft.facts[old.p] = old.r
   329  		}
   330  	}
   331  	for {
   332  		old := ft.limitStack[len(ft.limitStack)-1]
   333  		ft.limitStack = ft.limitStack[:len(ft.limitStack)-1]
   334  		if old.vid == 0 { // checkpointBound
   335  			break
   336  		}
   337  		if old.limit == noLimit {
   338  			delete(ft.limits, old.vid)
   339  		} else {
   340  			ft.limits[old.vid] = old.limit
   341  		}
   342  	}
   343  }
   344  
   345  func lessByID(v, w *Value) bool {
   346  	if v == nil && w == nil {
   347  		// Should not happen, but just in case.
   348  		return false
   349  	}
   350  	if v == nil {
   351  		return true
   352  	}
   353  	return w != nil && v.ID < w.ID
   354  }
   355  
   356  var (
   357  	reverseBits = [...]relation{0, 4, 2, 6, 1, 5, 3, 7}
   358  
   359  	// maps what we learn when the positive branch is taken.
   360  	// For example:
   361  	//      OpLess8:   {signed, lt},
   362  	//	v1 = (OpLess8 v2 v3).
   363  	// If v1 branch is taken than we learn that the rangeMaks
   364  	// can be at most lt.
   365  	domainRelationTable = map[Op]struct {
   366  		d domain
   367  		r relation
   368  	}{
   369  		OpEq8:   {signed | unsigned, eq},
   370  		OpEq16:  {signed | unsigned, eq},
   371  		OpEq32:  {signed | unsigned, eq},
   372  		OpEq64:  {signed | unsigned, eq},
   373  		OpEqPtr: {pointer, eq},
   374  
   375  		OpNeq8:   {signed | unsigned, lt | gt},
   376  		OpNeq16:  {signed | unsigned, lt | gt},
   377  		OpNeq32:  {signed | unsigned, lt | gt},
   378  		OpNeq64:  {signed | unsigned, lt | gt},
   379  		OpNeqPtr: {pointer, lt | gt},
   380  
   381  		OpLess8:   {signed, lt},
   382  		OpLess8U:  {unsigned, lt},
   383  		OpLess16:  {signed, lt},
   384  		OpLess16U: {unsigned, lt},
   385  		OpLess32:  {signed, lt},
   386  		OpLess32U: {unsigned, lt},
   387  		OpLess64:  {signed, lt},
   388  		OpLess64U: {unsigned, lt},
   389  
   390  		OpLeq8:   {signed, lt | eq},
   391  		OpLeq8U:  {unsigned, lt | eq},
   392  		OpLeq16:  {signed, lt | eq},
   393  		OpLeq16U: {unsigned, lt | eq},
   394  		OpLeq32:  {signed, lt | eq},
   395  		OpLeq32U: {unsigned, lt | eq},
   396  		OpLeq64:  {signed, lt | eq},
   397  		OpLeq64U: {unsigned, lt | eq},
   398  
   399  		OpGeq8:   {signed, eq | gt},
   400  		OpGeq8U:  {unsigned, eq | gt},
   401  		OpGeq16:  {signed, eq | gt},
   402  		OpGeq16U: {unsigned, eq | gt},
   403  		OpGeq32:  {signed, eq | gt},
   404  		OpGeq32U: {unsigned, eq | gt},
   405  		OpGeq64:  {signed, eq | gt},
   406  		OpGeq64U: {unsigned, eq | gt},
   407  
   408  		OpGreater8:   {signed, gt},
   409  		OpGreater8U:  {unsigned, gt},
   410  		OpGreater16:  {signed, gt},
   411  		OpGreater16U: {unsigned, gt},
   412  		OpGreater32:  {signed, gt},
   413  		OpGreater32U: {unsigned, gt},
   414  		OpGreater64:  {signed, gt},
   415  		OpGreater64U: {unsigned, gt},
   416  
   417  		// TODO: OpIsInBounds actually test 0 <= a < b. This means
   418  		// that the positive branch learns signed/LT and unsigned/LT
   419  		// but the negative branch only learns unsigned/GE.
   420  		OpIsInBounds:      {unsigned, lt},
   421  		OpIsSliceInBounds: {unsigned, lt | eq},
   422  	}
   423  )
   424  
   425  // prove removes redundant BlockIf controls that can be inferred in a straight line.
   426  //
   427  // By far, the most common redundant pair are generated by bounds checking.
   428  // For example for the code:
   429  //
   430  //    a[i] = 4
   431  //    foo(a[i])
   432  //
   433  // The compiler will generate the following code:
   434  //
   435  //    if i >= len(a) {
   436  //        panic("not in bounds")
   437  //    }
   438  //    a[i] = 4
   439  //    if i >= len(a) {
   440  //        panic("not in bounds")
   441  //    }
   442  //    foo(a[i])
   443  //
   444  // The second comparison i >= len(a) is clearly redundant because if the
   445  // else branch of the first comparison is executed, we already know that i < len(a).
   446  // The code for the second panic can be removed.
   447  func prove(f *Func) {
   448  	// current node state
   449  	type walkState int
   450  	const (
   451  		descend walkState = iota
   452  		simplify
   453  	)
   454  	// work maintains the DFS stack.
   455  	type bp struct {
   456  		block *Block    // current handled block
   457  		state walkState // what's to do
   458  	}
   459  	work := make([]bp, 0, 256)
   460  	work = append(work, bp{
   461  		block: f.Entry,
   462  		state: descend,
   463  	})
   464  
   465  	ft := newFactsTable()
   466  
   467  	// DFS on the dominator tree.
   468  	for len(work) > 0 {
   469  		node := work[len(work)-1]
   470  		work = work[:len(work)-1]
   471  		parent := f.idom[node.block.ID]
   472  		branch := getBranch(f.sdom, parent, node.block)
   473  
   474  		switch node.state {
   475  		case descend:
   476  			if branch != unknown {
   477  				ft.checkpoint()
   478  				c := parent.Control
   479  				updateRestrictions(ft, boolean, nil, c, lt|gt, branch)
   480  				if tr, has := domainRelationTable[parent.Control.Op]; has {
   481  					// When we branched from parent we learned a new set of
   482  					// restrictions. Update the factsTable accordingly.
   483  					updateRestrictions(ft, tr.d, c.Args[0], c.Args[1], tr.r, branch)
   484  				}
   485  			}
   486  
   487  			work = append(work, bp{
   488  				block: node.block,
   489  				state: simplify,
   490  			})
   491  			for s := f.sdom.Child(node.block); s != nil; s = f.sdom.Sibling(s) {
   492  				work = append(work, bp{
   493  					block: s,
   494  					state: descend,
   495  				})
   496  			}
   497  
   498  		case simplify:
   499  			succ := simplifyBlock(ft, node.block)
   500  			if succ != unknown {
   501  				b := node.block
   502  				b.Kind = BlockFirst
   503  				b.SetControl(nil)
   504  				if succ == negative {
   505  					b.swapSuccessors()
   506  				}
   507  			}
   508  
   509  			if branch != unknown {
   510  				ft.restore()
   511  			}
   512  		}
   513  	}
   514  }
   515  
   516  // getBranch returns the range restrictions added by p
   517  // when reaching b. p is the immediate dominator of b.
   518  func getBranch(sdom SparseTree, p *Block, b *Block) branch {
   519  	if p == nil || p.Kind != BlockIf {
   520  		return unknown
   521  	}
   522  	// If p and p.Succs[0] are dominators it means that every path
   523  	// from entry to b passes through p and p.Succs[0]. We care that
   524  	// no path from entry to b passes through p.Succs[1]. If p.Succs[0]
   525  	// has one predecessor then (apart from the degenerate case),
   526  	// there is no path from entry that can reach b through p.Succs[1].
   527  	// TODO: how about p->yes->b->yes, i.e. a loop in yes.
   528  	if sdom.isAncestorEq(p.Succs[0].b, b) && len(p.Succs[0].b.Preds) == 1 {
   529  		return positive
   530  	}
   531  	if sdom.isAncestorEq(p.Succs[1].b, b) && len(p.Succs[1].b.Preds) == 1 {
   532  		return negative
   533  	}
   534  	return unknown
   535  }
   536  
   537  // updateRestrictions updates restrictions from the immediate
   538  // dominating block (p) using r. r is adjusted according to the branch taken.
   539  func updateRestrictions(ft *factsTable, t domain, v, w *Value, r relation, branch branch) {
   540  	if t == 0 || branch == unknown {
   541  		// Trivial case: nothing to do, or branch unknown.
   542  		// Shoult not happen, but just in case.
   543  		return
   544  	}
   545  	if branch == negative {
   546  		// Negative branch taken, complement the relations.
   547  		r = (lt | eq | gt) ^ r
   548  	}
   549  	for i := domain(1); i <= t; i <<= 1 {
   550  		if t&i != 0 {
   551  			ft.update(v, w, i, r)
   552  		}
   553  	}
   554  }
   555  
   556  // simplifyBlock simplifies block known the restrictions in ft.
   557  // Returns which branch must always be taken.
   558  func simplifyBlock(ft *factsTable, b *Block) branch {
   559  	if b.Kind != BlockIf {
   560  		return unknown
   561  	}
   562  
   563  	// First, checks if the condition itself is redundant.
   564  	m := ft.get(nil, b.Control, boolean)
   565  	if m == lt|gt {
   566  		if b.Func.pass.debug > 0 {
   567  			b.Func.Config.Warnl(b.Line, "Proved boolean %s", b.Control.Op)
   568  		}
   569  		return positive
   570  	}
   571  	if m == eq {
   572  		if b.Func.pass.debug > 0 {
   573  			b.Func.Config.Warnl(b.Line, "Disproved boolean %s", b.Control.Op)
   574  		}
   575  		return negative
   576  	}
   577  
   578  	// Next look check equalities.
   579  	c := b.Control
   580  	tr, has := domainRelationTable[c.Op]
   581  	if !has {
   582  		return unknown
   583  	}
   584  
   585  	a0, a1 := c.Args[0], c.Args[1]
   586  	for d := domain(1); d <= tr.d; d <<= 1 {
   587  		if d&tr.d == 0 {
   588  			continue
   589  		}
   590  
   591  		// tr.r represents in which case the positive branch is taken.
   592  		// m represents which cases are possible because of previous relations.
   593  		// If the set of possible relations m is included in the set of relations
   594  		// need to take the positive branch (or negative) then that branch will
   595  		// always be taken.
   596  		// For shortcut, if m == 0 then this block is dead code.
   597  		m := ft.get(a0, a1, d)
   598  		if m != 0 && tr.r&m == m {
   599  			if b.Func.pass.debug > 0 {
   600  				b.Func.Config.Warnl(b.Line, "Proved %s", c.Op)
   601  			}
   602  			return positive
   603  		}
   604  		if m != 0 && ((lt|eq|gt)^tr.r)&m == m {
   605  			if b.Func.pass.debug > 0 {
   606  				b.Func.Config.Warnl(b.Line, "Disproved %s", c.Op)
   607  			}
   608  			return negative
   609  		}
   610  	}
   611  
   612  	// HACK: If the first argument of IsInBounds or IsSliceInBounds
   613  	// is a constant and we already know that constant is smaller (or equal)
   614  	// to the upper bound than this is proven. Most useful in cases such as:
   615  	// if len(a) <= 1 { return }
   616  	// do something with a[1]
   617  	if (c.Op == OpIsInBounds || c.Op == OpIsSliceInBounds) && ft.isNonNegative(c.Args[0]) {
   618  		m := ft.get(a0, a1, signed)
   619  		if m != 0 && tr.r&m == m {
   620  			if b.Func.pass.debug > 0 {
   621  				b.Func.Config.Warnl(b.Line, "Proved non-negative bounds %s", c.Op)
   622  			}
   623  			return positive
   624  		}
   625  	}
   626  
   627  	return unknown
   628  }
   629  
   630  // isNonNegative returns true is v is known to be greater or equal to zero.
   631  func isNonNegative(v *Value) bool {
   632  	switch v.Op {
   633  	case OpConst64:
   634  		return v.AuxInt >= 0
   635  
   636  	case OpStringLen, OpSliceLen, OpSliceCap,
   637  		OpZeroExt8to64, OpZeroExt16to64, OpZeroExt32to64:
   638  		return true
   639  
   640  	case OpRsh64x64:
   641  		return isNonNegative(v.Args[0])
   642  	}
   643  	return false
   644  }