github.com/aykevl/tinygo@v0.5.0/compiler/func-lowering.go (about) 1 package compiler 2 3 // This file lowers func values into their final form. This is necessary for 4 // funcValueSwitch, which needs full program analysis. 5 6 import ( 7 "sort" 8 "strconv" 9 10 "tinygo.org/x/go-llvm" 11 ) 12 13 // funcSignatureInfo keeps information about a single signature and its uses. 14 type funcSignatureInfo struct { 15 sig llvm.Value // *uint8 to identify the signature 16 funcValueWithSignatures []llvm.Value // slice of runtime.funcValueWithSignature 17 } 18 19 // funcWithUses keeps information about a single function used as func value and 20 // the assigned function ID. More commonly used functions are assigned a lower 21 // ID. 22 type funcWithUses struct { 23 funcPtr llvm.Value 24 useCount int // how often this function is used in a func value 25 id int // assigned ID 26 } 27 28 // Slice to sort functions by their use counts, or else their name if they're 29 // used equally often. 30 type funcWithUsesList []*funcWithUses 31 32 func (l funcWithUsesList) Len() int { return len(l) } 33 func (l funcWithUsesList) Less(i, j int) bool { 34 if l[i].useCount != l[j].useCount { 35 // return the reverse: we want the highest use counts sorted first 36 return l[i].useCount > l[j].useCount 37 } 38 iName := l[i].funcPtr.Name() 39 jName := l[j].funcPtr.Name() 40 return iName < jName 41 } 42 func (l funcWithUsesList) Swap(i, j int) { 43 l[i], l[j] = l[j], l[i] 44 } 45 46 // LowerFuncValue lowers the runtime.funcValueWithSignature type and 47 // runtime.getFuncPtr function to their final form. 48 func (c *Compiler) LowerFuncValues() { 49 if c.funcImplementation() != funcValueSwitch { 50 return 51 } 52 53 // Find all func values used in the program with their signatures. 54 funcValueWithSignaturePtr := llvm.PointerType(c.mod.GetTypeByName("runtime.funcValueWithSignature"), 0) 55 signatures := map[string]*funcSignatureInfo{} 56 for global := c.mod.FirstGlobal(); !global.IsNil(); global = llvm.NextGlobal(global) { 57 if global.Type() != funcValueWithSignaturePtr { 58 continue 59 } 60 sig := llvm.ConstExtractValue(global.Initializer(), []uint32{1}) 61 name := sig.Name() 62 if info, ok := signatures[name]; ok { 63 info.funcValueWithSignatures = append(info.funcValueWithSignatures, global) 64 } else { 65 signatures[name] = &funcSignatureInfo{ 66 sig: sig, 67 funcValueWithSignatures: []llvm.Value{global}, 68 } 69 } 70 } 71 72 // Sort the signatures, for deterministic execution. 73 names := make([]string, 0, len(signatures)) 74 for name := range signatures { 75 names = append(names, name) 76 } 77 sort.Strings(names) 78 79 for _, name := range names { 80 info := signatures[name] 81 functions := make(funcWithUsesList, len(info.funcValueWithSignatures)) 82 for i, use := range info.funcValueWithSignatures { 83 var useCount int 84 for _, use2 := range getUses(use) { 85 useCount += len(getUses(use2)) 86 } 87 functions[i] = &funcWithUses{ 88 funcPtr: llvm.ConstExtractValue(use.Initializer(), []uint32{0}).Operand(0), 89 useCount: useCount, 90 } 91 } 92 sort.Sort(functions) 93 94 for i, fn := range functions { 95 fn.id = i + 1 96 for _, ptrtoint := range getUses(fn.funcPtr) { 97 if ptrtoint.IsAConstantExpr().IsNil() || ptrtoint.Opcode() != llvm.PtrToInt { 98 continue 99 } 100 for _, funcValueWithSignatureConstant := range getUses(ptrtoint) { 101 for _, funcValueWithSignatureGlobal := range getUses(funcValueWithSignatureConstant) { 102 for _, use := range getUses(funcValueWithSignatureGlobal) { 103 if ptrtoint.IsAConstantExpr().IsNil() || ptrtoint.Opcode() != llvm.PtrToInt { 104 panic("expected const ptrtoint") 105 } 106 use.ReplaceAllUsesWith(llvm.ConstInt(c.uintptrType, uint64(fn.id), false)) 107 } 108 } 109 } 110 } 111 } 112 113 for _, getFuncPtrCall := range getUses(info.sig) { 114 if getFuncPtrCall.IsACallInst().IsNil() { 115 continue 116 } 117 if getFuncPtrCall.CalledValue().Name() != "runtime.getFuncPtr" { 118 panic("expected all call uses to be runtime.getFuncPtr") 119 } 120 funcID := getFuncPtrCall.Operand(1) 121 switch len(functions) { 122 case 0: 123 // There are no functions used in a func value that implement 124 // this signature. The only possible value is a nil value. 125 for _, inttoptr := range getUses(getFuncPtrCall) { 126 if inttoptr.IsAIntToPtrInst().IsNil() { 127 panic("expected inttoptr") 128 } 129 nilptr := llvm.ConstPointerNull(inttoptr.Type()) 130 inttoptr.ReplaceAllUsesWith(nilptr) 131 inttoptr.EraseFromParentAsInstruction() 132 } 133 getFuncPtrCall.EraseFromParentAsInstruction() 134 case 1: 135 // There is exactly one function with this signature that is 136 // used in a func value. The func value itself can be either nil 137 // or this one function. 138 c.builder.SetInsertPointBefore(getFuncPtrCall) 139 zero := llvm.ConstInt(c.uintptrType, 0, false) 140 isnil := c.builder.CreateICmp(llvm.IntEQ, funcID, zero, "") 141 funcPtrNil := llvm.ConstPointerNull(functions[0].funcPtr.Type()) 142 funcPtr := c.builder.CreateSelect(isnil, funcPtrNil, functions[0].funcPtr, "") 143 for _, inttoptr := range getUses(getFuncPtrCall) { 144 if inttoptr.IsAIntToPtrInst().IsNil() { 145 panic("expected inttoptr") 146 } 147 inttoptr.ReplaceAllUsesWith(funcPtr) 148 inttoptr.EraseFromParentAsInstruction() 149 } 150 getFuncPtrCall.EraseFromParentAsInstruction() 151 default: 152 // There are multiple functions used in a func value that 153 // implement this signature. 154 // What we'll do is transform the following: 155 // rawPtr := runtime.getFuncPtr(fn) 156 // if func.rawPtr == nil { 157 // runtime.nilpanic() 158 // } 159 // result := func.rawPtr(...args, func.context) 160 // into this: 161 // if false { 162 // runtime.nilpanic() 163 // } 164 // var result // Phi 165 // switch fn.id { 166 // case 0: 167 // runtime.nilpanic() 168 // case 1: 169 // result = call first implementation... 170 // case 2: 171 // result = call second implementation... 172 // default: 173 // unreachable 174 // } 175 176 // Remove some casts, checks, and the old call which we're going 177 // to replace. 178 var funcCall llvm.Value 179 for _, inttoptr := range getUses(getFuncPtrCall) { 180 if inttoptr.IsAIntToPtrInst().IsNil() { 181 panic("expected inttoptr") 182 } 183 for _, ptrUse := range getUses(inttoptr) { 184 if !ptrUse.IsABitCastInst().IsNil() { 185 for _, bitcastUse := range getUses(ptrUse) { 186 if bitcastUse.IsACallInst().IsNil() || bitcastUse.CalledValue().Name() != "runtime.isnil" { 187 panic("expected a call to runtime.isnil") 188 } 189 bitcastUse.ReplaceAllUsesWith(llvm.ConstInt(c.ctx.Int1Type(), 0, false)) 190 bitcastUse.EraseFromParentAsInstruction() 191 } 192 ptrUse.EraseFromParentAsInstruction() 193 } else if !ptrUse.IsACallInst().IsNil() && ptrUse.CalledValue() == inttoptr { 194 if !funcCall.IsNil() { 195 panic("multiple calls on a single runtime.getFuncPtr") 196 } 197 funcCall = ptrUse 198 } else { 199 panic("unexpected getFuncPtrCall") 200 } 201 } 202 } 203 if funcCall.IsNil() { 204 panic("expected exactly one call use of a runtime.getFuncPtr") 205 } 206 207 // The block that cannot be reached with correct funcValues (to 208 // help the optimizer). 209 c.builder.SetInsertPointBefore(funcCall) 210 defaultBlock := llvm.AddBasicBlock(funcCall.InstructionParent().Parent(), "func.default") 211 c.builder.SetInsertPointAtEnd(defaultBlock) 212 c.builder.CreateUnreachable() 213 214 // Create the switch. 215 c.builder.SetInsertPointBefore(funcCall) 216 sw := c.builder.CreateSwitch(funcID, defaultBlock, len(functions)+1) 217 218 // Split right after the switch. We will need to insert a few 219 // basic blocks in this gap. 220 nextBlock := c.splitBasicBlock(sw, llvm.NextBasicBlock(sw.InstructionParent()), "func.next") 221 222 // The 0 case, which is actually a nil check. 223 nilBlock := llvm.InsertBasicBlock(nextBlock, "func.nil") 224 c.builder.SetInsertPointAtEnd(nilBlock) 225 c.createRuntimeCall("nilpanic", nil, "") 226 c.builder.CreateUnreachable() 227 sw.AddCase(llvm.ConstInt(c.uintptrType, 0, false), nilBlock) 228 229 // Gather the list of parameters for every call we're going to 230 // make. 231 callParams := make([]llvm.Value, funcCall.OperandsCount()-1) 232 for i := range callParams { 233 callParams[i] = funcCall.Operand(i) 234 } 235 236 // If the call produces a value, we need to get it using a PHI 237 // node. 238 phiBlocks := make([]llvm.BasicBlock, len(functions)) 239 phiValues := make([]llvm.Value, len(functions)) 240 for i, fn := range functions { 241 // Insert a switch case. 242 bb := llvm.InsertBasicBlock(nextBlock, "func.call"+strconv.Itoa(fn.id)) 243 c.builder.SetInsertPointAtEnd(bb) 244 result := c.builder.CreateCall(fn.funcPtr, callParams, "") 245 c.builder.CreateBr(nextBlock) 246 sw.AddCase(llvm.ConstInt(c.uintptrType, uint64(fn.id), false), bb) 247 phiBlocks[i] = bb 248 phiValues[i] = result 249 } 250 // Create the PHI node so that the call result flows into the 251 // next block (after the split). This is only necessary when the 252 // call produced a value. 253 if funcCall.Type().TypeKind() != llvm.VoidTypeKind { 254 c.builder.SetInsertPointBefore(nextBlock.FirstInstruction()) 255 phi := c.builder.CreatePHI(funcCall.Type(), "") 256 phi.AddIncoming(phiValues, phiBlocks) 257 funcCall.ReplaceAllUsesWith(phi) 258 } 259 260 // Finally, remove the old instructions. 261 funcCall.EraseFromParentAsInstruction() 262 for _, inttoptr := range getUses(getFuncPtrCall) { 263 inttoptr.EraseFromParentAsInstruction() 264 } 265 getFuncPtrCall.EraseFromParentAsInstruction() 266 } 267 } 268 } 269 }