github.com/tetratelabs/wazero@v1.7.3-0.20240513003603-48f702e154b5/internal/testing/nodiff/nodiff.go (about)

     1  package nodiff
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"errors"
     7  	"fmt"
     8  	"sort"
     9  	"strings"
    10  	"testing"
    11  	"unsafe"
    12  
    13  	"github.com/tetratelabs/wazero"
    14  	"github.com/tetratelabs/wazero/api"
    15  	"github.com/tetratelabs/wazero/experimental"
    16  	"github.com/tetratelabs/wazero/experimental/logging"
    17  	"github.com/tetratelabs/wazero/internal/testing/binaryencoding"
    18  	"github.com/tetratelabs/wazero/internal/testing/require"
    19  	"github.com/tetratelabs/wazero/internal/wasm"
    20  )
    21  
    22  // We haven't had public APIs for referencing all the imported entries from wazero.CompiledModule,
    23  // so we use the unsafe.Pointer and the internal memory layout to get the internal *wasm.Module
    24  // from wazero.CompiledFunction.  This must be synced with the struct definition of wazero.compiledModule (internal one).
    25  func extractInternalWasmModuleFromCompiledModule(c wazero.CompiledModule) (*wasm.Module, error) {
    26  	// This is the internal representation of interface in Go.
    27  	// https://research.swtch.com/interfaces
    28  	type iface struct {
    29  		tp   *byte
    30  		data unsafe.Pointer
    31  	}
    32  
    33  	// This corresponds to the unexported wazero.compiledModule to get *wasm.Module from wazero.CompiledModule interface.
    34  	type compiledModule struct {
    35  		module *wasm.Module
    36  	}
    37  
    38  	ciface := (*iface)(unsafe.Pointer(&c))
    39  	if ciface == nil {
    40  		return nil, errors.New("invalid pointer")
    41  	}
    42  	cm := (*compiledModule)(ciface.data)
    43  	return cm.module, nil
    44  }
    45  
    46  // RequireNoDiffT is a wrapper of RequireNoDiff for testing.T.
    47  func RequireNoDiffT(t *testing.T, wasmBin []byte, checkMemory, loggingCheck bool) {
    48  	RequireNoDiff(wasmBin, checkMemory, loggingCheck, func(err error) { require.NoError(t, err) })
    49  }
    50  
    51  // RequireNoDiff ensures that the behavior is the same between the compiler and the interpreter for any given binary.
    52  func RequireNoDiff(wasmBin []byte, checkMemory, loggingCheck bool, requireNoError func(err error)) {
    53  	const features = api.CoreFeaturesV2 | experimental.CoreFeaturesThreads
    54  	compiler := wazero.NewRuntimeWithConfig(context.Background(), wazero.NewRuntimeConfigCompiler().WithCoreFeatures(features))
    55  	interpreter := wazero.NewRuntimeWithConfig(context.Background(), wazero.NewRuntimeConfigInterpreter().WithCoreFeatures(features))
    56  	defer compiler.Close(context.Background())
    57  	defer interpreter.Close(context.Background())
    58  
    59  	interpreterCtx, compilerCtx := context.Background(), context.Background()
    60  	var interPreterLoggingBuf, compilerLoggingBuf bytes.Buffer
    61  	var errorDuringInvocation bool
    62  	if loggingCheck {
    63  		interpreterCtx = experimental.WithFunctionListenerFactory(interpreterCtx, logging.NewLoggingListenerFactory(&interPreterLoggingBuf))
    64  		compilerCtx = experimental.WithFunctionListenerFactory(compilerCtx, logging.NewLoggingListenerFactory(&compilerLoggingBuf))
    65  		defer func() {
    66  			if !errorDuringInvocation {
    67  				if !bytes.Equal(compilerLoggingBuf.Bytes(), interPreterLoggingBuf.Bytes()) {
    68  					requireNoError(fmt.Errorf("logging mismatch\ncompiler: %s\ninterpreter: %s",
    69  						compilerLoggingBuf.String(), interPreterLoggingBuf.String()))
    70  				}
    71  			}
    72  		}()
    73  	}
    74  
    75  	compilerCompiled, err := compiler.CompileModule(compilerCtx, wasmBin)
    76  	if err != nil && strings.Contains(err.Error(), "has an empty module name") {
    77  		// This is the limitation wazero imposes to allow special-casing of anonymous modules.
    78  		return
    79  	}
    80  	requireNoError(err)
    81  
    82  	interpreterCompiled, err := interpreter.CompileModule(interpreterCtx, wasmBin)
    83  	requireNoError(err)
    84  
    85  	internalMod, err := extractInternalWasmModuleFromCompiledModule(compilerCompiled)
    86  	requireNoError(err)
    87  
    88  	if skip := ensureDummyImports(compiler, internalMod, requireNoError); skip {
    89  		return
    90  	}
    91  	ensureDummyImports(interpreter, internalMod, requireNoError)
    92  
    93  	// Instantiate module.
    94  	compilerMod, compilerInstErr := compiler.InstantiateModule(compilerCtx, compilerCompiled,
    95  		wazero.NewModuleConfig().WithName(string(internalMod.ID[:])))
    96  	interpreterMod, interpreterInstErr := interpreter.InstantiateModule(interpreterCtx, interpreterCompiled,
    97  		wazero.NewModuleConfig().WithName(string(internalMod.ID[:])))
    98  
    99  	okToInvoke, err := ensureInstantiationError(compilerInstErr, interpreterInstErr)
   100  	requireNoError(err)
   101  
   102  	if okToInvoke {
   103  		err, errorDuringInvocation = ensureInvocationResultMatch(
   104  			compilerCtx, interpreterCtx,
   105  			compilerMod, interpreterMod, interpreterCompiled.ExportedFunctions())
   106  		requireNoError(err)
   107  
   108  		compilerMem, _ := compilerMod.Memory().(*wasm.MemoryInstance)
   109  		interpreterMem, _ := interpreterMod.Memory().(*wasm.MemoryInstance)
   110  		if checkMemory && compilerMem != nil && interpreterMem != nil {
   111  			if !bytes.Equal(compilerMem.Buffer, interpreterMem.Buffer) {
   112  				requireNoError(errors.New("memory state mimsmatch"))
   113  			}
   114  		}
   115  		ensureMutableGlobalsMatch(compilerMod, interpreterMod, requireNoError)
   116  	}
   117  }
   118  
   119  func ensureMutableGlobalsMatch(compilerMod, interpreterMod api.Module, requireNoError func(err error)) {
   120  	ci, ii := compilerMod.(*wasm.ModuleInstance), interpreterMod.(*wasm.ModuleInstance)
   121  	if len(ci.Globals) == 0 {
   122  		return
   123  	}
   124  	var es []string
   125  	for i := range ci.Globals[:len(ci.Globals)-1] { // The last global is the fuel, so we can ignore it.
   126  		cg := ci.Globals[i]
   127  		ig := ii.Globals[i]
   128  		if !cg.Type.Mutable {
   129  			continue
   130  		}
   131  
   132  		cVal, cValHi := cg.Value()
   133  		iVal, iValHi := ig.Value()
   134  
   135  		var ok bool
   136  		switch ig.Type.ValType {
   137  		case wasm.ValueTypeI32, wasm.ValueTypeF32:
   138  			ok = uint32(cVal) == uint32(iVal)
   139  		case wasm.ValueTypeI64, wasm.ValueTypeF64:
   140  			ok = cVal == iVal
   141  		case wasm.ValueTypeV128:
   142  			ok = cVal == iVal && cValHi == iValHi
   143  		default:
   144  			ok = true // Ignore other types.
   145  		}
   146  
   147  		if !ok {
   148  			if typ := ig.Type.ValType; typ == wasm.ValueTypeV128 {
   149  				es = append(es, fmt.Sprintf("\t[%d] %s: (%v,%v) != (%v,%v)",
   150  					i, wasm.ValueTypeName(wasm.ValueTypeV128), cVal, cValHi, iVal, iValHi))
   151  			} else {
   152  				es = append(es, fmt.Sprintf("\t[%d] %s: %v != %v",
   153  					i, wasm.ValueTypeName(typ), cVal, iVal))
   154  			}
   155  		}
   156  	}
   157  	if len(es) > 0 {
   158  		requireNoError(fmt.Errorf("mutable globals mismatch:\n%s", strings.Join(es, "\n")))
   159  	}
   160  }
   161  
   162  // ensureDummyImports instantiates the modules which are required imports by `origin` *wasm.Module.
   163  func ensureDummyImports(r wazero.Runtime, origin *wasm.Module, requireNoError func(err error)) (skip bool) {
   164  	impMods := make(map[string][]wasm.Import)
   165  	for _, imp := range origin.ImportSection {
   166  		if imp.Module == "" {
   167  			// Importing empty modules are forbidden as future work will allow multiple anonymous modules.
   168  			skip = true
   169  			return
   170  		}
   171  		impMods[imp.Module] = append(impMods[imp.Module], imp)
   172  	}
   173  
   174  	for mName, impMod := range impMods {
   175  		usedName := make(map[string]struct{}, len(impMod))
   176  		m := &wasm.Module{NameSection: &wasm.NameSection{ModuleName: mName}}
   177  
   178  		for _, imp := range impMod {
   179  			_, ok := usedName[imp.Name]
   180  			if ok {
   181  				// Import segment can have duplicated "{module_name}.{name}" pair while it is prohibited for exports.
   182  				// Decision on allowing modules with these "ill" imports or not is up to embedder, and wazero chooses
   183  				// not to allow. Hence, we skip the entire case.
   184  				// See "Note" at https://www.w3.org/TR/wasm-core-2/syntax/modules.html#imports
   185  				return true
   186  			} else {
   187  				usedName[imp.Name] = struct{}{}
   188  			}
   189  
   190  			var index uint32
   191  			switch imp.Type {
   192  			case wasm.ExternTypeFunc:
   193  				tp := origin.TypeSection[imp.DescFunc]
   194  				typeIdx := uint32(len(m.TypeSection))
   195  				index = uint32(len(m.FunctionSection))
   196  				m.FunctionSection = append(m.FunctionSection, typeIdx)
   197  				m.TypeSection = append(m.TypeSection, tp)
   198  				body := bytes.NewBuffer(nil)
   199  				for _, vt := range tp.Results {
   200  					switch vt {
   201  					case wasm.ValueTypeI32:
   202  						body.WriteByte(wasm.OpcodeI32Const)
   203  						body.WriteByte(0)
   204  					case wasm.ValueTypeI64:
   205  						body.WriteByte(wasm.OpcodeI64Const)
   206  						body.WriteByte(0)
   207  					case wasm.ValueTypeF32:
   208  						body.Write([]byte{wasm.OpcodeF32Const, 0, 0, 0, 0})
   209  					case wasm.ValueTypeF64:
   210  						body.Write([]byte{wasm.OpcodeF64Const, 0, 0, 0, 0, 0, 0, 0, 0})
   211  					case wasm.ValueTypeV128:
   212  						body.Write([]byte{
   213  							wasm.OpcodeVecPrefix, wasm.OpcodeVecV128Const,
   214  							0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
   215  						})
   216  					case wasm.ValueTypeExternref:
   217  						body.Write([]byte{wasm.OpcodeRefNull, wasm.RefTypeExternref})
   218  					case wasm.ValueTypeFuncref:
   219  						body.Write([]byte{wasm.OpcodeRefNull, wasm.RefTypeFuncref})
   220  					}
   221  				}
   222  				body.WriteByte(wasm.OpcodeEnd)
   223  				m.CodeSection = append(m.CodeSection, wasm.Code{Body: body.Bytes()})
   224  			case wasm.ExternTypeGlobal:
   225  				index = uint32(len(m.GlobalSection))
   226  				var data []byte
   227  				var opcode byte
   228  				switch imp.DescGlobal.ValType {
   229  				case wasm.ValueTypeI32:
   230  					opcode = wasm.OpcodeI32Const
   231  					data = []byte{0}
   232  				case wasm.ValueTypeI64:
   233  					opcode = wasm.OpcodeI64Const
   234  					data = []byte{0}
   235  				case wasm.ValueTypeF32:
   236  					opcode = wasm.OpcodeF32Const
   237  					data = []byte{0, 0, 0, 0}
   238  				case wasm.ValueTypeF64:
   239  					opcode = wasm.OpcodeF64Const
   240  					data = []byte{0, 0, 0, 0, 0, 0, 0, 0}
   241  				case wasm.ValueTypeV128:
   242  					opcode = wasm.OpcodeVecPrefix
   243  					data = []byte{wasm.OpcodeVecV128Const, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
   244  				case wasm.ValueTypeExternref:
   245  					opcode = wasm.OpcodeRefNull
   246  					data = []byte{wasm.RefTypeExternref}
   247  				case wasm.ValueTypeFuncref:
   248  					opcode = wasm.OpcodeRefNull
   249  					data = []byte{wasm.RefTypeFuncref}
   250  				}
   251  				m.GlobalSection = append(m.GlobalSection, wasm.Global{
   252  					Type: imp.DescGlobal, Init: wasm.ConstantExpression{Opcode: opcode, Data: data},
   253  				})
   254  			case wasm.ExternTypeMemory:
   255  				m.MemorySection = imp.DescMem
   256  				index = 0
   257  			case wasm.ExternTypeTable:
   258  				index = uint32(len(m.TableSection))
   259  				m.TableSection = append(m.TableSection, imp.DescTable)
   260  			}
   261  			m.ExportSection = append(m.ExportSection, wasm.Export{Type: imp.Type, Name: imp.Name, Index: index})
   262  		}
   263  		_, err := r.Instantiate(context.Background(), binaryencoding.EncodeModule(m))
   264  		requireNoError(err)
   265  	}
   266  	return
   267  }
   268  
   269  const valueTypeVector = 0x7b
   270  
   271  // ensureInvocationResultMatch invokes all the exported functions from the module, and compare all the results between compiler vs interpreter.
   272  func ensureInvocationResultMatch(
   273  	compilerCtx, interpreterCtx context.Context, compiledMod, interpreterMod api.Module,
   274  	exportedFunctions map[string]api.FunctionDefinition,
   275  ) (err error, errorDuringInvocation bool) {
   276  	// In order to do the deterministic execution, we need to sort the exported functions.
   277  	var names []string
   278  	for f := range exportedFunctions {
   279  		names = append(names, f)
   280  	}
   281  	sort.Strings(names)
   282  
   283  outer:
   284  	for _, name := range names {
   285  		def := exportedFunctions[name]
   286  		resultTypes := def.ResultTypes()
   287  		for _, rt := range resultTypes {
   288  			switch rt {
   289  			case api.ValueTypeI32, api.ValueTypeI64, api.ValueTypeF32, api.ValueTypeF64, valueTypeVector:
   290  			default:
   291  				// For the sake of simplicity in the assertion, we only invoke the function with the basic types.
   292  				continue outer
   293  			}
   294  		}
   295  
   296  		cmpF := compiledMod.ExportedFunction(name)
   297  		intF := interpreterMod.ExportedFunction(name)
   298  
   299  		params := getDummyValues(def.ParamTypes())
   300  		cmpRes, cmpErr := cmpF.Call(compilerCtx, params...)
   301  		intRes, intErr := intF.Call(interpreterCtx, params...)
   302  		errorDuringInvocation = errorDuringInvocation || cmpErr != nil || intErr != nil
   303  		if errMismatch := ensureInvocationError(cmpErr, intErr); errMismatch != nil {
   304  			err = errors.Join(err, fmt.Errorf("error mismatch on invoking %s: %v", name, errMismatch))
   305  			continue
   306  		}
   307  
   308  		matched := true
   309  		var typesIndex int
   310  		for i := 0; i < len(cmpRes); i++ {
   311  			switch resultTypes[typesIndex] {
   312  			case api.ValueTypeI32, api.ValueTypeF32:
   313  				matched = matched && uint32(cmpRes[i]) == uint32(intRes[i])
   314  			case api.ValueTypeI64, api.ValueTypeF64:
   315  				matched = matched && cmpRes[i] == intRes[i]
   316  			case valueTypeVector:
   317  				matched = matched && cmpRes[i] == intRes[i] && cmpRes[i+1] == intRes[i+1]
   318  				i++ // We need to advance twice (lower and higher 64bits)
   319  			}
   320  			typesIndex++
   321  		}
   322  
   323  		if !matched {
   324  			err = errors.Join(err, fmt.Errorf("result mismatch on invoking '%s':\n\tinterpreter got: %v\n\tcompiler got: %v", name, intRes, cmpRes))
   325  		}
   326  	}
   327  	return
   328  }
   329  
   330  // getDummyValues returns a dummy input values for function invocations.
   331  func getDummyValues(valueTypes []api.ValueType) (ret []uint64) {
   332  	for _, vt := range valueTypes {
   333  		if vt != 0x7b { // v128
   334  			ret = append(ret, 0)
   335  		} else {
   336  			ret = append(ret, 0, 0)
   337  		}
   338  	}
   339  	return
   340  }
   341  
   342  // ensureInvocationError ensures that function invocation errors returned by interpreter and compiler match each other's.
   343  func ensureInvocationError(compilerErr, interpErr error) error {
   344  	if compilerErr == nil && interpErr == nil {
   345  		return nil
   346  	} else if compilerErr == nil && interpErr != nil {
   347  		return fmt.Errorf("compiler returned no error, but interpreter got: %w", interpErr)
   348  	} else if compilerErr != nil && interpErr == nil {
   349  		return fmt.Errorf("interpreter returned no error, but compiler got: %w", compilerErr)
   350  	}
   351  
   352  	compilerErrMsg, interpErrMsg := compilerErr.Error(), interpErr.Error()
   353  	if idx := strings.Index(compilerErrMsg, "\n"); idx >= 0 {
   354  		compilerErrMsg = compilerErrMsg[:strings.Index(compilerErrMsg, "\n")]
   355  	}
   356  	if idx := strings.Index(interpErrMsg, "\n"); idx >= 0 {
   357  		interpErrMsg = interpErrMsg[:strings.Index(interpErrMsg, "\n")]
   358  	}
   359  
   360  	if compiledStackOverFlow := strings.Contains(compilerErrMsg, "stack overflow"); compiledStackOverFlow && strings.Contains(interpErrMsg, "unreachable") {
   361  		// Compiler is more likely to reach stack overflow than interpreter, so we allow this case. This case is most likely
   362  		// that interpreter reached the unreachable out of "fuel".
   363  		return nil
   364  	} else if interpreterStackOverFlow := strings.Contains(interpErrMsg, "stack overflow"); compiledStackOverFlow && interpreterStackOverFlow {
   365  		// Both compiler and interpreter reached stack overflow, so we ignore diff in the content of the traces.
   366  		return nil
   367  	}
   368  
   369  	if compilerErrMsg != interpErrMsg {
   370  		return fmt.Errorf("error mismatch:\n\tinterpreter: %v\n\tcompiler: %v", interpErr, compilerErr)
   371  	}
   372  	return nil
   373  }
   374  
   375  // ensureInstantiationError ensures that instantiation errors returned by interpreter and compiler match each other's.
   376  func ensureInstantiationError(compilerErr, interpErr error) (okToInvoke bool, err error) {
   377  	if compilerErr == nil && interpErr == nil {
   378  		return true, nil
   379  	} else if compilerErr == nil && interpErr != nil {
   380  		return false, fmt.Errorf("compiler returned no error, but interpreter got: %w", interpErr)
   381  	} else if compilerErr != nil && interpErr == nil {
   382  		return false, fmt.Errorf("interpreter returned no error, but compiler got: %w", compilerErr)
   383  	}
   384  
   385  	compilerErrMsg, interpErrMsg := compilerErr.Error(), interpErr.Error()
   386  	if idx := strings.Index(compilerErrMsg, "\n"); idx >= 0 {
   387  		compilerErrMsg = compilerErrMsg[:strings.Index(compilerErrMsg, "\n")]
   388  	}
   389  	if idx := strings.Index(interpErrMsg, "\n"); idx >= 0 {
   390  		interpErrMsg = interpErrMsg[:strings.Index(interpErrMsg, "\n")]
   391  	}
   392  
   393  	if strings.Contains(compilerErrMsg, "stack overflow") && strings.Contains(interpErrMsg, "unreachable") {
   394  		// This is the case where the compiler reached stack overflow, but the interpreter reached the unreachable out of "fuel" during
   395  		// start function invocation. This is fine.
   396  		return false, nil
   397  	}
   398  
   399  	if !allowedErrorDuringInstantiation(compilerErrMsg) {
   400  		return false, fmt.Errorf("invalid error occur with compiler: %v vs interpreter: %v", compilerErr, interpErr)
   401  	} else if !allowedErrorDuringInstantiation(interpErrMsg) {
   402  		return false, fmt.Errorf("invalid error occur with interpreter: %v vs compiler: %v", interpErr, compilerErr)
   403  	}
   404  
   405  	if compilerErrMsg != interpErrMsg {
   406  		return false, fmt.Errorf("error mismatch:\n\tinterpreter: %v\n\tcompiler: %v", interpErr, compilerErr)
   407  	}
   408  	return false, nil
   409  }
   410  
   411  // allowedErrorDuringInstantiation checks if the error message is considered sane.
   412  func allowedErrorDuringInstantiation(errMsg string) bool {
   413  	// This happens when data segment causes out of bound, but it is considered as runtime-error in WebAssembly 2.0
   414  	// which is fine.
   415  	if strings.HasPrefix(errMsg, "data[") && strings.HasSuffix(errMsg, "]: out of bounds memory access") {
   416  		return true
   417  	}
   418  
   419  	// Start function failure is neither instantiation nor compilation error, but rather a runtime error, so that is fine.
   420  	if strings.HasPrefix(errMsg, "start function[") && strings.Contains(errMsg, "failed: wasm error:") {
   421  		return true
   422  	}
   423  	return false
   424  }