github.com/taubyte/vm-wasm-utils@v1.0.2/binary/code.go (about)

     1  package binary
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"io"
     7  	"math"
     8  
     9  	wasm "github.com/taubyte/vm-wasm-utils"
    10  	"github.com/taubyte/vm-wasm-utils/leb128"
    11  )
    12  
    13  func decodeCode(r *bytes.Reader) (*wasm.Code, error) {
    14  	ss, _, err := leb128.DecodeUint32(r)
    15  	if err != nil {
    16  		return nil, fmt.Errorf("get the size of code: %w", err)
    17  	}
    18  	remaining := int64(ss)
    19  
    20  	// parse locals
    21  	ls, bytesRead, err := leb128.DecodeUint32(r)
    22  	remaining -= int64(bytesRead)
    23  	if err != nil {
    24  		return nil, fmt.Errorf("get the size locals: %v", err)
    25  	} else if remaining < 0 {
    26  		return nil, io.EOF
    27  	}
    28  
    29  	var nums []uint64
    30  	var types []wasm.ValueType
    31  	var sum uint64
    32  	var n uint32
    33  	for i := uint32(0); i < ls; i++ {
    34  		n, bytesRead, err = leb128.DecodeUint32(r)
    35  		remaining -= int64(bytesRead) + 1 // +1 for the subsequent ReadByte
    36  		if err != nil {
    37  			return nil, fmt.Errorf("read n of locals: %v", err)
    38  		} else if remaining < 0 {
    39  			return nil, io.EOF
    40  		}
    41  
    42  		sum += uint64(n)
    43  		nums = append(nums, uint64(n))
    44  
    45  		b, err := r.ReadByte()
    46  		if err != nil {
    47  			return nil, fmt.Errorf("read type of local: %v", err)
    48  		}
    49  		switch vt := b; vt {
    50  		case wasm.ValueTypeI32, wasm.ValueTypeF32, wasm.ValueTypeI64, wasm.ValueTypeF64,
    51  			wasm.ValueTypeFuncref, wasm.ValueTypeExternref, wasm.ValueTypeV128:
    52  			types = append(types, vt)
    53  		default:
    54  			return nil, fmt.Errorf("invalid local type: 0x%x", vt)
    55  		}
    56  	}
    57  
    58  	if sum > math.MaxUint32 {
    59  		return nil, fmt.Errorf("too many locals: %d", sum)
    60  	}
    61  
    62  	var localTypes []wasm.ValueType
    63  	for i, num := range nums {
    64  		t := types[i]
    65  		for j := uint64(0); j < num; j++ {
    66  			localTypes = append(localTypes, t)
    67  		}
    68  	}
    69  
    70  	body := make([]byte, remaining)
    71  	if _, err = io.ReadFull(r, body); err != nil {
    72  		return nil, fmt.Errorf("read body: %w", err)
    73  	}
    74  
    75  	if endIndex := len(body) - 1; endIndex < 0 || body[endIndex] != wasm.OpcodeEnd {
    76  		return nil, fmt.Errorf("expr not end with OpcodeEnd")
    77  	}
    78  
    79  	return &wasm.Code{Body: body, LocalTypes: localTypes}, nil
    80  }
    81  
    82  // encodeCode returns the wasm.Code encoded in WebAssembly 1.0 (20191205) Binary Format.
    83  //
    84  // See https://www.w3.org/TR/2019/REC-wasm-core-1-20191205/#binary-code
    85  func encodeCode(c *wasm.Code) []byte {
    86  	if c.GoFunc != nil {
    87  		panic("BUG: GoFunc is not encodable")
    88  	}
    89  
    90  	// local blocks compress locals while preserving index order by grouping locals of the same type.
    91  	// https://www.w3.org/TR/2019/REC-wasm-core-1-20191205/#code-section%E2%91%A0
    92  	localBlockCount := uint32(0) // how many blocks of locals with the same type (types can repeat!)
    93  	var localBlocks []byte
    94  	localTypeLen := len(c.LocalTypes)
    95  	if localTypeLen > 0 {
    96  		i := localTypeLen - 1
    97  		var runCount uint32              // count of the same type
    98  		var lastValueType wasm.ValueType // initialize to an invalid type 0
    99  
   100  		// iterate backwards so it is easier to size prefix
   101  		for ; i >= 0; i-- {
   102  			vt := c.LocalTypes[i]
   103  			if lastValueType != vt {
   104  				if runCount != 0 { // Only on the first iteration, this is zero when vt is compared against invalid
   105  					localBlocks = append(leb128.EncodeUint32(runCount), localBlocks...)
   106  				}
   107  				lastValueType = vt
   108  				localBlocks = append(leb128.EncodeUint32(uint32(vt)), localBlocks...) // reuse the EncodeUint32 cache
   109  				localBlockCount++
   110  				runCount = 1
   111  			} else {
   112  				runCount++
   113  			}
   114  		}
   115  		localBlocks = append(leb128.EncodeUint32(runCount), localBlocks...)
   116  		localBlocks = append(leb128.EncodeUint32(localBlockCount), localBlocks...)
   117  	} else {
   118  		localBlocks = leb128.EncodeUint32(0)
   119  	}
   120  	code := append(localBlocks, c.Body...)
   121  	return append(leb128.EncodeUint32(uint32(len(code))), code...)
   122  }