github.com/tetratelabs/wazero@v1.7.3-0.20240513003603-48f702e154b5/internal/engine/wazevo/engine_cache.go (about)

     1  package wazevo
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"crypto/sha256"
     7  	"encoding/binary"
     8  	"fmt"
     9  	"hash/crc32"
    10  	"io"
    11  	"runtime"
    12  	"unsafe"
    13  
    14  	"github.com/tetratelabs/wazero/experimental"
    15  	"github.com/tetratelabs/wazero/internal/engine/wazevo/backend"
    16  	"github.com/tetratelabs/wazero/internal/engine/wazevo/ssa"
    17  	"github.com/tetratelabs/wazero/internal/engine/wazevo/wazevoapi"
    18  	"github.com/tetratelabs/wazero/internal/filecache"
    19  	"github.com/tetratelabs/wazero/internal/platform"
    20  	"github.com/tetratelabs/wazero/internal/u32"
    21  	"github.com/tetratelabs/wazero/internal/u64"
    22  	"github.com/tetratelabs/wazero/internal/wasm"
    23  )
    24  
    25  var crc = crc32.MakeTable(crc32.Castagnoli)
    26  
    27  // fileCacheKey returns a key for the file cache.
    28  // In order to avoid collisions with the existing compiler, we do not use m.ID directly,
    29  // but instead we rehash it with magic.
    30  func fileCacheKey(m *wasm.Module) (ret filecache.Key) {
    31  	s := sha256.New()
    32  	s.Write(m.ID[:])
    33  	s.Write(magic)
    34  	s.Sum(ret[:0])
    35  	return
    36  }
    37  
    38  func (e *engine) addCompiledModule(module *wasm.Module, cm *compiledModule) (err error) {
    39  	e.addCompiledModuleToMemory(module, cm)
    40  	if !module.IsHostModule && e.fileCache != nil {
    41  		err = e.addCompiledModuleToCache(module, cm)
    42  	}
    43  	return
    44  }
    45  
    46  func (e *engine) getCompiledModule(module *wasm.Module, listeners []experimental.FunctionListener, ensureTermination bool) (cm *compiledModule, ok bool, err error) {
    47  	cm, ok = e.getCompiledModuleFromMemory(module)
    48  	if ok {
    49  		return
    50  	}
    51  	cm, ok, err = e.getCompiledModuleFromCache(module)
    52  	if ok {
    53  		cm.parent = e
    54  		cm.module = module
    55  		cm.sharedFunctions = e.sharedFunctions
    56  		cm.ensureTermination = ensureTermination
    57  		cm.offsets = wazevoapi.NewModuleContextOffsetData(module, len(listeners) > 0)
    58  		if len(listeners) > 0 {
    59  			cm.listeners = listeners
    60  			cm.listenerBeforeTrampolines = make([]*byte, len(module.TypeSection))
    61  			cm.listenerAfterTrampolines = make([]*byte, len(module.TypeSection))
    62  			for i := range module.TypeSection {
    63  				typ := &module.TypeSection[i]
    64  				before, after := e.getListenerTrampolineForType(typ)
    65  				cm.listenerBeforeTrampolines[i] = before
    66  				cm.listenerAfterTrampolines[i] = after
    67  			}
    68  		}
    69  		e.addCompiledModuleToMemory(module, cm)
    70  		ssaBuilder := ssa.NewBuilder()
    71  		machine := newMachine()
    72  		be := backend.NewCompiler(context.Background(), machine, ssaBuilder)
    73  		cm.executables.compileEntryPreambles(module, machine, be)
    74  
    75  		// Set the finalizer.
    76  		e.setFinalizer(cm.executables, executablesFinalizer)
    77  	}
    78  	return
    79  }
    80  
    81  func (e *engine) addCompiledModuleToMemory(m *wasm.Module, cm *compiledModule) {
    82  	e.mux.Lock()
    83  	defer e.mux.Unlock()
    84  	e.compiledModules[m.ID] = cm
    85  	if len(cm.executable) > 0 {
    86  		e.addCompiledModuleToSortedList(cm)
    87  	}
    88  }
    89  
    90  func (e *engine) getCompiledModuleFromMemory(module *wasm.Module) (cm *compiledModule, ok bool) {
    91  	e.mux.RLock()
    92  	defer e.mux.RUnlock()
    93  	cm, ok = e.compiledModules[module.ID]
    94  	return
    95  }
    96  
    97  func (e *engine) addCompiledModuleToCache(module *wasm.Module, cm *compiledModule) (err error) {
    98  	if e.fileCache == nil || module.IsHostModule {
    99  		return
   100  	}
   101  	err = e.fileCache.Add(fileCacheKey(module), serializeCompiledModule(e.wazeroVersion, cm))
   102  	return
   103  }
   104  
   105  func (e *engine) getCompiledModuleFromCache(module *wasm.Module) (cm *compiledModule, hit bool, err error) {
   106  	if e.fileCache == nil || module.IsHostModule {
   107  		return
   108  	}
   109  
   110  	// Check if the entries exist in the external cache.
   111  	var cached io.ReadCloser
   112  	cached, hit, err = e.fileCache.Get(fileCacheKey(module))
   113  	if !hit || err != nil {
   114  		return
   115  	}
   116  
   117  	// Otherwise, we hit the cache on external cache.
   118  	// We retrieve *code structures from `cached`.
   119  	var staleCache bool
   120  	// Note: cached.Close is ensured to be called in deserializeCodes.
   121  	cm, staleCache, err = deserializeCompiledModule(e.wazeroVersion, cached)
   122  	if err != nil {
   123  		hit = false
   124  		return
   125  	} else if staleCache {
   126  		return nil, false, e.fileCache.Delete(fileCacheKey(module))
   127  	}
   128  	return
   129  }
   130  
   131  var magic = []byte{'W', 'A', 'Z', 'E', 'V', 'O'}
   132  
   133  func serializeCompiledModule(wazeroVersion string, cm *compiledModule) io.Reader {
   134  	buf := bytes.NewBuffer(nil)
   135  	// First 6 byte: WAZEVO header.
   136  	buf.Write(magic)
   137  	// Next 1 byte: length of version:
   138  	buf.WriteByte(byte(len(wazeroVersion)))
   139  	// Version of wazero.
   140  	buf.WriteString(wazeroVersion)
   141  	// Number of *code (== locally defined functions in the module): 4 bytes.
   142  	buf.Write(u32.LeBytes(uint32(len(cm.functionOffsets))))
   143  	for _, offset := range cm.functionOffsets {
   144  		// The offset of this function in the executable (8 bytes).
   145  		buf.Write(u64.LeBytes(uint64(offset)))
   146  	}
   147  	// The length of code segment (8 bytes).
   148  	buf.Write(u64.LeBytes(uint64(len(cm.executable))))
   149  	// Append the native code.
   150  	buf.Write(cm.executable)
   151  	// Append checksum.
   152  	checksum := crc32.Checksum(cm.executable, crc)
   153  	buf.Write(u32.LeBytes(checksum))
   154  	if sm := cm.sourceMap; len(sm.executableOffsets) > 0 {
   155  		buf.WriteByte(1) // indicates that source map is present.
   156  		l := len(sm.wasmBinaryOffsets)
   157  		buf.Write(u64.LeBytes(uint64(l)))
   158  		executableAddr := uintptr(unsafe.Pointer(&cm.executable[0]))
   159  		for i := 0; i < l; i++ {
   160  			buf.Write(u64.LeBytes(sm.wasmBinaryOffsets[i]))
   161  			// executableOffsets is absolute address, so we need to subtract executableAddr.
   162  			buf.Write(u64.LeBytes(uint64(sm.executableOffsets[i] - executableAddr)))
   163  		}
   164  	} else {
   165  		buf.WriteByte(0) // indicates that source map is not present.
   166  	}
   167  	return bytes.NewReader(buf.Bytes())
   168  }
   169  
   170  func deserializeCompiledModule(wazeroVersion string, reader io.ReadCloser) (cm *compiledModule, staleCache bool, err error) {
   171  	defer reader.Close()
   172  	cacheHeaderSize := len(magic) + 1 /* version size */ + len(wazeroVersion) + 4 /* number of functions */
   173  
   174  	// Read the header before the native code.
   175  	header := make([]byte, cacheHeaderSize)
   176  	n, err := reader.Read(header)
   177  	if err != nil {
   178  		return nil, false, fmt.Errorf("compilationcache: error reading header: %v", err)
   179  	}
   180  
   181  	if n != cacheHeaderSize {
   182  		return nil, false, fmt.Errorf("compilationcache: invalid header length: %d", n)
   183  	}
   184  
   185  	if !bytes.Equal(header[:len(magic)], magic) {
   186  		return nil, false, fmt.Errorf(
   187  			"compilationcache: invalid magic number: got %s but want %s", magic, header[:len(magic)])
   188  	}
   189  
   190  	// Check the version compatibility.
   191  	versionSize := int(header[len(magic)])
   192  
   193  	cachedVersionBegin, cachedVersionEnd := len(magic)+1, len(magic)+1+versionSize
   194  	if cachedVersionEnd >= len(header) {
   195  		staleCache = true
   196  		return
   197  	} else if cachedVersion := string(header[cachedVersionBegin:cachedVersionEnd]); cachedVersion != wazeroVersion {
   198  		staleCache = true
   199  		return
   200  	}
   201  
   202  	functionsNum := binary.LittleEndian.Uint32(header[len(header)-4:])
   203  	cm = &compiledModule{functionOffsets: make([]int, functionsNum), executables: &executables{}}
   204  
   205  	var eightBytes [8]byte
   206  	for i := uint32(0); i < functionsNum; i++ {
   207  		// Read the offset of each function in the executable.
   208  		var offset uint64
   209  		if offset, err = readUint64(reader, &eightBytes); err != nil {
   210  			err = fmt.Errorf("compilationcache: error reading func[%d] executable offset: %v", i, err)
   211  			return
   212  		}
   213  		cm.functionOffsets[i] = int(offset)
   214  	}
   215  
   216  	executableLen, err := readUint64(reader, &eightBytes)
   217  	if err != nil {
   218  		err = fmt.Errorf("compilationcache: error reading executable size: %v", err)
   219  		return
   220  	}
   221  
   222  	if executableLen > 0 {
   223  		executable, err := platform.MmapCodeSegment(int(executableLen))
   224  		if err != nil {
   225  			err = fmt.Errorf("compilationcache: error mmapping executable (len=%d): %v", executableLen, err)
   226  			return nil, false, err
   227  		}
   228  
   229  		_, err = io.ReadFull(reader, executable)
   230  		if err != nil {
   231  			err = fmt.Errorf("compilationcache: error reading executable (len=%d): %v", executableLen, err)
   232  			return nil, false, err
   233  		}
   234  
   235  		expected := crc32.Checksum(executable, crc)
   236  		if _, err = io.ReadFull(reader, eightBytes[:4]); err != nil {
   237  			return nil, false, fmt.Errorf("compilationcache: could not read checksum: %v", err)
   238  		} else if checksum := binary.LittleEndian.Uint32(eightBytes[:4]); expected != checksum {
   239  			return nil, false, fmt.Errorf("compilationcache: checksum mismatch (expected %d, got %d)", expected, checksum)
   240  		}
   241  
   242  		if runtime.GOARCH == "arm64" {
   243  			// On arm64, we cannot give all of rwx at the same time, so we change it to exec.
   244  			if err = platform.MprotectRX(executable); err != nil {
   245  				return nil, false, err
   246  			}
   247  		}
   248  		cm.executable = executable
   249  	}
   250  
   251  	if _, err := io.ReadFull(reader, eightBytes[:1]); err != nil {
   252  		return nil, false, fmt.Errorf("compilationcache: error reading source map presence: %v", err)
   253  	}
   254  
   255  	if eightBytes[0] == 1 {
   256  		sm := &cm.sourceMap
   257  		sourceMapLen, err := readUint64(reader, &eightBytes)
   258  		if err != nil {
   259  			err = fmt.Errorf("compilationcache: error reading source map length: %v", err)
   260  			return nil, false, err
   261  		}
   262  		executableOffset := uintptr(unsafe.Pointer(&cm.executable[0]))
   263  		for i := uint64(0); i < sourceMapLen; i++ {
   264  			wasmBinaryOffset, err := readUint64(reader, &eightBytes)
   265  			if err != nil {
   266  				err = fmt.Errorf("compilationcache: error reading source map[%d] wasm binary offset: %v", i, err)
   267  				return nil, false, err
   268  			}
   269  			executableRelativeOffset, err := readUint64(reader, &eightBytes)
   270  			if err != nil {
   271  				err = fmt.Errorf("compilationcache: error reading source map[%d] executable offset: %v", i, err)
   272  				return nil, false, err
   273  			}
   274  			sm.wasmBinaryOffsets = append(sm.wasmBinaryOffsets, wasmBinaryOffset)
   275  			// executableOffsets is absolute address, so we need to add executableOffset.
   276  			sm.executableOffsets = append(sm.executableOffsets, uintptr(executableRelativeOffset)+executableOffset)
   277  		}
   278  	}
   279  	return
   280  }
   281  
   282  // readUint64 strictly reads an uint64 in little-endian byte order, using the
   283  // given array as a buffer. This returns io.EOF if less than 8 bytes were read.
   284  func readUint64(reader io.Reader, b *[8]byte) (uint64, error) {
   285  	s := b[0:8]
   286  	n, err := reader.Read(s)
   287  	if err != nil {
   288  		return 0, err
   289  	} else if n < 8 { // more strict than reader.Read
   290  		return 0, io.EOF
   291  	}
   292  
   293  	// Read the u64 from the underlying buffer.
   294  	ret := binary.LittleEndian.Uint64(s)
   295  	return ret, nil
   296  }