github.com/tetratelabs/wazero@v1.7.3-0.20240513003603-48f702e154b5/internal/engine/wazevo/engine_cache.go (about) 1 package wazevo 2 3 import ( 4 "bytes" 5 "context" 6 "crypto/sha256" 7 "encoding/binary" 8 "fmt" 9 "hash/crc32" 10 "io" 11 "runtime" 12 "unsafe" 13 14 "github.com/tetratelabs/wazero/experimental" 15 "github.com/tetratelabs/wazero/internal/engine/wazevo/backend" 16 "github.com/tetratelabs/wazero/internal/engine/wazevo/ssa" 17 "github.com/tetratelabs/wazero/internal/engine/wazevo/wazevoapi" 18 "github.com/tetratelabs/wazero/internal/filecache" 19 "github.com/tetratelabs/wazero/internal/platform" 20 "github.com/tetratelabs/wazero/internal/u32" 21 "github.com/tetratelabs/wazero/internal/u64" 22 "github.com/tetratelabs/wazero/internal/wasm" 23 ) 24 25 var crc = crc32.MakeTable(crc32.Castagnoli) 26 27 // fileCacheKey returns a key for the file cache. 28 // In order to avoid collisions with the existing compiler, we do not use m.ID directly, 29 // but instead we rehash it with magic. 30 func fileCacheKey(m *wasm.Module) (ret filecache.Key) { 31 s := sha256.New() 32 s.Write(m.ID[:]) 33 s.Write(magic) 34 s.Sum(ret[:0]) 35 return 36 } 37 38 func (e *engine) addCompiledModule(module *wasm.Module, cm *compiledModule) (err error) { 39 e.addCompiledModuleToMemory(module, cm) 40 if !module.IsHostModule && e.fileCache != nil { 41 err = e.addCompiledModuleToCache(module, cm) 42 } 43 return 44 } 45 46 func (e *engine) getCompiledModule(module *wasm.Module, listeners []experimental.FunctionListener, ensureTermination bool) (cm *compiledModule, ok bool, err error) { 47 cm, ok = e.getCompiledModuleFromMemory(module) 48 if ok { 49 return 50 } 51 cm, ok, err = e.getCompiledModuleFromCache(module) 52 if ok { 53 cm.parent = e 54 cm.module = module 55 cm.sharedFunctions = e.sharedFunctions 56 cm.ensureTermination = ensureTermination 57 cm.offsets = wazevoapi.NewModuleContextOffsetData(module, len(listeners) > 0) 58 if len(listeners) > 0 { 59 cm.listeners = listeners 60 cm.listenerBeforeTrampolines = make([]*byte, len(module.TypeSection)) 61 cm.listenerAfterTrampolines = make([]*byte, len(module.TypeSection)) 62 for i := range module.TypeSection { 63 typ := &module.TypeSection[i] 64 before, after := e.getListenerTrampolineForType(typ) 65 cm.listenerBeforeTrampolines[i] = before 66 cm.listenerAfterTrampolines[i] = after 67 } 68 } 69 e.addCompiledModuleToMemory(module, cm) 70 ssaBuilder := ssa.NewBuilder() 71 machine := newMachine() 72 be := backend.NewCompiler(context.Background(), machine, ssaBuilder) 73 cm.executables.compileEntryPreambles(module, machine, be) 74 75 // Set the finalizer. 76 e.setFinalizer(cm.executables, executablesFinalizer) 77 } 78 return 79 } 80 81 func (e *engine) addCompiledModuleToMemory(m *wasm.Module, cm *compiledModule) { 82 e.mux.Lock() 83 defer e.mux.Unlock() 84 e.compiledModules[m.ID] = cm 85 if len(cm.executable) > 0 { 86 e.addCompiledModuleToSortedList(cm) 87 } 88 } 89 90 func (e *engine) getCompiledModuleFromMemory(module *wasm.Module) (cm *compiledModule, ok bool) { 91 e.mux.RLock() 92 defer e.mux.RUnlock() 93 cm, ok = e.compiledModules[module.ID] 94 return 95 } 96 97 func (e *engine) addCompiledModuleToCache(module *wasm.Module, cm *compiledModule) (err error) { 98 if e.fileCache == nil || module.IsHostModule { 99 return 100 } 101 err = e.fileCache.Add(fileCacheKey(module), serializeCompiledModule(e.wazeroVersion, cm)) 102 return 103 } 104 105 func (e *engine) getCompiledModuleFromCache(module *wasm.Module) (cm *compiledModule, hit bool, err error) { 106 if e.fileCache == nil || module.IsHostModule { 107 return 108 } 109 110 // Check if the entries exist in the external cache. 111 var cached io.ReadCloser 112 cached, hit, err = e.fileCache.Get(fileCacheKey(module)) 113 if !hit || err != nil { 114 return 115 } 116 117 // Otherwise, we hit the cache on external cache. 118 // We retrieve *code structures from `cached`. 119 var staleCache bool 120 // Note: cached.Close is ensured to be called in deserializeCodes. 121 cm, staleCache, err = deserializeCompiledModule(e.wazeroVersion, cached) 122 if err != nil { 123 hit = false 124 return 125 } else if staleCache { 126 return nil, false, e.fileCache.Delete(fileCacheKey(module)) 127 } 128 return 129 } 130 131 var magic = []byte{'W', 'A', 'Z', 'E', 'V', 'O'} 132 133 func serializeCompiledModule(wazeroVersion string, cm *compiledModule) io.Reader { 134 buf := bytes.NewBuffer(nil) 135 // First 6 byte: WAZEVO header. 136 buf.Write(magic) 137 // Next 1 byte: length of version: 138 buf.WriteByte(byte(len(wazeroVersion))) 139 // Version of wazero. 140 buf.WriteString(wazeroVersion) 141 // Number of *code (== locally defined functions in the module): 4 bytes. 142 buf.Write(u32.LeBytes(uint32(len(cm.functionOffsets)))) 143 for _, offset := range cm.functionOffsets { 144 // The offset of this function in the executable (8 bytes). 145 buf.Write(u64.LeBytes(uint64(offset))) 146 } 147 // The length of code segment (8 bytes). 148 buf.Write(u64.LeBytes(uint64(len(cm.executable)))) 149 // Append the native code. 150 buf.Write(cm.executable) 151 // Append checksum. 152 checksum := crc32.Checksum(cm.executable, crc) 153 buf.Write(u32.LeBytes(checksum)) 154 if sm := cm.sourceMap; len(sm.executableOffsets) > 0 { 155 buf.WriteByte(1) // indicates that source map is present. 156 l := len(sm.wasmBinaryOffsets) 157 buf.Write(u64.LeBytes(uint64(l))) 158 executableAddr := uintptr(unsafe.Pointer(&cm.executable[0])) 159 for i := 0; i < l; i++ { 160 buf.Write(u64.LeBytes(sm.wasmBinaryOffsets[i])) 161 // executableOffsets is absolute address, so we need to subtract executableAddr. 162 buf.Write(u64.LeBytes(uint64(sm.executableOffsets[i] - executableAddr))) 163 } 164 } else { 165 buf.WriteByte(0) // indicates that source map is not present. 166 } 167 return bytes.NewReader(buf.Bytes()) 168 } 169 170 func deserializeCompiledModule(wazeroVersion string, reader io.ReadCloser) (cm *compiledModule, staleCache bool, err error) { 171 defer reader.Close() 172 cacheHeaderSize := len(magic) + 1 /* version size */ + len(wazeroVersion) + 4 /* number of functions */ 173 174 // Read the header before the native code. 175 header := make([]byte, cacheHeaderSize) 176 n, err := reader.Read(header) 177 if err != nil { 178 return nil, false, fmt.Errorf("compilationcache: error reading header: %v", err) 179 } 180 181 if n != cacheHeaderSize { 182 return nil, false, fmt.Errorf("compilationcache: invalid header length: %d", n) 183 } 184 185 if !bytes.Equal(header[:len(magic)], magic) { 186 return nil, false, fmt.Errorf( 187 "compilationcache: invalid magic number: got %s but want %s", magic, header[:len(magic)]) 188 } 189 190 // Check the version compatibility. 191 versionSize := int(header[len(magic)]) 192 193 cachedVersionBegin, cachedVersionEnd := len(magic)+1, len(magic)+1+versionSize 194 if cachedVersionEnd >= len(header) { 195 staleCache = true 196 return 197 } else if cachedVersion := string(header[cachedVersionBegin:cachedVersionEnd]); cachedVersion != wazeroVersion { 198 staleCache = true 199 return 200 } 201 202 functionsNum := binary.LittleEndian.Uint32(header[len(header)-4:]) 203 cm = &compiledModule{functionOffsets: make([]int, functionsNum), executables: &executables{}} 204 205 var eightBytes [8]byte 206 for i := uint32(0); i < functionsNum; i++ { 207 // Read the offset of each function in the executable. 208 var offset uint64 209 if offset, err = readUint64(reader, &eightBytes); err != nil { 210 err = fmt.Errorf("compilationcache: error reading func[%d] executable offset: %v", i, err) 211 return 212 } 213 cm.functionOffsets[i] = int(offset) 214 } 215 216 executableLen, err := readUint64(reader, &eightBytes) 217 if err != nil { 218 err = fmt.Errorf("compilationcache: error reading executable size: %v", err) 219 return 220 } 221 222 if executableLen > 0 { 223 executable, err := platform.MmapCodeSegment(int(executableLen)) 224 if err != nil { 225 err = fmt.Errorf("compilationcache: error mmapping executable (len=%d): %v", executableLen, err) 226 return nil, false, err 227 } 228 229 _, err = io.ReadFull(reader, executable) 230 if err != nil { 231 err = fmt.Errorf("compilationcache: error reading executable (len=%d): %v", executableLen, err) 232 return nil, false, err 233 } 234 235 expected := crc32.Checksum(executable, crc) 236 if _, err = io.ReadFull(reader, eightBytes[:4]); err != nil { 237 return nil, false, fmt.Errorf("compilationcache: could not read checksum: %v", err) 238 } else if checksum := binary.LittleEndian.Uint32(eightBytes[:4]); expected != checksum { 239 return nil, false, fmt.Errorf("compilationcache: checksum mismatch (expected %d, got %d)", expected, checksum) 240 } 241 242 if runtime.GOARCH == "arm64" { 243 // On arm64, we cannot give all of rwx at the same time, so we change it to exec. 244 if err = platform.MprotectRX(executable); err != nil { 245 return nil, false, err 246 } 247 } 248 cm.executable = executable 249 } 250 251 if _, err := io.ReadFull(reader, eightBytes[:1]); err != nil { 252 return nil, false, fmt.Errorf("compilationcache: error reading source map presence: %v", err) 253 } 254 255 if eightBytes[0] == 1 { 256 sm := &cm.sourceMap 257 sourceMapLen, err := readUint64(reader, &eightBytes) 258 if err != nil { 259 err = fmt.Errorf("compilationcache: error reading source map length: %v", err) 260 return nil, false, err 261 } 262 executableOffset := uintptr(unsafe.Pointer(&cm.executable[0])) 263 for i := uint64(0); i < sourceMapLen; i++ { 264 wasmBinaryOffset, err := readUint64(reader, &eightBytes) 265 if err != nil { 266 err = fmt.Errorf("compilationcache: error reading source map[%d] wasm binary offset: %v", i, err) 267 return nil, false, err 268 } 269 executableRelativeOffset, err := readUint64(reader, &eightBytes) 270 if err != nil { 271 err = fmt.Errorf("compilationcache: error reading source map[%d] executable offset: %v", i, err) 272 return nil, false, err 273 } 274 sm.wasmBinaryOffsets = append(sm.wasmBinaryOffsets, wasmBinaryOffset) 275 // executableOffsets is absolute address, so we need to add executableOffset. 276 sm.executableOffsets = append(sm.executableOffsets, uintptr(executableRelativeOffset)+executableOffset) 277 } 278 } 279 return 280 } 281 282 // readUint64 strictly reads an uint64 in little-endian byte order, using the 283 // given array as a buffer. This returns io.EOF if less than 8 bytes were read. 284 func readUint64(reader io.Reader, b *[8]byte) (uint64, error) { 285 s := b[0:8] 286 n, err := reader.Read(s) 287 if err != nil { 288 return 0, err 289 } else if n < 8 { // more strict than reader.Read 290 return 0, io.EOF 291 } 292 293 // Read the u64 from the underlying buffer. 294 ret := binary.LittleEndian.Uint64(s) 295 return ret, nil 296 }