github.com/go-asm/go@v1.21.1-0.20240213172139-40c5ead50c48/cmd/compile/inline/inlheur/analyze_func_callsites.go (about)

     1  // Copyright 2023 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package inlheur
     6  
     7  import (
     8  	"fmt"
     9  	"os"
    10  	"strings"
    11  
    12  	"github.com/go-asm/go/cmd/compile/ir"
    13  	"github.com/go-asm/go/cmd/compile/pgo"
    14  	"github.com/go-asm/go/cmd/compile/typecheck"
    15  )
    16  
    17  type callSiteAnalyzer struct {
    18  	fn *ir.Func
    19  	*nameFinder
    20  }
    21  
    22  type callSiteTableBuilder struct {
    23  	fn *ir.Func
    24  	*nameFinder
    25  	cstab    CallSiteTab
    26  	ptab     map[ir.Node]pstate
    27  	nstack   []ir.Node
    28  	loopNest int
    29  	isInit   bool
    30  }
    31  
    32  func makeCallSiteAnalyzer(fn *ir.Func) *callSiteAnalyzer {
    33  	return &callSiteAnalyzer{
    34  		fn:         fn,
    35  		nameFinder: newNameFinder(fn),
    36  	}
    37  }
    38  
    39  func makeCallSiteTableBuilder(fn *ir.Func, cstab CallSiteTab, ptab map[ir.Node]pstate, loopNestingLevel int, nf *nameFinder) *callSiteTableBuilder {
    40  	isInit := fn.IsPackageInit() || strings.HasPrefix(fn.Sym().Name, "init.")
    41  	return &callSiteTableBuilder{
    42  		fn:         fn,
    43  		cstab:      cstab,
    44  		ptab:       ptab,
    45  		isInit:     isInit,
    46  		loopNest:   loopNestingLevel,
    47  		nstack:     []ir.Node{fn},
    48  		nameFinder: nf,
    49  	}
    50  }
    51  
    52  // computeCallSiteTable builds and returns a table of call sites for
    53  // the specified region in function fn. A region here corresponds to a
    54  // specific subtree within the AST for a function. The main intended
    55  // use cases are for 'region' to be either A) an entire function body,
    56  // or B) an inlined call expression.
    57  func computeCallSiteTable(fn *ir.Func, region ir.Nodes, cstab CallSiteTab, ptab map[ir.Node]pstate, loopNestingLevel int, nf *nameFinder) CallSiteTab {
    58  	cstb := makeCallSiteTableBuilder(fn, cstab, ptab, loopNestingLevel, nf)
    59  	var doNode func(ir.Node) bool
    60  	doNode = func(n ir.Node) bool {
    61  		cstb.nodeVisitPre(n)
    62  		ir.DoChildren(n, doNode)
    63  		cstb.nodeVisitPost(n)
    64  		return false
    65  	}
    66  	for _, n := range region {
    67  		doNode(n)
    68  	}
    69  	return cstb.cstab
    70  }
    71  
    72  func (cstb *callSiteTableBuilder) flagsForNode(call *ir.CallExpr) CSPropBits {
    73  	var r CSPropBits
    74  
    75  	if debugTrace&debugTraceCalls != 0 {
    76  		fmt.Fprintf(os.Stderr, "=-= analyzing call at %s\n",
    77  			fmtFullPos(call.Pos()))
    78  	}
    79  
    80  	// Set a bit if this call is within a loop.
    81  	if cstb.loopNest > 0 {
    82  		r |= CallSiteInLoop
    83  	}
    84  
    85  	// Set a bit if the call is within an init function (either
    86  	// compiler-generated or user-written).
    87  	if cstb.isInit {
    88  		r |= CallSiteInInitFunc
    89  	}
    90  
    91  	// Decide whether to apply the panic path heuristic. Hack: don't
    92  	// apply this heuristic in the function "main.main" (mostly just
    93  	// to avoid annoying users).
    94  	if !isMainMain(cstb.fn) {
    95  		r = cstb.determinePanicPathBits(call, r)
    96  	}
    97  
    98  	return r
    99  }
   100  
   101  // determinePanicPathBits updates the CallSiteOnPanicPath bit within
   102  // "r" if we think this call is on an unconditional path to
   103  // panic/exit. Do this by walking back up the node stack to see if we
   104  // can find either A) an enclosing panic, or B) a statement node that
   105  // we've determined leads to a panic/exit.
   106  func (cstb *callSiteTableBuilder) determinePanicPathBits(call ir.Node, r CSPropBits) CSPropBits {
   107  	cstb.nstack = append(cstb.nstack, call)
   108  	defer func() {
   109  		cstb.nstack = cstb.nstack[:len(cstb.nstack)-1]
   110  	}()
   111  
   112  	for ri := range cstb.nstack[:len(cstb.nstack)-1] {
   113  		i := len(cstb.nstack) - ri - 1
   114  		n := cstb.nstack[i]
   115  		_, isCallExpr := n.(*ir.CallExpr)
   116  		_, isStmt := n.(ir.Stmt)
   117  		if isCallExpr {
   118  			isStmt = false
   119  		}
   120  
   121  		if debugTrace&debugTraceCalls != 0 {
   122  			ps, inps := cstb.ptab[n]
   123  			fmt.Fprintf(os.Stderr, "=-= callpar %d op=%s ps=%s inptab=%v stmt=%v\n", i, n.Op().String(), ps.String(), inps, isStmt)
   124  		}
   125  
   126  		if n.Op() == ir.OPANIC {
   127  			r |= CallSiteOnPanicPath
   128  			break
   129  		}
   130  		if v, ok := cstb.ptab[n]; ok {
   131  			if v == psCallsPanic {
   132  				r |= CallSiteOnPanicPath
   133  				break
   134  			}
   135  			if isStmt {
   136  				break
   137  			}
   138  		}
   139  	}
   140  	return r
   141  }
   142  
   143  // propsForArg returns property bits for a given call argument expression arg.
   144  func (cstb *callSiteTableBuilder) propsForArg(arg ir.Node) ActualExprPropBits {
   145  	if cval := cstb.constValue(arg); cval != nil {
   146  		return ActualExprConstant
   147  	}
   148  	if cstb.isConcreteConvIface(arg) {
   149  		return ActualExprIsConcreteConvIface
   150  	}
   151  	fname := cstb.funcName(arg)
   152  	if fname != nil {
   153  		if fn := fname.Func; fn != nil && typecheck.HaveInlineBody(fn) {
   154  			return ActualExprIsInlinableFunc
   155  		}
   156  		return ActualExprIsFunc
   157  	}
   158  	return 0
   159  }
   160  
   161  // argPropsForCall returns a slice of argument properties for the
   162  // expressions being passed to the callee in the specific call
   163  // expression; these will be stored in the CallSite object for a given
   164  // call and then consulted when scoring. If no arg has any interesting
   165  // properties we try to save some space and return a nil slice.
   166  func (cstb *callSiteTableBuilder) argPropsForCall(ce *ir.CallExpr) []ActualExprPropBits {
   167  	rv := make([]ActualExprPropBits, len(ce.Args))
   168  	somethingInteresting := false
   169  	for idx := range ce.Args {
   170  		argProp := cstb.propsForArg(ce.Args[idx])
   171  		somethingInteresting = somethingInteresting || (argProp != 0)
   172  		rv[idx] = argProp
   173  	}
   174  	if !somethingInteresting {
   175  		return nil
   176  	}
   177  	return rv
   178  }
   179  
   180  func (cstb *callSiteTableBuilder) addCallSite(callee *ir.Func, call *ir.CallExpr) {
   181  	flags := cstb.flagsForNode(call)
   182  	argProps := cstb.argPropsForCall(call)
   183  	if debugTrace&debugTraceCalls != 0 {
   184  		fmt.Fprintf(os.Stderr, "=-= props %+v for call %v\n", argProps, call)
   185  	}
   186  	// FIXME: maybe bulk-allocate these?
   187  	cs := &CallSite{
   188  		Call:     call,
   189  		Callee:   callee,
   190  		Assign:   cstb.containingAssignment(call),
   191  		ArgProps: argProps,
   192  		Flags:    flags,
   193  		ID:       uint(len(cstb.cstab)),
   194  	}
   195  	if _, ok := cstb.cstab[call]; ok {
   196  		fmt.Fprintf(os.Stderr, "*** cstab duplicate entry at: %s\n",
   197  			fmtFullPos(call.Pos()))
   198  		fmt.Fprintf(os.Stderr, "*** call: %+v\n", call)
   199  		panic("bad")
   200  	}
   201  	// Set initial score for callsite to the cost computed
   202  	// by CanInline; this score will be refined later based
   203  	// on heuristics.
   204  	cs.Score = int(callee.Inl.Cost)
   205  
   206  	if cstb.cstab == nil {
   207  		cstb.cstab = make(CallSiteTab)
   208  	}
   209  	cstb.cstab[call] = cs
   210  	if debugTrace&debugTraceCalls != 0 {
   211  		fmt.Fprintf(os.Stderr, "=-= added callsite: caller=%v callee=%v n=%s\n",
   212  			cstb.fn, callee, fmtFullPos(call.Pos()))
   213  	}
   214  }
   215  
   216  func (cstb *callSiteTableBuilder) nodeVisitPre(n ir.Node) {
   217  	switch n.Op() {
   218  	case ir.ORANGE, ir.OFOR:
   219  		if !hasTopLevelLoopBodyReturnOrBreak(loopBody(n)) {
   220  			cstb.loopNest++
   221  		}
   222  	case ir.OCALLFUNC:
   223  		ce := n.(*ir.CallExpr)
   224  		callee := pgo.DirectCallee(ce.Fun)
   225  		if callee != nil && callee.Inl != nil {
   226  			cstb.addCallSite(callee, ce)
   227  		}
   228  	}
   229  	cstb.nstack = append(cstb.nstack, n)
   230  }
   231  
   232  func (cstb *callSiteTableBuilder) nodeVisitPost(n ir.Node) {
   233  	cstb.nstack = cstb.nstack[:len(cstb.nstack)-1]
   234  	switch n.Op() {
   235  	case ir.ORANGE, ir.OFOR:
   236  		if !hasTopLevelLoopBodyReturnOrBreak(loopBody(n)) {
   237  			cstb.loopNest--
   238  		}
   239  	}
   240  }
   241  
   242  func loopBody(n ir.Node) ir.Nodes {
   243  	if forst, ok := n.(*ir.ForStmt); ok {
   244  		return forst.Body
   245  	}
   246  	if rst, ok := n.(*ir.RangeStmt); ok {
   247  		return rst.Body
   248  	}
   249  	return nil
   250  }
   251  
   252  // hasTopLevelLoopBodyReturnOrBreak examines the body of a "for" or
   253  // "range" loop to try to verify that it is a real loop, as opposed to
   254  // a construct that is syntactically loopy but doesn't actually iterate
   255  // multiple times, like:
   256  //
   257  //	for {
   258  //	  blah()
   259  //	  return 1
   260  //	}
   261  //
   262  // [Remark: the pattern above crops up quite a bit in the source code
   263  // for the compiler itself, e.g. the auto-generated rewrite code]
   264  //
   265  // Note that we don't look for GOTO statements here, so it's possible
   266  // we'll get the wrong result for a loop with complicated control
   267  // jumps via gotos.
   268  func hasTopLevelLoopBodyReturnOrBreak(loopBody ir.Nodes) bool {
   269  	for _, n := range loopBody {
   270  		if n.Op() == ir.ORETURN || n.Op() == ir.OBREAK {
   271  			return true
   272  		}
   273  	}
   274  	return false
   275  }
   276  
   277  // containingAssignment returns the top-level assignment statement
   278  // for a statement level function call "n". Examples:
   279  //
   280  //	x := foo()
   281  //	x, y := bar(z, baz())
   282  //	if blah() { ...
   283  //
   284  // Here the top-level assignment statement for the foo() call is the
   285  // statement assigning to "x"; the top-level assignment for "bar()"
   286  // call is the assignment to x,y. For the baz() and blah() calls,
   287  // there is no top level assignment statement.
   288  //
   289  // The unstated goal here is that we want to use the containing
   290  // assignment to establish a connection between a given call and the
   291  // variables to which its results/returns are being assigned.
   292  //
   293  // Note that for the "bar" command above, the front end sometimes
   294  // decomposes this into two assignments, the first one assigning the
   295  // call to a pair of auto-temps, then the second one assigning the
   296  // auto-temps to the user-visible vars. This helper will return the
   297  // second (outer) of these two.
   298  func (cstb *callSiteTableBuilder) containingAssignment(n ir.Node) ir.Node {
   299  	parent := cstb.nstack[len(cstb.nstack)-1]
   300  
   301  	// assignsOnlyAutoTemps returns TRUE of the specified OAS2FUNC
   302  	// node assigns only auto-temps.
   303  	assignsOnlyAutoTemps := func(x ir.Node) bool {
   304  		alst := x.(*ir.AssignListStmt)
   305  		oa2init := alst.Init()
   306  		if len(oa2init) == 0 {
   307  			return false
   308  		}
   309  		for _, v := range oa2init {
   310  			d := v.(*ir.Decl)
   311  			if !ir.IsAutoTmp(d.X) {
   312  				return false
   313  			}
   314  		}
   315  		return true
   316  	}
   317  
   318  	// Simple case: x := foo()
   319  	if parent.Op() == ir.OAS {
   320  		return parent
   321  	}
   322  
   323  	// Multi-return case: x, y := bar()
   324  	if parent.Op() == ir.OAS2FUNC {
   325  		// Hack city: if the result vars are auto-temps, try looking
   326  		// for an outer assignment in the tree. The code shape we're
   327  		// looking for here is:
   328  		//
   329  		// OAS1({x,y},OCONVNOP(OAS2FUNC({auto1,auto2},OCALLFUNC(bar))))
   330  		//
   331  		if assignsOnlyAutoTemps(parent) {
   332  			par2 := cstb.nstack[len(cstb.nstack)-2]
   333  			if par2.Op() == ir.OAS2 {
   334  				return par2
   335  			}
   336  			if par2.Op() == ir.OCONVNOP {
   337  				par3 := cstb.nstack[len(cstb.nstack)-3]
   338  				if par3.Op() == ir.OAS2 {
   339  					return par3
   340  				}
   341  			}
   342  		}
   343  	}
   344  
   345  	return nil
   346  }
   347  
   348  // UpdateCallsiteTable handles updating of callerfn's call site table
   349  // after an inlined has been carried out, e.g. the call at 'n' as been
   350  // turned into the inlined call expression 'ic' within function
   351  // callerfn. The chief thing of interest here is to make sure that any
   352  // call nodes within 'ic' are added to the call site table for
   353  // 'callerfn' and scored appropriately.
   354  func UpdateCallsiteTable(callerfn *ir.Func, n *ir.CallExpr, ic *ir.InlinedCallExpr) {
   355  	enableDebugTraceIfEnv()
   356  	defer disableDebugTrace()
   357  
   358  	funcInlHeur, ok := fpmap[callerfn]
   359  	if !ok {
   360  		// This can happen for compiler-generated wrappers.
   361  		if debugTrace&debugTraceCalls != 0 {
   362  			fmt.Fprintf(os.Stderr, "=-= early exit, no entry for caller fn %v\n", callerfn)
   363  		}
   364  		return
   365  	}
   366  
   367  	if debugTrace&debugTraceCalls != 0 {
   368  		fmt.Fprintf(os.Stderr, "=-= UpdateCallsiteTable(caller=%v, cs=%s)\n",
   369  			callerfn, fmtFullPos(n.Pos()))
   370  	}
   371  
   372  	// Mark the call in question as inlined.
   373  	oldcs, ok := funcInlHeur.cstab[n]
   374  	if !ok {
   375  		// This can happen for compiler-generated wrappers.
   376  		return
   377  	}
   378  	oldcs.aux |= csAuxInlined
   379  
   380  	if debugTrace&debugTraceCalls != 0 {
   381  		fmt.Fprintf(os.Stderr, "=-= marked as inlined: callee=%v %s\n",
   382  			oldcs.Callee, EncodeCallSiteKey(oldcs))
   383  	}
   384  
   385  	// Walk the inlined call region to collect new callsites.
   386  	var icp pstate
   387  	if oldcs.Flags&CallSiteOnPanicPath != 0 {
   388  		icp = psCallsPanic
   389  	}
   390  	var loopNestLevel int
   391  	if oldcs.Flags&CallSiteInLoop != 0 {
   392  		loopNestLevel = 1
   393  	}
   394  	ptab := map[ir.Node]pstate{ic: icp}
   395  	nf := newNameFinder(nil)
   396  	icstab := computeCallSiteTable(callerfn, ic.Body, nil, ptab, loopNestLevel, nf)
   397  
   398  	// Record parent callsite. This is primarily for debug output.
   399  	for _, cs := range icstab {
   400  		cs.parent = oldcs
   401  	}
   402  
   403  	// Score the calls in the inlined body. Note the setting of
   404  	// "doCallResults" to false here: at the moment there isn't any
   405  	// easy way to localize or region-ize the work done by
   406  	// "rescoreBasedOnCallResultUses", which currently does a walk
   407  	// over the entire function to look for uses of a given set of
   408  	// results. Similarly we're passing nil to makeCallSiteAnalyzer,
   409  	// so as to run name finding without the use of static value &
   410  	// friends.
   411  	csa := makeCallSiteAnalyzer(nil)
   412  	const doCallResults = false
   413  	csa.scoreCallsRegion(callerfn, ic.Body, icstab, doCallResults, ic)
   414  }