github.com/bananabytelabs/wazero@v0.0.0-20240105073314-54b22a776da8/internal/engine/compiler/compiler_drop_test.go (about)

     1  package compiler
     2  
     3  import (
     4  	"fmt"
     5  	"testing"
     6  
     7  	"github.com/bananabytelabs/wazero/internal/asm"
     8  	"github.com/bananabytelabs/wazero/internal/testing/require"
     9  	"github.com/bananabytelabs/wazero/internal/wasm"
    10  	"github.com/bananabytelabs/wazero/internal/wazeroir"
    11  )
    12  
    13  func Test_compileDropRange(t *testing.T) {
    14  	t.Run("nop range", func(t *testing.T) {
    15  		c := newCompiler()
    16  
    17  		err := compileDropRange(c, wazeroir.NopInclusiveRange.AsU64())
    18  		require.NoError(t, err)
    19  	})
    20  
    21  	t.Run("start at the top", func(t *testing.T) {
    22  		c := newCompiler()
    23  		c.Init(&wasm.FunctionType{}, nil, false)
    24  
    25  		// Use up all unreserved registers.
    26  		for _, reg := range unreservedGeneralPurposeRegisters {
    27  			c.pushRuntimeValueLocationOnRegister(reg, runtimeValueTypeI32)
    28  		}
    29  		for i, vreg := range unreservedVectorRegisters {
    30  			// Mix and match scalar float and vector values.
    31  			if i%2 == 0 {
    32  				c.pushVectorRuntimeValueLocationOnRegister(vreg)
    33  			} else {
    34  				c.pushRuntimeValueLocationOnRegister(vreg, runtimeValueTypeF32)
    35  			}
    36  		}
    37  
    38  		unreservedRegisterTotal := len(unreservedGeneralPurposeRegisters) + len(unreservedVectorRegisters)
    39  		ls := c.runtimeValueLocationStack()
    40  		require.Equal(t, unreservedRegisterTotal, len(ls.usedRegisters.list()))
    41  
    42  		// Drop all the values.
    43  		err := compileDropRange(c, wazeroir.InclusiveRange{Start: 0, End: int32(ls.sp - 1)}.AsU64())
    44  		require.NoError(t, err)
    45  
    46  		// All the registers must be marked unused.
    47  		require.Equal(t, 0, len(ls.usedRegisters.list()))
    48  		// Also, stack pointer must be zero.
    49  		require.Equal(t, 0, int(ls.sp))
    50  	})
    51  }
    52  
    53  func TestRuntimeValueLocationStack_dropsLivesForInclusiveRange(t *testing.T) {
    54  	tests := []struct {
    55  		v            *runtimeValueLocationStack
    56  		ir           wazeroir.InclusiveRange
    57  		lives, drops []runtimeValueLocation
    58  	}{
    59  		{
    60  			v: &runtimeValueLocationStack{
    61  				stack: []runtimeValueLocation{{register: 0}, {register: 1} /* drop target */, {register: 2}},
    62  				sp:    3,
    63  			},
    64  			ir:    wazeroir.InclusiveRange{Start: 1, End: 1},
    65  			drops: []runtimeValueLocation{{register: 1}},
    66  			lives: []runtimeValueLocation{{register: 2}},
    67  		},
    68  		{
    69  			v: &runtimeValueLocationStack{
    70  				stack: []runtimeValueLocation{
    71  					{register: 0},
    72  					{register: 1},
    73  					{register: 2}, // drop target
    74  					{register: 3}, // drop target
    75  					{register: 4}, // drop target
    76  					{register: 5},
    77  					{register: 6},
    78  				},
    79  				sp: 7,
    80  			},
    81  			ir:    wazeroir.InclusiveRange{Start: 2, End: 4},
    82  			drops: []runtimeValueLocation{{register: 2}, {register: 3}, {register: 4}},
    83  			lives: []runtimeValueLocation{{register: 5}, {register: 6}},
    84  		},
    85  	}
    86  
    87  	for _, tc := range tests {
    88  		actualDrops, actualLives := tc.v.dropsLivesForInclusiveRange(tc.ir)
    89  		require.Equal(t, tc.drops, actualDrops)
    90  		require.Equal(t, tc.lives, actualLives)
    91  	}
    92  }
    93  
    94  func Test_getTemporariesForStackedLiveValues(t *testing.T) {
    95  	t.Run("no stacked values", func(t *testing.T) {
    96  		liveValues := []runtimeValueLocation{{register: 1}, {register: 2}}
    97  		c := newCompiler()
    98  		c.Init(&wasm.FunctionType{}, nil, false)
    99  
   100  		gpTmp, vecTmp, err := getTemporariesForStackedLiveValues(c, liveValues)
   101  		require.NoError(t, err)
   102  
   103  		require.Equal(t, asm.NilRegister, gpTmp)
   104  		require.Equal(t, asm.NilRegister, vecTmp)
   105  	})
   106  	t.Run("general purpose needed", func(t *testing.T) {
   107  		for _, freeRegisterExists := range []bool{false, true} {
   108  			freeRegisterExists := freeRegisterExists
   109  			t.Run(fmt.Sprintf("free register exists=%v", freeRegisterExists), func(t *testing.T) {
   110  				liveValues := []runtimeValueLocation{
   111  					// Even multiple integer values are alive and on stack,
   112  					// only one general purpose register should be chosen.
   113  					{valueType: runtimeValueTypeI32},
   114  					{valueType: runtimeValueTypeI64},
   115  				}
   116  				c := newCompiler()
   117  				c.Init(&wasm.FunctionType{}, nil, false)
   118  
   119  				if !freeRegisterExists {
   120  					// Use up all the unreserved gp registers.
   121  					for _, reg := range unreservedGeneralPurposeRegisters {
   122  						c.pushRuntimeValueLocationOnRegister(reg, runtimeValueTypeI32)
   123  					}
   124  					// Ensures actually we used them up all.
   125  					require.Equal(t, len(c.runtimeValueLocationStack().usedRegisters.list()),
   126  						len(unreservedGeneralPurposeRegisters))
   127  				}
   128  
   129  				gpTmp, vecTmp, err := getTemporariesForStackedLiveValues(c, liveValues)
   130  				require.NoError(t, err)
   131  
   132  				if !freeRegisterExists {
   133  					// At this point, one register should be marked as unused.
   134  					require.Equal(t, len(c.runtimeValueLocationStack().usedRegisters.list()),
   135  						len(unreservedGeneralPurposeRegisters)-1)
   136  				}
   137  
   138  				require.NotEqual(t, asm.NilRegister, gpTmp)
   139  				require.Equal(t, asm.NilRegister, vecTmp)
   140  			})
   141  		}
   142  	})
   143  
   144  	t.Run("vector needed", func(t *testing.T) {
   145  		for _, freeRegisterExists := range []bool{false, true} {
   146  			freeRegisterExists := freeRegisterExists
   147  			t.Run(fmt.Sprintf("free register exists=%v", freeRegisterExists), func(t *testing.T) {
   148  				liveValues := []runtimeValueLocation{
   149  					// Even multiple vectors are alive and on stack,
   150  					// only one vector register should be chosen.
   151  					{valueType: runtimeValueTypeF32},
   152  					{valueType: runtimeValueTypeV128Lo},
   153  					{valueType: runtimeValueTypeV128Hi},
   154  					{valueType: runtimeValueTypeV128Lo},
   155  					{valueType: runtimeValueTypeV128Hi},
   156  				}
   157  				c := newCompiler()
   158  				c.Init(&wasm.FunctionType{}, nil, false)
   159  
   160  				if !freeRegisterExists {
   161  					// Use up all the unreserved gp registers.
   162  					for _, reg := range unreservedVectorRegisters {
   163  						c.pushVectorRuntimeValueLocationOnRegister(reg)
   164  					}
   165  					// Ensures actually we used them up all.
   166  					require.Equal(t, len(c.runtimeValueLocationStack().usedRegisters.list()),
   167  						len(unreservedVectorRegisters))
   168  				}
   169  
   170  				gpTmp, vecTmp, err := getTemporariesForStackedLiveValues(c, liveValues)
   171  				require.NoError(t, err)
   172  
   173  				if !freeRegisterExists {
   174  					// At this point, one register should be marked as unused.
   175  					require.Equal(t, len(c.runtimeValueLocationStack().usedRegisters.list()),
   176  						len(unreservedVectorRegisters)-1)
   177  				}
   178  
   179  				require.Equal(t, asm.NilRegister, gpTmp)
   180  				require.NotEqual(t, asm.NilRegister, vecTmp)
   181  			})
   182  		}
   183  	})
   184  }
   185  
   186  func Test_migrateLiveValue(t *testing.T) {
   187  	t.Run("v128.hi", func(t *testing.T) {
   188  		migrateLiveValue(nil, &runtimeValueLocation{valueType: runtimeValueTypeV128Hi}, asm.NilRegister, asm.NilRegister)
   189  	})
   190  	t.Run("already on register", func(t *testing.T) {
   191  		// This case, we don't use tmp registers.
   192  		c := newCompiler()
   193  		c.Init(&wasm.FunctionType{}, nil, false)
   194  
   195  		// Push the dummy values.
   196  		for i := 0; i < 10; i++ {
   197  			_ = c.runtimeValueLocationStack().pushRuntimeValueLocationOnStack()
   198  		}
   199  
   200  		gpReg := unreservedGeneralPurposeRegisters[0]
   201  		vReg := unreservedVectorRegisters[0]
   202  		c.pushRuntimeValueLocationOnRegister(gpReg, runtimeValueTypeI64)
   203  		c.pushVectorRuntimeValueLocationOnRegister(vReg)
   204  
   205  		// Emulate the compileDrop
   206  		ls := c.runtimeValueLocationStack()
   207  		vLive, gpLive := ls.popV128(), ls.pop()
   208  		const dropNum = 5
   209  		ls.sp -= dropNum
   210  
   211  		// Migrate these two values.
   212  		migrateLiveValue(c, gpLive, asm.NilRegister, asm.NilRegister)
   213  		migrateLiveValue(c, vLive, asm.NilRegister, asm.NilRegister)
   214  
   215  		// Check the new stack location.
   216  		vectorMigrated, gpMigrated := ls.popV128(), ls.pop()
   217  		require.Equal(t, uint64(5), gpMigrated.stackPointer)
   218  		require.Equal(t, uint64(6), vectorMigrated.stackPointer)
   219  
   220  		require.Equal(t, gpLive.register, gpMigrated.register)
   221  		require.Equal(t, vLive.register, vectorMigrated.register)
   222  	})
   223  }