github.com/n1ghtfa1l/go-vnt@v0.6.4-alpha.6/core/wavm/compile.go (about)

     1  // Copyright 2017 The go-interpreter Authors.  All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Package compile is used internally by wagon to convert standard structured
     6  // WebAssembly bytecode into an unstructured form suitable for execution by
     7  // it's VM.
     8  // The conversion process consists of translating block instruction sequences
     9  // and branch operators (br, br_if, br_table) to absolute jumps to PC values.
    10  // For instance, an instruction sequence like:
    11  //     loop
    12  //       i32.const 1
    13  //       get_local 0
    14  //       i32.add
    15  //       set_local 0
    16  //       get_local 1
    17  //       i32.const 1
    18  //       i32.add
    19  //       tee_local 1
    20  //       get_local 2
    21  //       i32.eq
    22  //       br_if 0
    23  //     end
    24  // Is "compiled" to:
    25  //     i32.const 1
    26  //     i32.add
    27  //     set_local 0
    28  //     get_local 1
    29  //     i32.const 1
    30  //     i32.add
    31  //     tee_local 1
    32  //     get_local 2
    33  //     i32.eq
    34  //     jmpnz <addr> <preserve> <discard>
    35  // Where jmpnz is a jump-if-not-zero operator that takes certain arguments
    36  // plus the jump address as immediates.
    37  // This is in contrast with original WebAssembly bytecode, where the target
    38  // of branch operators are relative block depths instead.
    39  package wavm
    40  
    41  import (
    42  	"bytes"
    43  	"encoding/binary"
    44  
    45  	"github.com/vntchain/go-vnt/core/wavm/gas"
    46  	"github.com/vntchain/go-vnt/core/wavm/utils"
    47  	"github.com/vntchain/go-vnt/log"
    48  	"github.com/vntchain/vnt-wasm/disasm"
    49  	"github.com/vntchain/vnt-wasm/vnt"
    50  	"github.com/vntchain/vnt-wasm/wasm"
    51  	ops "github.com/vntchain/vnt-wasm/wasm/operators"
    52  )
    53  
    54  // A small note on the usage of discard instructions:
    55  // A control operator sequence isn't allowed to access nor modify (pop) operands
    56  // that were pushed outside it. Therefore, each sequence has its own stack
    57  // that may or may not push a value to the original stack, depending on the
    58  // block's signature.
    59  // Instead of creating a new stack every time we enter a control structure,
    60  // we record the current stack height on encountering a control operator.
    61  // After we leave the sequence, the stack height is restored using the discard
    62  // operator. A block with a signature will push a value of that type on the parent
    63  // stack (that is, the stack of the parent block where this block started). The
    64  // OpDiscardPreserveTop operator allows us to preserve this value while
    65  // discarding the remaining ones.
    66  
    67  // Branches are rewritten as
    68  //     <jmp> <addr>
    69  // Where the address is an 8 byte address, initially set to zero. It is
    70  // later "patched" by patchOffset.
    71  
    72  var (
    73  	// OpJmp unconditionally jumps to the provided address.
    74  	OpJmp byte = 0x0c
    75  	// OpJmpZ jumps to the given address if the value at the top of the stack is zero.
    76  	OpJmpZ byte = 0x03
    77  	// OpJmpNz jumps to the given address if the value at the top of the
    78  	// stack is not zero. It also discards elements and optionally preserves
    79  	// the topmost value on the stack
    80  	OpJmpNz byte = 0x0d
    81  	// OpDiscard discards a given number of elements from the execution stack.
    82  	OpDiscard byte = 0x0b
    83  	// OpDiscardPreserveTop discards a given number of elements from the
    84  	// execution stack, while preserving the value on the top of the stack.
    85  	OpDiscardPreserveTop byte = 0x05
    86  )
    87  
    88  // block stores the information relevant for a block created by a control operator
    89  // sequence (if...else...end, loop...end, and block...end)
    90  type block struct {
    91  	// the byte offset to which the continuation of the label
    92  	// created by the block operator is located
    93  	// for 'loop', this is the offset of the loop operator itself
    94  	// for 'if', 'else', 'block', this is the 'end' operator
    95  	offset int64
    96  
    97  	// Whether this block is created by an 'if' operator
    98  	// in that case, the 'offset' field is set to the byte offset
    99  	// of the else branch, once the else operator is reached.
   100  	ifBlock bool
   101  	// if ... else ... end is compiled to
   102  	// jmpnz <else-addr> ... jmp <end-addr> ... <discard>
   103  	// elseAddrOffset is the byte offset of the else-addr address
   104  	// in the new/compiled byte buffer.
   105  	elseAddrOffset int64
   106  
   107  	// Whether this block is created by a 'loop' operator
   108  	// in that case, the 'offset' field is set at the end of the block
   109  	loopBlock bool
   110  
   111  	patchOffsets []int64 // A list of offsets in the bytecode stream that need to be patched with the correct jump addresses
   112  
   113  	discard      disasm.StackInfo   // Information about the stack created in this block, used while creating Discard instructions
   114  	branchTables []*vnt.BranchTable // All branch tables that were defined in this block.
   115  }
   116  
   117  type Mutable map[uint32]bool
   118  
   119  type CodeBlock struct {
   120  	stack map[int]*Stack
   121  }
   122  
   123  type Code struct {
   124  	Body     disasm.Instr
   125  	Children []*Code
   126  }
   127  
   128  func (c Code) Recursive() []disasm.Instr {
   129  	if len(c.Children) == 0 {
   130  		return []disasm.Instr{c.Body}
   131  	} else {
   132  		merge := []disasm.Instr{c.Body}
   133  		for _, v := range c.Children {
   134  			child := v.Body
   135  			tmp := merge
   136  			merge = append([]disasm.Instr{child}, tmp...)
   137  		}
   138  		return merge
   139  	}
   140  }
   141  
   142  // func (c Code) String() string {
   143  // 	if len(c.Children) == 0 {
   144  // 		return fmt.Sprintf("( %s )", c.Body)
   145  // 	} else {
   146  // 		merge := fmt.Sprintf("( %s", c.Body)
   147  // 		for _, v := range c.Children {
   148  // 			child := fmt.Sprintf("%s", v.String())
   149  // 			merge = fmt.Sprintf("%s\n%s", merge, child)
   150  // 		}
   151  // 		merge = fmt.Sprintf("%s\n)", merge)
   152  // 		return merge
   153  // 	}
   154  // }
   155  
   156  type Stack struct {
   157  	slice []*Code
   158  }
   159  
   160  func (s *Stack) Push(b *Code) {
   161  	s.slice = append(s.slice, b)
   162  }
   163  
   164  func (s *Stack) Pop() *Code {
   165  	v := s.Top()
   166  	s.slice = s.slice[:len(s.slice)-1]
   167  	return v
   168  }
   169  
   170  func (s *Stack) Top() *Code {
   171  	return s.slice[len(s.slice)-1]
   172  }
   173  
   174  func (s Stack) Len() int {
   175  	return len(s.slice)
   176  }
   177  
   178  func (cb *CodeBlock) buildCode(blockDepth int, n int) *Code {
   179  	code := &Code{}
   180  	stack := cb.stack[blockDepth]
   181  	for i := 0; i < n; i++ {
   182  		pop := stack.Pop()
   183  		tmp := code.Children
   184  		code.Children = append([]*Code{pop}, tmp...)
   185  		//w.code.Children = append(w.code.Children, pop)
   186  	}
   187  
   188  	return code
   189  }
   190  
   191  func CompileModule(module *wasm.Module, chainctx ChainContext, mutable Mutable) ([]vnt.Compiled, error) {
   192  	Compiled := make([]vnt.Compiled, len(module.FunctionIndexSpace))
   193  	for i, fn := range module.FunctionIndexSpace {
   194  		// Skip native methods as they need not be
   195  		// disassembled; simply add them at the end
   196  		// of the `funcs` array as is, as specified
   197  		// in the spec. See the "host functions"
   198  		// section of:
   199  		// https://webassembly.github.io/spec/core/exec/modules.html#allocation
   200  		if fn.IsHost() {
   201  			continue
   202  		}
   203  
   204  		var code []byte
   205  		var table []*vnt.BranchTable
   206  		var maxDepth int
   207  		totalLocalVars := 0
   208  
   209  		disassembly, err := disasm.Disassemble(fn, module)
   210  		if err != nil {
   211  			return nil, err
   212  		}
   213  
   214  		maxDepth = disassembly.MaxDepth
   215  
   216  		totalLocalVars += len(fn.Sig.ParamTypes)
   217  		for _, entry := range fn.Body.Locals {
   218  			totalLocalVars += int(entry.Count)
   219  		}
   220  		disassembly.Code = gas.InjectCounter(disassembly.Code, module, chainctx.GasRule)
   221  		code, table = Compile(disassembly.Code, module, mutable)
   222  		Compiled[i] = vnt.Compiled{
   223  			Code:           code,
   224  			Table:          table,
   225  			MaxDepth:       maxDepth,
   226  			TotalLocalVars: totalLocalVars,
   227  		}
   228  	}
   229  	return Compiled, nil
   230  }
   231  
   232  // func (cb *CodeBlock) addChild() {
   233  // 	cb.code.Children = append([]code)
   234  // }
   235  
   236  // Compile rewrites WebAssembly bytecode from its disassembly.
   237  // TODO(vibhavp): Add options for optimizing code. Operators like i32.reinterpret/f32
   238  // are no-ops, and can be safely removed.
   239  func Compile(disassembly []disasm.Instr, module *wasm.Module, mutable Mutable) ([]byte, []*vnt.BranchTable) {
   240  	buffer := new(bytes.Buffer)
   241  	branchTables := []*vnt.BranchTable{}
   242  
   243  	curBlockDepth := -1
   244  	blocks := make(map[int]*block) // maps nesting depths (labels) to blocks
   245  
   246  	blocks[-1] = &block{}
   247  
   248  	writeIndex, readIndex, _ := utils.GetIndex(module)
   249  	codeBlock := &CodeBlock{stack: map[int]*Stack{}}
   250  
   251  	newInstr := []disasm.Instr{}
   252  
   253  	for _, instr := range disassembly {
   254  		// fmt.Printf("compile instr %+v blockinfo %+v\n", instr, instr.Block)
   255  		var readInstr []disasm.Instr
   256  		var writeInstr []disasm.Instr
   257  		if instr.Unreachable {
   258  			continue
   259  		}
   260  		if codeBlock.stack[curBlockDepth] == nil {
   261  			codeBlock.stack[curBlockDepth] = &Stack{}
   262  		}
   263  		switch instr.Op.Code {
   264  		case ops.I32Const, ops.I64Const, ops.F32Const, ops.F64Const:
   265  			codeBlock.stack[curBlockDepth].Push(&Code{Body: instr})
   266  		case ops.I32Add, ops.I32Sub, ops.I32Mul, ops.I32DivS, ops.I32DivU, ops.I32RemS, ops.I32RemU, ops.I32And, ops.I32Or, ops.I32Xor, ops.I32Shl, ops.I32ShrS, ops.I32ShrU, ops.I32Rotl, ops.I32Rotr,
   267  			ops.I32Eq, ops.I32Ne, ops.I32LtS, ops.I32LtU, ops.I32LeS, ops.I32LeU, ops.I32GtS, ops.I32GtU, ops.I32GeS, ops.I32GeU,
   268  			ops.I64Add, ops.I64Sub, ops.I64Mul, ops.I64DivS, ops.I64DivU, ops.I64RemS, ops.I64RemU, ops.I64And, ops.I64Or, ops.I64Xor, ops.I64Shl, ops.I64ShrS, ops.I64ShrU, ops.I64Rotl, ops.I64Rotr,
   269  			ops.I64Eq, ops.I64Ne, ops.I64LtS, ops.I64LtU, ops.I64LeS, ops.I64LeU, ops.I64GtS, ops.I64GtU, ops.I64GeS, ops.I64GeU,
   270  			ops.F32Add, ops.F32Sub, ops.F32Mul, ops.F32Div, ops.F32Min, ops.F32Max, ops.F32Copysign,
   271  			ops.F32Eq, ops.F32Ne, ops.F32Lt, ops.F32Le, ops.F32Gt, ops.F32Ge,
   272  			ops.F64Add, ops.F64Sub, ops.F64Mul, ops.F64Div, ops.F64Min, ops.F64Max, ops.F64Copysign,
   273  			ops.F64Eq, ops.F64Ne, ops.F64Lt, ops.F64Le, ops.F64Gt, ops.F64Ge:
   274  			code := codeBlock.buildCode(curBlockDepth, 2)
   275  			code.Body = instr
   276  			codeBlock.stack[curBlockDepth].Push(code)
   277  		case ops.I32Clz, ops.I32Ctz, ops.I32Popcnt, ops.I32Eqz,
   278  			ops.I64Clz, ops.I64Ctz, ops.I64Popcnt, ops.I64Eqz,
   279  			ops.F32Sqrt, ops.F32Ceil, ops.F32Floor, ops.F32Trunc, ops.F32Nearest, ops.F32Abs, ops.F32Neg,
   280  			ops.F64Sqrt, ops.F64Ceil, ops.F64Floor, ops.F64Trunc, ops.F64Nearest, ops.F64Abs, ops.F64Neg,
   281  			ops.I32WrapI64, ops.I64ExtendUI32, ops.I64ExtendSI32,
   282  			ops.I32TruncUF32, ops.I32TruncUF64, ops.I64TruncUF32, ops.I64TruncUF64,
   283  			ops.I32TruncSF32, ops.I32TruncSF64, ops.I64TruncSF32, ops.I64TruncSF64,
   284  			ops.F32DemoteF64, ops.F64PromoteF32,
   285  			ops.F32ConvertUI32, ops.F32ConvertUI64, ops.F64ConvertUI32, ops.F64ConvertUI64,
   286  			ops.F32ConvertSI32, ops.F32ConvertSI64, ops.F64ConvertSI32, ops.F64ConvertSI64,
   287  			ops.I32ReinterpretF32, ops.I64ReinterpretF64,
   288  			ops.F32ReinterpretI32, ops.F64ReinterpretI64:
   289  			code := codeBlock.buildCode(curBlockDepth, 1)
   290  			code.Body = instr
   291  			codeBlock.stack[curBlockDepth].Push(code)
   292  		case ops.Drop:
   293  			code := codeBlock.buildCode(curBlockDepth, 1)
   294  			code.Body = instr
   295  		case ops.GetLocal, ops.GetGlobal:
   296  			codeBlock.stack[curBlockDepth].Push(&Code{Body: instr})
   297  		case ops.SetLocal, ops.SetGlobal:
   298  			code := codeBlock.buildCode(curBlockDepth, 1)
   299  			code.Body = instr
   300  		case ops.TeeLocal:
   301  			code := codeBlock.buildCode(curBlockDepth, 1)
   302  			code.Body = instr
   303  			codeBlock.stack[curBlockDepth].Push(code)
   304  		case ops.I32Load, ops.I64Load, ops.F32Load, ops.F64Load, ops.I32Load8s, ops.I32Load8u, ops.I32Load16s, ops.I32Load16u, ops.I64Load8s, ops.I64Load8u, ops.I64Load16s, ops.I64Load16u, ops.I64Load32s, ops.I64Load32u:
   305  			// memory_immediate has two fields, the alignment and the offset.
   306  			// The former is simply an optimization hint and can be safely
   307  			// discarded.
   308  			instr.Immediates = []interface{}{instr.Immediates[1].(uint32)}
   309  
   310  			arg := codeBlock.stack[curBlockDepth].slice[codeBlock.stack[curBlockDepth].Len()-1]
   311  			if arg.Body.Op.Code == ops.I32Const {
   312  				constBaseInstr := arg.Body
   313  				constOffsetOp, _ := ops.New(ops.I32Const)
   314  				constInstr := disasm.Instr{Op: constOffsetOp, Immediates: []interface{}{int32(instr.Immediates[0].(uint32))}}
   315  				callOp, _ := ops.New(ops.Call)
   316  				callInstr := disasm.Instr{Op: callOp, Immediates: []interface{}{uint32(readIndex)}}
   317  				readInstr = []disasm.Instr{constBaseInstr, constInstr, callInstr}
   318  			}
   319  			code := codeBlock.buildCode(curBlockDepth, 1)
   320  			code.Body = instr
   321  			codeBlock.stack[curBlockDepth].Push(code)
   322  		case ops.I32Store, ops.I64Store, ops.F32Store, ops.F64Store, ops.I32Store8, ops.I32Store16, ops.I64Store8, ops.I64Store16, ops.I64Store32:
   323  			// memory_immediate has two fields, the alignment and the offset.
   324  			// The former is simply an optimization hint and can be safely
   325  			// discarded.
   326  			instr.Immediates = []interface{}{instr.Immediates[1].(uint32)}
   327  
   328  			arg := codeBlock.stack[curBlockDepth].slice[codeBlock.stack[curBlockDepth].Len()-2]
   329  			if arg.Body.Op.Code == ops.I32Const {
   330  				constBaseInstr := arg.Body
   331  				constOffsetOp, _ := ops.New(ops.I32Const)
   332  				constoffsetInstr := disasm.Instr{Op: constOffsetOp, Immediates: []interface{}{int32(instr.Immediates[0].(uint32))}}
   333  				callOp, _ := ops.New(ops.Call)
   334  				callInstr := disasm.Instr{Op: callOp, Immediates: []interface{}{uint32(writeIndex)}}
   335  				writeInstr = []disasm.Instr{constBaseInstr, constoffsetInstr, callInstr}
   336  			}
   337  			code := codeBlock.buildCode(curBlockDepth, 2)
   338  			code.Body = instr
   339  		case ops.Call, ops.CallIndirect:
   340  			index := instr.Immediates[0].(uint32)
   341  			sig := module.GetFunction(int(index)).Sig
   342  			if instr.Op.Code == ops.CallIndirect {
   343  				sig = &module.Types.Entries[int(index)]
   344  			}
   345  			parms := len(sig.ParamTypes)
   346  			returns := len(sig.ReturnTypes)
   347  			code := codeBlock.buildCode(curBlockDepth, parms)
   348  			code.Body = instr
   349  			//codeBlock.stack.Push(codeBlock.code)
   350  			if returns != 0 {
   351  				codeBlock.stack[curBlockDepth].Push(code)
   352  			}
   353  		case ops.If:
   354  			curBlockDepth++
   355  			buffer.WriteByte(OpJmpZ)
   356  			blocks[curBlockDepth] = &block{
   357  				ifBlock:        true,
   358  				elseAddrOffset: int64(buffer.Len()),
   359  			}
   360  			// the address to jump to if the condition for `if` is false
   361  			// (i.e when the value on the top of the stack is 0)
   362  			binary.Write(buffer, binary.LittleEndian, int64(0))
   363  
   364  			op, err := ops.New(OpJmpZ)
   365  			if err != nil {
   366  				panic(err)
   367  			}
   368  			ins := disasm.Instr{
   369  				Op:         op,
   370  				Immediates: [](interface{}){},
   371  			}
   372  			ins.Immediates = append(ins.Immediates, int64(0))
   373  			newInstr = append(newInstr, ins)
   374  
   375  			sig := instr.Immediates[0].(wasm.BlockType)
   376  			code := codeBlock.buildCode(curBlockDepth-1, 1)
   377  			if sig != wasm.BlockTypeEmpty {
   378  				code.Body = instr
   379  				codeBlock.stack[curBlockDepth-1].Push(code)
   380  			}
   381  			// else {
   382  			// 	if curBlockDepth == 0 {
   383  			// 		code := &Code{Body: instr}
   384  			// 		codeBlock.stack[curBlockDepth].Push(code)
   385  			// 	} else {
   386  			// 		parentCode := codeBlock.stack[curBlockDepth-1].Top()
   387  			// 		code := &Code{Body: instr}
   388  			// 		parentCode.Children = append(parentCode.Children, code)
   389  			// 	}
   390  			// }
   391  			continue
   392  		case ops.Loop:
   393  			// there is no condition for entering a loop block
   394  			curBlockDepth++
   395  			blocks[curBlockDepth] = &block{
   396  				offset:    int64(buffer.Len()),
   397  				ifBlock:   false,
   398  				loopBlock: true,
   399  				discard:   *instr.NewStack,
   400  			}
   401  
   402  			sig := instr.Immediates[0].(wasm.BlockType)
   403  			if sig != wasm.BlockTypeEmpty {
   404  				code := &Code{Body: instr}
   405  				codeBlock.stack[curBlockDepth-1].Push(code)
   406  			}
   407  			// else {
   408  			// 	if curBlockDepth == 0 {
   409  			// 		code := &Code{Body: instr}
   410  			// 		codeBlock.stack[curBlockDepth].Push(code)
   411  			// 	} else {
   412  			// 		parentCode := codeBlock.stack[curBlockDepth-1].Top()
   413  			// 		code := &Code{Body: instr}
   414  			// 		parentCode.Children = append(parentCode.Children, code)
   415  			// 	}
   416  			// }
   417  
   418  			continue
   419  		case ops.Block:
   420  			curBlockDepth++
   421  			blocks[curBlockDepth] = &block{
   422  				ifBlock: false,
   423  				discard: *instr.NewStack,
   424  			}
   425  
   426  			sig := instr.Immediates[0].(wasm.BlockType)
   427  			if sig != wasm.BlockTypeEmpty {
   428  				code := &Code{Body: instr}
   429  				codeBlock.stack[curBlockDepth-1].Push(code)
   430  			}
   431  			// else {
   432  			// 	if curBlockDepth == 0 {
   433  			// 		code := &Code{Body: instr}
   434  			// 		codeBlock.stack[curBlockDepth].Push(code)
   435  			// 	} else {
   436  			// 		parentCode := codeBlock.stack[curBlockDepth-1].Top()
   437  			// 		code := &Code{Body: instr}
   438  			// 		parentCode.Children = append(parentCode.Children, code)
   439  			// 	}
   440  			// }
   441  			continue
   442  		case ops.Else:
   443  			ifInstr := disassembly[instr.Block.ElseIfIndex] // the corresponding `if` instruction for this else
   444  			if ifInstr.NewStack != nil && ifInstr.NewStack.StackTopDiff != 0 {
   445  				// add code for jumping out of a taken if branch
   446  				if ifInstr.NewStack.PreserveTop {
   447  					buffer.WriteByte(OpDiscardPreserveTop)
   448  				} else {
   449  					buffer.WriteByte(OpDiscard)
   450  				}
   451  				binary.Write(buffer, binary.LittleEndian, ifInstr.NewStack.StackTopDiff)
   452  			}
   453  			buffer.WriteByte(OpJmp)
   454  			ifBlockEndOffset := int64(buffer.Len())
   455  			binary.Write(buffer, binary.LittleEndian, int64(0))
   456  
   457  			curOffset := int64(buffer.Len())
   458  			ifBlock := blocks[curBlockDepth]
   459  			code := buffer.Bytes()
   460  
   461  			buffer = patchOffset(code, ifBlock.elseAddrOffset, curOffset)
   462  			// this is no longer an if block
   463  			ifBlock.ifBlock = false
   464  			ifBlock.patchOffsets = append(ifBlock.patchOffsets, ifBlockEndOffset)
   465  
   466  			op, err := ops.New(OpJmp)
   467  			if err != nil {
   468  				panic(err)
   469  			}
   470  			ins := disasm.Instr{
   471  				Op:         op,
   472  				Immediates: [](interface{}){},
   473  			}
   474  			ins.Immediates = append(ins.Immediates, int64(0))
   475  			newInstr = append(newInstr, ins)
   476  
   477  			continue
   478  		case ops.End:
   479  			depth := curBlockDepth
   480  			block := blocks[depth]
   481  
   482  			if instr.NewStack.StackTopDiff != 0 {
   483  				// when exiting a block, discard elements to
   484  				// restore stack height.
   485  				var op ops.Op
   486  				var err error
   487  				var ins disasm.Instr
   488  				if instr.NewStack.PreserveTop {
   489  					// this is true when the block has a
   490  					// signature, and therefore pushes
   491  					// a value on to the stack
   492  					buffer.WriteByte(OpDiscardPreserveTop)
   493  
   494  					op, err = ops.New(OpDiscardPreserveTop)
   495  					if err != nil {
   496  						panic(err)
   497  					}
   498  					ins = disasm.Instr{
   499  						Op:         op,
   500  						Immediates: [](interface{}){},
   501  					}
   502  
   503  				} else {
   504  					buffer.WriteByte(OpDiscard)
   505  
   506  					op, err = ops.New(OpDiscard)
   507  					if err != nil {
   508  						panic(err)
   509  					}
   510  					ins = disasm.Instr{
   511  						Op:         op,
   512  						Immediates: [](interface{}){},
   513  					}
   514  				}
   515  				binary.Write(buffer, binary.LittleEndian, instr.NewStack.StackTopDiff)
   516  
   517  				ins.Immediates = append(ins.Immediates, instr.NewStack.StackTopDiff)
   518  				newInstr = append(newInstr, ins)
   519  			}
   520  
   521  			if !block.loopBlock { // is a normal block
   522  				block.offset = int64(buffer.Len())
   523  				if block.ifBlock {
   524  					code := buffer.Bytes()
   525  					buffer = patchOffset(code, block.elseAddrOffset, int64(block.offset))
   526  				}
   527  			}
   528  
   529  			for _, offset := range block.patchOffsets {
   530  				code := buffer.Bytes()
   531  				buffer = patchOffset(code, offset, block.offset)
   532  			}
   533  
   534  			for _, table := range block.branchTables {
   535  				table.PatchTable(table.BlocksLen-depth-1, int64(block.offset))
   536  			}
   537  
   538  			delete(blocks, curBlockDepth)
   539  			curBlockDepth--
   540  			continue
   541  		case ops.Br:
   542  			if instr.NewStack != nil && instr.NewStack.StackTopDiff != 0 {
   543  				var op ops.Op
   544  				var err error
   545  				var ins disasm.Instr
   546  				if instr.NewStack.PreserveTop {
   547  					buffer.WriteByte(OpDiscardPreserveTop)
   548  
   549  					op, err = ops.New(OpDiscardPreserveTop)
   550  					if err != nil {
   551  						panic(err)
   552  					}
   553  					ins = disasm.Instr{
   554  						Op:         op,
   555  						Immediates: [](interface{}){},
   556  					}
   557  
   558  				} else {
   559  					buffer.WriteByte(OpDiscard)
   560  
   561  					op, err = ops.New(OpDiscard)
   562  					if err != nil {
   563  						panic(err)
   564  					}
   565  					ins = disasm.Instr{
   566  						Op:         op,
   567  						Immediates: [](interface{}){},
   568  					}
   569  				}
   570  				binary.Write(buffer, binary.LittleEndian, instr.NewStack.StackTopDiff)
   571  
   572  				ins.Immediates = append(ins.Immediates, instr.NewStack.StackTopDiff)
   573  				newInstr = append(newInstr, ins)
   574  			}
   575  			buffer.WriteByte(OpJmp)
   576  			label := int(instr.Immediates[0].(uint32))
   577  			block := blocks[curBlockDepth-int(label)]
   578  			block.patchOffsets = append(block.patchOffsets, int64(buffer.Len()))
   579  			// write the jump address
   580  			binary.Write(buffer, binary.LittleEndian, int64(0))
   581  
   582  			op, err := ops.New(OpJmp)
   583  			if err != nil {
   584  				panic(err)
   585  			}
   586  			ins := disasm.Instr{
   587  				Op:         op,
   588  				Immediates: [](interface{}){},
   589  			}
   590  			ins.Immediates = append(ins.Immediates, int64(0))
   591  			newInstr = append(newInstr, ins)
   592  			continue
   593  		case ops.BrIf:
   594  			buffer.WriteByte(OpJmpNz)
   595  			label := int(instr.Immediates[0].(uint32))
   596  			block := blocks[curBlockDepth-int(label)]
   597  			block.patchOffsets = append(block.patchOffsets, int64(buffer.Len()))
   598  			// write the jump address
   599  			binary.Write(buffer, binary.LittleEndian, int64(0))
   600  
   601  			op, err := ops.New(OpJmpNz)
   602  			if err != nil {
   603  				panic(err)
   604  			}
   605  			ins := disasm.Instr{
   606  				Op:         op,
   607  				Immediates: [](interface{}){},
   608  			}
   609  
   610  			var stackTopDiff int64
   611  			// write whether we need to preserve the top
   612  			if instr.NewStack == nil || !instr.NewStack.PreserveTop || instr.NewStack.StackTopDiff == 0 {
   613  				buffer.WriteByte(byte(0))
   614  				ins.Immediates = append(ins.Immediates, false)
   615  			} else {
   616  				stackTopDiff = instr.NewStack.StackTopDiff
   617  				buffer.WriteByte(byte(1))
   618  				ins.Immediates = append(ins.Immediates, true)
   619  			}
   620  			// write the number of elements on the stack we need to discard
   621  			binary.Write(buffer, binary.LittleEndian, stackTopDiff)
   622  
   623  			ins.Immediates = append(ins.Immediates, int64(0))
   624  			ins.Immediates = append(ins.Immediates, stackTopDiff)
   625  			newInstr = append(newInstr, ins)
   626  			continue
   627  		case ops.BrTable:
   628  			branchTable := &vnt.BranchTable{
   629  				// we subtract one for the implicit block created by
   630  				// the function body
   631  				BlocksLen: len(blocks) - 1,
   632  			}
   633  			targetCount := instr.Immediates[0].(uint32)
   634  			branchTable.Targets = make([]vnt.Target, targetCount)
   635  			for i := range branchTable.Targets {
   636  				// The first immediates is the number of targets, so we ignore that
   637  				label := int64(instr.Immediates[i+1].(uint32))
   638  				branchTable.Targets[i].Addr = label
   639  				branch := instr.Branches[i]
   640  
   641  				branchTable.Targets[i].Return = branch.IsReturn
   642  				branchTable.Targets[i].Discard = branch.StackTopDiff
   643  				branchTable.Targets[i].PreserveTop = branch.PreserveTop
   644  			}
   645  			defaultLabel := int64(instr.Immediates[len(instr.Immediates)-1].(uint32))
   646  			branchTable.DefaultTarget.Addr = defaultLabel
   647  			defaultBranch := instr.Branches[targetCount]
   648  			branchTable.DefaultTarget.Return = defaultBranch.IsReturn
   649  			branchTable.DefaultTarget.Discard = defaultBranch.StackTopDiff
   650  			branchTable.DefaultTarget.PreserveTop = defaultBranch.PreserveTop
   651  			branchTables = append(branchTables, branchTable)
   652  			for _, block := range blocks {
   653  				block.branchTables = append(block.branchTables, branchTable)
   654  			}
   655  
   656  			buffer.WriteByte(ops.BrTable)
   657  			binary.Write(buffer, binary.LittleEndian, int64(len(branchTables)-1))
   658  
   659  			op, err := ops.New(ops.BrTable)
   660  			if err != nil {
   661  				panic(err)
   662  			}
   663  			ins := disasm.Instr{
   664  				Op:         op,
   665  				Immediates: [](interface{}){},
   666  			}
   667  			ins.Immediates = append(ins.Immediates, int64(len(branchTables)-1))
   668  			newInstr = append(newInstr, ins)
   669  		}
   670  		if len(readInstr) != 0 {
   671  			if readIndex != -1 {
   672  				for _, instr := range readInstr {
   673  					buffer.WriteByte(instr.Op.Code)
   674  					for _, imm := range instr.Immediates {
   675  						err := binary.Write(buffer, binary.LittleEndian, imm)
   676  						if err != nil {
   677  							panic(err)
   678  						}
   679  					}
   680  				}
   681  				newInstr = append(newInstr, readInstr...)
   682  			} else {
   683  				log.Warn("Compile warning", "Msg", "Can't find ReadWithPointer env function!!")
   684  			}
   685  		}
   686  		buffer.WriteByte(instr.Op.Code)
   687  		for _, imm := range instr.Immediates {
   688  			err := binary.Write(buffer, binary.LittleEndian, imm)
   689  			if err != nil {
   690  				panic(err)
   691  			}
   692  		}
   693  		newInstr = append(newInstr, instr)
   694  		if len(writeInstr) != 0 {
   695  			if writeIndex != -1 {
   696  				for _, instr := range writeInstr {
   697  					buffer.WriteByte(instr.Op.Code)
   698  					for _, imm := range instr.Immediates {
   699  						err := binary.Write(buffer, binary.LittleEndian, imm)
   700  						if err != nil {
   701  							panic(err)
   702  						}
   703  					}
   704  				}
   705  				newInstr = append(newInstr, writeInstr...)
   706  			} else {
   707  				log.Warn("Compile warning", "Msg", "Can't find WriteWithPointer env function!!")
   708  			}
   709  		}
   710  	}
   711  
   712  	// writing nop as the last instructions allows us to branch out of the
   713  	// function (ie, return)
   714  	addr := buffer.Len()
   715  	buffer.WriteByte(ops.Nop)
   716  
   717  	// patch all references to the "root" block of the function body
   718  	for _, offset := range blocks[-1].patchOffsets {
   719  		code := buffer.Bytes()
   720  		buffer = patchOffset(code, offset, int64(addr))
   721  	}
   722  
   723  	for _, table := range branchTables {
   724  		table.PatchedAddrs = nil
   725  	}
   726  	return buffer.Bytes(), branchTables
   727  }
   728  
   729  // replace the address starting at start with addr
   730  func patchOffset(code []byte, start int64, addr int64) *bytes.Buffer {
   731  	var shift uint
   732  	for i := int64(0); i < 8; i++ {
   733  		code[start+i] = byte(addr >> shift)
   734  		shift += 8
   735  	}
   736  
   737  	buf := new(bytes.Buffer)
   738  	buf.Write(code)
   739  	return buf
   740  }