github.com/aykevl/tinygo@v0.5.0/compiler/func-lowering.go (about)

     1  package compiler
     2  
     3  // This file lowers func values into their final form. This is necessary for
     4  // funcValueSwitch, which needs full program analysis.
     5  
     6  import (
     7  	"sort"
     8  	"strconv"
     9  
    10  	"tinygo.org/x/go-llvm"
    11  )
    12  
    13  // funcSignatureInfo keeps information about a single signature and its uses.
    14  type funcSignatureInfo struct {
    15  	sig                     llvm.Value   // *uint8 to identify the signature
    16  	funcValueWithSignatures []llvm.Value // slice of runtime.funcValueWithSignature
    17  }
    18  
    19  // funcWithUses keeps information about a single function used as func value and
    20  // the assigned function ID. More commonly used functions are assigned a lower
    21  // ID.
    22  type funcWithUses struct {
    23  	funcPtr  llvm.Value
    24  	useCount int // how often this function is used in a func value
    25  	id       int // assigned ID
    26  }
    27  
    28  // Slice to sort functions by their use counts, or else their name if they're
    29  // used equally often.
    30  type funcWithUsesList []*funcWithUses
    31  
    32  func (l funcWithUsesList) Len() int { return len(l) }
    33  func (l funcWithUsesList) Less(i, j int) bool {
    34  	if l[i].useCount != l[j].useCount {
    35  		// return the reverse: we want the highest use counts sorted first
    36  		return l[i].useCount > l[j].useCount
    37  	}
    38  	iName := l[i].funcPtr.Name()
    39  	jName := l[j].funcPtr.Name()
    40  	return iName < jName
    41  }
    42  func (l funcWithUsesList) Swap(i, j int) {
    43  	l[i], l[j] = l[j], l[i]
    44  }
    45  
    46  // LowerFuncValue lowers the runtime.funcValueWithSignature type and
    47  // runtime.getFuncPtr function to their final form.
    48  func (c *Compiler) LowerFuncValues() {
    49  	if c.funcImplementation() != funcValueSwitch {
    50  		return
    51  	}
    52  
    53  	// Find all func values used in the program with their signatures.
    54  	funcValueWithSignaturePtr := llvm.PointerType(c.mod.GetTypeByName("runtime.funcValueWithSignature"), 0)
    55  	signatures := map[string]*funcSignatureInfo{}
    56  	for global := c.mod.FirstGlobal(); !global.IsNil(); global = llvm.NextGlobal(global) {
    57  		if global.Type() != funcValueWithSignaturePtr {
    58  			continue
    59  		}
    60  		sig := llvm.ConstExtractValue(global.Initializer(), []uint32{1})
    61  		name := sig.Name()
    62  		if info, ok := signatures[name]; ok {
    63  			info.funcValueWithSignatures = append(info.funcValueWithSignatures, global)
    64  		} else {
    65  			signatures[name] = &funcSignatureInfo{
    66  				sig:                     sig,
    67  				funcValueWithSignatures: []llvm.Value{global},
    68  			}
    69  		}
    70  	}
    71  
    72  	// Sort the signatures, for deterministic execution.
    73  	names := make([]string, 0, len(signatures))
    74  	for name := range signatures {
    75  		names = append(names, name)
    76  	}
    77  	sort.Strings(names)
    78  
    79  	for _, name := range names {
    80  		info := signatures[name]
    81  		functions := make(funcWithUsesList, len(info.funcValueWithSignatures))
    82  		for i, use := range info.funcValueWithSignatures {
    83  			var useCount int
    84  			for _, use2 := range getUses(use) {
    85  				useCount += len(getUses(use2))
    86  			}
    87  			functions[i] = &funcWithUses{
    88  				funcPtr:  llvm.ConstExtractValue(use.Initializer(), []uint32{0}).Operand(0),
    89  				useCount: useCount,
    90  			}
    91  		}
    92  		sort.Sort(functions)
    93  
    94  		for i, fn := range functions {
    95  			fn.id = i + 1
    96  			for _, ptrtoint := range getUses(fn.funcPtr) {
    97  				if ptrtoint.IsAConstantExpr().IsNil() || ptrtoint.Opcode() != llvm.PtrToInt {
    98  					continue
    99  				}
   100  				for _, funcValueWithSignatureConstant := range getUses(ptrtoint) {
   101  					for _, funcValueWithSignatureGlobal := range getUses(funcValueWithSignatureConstant) {
   102  						for _, use := range getUses(funcValueWithSignatureGlobal) {
   103  							if ptrtoint.IsAConstantExpr().IsNil() || ptrtoint.Opcode() != llvm.PtrToInt {
   104  								panic("expected const ptrtoint")
   105  							}
   106  							use.ReplaceAllUsesWith(llvm.ConstInt(c.uintptrType, uint64(fn.id), false))
   107  						}
   108  					}
   109  				}
   110  			}
   111  		}
   112  
   113  		for _, getFuncPtrCall := range getUses(info.sig) {
   114  			if getFuncPtrCall.IsACallInst().IsNil() {
   115  				continue
   116  			}
   117  			if getFuncPtrCall.CalledValue().Name() != "runtime.getFuncPtr" {
   118  				panic("expected all call uses to be runtime.getFuncPtr")
   119  			}
   120  			funcID := getFuncPtrCall.Operand(1)
   121  			switch len(functions) {
   122  			case 0:
   123  				// There are no functions used in a func value that implement
   124  				// this signature. The only possible value is a nil value.
   125  				for _, inttoptr := range getUses(getFuncPtrCall) {
   126  					if inttoptr.IsAIntToPtrInst().IsNil() {
   127  						panic("expected inttoptr")
   128  					}
   129  					nilptr := llvm.ConstPointerNull(inttoptr.Type())
   130  					inttoptr.ReplaceAllUsesWith(nilptr)
   131  					inttoptr.EraseFromParentAsInstruction()
   132  				}
   133  				getFuncPtrCall.EraseFromParentAsInstruction()
   134  			case 1:
   135  				// There is exactly one function with this signature that is
   136  				// used in a func value. The func value itself can be either nil
   137  				// or this one function.
   138  				c.builder.SetInsertPointBefore(getFuncPtrCall)
   139  				zero := llvm.ConstInt(c.uintptrType, 0, false)
   140  				isnil := c.builder.CreateICmp(llvm.IntEQ, funcID, zero, "")
   141  				funcPtrNil := llvm.ConstPointerNull(functions[0].funcPtr.Type())
   142  				funcPtr := c.builder.CreateSelect(isnil, funcPtrNil, functions[0].funcPtr, "")
   143  				for _, inttoptr := range getUses(getFuncPtrCall) {
   144  					if inttoptr.IsAIntToPtrInst().IsNil() {
   145  						panic("expected inttoptr")
   146  					}
   147  					inttoptr.ReplaceAllUsesWith(funcPtr)
   148  					inttoptr.EraseFromParentAsInstruction()
   149  				}
   150  				getFuncPtrCall.EraseFromParentAsInstruction()
   151  			default:
   152  				// There are multiple functions used in a func value that
   153  				// implement this signature.
   154  				// What we'll do is transform the following:
   155  				//     rawPtr := runtime.getFuncPtr(fn)
   156  				//     if func.rawPtr == nil {
   157  				//         runtime.nilpanic()
   158  				//     }
   159  				//     result := func.rawPtr(...args, func.context)
   160  				// into this:
   161  				//     if false {
   162  				//         runtime.nilpanic()
   163  				//     }
   164  				//     var result // Phi
   165  				//     switch fn.id {
   166  				//     case 0:
   167  				//         runtime.nilpanic()
   168  				//     case 1:
   169  				//         result = call first implementation...
   170  				//     case 2:
   171  				//         result = call second implementation...
   172  				//     default:
   173  				//         unreachable
   174  				//     }
   175  
   176  				// Remove some casts, checks, and the old call which we're going
   177  				// to replace.
   178  				var funcCall llvm.Value
   179  				for _, inttoptr := range getUses(getFuncPtrCall) {
   180  					if inttoptr.IsAIntToPtrInst().IsNil() {
   181  						panic("expected inttoptr")
   182  					}
   183  					for _, ptrUse := range getUses(inttoptr) {
   184  						if !ptrUse.IsABitCastInst().IsNil() {
   185  							for _, bitcastUse := range getUses(ptrUse) {
   186  								if bitcastUse.IsACallInst().IsNil() || bitcastUse.CalledValue().Name() != "runtime.isnil" {
   187  									panic("expected a call to runtime.isnil")
   188  								}
   189  								bitcastUse.ReplaceAllUsesWith(llvm.ConstInt(c.ctx.Int1Type(), 0, false))
   190  								bitcastUse.EraseFromParentAsInstruction()
   191  							}
   192  							ptrUse.EraseFromParentAsInstruction()
   193  						} else if !ptrUse.IsACallInst().IsNil() && ptrUse.CalledValue() == inttoptr {
   194  							if !funcCall.IsNil() {
   195  								panic("multiple calls on a single runtime.getFuncPtr")
   196  							}
   197  							funcCall = ptrUse
   198  						} else {
   199  							panic("unexpected getFuncPtrCall")
   200  						}
   201  					}
   202  				}
   203  				if funcCall.IsNil() {
   204  					panic("expected exactly one call use of a runtime.getFuncPtr")
   205  				}
   206  
   207  				// The block that cannot be reached with correct funcValues (to
   208  				// help the optimizer).
   209  				c.builder.SetInsertPointBefore(funcCall)
   210  				defaultBlock := llvm.AddBasicBlock(funcCall.InstructionParent().Parent(), "func.default")
   211  				c.builder.SetInsertPointAtEnd(defaultBlock)
   212  				c.builder.CreateUnreachable()
   213  
   214  				// Create the switch.
   215  				c.builder.SetInsertPointBefore(funcCall)
   216  				sw := c.builder.CreateSwitch(funcID, defaultBlock, len(functions)+1)
   217  
   218  				// Split right after the switch. We will need to insert a few
   219  				// basic blocks in this gap.
   220  				nextBlock := c.splitBasicBlock(sw, llvm.NextBasicBlock(sw.InstructionParent()), "func.next")
   221  
   222  				// The 0 case, which is actually a nil check.
   223  				nilBlock := llvm.InsertBasicBlock(nextBlock, "func.nil")
   224  				c.builder.SetInsertPointAtEnd(nilBlock)
   225  				c.createRuntimeCall("nilpanic", nil, "")
   226  				c.builder.CreateUnreachable()
   227  				sw.AddCase(llvm.ConstInt(c.uintptrType, 0, false), nilBlock)
   228  
   229  				// Gather the list of parameters for every call we're going to
   230  				// make.
   231  				callParams := make([]llvm.Value, funcCall.OperandsCount()-1)
   232  				for i := range callParams {
   233  					callParams[i] = funcCall.Operand(i)
   234  				}
   235  
   236  				// If the call produces a value, we need to get it using a PHI
   237  				// node.
   238  				phiBlocks := make([]llvm.BasicBlock, len(functions))
   239  				phiValues := make([]llvm.Value, len(functions))
   240  				for i, fn := range functions {
   241  					// Insert a switch case.
   242  					bb := llvm.InsertBasicBlock(nextBlock, "func.call"+strconv.Itoa(fn.id))
   243  					c.builder.SetInsertPointAtEnd(bb)
   244  					result := c.builder.CreateCall(fn.funcPtr, callParams, "")
   245  					c.builder.CreateBr(nextBlock)
   246  					sw.AddCase(llvm.ConstInt(c.uintptrType, uint64(fn.id), false), bb)
   247  					phiBlocks[i] = bb
   248  					phiValues[i] = result
   249  				}
   250  				// Create the PHI node so that the call result flows into the
   251  				// next block (after the split). This is only necessary when the
   252  				// call produced a value.
   253  				if funcCall.Type().TypeKind() != llvm.VoidTypeKind {
   254  					c.builder.SetInsertPointBefore(nextBlock.FirstInstruction())
   255  					phi := c.builder.CreatePHI(funcCall.Type(), "")
   256  					phi.AddIncoming(phiValues, phiBlocks)
   257  					funcCall.ReplaceAllUsesWith(phi)
   258  				}
   259  
   260  				// Finally, remove the old instructions.
   261  				funcCall.EraseFromParentAsInstruction()
   262  				for _, inttoptr := range getUses(getFuncPtrCall) {
   263  					inttoptr.EraseFromParentAsInstruction()
   264  				}
   265  				getFuncPtrCall.EraseFromParentAsInstruction()
   266  			}
   267  		}
   268  	}
   269  }