github.com/tetratelabs/wazero@v1.2.1/internal/engine/compiler/engine_cache.go (about)

     1  package compiler
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/binary"
     6  	"fmt"
     7  	"io"
     8  	"runtime"
     9  
    10  	"github.com/tetratelabs/wazero/experimental"
    11  	"github.com/tetratelabs/wazero/internal/platform"
    12  	"github.com/tetratelabs/wazero/internal/u32"
    13  	"github.com/tetratelabs/wazero/internal/u64"
    14  	"github.com/tetratelabs/wazero/internal/wasm"
    15  )
    16  
    17  func (e *engine) deleteCompiledModule(module *wasm.Module) {
    18  	e.mux.Lock()
    19  	defer e.mux.Unlock()
    20  	delete(e.codes, module.ID)
    21  
    22  	// Note: we do not call e.Cache.Delete, as the lifetime of
    23  	// the content is up to the implementation of extencache.Cache interface.
    24  }
    25  
    26  func (e *engine) addCompiledModule(module *wasm.Module, cm *compiledModule, withGoFunc bool) (err error) {
    27  	e.addCompiledModuleToMemory(module, cm)
    28  	if !withGoFunc {
    29  		err = e.addCompiledModuleToCache(module, cm)
    30  	}
    31  	return
    32  }
    33  
    34  func (e *engine) getCompiledModule(module *wasm.Module, listeners []experimental.FunctionListener) (cm *compiledModule, ok bool, err error) {
    35  	cm, ok = e.getCompiledModuleFromMemory(module)
    36  	if ok {
    37  		return
    38  	}
    39  	cm, ok, err = e.getCompiledModuleFromCache(module)
    40  	if ok {
    41  		e.addCompiledModuleToMemory(module, cm)
    42  		if len(listeners) > 0 {
    43  			// Files do not contain the actual listener instances (it's impossible to cache them as files!), so assign each here.
    44  			for i := range cm.functions {
    45  				cm.functions[i].listener = listeners[i]
    46  			}
    47  		}
    48  	}
    49  	return
    50  }
    51  
    52  func (e *engine) addCompiledModuleToMemory(module *wasm.Module, cm *compiledModule) {
    53  	e.mux.Lock()
    54  	defer e.mux.Unlock()
    55  	e.codes[module.ID] = cm
    56  }
    57  
    58  func (e *engine) getCompiledModuleFromMemory(module *wasm.Module) (cm *compiledModule, ok bool) {
    59  	e.mux.RLock()
    60  	defer e.mux.RUnlock()
    61  	cm, ok = e.codes[module.ID]
    62  	return
    63  }
    64  
    65  func (e *engine) addCompiledModuleToCache(module *wasm.Module, cm *compiledModule) (err error) {
    66  	if e.fileCache == nil || module.IsHostModule {
    67  		return
    68  	}
    69  	err = e.fileCache.Add(module.ID, serializeCompiledModule(e.wazeroVersion, cm))
    70  	return
    71  }
    72  
    73  func (e *engine) getCompiledModuleFromCache(module *wasm.Module) (cm *compiledModule, hit bool, err error) {
    74  	if e.fileCache == nil || module.IsHostModule {
    75  		return
    76  	}
    77  
    78  	// Check if the entries exist in the external cache.
    79  	var cached io.ReadCloser
    80  	cached, hit, err = e.fileCache.Get(module.ID)
    81  	if !hit || err != nil {
    82  		return
    83  	}
    84  
    85  	// Otherwise, we hit the cache on external cache.
    86  	// We retrieve *code structures from `cached`.
    87  	var staleCache bool
    88  	// Note: cached.Close is ensured to be called in deserializeCodes.
    89  	cm, staleCache, err = deserializeCompiledModule(e.wazeroVersion, cached, module)
    90  	if err != nil {
    91  		hit = false
    92  		return
    93  	} else if staleCache {
    94  		return nil, false, e.fileCache.Delete(module.ID)
    95  	}
    96  
    97  	cm.source = module
    98  	return
    99  }
   100  
   101  var wazeroMagic = "WAZERO" // version must be synced with the tag of the wazero library.
   102  
   103  func serializeCompiledModule(wazeroVersion string, cm *compiledModule) io.Reader {
   104  	buf := bytes.NewBuffer(nil)
   105  	// First 6 byte: WAZERO header.
   106  	buf.WriteString(wazeroMagic)
   107  	// Next 1 byte: length of version:
   108  	buf.WriteByte(byte(len(wazeroVersion)))
   109  	// Version of wazero.
   110  	buf.WriteString(wazeroVersion)
   111  	if cm.ensureTermination {
   112  		buf.WriteByte(1)
   113  	} else {
   114  		buf.WriteByte(0)
   115  	}
   116  	// Number of *code (== locally defined functions in the module): 4 bytes.
   117  	buf.Write(u32.LeBytes(uint32(len(cm.functions))))
   118  	for i := 0; i < len(cm.functions); i++ {
   119  		f := &cm.functions[i]
   120  		// The stack pointer ceil (8 bytes).
   121  		buf.Write(u64.LeBytes(f.stackPointerCeil))
   122  		// The offset of this function in the executable (8 bytes).
   123  		buf.Write(u64.LeBytes(uint64(f.executableOffset)))
   124  	}
   125  	// The length of code segment (8 bytes).
   126  	buf.Write(u64.LeBytes(uint64(cm.executable.Len())))
   127  	// Append the native code.
   128  	buf.Write(cm.executable.Bytes())
   129  	return bytes.NewReader(buf.Bytes())
   130  }
   131  
   132  func deserializeCompiledModule(wazeroVersion string, reader io.ReadCloser, module *wasm.Module) (cm *compiledModule, staleCache bool, err error) {
   133  	defer reader.Close()
   134  	cacheHeaderSize := len(wazeroMagic) + 1 /* version size */ + len(wazeroVersion) + 1 /* ensure termination */ + 4 /* number of functions */
   135  
   136  	// Read the header before the native code.
   137  	header := make([]byte, cacheHeaderSize)
   138  	n, err := reader.Read(header)
   139  	if err != nil {
   140  		return nil, false, fmt.Errorf("compilationcache: error reading header: %v", err)
   141  	}
   142  
   143  	if n != cacheHeaderSize {
   144  		return nil, false, fmt.Errorf("compilationcache: invalid header length: %d", n)
   145  	}
   146  
   147  	// Check the version compatibility.
   148  	versionSize := int(header[len(wazeroMagic)])
   149  
   150  	cachedVersionBegin, cachedVersionEnd := len(wazeroMagic)+1, len(wazeroMagic)+1+versionSize
   151  	if cachedVersionEnd >= len(header) {
   152  		staleCache = true
   153  		return
   154  	} else if cachedVersion := string(header[cachedVersionBegin:cachedVersionEnd]); cachedVersion != wazeroVersion {
   155  		staleCache = true
   156  		return
   157  	}
   158  
   159  	ensureTermination := header[cachedVersionEnd] != 0
   160  	functionsNum := binary.LittleEndian.Uint32(header[len(header)-4:])
   161  	cm = &compiledModule{functions: make([]compiledFunction, functionsNum), ensureTermination: ensureTermination}
   162  
   163  	imported := module.ImportFunctionCount
   164  
   165  	var eightBytes [8]byte
   166  	for i := uint32(0); i < functionsNum; i++ {
   167  		f := &cm.functions[i]
   168  		f.parent = cm
   169  
   170  		// Read the stack pointer ceil.
   171  		if f.stackPointerCeil, err = readUint64(reader, &eightBytes); err != nil {
   172  			err = fmt.Errorf("compilationcache: error reading func[%d] stack pointer ceil: %v", i, err)
   173  			return
   174  		}
   175  
   176  		// Read the offset of each function in the executable.
   177  		var offset uint64
   178  		if offset, err = readUint64(reader, &eightBytes); err != nil {
   179  			err = fmt.Errorf("compilationcache: error reading func[%d] executable offset: %v", i, err)
   180  			return
   181  		}
   182  		f.executableOffset = uintptr(offset)
   183  		f.index = imported + i
   184  	}
   185  
   186  	executableLen, err := readUint64(reader, &eightBytes)
   187  	if err != nil {
   188  		err = fmt.Errorf("compilationcache: error reading executable size: %v", err)
   189  		return
   190  	}
   191  
   192  	if executableLen > 0 {
   193  		if err = cm.executable.Map(int(executableLen)); err != nil {
   194  			err = fmt.Errorf("compilationcache: error mmapping executable (len=%d): %v", executableLen, err)
   195  			return
   196  		}
   197  
   198  		_, err = io.ReadFull(reader, cm.executable.Bytes())
   199  		if err != nil {
   200  			err = fmt.Errorf("compilationcache: error reading executable (len=%d): %v", executableLen, err)
   201  			return
   202  		}
   203  
   204  		if runtime.GOARCH == "arm64" {
   205  			// On arm64, we cannot give all of rwx at the same time, so we change it to exec.
   206  			if err = platform.MprotectRX(cm.executable.Bytes()); err != nil {
   207  				return
   208  			}
   209  		}
   210  	}
   211  	return
   212  }
   213  
   214  // readUint64 strictly reads an uint64 in little-endian byte order, using the
   215  // given array as a buffer. This returns io.EOF if less than 8 bytes were read.
   216  func readUint64(reader io.Reader, b *[8]byte) (uint64, error) {
   217  	s := b[0:8]
   218  	n, err := reader.Read(s)
   219  	if err != nil {
   220  		return 0, err
   221  	} else if n < 8 { // more strict than reader.Read
   222  		return 0, io.EOF
   223  	}
   224  
   225  	// Read the u64 from the underlying buffer.
   226  	ret := binary.LittleEndian.Uint64(s)
   227  
   228  	// Clear the underlying array.
   229  	for i := 0; i < 8; i++ {
   230  		b[i] = 0
   231  	}
   232  	return ret, nil
   233  }