github.com/go-asm/go@v1.21.1-0.20240213172139-40c5ead50c48/cmd/compile/ssagen/nowb.go (about) 1 // Copyright 2009 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package ssagen 6 7 import ( 8 "fmt" 9 "strings" 10 11 "github.com/go-asm/go/cmd/compile/base" 12 "github.com/go-asm/go/cmd/compile/ir" 13 "github.com/go-asm/go/cmd/compile/typecheck" 14 "github.com/go-asm/go/cmd/compile/types" 15 "github.com/go-asm/go/cmd/obj" 16 "github.com/go-asm/go/cmd/src" 17 ) 18 19 func EnableNoWriteBarrierRecCheck() { 20 nowritebarrierrecCheck = newNowritebarrierrecChecker() 21 } 22 23 func NoWriteBarrierRecCheck() { 24 // Write barriers are now known. Check the 25 // call graph. 26 nowritebarrierrecCheck.check() 27 nowritebarrierrecCheck = nil 28 } 29 30 var nowritebarrierrecCheck *nowritebarrierrecChecker 31 32 type nowritebarrierrecChecker struct { 33 // extraCalls contains extra function calls that may not be 34 // visible during later analysis. It maps from the ODCLFUNC of 35 // the caller to a list of callees. 36 extraCalls map[*ir.Func][]nowritebarrierrecCall 37 38 // curfn is the current function during AST walks. 39 curfn *ir.Func 40 } 41 42 type nowritebarrierrecCall struct { 43 target *ir.Func // caller or callee 44 lineno src.XPos // line of call 45 } 46 47 // newNowritebarrierrecChecker creates a nowritebarrierrecChecker. It 48 // must be called before walk. 49 func newNowritebarrierrecChecker() *nowritebarrierrecChecker { 50 c := &nowritebarrierrecChecker{ 51 extraCalls: make(map[*ir.Func][]nowritebarrierrecCall), 52 } 53 54 // Find all systemstack calls and record their targets. In 55 // general, flow analysis can't see into systemstack, but it's 56 // important to handle it for this check, so we model it 57 // directly. This has to happen before transforming closures in walk since 58 // it's a lot harder to work out the argument after. 59 for _, n := range typecheck.Target.Funcs { 60 c.curfn = n 61 if c.curfn.ABIWrapper() { 62 // We only want "real" calls to these 63 // functions, not the generated ones within 64 // their own ABI wrappers. 65 continue 66 } 67 ir.Visit(n, c.findExtraCalls) 68 } 69 c.curfn = nil 70 return c 71 } 72 73 func (c *nowritebarrierrecChecker) findExtraCalls(nn ir.Node) { 74 if nn.Op() != ir.OCALLFUNC { 75 return 76 } 77 n := nn.(*ir.CallExpr) 78 if n.Fun == nil || n.Fun.Op() != ir.ONAME { 79 return 80 } 81 fn := n.Fun.(*ir.Name) 82 if fn.Class != ir.PFUNC || fn.Defn == nil { 83 return 84 } 85 if types.RuntimeSymName(fn.Sym()) != "systemstack" { 86 return 87 } 88 89 var callee *ir.Func 90 arg := n.Args[0] 91 switch arg.Op() { 92 case ir.ONAME: 93 arg := arg.(*ir.Name) 94 callee = arg.Defn.(*ir.Func) 95 case ir.OCLOSURE: 96 arg := arg.(*ir.ClosureExpr) 97 callee = arg.Func 98 default: 99 base.Fatalf("expected ONAME or OCLOSURE node, got %+v", arg) 100 } 101 c.extraCalls[c.curfn] = append(c.extraCalls[c.curfn], nowritebarrierrecCall{callee, n.Pos()}) 102 } 103 104 // recordCall records a call from ODCLFUNC node "from", to function 105 // symbol "to" at position pos. 106 // 107 // This should be done as late as possible during compilation to 108 // capture precise call graphs. The target of the call is an LSym 109 // because that's all we know after we start SSA. 110 // 111 // This can be called concurrently for different from Nodes. 112 func (c *nowritebarrierrecChecker) recordCall(fn *ir.Func, to *obj.LSym, pos src.XPos) { 113 // We record this information on the *Func so this is concurrent-safe. 114 if fn.NWBRCalls == nil { 115 fn.NWBRCalls = new([]ir.SymAndPos) 116 } 117 *fn.NWBRCalls = append(*fn.NWBRCalls, ir.SymAndPos{Sym: to, Pos: pos}) 118 } 119 120 func (c *nowritebarrierrecChecker) check() { 121 // We walk the call graph as late as possible so we can 122 // capture all calls created by lowering, but this means we 123 // only get to see the obj.LSyms of calls. symToFunc lets us 124 // get back to the ODCLFUNCs. 125 symToFunc := make(map[*obj.LSym]*ir.Func) 126 // funcs records the back-edges of the BFS call graph walk. It 127 // maps from the ODCLFUNC of each function that must not have 128 // write barriers to the call that inhibits them. Functions 129 // that are directly marked go:nowritebarrierrec are in this 130 // map with a zero-valued nowritebarrierrecCall. This also 131 // acts as the set of marks for the BFS of the call graph. 132 funcs := make(map[*ir.Func]nowritebarrierrecCall) 133 // q is the queue of ODCLFUNC Nodes to visit in BFS order. 134 var q ir.NameQueue 135 136 for _, fn := range typecheck.Target.Funcs { 137 symToFunc[fn.LSym] = fn 138 139 // Make nowritebarrierrec functions BFS roots. 140 if fn.Pragma&ir.Nowritebarrierrec != 0 { 141 funcs[fn] = nowritebarrierrecCall{} 142 q.PushRight(fn.Nname) 143 } 144 // Check go:nowritebarrier functions. 145 if fn.Pragma&ir.Nowritebarrier != 0 && fn.WBPos.IsKnown() { 146 base.ErrorfAt(fn.WBPos, 0, "write barrier prohibited") 147 } 148 } 149 150 // Perform a BFS of the call graph from all 151 // go:nowritebarrierrec functions. 152 enqueue := func(src, target *ir.Func, pos src.XPos) { 153 if target.Pragma&ir.Yeswritebarrierrec != 0 { 154 // Don't flow into this function. 155 return 156 } 157 if _, ok := funcs[target]; ok { 158 // Already found a path to target. 159 return 160 } 161 162 // Record the path. 163 funcs[target] = nowritebarrierrecCall{target: src, lineno: pos} 164 q.PushRight(target.Nname) 165 } 166 for !q.Empty() { 167 fn := q.PopLeft().Func 168 169 // Check fn. 170 if fn.WBPos.IsKnown() { 171 var err strings.Builder 172 call := funcs[fn] 173 for call.target != nil { 174 fmt.Fprintf(&err, "\n\t%v: called by %v", base.FmtPos(call.lineno), call.target.Nname) 175 call = funcs[call.target] 176 } 177 base.ErrorfAt(fn.WBPos, 0, "write barrier prohibited by caller; %v%s", fn.Nname, err.String()) 178 continue 179 } 180 181 // Enqueue fn's calls. 182 for _, callee := range c.extraCalls[fn] { 183 enqueue(fn, callee.target, callee.lineno) 184 } 185 if fn.NWBRCalls == nil { 186 continue 187 } 188 for _, callee := range *fn.NWBRCalls { 189 target := symToFunc[callee.Sym] 190 if target != nil { 191 enqueue(fn, target, callee.Pos) 192 } 193 } 194 } 195 }