github.com/go-asm/go@v1.21.1-0.20240213172139-40c5ead50c48/cmd/compile/inline/inlheur/score_callresult_uses.go (about)

     1  // Copyright 2023 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package inlheur
     6  
     7  import (
     8  	"fmt"
     9  	"os"
    10  
    11  	"github.com/go-asm/go/cmd/compile/ir"
    12  )
    13  
    14  // This file contains code to re-score callsites based on how the
    15  // results of the call were used.  Example:
    16  //
    17  //    func foo() {
    18  //       x, fptr := bar()
    19  //       switch x {
    20  //         case 10: fptr = baz()
    21  //         default: blix()
    22  //       }
    23  //       fptr(100)
    24  //     }
    25  //
    26  // The initial scoring pass will assign a score to "bar()" based on
    27  // various criteria, however once the first pass of scoring is done,
    28  // we look at the flags on the result from bar, and check to see
    29  // how those results are used. If bar() always returns the same constant
    30  // for its first result, and if the variable receiving that result
    31  // isn't redefined, and if that variable feeds into an if/switch
    32  // condition, then we will try to adjust the score for "bar" (on the
    33  // theory that if we inlined, we can constant fold / deadcode).
    34  
    35  type resultPropAndCS struct {
    36  	defcs *CallSite
    37  	props ResultPropBits
    38  }
    39  
    40  type resultUseAnalyzer struct {
    41  	resultNameTab map[*ir.Name]resultPropAndCS
    42  	fn            *ir.Func
    43  	cstab         CallSiteTab
    44  	*condLevelTracker
    45  }
    46  
    47  // rescoreBasedOnCallResultUses examines how call results are used,
    48  // and tries to update the scores of calls based on how their results
    49  // are used in the function.
    50  func (csa *callSiteAnalyzer) rescoreBasedOnCallResultUses(fn *ir.Func, resultNameTab map[*ir.Name]resultPropAndCS, cstab CallSiteTab) {
    51  	enableDebugTraceIfEnv()
    52  	rua := &resultUseAnalyzer{
    53  		resultNameTab:    resultNameTab,
    54  		fn:               fn,
    55  		cstab:            cstab,
    56  		condLevelTracker: new(condLevelTracker),
    57  	}
    58  	var doNode func(ir.Node) bool
    59  	doNode = func(n ir.Node) bool {
    60  		rua.nodeVisitPre(n)
    61  		ir.DoChildren(n, doNode)
    62  		rua.nodeVisitPost(n)
    63  		return false
    64  	}
    65  	doNode(fn)
    66  	disableDebugTrace()
    67  }
    68  
    69  func (csa *callSiteAnalyzer) examineCallResults(cs *CallSite, resultNameTab map[*ir.Name]resultPropAndCS) map[*ir.Name]resultPropAndCS {
    70  	if debugTrace&debugTraceScoring != 0 {
    71  		fmt.Fprintf(os.Stderr, "=-= examining call results for %q\n",
    72  			EncodeCallSiteKey(cs))
    73  	}
    74  
    75  	// Invoke a helper to pick out the specific ir.Name's the results
    76  	// from this call are assigned into, e.g. "x, y := fooBar()". If
    77  	// the call is not part of an assignment statement, or if the
    78  	// variables in question are not newly defined, then we'll receive
    79  	// an empty list here.
    80  	//
    81  	names, autoTemps, props := namesDefined(cs)
    82  	if len(names) == 0 {
    83  		return resultNameTab
    84  	}
    85  
    86  	if debugTrace&debugTraceScoring != 0 {
    87  		fmt.Fprintf(os.Stderr, "=-= %d names defined\n", len(names))
    88  	}
    89  
    90  	// For each returned value, if the value has interesting
    91  	// properties (ex: always returns the same constant), and the name
    92  	// in question is never redefined, then make an entry in the
    93  	// result table for it.
    94  	const interesting = (ResultIsConcreteTypeConvertedToInterface |
    95  		ResultAlwaysSameConstant | ResultAlwaysSameInlinableFunc | ResultAlwaysSameFunc)
    96  	for idx, n := range names {
    97  		rprop := props.ResultFlags[idx]
    98  
    99  		if debugTrace&debugTraceScoring != 0 {
   100  			fmt.Fprintf(os.Stderr, "=-= props for ret %d %q: %s\n",
   101  				idx, n.Sym().Name, rprop.String())
   102  		}
   103  
   104  		if rprop&interesting == 0 {
   105  			continue
   106  		}
   107  		if csa.nameFinder.reassigned(n) {
   108  			continue
   109  		}
   110  		if resultNameTab == nil {
   111  			resultNameTab = make(map[*ir.Name]resultPropAndCS)
   112  		} else if _, ok := resultNameTab[n]; ok {
   113  			panic("should never happen")
   114  		}
   115  		entry := resultPropAndCS{
   116  			defcs: cs,
   117  			props: rprop,
   118  		}
   119  		resultNameTab[n] = entry
   120  		if autoTemps[idx] != nil {
   121  			resultNameTab[autoTemps[idx]] = entry
   122  		}
   123  		if debugTrace&debugTraceScoring != 0 {
   124  			fmt.Fprintf(os.Stderr, "=-= add resultNameTab table entry n=%v autotemp=%v props=%s\n", n, autoTemps[idx], rprop.String())
   125  		}
   126  	}
   127  	return resultNameTab
   128  }
   129  
   130  // namesDefined returns a list of ir.Name's corresponding to locals
   131  // that receive the results from the call at site 'cs', plus the
   132  // properties object for the called function. If a given result
   133  // isn't cleanly assigned to a newly defined local, the
   134  // slot for that result in the returned list will be nil. Example:
   135  //
   136  //	call                             returned name list
   137  //
   138  //	x := foo()                       [ x ]
   139  //	z, y := bar()                    [ nil, nil ]
   140  //	_, q := baz()                    [ nil, q ]
   141  //
   142  // In the case of a multi-return call, such as "x, y := foo()",
   143  // the pattern we see from the front end will be a call op
   144  // assigning to auto-temps, and then an assignment of the auto-temps
   145  // to the user-level variables. In such cases we return
   146  // first the user-level variable (in the first func result)
   147  // and then the auto-temp name in the second result.
   148  func namesDefined(cs *CallSite) ([]*ir.Name, []*ir.Name, *FuncProps) {
   149  	// If this call doesn't feed into an assignment (and of course not
   150  	// all calls do), then we don't have anything to work with here.
   151  	if cs.Assign == nil {
   152  		return nil, nil, nil
   153  	}
   154  	funcInlHeur, ok := fpmap[cs.Callee]
   155  	if !ok {
   156  		// TODO: add an assert/panic here.
   157  		return nil, nil, nil
   158  	}
   159  	if len(funcInlHeur.props.ResultFlags) == 0 {
   160  		return nil, nil, nil
   161  	}
   162  
   163  	// Single return case.
   164  	if len(funcInlHeur.props.ResultFlags) == 1 {
   165  		asgn, ok := cs.Assign.(*ir.AssignStmt)
   166  		if !ok {
   167  			return nil, nil, nil
   168  		}
   169  		// locate name being assigned
   170  		aname, ok := asgn.X.(*ir.Name)
   171  		if !ok {
   172  			return nil, nil, nil
   173  		}
   174  		return []*ir.Name{aname}, []*ir.Name{nil}, funcInlHeur.props
   175  	}
   176  
   177  	// Multi-return case
   178  	asgn, ok := cs.Assign.(*ir.AssignListStmt)
   179  	if !ok || !asgn.Def {
   180  		return nil, nil, nil
   181  	}
   182  	userVars := make([]*ir.Name, len(funcInlHeur.props.ResultFlags))
   183  	autoTemps := make([]*ir.Name, len(funcInlHeur.props.ResultFlags))
   184  	for idx, x := range asgn.Lhs {
   185  		if n, ok := x.(*ir.Name); ok {
   186  			userVars[idx] = n
   187  			r := asgn.Rhs[idx]
   188  			if r.Op() == ir.OCONVNOP {
   189  				r = r.(*ir.ConvExpr).X
   190  			}
   191  			if ir.IsAutoTmp(r) {
   192  				autoTemps[idx] = r.(*ir.Name)
   193  			}
   194  			if debugTrace&debugTraceScoring != 0 {
   195  				fmt.Fprintf(os.Stderr, "=-= multi-ret namedef uv=%v at=%v\n",
   196  					x, autoTemps[idx])
   197  			}
   198  		} else {
   199  			return nil, nil, nil
   200  		}
   201  	}
   202  	return userVars, autoTemps, funcInlHeur.props
   203  }
   204  
   205  func (rua *resultUseAnalyzer) nodeVisitPost(n ir.Node) {
   206  	rua.condLevelTracker.post(n)
   207  }
   208  
   209  func (rua *resultUseAnalyzer) nodeVisitPre(n ir.Node) {
   210  	rua.condLevelTracker.pre(n)
   211  	switch n.Op() {
   212  	case ir.OCALLINTER:
   213  		if debugTrace&debugTraceScoring != 0 {
   214  			fmt.Fprintf(os.Stderr, "=-= rescore examine iface call %v:\n", n)
   215  		}
   216  		rua.callTargetCheckResults(n)
   217  	case ir.OCALLFUNC:
   218  		if debugTrace&debugTraceScoring != 0 {
   219  			fmt.Fprintf(os.Stderr, "=-= rescore examine call %v:\n", n)
   220  		}
   221  		rua.callTargetCheckResults(n)
   222  	case ir.OIF:
   223  		ifst := n.(*ir.IfStmt)
   224  		rua.foldCheckResults(ifst.Cond)
   225  	case ir.OSWITCH:
   226  		swst := n.(*ir.SwitchStmt)
   227  		if swst.Tag != nil {
   228  			rua.foldCheckResults(swst.Tag)
   229  		}
   230  
   231  	}
   232  }
   233  
   234  // callTargetCheckResults examines a given call to see whether the
   235  // callee expression is potentially an inlinable function returned
   236  // from a potentially inlinable call. Examples:
   237  //
   238  //	Scenario 1: named intermediate
   239  //
   240  //	   fn1 := foo()         conc := bar()
   241  //	   fn1("blah")          conc.MyMethod()
   242  //
   243  //	Scenario 2: returned func or concrete object feeds directly to call
   244  //
   245  //	   foo()("blah")        bar().MyMethod()
   246  //
   247  // In the second case although at the source level the result of the
   248  // direct call feeds right into the method call or indirect call,
   249  // we're relying on the front end having inserted an auto-temp to
   250  // capture the value.
   251  func (rua *resultUseAnalyzer) callTargetCheckResults(call ir.Node) {
   252  	ce := call.(*ir.CallExpr)
   253  	rname := rua.getCallResultName(ce)
   254  	if rname == nil {
   255  		return
   256  	}
   257  	if debugTrace&debugTraceScoring != 0 {
   258  		fmt.Fprintf(os.Stderr, "=-= staticvalue returns %v:\n",
   259  			rname)
   260  	}
   261  	if rname.Class != ir.PAUTO {
   262  		return
   263  	}
   264  	switch call.Op() {
   265  	case ir.OCALLINTER:
   266  		if debugTrace&debugTraceScoring != 0 {
   267  			fmt.Fprintf(os.Stderr, "=-= in %s checking %v for cci prop:\n",
   268  				rua.fn.Sym().Name, rname)
   269  		}
   270  		if cs := rua.returnHasProp(rname, ResultIsConcreteTypeConvertedToInterface); cs != nil {
   271  
   272  			adj := returnFeedsConcreteToInterfaceCallAdj
   273  			cs.Score, cs.ScoreMask = adjustScore(adj, cs.Score, cs.ScoreMask)
   274  		}
   275  	case ir.OCALLFUNC:
   276  		if debugTrace&debugTraceScoring != 0 {
   277  			fmt.Fprintf(os.Stderr, "=-= in %s checking %v for samefunc props:\n",
   278  				rua.fn.Sym().Name, rname)
   279  			v, ok := rua.resultNameTab[rname]
   280  			if !ok {
   281  				fmt.Fprintf(os.Stderr, "=-= no entry for %v in rt\n", rname)
   282  			} else {
   283  				fmt.Fprintf(os.Stderr, "=-= props for %v: %q\n", rname, v.props.String())
   284  			}
   285  		}
   286  		if cs := rua.returnHasProp(rname, ResultAlwaysSameInlinableFunc); cs != nil {
   287  			adj := returnFeedsInlinableFuncToIndCallAdj
   288  			cs.Score, cs.ScoreMask = adjustScore(adj, cs.Score, cs.ScoreMask)
   289  		} else if cs := rua.returnHasProp(rname, ResultAlwaysSameFunc); cs != nil {
   290  			adj := returnFeedsFuncToIndCallAdj
   291  			cs.Score, cs.ScoreMask = adjustScore(adj, cs.Score, cs.ScoreMask)
   292  
   293  		}
   294  	}
   295  }
   296  
   297  // foldCheckResults examines the specified if/switch condition 'cond'
   298  // to see if it refers to locals defined by a (potentially inlinable)
   299  // function call at call site C, and if so, whether 'cond' contains
   300  // only combinations of simple references to all of the names in
   301  // 'names' with selected constants + operators. If these criteria are
   302  // met, then we adjust the score for call site C to reflect the
   303  // fact that inlining will enable deadcode and/or constant propagation.
   304  // Note: for this heuristic to kick in, the names in question have to
   305  // be all from the same callsite. Examples:
   306  //
   307  //	  q, r := baz()	    x, y := foo()
   308  //	  switch q+r {		a, b, c := bar()
   309  //		...			    if x && y && a && b && c {
   310  //	  }					   ...
   311  //					    }
   312  //
   313  // For the call to "baz" above we apply a score adjustment, but not
   314  // for the calls to "foo" or "bar".
   315  func (rua *resultUseAnalyzer) foldCheckResults(cond ir.Node) {
   316  	namesUsed := collectNamesUsed(cond)
   317  	if len(namesUsed) == 0 {
   318  		return
   319  	}
   320  	var cs *CallSite
   321  	for _, n := range namesUsed {
   322  		rpcs, found := rua.resultNameTab[n]
   323  		if !found {
   324  			return
   325  		}
   326  		if cs != nil && rpcs.defcs != cs {
   327  			return
   328  		}
   329  		cs = rpcs.defcs
   330  		if rpcs.props&ResultAlwaysSameConstant == 0 {
   331  			return
   332  		}
   333  	}
   334  	if debugTrace&debugTraceScoring != 0 {
   335  		nls := func(nl []*ir.Name) string {
   336  			r := ""
   337  			for _, n := range nl {
   338  				r += " " + n.Sym().Name
   339  			}
   340  			return r
   341  		}
   342  		fmt.Fprintf(os.Stderr, "=-= calling ShouldFoldIfNameConstant on names={%s} cond=%v\n", nls(namesUsed), cond)
   343  	}
   344  
   345  	if !ShouldFoldIfNameConstant(cond, namesUsed) {
   346  		return
   347  	}
   348  	adj := returnFeedsConstToIfAdj
   349  	cs.Score, cs.ScoreMask = adjustScore(adj, cs.Score, cs.ScoreMask)
   350  }
   351  
   352  func collectNamesUsed(expr ir.Node) []*ir.Name {
   353  	res := []*ir.Name{}
   354  	ir.Visit(expr, func(n ir.Node) {
   355  		if n.Op() != ir.ONAME {
   356  			return
   357  		}
   358  		nn := n.(*ir.Name)
   359  		if nn.Class != ir.PAUTO {
   360  			return
   361  		}
   362  		res = append(res, nn)
   363  	})
   364  	return res
   365  }
   366  
   367  func (rua *resultUseAnalyzer) returnHasProp(name *ir.Name, prop ResultPropBits) *CallSite {
   368  	v, ok := rua.resultNameTab[name]
   369  	if !ok {
   370  		return nil
   371  	}
   372  	if v.props&prop == 0 {
   373  		return nil
   374  	}
   375  	return v.defcs
   376  }
   377  
   378  func (rua *resultUseAnalyzer) getCallResultName(ce *ir.CallExpr) *ir.Name {
   379  	var callTarg ir.Node
   380  	if sel, ok := ce.Fun.(*ir.SelectorExpr); ok {
   381  		// method call
   382  		callTarg = sel.X
   383  	} else if ctarg, ok := ce.Fun.(*ir.Name); ok {
   384  		// regular call
   385  		callTarg = ctarg
   386  	} else {
   387  		return nil
   388  	}
   389  	r := ir.StaticValue(callTarg)
   390  	if debugTrace&debugTraceScoring != 0 {
   391  		fmt.Fprintf(os.Stderr, "=-= staticname on %v returns %v:\n",
   392  			callTarg, r)
   393  	}
   394  	if r.Op() == ir.OCALLFUNC {
   395  		// This corresponds to the "x := foo()" case; here
   396  		// ir.StaticValue has brought us all the way back to
   397  		// the call expression itself. We need to back off to
   398  		// the name defined by the call; do this by looking up
   399  		// the callsite.
   400  		ce := r.(*ir.CallExpr)
   401  		cs, ok := rua.cstab[ce]
   402  		if !ok {
   403  			return nil
   404  		}
   405  		names, _, _ := namesDefined(cs)
   406  		if len(names) == 0 {
   407  			return nil
   408  		}
   409  		return names[0]
   410  	} else if r.Op() == ir.ONAME {
   411  		return r.(*ir.Name)
   412  	}
   413  	return nil
   414  }