github.com/arnodel/golua@v0.0.0-20230215163904-e0b5347eaaa1/astcomp/compstat.go (about)

     1  package astcomp
     2  
     3  import (
     4  	"github.com/arnodel/golua/ast"
     5  	"github.com/arnodel/golua/ir"
     6  	"github.com/arnodel/golua/ops"
     7  )
     8  
     9  //
    10  // Statement compilation
    11  //
    12  
    13  // Static check that no statement is overlooked.
    14  var _ ast.StatProcessor = (*compiler)(nil)
    15  
    16  // ProcessAssignStat compiles a AssignStat.
    17  func (c *compiler) ProcessAssignStat(s ast.AssignStat) {
    18  
    19  	// Evaluate the right hand side
    20  	resultRegs := make([]ir.Register, len(s.Dest))
    21  	c.compileExpList(s.Src, resultRegs)
    22  
    23  	// Compile the lvalues and assignments
    24  	c.compileAssignments(s.Dest, resultRegs)
    25  }
    26  
    27  // ProcessBlockStat compiles a BlockStat.
    28  func (c *compiler) ProcessBlockStat(s ast.BlockStat) {
    29  	c.PushContext()
    30  	c.compileBlock(s)
    31  	c.PopContext()
    32  }
    33  
    34  // ProcessBreakStat compiles a BreakStat.
    35  func (c *compiler) ProcessBreakStat(s ast.BreakStat) {
    36  	c.emitJump(s, breakLblName)
    37  }
    38  
    39  // ProcessEmptyStat compiles a EmptyStat.
    40  func (c *compiler) ProcessEmptyStat(s ast.EmptyStat) {
    41  	// Nothing to compile!
    42  }
    43  
    44  // ProcessForInStat compiles a ForInStat.
    45  func (c *compiler) ProcessForInStat(s ast.ForInStat) {
    46  	initRegs := make([]ir.Register, 4)
    47  	c.compileExpList(s.Params, initRegs)
    48  	fReg := initRegs[0]
    49  	sReg := initRegs[1]
    50  	varReg := initRegs[2]
    51  	closeReg := initRegs[3]
    52  
    53  	c.PushContext()
    54  	c.PushCloseAction(closeReg) // Now closeReg is no longer needed
    55  	c.DeclareLocal(loopFRegName, fReg)
    56  	c.DeclareLocal(loopSRegName, sReg)
    57  	c.DeclareLocal(loopVarRegName, varReg)
    58  
    59  	loopLbl := c.GetNewLabel()
    60  	must(c.EmitLabelNoLine(loopLbl))
    61  
    62  	nameAttribs := make([]ast.NameAttrib, len(s.Vars))
    63  	for i, name := range s.Vars {
    64  		nameAttribs[i] = ast.NewNameAttrib(name, nil, ast.NoAttrib)
    65  	}
    66  	c.CompileStat(ast.LocalStat{
    67  		NameAttribs: nameAttribs,
    68  		Values: []ast.ExpNode{ast.FunctionCall{BFunctionCall: &ast.BFunctionCall{
    69  			Location: s.Params[0].Locate(), // To report the line where the function is if it fails
    70  			Target:   ast.Name{Location: s.Location, Val: string(loopFRegName)},
    71  			Args: []ast.ExpNode{
    72  				ast.Name{Location: s.Location, Val: string(loopSRegName)},
    73  				ast.Name{Location: s.Location, Val: string(loopVarRegName)},
    74  			},
    75  		}}},
    76  	})
    77  	var1, _ := c.GetRegister(ir.Name(s.Vars[0].Val))
    78  
    79  	testReg := c.GetFreeRegister()
    80  	c.emitLoadConst(s, ir.NilType{}, testReg)
    81  	c.emitInstr(s, ir.Combine{
    82  		Dst:  testReg,
    83  		Op:   ops.OpEq,
    84  		Lsrc: var1,
    85  		Rsrc: testReg,
    86  	})
    87  	endLbl := c.DeclareGotoLabelNoLine(breakLblName)
    88  	c.emitInstr(s, ir.JumpIf{Cond: testReg, Label: endLbl})
    89  	c.emitInstr(s, ir.Transform{Dst: varReg, Op: ops.OpId, Src: var1})
    90  	c.compileBlock(s.Body)
    91  
    92  	c.emitInstr(s, ir.Jump{Label: loopLbl})
    93  
    94  	must(c.EmitGotoLabel(breakLblName))
    95  	c.PopContext()
    96  
    97  }
    98  
    99  // ProcessForStat compiles a ForStat.
   100  func (c *compiler) ProcessForStat(s ast.ForStat) {
   101  
   102  	// Get register for current value of i and initialise it
   103  	startReg := c.GetFreeRegister()
   104  	r := c.compileExp(s.Start, startReg)
   105  	ir.EmitMoveNoLine(c.CodeBuilder, startReg, r)
   106  	c.TakeRegister(startReg)
   107  
   108  	// Get register for the stop value and initialise it
   109  	stopReg := c.GetFreeRegister()
   110  	r = c.compileExp(s.Stop, stopReg)
   111  	ir.EmitMoveNoLine(c.CodeBuilder, stopReg, r)
   112  	c.TakeRegister(stopReg)
   113  
   114  	// Get register for the step value and initialise it
   115  	stepReg := c.GetFreeRegister()
   116  	r = c.compileExp(s.Step, stepReg)
   117  	ir.EmitMoveNoLine(c.CodeBuilder, stepReg, r)
   118  	c.TakeRegister(stepReg)
   119  
   120  	// Prepare the for loop
   121  	c.emitInstr(s, ir.PrepForLoop{
   122  		Start: startReg,
   123  		Stop:  stopReg,
   124  		Step:  stepReg,
   125  	})
   126  
   127  	c.PushContext()
   128  	loopLbl := c.GetNewLabel()
   129  	must(c.EmitLabelNoLine(loopLbl))
   130  	endLbl := c.DeclareGotoLabelNoLine(breakLblName)
   131  
   132  	// If startReg is nil, then there are no iterations in the loop
   133  	c.EmitNoLine(ir.JumpIf{
   134  		Cond:  startReg,
   135  		Label: endLbl,
   136  		Not:   true,
   137  	})
   138  
   139  	// Here compile the loop body
   140  	c.PushContext()
   141  	iterReg := c.GetFreeRegister()
   142  	// We copy the loop variable because the body may change it
   143  	// iter <- start
   144  	ir.EmitMoveNoLine(c.CodeBuilder, iterReg, startReg)
   145  	c.DeclareLocal(ir.Name(s.Var.Val), iterReg)
   146  	c.compileBlock(s.Body)
   147  	c.PopContext()
   148  
   149  	//Advance the for loop
   150  	c.emitInstr(s, ir.AdvForLoop{
   151  		Start: startReg,
   152  		Stop:  stopReg,
   153  		Step:  stepReg,
   154  	})
   155  	// If startReg is not nil, it means the loop continues
   156  	c.EmitNoLine(ir.JumpIf{
   157  		Cond:  startReg,
   158  		Label: loopLbl,
   159  	})
   160  
   161  	// break:
   162  	must(c.EmitGotoLabel(breakLblName))
   163  	c.PopContext()
   164  
   165  	c.ReleaseRegister(startReg)
   166  	c.ReleaseRegister(stopReg)
   167  	c.ReleaseRegister(stepReg)
   168  }
   169  
   170  // ProcessFunctionCallStat compiles a FunctionCallStat.
   171  func (c *compiler) ProcessFunctionCallStat(f ast.FunctionCall) {
   172  	c.compileCall(*f.BFunctionCall, false)
   173  	c.emitInstr(f, ir.Receive{})
   174  }
   175  
   176  // ProcessGotoStat compiles a GotoStat.
   177  func (c *compiler) ProcessGotoStat(s ast.GotoStat) {
   178  	c.emitJump(s, ir.Name(s.Label.Val))
   179  }
   180  
   181  // ProcessIfStat compiles a IfStat.
   182  func (c *compiler) ProcessIfStat(s ast.IfStat) {
   183  	endLbl := c.GetNewLabel()
   184  	lbl := c.GetNewLabel()
   185  	c.compileCond(s.If, lbl)
   186  	for _, s := range s.ElseIfs {
   187  		c.emitInstr(s.Cond, ir.Jump{Label: endLbl}) // TODO: better location
   188  		must(c.EmitLabelNoLine(lbl))
   189  		lbl = c.GetNewLabel()
   190  		c.compileCond(s, lbl)
   191  	}
   192  	if s.Else != nil {
   193  		c.emitInstr(s, ir.Jump{Label: endLbl}) // TODO: better location
   194  		must(c.EmitLabelNoLine(lbl))
   195  		c.CompileStat(s.Else)
   196  	} else {
   197  		must(c.EmitLabelNoLine(lbl))
   198  	}
   199  	must(c.EmitLabelNoLine(endLbl))
   200  }
   201  
   202  func (c *compiler) compileCond(s ast.CondStat, lbl ir.Label) {
   203  	condReg := c.compileExpNoDestHint(s.Cond)
   204  	c.emitInstr(s.Cond, ir.JumpIf{Cond: condReg, Label: lbl, Not: true})
   205  	c.CompileStat(s.Body)
   206  }
   207  
   208  // ProcessLabelStat compiles a LabelStat.
   209  func (c *compiler) ProcessLabelStat(s ast.LabelStat) {
   210  	if err := c.EmitGotoLabel(ir.Name(s.Name.Val)); err != nil {
   211  		panic(Error{
   212  			Where:   s,
   213  			Message: err.Error(),
   214  		})
   215  	}
   216  }
   217  
   218  // ProcessLocalFunctionStat compiles a LocalFunctionStat.
   219  func (c *compiler) ProcessLocalFunctionStat(s ast.LocalFunctionStat) {
   220  	fReg := c.GetFreeRegister()
   221  	c.DeclareLocal(ir.Name(s.Name.Val), fReg)
   222  	c.compileExpInto(s.Function, fReg)
   223  }
   224  
   225  // ProcessLocalStat compiles a LocalStat.
   226  func (c *compiler) ProcessLocalStat(s ast.LocalStat) {
   227  	localRegs := make([]ir.Register, len(s.NameAttribs))
   228  	c.compileExpList(s.Values, localRegs)
   229  	for i, reg := range localRegs {
   230  		c.ReleaseRegister(reg)
   231  		c.DeclareLocal(ir.Name(s.NameAttribs[i].Name.Val), reg)
   232  		switch s.NameAttribs[i].Attrib {
   233  		case ast.NoAttrib:
   234  			// Nothing to do
   235  		case ast.ConstAttrib:
   236  			c.MarkConstantReg(reg)
   237  		case ast.CloseAttrib:
   238  			c.MarkConstantReg(reg)
   239  			c.PushCloseAction(reg)
   240  		default:
   241  			panic(compilerBug{})
   242  		}
   243  	}
   244  }
   245  
   246  // ProcessRepeatStat compiles a RepeatStat.
   247  func (c *compiler) ProcessRepeatStat(s ast.RepeatStat) {
   248  	c.PushContext()
   249  	c.DeclareGotoLabelNoLine(breakLblName)
   250  
   251  	loopLbl := c.GetNewLabel()
   252  	must(c.EmitLabelNoLine(loopLbl))
   253  	pop := c.compileBlockNoPop(s.Body, false)
   254  	condReg := c.compileExpNoDestHint(s.Cond)
   255  	negReg := c.GetFreeRegister()
   256  	c.emitInstr(s.Cond, ir.Transform{
   257  		Op:  ops.OpNot,
   258  		Dst: negReg,
   259  		Src: condReg,
   260  	})
   261  	pop()
   262  	c.emitInstr(s.Cond, ir.JumpIf{
   263  		Cond:  negReg,
   264  		Label: loopLbl,
   265  	})
   266  
   267  	must(c.EmitGotoLabel(breakLblName))
   268  	c.PopContext()
   269  }
   270  
   271  // ProcessWhileStat compiles a WhileStat.
   272  func (c *compiler) ProcessWhileStat(s ast.WhileStat) {
   273  	c.PushContext()
   274  	stopLbl := c.DeclareGotoLabelNoLine(breakLblName)
   275  
   276  	loopLbl := c.GetNewLabel()
   277  	must(c.EmitLabelNoLine(loopLbl))
   278  
   279  	c.compileCond(s.CondStat, stopLbl)
   280  
   281  	c.emitInstr(s, ir.Jump{Label: loopLbl}) // TODO: better location
   282  
   283  	must(c.EmitGotoLabel(breakLblName))
   284  	c.PopContext()
   285  }
   286  
   287  func (c *compiler) CompileStat(s ast.Stat) {
   288  	s.ProcessStat(c)
   289  }
   290  
   291  //
   292  // Helper functions
   293  //
   294  
   295  func (c *compiler) compileBlock(s ast.BlockStat) {
   296  	c.compileBlockNoPop(s, true)()
   297  }
   298  
   299  func (c *compiler) compileBlockNoPop(s ast.BlockStat, complete bool) func() {
   300  	totalDepth := 0
   301  	noBackLabels := getLabels(c.CodeBuilder, s.Stats)
   302  	truncLen := len(s.Stats)
   303  	if complete && !noBackLabels && s.Return == nil {
   304  		truncLen -= getBackLabels(c.CodeBuilder, s.Stats)
   305  	}
   306  	for i, stat := range s.Stats {
   307  		switch stat.(type) {
   308  		case ast.LocalStat, ast.LocalFunctionStat:
   309  			totalDepth++
   310  			c.PushContext()
   311  			getLabels(c.CodeBuilder, s.Stats[i+1:truncLen])
   312  		}
   313  		c.CompileStat(stat)
   314  	}
   315  	if s.Return != nil {
   316  		if fc, ok := c.getTailCall(s.Return); ok {
   317  			c.compileCall(*fc.BFunctionCall, true)
   318  		} else {
   319  			contReg := c.getCallerReg()
   320  			c.compilePushArgs(s.Return, contReg)
   321  			var loc ast.Locator
   322  			if len(s.Return) > 0 {
   323  				loc = s.Return[0]
   324  			}
   325  			c.emitInstr(loc, ir.Call{
   326  				Cont: contReg,
   327  				Tail: true,
   328  			})
   329  		}
   330  	}
   331  	return func() {
   332  		for ; totalDepth > 0; totalDepth-- {
   333  			c.PopContext()
   334  		}
   335  	}
   336  }
   337  
   338  // Declares goto labels for the statements in order, stopping when encountering
   339  // a local variable declaration.  Return true if the whole slice was processed
   340  // (so no need to get back labels)
   341  func getLabels(c *ir.CodeBuilder, statements []ast.Stat) bool {
   342  	for _, stat := range statements {
   343  		switch s := stat.(type) {
   344  		case ast.LabelStat:
   345  			_, err := c.DeclareUniqueGotoLabel(ir.Name(s.Name.Val), s.Name.StartPos().Line)
   346  			if err != nil {
   347  				panic(Error{
   348  					Where:   s.Name,
   349  					Message: err.Error(),
   350  				})
   351  			}
   352  		case ast.LocalStat, ast.LocalFunctionStat:
   353  			return false
   354  		}
   355  	}
   356  	return true
   357  }
   358  
   359  // Process the statements in reverse order to declare "back labels".  Return the
   360  // number of statements processed.
   361  func getBackLabels(c *ir.CodeBuilder, statements []ast.Stat) int {
   362  	count := 0
   363  	for i := len(statements) - 1; i >= 0; i-- {
   364  		switch s := statements[i].(type) {
   365  		case ast.EmptyStat:
   366  			// That doesn't count
   367  		case ast.LabelStat:
   368  			_, err := c.DeclareUniqueGotoLabel(ir.Name(s.Name.Val), s.Name.StartPos().Line)
   369  			if err != nil {
   370  				panic(Error{
   371  					Where:   s.Name,
   372  					Message: err.Error(),
   373  				})
   374  			}
   375  		default:
   376  			return count
   377  		}
   378  		count++
   379  	}
   380  	return count
   381  }
   382  
   383  func (c *compiler) getTailCall(rtn []ast.ExpNode) (ast.FunctionCall, bool) {
   384  	if len(rtn) != 1 || c.HasPendingCloseActions() {
   385  		return ast.FunctionCall{}, false
   386  	}
   387  	fc, ok := rtn[0].(ast.FunctionCall)
   388  	return fc, ok
   389  }