github.com/tetratelabs/wazero@v1.7.3-0.20240513003603-48f702e154b5/internal/testing/nodiff/nodiff.go (about) 1 package nodiff 2 3 import ( 4 "bytes" 5 "context" 6 "errors" 7 "fmt" 8 "sort" 9 "strings" 10 "testing" 11 "unsafe" 12 13 "github.com/tetratelabs/wazero" 14 "github.com/tetratelabs/wazero/api" 15 "github.com/tetratelabs/wazero/experimental" 16 "github.com/tetratelabs/wazero/experimental/logging" 17 "github.com/tetratelabs/wazero/internal/testing/binaryencoding" 18 "github.com/tetratelabs/wazero/internal/testing/require" 19 "github.com/tetratelabs/wazero/internal/wasm" 20 ) 21 22 // We haven't had public APIs for referencing all the imported entries from wazero.CompiledModule, 23 // so we use the unsafe.Pointer and the internal memory layout to get the internal *wasm.Module 24 // from wazero.CompiledFunction. This must be synced with the struct definition of wazero.compiledModule (internal one). 25 func extractInternalWasmModuleFromCompiledModule(c wazero.CompiledModule) (*wasm.Module, error) { 26 // This is the internal representation of interface in Go. 27 // https://research.swtch.com/interfaces 28 type iface struct { 29 tp *byte 30 data unsafe.Pointer 31 } 32 33 // This corresponds to the unexported wazero.compiledModule to get *wasm.Module from wazero.CompiledModule interface. 34 type compiledModule struct { 35 module *wasm.Module 36 } 37 38 ciface := (*iface)(unsafe.Pointer(&c)) 39 if ciface == nil { 40 return nil, errors.New("invalid pointer") 41 } 42 cm := (*compiledModule)(ciface.data) 43 return cm.module, nil 44 } 45 46 // RequireNoDiffT is a wrapper of RequireNoDiff for testing.T. 47 func RequireNoDiffT(t *testing.T, wasmBin []byte, checkMemory, loggingCheck bool) { 48 RequireNoDiff(wasmBin, checkMemory, loggingCheck, func(err error) { require.NoError(t, err) }) 49 } 50 51 // RequireNoDiff ensures that the behavior is the same between the compiler and the interpreter for any given binary. 52 func RequireNoDiff(wasmBin []byte, checkMemory, loggingCheck bool, requireNoError func(err error)) { 53 const features = api.CoreFeaturesV2 | experimental.CoreFeaturesThreads 54 compiler := wazero.NewRuntimeWithConfig(context.Background(), wazero.NewRuntimeConfigCompiler().WithCoreFeatures(features)) 55 interpreter := wazero.NewRuntimeWithConfig(context.Background(), wazero.NewRuntimeConfigInterpreter().WithCoreFeatures(features)) 56 defer compiler.Close(context.Background()) 57 defer interpreter.Close(context.Background()) 58 59 interpreterCtx, compilerCtx := context.Background(), context.Background() 60 var interPreterLoggingBuf, compilerLoggingBuf bytes.Buffer 61 var errorDuringInvocation bool 62 if loggingCheck { 63 interpreterCtx = experimental.WithFunctionListenerFactory(interpreterCtx, logging.NewLoggingListenerFactory(&interPreterLoggingBuf)) 64 compilerCtx = experimental.WithFunctionListenerFactory(compilerCtx, logging.NewLoggingListenerFactory(&compilerLoggingBuf)) 65 defer func() { 66 if !errorDuringInvocation { 67 if !bytes.Equal(compilerLoggingBuf.Bytes(), interPreterLoggingBuf.Bytes()) { 68 requireNoError(fmt.Errorf("logging mismatch\ncompiler: %s\ninterpreter: %s", 69 compilerLoggingBuf.String(), interPreterLoggingBuf.String())) 70 } 71 } 72 }() 73 } 74 75 compilerCompiled, err := compiler.CompileModule(compilerCtx, wasmBin) 76 if err != nil && strings.Contains(err.Error(), "has an empty module name") { 77 // This is the limitation wazero imposes to allow special-casing of anonymous modules. 78 return 79 } 80 requireNoError(err) 81 82 interpreterCompiled, err := interpreter.CompileModule(interpreterCtx, wasmBin) 83 requireNoError(err) 84 85 internalMod, err := extractInternalWasmModuleFromCompiledModule(compilerCompiled) 86 requireNoError(err) 87 88 if skip := ensureDummyImports(compiler, internalMod, requireNoError); skip { 89 return 90 } 91 ensureDummyImports(interpreter, internalMod, requireNoError) 92 93 // Instantiate module. 94 compilerMod, compilerInstErr := compiler.InstantiateModule(compilerCtx, compilerCompiled, 95 wazero.NewModuleConfig().WithName(string(internalMod.ID[:]))) 96 interpreterMod, interpreterInstErr := interpreter.InstantiateModule(interpreterCtx, interpreterCompiled, 97 wazero.NewModuleConfig().WithName(string(internalMod.ID[:]))) 98 99 okToInvoke, err := ensureInstantiationError(compilerInstErr, interpreterInstErr) 100 requireNoError(err) 101 102 if okToInvoke { 103 err, errorDuringInvocation = ensureInvocationResultMatch( 104 compilerCtx, interpreterCtx, 105 compilerMod, interpreterMod, interpreterCompiled.ExportedFunctions()) 106 requireNoError(err) 107 108 compilerMem, _ := compilerMod.Memory().(*wasm.MemoryInstance) 109 interpreterMem, _ := interpreterMod.Memory().(*wasm.MemoryInstance) 110 if checkMemory && compilerMem != nil && interpreterMem != nil { 111 if !bytes.Equal(compilerMem.Buffer, interpreterMem.Buffer) { 112 requireNoError(errors.New("memory state mimsmatch")) 113 } 114 } 115 ensureMutableGlobalsMatch(compilerMod, interpreterMod, requireNoError) 116 } 117 } 118 119 func ensureMutableGlobalsMatch(compilerMod, interpreterMod api.Module, requireNoError func(err error)) { 120 ci, ii := compilerMod.(*wasm.ModuleInstance), interpreterMod.(*wasm.ModuleInstance) 121 if len(ci.Globals) == 0 { 122 return 123 } 124 var es []string 125 for i := range ci.Globals[:len(ci.Globals)-1] { // The last global is the fuel, so we can ignore it. 126 cg := ci.Globals[i] 127 ig := ii.Globals[i] 128 if !cg.Type.Mutable { 129 continue 130 } 131 132 cVal, cValHi := cg.Value() 133 iVal, iValHi := ig.Value() 134 135 var ok bool 136 switch ig.Type.ValType { 137 case wasm.ValueTypeI32, wasm.ValueTypeF32: 138 ok = uint32(cVal) == uint32(iVal) 139 case wasm.ValueTypeI64, wasm.ValueTypeF64: 140 ok = cVal == iVal 141 case wasm.ValueTypeV128: 142 ok = cVal == iVal && cValHi == iValHi 143 default: 144 ok = true // Ignore other types. 145 } 146 147 if !ok { 148 if typ := ig.Type.ValType; typ == wasm.ValueTypeV128 { 149 es = append(es, fmt.Sprintf("\t[%d] %s: (%v,%v) != (%v,%v)", 150 i, wasm.ValueTypeName(wasm.ValueTypeV128), cVal, cValHi, iVal, iValHi)) 151 } else { 152 es = append(es, fmt.Sprintf("\t[%d] %s: %v != %v", 153 i, wasm.ValueTypeName(typ), cVal, iVal)) 154 } 155 } 156 } 157 if len(es) > 0 { 158 requireNoError(fmt.Errorf("mutable globals mismatch:\n%s", strings.Join(es, "\n"))) 159 } 160 } 161 162 // ensureDummyImports instantiates the modules which are required imports by `origin` *wasm.Module. 163 func ensureDummyImports(r wazero.Runtime, origin *wasm.Module, requireNoError func(err error)) (skip bool) { 164 impMods := make(map[string][]wasm.Import) 165 for _, imp := range origin.ImportSection { 166 if imp.Module == "" { 167 // Importing empty modules are forbidden as future work will allow multiple anonymous modules. 168 skip = true 169 return 170 } 171 impMods[imp.Module] = append(impMods[imp.Module], imp) 172 } 173 174 for mName, impMod := range impMods { 175 usedName := make(map[string]struct{}, len(impMod)) 176 m := &wasm.Module{NameSection: &wasm.NameSection{ModuleName: mName}} 177 178 for _, imp := range impMod { 179 _, ok := usedName[imp.Name] 180 if ok { 181 // Import segment can have duplicated "{module_name}.{name}" pair while it is prohibited for exports. 182 // Decision on allowing modules with these "ill" imports or not is up to embedder, and wazero chooses 183 // not to allow. Hence, we skip the entire case. 184 // See "Note" at https://www.w3.org/TR/wasm-core-2/syntax/modules.html#imports 185 return true 186 } else { 187 usedName[imp.Name] = struct{}{} 188 } 189 190 var index uint32 191 switch imp.Type { 192 case wasm.ExternTypeFunc: 193 tp := origin.TypeSection[imp.DescFunc] 194 typeIdx := uint32(len(m.TypeSection)) 195 index = uint32(len(m.FunctionSection)) 196 m.FunctionSection = append(m.FunctionSection, typeIdx) 197 m.TypeSection = append(m.TypeSection, tp) 198 body := bytes.NewBuffer(nil) 199 for _, vt := range tp.Results { 200 switch vt { 201 case wasm.ValueTypeI32: 202 body.WriteByte(wasm.OpcodeI32Const) 203 body.WriteByte(0) 204 case wasm.ValueTypeI64: 205 body.WriteByte(wasm.OpcodeI64Const) 206 body.WriteByte(0) 207 case wasm.ValueTypeF32: 208 body.Write([]byte{wasm.OpcodeF32Const, 0, 0, 0, 0}) 209 case wasm.ValueTypeF64: 210 body.Write([]byte{wasm.OpcodeF64Const, 0, 0, 0, 0, 0, 0, 0, 0}) 211 case wasm.ValueTypeV128: 212 body.Write([]byte{ 213 wasm.OpcodeVecPrefix, wasm.OpcodeVecV128Const, 214 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 215 }) 216 case wasm.ValueTypeExternref: 217 body.Write([]byte{wasm.OpcodeRefNull, wasm.RefTypeExternref}) 218 case wasm.ValueTypeFuncref: 219 body.Write([]byte{wasm.OpcodeRefNull, wasm.RefTypeFuncref}) 220 } 221 } 222 body.WriteByte(wasm.OpcodeEnd) 223 m.CodeSection = append(m.CodeSection, wasm.Code{Body: body.Bytes()}) 224 case wasm.ExternTypeGlobal: 225 index = uint32(len(m.GlobalSection)) 226 var data []byte 227 var opcode byte 228 switch imp.DescGlobal.ValType { 229 case wasm.ValueTypeI32: 230 opcode = wasm.OpcodeI32Const 231 data = []byte{0} 232 case wasm.ValueTypeI64: 233 opcode = wasm.OpcodeI64Const 234 data = []byte{0} 235 case wasm.ValueTypeF32: 236 opcode = wasm.OpcodeF32Const 237 data = []byte{0, 0, 0, 0} 238 case wasm.ValueTypeF64: 239 opcode = wasm.OpcodeF64Const 240 data = []byte{0, 0, 0, 0, 0, 0, 0, 0} 241 case wasm.ValueTypeV128: 242 opcode = wasm.OpcodeVecPrefix 243 data = []byte{wasm.OpcodeVecV128Const, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} 244 case wasm.ValueTypeExternref: 245 opcode = wasm.OpcodeRefNull 246 data = []byte{wasm.RefTypeExternref} 247 case wasm.ValueTypeFuncref: 248 opcode = wasm.OpcodeRefNull 249 data = []byte{wasm.RefTypeFuncref} 250 } 251 m.GlobalSection = append(m.GlobalSection, wasm.Global{ 252 Type: imp.DescGlobal, Init: wasm.ConstantExpression{Opcode: opcode, Data: data}, 253 }) 254 case wasm.ExternTypeMemory: 255 m.MemorySection = imp.DescMem 256 index = 0 257 case wasm.ExternTypeTable: 258 index = uint32(len(m.TableSection)) 259 m.TableSection = append(m.TableSection, imp.DescTable) 260 } 261 m.ExportSection = append(m.ExportSection, wasm.Export{Type: imp.Type, Name: imp.Name, Index: index}) 262 } 263 _, err := r.Instantiate(context.Background(), binaryencoding.EncodeModule(m)) 264 requireNoError(err) 265 } 266 return 267 } 268 269 const valueTypeVector = 0x7b 270 271 // ensureInvocationResultMatch invokes all the exported functions from the module, and compare all the results between compiler vs interpreter. 272 func ensureInvocationResultMatch( 273 compilerCtx, interpreterCtx context.Context, compiledMod, interpreterMod api.Module, 274 exportedFunctions map[string]api.FunctionDefinition, 275 ) (err error, errorDuringInvocation bool) { 276 // In order to do the deterministic execution, we need to sort the exported functions. 277 var names []string 278 for f := range exportedFunctions { 279 names = append(names, f) 280 } 281 sort.Strings(names) 282 283 outer: 284 for _, name := range names { 285 def := exportedFunctions[name] 286 resultTypes := def.ResultTypes() 287 for _, rt := range resultTypes { 288 switch rt { 289 case api.ValueTypeI32, api.ValueTypeI64, api.ValueTypeF32, api.ValueTypeF64, valueTypeVector: 290 default: 291 // For the sake of simplicity in the assertion, we only invoke the function with the basic types. 292 continue outer 293 } 294 } 295 296 cmpF := compiledMod.ExportedFunction(name) 297 intF := interpreterMod.ExportedFunction(name) 298 299 params := getDummyValues(def.ParamTypes()) 300 cmpRes, cmpErr := cmpF.Call(compilerCtx, params...) 301 intRes, intErr := intF.Call(interpreterCtx, params...) 302 errorDuringInvocation = errorDuringInvocation || cmpErr != nil || intErr != nil 303 if errMismatch := ensureInvocationError(cmpErr, intErr); errMismatch != nil { 304 err = errors.Join(err, fmt.Errorf("error mismatch on invoking %s: %v", name, errMismatch)) 305 continue 306 } 307 308 matched := true 309 var typesIndex int 310 for i := 0; i < len(cmpRes); i++ { 311 switch resultTypes[typesIndex] { 312 case api.ValueTypeI32, api.ValueTypeF32: 313 matched = matched && uint32(cmpRes[i]) == uint32(intRes[i]) 314 case api.ValueTypeI64, api.ValueTypeF64: 315 matched = matched && cmpRes[i] == intRes[i] 316 case valueTypeVector: 317 matched = matched && cmpRes[i] == intRes[i] && cmpRes[i+1] == intRes[i+1] 318 i++ // We need to advance twice (lower and higher 64bits) 319 } 320 typesIndex++ 321 } 322 323 if !matched { 324 err = errors.Join(err, fmt.Errorf("result mismatch on invoking '%s':\n\tinterpreter got: %v\n\tcompiler got: %v", name, intRes, cmpRes)) 325 } 326 } 327 return 328 } 329 330 // getDummyValues returns a dummy input values for function invocations. 331 func getDummyValues(valueTypes []api.ValueType) (ret []uint64) { 332 for _, vt := range valueTypes { 333 if vt != 0x7b { // v128 334 ret = append(ret, 0) 335 } else { 336 ret = append(ret, 0, 0) 337 } 338 } 339 return 340 } 341 342 // ensureInvocationError ensures that function invocation errors returned by interpreter and compiler match each other's. 343 func ensureInvocationError(compilerErr, interpErr error) error { 344 if compilerErr == nil && interpErr == nil { 345 return nil 346 } else if compilerErr == nil && interpErr != nil { 347 return fmt.Errorf("compiler returned no error, but interpreter got: %w", interpErr) 348 } else if compilerErr != nil && interpErr == nil { 349 return fmt.Errorf("interpreter returned no error, but compiler got: %w", compilerErr) 350 } 351 352 compilerErrMsg, interpErrMsg := compilerErr.Error(), interpErr.Error() 353 if idx := strings.Index(compilerErrMsg, "\n"); idx >= 0 { 354 compilerErrMsg = compilerErrMsg[:strings.Index(compilerErrMsg, "\n")] 355 } 356 if idx := strings.Index(interpErrMsg, "\n"); idx >= 0 { 357 interpErrMsg = interpErrMsg[:strings.Index(interpErrMsg, "\n")] 358 } 359 360 if compiledStackOverFlow := strings.Contains(compilerErrMsg, "stack overflow"); compiledStackOverFlow && strings.Contains(interpErrMsg, "unreachable") { 361 // Compiler is more likely to reach stack overflow than interpreter, so we allow this case. This case is most likely 362 // that interpreter reached the unreachable out of "fuel". 363 return nil 364 } else if interpreterStackOverFlow := strings.Contains(interpErrMsg, "stack overflow"); compiledStackOverFlow && interpreterStackOverFlow { 365 // Both compiler and interpreter reached stack overflow, so we ignore diff in the content of the traces. 366 return nil 367 } 368 369 if compilerErrMsg != interpErrMsg { 370 return fmt.Errorf("error mismatch:\n\tinterpreter: %v\n\tcompiler: %v", interpErr, compilerErr) 371 } 372 return nil 373 } 374 375 // ensureInstantiationError ensures that instantiation errors returned by interpreter and compiler match each other's. 376 func ensureInstantiationError(compilerErr, interpErr error) (okToInvoke bool, err error) { 377 if compilerErr == nil && interpErr == nil { 378 return true, nil 379 } else if compilerErr == nil && interpErr != nil { 380 return false, fmt.Errorf("compiler returned no error, but interpreter got: %w", interpErr) 381 } else if compilerErr != nil && interpErr == nil { 382 return false, fmt.Errorf("interpreter returned no error, but compiler got: %w", compilerErr) 383 } 384 385 compilerErrMsg, interpErrMsg := compilerErr.Error(), interpErr.Error() 386 if idx := strings.Index(compilerErrMsg, "\n"); idx >= 0 { 387 compilerErrMsg = compilerErrMsg[:strings.Index(compilerErrMsg, "\n")] 388 } 389 if idx := strings.Index(interpErrMsg, "\n"); idx >= 0 { 390 interpErrMsg = interpErrMsg[:strings.Index(interpErrMsg, "\n")] 391 } 392 393 if strings.Contains(compilerErrMsg, "stack overflow") && strings.Contains(interpErrMsg, "unreachable") { 394 // This is the case where the compiler reached stack overflow, but the interpreter reached the unreachable out of "fuel" during 395 // start function invocation. This is fine. 396 return false, nil 397 } 398 399 if !allowedErrorDuringInstantiation(compilerErrMsg) { 400 return false, fmt.Errorf("invalid error occur with compiler: %v vs interpreter: %v", compilerErr, interpErr) 401 } else if !allowedErrorDuringInstantiation(interpErrMsg) { 402 return false, fmt.Errorf("invalid error occur with interpreter: %v vs compiler: %v", interpErr, compilerErr) 403 } 404 405 if compilerErrMsg != interpErrMsg { 406 return false, fmt.Errorf("error mismatch:\n\tinterpreter: %v\n\tcompiler: %v", interpErr, compilerErr) 407 } 408 return false, nil 409 } 410 411 // allowedErrorDuringInstantiation checks if the error message is considered sane. 412 func allowedErrorDuringInstantiation(errMsg string) bool { 413 // This happens when data segment causes out of bound, but it is considered as runtime-error in WebAssembly 2.0 414 // which is fine. 415 if strings.HasPrefix(errMsg, "data[") && strings.HasSuffix(errMsg, "]: out of bounds memory access") { 416 return true 417 } 418 419 // Start function failure is neither instantiation nor compilation error, but rather a runtime error, so that is fine. 420 if strings.HasPrefix(errMsg, "start function[") && strings.Contains(errMsg, "failed: wasm error:") { 421 return true 422 } 423 return false 424 }