github.com/tinygo-org/tinygo@v0.31.3-0.20240404173401-90b0bf646c27/transform/gc.go (about)

     1  package transform
     2  
     3  import (
     4  	"tinygo.org/x/go-llvm"
     5  )
     6  
     7  // MakeGCStackSlots converts all calls to runtime.trackPointer to explicit
     8  // stores to stack slots that are scannable by the GC.
     9  func MakeGCStackSlots(mod llvm.Module) bool {
    10  	// Check whether there are allocations at all.
    11  	alloc := mod.NamedFunction("runtime.alloc")
    12  	if alloc.IsNil() {
    13  		// Nothing to. Make sure all remaining bits and pieces for stack
    14  		// chains are neutralized.
    15  		for _, call := range getUses(mod.NamedFunction("runtime.trackPointer")) {
    16  			call.EraseFromParentAsInstruction()
    17  		}
    18  		stackChainStart := mod.NamedGlobal("runtime.stackChainStart")
    19  		if !stackChainStart.IsNil() {
    20  			stackChainStart.SetLinkage(llvm.InternalLinkage)
    21  			stackChainStart.SetInitializer(llvm.ConstNull(stackChainStart.GlobalValueType()))
    22  			stackChainStart.SetGlobalConstant(true)
    23  		}
    24  		return false
    25  	}
    26  
    27  	trackPointer := mod.NamedFunction("runtime.trackPointer")
    28  	if trackPointer.IsNil() || trackPointer.FirstUse().IsNil() {
    29  		return false // nothing to do
    30  	}
    31  
    32  	ctx := mod.Context()
    33  	builder := ctx.NewBuilder()
    34  	defer builder.Dispose()
    35  	targetData := llvm.NewTargetData(mod.DataLayout())
    36  	defer targetData.Dispose()
    37  	uintptrType := ctx.IntType(targetData.PointerSize() * 8)
    38  
    39  	// Look at *all* functions to see whether they are free of function pointer
    40  	// calls.
    41  	// This takes less than 5ms for ~100kB of WebAssembly but would perhaps be
    42  	// faster when written in C++ (to avoid the CGo overhead).
    43  	funcsWithFPCall := map[llvm.Value]struct{}{}
    44  	n := 0
    45  	for fn := mod.FirstFunction(); !fn.IsNil(); fn = llvm.NextFunction(fn) {
    46  		n++
    47  		if _, ok := funcsWithFPCall[fn]; ok {
    48  			continue // already found
    49  		}
    50  		done := false
    51  		for bb := fn.FirstBasicBlock(); !bb.IsNil() && !done; bb = llvm.NextBasicBlock(bb) {
    52  			for call := bb.FirstInstruction(); !call.IsNil() && !done; call = llvm.NextInstruction(call) {
    53  				if call.IsACallInst().IsNil() {
    54  					continue // only looking at calls
    55  				}
    56  				called := call.CalledValue()
    57  				if !called.IsAFunction().IsNil() {
    58  					continue // only looking for function pointers
    59  				}
    60  				funcsWithFPCall[fn] = struct{}{}
    61  				markParentFunctions(funcsWithFPCall, fn)
    62  				done = true
    63  			}
    64  		}
    65  	}
    66  
    67  	// Determine which functions need stack objects. Many leaf functions don't
    68  	// need it: it only causes overhead for them.
    69  	// Actually, in one test it was only able to eliminate stack object from 12%
    70  	// of functions that had a call to runtime.trackPointer (8 out of 68
    71  	// functions), so this optimization is not as big as it may seem.
    72  	allocatingFunctions := map[llvm.Value]struct{}{} // set of allocating functions
    73  
    74  	// Work from runtime.alloc and trace all parents to check which functions do
    75  	// a heap allocation (and thus which functions do not).
    76  	markParentFunctions(allocatingFunctions, alloc)
    77  
    78  	// Also trace all functions that call a function pointer.
    79  	for fn := range funcsWithFPCall {
    80  		// Assume that functions that call a function pointer do a heap
    81  		// allocation as a conservative guess because the called function might
    82  		// do a heap allocation.
    83  		allocatingFunctions[fn] = struct{}{}
    84  		markParentFunctions(allocatingFunctions, fn)
    85  	}
    86  
    87  	// Collect some variables used below in the loop.
    88  	stackChainStart := mod.NamedGlobal("runtime.stackChainStart")
    89  	if stackChainStart.IsNil() {
    90  		// This may be reached in a weird scenario where we call runtime.alloc but the garbage collector is unreachable.
    91  		// This can be accomplished by allocating 0 bytes.
    92  		// There is no point in tracking anything.
    93  		for _, use := range getUses(trackPointer) {
    94  			use.EraseFromParentAsInstruction()
    95  		}
    96  		return false
    97  	}
    98  	stackChainStart.SetLinkage(llvm.InternalLinkage)
    99  	stackChainStartType := stackChainStart.GlobalValueType()
   100  	stackChainStart.SetInitializer(llvm.ConstNull(stackChainStartType))
   101  
   102  	// Iterate until runtime.trackPointer has no uses left.
   103  	for use := trackPointer.FirstUse(); !use.IsNil(); use = trackPointer.FirstUse() {
   104  		// Pick the first use of runtime.trackPointer.
   105  		call := use.User()
   106  		if call.IsACallInst().IsNil() {
   107  			panic("expected runtime.trackPointer use to be a call")
   108  		}
   109  
   110  		// Pick the parent function.
   111  		fn := call.InstructionParent().Parent()
   112  
   113  		if _, ok := allocatingFunctions[fn]; !ok {
   114  			// This function nor any of the functions it calls (recursively)
   115  			// allocate anything from the heap, so it will not trigger a garbage
   116  			// collection cycle. Thus, it does not need to track local pointer
   117  			// values.
   118  			// This is a useful optimization but not as big as you might guess,
   119  			// as described above (it avoids stack objects for ~12% of
   120  			// functions).
   121  			call.EraseFromParentAsInstruction()
   122  			continue
   123  		}
   124  
   125  		// Find all calls to runtime.trackPointer in this function.
   126  		var calls []llvm.Value
   127  		var returns []llvm.Value
   128  		for bb := fn.FirstBasicBlock(); !bb.IsNil(); bb = llvm.NextBasicBlock(bb) {
   129  			for inst := bb.FirstInstruction(); !inst.IsNil(); inst = llvm.NextInstruction(inst) {
   130  				switch inst.InstructionOpcode() {
   131  				case llvm.Call:
   132  					if inst.CalledValue() == trackPointer {
   133  						calls = append(calls, inst)
   134  					}
   135  				case llvm.Ret:
   136  					returns = append(returns, inst)
   137  				}
   138  			}
   139  		}
   140  
   141  		// Determine what to do with each call.
   142  		var pointers []llvm.Value
   143  		for _, call := range calls {
   144  			ptr := call.Operand(0)
   145  			call.EraseFromParentAsInstruction()
   146  
   147  			// Some trivial optimizations.
   148  			if ptr.IsAInstruction().IsNil() {
   149  				continue
   150  			}
   151  			switch ptr.InstructionOpcode() {
   152  			case llvm.GetElementPtr:
   153  				// Check for all zero offsets.
   154  				// Sometimes LLVM rewrites bitcasts to zero-index GEPs, and we still need to track the GEP.
   155  				n := ptr.OperandsCount()
   156  				var hasOffset bool
   157  				for i := 1; i < n; i++ {
   158  					offset := ptr.Operand(i)
   159  					if offset.IsAConstantInt().IsNil() || offset.ZExtValue() != 0 {
   160  						hasOffset = true
   161  						break
   162  					}
   163  				}
   164  
   165  				if hasOffset {
   166  					// These values do not create new values: the values already
   167  					// existed locally in this function so must have been tracked
   168  					// already.
   169  					continue
   170  				}
   171  			case llvm.PHI:
   172  				// While the value may have already been tracked, it may be overwritten in a loop.
   173  				// Therefore, a second copy must be created to ensure that it is tracked over the entirety of its lifetime.
   174  			case llvm.ExtractValue, llvm.BitCast:
   175  				// These instructions do not create new values, but their
   176  				// original value may not be tracked. So keep tracking them for
   177  				// now.
   178  				// With more analysis, it should be possible to optimize a
   179  				// significant chunk of these away.
   180  			case llvm.Call, llvm.Load, llvm.IntToPtr:
   181  				// These create new values so must be stored locally. But
   182  				// perhaps some of these can be fused when they actually refer
   183  				// to the same value.
   184  			default:
   185  				// Ambiguous. These instructions are uncommon, but perhaps could
   186  				// be optimized if needed.
   187  			}
   188  
   189  			if ptr := stripPointerCasts(ptr); !ptr.IsAAllocaInst().IsNil() {
   190  				// Allocas don't need to be tracked because they are allocated
   191  				// on the C stack which is scanned separately.
   192  				continue
   193  			}
   194  			pointers = append(pointers, ptr)
   195  		}
   196  
   197  		if len(pointers) == 0 {
   198  			// This function does not need to keep track of stack pointers.
   199  			continue
   200  		}
   201  
   202  		// Determine the type of the required stack slot.
   203  		fields := []llvm.Type{
   204  			stackChainStartType, // Pointer to parent frame.
   205  			uintptrType,         // Number of elements in this frame.
   206  		}
   207  		for _, ptr := range pointers {
   208  			fields = append(fields, ptr.Type())
   209  		}
   210  		stackObjectType := ctx.StructType(fields, false)
   211  
   212  		// Create the stack object at the function entry.
   213  		builder.SetInsertPointBefore(fn.EntryBasicBlock().FirstInstruction())
   214  		stackObject := builder.CreateAlloca(stackObjectType, "gc.stackobject")
   215  		initialStackObject := llvm.ConstNull(stackObjectType)
   216  		numSlots := (targetData.TypeAllocSize(stackObjectType) - uint64(targetData.PointerSize())*2) / uint64(targetData.ABITypeAlignment(uintptrType))
   217  		numSlotsValue := llvm.ConstInt(uintptrType, numSlots, false)
   218  		initialStackObject = builder.CreateInsertValue(initialStackObject, numSlotsValue, 1, "")
   219  		builder.CreateStore(initialStackObject, stackObject)
   220  
   221  		// Update stack start.
   222  		parent := builder.CreateLoad(stackChainStartType, stackChainStart, "")
   223  		gep := builder.CreateGEP(stackObjectType, stackObject, []llvm.Value{
   224  			llvm.ConstInt(ctx.Int32Type(), 0, false),
   225  			llvm.ConstInt(ctx.Int32Type(), 0, false),
   226  		}, "")
   227  		builder.CreateStore(parent, gep)
   228  		builder.CreateStore(stackObject, stackChainStart)
   229  
   230  		// Do a store to the stack object after each new pointer that is created.
   231  		pointerStores := make(map[llvm.Value]struct{})
   232  		for i, ptr := range pointers {
   233  			// Insert the store after the pointer value is created.
   234  			insertionPoint := llvm.NextInstruction(ptr)
   235  			for !insertionPoint.IsAPHINode().IsNil() {
   236  				// PHI nodes are required to be at the start of the block.
   237  				// Insert after the last PHI node.
   238  				insertionPoint = llvm.NextInstruction(insertionPoint)
   239  			}
   240  			builder.SetInsertPointBefore(insertionPoint)
   241  
   242  			// Extract a pointer to the appropriate section of the stack object.
   243  			gep := builder.CreateGEP(stackObjectType, stackObject, []llvm.Value{
   244  				llvm.ConstInt(ctx.Int32Type(), 0, false),
   245  				llvm.ConstInt(ctx.Int32Type(), uint64(2+i), false),
   246  			}, "")
   247  
   248  			// Store the pointer into the stack slot.
   249  			store := builder.CreateStore(ptr, gep)
   250  			pointerStores[store] = struct{}{}
   251  		}
   252  
   253  		// Make sure this stack object is popped from the linked list of stack
   254  		// objects at return.
   255  		for _, ret := range returns {
   256  			// Check for any tail calls at this return.
   257  			prev := llvm.PrevInstruction(ret)
   258  			if !prev.IsNil() && !prev.IsABitCastInst().IsNil() {
   259  				// A bitcast can appear before a tail call, so skip backwards more.
   260  				prev = llvm.PrevInstruction(prev)
   261  			}
   262  			if !prev.IsNil() && !prev.IsACallInst().IsNil() {
   263  				// This is no longer a tail call.
   264  				prev.SetTailCall(false)
   265  			}
   266  			builder.SetInsertPointBefore(ret)
   267  			builder.CreateStore(parent, stackChainStart)
   268  		}
   269  	}
   270  
   271  	return true
   272  }
   273  
   274  // markParentFunctions traverses all parent function calls (recursively) and
   275  // adds them to the set of marked functions. It only considers function calls:
   276  // any other uses of such a function is ignored.
   277  func markParentFunctions(marked map[llvm.Value]struct{}, fn llvm.Value) {
   278  	worklist := []llvm.Value{fn}
   279  	for len(worklist) != 0 {
   280  		fn := worklist[len(worklist)-1]
   281  		worklist = worklist[:len(worklist)-1]
   282  		for _, use := range getUses(fn) {
   283  			if use.IsACallInst().IsNil() || use.CalledValue() != fn {
   284  				// Not the parent function.
   285  				continue
   286  			}
   287  			parent := use.InstructionParent().Parent()
   288  			if _, ok := marked[parent]; !ok {
   289  				marked[parent] = struct{}{}
   290  				worklist = append(worklist, parent)
   291  			}
   292  		}
   293  	}
   294  }