github.com/AR1011/wazero@v1.0.5/internal/engine/wazevo/engine_cache.go (about)

     1  package wazevo
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"crypto/sha256"
     7  	"encoding/binary"
     8  	"fmt"
     9  	"io"
    10  	"runtime"
    11  	"unsafe"
    12  
    13  	"github.com/AR1011/wazero/experimental"
    14  	"github.com/AR1011/wazero/internal/engine/wazevo/backend"
    15  	"github.com/AR1011/wazero/internal/engine/wazevo/ssa"
    16  	"github.com/AR1011/wazero/internal/engine/wazevo/wazevoapi"
    17  	"github.com/AR1011/wazero/internal/filecache"
    18  	"github.com/AR1011/wazero/internal/platform"
    19  	"github.com/AR1011/wazero/internal/u32"
    20  	"github.com/AR1011/wazero/internal/u64"
    21  	"github.com/AR1011/wazero/internal/wasm"
    22  )
    23  
    24  // fileCacheKey returns a key for the file cache.
    25  // In order to avoid collisions with the existing compiler, we do not use m.ID directly,
    26  // but instead we rehash it with magic.
    27  func fileCacheKey(m *wasm.Module) (ret filecache.Key) {
    28  	s := sha256.New()
    29  	s.Write(m.ID[:])
    30  	s.Write(magic)
    31  	s.Sum(ret[:0])
    32  	return
    33  }
    34  
    35  func (e *engine) addCompiledModule(module *wasm.Module, cm *compiledModule) (err error) {
    36  	e.addCompiledModuleToMemory(module, cm)
    37  	if !module.IsHostModule && e.fileCache != nil {
    38  		err = e.addCompiledModuleToCache(module, cm)
    39  	}
    40  	return
    41  }
    42  
    43  func (e *engine) getCompiledModule(module *wasm.Module, listeners []experimental.FunctionListener, ensureTermination bool) (cm *compiledModule, ok bool, err error) {
    44  	cm, ok = e.getCompiledModuleFromMemory(module)
    45  	if ok {
    46  		return
    47  	}
    48  	cm, ok, err = e.getCompiledModuleFromCache(module)
    49  	if ok {
    50  		cm.parent = e
    51  		cm.module = module
    52  		cm.sharedFunctions = e.sharedFunctions
    53  		cm.ensureTermination = ensureTermination
    54  		cm.offsets = wazevoapi.NewModuleContextOffsetData(module, len(listeners) > 0)
    55  		if len(listeners) > 0 {
    56  			cm.listeners = listeners
    57  			cm.listenerBeforeTrampolines = make([]*byte, len(module.TypeSection))
    58  			cm.listenerAfterTrampolines = make([]*byte, len(module.TypeSection))
    59  			for i := range module.TypeSection {
    60  				typ := &module.TypeSection[i]
    61  				before, after := e.getListenerTrampolineForType(typ)
    62  				cm.listenerBeforeTrampolines[i] = before
    63  				cm.listenerAfterTrampolines[i] = after
    64  			}
    65  		}
    66  		e.addCompiledModuleToMemory(module, cm)
    67  		ssaBuilder := ssa.NewBuilder()
    68  		machine := newMachine()
    69  		be := backend.NewCompiler(context.Background(), machine, ssaBuilder)
    70  		cm.executables.compileEntryPreambles(module, machine, be)
    71  
    72  		// Set the finalizer.
    73  		e.setFinalizer(cm.executables, executablesFinalizer)
    74  	}
    75  	return
    76  }
    77  
    78  func (e *engine) addCompiledModuleToMemory(m *wasm.Module, cm *compiledModule) {
    79  	e.mux.Lock()
    80  	defer e.mux.Unlock()
    81  	e.compiledModules[m.ID] = cm
    82  	if len(cm.executable) > 0 {
    83  		e.addCompiledModuleToSortedList(cm)
    84  	}
    85  }
    86  
    87  func (e *engine) getCompiledModuleFromMemory(module *wasm.Module) (cm *compiledModule, ok bool) {
    88  	e.mux.RLock()
    89  	defer e.mux.RUnlock()
    90  	cm, ok = e.compiledModules[module.ID]
    91  	return
    92  }
    93  
    94  func (e *engine) addCompiledModuleToCache(module *wasm.Module, cm *compiledModule) (err error) {
    95  	if e.fileCache == nil || module.IsHostModule {
    96  		return
    97  	}
    98  	err = e.fileCache.Add(fileCacheKey(module), serializeCompiledModule(e.wazeroVersion, cm))
    99  	return
   100  }
   101  
   102  func (e *engine) getCompiledModuleFromCache(module *wasm.Module) (cm *compiledModule, hit bool, err error) {
   103  	if e.fileCache == nil || module.IsHostModule {
   104  		return
   105  	}
   106  
   107  	// Check if the entries exist in the external cache.
   108  	var cached io.ReadCloser
   109  	cached, hit, err = e.fileCache.Get(fileCacheKey(module))
   110  	if !hit || err != nil {
   111  		return
   112  	}
   113  
   114  	// Otherwise, we hit the cache on external cache.
   115  	// We retrieve *code structures from `cached`.
   116  	var staleCache bool
   117  	// Note: cached.Close is ensured to be called in deserializeCodes.
   118  	cm, staleCache, err = deserializeCompiledModule(e.wazeroVersion, cached)
   119  	if err != nil {
   120  		hit = false
   121  		return
   122  	} else if staleCache {
   123  		return nil, false, e.fileCache.Delete(fileCacheKey(module))
   124  	}
   125  	return
   126  }
   127  
   128  var magic = []byte{'W', 'A', 'Z', 'E', 'V', 'O'}
   129  
   130  func serializeCompiledModule(wazeroVersion string, cm *compiledModule) io.Reader {
   131  	buf := bytes.NewBuffer(nil)
   132  	// First 6 byte: WAZEVO header.
   133  	buf.Write(magic)
   134  	// Next 1 byte: length of version:
   135  	buf.WriteByte(byte(len(wazeroVersion)))
   136  	// Version of wazero.
   137  	buf.WriteString(wazeroVersion)
   138  	// Number of *code (== locally defined functions in the module): 4 bytes.
   139  	buf.Write(u32.LeBytes(uint32(len(cm.functionOffsets))))
   140  	for _, offset := range cm.functionOffsets {
   141  		// The offset of this function in the executable (8 bytes).
   142  		buf.Write(u64.LeBytes(uint64(offset)))
   143  	}
   144  	// The length of code segment (8 bytes).
   145  	buf.Write(u64.LeBytes(uint64(len(cm.executable))))
   146  	// Append the native code.
   147  	buf.Write(cm.executable)
   148  	if sm := cm.sourceMap; len(sm.executableOffsets) > 0 {
   149  		buf.WriteByte(1) // indicates that source map is present.
   150  		l := len(sm.wasmBinaryOffsets)
   151  		buf.Write(u64.LeBytes(uint64(l)))
   152  		executableAddr := uintptr(unsafe.Pointer(&cm.executable[0]))
   153  		for i := 0; i < l; i++ {
   154  			buf.Write(u64.LeBytes(sm.wasmBinaryOffsets[i]))
   155  			// executableOffsets is absolute address, so we need to subtract executableAddr.
   156  			buf.Write(u64.LeBytes(uint64(sm.executableOffsets[i] - executableAddr)))
   157  		}
   158  	} else {
   159  		buf.WriteByte(0) // indicates that source map is not present.
   160  	}
   161  	return bytes.NewReader(buf.Bytes())
   162  }
   163  
   164  func deserializeCompiledModule(wazeroVersion string, reader io.ReadCloser) (cm *compiledModule, staleCache bool, err error) {
   165  	defer reader.Close()
   166  	cacheHeaderSize := len(magic) + 1 /* version size */ + len(wazeroVersion) + 4 /* number of functions */
   167  
   168  	// Read the header before the native code.
   169  	header := make([]byte, cacheHeaderSize)
   170  	n, err := reader.Read(header)
   171  	if err != nil {
   172  		return nil, false, fmt.Errorf("compilationcache: error reading header: %v", err)
   173  	}
   174  
   175  	if n != cacheHeaderSize {
   176  		return nil, false, fmt.Errorf("compilationcache: invalid header length: %d", n)
   177  	}
   178  
   179  	if !bytes.Equal(header[:len(magic)], magic) {
   180  		return nil, false, fmt.Errorf(
   181  			"compilationcache: invalid magic number: got %s but want %s", magic, header[:len(magic)])
   182  	}
   183  
   184  	// Check the version compatibility.
   185  	versionSize := int(header[len(magic)])
   186  
   187  	cachedVersionBegin, cachedVersionEnd := len(magic)+1, len(magic)+1+versionSize
   188  	if cachedVersionEnd >= len(header) {
   189  		staleCache = true
   190  		return
   191  	} else if cachedVersion := string(header[cachedVersionBegin:cachedVersionEnd]); cachedVersion != wazeroVersion {
   192  		staleCache = true
   193  		return
   194  	}
   195  
   196  	functionsNum := binary.LittleEndian.Uint32(header[len(header)-4:])
   197  	cm = &compiledModule{functionOffsets: make([]int, functionsNum), executables: &executables{}}
   198  
   199  	var eightBytes [8]byte
   200  	for i := uint32(0); i < functionsNum; i++ {
   201  		// Read the offset of each function in the executable.
   202  		var offset uint64
   203  		if offset, err = readUint64(reader, &eightBytes); err != nil {
   204  			err = fmt.Errorf("compilationcache: error reading func[%d] executable offset: %v", i, err)
   205  			return
   206  		}
   207  		cm.functionOffsets[i] = int(offset)
   208  	}
   209  
   210  	executableLen, err := readUint64(reader, &eightBytes)
   211  	if err != nil {
   212  		err = fmt.Errorf("compilationcache: error reading executable size: %v", err)
   213  		return
   214  	}
   215  
   216  	if executableLen > 0 {
   217  		executable, err := platform.MmapCodeSegment(int(executableLen))
   218  		if err != nil {
   219  			err = fmt.Errorf("compilationcache: error mmapping executable (len=%d): %v", executableLen, err)
   220  			return nil, false, err
   221  		}
   222  
   223  		_, err = io.ReadFull(reader, executable)
   224  		if err != nil {
   225  			err = fmt.Errorf("compilationcache: error reading executable (len=%d): %v", executableLen, err)
   226  			return nil, false, err
   227  		}
   228  
   229  		if runtime.GOARCH == "arm64" {
   230  			// On arm64, we cannot give all of rwx at the same time, so we change it to exec.
   231  			if err = platform.MprotectRX(executable); err != nil {
   232  				return nil, false, err
   233  			}
   234  		}
   235  		cm.executable = executable
   236  	}
   237  
   238  	if _, err := io.ReadFull(reader, eightBytes[:1]); err != nil {
   239  		return nil, false, fmt.Errorf("compilationcache: error reading source map presence: %v", err)
   240  	}
   241  
   242  	if eightBytes[0] == 1 {
   243  		sm := &cm.sourceMap
   244  		sourceMapLen, err := readUint64(reader, &eightBytes)
   245  		if err != nil {
   246  			err = fmt.Errorf("compilationcache: error reading source map length: %v", err)
   247  			return nil, false, err
   248  		}
   249  		executableOffset := uintptr(unsafe.Pointer(&cm.executable[0]))
   250  		for i := uint64(0); i < sourceMapLen; i++ {
   251  			wasmBinaryOffset, err := readUint64(reader, &eightBytes)
   252  			if err != nil {
   253  				err = fmt.Errorf("compilationcache: error reading source map[%d] wasm binary offset: %v", i, err)
   254  				return nil, false, err
   255  			}
   256  			executableRelativeOffset, err := readUint64(reader, &eightBytes)
   257  			if err != nil {
   258  				err = fmt.Errorf("compilationcache: error reading source map[%d] executable offset: %v", i, err)
   259  				return nil, false, err
   260  			}
   261  			sm.wasmBinaryOffsets = append(sm.wasmBinaryOffsets, wasmBinaryOffset)
   262  			// executableOffsets is absolute address, so we need to add executableOffset.
   263  			sm.executableOffsets = append(sm.executableOffsets, uintptr(executableRelativeOffset)+executableOffset)
   264  		}
   265  	}
   266  	return
   267  }
   268  
   269  // readUint64 strictly reads an uint64 in little-endian byte order, using the
   270  // given array as a buffer. This returns io.EOF if less than 8 bytes were read.
   271  func readUint64(reader io.Reader, b *[8]byte) (uint64, error) {
   272  	s := b[0:8]
   273  	n, err := reader.Read(s)
   274  	if err != nil {
   275  		return 0, err
   276  	} else if n < 8 { // more strict than reader.Read
   277  		return 0, io.EOF
   278  	}
   279  
   280  	// Read the u64 from the underlying buffer.
   281  	ret := binary.LittleEndian.Uint64(s)
   282  	return ret, nil
   283  }