wa-lang.org/wazero@v1.0.2/internal/engine/compiler/engine_cache.go (about)

     1  package compiler
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/binary"
     6  	"fmt"
     7  	"io"
     8  
     9  	"wa-lang.org/wazero/internal/platform"
    10  	"wa-lang.org/wazero/internal/u32"
    11  	"wa-lang.org/wazero/internal/u64"
    12  	"wa-lang.org/wazero/internal/wasm"
    13  )
    14  
    15  func (e *engine) deleteCodes(module *wasm.Module) {
    16  	e.mux.Lock()
    17  	defer e.mux.Unlock()
    18  	delete(e.codes, module.ID)
    19  
    20  	// Note: we do not call e.Cache.Delete, as the lifetime of
    21  	// the content is up to the implementation of extencache.Cache interface.
    22  }
    23  
    24  func (e *engine) addCodes(module *wasm.Module, codes []*code) (err error) {
    25  	e.addCodesToMemory(module, codes)
    26  	err = e.addCodesToCache(module, codes)
    27  	return
    28  }
    29  
    30  func (e *engine) getCodes(module *wasm.Module) (codes []*code, ok bool, err error) {
    31  	codes, ok = e.getCodesFromMemory(module)
    32  	if ok {
    33  		return
    34  	}
    35  	codes, ok, err = e.getCodesFromCache(module)
    36  	if ok {
    37  		e.addCodesToMemory(module, codes)
    38  	}
    39  	return
    40  }
    41  
    42  func (e *engine) addCodesToMemory(module *wasm.Module, codes []*code) {
    43  	e.mux.Lock()
    44  	defer e.mux.Unlock()
    45  	e.codes[module.ID] = codes
    46  }
    47  
    48  func (e *engine) getCodesFromMemory(module *wasm.Module) (codes []*code, ok bool) {
    49  	e.mux.RLock()
    50  	defer e.mux.RUnlock()
    51  	codes, ok = e.codes[module.ID]
    52  	return
    53  }
    54  
    55  func (e *engine) addCodesToCache(module *wasm.Module, codes []*code) (err error) {
    56  	if e.Cache == nil {
    57  		return
    58  	}
    59  	err = e.Cache.Add(module.ID, serializeCodes(e.wazeroVersion, codes))
    60  	return
    61  }
    62  
    63  func (e *engine) getCodesFromCache(module *wasm.Module) (codes []*code, hit bool, err error) {
    64  	if e.Cache == nil {
    65  		return
    66  	}
    67  
    68  	// Check if the entries exist in the external cache.
    69  	var cached io.ReadCloser
    70  	cached, hit, err = e.Cache.Get(module.ID)
    71  	if !hit || err != nil {
    72  		return
    73  	}
    74  	defer cached.Close()
    75  
    76  	// Otherwise, we hit the cache on external cache.
    77  	// We retrieve *code structures from `cached`.
    78  	var staleCache bool
    79  	codes, staleCache, err = deserializeCodes(e.wazeroVersion, cached)
    80  	if err != nil {
    81  		hit = false
    82  		return
    83  	} else if staleCache {
    84  		return nil, false, e.Cache.Delete(module.ID)
    85  	}
    86  
    87  	for i, c := range codes {
    88  		c.indexInModule = wasm.Index(i)
    89  		c.sourceModule = module
    90  	}
    91  	return
    92  }
    93  
    94  var wazeroMagic = "WAZERO" // version must be synced with the tag of the wazero library.
    95  
    96  func serializeCodes(wazeroVersion string, codes []*code) io.Reader {
    97  	buf := bytes.NewBuffer(nil)
    98  	// First 6 byte: WAZERO header.
    99  	buf.WriteString(wazeroMagic)
   100  	// Next 1 byte: length of version:
   101  	buf.WriteByte(byte(len(wazeroVersion)))
   102  	// Version of wazero.
   103  	buf.WriteString(wazeroVersion)
   104  	// Number of *code (== locally defined functions in the module): 4 bytes.
   105  	buf.Write(u32.LeBytes(uint32(len(codes))))
   106  	for _, c := range codes {
   107  		// The stack pointer ceil (8 bytes).
   108  		buf.Write(u64.LeBytes(c.stackPointerCeil))
   109  		// The length of code segment (8 bytes).
   110  		buf.Write(u64.LeBytes(uint64(len(c.codeSegment))))
   111  		// Append the native code.
   112  		buf.Write(c.codeSegment)
   113  	}
   114  	return bytes.NewReader(buf.Bytes())
   115  }
   116  
   117  func deserializeCodes(wazeroVersion string, reader io.Reader) (codes []*code, staleCache bool, err error) {
   118  	cacheHeaderSize := len(wazeroMagic) + 1 /* version size */ + len(wazeroVersion) + 4 /* number of functions */
   119  
   120  	// Read the header before the native code.
   121  	header := make([]byte, cacheHeaderSize)
   122  	n, err := reader.Read(header)
   123  	if err != nil {
   124  		return nil, false, fmt.Errorf("compilationcache: error reading header: %v", err)
   125  	}
   126  
   127  	if n != cacheHeaderSize {
   128  		return nil, false, fmt.Errorf("compilationcache: invalid header length: %d", n)
   129  	}
   130  
   131  	// Check the version compatibility.
   132  	versionSize := int(header[len(wazeroMagic)])
   133  
   134  	cachedVersionBegin, cachedVersionEnd := len(wazeroMagic)+1, len(wazeroMagic)+1+versionSize
   135  	if cachedVersionEnd >= len(header) {
   136  		staleCache = true
   137  		return
   138  	} else if cachedVersion := string(header[cachedVersionBegin:cachedVersionEnd]); cachedVersion != wazeroVersion {
   139  		staleCache = true
   140  		return
   141  	}
   142  
   143  	functionsNum := binary.LittleEndian.Uint32(header[len(header)-4:])
   144  	codes = make([]*code, 0, functionsNum)
   145  
   146  	var eightBytes [8]byte
   147  	var nativeCodeLen uint64
   148  	for i := uint32(0); i < functionsNum; i++ {
   149  		c := &code{}
   150  
   151  		// Read the stack pointer ceil.
   152  		if c.stackPointerCeil, err = readUint64(reader, &eightBytes); err != nil {
   153  			err = fmt.Errorf("compilationcache: error reading func[%d] stack pointer ceil: %v", i, err)
   154  			break
   155  		}
   156  
   157  		// Read (and mmap) the native code.
   158  		if nativeCodeLen, err = readUint64(reader, &eightBytes); err != nil {
   159  			err = fmt.Errorf("compilationcache: error reading func[%d] reading native code size: %v", i, err)
   160  			break
   161  		}
   162  
   163  		if c.codeSegment, err = platform.MmapCodeSegment(reader, int(nativeCodeLen)); err != nil {
   164  			err = fmt.Errorf("compilationcache: error mmapping func[%d] code (len=%d): %v", i, nativeCodeLen, err)
   165  			break
   166  		}
   167  
   168  		codes = append(codes, c)
   169  	}
   170  
   171  	if err != nil {
   172  		for _, c := range codes {
   173  			if errMunmap := platform.MunmapCodeSegment(c.codeSegment); errMunmap != nil {
   174  				// Munmap failure shouldn't happen.
   175  				panic(errMunmap)
   176  			}
   177  		}
   178  		codes = nil
   179  	}
   180  	return
   181  }
   182  
   183  // readUint64 strictly reads a uint64 in little-endian byte order, using the
   184  // given array as a buffer. This returns io.EOF if less than 8 bytes were read.
   185  func readUint64(reader io.Reader, b *[8]byte) (uint64, error) {
   186  	s := b[0:8]
   187  	n, err := reader.Read(s)
   188  	if err != nil {
   189  		return 0, err
   190  	} else if n < 8 { // more strict than reader.Read
   191  		return 0, io.EOF
   192  	}
   193  
   194  	// read the u64 from the underlying buffer
   195  	ret := binary.LittleEndian.Uint64(s)
   196  
   197  	// clear the underlying array
   198  	for i := 0; i < 8; i++ {
   199  		b[i] = 0
   200  	}
   201  	return ret, nil
   202  }