github.com/miolini/go@v0.0.0-20160405192216-fca68c8cb408/src/cmd/compile/internal/ssa/prove.go (about)

     1  // Copyright 2016 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package ssa
     6  
     7  import "math"
     8  
     9  type branch int
    10  
    11  const (
    12  	unknown = iota
    13  	positive
    14  	negative
    15  )
    16  
    17  // relation represents the set of possible relations between
    18  // pairs of variables (v, w). Without a priori knowledge the
    19  // mask is lt | eq | gt meaning v can be less than, equal to or
    20  // greater than w. When the execution path branches on the condition
    21  // `v op w` the set of relations is updated to exclude any
    22  // relation not possible due to `v op w` being true (or false).
    23  //
    24  // E.g.
    25  //
    26  // r := relation(...)
    27  //
    28  // if v < w {
    29  //   newR := r & lt
    30  // }
    31  // if v >= w {
    32  //   newR := r & (eq|gt)
    33  // }
    34  // if v != w {
    35  //   newR := r & (lt|gt)
    36  // }
    37  type relation uint
    38  
    39  const (
    40  	lt relation = 1 << iota
    41  	eq
    42  	gt
    43  )
    44  
    45  // domain represents the domain of a variable pair in which a set
    46  // of relations is known.  For example, relations learned for unsigned
    47  // pairs cannot be transferred to signed pairs because the same bit
    48  // representation can mean something else.
    49  type domain uint
    50  
    51  const (
    52  	signed domain = 1 << iota
    53  	unsigned
    54  	pointer
    55  	boolean
    56  )
    57  
    58  type pair struct {
    59  	v, w *Value // a pair of values, ordered by ID.
    60  	// v can be nil, to mean the zero value.
    61  	// for booleans the zero value (v == nil) is false.
    62  	d domain
    63  }
    64  
    65  // fact is a pair plus a relation for that pair.
    66  type fact struct {
    67  	p pair
    68  	r relation
    69  }
    70  
    71  // a limit records known upper and lower bounds for a value.
    72  type limit struct {
    73  	min, max   int64  // min <= value <= max, signed
    74  	umin, umax uint64 // umin <= value <= umax, unsigned
    75  }
    76  
    77  var noLimit = limit{math.MinInt64, math.MaxInt64, 0, math.MaxUint64}
    78  
    79  // a limitFact is a limit known for a particular value.
    80  type limitFact struct {
    81  	vid   ID
    82  	limit limit
    83  }
    84  
    85  // factsTable keeps track of relations between pairs of values.
    86  type factsTable struct {
    87  	facts map[pair]relation // current known set of relation
    88  	stack []fact            // previous sets of relations
    89  
    90  	// known lower and upper bounds on individual values.
    91  	limits     map[ID]limit
    92  	limitStack []limitFact // previous entries
    93  }
    94  
    95  // checkpointFact is an invalid value used for checkpointing
    96  // and restoring factsTable.
    97  var checkpointFact = fact{}
    98  var checkpointBound = limitFact{}
    99  
   100  func newFactsTable() *factsTable {
   101  	ft := &factsTable{}
   102  	ft.facts = make(map[pair]relation)
   103  	ft.stack = make([]fact, 4)
   104  	ft.limits = make(map[ID]limit)
   105  	ft.limitStack = make([]limitFact, 4)
   106  	return ft
   107  }
   108  
   109  // get returns the known possible relations between v and w.
   110  // If v and w are not in the map it returns lt|eq|gt, i.e. any order.
   111  func (ft *factsTable) get(v, w *Value, d domain) relation {
   112  	if v.isGenericIntConst() || w.isGenericIntConst() {
   113  		reversed := false
   114  		if v.isGenericIntConst() {
   115  			v, w = w, v
   116  			reversed = true
   117  		}
   118  		r := lt | eq | gt
   119  		lim, ok := ft.limits[v.ID]
   120  		if !ok {
   121  			return r
   122  		}
   123  		c := w.AuxInt
   124  		switch d {
   125  		case signed:
   126  			switch {
   127  			case c < lim.min:
   128  				r = gt
   129  			case c > lim.max:
   130  				r = lt
   131  			case c == lim.min && c == lim.max:
   132  				r = eq
   133  			case c == lim.min:
   134  				r = gt | eq
   135  			case c == lim.max:
   136  				r = lt | eq
   137  			}
   138  		case unsigned:
   139  			// TODO: also use signed data if lim.min >= 0?
   140  			var uc uint64
   141  			switch w.Op {
   142  			case OpConst64:
   143  				uc = uint64(c)
   144  			case OpConst32:
   145  				uc = uint64(uint32(c))
   146  			case OpConst16:
   147  				uc = uint64(uint16(c))
   148  			case OpConst8:
   149  				uc = uint64(uint8(c))
   150  			}
   151  			switch {
   152  			case uc < lim.umin:
   153  				r = gt
   154  			case uc > lim.umax:
   155  				r = lt
   156  			case uc == lim.umin && uc == lim.umax:
   157  				r = eq
   158  			case uc == lim.umin:
   159  				r = gt | eq
   160  			case uc == lim.umax:
   161  				r = lt | eq
   162  			}
   163  		}
   164  		if reversed {
   165  			return reverseBits[r]
   166  		}
   167  		return r
   168  	}
   169  
   170  	reversed := false
   171  	if lessByID(w, v) {
   172  		v, w = w, v
   173  		reversed = !reversed
   174  	}
   175  
   176  	p := pair{v, w, d}
   177  	r, ok := ft.facts[p]
   178  	if !ok {
   179  		if p.v == p.w {
   180  			r = eq
   181  		} else {
   182  			r = lt | eq | gt
   183  		}
   184  	}
   185  
   186  	if reversed {
   187  		return reverseBits[r]
   188  	}
   189  	return r
   190  }
   191  
   192  // update updates the set of relations between v and w in domain d
   193  // restricting it to r.
   194  func (ft *factsTable) update(v, w *Value, d domain, r relation) {
   195  	if lessByID(w, v) {
   196  		v, w = w, v
   197  		r = reverseBits[r]
   198  	}
   199  
   200  	p := pair{v, w, d}
   201  	oldR := ft.get(v, w, d)
   202  	ft.stack = append(ft.stack, fact{p, oldR})
   203  	ft.facts[p] = oldR & r
   204  
   205  	// Extract bounds when comparing against constants
   206  	if v.isGenericIntConst() {
   207  		v, w = w, v
   208  		r = reverseBits[r]
   209  	}
   210  	if v != nil && w.isGenericIntConst() {
   211  		c := w.AuxInt
   212  		// Note: all the +1/-1 below could overflow/underflow. Either will
   213  		// still generate correct results, it will just lead to imprecision.
   214  		// In fact if there is overflow/underflow, the corresponding
   215  		// code is unreachable because the known range is outside the range
   216  		// of the value's type.
   217  		old, ok := ft.limits[v.ID]
   218  		if !ok {
   219  			old = noLimit
   220  		}
   221  		lim := old
   222  		// Update lim with the new information we know.
   223  		switch d {
   224  		case signed:
   225  			switch r {
   226  			case lt:
   227  				if c-1 < lim.max {
   228  					lim.max = c - 1
   229  				}
   230  			case lt | eq:
   231  				if c < lim.max {
   232  					lim.max = c
   233  				}
   234  			case gt | eq:
   235  				if c > lim.min {
   236  					lim.min = c
   237  				}
   238  			case gt:
   239  				if c+1 > lim.min {
   240  					lim.min = c + 1
   241  				}
   242  			case lt | gt:
   243  				if c == lim.min {
   244  					lim.min++
   245  				}
   246  				if c == lim.max {
   247  					lim.max--
   248  				}
   249  			case eq:
   250  				lim.min = c
   251  				lim.max = c
   252  			}
   253  		case unsigned:
   254  			var uc uint64
   255  			switch w.Op {
   256  			case OpConst64:
   257  				uc = uint64(c)
   258  			case OpConst32:
   259  				uc = uint64(uint32(c))
   260  			case OpConst16:
   261  				uc = uint64(uint16(c))
   262  			case OpConst8:
   263  				uc = uint64(uint8(c))
   264  			}
   265  			switch r {
   266  			case lt:
   267  				if uc-1 < lim.umax {
   268  					lim.umax = uc - 1
   269  				}
   270  			case lt | eq:
   271  				if uc < lim.umax {
   272  					lim.umax = uc
   273  				}
   274  			case gt | eq:
   275  				if uc > lim.umin {
   276  					lim.umin = uc
   277  				}
   278  			case gt:
   279  				if uc+1 > lim.umin {
   280  					lim.umin = uc + 1
   281  				}
   282  			case lt | gt:
   283  				if uc == lim.umin {
   284  					lim.umin++
   285  				}
   286  				if uc == lim.umax {
   287  					lim.umax--
   288  				}
   289  			case eq:
   290  				lim.umin = uc
   291  				lim.umax = uc
   292  			}
   293  		}
   294  		ft.limitStack = append(ft.limitStack, limitFact{v.ID, old})
   295  		ft.limits[v.ID] = lim
   296  	}
   297  }
   298  
   299  // isNonNegative returns true if v is known to be non-negative.
   300  func (ft *factsTable) isNonNegative(v *Value) bool {
   301  	if isNonNegative(v) {
   302  		return true
   303  	}
   304  	l, has := ft.limits[v.ID]
   305  	return has && (l.min >= 0 || l.umax <= math.MaxInt64)
   306  }
   307  
   308  // checkpoint saves the current state of known relations.
   309  // Called when descending on a branch.
   310  func (ft *factsTable) checkpoint() {
   311  	ft.stack = append(ft.stack, checkpointFact)
   312  	ft.limitStack = append(ft.limitStack, checkpointBound)
   313  }
   314  
   315  // restore restores known relation to the state just
   316  // before the previous checkpoint.
   317  // Called when backing up on a branch.
   318  func (ft *factsTable) restore() {
   319  	for {
   320  		old := ft.stack[len(ft.stack)-1]
   321  		ft.stack = ft.stack[:len(ft.stack)-1]
   322  		if old == checkpointFact {
   323  			break
   324  		}
   325  		if old.r == lt|eq|gt {
   326  			delete(ft.facts, old.p)
   327  		} else {
   328  			ft.facts[old.p] = old.r
   329  		}
   330  	}
   331  	for {
   332  		old := ft.limitStack[len(ft.limitStack)-1]
   333  		ft.limitStack = ft.limitStack[:len(ft.limitStack)-1]
   334  		if old.vid == 0 { // checkpointBound
   335  			break
   336  		}
   337  		if old.limit == noLimit {
   338  			delete(ft.limits, old.vid)
   339  		} else {
   340  			ft.limits[old.vid] = old.limit
   341  		}
   342  	}
   343  }
   344  
   345  func lessByID(v, w *Value) bool {
   346  	if v == nil && w == nil {
   347  		// Should not happen, but just in case.
   348  		return false
   349  	}
   350  	if v == nil {
   351  		return true
   352  	}
   353  	return w != nil && v.ID < w.ID
   354  }
   355  
   356  var (
   357  	reverseBits = [...]relation{0, 4, 2, 6, 1, 5, 3, 7}
   358  
   359  	// maps what we learn when the positive branch is taken.
   360  	// For example:
   361  	//      OpLess8:   {signed, lt},
   362  	//	v1 = (OpLess8 v2 v3).
   363  	// If v1 branch is taken than we learn that the rangeMaks
   364  	// can be at most lt.
   365  	domainRelationTable = map[Op]struct {
   366  		d domain
   367  		r relation
   368  	}{
   369  		OpEq8:   {signed | unsigned, eq},
   370  		OpEq16:  {signed | unsigned, eq},
   371  		OpEq32:  {signed | unsigned, eq},
   372  		OpEq64:  {signed | unsigned, eq},
   373  		OpEqPtr: {pointer, eq},
   374  
   375  		OpNeq8:   {signed | unsigned, lt | gt},
   376  		OpNeq16:  {signed | unsigned, lt | gt},
   377  		OpNeq32:  {signed | unsigned, lt | gt},
   378  		OpNeq64:  {signed | unsigned, lt | gt},
   379  		OpNeqPtr: {pointer, lt | gt},
   380  
   381  		OpLess8:   {signed, lt},
   382  		OpLess8U:  {unsigned, lt},
   383  		OpLess16:  {signed, lt},
   384  		OpLess16U: {unsigned, lt},
   385  		OpLess32:  {signed, lt},
   386  		OpLess32U: {unsigned, lt},
   387  		OpLess64:  {signed, lt},
   388  		OpLess64U: {unsigned, lt},
   389  
   390  		OpLeq8:   {signed, lt | eq},
   391  		OpLeq8U:  {unsigned, lt | eq},
   392  		OpLeq16:  {signed, lt | eq},
   393  		OpLeq16U: {unsigned, lt | eq},
   394  		OpLeq32:  {signed, lt | eq},
   395  		OpLeq32U: {unsigned, lt | eq},
   396  		OpLeq64:  {signed, lt | eq},
   397  		OpLeq64U: {unsigned, lt | eq},
   398  
   399  		OpGeq8:   {signed, eq | gt},
   400  		OpGeq8U:  {unsigned, eq | gt},
   401  		OpGeq16:  {signed, eq | gt},
   402  		OpGeq16U: {unsigned, eq | gt},
   403  		OpGeq32:  {signed, eq | gt},
   404  		OpGeq32U: {unsigned, eq | gt},
   405  		OpGeq64:  {signed, eq | gt},
   406  		OpGeq64U: {unsigned, eq | gt},
   407  
   408  		OpGreater8:   {signed, gt},
   409  		OpGreater8U:  {unsigned, gt},
   410  		OpGreater16:  {signed, gt},
   411  		OpGreater16U: {unsigned, gt},
   412  		OpGreater32:  {signed, gt},
   413  		OpGreater32U: {unsigned, gt},
   414  		OpGreater64:  {signed, gt},
   415  		OpGreater64U: {unsigned, gt},
   416  
   417  		// TODO: OpIsInBounds actually test 0 <= a < b. This means
   418  		// that the positive branch learns signed/LT and unsigned/LT
   419  		// but the negative branch only learns unsigned/GE.
   420  		OpIsInBounds:      {unsigned, lt},
   421  		OpIsSliceInBounds: {unsigned, lt | eq},
   422  	}
   423  )
   424  
   425  // prove removes redundant BlockIf controls that can be inferred in a straight line.
   426  //
   427  // By far, the most common redundant pair are generated by bounds checking.
   428  // For example for the code:
   429  //
   430  //    a[i] = 4
   431  //    foo(a[i])
   432  //
   433  // The compiler will generate the following code:
   434  //
   435  //    if i >= len(a) {
   436  //        panic("not in bounds")
   437  //    }
   438  //    a[i] = 4
   439  //    if i >= len(a) {
   440  //        panic("not in bounds")
   441  //    }
   442  //    foo(a[i])
   443  //
   444  // The second comparison i >= len(a) is clearly redundant because if the
   445  // else branch of the first comparison is executed, we already know that i < len(a).
   446  // The code for the second panic can be removed.
   447  func prove(f *Func) {
   448  	idom := dominators(f)
   449  	sdom := newSparseTree(f, idom)
   450  
   451  	// current node state
   452  	type walkState int
   453  	const (
   454  		descend walkState = iota
   455  		simplify
   456  	)
   457  	// work maintains the DFS stack.
   458  	type bp struct {
   459  		block *Block    // current handled block
   460  		state walkState // what's to do
   461  	}
   462  	work := make([]bp, 0, 256)
   463  	work = append(work, bp{
   464  		block: f.Entry,
   465  		state: descend,
   466  	})
   467  
   468  	ft := newFactsTable()
   469  
   470  	// DFS on the dominator tree.
   471  	for len(work) > 0 {
   472  		node := work[len(work)-1]
   473  		work = work[:len(work)-1]
   474  		parent := idom[node.block.ID]
   475  		branch := getBranch(sdom, parent, node.block)
   476  
   477  		switch node.state {
   478  		case descend:
   479  			if branch != unknown {
   480  				ft.checkpoint()
   481  				c := parent.Control
   482  				updateRestrictions(ft, boolean, nil, c, lt|gt, branch)
   483  				if tr, has := domainRelationTable[parent.Control.Op]; has {
   484  					// When we branched from parent we learned a new set of
   485  					// restrictions. Update the factsTable accordingly.
   486  					updateRestrictions(ft, tr.d, c.Args[0], c.Args[1], tr.r, branch)
   487  				}
   488  			}
   489  
   490  			work = append(work, bp{
   491  				block: node.block,
   492  				state: simplify,
   493  			})
   494  			for s := sdom.Child(node.block); s != nil; s = sdom.Sibling(s) {
   495  				work = append(work, bp{
   496  					block: s,
   497  					state: descend,
   498  				})
   499  			}
   500  
   501  		case simplify:
   502  			succ := simplifyBlock(ft, node.block)
   503  			if succ != unknown {
   504  				b := node.block
   505  				b.Kind = BlockFirst
   506  				b.SetControl(nil)
   507  				if succ == negative {
   508  					b.Succs[0], b.Succs[1] = b.Succs[1], b.Succs[0]
   509  				}
   510  			}
   511  
   512  			if branch != unknown {
   513  				ft.restore()
   514  			}
   515  		}
   516  	}
   517  }
   518  
   519  // getBranch returns the range restrictions added by p
   520  // when reaching b. p is the immediate dominator of b.
   521  func getBranch(sdom sparseTree, p *Block, b *Block) branch {
   522  	if p == nil || p.Kind != BlockIf {
   523  		return unknown
   524  	}
   525  	// If p and p.Succs[0] are dominators it means that every path
   526  	// from entry to b passes through p and p.Succs[0]. We care that
   527  	// no path from entry to b passes through p.Succs[1]. If p.Succs[0]
   528  	// has one predecessor then (apart from the degenerate case),
   529  	// there is no path from entry that can reach b through p.Succs[1].
   530  	// TODO: how about p->yes->b->yes, i.e. a loop in yes.
   531  	if sdom.isAncestorEq(p.Succs[0], b) && len(p.Succs[0].Preds) == 1 {
   532  		return positive
   533  	}
   534  	if sdom.isAncestorEq(p.Succs[1], b) && len(p.Succs[1].Preds) == 1 {
   535  		return negative
   536  	}
   537  	return unknown
   538  }
   539  
   540  // updateRestrictions updates restrictions from the immediate
   541  // dominating block (p) using r. r is adjusted according to the branch taken.
   542  func updateRestrictions(ft *factsTable, t domain, v, w *Value, r relation, branch branch) {
   543  	if t == 0 || branch == unknown {
   544  		// Trivial case: nothing to do, or branch unknown.
   545  		// Shoult not happen, but just in case.
   546  		return
   547  	}
   548  	if branch == negative {
   549  		// Negative branch taken, complement the relations.
   550  		r = (lt | eq | gt) ^ r
   551  	}
   552  	for i := domain(1); i <= t; i <<= 1 {
   553  		if t&i != 0 {
   554  			ft.update(v, w, i, r)
   555  		}
   556  	}
   557  }
   558  
   559  // simplifyBlock simplifies block known the restrictions in ft.
   560  // Returns which branch must always be taken.
   561  func simplifyBlock(ft *factsTable, b *Block) branch {
   562  	if b.Kind != BlockIf {
   563  		return unknown
   564  	}
   565  
   566  	// First, checks if the condition itself is redundant.
   567  	m := ft.get(nil, b.Control, boolean)
   568  	if m == lt|gt {
   569  		if b.Func.pass.debug > 0 {
   570  			b.Func.Config.Warnl(b.Line, "Proved boolean %s", b.Control.Op)
   571  		}
   572  		return positive
   573  	}
   574  	if m == eq {
   575  		if b.Func.pass.debug > 0 {
   576  			b.Func.Config.Warnl(b.Line, "Disproved boolean %s", b.Control.Op)
   577  		}
   578  		return negative
   579  	}
   580  
   581  	// Next look check equalities.
   582  	c := b.Control
   583  	tr, has := domainRelationTable[c.Op]
   584  	if !has {
   585  		return unknown
   586  	}
   587  
   588  	a0, a1 := c.Args[0], c.Args[1]
   589  	for d := domain(1); d <= tr.d; d <<= 1 {
   590  		if d&tr.d == 0 {
   591  			continue
   592  		}
   593  
   594  		// tr.r represents in which case the positive branch is taken.
   595  		// m represents which cases are possible because of previous relations.
   596  		// If the set of possible relations m is included in the set of relations
   597  		// need to take the positive branch (or negative) then that branch will
   598  		// always be taken.
   599  		// For shortcut, if m == 0 then this block is dead code.
   600  		m := ft.get(a0, a1, d)
   601  		if m != 0 && tr.r&m == m {
   602  			if b.Func.pass.debug > 0 {
   603  				b.Func.Config.Warnl(b.Line, "Proved %s", c.Op)
   604  			}
   605  			return positive
   606  		}
   607  		if m != 0 && ((lt|eq|gt)^tr.r)&m == m {
   608  			if b.Func.pass.debug > 0 {
   609  				b.Func.Config.Warnl(b.Line, "Disproved %s", c.Op)
   610  			}
   611  			return negative
   612  		}
   613  	}
   614  
   615  	// HACK: If the first argument of IsInBounds or IsSliceInBounds
   616  	// is a constant and we already know that constant is smaller (or equal)
   617  	// to the upper bound than this is proven. Most useful in cases such as:
   618  	// if len(a) <= 1 { return }
   619  	// do something with a[1]
   620  	if (c.Op == OpIsInBounds || c.Op == OpIsSliceInBounds) && ft.isNonNegative(c.Args[0]) {
   621  		m := ft.get(a0, a1, signed)
   622  		if m != 0 && tr.r&m == m {
   623  			if b.Func.pass.debug > 0 {
   624  				b.Func.Config.Warnl(b.Line, "Proved non-negative bounds %s", c.Op)
   625  			}
   626  			return positive
   627  		}
   628  	}
   629  
   630  	return unknown
   631  }
   632  
   633  // isNonNegative returns true is v is known to be greater or equal to zero.
   634  func isNonNegative(v *Value) bool {
   635  	switch v.Op {
   636  	case OpConst64:
   637  		return v.AuxInt >= 0
   638  
   639  	case OpStringLen, OpSliceLen, OpSliceCap,
   640  		OpZeroExt8to64, OpZeroExt16to64, OpZeroExt32to64:
   641  		return true
   642  
   643  	case OpRsh64x64:
   644  		return isNonNegative(v.Args[0])
   645  	}
   646  	return false
   647  }