github.com/bananabytelabs/wazero@v0.0.0-20240105073314-54b22a776da8/internal/engine/compiler/engine_cache.go (about)

     1  package compiler
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/binary"
     6  	"fmt"
     7  	"io"
     8  	"runtime"
     9  
    10  	"github.com/bananabytelabs/wazero/experimental"
    11  	"github.com/bananabytelabs/wazero/internal/platform"
    12  	"github.com/bananabytelabs/wazero/internal/u32"
    13  	"github.com/bananabytelabs/wazero/internal/u64"
    14  	"github.com/bananabytelabs/wazero/internal/wasm"
    15  )
    16  
    17  func (e *engine) deleteCompiledModule(module *wasm.Module) {
    18  	e.mux.Lock()
    19  	defer e.mux.Unlock()
    20  
    21  	delete(e.codes, module.ID)
    22  
    23  	// Note: we do not call e.Cache.Delete, as the lifetime of
    24  	// the content is up to the implementation of extencache.Cache interface.
    25  }
    26  
    27  func (e *engine) addCompiledModule(module *wasm.Module, cm *compiledModule, withGoFunc bool) (err error) {
    28  	e.addCompiledModuleToMemory(module, cm)
    29  	if !withGoFunc {
    30  		err = e.addCompiledModuleToCache(module, cm)
    31  	}
    32  	return
    33  }
    34  
    35  func (e *engine) getCompiledModule(module *wasm.Module, listeners []experimental.FunctionListener) (cm *compiledModule, ok bool, err error) {
    36  	cm, ok = e.getCompiledModuleFromMemory(module)
    37  	if ok {
    38  		return
    39  	}
    40  	cm, ok, err = e.getCompiledModuleFromCache(module)
    41  	if ok {
    42  		e.addCompiledModuleToMemory(module, cm)
    43  		if len(listeners) > 0 {
    44  			// Files do not contain the actual listener instances (it's impossible to cache them as files!), so assign each here.
    45  			for i := range cm.functions {
    46  				cm.functions[i].listener = listeners[i]
    47  			}
    48  		}
    49  
    50  		// As this uses mmap, we need to munmap on the compiled machine code when it's GCed.
    51  		e.setFinalizer(cm, releaseCompiledModule)
    52  	}
    53  	return
    54  }
    55  
    56  func (e *engine) addCompiledModuleToMemory(module *wasm.Module, cm *compiledModule) {
    57  	e.mux.Lock()
    58  	defer e.mux.Unlock()
    59  	e.codes[module.ID] = cm
    60  }
    61  
    62  func (e *engine) getCompiledModuleFromMemory(module *wasm.Module) (cm *compiledModule, ok bool) {
    63  	e.mux.RLock()
    64  	defer e.mux.RUnlock()
    65  	cm, ok = e.codes[module.ID]
    66  	return
    67  }
    68  
    69  func (e *engine) addCompiledModuleToCache(module *wasm.Module, cm *compiledModule) (err error) {
    70  	if e.fileCache == nil || module.IsHostModule {
    71  		return
    72  	}
    73  	err = e.fileCache.Add(module.ID, serializeCompiledModule(e.wazeroVersion, cm))
    74  	return
    75  }
    76  
    77  func (e *engine) getCompiledModuleFromCache(module *wasm.Module) (cm *compiledModule, hit bool, err error) {
    78  	if e.fileCache == nil || module.IsHostModule {
    79  		return
    80  	}
    81  
    82  	// Check if the entries exist in the external cache.
    83  	var cached io.ReadCloser
    84  	cached, hit, err = e.fileCache.Get(module.ID)
    85  	if !hit || err != nil {
    86  		return
    87  	}
    88  
    89  	// Otherwise, we hit the cache on external cache.
    90  	// We retrieve *code structures from `cached`.
    91  	var staleCache bool
    92  	// Note: cached.Close is ensured to be called in deserializeCodes.
    93  	cm, staleCache, err = deserializeCompiledModule(e.wazeroVersion, cached, module)
    94  	if err != nil {
    95  		hit = false
    96  		return
    97  	} else if staleCache {
    98  		return nil, false, e.fileCache.Delete(module.ID)
    99  	}
   100  
   101  	cm.source = module
   102  	return
   103  }
   104  
   105  var wazeroMagic = "WAZERO" // version must be synced with the tag of the wazero library.
   106  
   107  func serializeCompiledModule(wazeroVersion string, cm *compiledModule) io.Reader {
   108  	buf := bytes.NewBuffer(nil)
   109  	// First 6 byte: WAZERO header.
   110  	buf.WriteString(wazeroMagic)
   111  	// Next 1 byte: length of version:
   112  	buf.WriteByte(byte(len(wazeroVersion)))
   113  	// Version of wazero.
   114  	buf.WriteString(wazeroVersion)
   115  	if cm.ensureTermination {
   116  		buf.WriteByte(1)
   117  	} else {
   118  		buf.WriteByte(0)
   119  	}
   120  	// Number of *code (== locally defined functions in the module): 4 bytes.
   121  	buf.Write(u32.LeBytes(uint32(len(cm.functions))))
   122  	for i := 0; i < len(cm.functions); i++ {
   123  		f := &cm.functions[i]
   124  		// The stack pointer ceil (8 bytes).
   125  		buf.Write(u64.LeBytes(f.stackPointerCeil))
   126  		// The offset of this function in the executable (8 bytes).
   127  		buf.Write(u64.LeBytes(uint64(f.executableOffset)))
   128  	}
   129  	// The length of code segment (8 bytes).
   130  	buf.Write(u64.LeBytes(uint64(cm.executable.Len())))
   131  	// Append the native code.
   132  	buf.Write(cm.executable.Bytes())
   133  	return bytes.NewReader(buf.Bytes())
   134  }
   135  
   136  func deserializeCompiledModule(wazeroVersion string, reader io.ReadCloser, module *wasm.Module) (cm *compiledModule, staleCache bool, err error) {
   137  	defer reader.Close()
   138  	cacheHeaderSize := len(wazeroMagic) + 1 /* version size */ + len(wazeroVersion) + 1 /* ensure termination */ + 4 /* number of functions */
   139  
   140  	// Read the header before the native code.
   141  	header := make([]byte, cacheHeaderSize)
   142  	n, err := reader.Read(header)
   143  	if err != nil {
   144  		return nil, false, fmt.Errorf("compilationcache: error reading header: %v", err)
   145  	}
   146  
   147  	if n != cacheHeaderSize {
   148  		return nil, false, fmt.Errorf("compilationcache: invalid header length: %d", n)
   149  	}
   150  
   151  	// Check the version compatibility.
   152  	versionSize := int(header[len(wazeroMagic)])
   153  
   154  	cachedVersionBegin, cachedVersionEnd := len(wazeroMagic)+1, len(wazeroMagic)+1+versionSize
   155  	if cachedVersionEnd >= len(header) {
   156  		staleCache = true
   157  		return
   158  	} else if cachedVersion := string(header[cachedVersionBegin:cachedVersionEnd]); cachedVersion != wazeroVersion {
   159  		staleCache = true
   160  		return
   161  	}
   162  
   163  	ensureTermination := header[cachedVersionEnd] != 0
   164  	functionsNum := binary.LittleEndian.Uint32(header[len(header)-4:])
   165  	cm = &compiledModule{
   166  		compiledCode:      new(compiledCode),
   167  		functions:         make([]compiledFunction, functionsNum),
   168  		ensureTermination: ensureTermination,
   169  	}
   170  
   171  	imported := module.ImportFunctionCount
   172  
   173  	var eightBytes [8]byte
   174  	for i := uint32(0); i < functionsNum; i++ {
   175  		f := &cm.functions[i]
   176  		f.parent = cm.compiledCode
   177  
   178  		// Read the stack pointer ceil.
   179  		if f.stackPointerCeil, err = readUint64(reader, &eightBytes); err != nil {
   180  			err = fmt.Errorf("compilationcache: error reading func[%d] stack pointer ceil: %v", i, err)
   181  			return
   182  		}
   183  
   184  		// Read the offset of each function in the executable.
   185  		var offset uint64
   186  		if offset, err = readUint64(reader, &eightBytes); err != nil {
   187  			err = fmt.Errorf("compilationcache: error reading func[%d] executable offset: %v", i, err)
   188  			return
   189  		}
   190  		f.executableOffset = uintptr(offset)
   191  		f.index = imported + i
   192  	}
   193  
   194  	executableLen, err := readUint64(reader, &eightBytes)
   195  	if err != nil {
   196  		err = fmt.Errorf("compilationcache: error reading executable size: %v", err)
   197  		return
   198  	}
   199  
   200  	if executableLen > 0 {
   201  		if err = cm.executable.Map(int(executableLen)); err != nil {
   202  			err = fmt.Errorf("compilationcache: error mmapping executable (len=%d): %v", executableLen, err)
   203  			return
   204  		}
   205  
   206  		_, err = io.ReadFull(reader, cm.executable.Bytes())
   207  		if err != nil {
   208  			err = fmt.Errorf("compilationcache: error reading executable (len=%d): %v", executableLen, err)
   209  			return
   210  		}
   211  
   212  		if runtime.GOARCH == "arm64" {
   213  			// On arm64, we cannot give all of rwx at the same time, so we change it to exec.
   214  			if err = platform.MprotectRX(cm.executable.Bytes()); err != nil {
   215  				return
   216  			}
   217  		}
   218  	}
   219  	return
   220  }
   221  
   222  // readUint64 strictly reads an uint64 in little-endian byte order, using the
   223  // given array as a buffer. This returns io.EOF if less than 8 bytes were read.
   224  func readUint64(reader io.Reader, b *[8]byte) (uint64, error) {
   225  	s := b[0:8]
   226  	n, err := reader.Read(s)
   227  	if err != nil {
   228  		return 0, err
   229  	} else if n < 8 { // more strict than reader.Read
   230  		return 0, io.EOF
   231  	}
   232  
   233  	// Read the u64 from the underlying buffer.
   234  	ret := binary.LittleEndian.Uint64(s)
   235  
   236  	// Clear the underlying array.
   237  	for i := 0; i < 8; i++ {
   238  		b[i] = 0
   239  	}
   240  	return ret, nil
   241  }