wa-lang.org/wazero@v1.0.2/internal/engine/compiler/impl_amd64_test.go (about)

     1  package compiler
     2  
     3  import (
     4  	"testing"
     5  	"unsafe"
     6  
     7  	"wa-lang.org/wazero/internal/asm"
     8  	"wa-lang.org/wazero/internal/asm/amd64"
     9  	"wa-lang.org/wazero/internal/testing/require"
    10  	"wa-lang.org/wazero/internal/wasm"
    11  	"wa-lang.org/wazero/internal/wazeroir"
    12  )
    13  
    14  // TestAmd64Compiler_indirectCallWithTargetOnCallingConvReg is the regression test for #526.
    15  // In short, the offset register for call_indirect might be the same as amd64CallingConventionDestinationFunctionModuleInstanceAddressRegister
    16  // and that must not be a failure.
    17  func TestAmd64Compiler_indirectCallWithTargetOnCallingConvReg(t *testing.T) {
    18  	env := newCompilerEnvironment()
    19  	table := make([]wasm.Reference, 1)
    20  	env.addTable(&wasm.TableInstance{References: table})
    21  	// Ensure that the module instance has the type information for targetOperation.TypeIndex,
    22  	// and the typeID  matches the table[targetOffset]'s type ID.
    23  	operation := &wazeroir.OperationCallIndirect{TypeIndex: 0}
    24  	env.module().TypeIDs = []wasm.FunctionTypeID{0}
    25  	env.module().Engine = &moduleEngine{functions: []*function{}}
    26  
    27  	me := env.moduleEngine()
    28  	{ // Compiling call target.
    29  		compiler := env.requireNewCompiler(t, newCompiler, nil)
    30  		err := compiler.compilePreamble()
    31  		require.NoError(t, err)
    32  		err = compiler.compileReturnFunction()
    33  		require.NoError(t, err)
    34  
    35  		c, _, err := compiler.compile()
    36  		require.NoError(t, err)
    37  
    38  		f := &function{
    39  			parent:                &code{codeSegment: c},
    40  			codeInitialAddress:    uintptr(unsafe.Pointer(&c[0])),
    41  			moduleInstanceAddress: uintptr(unsafe.Pointer(env.moduleInstance)),
    42  			source:                &wasm.FunctionInstance{TypeID: 0},
    43  		}
    44  		me.functions = append(me.functions, f)
    45  		table[0] = uintptr(unsafe.Pointer(f))
    46  	}
    47  
    48  	compiler := env.requireNewCompiler(t, newCompiler, &wazeroir.CompilationResult{
    49  		Signature: &wasm.FunctionType{},
    50  		Types:     []*wasm.FunctionType{{}},
    51  		HasTable:  true,
    52  	}).(*amd64Compiler)
    53  	err := compiler.compilePreamble()
    54  	require.NoError(t, err)
    55  
    56  	// Place the offset into the calling-convention reserved register.
    57  	offsetLoc := compiler.pushRuntimeValueLocationOnRegister(amd64CallingConventionDestinationFunctionModuleInstanceAddressRegister,
    58  		runtimeValueTypeI32)
    59  	compiler.assembler.CompileConstToRegister(amd64.MOVQ, 0, offsetLoc.register)
    60  
    61  	require.NoError(t, compiler.compileCallIndirect(operation))
    62  
    63  	err = compiler.compileReturnFunction()
    64  	require.NoError(t, err)
    65  
    66  	// Generate the code under test and run.
    67  	code, _, err := compiler.compile()
    68  	require.NoError(t, err)
    69  	env.exec(code)
    70  }
    71  
    72  func TestAmd64Compiler_compile_Mul_Div_Rem(t *testing.T) {
    73  	for _, kind := range []wazeroir.OperationKind{
    74  		wazeroir.OperationKindMul,
    75  		wazeroir.OperationKindDiv,
    76  		wazeroir.OperationKindRem,
    77  	} {
    78  		kind := kind
    79  		t.Run(kind.String(), func(t *testing.T) {
    80  			t.Run("int32", func(t *testing.T) {
    81  				tests := []struct {
    82  					name         string
    83  					x1Reg, x2Reg asm.Register
    84  				}{
    85  					{
    86  						name:  "x1:ax,x2:random_reg",
    87  						x1Reg: amd64.RegAX,
    88  						x2Reg: amd64.RegR10,
    89  					},
    90  					{
    91  						name:  "x1:ax,x2:stack",
    92  						x1Reg: amd64.RegAX,
    93  						x2Reg: asm.NilRegister,
    94  					},
    95  					{
    96  						name:  "x1:random_reg,x2:ax",
    97  						x1Reg: amd64.RegR10,
    98  						x2Reg: amd64.RegAX,
    99  					},
   100  					{
   101  						name:  "x1:stack,x2:ax",
   102  						x1Reg: asm.NilRegister,
   103  						x2Reg: amd64.RegAX,
   104  					},
   105  					{
   106  						name:  "x1:random_reg,x2:random_reg",
   107  						x1Reg: amd64.RegR10,
   108  						x2Reg: amd64.RegR9,
   109  					},
   110  					{
   111  						name:  "x1:stack,x2:random_reg",
   112  						x1Reg: asm.NilRegister,
   113  						x2Reg: amd64.RegR9,
   114  					},
   115  					{
   116  						name:  "x1:random_reg,x2:stack",
   117  						x1Reg: amd64.RegR9,
   118  						x2Reg: asm.NilRegister,
   119  					},
   120  					{
   121  						name:  "x1:stack,x2:stack",
   122  						x1Reg: asm.NilRegister,
   123  						x2Reg: asm.NilRegister,
   124  					},
   125  				}
   126  
   127  				for _, tt := range tests {
   128  					tc := tt
   129  					t.Run(tc.name, func(t *testing.T) {
   130  						env := newCompilerEnvironment()
   131  
   132  						const x1Value uint32 = 1 << 11
   133  						const x2Value uint32 = 51
   134  						const dxValue uint64 = 111111
   135  
   136  						compiler := env.requireNewCompiler(t, newAmd64Compiler, nil).(*amd64Compiler)
   137  						err := compiler.compilePreamble()
   138  						require.NoError(t, err)
   139  
   140  						// Pretend there was an existing value on the DX register. We expect compileMul to save this to the stack.
   141  						// Here, we put it just before two operands as ["any value used by DX", x1, x2]
   142  						// but in reality, it can exist in any position of stack.
   143  						compiler.assembler.CompileConstToRegister(amd64.MOVQ, int64(dxValue), amd64.RegDX)
   144  						prevOnDX := compiler.pushRuntimeValueLocationOnRegister(amd64.RegDX, runtimeValueTypeI32)
   145  
   146  						// Setup values.
   147  						if tc.x1Reg != asm.NilRegister {
   148  							compiler.assembler.CompileConstToRegister(amd64.MOVQ, int64(x1Value), tc.x1Reg)
   149  							compiler.pushRuntimeValueLocationOnRegister(tc.x1Reg, runtimeValueTypeI32)
   150  						} else {
   151  							loc := compiler.runtimeValueLocationStack().pushRuntimeValueLocationOnStack()
   152  							env.stack()[loc.stackPointer] = uint64(x1Value)
   153  						}
   154  						if tc.x2Reg != asm.NilRegister {
   155  							compiler.assembler.CompileConstToRegister(amd64.MOVQ, int64(x2Value), tc.x2Reg)
   156  							compiler.pushRuntimeValueLocationOnRegister(tc.x2Reg, runtimeValueTypeI32)
   157  						} else {
   158  							loc := compiler.runtimeValueLocationStack().pushRuntimeValueLocationOnStack()
   159  							env.stack()[loc.stackPointer] = uint64(x2Value)
   160  						}
   161  
   162  						switch kind {
   163  						case wazeroir.OperationKindDiv:
   164  							err = compiler.compileDiv(&wazeroir.OperationDiv{Type: wazeroir.SignedTypeUint32})
   165  						case wazeroir.OperationKindMul:
   166  							err = compiler.compileMul(&wazeroir.OperationMul{Type: wazeroir.UnsignedTypeI32})
   167  						case wazeroir.OperationKindRem:
   168  							err = compiler.compileRem(&wazeroir.OperationRem{Type: wazeroir.SignedUint32})
   169  						}
   170  						require.NoError(t, err)
   171  
   172  						require.Equal(t, registerTypeGeneralPurpose, compiler.runtimeValueLocationStack().peek().getRegisterType())
   173  						requireRuntimeLocationStackPointerEqual(t, uint64(2), compiler)
   174  						require.Equal(t, 1, len(compiler.runtimeValueLocationStack().usedRegisters))
   175  						// At this point, the previous value on the DX register is saved to the stack.
   176  						require.True(t, prevOnDX.onStack())
   177  
   178  						// We add the value previously on the DX with the multiplication result
   179  						// in order to ensure that not saving existing DX value would cause
   180  						// the failure in a subsequent instruction.
   181  						err = compiler.compileAdd(&wazeroir.OperationAdd{Type: wazeroir.UnsignedTypeI32})
   182  						require.NoError(t, err)
   183  
   184  						require.NoError(t, compiler.compileReturnFunction())
   185  
   186  						// Generate the code under test.
   187  						code, _, err := compiler.compile()
   188  						require.NoError(t, err)
   189  						// Run code.
   190  						env.exec(code)
   191  
   192  						// Verify the stack is in the form of ["any value previously used by DX" + the result of operation]
   193  						require.Equal(t, uint64(1), env.stackPointer())
   194  						switch kind {
   195  						case wazeroir.OperationKindDiv:
   196  							require.Equal(t, x1Value/x2Value+uint32(dxValue), env.stackTopAsUint32())
   197  						case wazeroir.OperationKindMul:
   198  							require.Equal(t, x1Value*x2Value+uint32(dxValue), env.stackTopAsUint32())
   199  						case wazeroir.OperationKindRem:
   200  							require.Equal(t, x1Value%x2Value+uint32(dxValue), env.stackTopAsUint32())
   201  						}
   202  					})
   203  				}
   204  			})
   205  			t.Run("int64", func(t *testing.T) {
   206  				tests := []struct {
   207  					name         string
   208  					x1Reg, x2Reg asm.Register
   209  				}{
   210  					{
   211  						name:  "x1:ax,x2:random_reg",
   212  						x1Reg: amd64.RegAX,
   213  						x2Reg: amd64.RegR10,
   214  					},
   215  					{
   216  						name:  "x1:ax,x2:stack",
   217  						x1Reg: amd64.RegAX,
   218  						x2Reg: asm.NilRegister,
   219  					},
   220  					{
   221  						name:  "x1:random_reg,x2:ax",
   222  						x1Reg: amd64.RegR10,
   223  						x2Reg: amd64.RegAX,
   224  					},
   225  					{
   226  						name:  "x1:stack,x2:ax",
   227  						x1Reg: asm.NilRegister,
   228  						x2Reg: amd64.RegAX,
   229  					},
   230  					{
   231  						name:  "x1:random_reg,x2:random_reg",
   232  						x1Reg: amd64.RegR10,
   233  						x2Reg: amd64.RegR9,
   234  					},
   235  					{
   236  						name:  "x1:stack,x2:random_reg",
   237  						x1Reg: asm.NilRegister,
   238  						x2Reg: amd64.RegR9,
   239  					},
   240  					{
   241  						name:  "x1:random_reg,x2:stack",
   242  						x1Reg: amd64.RegR9,
   243  						x2Reg: asm.NilRegister,
   244  					},
   245  					{
   246  						name:  "x1:stack,x2:stack",
   247  						x1Reg: asm.NilRegister,
   248  						x2Reg: asm.NilRegister,
   249  					},
   250  				}
   251  
   252  				for _, tt := range tests {
   253  					tc := tt
   254  					t.Run(tc.name, func(t *testing.T) {
   255  						const x1Value uint64 = 1 << 35
   256  						const x2Value uint64 = 51
   257  						const dxValue uint64 = 111111
   258  
   259  						env := newCompilerEnvironment()
   260  						compiler := env.requireNewCompiler(t, newAmd64Compiler, nil).(*amd64Compiler)
   261  						err := compiler.compilePreamble()
   262  						require.NoError(t, err)
   263  
   264  						// Pretend there was an existing value on the DX register. We expect compileMul to save this to the stack.
   265  						// Here, we put it just before two operands as ["any value used by DX", x1, x2]
   266  						// but in reality, it can exist in any position of stack.
   267  						compiler.assembler.CompileConstToRegister(amd64.MOVQ, int64(dxValue), amd64.RegDX)
   268  						prevOnDX := compiler.pushRuntimeValueLocationOnRegister(amd64.RegDX, runtimeValueTypeI64)
   269  
   270  						// Setup values.
   271  						if tc.x1Reg != asm.NilRegister {
   272  							compiler.assembler.CompileConstToRegister(amd64.MOVQ, int64(x1Value), tc.x1Reg)
   273  							compiler.pushRuntimeValueLocationOnRegister(tc.x1Reg, runtimeValueTypeI64)
   274  						} else {
   275  							loc := compiler.runtimeValueLocationStack().pushRuntimeValueLocationOnStack()
   276  							loc.valueType = runtimeValueTypeI64
   277  							env.stack()[loc.stackPointer] = uint64(x1Value)
   278  						}
   279  						if tc.x2Reg != asm.NilRegister {
   280  							compiler.assembler.CompileConstToRegister(amd64.MOVQ, int64(x2Value), tc.x2Reg)
   281  							compiler.pushRuntimeValueLocationOnRegister(tc.x2Reg, runtimeValueTypeI64)
   282  						} else {
   283  							loc := compiler.runtimeValueLocationStack().pushRuntimeValueLocationOnStack()
   284  							loc.valueType = runtimeValueTypeI64
   285  							env.stack()[loc.stackPointer] = uint64(x2Value)
   286  						}
   287  
   288  						switch kind {
   289  						case wazeroir.OperationKindDiv:
   290  							err = compiler.compileDiv(&wazeroir.OperationDiv{Type: wazeroir.SignedTypeInt64})
   291  						case wazeroir.OperationKindMul:
   292  							err = compiler.compileMul(&wazeroir.OperationMul{Type: wazeroir.UnsignedTypeI64})
   293  						case wazeroir.OperationKindRem:
   294  							err = compiler.compileRem(&wazeroir.OperationRem{Type: wazeroir.SignedUint64})
   295  						}
   296  						require.NoError(t, err)
   297  
   298  						require.Equal(t, registerTypeGeneralPurpose, compiler.runtimeValueLocationStack().peek().getRegisterType())
   299  						requireRuntimeLocationStackPointerEqual(t, uint64(2), compiler)
   300  						require.Equal(t, 1, len(compiler.runtimeValueLocationStack().usedRegisters))
   301  						// At this point, the previous value on the DX register is saved to the stack.
   302  						require.True(t, prevOnDX.onStack())
   303  
   304  						// We add the value previously on the DX with the multiplication result
   305  						// in order to ensure that not saving existing DX value would cause
   306  						// the failure in a subsequent instruction.
   307  						err = compiler.compileAdd(&wazeroir.OperationAdd{Type: wazeroir.UnsignedTypeI64})
   308  						require.NoError(t, err)
   309  
   310  						require.NoError(t, compiler.compileReturnFunction())
   311  
   312  						// Generate the code under test.
   313  						code, _, err := compiler.compile()
   314  						require.NoError(t, err)
   315  
   316  						// Run code.
   317  						env.exec(code)
   318  
   319  						// Verify the stack is in the form of ["any value previously used by DX" + the result of operation]
   320  						switch kind {
   321  						case wazeroir.OperationKindDiv:
   322  							require.Equal(t, uint64(1), env.stackPointer())
   323  							require.Equal(t, uint64(x1Value/x2Value)+dxValue, env.stackTopAsUint64())
   324  						case wazeroir.OperationKindMul:
   325  							require.Equal(t, uint64(1), env.stackPointer())
   326  							require.Equal(t, uint64(x1Value*x2Value)+dxValue, env.stackTopAsUint64())
   327  						case wazeroir.OperationKindRem:
   328  							require.Equal(t, uint64(1), env.stackPointer())
   329  							require.Equal(t, x1Value%x2Value+dxValue, env.stackTopAsUint64())
   330  						}
   331  					})
   332  				}
   333  			})
   334  		})
   335  	}
   336  }
   337  
   338  func TestAmd64Compiler_readInstructionAddress(t *testing.T) {
   339  	t.Run("invalid", func(t *testing.T) {
   340  		env := newCompilerEnvironment()
   341  		compiler := env.requireNewCompiler(t, newAmd64Compiler, nil).(*amd64Compiler)
   342  
   343  		err := compiler.compilePreamble()
   344  		require.NoError(t, err)
   345  
   346  		// Set the acquisition target instruction to the one after JMP.
   347  		compiler.assembler.CompileReadInstructionAddress(amd64.RegAX, amd64.JMP)
   348  
   349  		// If generate the code without JMP after readInstructionAddress,
   350  		// the call back added must return error.
   351  		_, _, err = compiler.compile()
   352  		require.Error(t, err)
   353  	})
   354  
   355  	t.Run("ok", func(t *testing.T) {
   356  		env := newCompilerEnvironment()
   357  		compiler := env.requireNewCompiler(t, newAmd64Compiler, nil).(*amd64Compiler)
   358  
   359  		err := compiler.compilePreamble()
   360  		require.NoError(t, err)
   361  
   362  		const destinationRegister = amd64.RegAX
   363  		// Set the acquisition target instruction to the one after RET,
   364  		// and read the absolute address into destinationRegister.
   365  		compiler.assembler.CompileReadInstructionAddress(destinationRegister, amd64.RET)
   366  
   367  		// Jump to the instruction after RET below via the absolute
   368  		// address stored in destinationRegister.
   369  		compiler.assembler.CompileJumpToRegister(amd64.JMP, destinationRegister)
   370  
   371  		compiler.assembler.CompileStandAlone(amd64.RET)
   372  
   373  		// This could be the read instruction target as this is the
   374  		// right after RET. Therefore, the jmp instruction above
   375  		// must target here.
   376  		const expectedReturnValue uint32 = 10000
   377  		err = compiler.compileConstI32(&wazeroir.OperationConstI32{Value: expectedReturnValue})
   378  		require.NoError(t, err)
   379  
   380  		err = compiler.compileReturnFunction()
   381  		require.NoError(t, err)
   382  
   383  		// Generate the code under test.
   384  		code, _, err := compiler.compile()
   385  		require.NoError(t, err)
   386  
   387  		// Run code.
   388  		env.exec(code)
   389  
   390  		require.Equal(t, nativeCallStatusCodeReturned, env.compilerStatus())
   391  		require.Equal(t, uint64(1), env.stackPointer())
   392  		require.Equal(t, expectedReturnValue, env.stackTopAsUint32())
   393  	})
   394  }
   395  
   396  func TestAmd64Compiler_preventCrossedTargetdRegisters(t *testing.T) {
   397  	env := newCompilerEnvironment()
   398  	compiler := env.requireNewCompiler(t, newAmd64Compiler, nil).(*amd64Compiler)
   399  
   400  	tests := []struct {
   401  		initial           []*runtimeValueLocation
   402  		desired, expected []asm.Register
   403  	}{
   404  		{
   405  			initial:  []*runtimeValueLocation{{register: amd64.RegAX}, {register: amd64.RegCX}, {register: amd64.RegDX}},
   406  			desired:  []asm.Register{amd64.RegDX, amd64.RegCX, amd64.RegAX},
   407  			expected: []asm.Register{amd64.RegDX, amd64.RegCX, amd64.RegAX},
   408  		},
   409  		{
   410  			initial:  []*runtimeValueLocation{{register: amd64.RegAX}, {register: amd64.RegCX}, {register: amd64.RegDX}},
   411  			desired:  []asm.Register{amd64.RegDX, amd64.RegAX, amd64.RegCX},
   412  			expected: []asm.Register{amd64.RegDX, amd64.RegAX, amd64.RegCX},
   413  		},
   414  		{
   415  			initial:  []*runtimeValueLocation{{register: amd64.RegR8}, {register: amd64.RegR9}, {register: amd64.RegR10}},
   416  			desired:  []asm.Register{amd64.RegR8, amd64.RegR9, amd64.RegR10},
   417  			expected: []asm.Register{amd64.RegR8, amd64.RegR9, amd64.RegR10},
   418  		},
   419  		{
   420  			initial:  []*runtimeValueLocation{{register: amd64.RegBX}, {register: amd64.RegDX}, {register: amd64.RegCX}},
   421  			desired:  []asm.Register{amd64.RegR8, amd64.RegR9, amd64.RegR10},
   422  			expected: []asm.Register{amd64.RegBX, amd64.RegDX, amd64.RegCX},
   423  		},
   424  		{
   425  			initial:  []*runtimeValueLocation{{register: amd64.RegR8}, {register: amd64.RegR9}, {register: amd64.RegR10}},
   426  			desired:  []asm.Register{amd64.RegAX, amd64.RegCX, amd64.RegR9},
   427  			expected: []asm.Register{amd64.RegR8, amd64.RegR10, amd64.RegR9},
   428  		},
   429  	}
   430  
   431  	for _, tt := range tests {
   432  		initialRegisters := collectRegistersFromRuntimeValues(tt.initial)
   433  		restoreCrossing := compiler.compilePreventCrossedTargetRegisters(tt.initial, tt.desired)
   434  		// Required expected state after prevented crossing.
   435  		require.Equal(t, tt.expected, collectRegistersFromRuntimeValues(tt.initial))
   436  		restoreCrossing()
   437  		// Require initial state after restoring.
   438  		require.Equal(t, initialRegisters, collectRegistersFromRuntimeValues(tt.initial))
   439  	}
   440  }
   441  
   442  // collectRegistersFromRuntimeValues returns the registers occupied by locs.
   443  func collectRegistersFromRuntimeValues(locs []*runtimeValueLocation) []asm.Register {
   444  	out := make([]asm.Register, len(locs))
   445  	for i := range locs {
   446  		out[i] = locs[i].register
   447  	}
   448  	return out
   449  }
   450  
   451  // compile implements compilerImpl.getOnStackPointerCeilDeterminedCallBack for the amd64 architecture.
   452  func (c *amd64Compiler) getOnStackPointerCeilDeterminedCallBack() func(uint64) {
   453  	return c.onStackPointerCeilDeterminedCallBack
   454  }
   455  
   456  // compile implements compilerImpl.setStackPointerCeil for the amd64 architecture.
   457  func (c *amd64Compiler) setStackPointerCeil(v uint64) {
   458  	c.stackPointerCeil = v
   459  }
   460  
   461  // compile implements compilerImpl.setRuntimeValueLocationStack for the amd64 architecture.
   462  func (c *amd64Compiler) setRuntimeValueLocationStack(s *runtimeValueLocationStack) {
   463  	c.locationStack = s
   464  }