github.com/aykevl/tinygo@v0.5.0/compiler/defer.go (about)

     1  package compiler
     2  
     3  // This file implements the 'defer' keyword in Go.
     4  // Defer statements are implemented by transforming the function in the
     5  // following way:
     6  //   * Creating an alloca in the entry block that contains a pointer (initially
     7  //     null) to the linked list of defer frames.
     8  //   * Every time a defer statement is executed, a new defer frame is created
     9  //     using alloca with a pointer to the previous defer frame, and the head
    10  //     pointer in the entry block is replaced with a pointer to this defer
    11  //     frame.
    12  //   * On return, runtime.rundefers is called which calls all deferred functions
    13  //     from the head of the linked list until it has gone through all defer
    14  //     frames.
    15  
    16  import (
    17  	"github.com/tinygo-org/tinygo/ir"
    18  	"golang.org/x/tools/go/ssa"
    19  	"tinygo.org/x/go-llvm"
    20  )
    21  
    22  // deferInitFunc sets up this function for future deferred calls. It must be
    23  // called from within the entry block when this function contains deferred
    24  // calls.
    25  func (c *Compiler) deferInitFunc(frame *Frame) {
    26  	// Some setup.
    27  	frame.deferFuncs = make(map[*ir.Function]int)
    28  	frame.deferInvokeFuncs = make(map[string]int)
    29  	frame.deferClosureFuncs = make(map[*ir.Function]int)
    30  
    31  	// Create defer list pointer.
    32  	deferType := llvm.PointerType(c.mod.GetTypeByName("runtime._defer"), 0)
    33  	frame.deferPtr = c.builder.CreateAlloca(deferType, "deferPtr")
    34  	c.builder.CreateStore(llvm.ConstPointerNull(deferType), frame.deferPtr)
    35  }
    36  
    37  // emitDefer emits a single defer instruction, to be run when this function
    38  // returns.
    39  func (c *Compiler) emitDefer(frame *Frame, instr *ssa.Defer) error {
    40  	// The pointer to the previous defer struct, which we will replace to
    41  	// make a linked list.
    42  	next := c.builder.CreateLoad(frame.deferPtr, "defer.next")
    43  
    44  	var values []llvm.Value
    45  	valueTypes := []llvm.Type{c.uintptrType, next.Type()}
    46  	if instr.Call.IsInvoke() {
    47  		// Method call on an interface.
    48  
    49  		// Get callback type number.
    50  		methodName := instr.Call.Method.FullName()
    51  		if _, ok := frame.deferInvokeFuncs[methodName]; !ok {
    52  			frame.deferInvokeFuncs[methodName] = len(frame.allDeferFuncs)
    53  			frame.allDeferFuncs = append(frame.allDeferFuncs, &instr.Call)
    54  		}
    55  		callback := llvm.ConstInt(c.uintptrType, uint64(frame.deferInvokeFuncs[methodName]), false)
    56  
    57  		// Collect all values to be put in the struct (starting with
    58  		// runtime._defer fields, followed by the call parameters).
    59  		itf, err := c.parseExpr(frame, instr.Call.Value) // interface
    60  		if err != nil {
    61  			return err
    62  		}
    63  		receiverValue := c.builder.CreateExtractValue(itf, 1, "invoke.func.receiver")
    64  		values = []llvm.Value{callback, next, receiverValue}
    65  		valueTypes = append(valueTypes, c.i8ptrType)
    66  		for _, arg := range instr.Call.Args {
    67  			val, err := c.parseExpr(frame, arg)
    68  			if err != nil {
    69  				return err
    70  			}
    71  			values = append(values, val)
    72  			valueTypes = append(valueTypes, val.Type())
    73  		}
    74  
    75  	} else if callee, ok := instr.Call.Value.(*ssa.Function); ok {
    76  		// Regular function call.
    77  		fn := c.ir.GetFunction(callee)
    78  
    79  		if _, ok := frame.deferFuncs[fn]; !ok {
    80  			frame.deferFuncs[fn] = len(frame.allDeferFuncs)
    81  			frame.allDeferFuncs = append(frame.allDeferFuncs, fn)
    82  		}
    83  		callback := llvm.ConstInt(c.uintptrType, uint64(frame.deferFuncs[fn]), false)
    84  
    85  		// Collect all values to be put in the struct (starting with
    86  		// runtime._defer fields).
    87  		values = []llvm.Value{callback, next}
    88  		for _, param := range instr.Call.Args {
    89  			llvmParam, err := c.parseExpr(frame, param)
    90  			if err != nil {
    91  				return err
    92  			}
    93  			values = append(values, llvmParam)
    94  			valueTypes = append(valueTypes, llvmParam.Type())
    95  		}
    96  
    97  	} else if makeClosure, ok := instr.Call.Value.(*ssa.MakeClosure); ok {
    98  		// Immediately applied function literal with free variables.
    99  
   100  		// Extract the context from the closure. We won't need the function
   101  		// pointer.
   102  		// TODO: ignore this closure entirely and put pointers to the free
   103  		// variables directly in the defer struct, avoiding a memory allocation.
   104  		closure, err := c.parseExpr(frame, instr.Call.Value)
   105  		if err != nil {
   106  			return err
   107  		}
   108  		context := c.builder.CreateExtractValue(closure, 0, "")
   109  
   110  		// Get the callback number.
   111  		fn := c.ir.GetFunction(makeClosure.Fn.(*ssa.Function))
   112  		if _, ok := frame.deferClosureFuncs[fn]; !ok {
   113  			frame.deferClosureFuncs[fn] = len(frame.allDeferFuncs)
   114  			frame.allDeferFuncs = append(frame.allDeferFuncs, makeClosure)
   115  		}
   116  		callback := llvm.ConstInt(c.uintptrType, uint64(frame.deferClosureFuncs[fn]), false)
   117  
   118  		// Collect all values to be put in the struct (starting with
   119  		// runtime._defer fields, followed by all parameters including the
   120  		// context pointer).
   121  		values = []llvm.Value{callback, next}
   122  		for _, param := range instr.Call.Args {
   123  			llvmParam, err := c.parseExpr(frame, param)
   124  			if err != nil {
   125  				return err
   126  			}
   127  			values = append(values, llvmParam)
   128  			valueTypes = append(valueTypes, llvmParam.Type())
   129  		}
   130  		values = append(values, context)
   131  		valueTypes = append(valueTypes, context.Type())
   132  
   133  	} else {
   134  		return c.makeError(instr.Pos(), "todo: defer on uncommon function call type")
   135  	}
   136  
   137  	// Make a struct out of the collected values to put in the defer frame.
   138  	deferFrameType := c.ctx.StructType(valueTypes, false)
   139  	deferFrame, err := c.getZeroValue(deferFrameType)
   140  	if err != nil {
   141  		return err
   142  	}
   143  	for i, value := range values {
   144  		deferFrame = c.builder.CreateInsertValue(deferFrame, value, i, "")
   145  	}
   146  
   147  	// Put this struct in an alloca.
   148  	alloca := c.builder.CreateAlloca(deferFrameType, "defer.alloca")
   149  	c.builder.CreateStore(deferFrame, alloca)
   150  
   151  	// Push it on top of the linked list by replacing deferPtr.
   152  	allocaCast := c.builder.CreateBitCast(alloca, next.Type(), "defer.alloca.cast")
   153  	c.builder.CreateStore(allocaCast, frame.deferPtr)
   154  	return nil
   155  }
   156  
   157  // emitRunDefers emits code to run all deferred functions.
   158  func (c *Compiler) emitRunDefers(frame *Frame) error {
   159  	// Add a loop like the following:
   160  	//     for stack != nil {
   161  	//         _stack := stack
   162  	//         stack = stack.next
   163  	//         switch _stack.callback {
   164  	//         case 0:
   165  	//             // run first deferred call
   166  	//         case 1:
   167  	//             // run second deferred call
   168  	//             // etc.
   169  	//         default:
   170  	//             unreachable
   171  	//         }
   172  	//     }
   173  
   174  	// Create loop.
   175  	loophead := llvm.AddBasicBlock(frame.fn.LLVMFn, "rundefers.loophead")
   176  	loop := llvm.AddBasicBlock(frame.fn.LLVMFn, "rundefers.loop")
   177  	unreachable := llvm.AddBasicBlock(frame.fn.LLVMFn, "rundefers.default")
   178  	end := llvm.AddBasicBlock(frame.fn.LLVMFn, "rundefers.end")
   179  	c.builder.CreateBr(loophead)
   180  
   181  	// Create loop head:
   182  	//     for stack != nil {
   183  	c.builder.SetInsertPointAtEnd(loophead)
   184  	deferData := c.builder.CreateLoad(frame.deferPtr, "")
   185  	stackIsNil := c.builder.CreateICmp(llvm.IntEQ, deferData, llvm.ConstPointerNull(deferData.Type()), "stackIsNil")
   186  	c.builder.CreateCondBr(stackIsNil, end, loop)
   187  
   188  	// Create loop body:
   189  	//     _stack := stack
   190  	//     stack = stack.next
   191  	//     switch stack.callback {
   192  	c.builder.SetInsertPointAtEnd(loop)
   193  	nextStackGEP := c.builder.CreateGEP(deferData, []llvm.Value{
   194  		llvm.ConstInt(c.ctx.Int32Type(), 0, false),
   195  		llvm.ConstInt(c.ctx.Int32Type(), 1, false), // .next field
   196  	}, "stack.next.gep")
   197  	nextStack := c.builder.CreateLoad(nextStackGEP, "stack.next")
   198  	c.builder.CreateStore(nextStack, frame.deferPtr)
   199  	gep := c.builder.CreateGEP(deferData, []llvm.Value{
   200  		llvm.ConstInt(c.ctx.Int32Type(), 0, false),
   201  		llvm.ConstInt(c.ctx.Int32Type(), 0, false), // .callback field
   202  	}, "callback.gep")
   203  	callback := c.builder.CreateLoad(gep, "callback")
   204  	sw := c.builder.CreateSwitch(callback, unreachable, len(frame.allDeferFuncs))
   205  
   206  	for i, callback := range frame.allDeferFuncs {
   207  		// Create switch case, for example:
   208  		//     case 0:
   209  		//         // run first deferred call
   210  		block := llvm.AddBasicBlock(frame.fn.LLVMFn, "rundefers.callback")
   211  		sw.AddCase(llvm.ConstInt(c.uintptrType, uint64(i), false), block)
   212  		c.builder.SetInsertPointAtEnd(block)
   213  		switch callback := callback.(type) {
   214  		case *ssa.CallCommon:
   215  			// Call on an interface value.
   216  			if !callback.IsInvoke() {
   217  				panic("expected an invoke call, not a direct call")
   218  			}
   219  
   220  			// Get the real defer struct type and cast to it.
   221  			valueTypes := []llvm.Type{c.uintptrType, llvm.PointerType(c.mod.GetTypeByName("runtime._defer"), 0), c.i8ptrType}
   222  			for _, arg := range callback.Args {
   223  				llvmType, err := c.getLLVMType(arg.Type())
   224  				if err != nil {
   225  					return err
   226  				}
   227  				valueTypes = append(valueTypes, llvmType)
   228  			}
   229  			deferFrameType := c.ctx.StructType(valueTypes, false)
   230  			deferFramePtr := c.builder.CreateBitCast(deferData, llvm.PointerType(deferFrameType, 0), "deferFrame")
   231  
   232  			// Extract the params from the struct (including receiver).
   233  			forwardParams := []llvm.Value{}
   234  			zero := llvm.ConstInt(c.ctx.Int32Type(), 0, false)
   235  			for i := 2; i < len(valueTypes); i++ {
   236  				gep := c.builder.CreateGEP(deferFramePtr, []llvm.Value{zero, llvm.ConstInt(c.ctx.Int32Type(), uint64(i), false)}, "gep")
   237  				forwardParam := c.builder.CreateLoad(gep, "param")
   238  				forwardParams = append(forwardParams, forwardParam)
   239  			}
   240  
   241  			// Add the context parameter. An interface call cannot also be a
   242  			// closure but we have to supply the parameter anyway for platforms
   243  			// with a strict calling convention.
   244  			forwardParams = append(forwardParams, llvm.Undef(c.i8ptrType))
   245  
   246  			// Parent coroutine handle.
   247  			forwardParams = append(forwardParams, llvm.Undef(c.i8ptrType))
   248  
   249  			fnPtr, _, err := c.getInvokeCall(frame, callback)
   250  			if err != nil {
   251  				return err
   252  			}
   253  			c.createCall(fnPtr, forwardParams, "")
   254  
   255  		case *ir.Function:
   256  			// Direct call.
   257  
   258  			// Get the real defer struct type and cast to it.
   259  			valueTypes := []llvm.Type{c.uintptrType, llvm.PointerType(c.mod.GetTypeByName("runtime._defer"), 0)}
   260  			for _, param := range callback.Params {
   261  				llvmType, err := c.getLLVMType(param.Type())
   262  				if err != nil {
   263  					return err
   264  				}
   265  				valueTypes = append(valueTypes, llvmType)
   266  			}
   267  			deferFrameType := c.ctx.StructType(valueTypes, false)
   268  			deferFramePtr := c.builder.CreateBitCast(deferData, llvm.PointerType(deferFrameType, 0), "deferFrame")
   269  
   270  			// Extract the params from the struct.
   271  			forwardParams := []llvm.Value{}
   272  			zero := llvm.ConstInt(c.ctx.Int32Type(), 0, false)
   273  			for i := range callback.Params {
   274  				gep := c.builder.CreateGEP(deferFramePtr, []llvm.Value{zero, llvm.ConstInt(c.ctx.Int32Type(), uint64(i+2), false)}, "gep")
   275  				forwardParam := c.builder.CreateLoad(gep, "param")
   276  				forwardParams = append(forwardParams, forwardParam)
   277  			}
   278  
   279  			// Add the context parameter. We know it is ignored by the receiving
   280  			// function, but we have to pass one anyway.
   281  			forwardParams = append(forwardParams, llvm.Undef(c.i8ptrType))
   282  
   283  			// Parent coroutine handle.
   284  			forwardParams = append(forwardParams, llvm.Undef(c.i8ptrType))
   285  
   286  			// Call real function.
   287  			c.createCall(callback.LLVMFn, forwardParams, "")
   288  
   289  		case *ssa.MakeClosure:
   290  			// Get the real defer struct type and cast to it.
   291  			fn := c.ir.GetFunction(callback.Fn.(*ssa.Function))
   292  			valueTypes := []llvm.Type{c.uintptrType, llvm.PointerType(c.mod.GetTypeByName("runtime._defer"), 0)}
   293  			params := fn.Signature.Params()
   294  			for i := 0; i < params.Len(); i++ {
   295  				llvmType, err := c.getLLVMType(params.At(i).Type())
   296  				if err != nil {
   297  					return err
   298  				}
   299  				valueTypes = append(valueTypes, llvmType)
   300  			}
   301  			valueTypes = append(valueTypes, c.i8ptrType) // closure
   302  			deferFrameType := c.ctx.StructType(valueTypes, false)
   303  			deferFramePtr := c.builder.CreateBitCast(deferData, llvm.PointerType(deferFrameType, 0), "deferFrame")
   304  
   305  			// Extract the params from the struct.
   306  			forwardParams := []llvm.Value{}
   307  			zero := llvm.ConstInt(c.ctx.Int32Type(), 0, false)
   308  			for i := 2; i < len(valueTypes); i++ {
   309  				gep := c.builder.CreateGEP(deferFramePtr, []llvm.Value{zero, llvm.ConstInt(c.ctx.Int32Type(), uint64(i), false)}, "")
   310  				forwardParam := c.builder.CreateLoad(gep, "param")
   311  				forwardParams = append(forwardParams, forwardParam)
   312  			}
   313  
   314  			// Parent coroutine handle.
   315  			forwardParams = append(forwardParams, llvm.Undef(c.i8ptrType))
   316  
   317  			// Call deferred function.
   318  			c.createCall(fn.LLVMFn, forwardParams, "")
   319  
   320  		default:
   321  			panic("unknown deferred function type")
   322  		}
   323  
   324  		// Branch back to the start of the loop.
   325  		c.builder.CreateBr(loophead)
   326  	}
   327  
   328  	// Create default unreachable block:
   329  	//     default:
   330  	//         unreachable
   331  	//     }
   332  	c.builder.SetInsertPointAtEnd(unreachable)
   333  	c.builder.CreateUnreachable()
   334  
   335  	// End of loop.
   336  	c.builder.SetInsertPointAtEnd(end)
   337  	return nil
   338  }