github.com/bananabytelabs/wazero@v0.0.0-20240105073314-54b22a776da8/internal/engine/compiler/engine_cache_test.go (about) 1 package compiler 2 3 import ( 4 "bytes" 5 "crypto/sha256" 6 "encoding/binary" 7 "errors" 8 "io" 9 "math" 10 "testing" 11 "testing/iotest" 12 13 "github.com/bananabytelabs/wazero/internal/asm" 14 "github.com/bananabytelabs/wazero/internal/filecache" 15 "github.com/bananabytelabs/wazero/internal/testing/require" 16 "github.com/bananabytelabs/wazero/internal/u32" 17 "github.com/bananabytelabs/wazero/internal/u64" 18 "github.com/bananabytelabs/wazero/internal/wasm" 19 ) 20 21 var testVersion = "" 22 23 func concat(ins ...[]byte) (ret []byte) { 24 for _, in := range ins { 25 ret = append(ret, in...) 26 } 27 return 28 } 29 30 func makeCodeSegment(bytes ...byte) asm.CodeSegment { 31 return *asm.NewCodeSegment(bytes) 32 } 33 34 func TestSerializeCompiledModule(t *testing.T) { 35 tests := []struct { 36 in *compiledModule 37 exp []byte 38 }{ 39 { 40 in: &compiledModule{ 41 compiledCode: &compiledCode{ 42 executable: makeCodeSegment(1, 2, 3, 4, 5), 43 }, 44 functions: []compiledFunction{ 45 {executableOffset: 0, stackPointerCeil: 12345}, 46 }, 47 }, 48 exp: concat( 49 []byte(wazeroMagic), 50 []byte{byte(len(testVersion))}, 51 []byte(testVersion), 52 []byte{0}, // ensure termination. 53 u32.LeBytes(1), // number of functions. 54 u64.LeBytes(12345), // stack pointer ceil. 55 u64.LeBytes(0), // offset. 56 u64.LeBytes(5), // length of code. 57 []byte{1, 2, 3, 4, 5}, // code. 58 ), 59 }, 60 { 61 in: &compiledModule{ 62 compiledCode: &compiledCode{ 63 executable: makeCodeSegment(1, 2, 3, 4, 5), 64 }, 65 functions: []compiledFunction{ 66 {executableOffset: 0, stackPointerCeil: 12345}, 67 }, 68 ensureTermination: true, 69 }, 70 exp: concat( 71 []byte(wazeroMagic), 72 []byte{byte(len(testVersion))}, 73 []byte(testVersion), 74 []byte{1}, // ensure termination. 75 u32.LeBytes(1), // number of functions. 76 u64.LeBytes(12345), // stack pointer ceil. 77 u64.LeBytes(0), // offset. 78 u64.LeBytes(5), // length of code. 79 []byte{1, 2, 3, 4, 5}, // code. 80 ), 81 }, 82 { 83 in: &compiledModule{ 84 compiledCode: &compiledCode{ 85 executable: makeCodeSegment(1, 2, 3, 4, 5, 1, 2, 3), 86 }, 87 functions: []compiledFunction{ 88 {executableOffset: 0, stackPointerCeil: 12345}, 89 {executableOffset: 5, stackPointerCeil: 0xffffffff}, 90 }, 91 ensureTermination: true, 92 }, 93 exp: concat( 94 []byte(wazeroMagic), 95 []byte{byte(len(testVersion))}, 96 []byte(testVersion), 97 []byte{1}, // ensure termination. 98 u32.LeBytes(2), // number of functions. 99 // Function index = 0. 100 u64.LeBytes(12345), // stack pointer ceil. 101 u64.LeBytes(0), // offset. 102 // Function index = 1. 103 u64.LeBytes(0xffffffff), // stack pointer ceil. 104 u64.LeBytes(5), // offset. 105 // Executable. 106 u64.LeBytes(8), // length of code. 107 []byte{1, 2, 3, 4, 5, 1, 2, 3}, // code. 108 ), 109 }, 110 } 111 112 for i, tc := range tests { 113 actual, err := io.ReadAll(serializeCompiledModule(testVersion, tc.in)) 114 require.NoError(t, err, i) 115 require.Equal(t, tc.exp, actual, i) 116 } 117 } 118 119 func TestDeserializeCompiledModule(t *testing.T) { 120 tests := []struct { 121 name string 122 in []byte 123 importedFunctionCount uint32 124 expCompiledModule *compiledModule 125 expStaleCache bool 126 expErr string 127 }{ 128 { 129 name: "invalid header", 130 in: []byte{1}, 131 expErr: "compilationcache: invalid header length: 1", 132 }, 133 { 134 name: "version mismatch", 135 in: concat( 136 []byte(wazeroMagic), 137 []byte{byte(len("1233123.1.1"))}, 138 []byte("1233123.1.1"), 139 u32.LeBytes(1), // number of functions. 140 ), 141 expStaleCache: true, 142 }, 143 { 144 name: "version mismatch", 145 in: concat( 146 []byte(wazeroMagic), 147 []byte{byte(len("1"))}, 148 []byte("1"), 149 u32.LeBytes(1), // number of functions. 150 ), 151 expStaleCache: true, 152 }, 153 { 154 name: "one function", 155 in: concat( 156 []byte(wazeroMagic), 157 []byte{byte(len(testVersion))}, 158 []byte(testVersion), 159 []byte{0}, // ensure termination. 160 u32.LeBytes(1), // number of functions. 161 u64.LeBytes(12345), // stack pointer ceil. 162 u64.LeBytes(0), // offset. 163 // Executable. 164 u64.LeBytes(5), // size. 165 []byte{1, 2, 3, 4, 5}, // machine code. 166 ), 167 expCompiledModule: &compiledModule{ 168 compiledCode: &compiledCode{ 169 executable: makeCodeSegment(1, 2, 3, 4, 5), 170 }, 171 functions: []compiledFunction{ 172 {executableOffset: 0, stackPointerCeil: 12345, index: 0}, 173 }, 174 }, 175 expStaleCache: false, 176 expErr: "", 177 }, 178 { 179 name: "one function with ensure termination", 180 in: concat( 181 []byte(wazeroMagic), 182 []byte{byte(len(testVersion))}, 183 []byte(testVersion), 184 []byte{1}, // ensure termination. 185 u32.LeBytes(1), // number of functions. 186 u64.LeBytes(12345), // stack pointer ceil. 187 u64.LeBytes(0), // offset. 188 u64.LeBytes(5), // length of code. 189 []byte{1, 2, 3, 4, 5}, // code. 190 ), 191 expCompiledModule: &compiledModule{ 192 compiledCode: &compiledCode{ 193 executable: makeCodeSegment(1, 2, 3, 4, 5), 194 }, 195 functions: []compiledFunction{{executableOffset: 0, stackPointerCeil: 12345, index: 0}}, 196 ensureTermination: true, 197 }, 198 expStaleCache: false, 199 expErr: "", 200 }, 201 { 202 name: "two functions", 203 in: concat( 204 []byte(wazeroMagic), 205 []byte{byte(len(testVersion))}, 206 []byte(testVersion), 207 []byte{0}, // ensure termination. 208 u32.LeBytes(2), // number of functions. 209 // Function index = 0. 210 u64.LeBytes(12345), // stack pointer ceil. 211 u64.LeBytes(0), // offset. 212 // Function index = 1. 213 u64.LeBytes(0xffffffff), // stack pointer ceil. 214 u64.LeBytes(7), // offset. 215 // Executable. 216 u64.LeBytes(10), // size. 217 []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, // machine code. 218 ), 219 importedFunctionCount: 1, 220 expCompiledModule: &compiledModule{ 221 compiledCode: &compiledCode{ 222 executable: makeCodeSegment(1, 2, 3, 4, 5, 6, 7, 8, 9, 10), 223 }, 224 functions: []compiledFunction{ 225 {executableOffset: 0, stackPointerCeil: 12345, index: 1}, 226 {executableOffset: 7, stackPointerCeil: 0xffffffff, index: 2}, 227 }, 228 }, 229 expStaleCache: false, 230 expErr: "", 231 }, 232 { 233 name: "reading stack pointer", 234 in: concat( 235 []byte(wazeroMagic), 236 []byte{byte(len(testVersion))}, 237 []byte(testVersion), 238 []byte{0}, // ensure termination. 239 u32.LeBytes(2), // number of functions. 240 // Function index = 0. 241 u64.LeBytes(12345), // stack pointer ceil. 242 u64.LeBytes(5), // offset. 243 // Function index = 1. 244 ), 245 expErr: "compilationcache: error reading func[1] stack pointer ceil: EOF", 246 }, 247 { 248 name: "reading executable offset", 249 in: concat( 250 []byte(wazeroMagic), 251 []byte{byte(len(testVersion))}, 252 []byte(testVersion), 253 []byte{0}, // ensure termination. 254 u32.LeBytes(2), // number of functions. 255 // Function index = 0. 256 u64.LeBytes(12345), // stack pointer ceil. 257 u64.LeBytes(5), // offset. 258 // Function index = 1. 259 u64.LeBytes(12345), // stack pointer ceil. 260 ), 261 expErr: "compilationcache: error reading func[1] executable offset: EOF", 262 }, 263 { 264 name: "mmapping", 265 in: concat( 266 []byte(wazeroMagic), 267 []byte{byte(len(testVersion))}, 268 []byte(testVersion), 269 []byte{0}, // ensure termination. 270 u32.LeBytes(2), // number of functions. 271 // Function index = 0. 272 u64.LeBytes(12345), // stack pointer ceil. 273 u64.LeBytes(0), // offset. 274 // Function index = 1. 275 u64.LeBytes(12345), // stack pointer ceil. 276 u64.LeBytes(5), // offset. 277 // Executable. 278 u64.LeBytes(5), // size of the executable. 279 // Lack of machine code here. 280 ), 281 expErr: "compilationcache: error reading executable (len=5): EOF", 282 }, 283 } 284 285 for _, tc := range tests { 286 tc := tc 287 t.Run(tc.name, func(t *testing.T) { 288 cm, staleCache, err := deserializeCompiledModule(testVersion, io.NopCloser(bytes.NewReader(tc.in)), 289 &wasm.Module{ImportFunctionCount: tc.importedFunctionCount}) 290 291 if tc.expCompiledModule != nil { 292 require.Equal(t, len(tc.expCompiledModule.functions), len(cm.functions)) 293 for i := 0; i < len(cm.functions); i++ { 294 require.Equal(t, cm.compiledCode, cm.functions[i].parent) 295 tc.expCompiledModule.functions[i].parent = cm.compiledCode 296 } 297 } 298 299 if tc.expErr != "" { 300 require.EqualError(t, err, tc.expErr) 301 } else { 302 require.NoError(t, err) 303 require.Equal(t, tc.expCompiledModule, cm) 304 } 305 306 require.Equal(t, tc.expStaleCache, staleCache) 307 }) 308 } 309 } 310 311 func TestEngine_getCompiledModuleFromCache(t *testing.T) { 312 valid := concat( 313 []byte(wazeroMagic), 314 []byte{byte(len(testVersion))}, 315 []byte(testVersion), 316 []byte{0}, // ensure termination. 317 u32.LeBytes(2), // number of functions. 318 // Function index = 0. 319 u64.LeBytes(12345), // stack pointer ceil. 320 u64.LeBytes(0), // offset. 321 // Function index = 1. 322 u64.LeBytes(0xffffffff), // stack pointer ceil. 323 u64.LeBytes(5), // offset. 324 // executables. 325 u64.LeBytes(10), // length of code. 326 []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, // code. 327 ) 328 329 tests := []struct { 330 name string 331 ext map[wasm.ModuleID][]byte 332 key wasm.ModuleID 333 isHostMod bool 334 expCompiledModule *compiledModule 335 expHit bool 336 expErr string 337 expDeleted bool 338 }{ 339 {name: "extern cache not given"}, 340 { 341 name: "not hit", 342 ext: map[wasm.ModuleID][]byte{}, 343 }, 344 { 345 name: "host module", 346 ext: map[wasm.ModuleID][]byte{{}: valid}, 347 isHostMod: true, 348 }, 349 { 350 name: "error in Cache.Get", 351 ext: map[wasm.ModuleID][]byte{{}: {}}, 352 expErr: "compilationcache: error reading header: EOF", 353 }, 354 { 355 name: "error in deserialization", 356 ext: map[wasm.ModuleID][]byte{{}: {1, 2, 3}}, 357 expErr: "compilationcache: invalid header length: 3", 358 }, 359 { 360 name: "stale cache", 361 ext: map[wasm.ModuleID][]byte{{}: concat( 362 []byte(wazeroMagic), 363 []byte{byte(len("1233123.1.1"))}, 364 []byte("1233123.1.1"), 365 u32.LeBytes(1), // number of functions. 366 )}, 367 expDeleted: true, 368 }, 369 { 370 name: "hit", 371 ext: map[wasm.ModuleID][]byte{ 372 {}: valid, 373 }, 374 expHit: true, 375 expCompiledModule: &compiledModule{ 376 compiledCode: &compiledCode{ 377 executable: makeCodeSegment(1, 2, 3, 4, 5, 6, 7, 8, 9, 10), 378 }, 379 functions: []compiledFunction{ 380 {stackPointerCeil: 12345, executableOffset: 0, index: 0}, 381 {stackPointerCeil: 0xffffffff, executableOffset: 5, index: 1}, 382 }, 383 }, 384 }, 385 } 386 387 for _, tc := range tests { 388 tc := tc 389 t.Run(tc.name, func(t *testing.T) { 390 m := &wasm.Module{ID: tc.key, IsHostModule: tc.isHostMod} 391 if exp := tc.expCompiledModule; exp != nil { 392 exp.source = m 393 for i := range tc.expCompiledModule.functions { 394 tc.expCompiledModule.functions[i].parent = exp.compiledCode 395 } 396 } 397 398 e := engine{} 399 if tc.ext != nil { 400 tmp := t.TempDir() 401 e.fileCache = filecache.New(tmp) 402 for key, value := range tc.ext { 403 err := e.fileCache.Add(key, bytes.NewReader(value)) 404 require.NoError(t, err) 405 } 406 } 407 408 codes, hit, err := e.getCompiledModuleFromCache(m) 409 if tc.expErr != "" { 410 require.EqualError(t, err, tc.expErr) 411 } else { 412 require.NoError(t, err) 413 } 414 415 require.Equal(t, tc.expHit, hit) 416 require.Equal(t, tc.expCompiledModule, codes) 417 418 if tc.ext != nil && tc.expDeleted { 419 _, hit, err := e.fileCache.Get(tc.key) 420 require.NoError(t, err) 421 require.False(t, hit) 422 } 423 }) 424 } 425 } 426 427 func TestEngine_addCompiledModuleToCache(t *testing.T) { 428 t.Run("not defined", func(t *testing.T) { 429 e := engine{} 430 err := e.addCompiledModuleToCache(nil, nil) 431 require.NoError(t, err) 432 }) 433 t.Run("host module", func(t *testing.T) { 434 tc := filecache.New(t.TempDir()) 435 e := engine{fileCache: tc} 436 cm := &compiledModule{ 437 compiledCode: &compiledCode{ 438 executable: makeCodeSegment(1, 2, 3), 439 }, 440 functions: []compiledFunction{{stackPointerCeil: 123}}, 441 } 442 m := &wasm.Module{ID: sha256.Sum256(nil), IsHostModule: true} // Host module! 443 err := e.addCompiledModuleToCache(m, cm) 444 require.NoError(t, err) 445 // Check the host module not cached. 446 _, hit, err := tc.Get(m.ID) 447 require.NoError(t, err) 448 require.False(t, hit) 449 }) 450 t.Run("add", func(t *testing.T) { 451 tc := filecache.New(t.TempDir()) 452 e := engine{fileCache: tc} 453 m := &wasm.Module{} 454 cm := &compiledModule{ 455 compiledCode: &compiledCode{ 456 executable: makeCodeSegment(1, 2, 3), 457 }, 458 functions: []compiledFunction{{stackPointerCeil: 123}}, 459 } 460 err := e.addCompiledModuleToCache(m, cm) 461 require.NoError(t, err) 462 463 content, ok, err := tc.Get(m.ID) 464 require.NoError(t, err) 465 require.True(t, ok) 466 actual, err := io.ReadAll(content) 467 require.NoError(t, err) 468 require.Equal(t, concat( 469 []byte(wazeroMagic), 470 []byte{byte(len(testVersion))}, 471 []byte(testVersion), 472 []byte{0}, 473 u32.LeBytes(1), // number of functions. 474 u64.LeBytes(123), // stack pointer ceil. 475 u64.LeBytes(0), // offset. 476 u64.LeBytes(3), // size of executable. 477 []byte{1, 2, 3}, 478 ), actual) 479 require.NoError(t, content.Close()) 480 }) 481 } 482 483 func Test_readUint64(t *testing.T) { 484 tests := []struct { 485 name string 486 input uint64 487 }{ 488 { 489 name: "zero", 490 input: 0, 491 }, 492 { 493 name: "half", 494 input: math.MaxUint32, 495 }, 496 { 497 name: "max", 498 input: math.MaxUint64, 499 }, 500 } 501 502 for _, tt := range tests { 503 tc := tt 504 505 t.Run(tc.name, func(t *testing.T) { 506 input := make([]byte, 8) 507 binary.LittleEndian.PutUint64(input, tc.input) 508 509 var b [8]byte 510 n, err := readUint64(bytes.NewReader(input), &b) 511 require.NoError(t, err) 512 require.Equal(t, tc.input, n) 513 514 // ensure the buffer was cleared 515 var expectedB [8]byte 516 require.Equal(t, expectedB, b) 517 }) 518 } 519 } 520 521 func Test_readUint64_errors(t *testing.T) { 522 tests := []struct { 523 name string 524 input io.Reader 525 expectedErr string 526 }{ 527 { 528 name: "zero", 529 input: bytes.NewReader([]byte{}), 530 expectedErr: "EOF", 531 }, 532 { 533 name: "not enough", 534 input: bytes.NewReader([]byte{1, 2}), 535 expectedErr: "EOF", 536 }, 537 { 538 name: "error reading", 539 input: iotest.ErrReader(errors.New("ice cream")), 540 expectedErr: "ice cream", 541 }, 542 } 543 544 for _, tt := range tests { 545 tc := tt 546 547 t.Run(tc.name, func(t *testing.T) { 548 var b [8]byte 549 _, err := readUint64(tc.input, &b) 550 require.EqualError(t, err, tc.expectedErr) 551 }) 552 } 553 }